
626 lines
18 KiB
Raw Normal View History

2022-03-22 15:13:27 +00:00
import collections
import concurrent.futures
import copy
import datetime
import decimal
import functools
import hashlib
import itertools
import jinja2
import json
import os
import requests
import time
from contextlib import contextmanager
from dbt.exceptions import ConnectionException
from dbt.logger import GLOBAL_LOGGER as logger
from enum import Enum
from typing_extensions import Protocol
from typing import (
Tuple, Type, Any, Optional, TypeVar, Dict, Union, Callable, List, Iterator,
Mapping, Iterable, AbstractSet, Set, Sequence
import dbt.exceptions
from dbt.node_types import NodeType
DECIMALS: Tuple[Type[Any], ...]
import cdecimal # typing: ignore
except ImportError:
DECIMALS = (decimal.Decimal,)
DECIMALS = (decimal.Decimal, cdecimal.Decimal)
class ExitCodes(int, Enum):
Success = 0
ModelError = 1
UnhandledError = 2
def coalesce(*args):
for arg in args:
if arg is not None:
return arg
return None
def get_profile_from_project(project):
target_name = project.get('target', {})
profile = project.get('outputs', {}).get(target_name, {})
return profile
def get_model_name_or_none(model):
if model is None:
name = '<None>'
elif isinstance(model, str):
name = model
elif isinstance(model, dict):
name = model.get('alias', model.get('name'))
elif hasattr(model, 'alias'):
name = model.alias
elif hasattr(model, 'name'):
name = model.name
name = str(model)
return name
MACRO_PREFIX = 'dbt_macro__'
DOCS_PREFIX = 'dbt_docs__'
def get_dbt_macro_name(name):
if name is None:
raise dbt.exceptions.InternalException('Got None for a macro name!')
return f'{MACRO_PREFIX}{name}'
def get_dbt_docs_name(name):
if name is None:
raise dbt.exceptions.InternalException('Got None for a doc name!')
return f'{DOCS_PREFIX}{name}'
def get_materialization_macro_name(materialization_name, adapter_type=None,
if adapter_type is None:
adapter_type = 'default'
name = f'materialization_{materialization_name}_{adapter_type}'
return get_dbt_macro_name(name) if with_prefix else name
def get_docs_macro_name(docs_name, with_prefix=True):
return get_dbt_docs_name(docs_name) if with_prefix else docs_name
def get_test_macro_name(test_name, with_prefix=True):
name = f'test_{test_name}'
return get_dbt_macro_name(name) if with_prefix else name
def split_path(path):
return path.split(os.sep)
def merge(*args):
if len(args) == 0:
return None
if len(args) == 1:
return args[0]
lst = list(args)
last = lst.pop(len(lst) - 1)
return _merge(merge(*lst), last)
def _merge(a, b):
to_return = a.copy()
return to_return
# http://stackoverflow.com/questions/20656135/python-deep-merge-dictionary-data
def deep_merge(*args):
>>> dbt.utils.deep_merge({'a': 1, 'b': 2, 'c': 3}, {'a': 2}, {'a': 3, 'b': 1}) # noqa
{'a': 3, 'b': 1, 'c': 3}
if len(args) == 0:
return None
if len(args) == 1:
return copy.deepcopy(args[0])
lst = list(args)
last = copy.deepcopy(lst.pop(len(lst) - 1))
return _deep_merge(deep_merge(*lst), last)
def _deep_merge(destination, source):
if isinstance(source, dict):
for key, value in source.items():
deep_merge_item(destination, key, value)
return destination
def deep_merge_item(destination, key, value):
if isinstance(value, dict):
node = destination.setdefault(key, {})
destination[key] = deep_merge(node, value)
elif isinstance(value, tuple) or isinstance(value, list):
if key in destination:
destination[key] = list(value) + list(destination[key])
destination[key] = value
destination[key] = value
def _deep_map(
func: Callable[[Any, Tuple[Union[str, int], ...]], Any],
value: Any,
keypath: Tuple[Union[str, int], ...],
) -> Any:
atomic_types: Tuple[Type[Any], ...] = (int, float, str, type(None), bool)
ret: Any
if isinstance(value, list):
ret = [
_deep_map(func, v, (keypath + (idx,)))
for idx, v in enumerate(value)
elif isinstance(value, dict):
ret = {
k: _deep_map(func, v, (keypath + (str(k),)))
for k, v in value.items()
elif isinstance(value, atomic_types):
ret = func(value, keypath)
container_types: Tuple[Type[Any], ...] = (list, dict)
ok_types = container_types + atomic_types
raise dbt.exceptions.DbtConfigError(
'in _deep_map, expected one of {!r}, got {!r}'
.format(ok_types, type(value))
return ret
def deep_map(
func: Callable[[Any, Tuple[Union[str, int], ...]], Any],
value: Any
) -> Any:
"""map the function func() onto each non-container value in 'value'
recursively, returning a new value. As long as func does not manipulate
value, then deep_map will also not manipulate it.
value should be a value returned by `yaml.safe_load` or `json.load` - the
only expected types are list, dict, native python number, str, NoneType,
and bool.
func() will be called on numbers, strings, Nones, and booleans. Its first
parameter will be the value, and the second will be its keypath, an
iterable over the __getitem__ keys needed to get to it.
:raises: If there are cycles in the value, raises a
return _deep_map(func, value, ())
except RuntimeError as exc:
if 'maximum recursion depth exceeded' in str(exc):
raise dbt.exceptions.RecursionException(
'Cycle detected in deep_map'
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
def get_pseudo_test_path(node_name, source_path, test_type):
"schema tests all come from schema.yml files. fake a source sql file"
source_path_parts = split_path(source_path)
source_path_parts.pop() # ignore filename
suffix = [test_type, "{}.sql".format(node_name)]
pseudo_path_parts = source_path_parts + suffix
return os.path.join(*pseudo_path_parts)
def get_pseudo_hook_path(hook_name):
path_parts = ['hooks', "{}.sql".format(hook_name)]
return os.path.join(*path_parts)
def md5(string):
return hashlib.md5(string.encode('utf-8')).hexdigest()
def get_hash(model):
return hashlib.md5(model.unique_id.encode('utf-8')).hexdigest()
def get_hashed_contents(model):
return hashlib.md5(model.raw_sql.encode('utf-8')).hexdigest()
def flatten_nodes(dep_list):
return list(itertools.chain.from_iterable(dep_list))
class memoized:
'''Decorator. Caches a function's return value each time it is called. If
called later with the same arguments, the cached value is returned (not
Taken from https://wiki.python.org/moin/PythonDecoratorLibrary#Memoize'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.abc.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
K_T = TypeVar('K_T')
V_T = TypeVar('V_T')
def filter_null_values(input: Dict[K_T, Optional[V_T]]) -> Dict[K_T, V_T]:
return {k: v for k, v in input.items() if v is not None}
def add_ephemeral_model_prefix(s: str) -> str:
return '__dbt__cte__{}'.format(s)
def timestring() -> str:
"""Get the current datetime as an RFC 3339-compliant string"""
# isoformat doesn't include the mandatory trailing 'Z' for UTC.
return datetime.datetime.utcnow().isoformat() + 'Z'
class JSONEncoder(json.JSONEncoder):
"""A 'custom' json encoder that does normal json encoder things, but also
handles `Decimal`s. and `Undefined`s. Decimals can lose precision because
they get converted to floats. Undefined's are serialized to an empty string
def default(self, obj):
if isinstance(obj, DECIMALS):
return float(obj)
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return obj.isoformat()
if isinstance(obj, jinja2.Undefined):
return ""
if hasattr(obj, 'to_dict'):
# if we have a to_dict we should try to serialize the result of
# that!
return obj.to_dict(omit_none=True)
return super().default(obj)
class ForgivingJSONEncoder(JSONEncoder):
def default(self, obj):
# let dbt's default JSON encoder handle it if possible, fallback to
# str()
return super().default(obj)
except TypeError:
return str(obj)
class Translator:
def __init__(self, aliases: Mapping[str, str], recursive: bool = False):
self.aliases = aliases
self.recursive = recursive
def translate_mapping(
self, kwargs: Mapping[str, Any]
) -> Dict[str, Any]:
result: Dict[str, Any] = {}
for key, value in kwargs.items():
canonical_key = self.aliases.get(key, key)
if canonical_key in result:
kwargs, self.aliases, canonical_key
result[canonical_key] = self.translate_value(value)
return result
def translate_sequence(self, value: Sequence[Any]) -> List[Any]:
return [self.translate_value(v) for v in value]
def translate_value(self, value: Any) -> Any:
if self.recursive:
if isinstance(value, Mapping):
return self.translate_mapping(value)
elif isinstance(value, (list, tuple)):
return self.translate_sequence(value)
return value
def translate(self, value: Mapping[str, Any]) -> Dict[str, Any]:
return self.translate_mapping(value)
except RuntimeError as exc:
if 'maximum recursion depth exceeded' in str(exc):
raise dbt.exceptions.RecursionException(
'Cycle detected in a value passed to translate!'
def translate_aliases(
kwargs: Dict[str, Any], aliases: Dict[str, str], recurse: bool = False,
) -> Dict[str, Any]:
"""Given a dict of keyword arguments and a dict mapping aliases to their
canonical values, canonicalize the keys in the kwargs dict.
If recurse is True, perform this operation recursively.
:return: A dict containing all the values in kwargs referenced by their
canonical key.
:raises: `AliasException`, if a canonical key is defined more than once.
translator = Translator(aliases, recurse)
return translator.translate(kwargs)
def _pluralize(string: Union[str, NodeType]) -> str:
convert = NodeType(string)
except ValueError:
return f'{string}s'
return convert.pluralize()
def pluralize(count, string: Union[str, NodeType]):
pluralized: str = str(string)
if count != 1:
pluralized = _pluralize(string)
return f'{count} {pluralized}'
# Note that this only affects hologram json validation.
# It has no effect on mashumaro serialization.
def restrict_to(*restrictions):
"""Create the metadata for a restricted dataclass field"""
return {'restrict': list(restrictions)}
def coerce_dict_str(value: Any) -> Optional[Dict[str, Any]]:
"""For annoying mypy reasons, this helper makes dealing with nested dicts
easier. You get either `None` if it's not a Dict[str, Any], or the
Dict[str, Any] you expected (to pass it to dbtClassMixin.from_dict(...)).
if (isinstance(value, dict) and all(isinstance(k, str) for k in value)):
return value
return None
def _coerce_decimal(value):
if isinstance(value, DECIMALS):
return float(value)
return value
def lowercase(value: Optional[str]) -> Optional[str]:
if value is None:
return None
return value.lower()
# some types need to make constants available to the jinja context as
# attributes, and regular properties only work with objects. maybe this should
# be handled by the RelationProxy?
class classproperty(object):
def __init__(self, func):
self.func = func
def __get__(self, obj, objtype):
return self.func(objtype)
def format_bytes(num_bytes):
for unit in ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB']:
if abs(num_bytes) < 1024.0:
return f"{num_bytes:3.1f} {unit}"
num_bytes /= 1024.0
num_bytes *= 1024.0
return f"{num_bytes:3.1f} {unit}"
def format_rows_number(rows_number):
for unit in ['', 'k', 'm', 'b', 't']:
if abs(rows_number) < 1000.0:
return f"{rows_number:3.1f}{unit}".strip()
rows_number /= 1000.0
rows_number *= 1000.0
return f"{rows_number:3.1f}{unit}".strip()
class ConnectingExecutor(concurrent.futures.Executor):
def submit_connected(self, adapter, conn_name, func, *args, **kwargs):
def connected(conn_name, func, *args, **kwargs):
with self.connection_named(adapter, conn_name):
return func(*args, **kwargs)
return self.submit(connected, conn_name, func, *args, **kwargs)
# a little concurrent.futures.Executor for single-threaded mode
class SingleThreadedExecutor(ConnectingExecutor):
def submit(*args, **kwargs):
# this basic pattern comes from concurrent.futures.Executor itself,
# but without handling the `fn=` form.
if len(args) >= 2:
self, fn, *args = args
elif not args:
raise TypeError(
"descriptor 'submit' of 'SingleThreadedExecutor' object needs "
"an argument"
raise TypeError(
'submit expected at least 1 positional argument, '
'got %d' % (len(args) - 1)
fut = concurrent.futures.Future()
result = fn(*args, **kwargs)
except Exception as exc:
return fut
def connection_named(self, adapter, name):
class MultiThreadedExecutor(
def connection_named(self, adapter, name):
with adapter.connection_named(name):
class ThreadedArgs(Protocol):
single_threaded: bool
class HasThreadingConfig(Protocol):
args: ThreadedArgs
threads: Optional[int]
def executor(config: HasThreadingConfig) -> ConnectingExecutor:
if config.args.single_threaded:
return SingleThreadedExecutor()
return MultiThreadedExecutor(max_workers=config.threads)
def fqn_search(
root: Dict[str, Any], fqn: List[str]
) -> Iterator[Dict[str, Any]]:
"""Iterate into a nested dictionary, looking for keys in the fqn as levels.
Yield the level config.
yield root
for level in fqn:
level_config = root.get(level, None)
if not isinstance(level_config, dict):
# This used to do a 'deepcopy',
# but it didn't seem to be necessary
yield level_config
root = level_config
StringMap = Mapping[str, Any]
StringMapList = List[StringMap]
StringMapIter = Iterable[StringMap]
class MultiDict(Mapping[str, Any]):
"""Implement the mapping protocol using a list of mappings. The most
recently added mapping "wins".
def __init__(self, sources: Optional[StringMapList] = None) -> None:
self.sources: StringMapList
if sources is None:
self.sources = []
self.sources = sources
def add_from(self, sources: StringMapIter):
def add(self, source: StringMap):
def _keyset(self) -> AbstractSet[str]:
# return the set of keys
keys: Set[str] = set()
for entry in self._itersource():
return keys
def _itersource(self) -> StringMapIter:
return reversed(self.sources)
def __iter__(self) -> Iterator[str]:
# we need to avoid duplicate keys
return iter(self._keyset())
def __len__(self):
return len(self._keyset())
def __getitem__(self, name: str) -> Any:
for entry in self._itersource():
if name in entry:
return entry[name]
raise KeyError(name)
def __contains__(self, name) -> bool:
return any((name in entry for entry in self._itersource()))
def _connection_exception_retry(fn, max_attempts: int, attempt: int = 0):
"""Attempts to run a function that makes an external call, if the call fails
on a connection error or timeout, it will be tried up to 5 more times.
return fn()
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as exc:
if attempt <= max_attempts - 1:
logger.debug('Retrying external call. Attempt: ' +
f'{attempt} Max attempts: {max_attempts}')
_connection_exception_retry(fn, max_attempts, attempt + 1)
raise ConnectionException('External connection exception occurred: ' + str(exc))