626 lines
18 KiB
Python
626 lines
18 KiB
Python
import collections
|
|
import concurrent.futures
|
|
import copy
|
|
import datetime
|
|
import decimal
|
|
import functools
|
|
import hashlib
|
|
import itertools
|
|
import jinja2
|
|
import json
|
|
import os
|
|
import requests
|
|
import time
|
|
|
|
from contextlib import contextmanager
|
|
from dbt.exceptions import ConnectionException
|
|
from dbt.logger import GLOBAL_LOGGER as logger
|
|
from enum import Enum
|
|
from typing_extensions import Protocol
|
|
from typing import (
|
|
Tuple, Type, Any, Optional, TypeVar, Dict, Union, Callable, List, Iterator,
|
|
Mapping, Iterable, AbstractSet, Set, Sequence
|
|
)
|
|
|
|
import dbt.exceptions
|
|
|
|
from dbt.node_types import NodeType
|
|
|
|
DECIMALS: Tuple[Type[Any], ...]
|
|
try:
|
|
import cdecimal # typing: ignore
|
|
except ImportError:
|
|
DECIMALS = (decimal.Decimal,)
|
|
else:
|
|
DECIMALS = (decimal.Decimal, cdecimal.Decimal)
|
|
|
|
|
|
class ExitCodes(int, Enum):
|
|
Success = 0
|
|
ModelError = 1
|
|
UnhandledError = 2
|
|
|
|
|
|
def coalesce(*args):
|
|
for arg in args:
|
|
if arg is not None:
|
|
return arg
|
|
return None
|
|
|
|
|
|
def get_profile_from_project(project):
|
|
target_name = project.get('target', {})
|
|
profile = project.get('outputs', {}).get(target_name, {})
|
|
return profile
|
|
|
|
|
|
def get_model_name_or_none(model):
|
|
if model is None:
|
|
name = '<None>'
|
|
|
|
elif isinstance(model, str):
|
|
name = model
|
|
elif isinstance(model, dict):
|
|
name = model.get('alias', model.get('name'))
|
|
elif hasattr(model, 'alias'):
|
|
name = model.alias
|
|
elif hasattr(model, 'name'):
|
|
name = model.name
|
|
else:
|
|
name = str(model)
|
|
return name
|
|
|
|
|
|
MACRO_PREFIX = 'dbt_macro__'
|
|
DOCS_PREFIX = 'dbt_docs__'
|
|
|
|
|
|
def get_dbt_macro_name(name):
|
|
if name is None:
|
|
raise dbt.exceptions.InternalException('Got None for a macro name!')
|
|
return f'{MACRO_PREFIX}{name}'
|
|
|
|
|
|
def get_dbt_docs_name(name):
|
|
if name is None:
|
|
raise dbt.exceptions.InternalException('Got None for a doc name!')
|
|
return f'{DOCS_PREFIX}{name}'
|
|
|
|
|
|
def get_materialization_macro_name(materialization_name, adapter_type=None,
|
|
with_prefix=True):
|
|
if adapter_type is None:
|
|
adapter_type = 'default'
|
|
name = f'materialization_{materialization_name}_{adapter_type}'
|
|
return get_dbt_macro_name(name) if with_prefix else name
|
|
|
|
|
|
def get_docs_macro_name(docs_name, with_prefix=True):
|
|
return get_dbt_docs_name(docs_name) if with_prefix else docs_name
|
|
|
|
|
|
def get_test_macro_name(test_name, with_prefix=True):
|
|
name = f'test_{test_name}'
|
|
return get_dbt_macro_name(name) if with_prefix else name
|
|
|
|
|
|
def split_path(path):
|
|
return path.split(os.sep)
|
|
|
|
|
|
def merge(*args):
|
|
if len(args) == 0:
|
|
return None
|
|
|
|
if len(args) == 1:
|
|
return args[0]
|
|
|
|
lst = list(args)
|
|
last = lst.pop(len(lst) - 1)
|
|
|
|
return _merge(merge(*lst), last)
|
|
|
|
|
|
def _merge(a, b):
|
|
to_return = a.copy()
|
|
to_return.update(b)
|
|
return to_return
|
|
|
|
|
|
# http://stackoverflow.com/questions/20656135/python-deep-merge-dictionary-data
|
|
def deep_merge(*args):
|
|
"""
|
|
>>> dbt.utils.deep_merge({'a': 1, 'b': 2, 'c': 3}, {'a': 2}, {'a': 3, 'b': 1}) # noqa
|
|
{'a': 3, 'b': 1, 'c': 3}
|
|
"""
|
|
if len(args) == 0:
|
|
return None
|
|
|
|
if len(args) == 1:
|
|
return copy.deepcopy(args[0])
|
|
|
|
lst = list(args)
|
|
last = copy.deepcopy(lst.pop(len(lst) - 1))
|
|
|
|
return _deep_merge(deep_merge(*lst), last)
|
|
|
|
|
|
def _deep_merge(destination, source):
|
|
if isinstance(source, dict):
|
|
for key, value in source.items():
|
|
deep_merge_item(destination, key, value)
|
|
return destination
|
|
|
|
|
|
def deep_merge_item(destination, key, value):
|
|
if isinstance(value, dict):
|
|
node = destination.setdefault(key, {})
|
|
destination[key] = deep_merge(node, value)
|
|
elif isinstance(value, tuple) or isinstance(value, list):
|
|
if key in destination:
|
|
destination[key] = list(value) + list(destination[key])
|
|
else:
|
|
destination[key] = value
|
|
else:
|
|
destination[key] = value
|
|
|
|
|
|
def _deep_map(
|
|
func: Callable[[Any, Tuple[Union[str, int], ...]], Any],
|
|
value: Any,
|
|
keypath: Tuple[Union[str, int], ...],
|
|
) -> Any:
|
|
atomic_types: Tuple[Type[Any], ...] = (int, float, str, type(None), bool)
|
|
|
|
ret: Any
|
|
|
|
if isinstance(value, list):
|
|
ret = [
|
|
_deep_map(func, v, (keypath + (idx,)))
|
|
for idx, v in enumerate(value)
|
|
]
|
|
elif isinstance(value, dict):
|
|
ret = {
|
|
k: _deep_map(func, v, (keypath + (str(k),)))
|
|
for k, v in value.items()
|
|
}
|
|
elif isinstance(value, atomic_types):
|
|
ret = func(value, keypath)
|
|
else:
|
|
container_types: Tuple[Type[Any], ...] = (list, dict)
|
|
ok_types = container_types + atomic_types
|
|
raise dbt.exceptions.DbtConfigError(
|
|
'in _deep_map, expected one of {!r}, got {!r}'
|
|
.format(ok_types, type(value))
|
|
)
|
|
|
|
return ret
|
|
|
|
|
|
def deep_map(
|
|
func: Callable[[Any, Tuple[Union[str, int], ...]], Any],
|
|
value: Any
|
|
) -> Any:
|
|
"""map the function func() onto each non-container value in 'value'
|
|
recursively, returning a new value. As long as func does not manipulate
|
|
value, then deep_map will also not manipulate it.
|
|
|
|
value should be a value returned by `yaml.safe_load` or `json.load` - the
|
|
only expected types are list, dict, native python number, str, NoneType,
|
|
and bool.
|
|
|
|
func() will be called on numbers, strings, Nones, and booleans. Its first
|
|
parameter will be the value, and the second will be its keypath, an
|
|
iterable over the __getitem__ keys needed to get to it.
|
|
|
|
:raises: If there are cycles in the value, raises a
|
|
dbt.exceptions.RecursionException
|
|
"""
|
|
try:
|
|
return _deep_map(func, value, ())
|
|
except RuntimeError as exc:
|
|
if 'maximum recursion depth exceeded' in str(exc):
|
|
raise dbt.exceptions.RecursionException(
|
|
'Cycle detected in deep_map'
|
|
)
|
|
raise
|
|
|
|
|
|
class AttrDict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.__dict__ = self
|
|
|
|
|
|
def get_pseudo_test_path(node_name, source_path, test_type):
|
|
"schema tests all come from schema.yml files. fake a source sql file"
|
|
source_path_parts = split_path(source_path)
|
|
source_path_parts.pop() # ignore filename
|
|
suffix = [test_type, "{}.sql".format(node_name)]
|
|
pseudo_path_parts = source_path_parts + suffix
|
|
return os.path.join(*pseudo_path_parts)
|
|
|
|
|
|
def get_pseudo_hook_path(hook_name):
|
|
path_parts = ['hooks', "{}.sql".format(hook_name)]
|
|
return os.path.join(*path_parts)
|
|
|
|
|
|
def md5(string):
|
|
return hashlib.md5(string.encode('utf-8')).hexdigest()
|
|
|
|
|
|
def get_hash(model):
|
|
return hashlib.md5(model.unique_id.encode('utf-8')).hexdigest()
|
|
|
|
|
|
def get_hashed_contents(model):
|
|
return hashlib.md5(model.raw_sql.encode('utf-8')).hexdigest()
|
|
|
|
|
|
def flatten_nodes(dep_list):
|
|
return list(itertools.chain.from_iterable(dep_list))
|
|
|
|
|
|
class memoized:
|
|
'''Decorator. Caches a function's return value each time it is called. If
|
|
called later with the same arguments, the cached value is returned (not
|
|
reevaluated).
|
|
|
|
Taken from https://wiki.python.org/moin/PythonDecoratorLibrary#Memoize'''
|
|
def __init__(self, func):
|
|
self.func = func
|
|
self.cache = {}
|
|
|
|
def __call__(self, *args):
|
|
if not isinstance(args, collections.abc.Hashable):
|
|
# uncacheable. a list, for instance.
|
|
# better to not cache than blow up.
|
|
return self.func(*args)
|
|
if args in self.cache:
|
|
return self.cache[args]
|
|
value = self.func(*args)
|
|
self.cache[args] = value
|
|
return value
|
|
|
|
def __repr__(self):
|
|
'''Return the function's docstring.'''
|
|
return self.func.__doc__
|
|
|
|
def __get__(self, obj, objtype):
|
|
'''Support instance methods.'''
|
|
return functools.partial(self.__call__, obj)
|
|
|
|
|
|
K_T = TypeVar('K_T')
|
|
V_T = TypeVar('V_T')
|
|
|
|
|
|
def filter_null_values(input: Dict[K_T, Optional[V_T]]) -> Dict[K_T, V_T]:
|
|
return {k: v for k, v in input.items() if v is not None}
|
|
|
|
|
|
def add_ephemeral_model_prefix(s: str) -> str:
|
|
return '__dbt__cte__{}'.format(s)
|
|
|
|
|
|
def timestring() -> str:
|
|
"""Get the current datetime as an RFC 3339-compliant string"""
|
|
# isoformat doesn't include the mandatory trailing 'Z' for UTC.
|
|
return datetime.datetime.utcnow().isoformat() + 'Z'
|
|
|
|
|
|
class JSONEncoder(json.JSONEncoder):
|
|
"""A 'custom' json encoder that does normal json encoder things, but also
|
|
handles `Decimal`s. and `Undefined`s. Decimals can lose precision because
|
|
they get converted to floats. Undefined's are serialized to an empty string
|
|
"""
|
|
def default(self, obj):
|
|
if isinstance(obj, DECIMALS):
|
|
return float(obj)
|
|
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
|
|
return obj.isoformat()
|
|
if isinstance(obj, jinja2.Undefined):
|
|
return ""
|
|
if hasattr(obj, 'to_dict'):
|
|
# if we have a to_dict we should try to serialize the result of
|
|
# that!
|
|
return obj.to_dict(omit_none=True)
|
|
return super().default(obj)
|
|
|
|
|
|
class ForgivingJSONEncoder(JSONEncoder):
|
|
def default(self, obj):
|
|
# let dbt's default JSON encoder handle it if possible, fallback to
|
|
# str()
|
|
try:
|
|
return super().default(obj)
|
|
except TypeError:
|
|
return str(obj)
|
|
|
|
|
|
class Translator:
|
|
def __init__(self, aliases: Mapping[str, str], recursive: bool = False):
|
|
self.aliases = aliases
|
|
self.recursive = recursive
|
|
|
|
def translate_mapping(
|
|
self, kwargs: Mapping[str, Any]
|
|
) -> Dict[str, Any]:
|
|
result: Dict[str, Any] = {}
|
|
|
|
for key, value in kwargs.items():
|
|
canonical_key = self.aliases.get(key, key)
|
|
if canonical_key in result:
|
|
dbt.exceptions.raise_duplicate_alias(
|
|
kwargs, self.aliases, canonical_key
|
|
)
|
|
result[canonical_key] = self.translate_value(value)
|
|
return result
|
|
|
|
def translate_sequence(self, value: Sequence[Any]) -> List[Any]:
|
|
return [self.translate_value(v) for v in value]
|
|
|
|
def translate_value(self, value: Any) -> Any:
|
|
if self.recursive:
|
|
if isinstance(value, Mapping):
|
|
return self.translate_mapping(value)
|
|
elif isinstance(value, (list, tuple)):
|
|
return self.translate_sequence(value)
|
|
return value
|
|
|
|
def translate(self, value: Mapping[str, Any]) -> Dict[str, Any]:
|
|
try:
|
|
return self.translate_mapping(value)
|
|
except RuntimeError as exc:
|
|
if 'maximum recursion depth exceeded' in str(exc):
|
|
raise dbt.exceptions.RecursionException(
|
|
'Cycle detected in a value passed to translate!'
|
|
)
|
|
raise
|
|
|
|
|
|
def translate_aliases(
|
|
kwargs: Dict[str, Any], aliases: Dict[str, str], recurse: bool = False,
|
|
) -> Dict[str, Any]:
|
|
"""Given a dict of keyword arguments and a dict mapping aliases to their
|
|
canonical values, canonicalize the keys in the kwargs dict.
|
|
|
|
If recurse is True, perform this operation recursively.
|
|
|
|
:return: A dict containing all the values in kwargs referenced by their
|
|
canonical key.
|
|
:raises: `AliasException`, if a canonical key is defined more than once.
|
|
"""
|
|
translator = Translator(aliases, recurse)
|
|
return translator.translate(kwargs)
|
|
|
|
|
|
def _pluralize(string: Union[str, NodeType]) -> str:
|
|
try:
|
|
convert = NodeType(string)
|
|
except ValueError:
|
|
return f'{string}s'
|
|
else:
|
|
return convert.pluralize()
|
|
|
|
|
|
def pluralize(count, string: Union[str, NodeType]):
|
|
pluralized: str = str(string)
|
|
if count != 1:
|
|
pluralized = _pluralize(string)
|
|
return f'{count} {pluralized}'
|
|
|
|
|
|
# Note that this only affects hologram json validation.
|
|
# It has no effect on mashumaro serialization.
|
|
def restrict_to(*restrictions):
|
|
"""Create the metadata for a restricted dataclass field"""
|
|
return {'restrict': list(restrictions)}
|
|
|
|
|
|
def coerce_dict_str(value: Any) -> Optional[Dict[str, Any]]:
|
|
"""For annoying mypy reasons, this helper makes dealing with nested dicts
|
|
easier. You get either `None` if it's not a Dict[str, Any], or the
|
|
Dict[str, Any] you expected (to pass it to dbtClassMixin.from_dict(...)).
|
|
"""
|
|
if (isinstance(value, dict) and all(isinstance(k, str) for k in value)):
|
|
return value
|
|
else:
|
|
return None
|
|
|
|
|
|
def _coerce_decimal(value):
|
|
if isinstance(value, DECIMALS):
|
|
return float(value)
|
|
return value
|
|
|
|
|
|
def lowercase(value: Optional[str]) -> Optional[str]:
|
|
if value is None:
|
|
return None
|
|
else:
|
|
return value.lower()
|
|
|
|
|
|
# some types need to make constants available to the jinja context as
|
|
# attributes, and regular properties only work with objects. maybe this should
|
|
# be handled by the RelationProxy?
|
|
|
|
class classproperty(object):
|
|
def __init__(self, func):
|
|
self.func = func
|
|
|
|
def __get__(self, obj, objtype):
|
|
return self.func(objtype)
|
|
|
|
|
|
def format_bytes(num_bytes):
|
|
for unit in ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB']:
|
|
if abs(num_bytes) < 1024.0:
|
|
return f"{num_bytes:3.1f} {unit}"
|
|
num_bytes /= 1024.0
|
|
|
|
num_bytes *= 1024.0
|
|
return f"{num_bytes:3.1f} {unit}"
|
|
|
|
|
|
def format_rows_number(rows_number):
|
|
for unit in ['', 'k', 'm', 'b', 't']:
|
|
if abs(rows_number) < 1000.0:
|
|
return f"{rows_number:3.1f}{unit}".strip()
|
|
rows_number /= 1000.0
|
|
|
|
rows_number *= 1000.0
|
|
return f"{rows_number:3.1f}{unit}".strip()
|
|
|
|
|
|
class ConnectingExecutor(concurrent.futures.Executor):
|
|
def submit_connected(self, adapter, conn_name, func, *args, **kwargs):
|
|
def connected(conn_name, func, *args, **kwargs):
|
|
with self.connection_named(adapter, conn_name):
|
|
return func(*args, **kwargs)
|
|
return self.submit(connected, conn_name, func, *args, **kwargs)
|
|
|
|
|
|
# a little concurrent.futures.Executor for single-threaded mode
|
|
class SingleThreadedExecutor(ConnectingExecutor):
|
|
def submit(*args, **kwargs):
|
|
# this basic pattern comes from concurrent.futures.Executor itself,
|
|
# but without handling the `fn=` form.
|
|
if len(args) >= 2:
|
|
self, fn, *args = args
|
|
elif not args:
|
|
raise TypeError(
|
|
"descriptor 'submit' of 'SingleThreadedExecutor' object needs "
|
|
"an argument"
|
|
)
|
|
else:
|
|
raise TypeError(
|
|
'submit expected at least 1 positional argument, '
|
|
'got %d' % (len(args) - 1)
|
|
)
|
|
fut = concurrent.futures.Future()
|
|
try:
|
|
result = fn(*args, **kwargs)
|
|
except Exception as exc:
|
|
fut.set_exception(exc)
|
|
else:
|
|
fut.set_result(result)
|
|
return fut
|
|
|
|
@contextmanager
|
|
def connection_named(self, adapter, name):
|
|
yield
|
|
|
|
|
|
class MultiThreadedExecutor(
|
|
ConnectingExecutor,
|
|
concurrent.futures.ThreadPoolExecutor,
|
|
):
|
|
@contextmanager
|
|
def connection_named(self, adapter, name):
|
|
with adapter.connection_named(name):
|
|
yield
|
|
|
|
|
|
class ThreadedArgs(Protocol):
|
|
single_threaded: bool
|
|
|
|
|
|
class HasThreadingConfig(Protocol):
|
|
args: ThreadedArgs
|
|
threads: Optional[int]
|
|
|
|
|
|
def executor(config: HasThreadingConfig) -> ConnectingExecutor:
|
|
if config.args.single_threaded:
|
|
return SingleThreadedExecutor()
|
|
else:
|
|
return MultiThreadedExecutor(max_workers=config.threads)
|
|
|
|
|
|
def fqn_search(
|
|
root: Dict[str, Any], fqn: List[str]
|
|
) -> Iterator[Dict[str, Any]]:
|
|
"""Iterate into a nested dictionary, looking for keys in the fqn as levels.
|
|
Yield the level config.
|
|
"""
|
|
yield root
|
|
|
|
for level in fqn:
|
|
level_config = root.get(level, None)
|
|
if not isinstance(level_config, dict):
|
|
break
|
|
# This used to do a 'deepcopy',
|
|
# but it didn't seem to be necessary
|
|
yield level_config
|
|
root = level_config
|
|
|
|
|
|
StringMap = Mapping[str, Any]
|
|
StringMapList = List[StringMap]
|
|
StringMapIter = Iterable[StringMap]
|
|
|
|
|
|
class MultiDict(Mapping[str, Any]):
|
|
"""Implement the mapping protocol using a list of mappings. The most
|
|
recently added mapping "wins".
|
|
"""
|
|
def __init__(self, sources: Optional[StringMapList] = None) -> None:
|
|
super().__init__()
|
|
self.sources: StringMapList
|
|
|
|
if sources is None:
|
|
self.sources = []
|
|
else:
|
|
self.sources = sources
|
|
|
|
def add_from(self, sources: StringMapIter):
|
|
self.sources.extend(sources)
|
|
|
|
def add(self, source: StringMap):
|
|
self.sources.append(source)
|
|
|
|
def _keyset(self) -> AbstractSet[str]:
|
|
# return the set of keys
|
|
keys: Set[str] = set()
|
|
for entry in self._itersource():
|
|
keys.update(entry)
|
|
return keys
|
|
|
|
def _itersource(self) -> StringMapIter:
|
|
return reversed(self.sources)
|
|
|
|
def __iter__(self) -> Iterator[str]:
|
|
# we need to avoid duplicate keys
|
|
return iter(self._keyset())
|
|
|
|
def __len__(self):
|
|
return len(self._keyset())
|
|
|
|
def __getitem__(self, name: str) -> Any:
|
|
for entry in self._itersource():
|
|
if name in entry:
|
|
return entry[name]
|
|
raise KeyError(name)
|
|
|
|
def __contains__(self, name) -> bool:
|
|
return any((name in entry for entry in self._itersource()))
|
|
|
|
|
|
def _connection_exception_retry(fn, max_attempts: int, attempt: int = 0):
|
|
"""Attempts to run a function that makes an external call, if the call fails
|
|
on a connection error or timeout, it will be tried up to 5 more times.
|
|
"""
|
|
try:
|
|
return fn()
|
|
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as exc:
|
|
if attempt <= max_attempts - 1:
|
|
logger.debug('Retrying external call. Attempt: ' +
|
|
f'{attempt} Max attempts: {max_attempts}')
|
|
time.sleep(1)
|
|
_connection_exception_retry(fn, max_attempts, attempt + 1)
|
|
else:
|
|
raise ConnectionException('External connection exception occurred: ' + str(exc))
|