567 lines
16 KiB
Python
567 lines
16 KiB
Python
from dataclasses import field, Field, dataclass
|
|
from enum import Enum
|
|
from itertools import chain
|
|
from typing import (
|
|
Any, List, Optional, Dict, Union, Type, TypeVar, Callable
|
|
)
|
|
from dbt.dataclass_schema import (
|
|
dbtClassMixin, ValidationError, register_pattern,
|
|
)
|
|
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed
|
|
from dbt.exceptions import InternalException, CompilationException
|
|
from dbt.contracts.util import Replaceable, list_str
|
|
from dbt import hooks
|
|
from dbt.node_types import NodeType
|
|
|
|
|
|
M = TypeVar('M', bound='Metadata')
|
|
|
|
|
|
def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
|
|
# a metadata field might exist. If it does, it might have a matching key.
|
|
# If it has both, make sure the value is valid and return it. If it
|
|
# doesn't, return the default.
|
|
if fld.metadata:
|
|
value = fld.metadata.get(key, default)
|
|
else:
|
|
value = default
|
|
|
|
try:
|
|
return cls(value)
|
|
except ValueError as exc:
|
|
raise InternalException(
|
|
f'Invalid {cls} value: {value}'
|
|
) from exc
|
|
|
|
|
|
def _set_meta_value(
|
|
obj: M, key: str, existing: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
if existing is None:
|
|
result = {}
|
|
else:
|
|
result = existing.copy()
|
|
result.update({key: obj})
|
|
return result
|
|
|
|
|
|
class Metadata(Enum):
|
|
@classmethod
|
|
def from_field(cls: Type[M], fld: Field) -> M:
|
|
default = cls.default_field()
|
|
key = cls.metadata_key()
|
|
|
|
return _get_meta_value(cls, fld, key, default)
|
|
|
|
def meta(
|
|
self, existing: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
key = self.metadata_key()
|
|
return _set_meta_value(self, key, existing)
|
|
|
|
@classmethod
|
|
def default_field(cls) -> 'Metadata':
|
|
raise NotImplementedError('Not implemented')
|
|
|
|
@classmethod
|
|
def metadata_key(cls) -> str:
|
|
raise NotImplementedError('Not implemented')
|
|
|
|
|
|
class MergeBehavior(Metadata):
|
|
Append = 1
|
|
Update = 2
|
|
Clobber = 3
|
|
|
|
@classmethod
|
|
def default_field(cls) -> 'MergeBehavior':
|
|
return cls.Clobber
|
|
|
|
@classmethod
|
|
def metadata_key(cls) -> str:
|
|
return 'merge'
|
|
|
|
|
|
class ShowBehavior(Metadata):
|
|
Show = 1
|
|
Hide = 2
|
|
|
|
@classmethod
|
|
def default_field(cls) -> 'ShowBehavior':
|
|
return cls.Show
|
|
|
|
@classmethod
|
|
def metadata_key(cls) -> str:
|
|
return 'show_hide'
|
|
|
|
@classmethod
|
|
def should_show(cls, fld: Field) -> bool:
|
|
return cls.from_field(fld) == cls.Show
|
|
|
|
|
|
class CompareBehavior(Metadata):
|
|
Include = 1
|
|
Exclude = 2
|
|
|
|
@classmethod
|
|
def default_field(cls) -> 'CompareBehavior':
|
|
return cls.Include
|
|
|
|
@classmethod
|
|
def metadata_key(cls) -> str:
|
|
return 'compare'
|
|
|
|
@classmethod
|
|
def should_include(cls, fld: Field) -> bool:
|
|
return cls.from_field(fld) == cls.Include
|
|
|
|
|
|
def metas(*metas: Metadata) -> Dict[str, Any]:
|
|
existing: Dict[str, Any] = {}
|
|
for m in metas:
|
|
existing = m.meta(existing)
|
|
return existing
|
|
|
|
|
|
def _listify(value: Any) -> List:
|
|
if isinstance(value, list):
|
|
return value[:]
|
|
else:
|
|
return [value]
|
|
|
|
|
|
def _merge_field_value(
|
|
merge_behavior: MergeBehavior,
|
|
self_value: Any,
|
|
other_value: Any,
|
|
):
|
|
if merge_behavior == MergeBehavior.Clobber:
|
|
return other_value
|
|
elif merge_behavior == MergeBehavior.Append:
|
|
return _listify(self_value) + _listify(other_value)
|
|
elif merge_behavior == MergeBehavior.Update:
|
|
if not isinstance(self_value, dict):
|
|
raise InternalException(f'expected dict, got {self_value}')
|
|
if not isinstance(other_value, dict):
|
|
raise InternalException(f'expected dict, got {other_value}')
|
|
value = self_value.copy()
|
|
value.update(other_value)
|
|
return value
|
|
else:
|
|
raise InternalException(
|
|
f'Got an invalid merge_behavior: {merge_behavior}'
|
|
)
|
|
|
|
|
|
def insensitive_patterns(*patterns: str):
|
|
lowercased = []
|
|
for pattern in patterns:
|
|
lowercased.append(
|
|
''.join('[{}{}]'.format(s.upper(), s.lower()) for s in pattern)
|
|
)
|
|
return '^({})$'.format('|'.join(lowercased))
|
|
|
|
|
|
class Severity(str):
|
|
pass
|
|
|
|
|
|
register_pattern(Severity, insensitive_patterns('warn', 'error'))
|
|
|
|
|
|
@dataclass
|
|
class Hook(dbtClassMixin, Replaceable):
|
|
sql: str
|
|
transaction: bool = True
|
|
index: Optional[int] = None
|
|
|
|
|
|
T = TypeVar('T', bound='BaseConfig')
|
|
|
|
|
|
@dataclass
|
|
class BaseConfig(
|
|
AdditionalPropertiesAllowed, Replaceable
|
|
):
|
|
|
|
# enable syntax like: config['key']
|
|
def __getitem__(self, key):
|
|
return self.get(key)
|
|
|
|
# like doing 'get' on a dictionary
|
|
def get(self, key, default=None):
|
|
if hasattr(self, key):
|
|
return getattr(self, key)
|
|
elif key in self._extra:
|
|
return self._extra[key]
|
|
else:
|
|
return default
|
|
|
|
# enable syntax like: config['key'] = value
|
|
def __setitem__(self, key, value):
|
|
if hasattr(self, key):
|
|
setattr(self, key, value)
|
|
else:
|
|
self._extra[key] = value
|
|
|
|
def __delitem__(self, key):
|
|
if hasattr(self, key):
|
|
msg = (
|
|
'Error, tried to delete config key "{}": Cannot delete '
|
|
'built-in keys'
|
|
).format(key)
|
|
raise CompilationException(msg)
|
|
else:
|
|
del self._extra[key]
|
|
|
|
def _content_iterator(self, include_condition: Callable[[Field], bool]):
|
|
seen = set()
|
|
for fld, _ in self._get_fields():
|
|
seen.add(fld.name)
|
|
if include_condition(fld):
|
|
yield fld.name
|
|
|
|
for key in self._extra:
|
|
if key not in seen:
|
|
seen.add(key)
|
|
yield key
|
|
|
|
def __iter__(self):
|
|
yield from self._content_iterator(include_condition=lambda f: True)
|
|
|
|
def __len__(self):
|
|
return len(self._get_fields()) + len(self._extra)
|
|
|
|
@staticmethod
|
|
def compare_key(
|
|
unrendered: Dict[str, Any],
|
|
other: Dict[str, Any],
|
|
key: str,
|
|
) -> bool:
|
|
if key not in unrendered and key not in other:
|
|
return True
|
|
elif key not in unrendered and key in other:
|
|
return False
|
|
elif key in unrendered and key not in other:
|
|
return False
|
|
else:
|
|
return unrendered[key] == other[key]
|
|
|
|
@classmethod
|
|
def same_contents(
|
|
cls, unrendered: Dict[str, Any], other: Dict[str, Any]
|
|
) -> bool:
|
|
"""This is like __eq__, except it ignores some fields."""
|
|
seen = set()
|
|
for fld, target_name in cls._get_fields():
|
|
key = target_name
|
|
seen.add(key)
|
|
if CompareBehavior.should_include(fld):
|
|
if not cls.compare_key(unrendered, other, key):
|
|
return False
|
|
|
|
for key in chain(unrendered, other):
|
|
if key not in seen:
|
|
seen.add(key)
|
|
if not cls.compare_key(unrendered, other, key):
|
|
return False
|
|
return True
|
|
|
|
# This is used in 'add_config_call' to created the combined config_call_dict.
|
|
# 'meta' moved here from node
|
|
mergebehavior = {
|
|
"append": ['pre-hook', 'pre_hook', 'post-hook', 'post_hook', 'tags'],
|
|
"update": ['quoting', 'column_types', 'meta'],
|
|
}
|
|
|
|
@classmethod
|
|
def _merge_dicts(
|
|
cls, src: Dict[str, Any], data: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Find all the items in data that match a target_field on this class,
|
|
and merge them with the data found in `src` for target_field, using the
|
|
field's specified merge behavior. Matching items will be removed from
|
|
`data` (but _not_ `src`!).
|
|
|
|
Returns a dict with the merge results.
|
|
|
|
That means this method mutates its input! Any remaining values in data
|
|
were not merged.
|
|
"""
|
|
result = {}
|
|
|
|
for fld, target_field in cls._get_fields():
|
|
if target_field not in data:
|
|
continue
|
|
|
|
data_attr = data.pop(target_field)
|
|
if target_field not in src:
|
|
result[target_field] = data_attr
|
|
continue
|
|
|
|
merge_behavior = MergeBehavior.from_field(fld)
|
|
self_attr = src[target_field]
|
|
|
|
result[target_field] = _merge_field_value(
|
|
merge_behavior=merge_behavior,
|
|
self_value=self_attr,
|
|
other_value=data_attr,
|
|
)
|
|
return result
|
|
|
|
def update_from(
|
|
self: T, data: Dict[str, Any], adapter_type: str, validate: bool = True
|
|
) -> T:
|
|
"""Given a dict of keys, update the current config from them, validate
|
|
it, and return a new config with the updated values
|
|
"""
|
|
# sadly, this is a circular import
|
|
from dbt.adapters.factory import get_config_class_by_name
|
|
dct = self.to_dict(omit_none=False)
|
|
|
|
adapter_config_cls = get_config_class_by_name(adapter_type)
|
|
|
|
self_merged = self._merge_dicts(dct, data)
|
|
dct.update(self_merged)
|
|
|
|
adapter_merged = adapter_config_cls._merge_dicts(dct, data)
|
|
dct.update(adapter_merged)
|
|
|
|
# any remaining fields must be "clobber"
|
|
dct.update(data)
|
|
|
|
# any validation failures must have come from the update
|
|
if validate:
|
|
self.validate(dct)
|
|
return self.from_dict(dct)
|
|
|
|
def finalize_and_validate(self: T) -> T:
|
|
dct = self.to_dict(omit_none=False)
|
|
self.validate(dct)
|
|
return self.from_dict(dct)
|
|
|
|
def replace(self, **kwargs):
|
|
dct = self.to_dict(omit_none=True)
|
|
|
|
mapping = self.field_mapping()
|
|
for key, value in kwargs.items():
|
|
new_key = mapping.get(key, key)
|
|
dct[new_key] = value
|
|
return self.from_dict(dct)
|
|
|
|
|
|
@dataclass
|
|
class SourceConfig(BaseConfig):
|
|
enabled: bool = True
|
|
|
|
|
|
@dataclass
|
|
class NodeAndTestConfig(BaseConfig):
|
|
enabled: bool = True
|
|
# these fields are included in serialized output, but are not part of
|
|
# config comparison (they are part of database_representation)
|
|
alias: Optional[str] = field(
|
|
default=None,
|
|
metadata=CompareBehavior.Exclude.meta(),
|
|
)
|
|
schema: Optional[str] = field(
|
|
default=None,
|
|
metadata=CompareBehavior.Exclude.meta(),
|
|
)
|
|
database: Optional[str] = field(
|
|
default=None,
|
|
metadata=CompareBehavior.Exclude.meta(),
|
|
)
|
|
tags: Union[List[str], str] = field(
|
|
default_factory=list_str,
|
|
metadata=metas(ShowBehavior.Hide,
|
|
MergeBehavior.Append,
|
|
CompareBehavior.Exclude),
|
|
)
|
|
meta: Dict[str, Any] = field(
|
|
default_factory=dict,
|
|
metadata=MergeBehavior.Update.meta(),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class NodeConfig(NodeAndTestConfig):
|
|
# Note: if any new fields are added with MergeBehavior, also update the
|
|
# 'mergebehavior' dictionary
|
|
materialized: str = 'view'
|
|
persist_docs: Dict[str, Any] = field(default_factory=dict)
|
|
post_hook: List[Hook] = field(
|
|
default_factory=list,
|
|
metadata=MergeBehavior.Append.meta(),
|
|
)
|
|
pre_hook: List[Hook] = field(
|
|
default_factory=list,
|
|
metadata=MergeBehavior.Append.meta(),
|
|
)
|
|
quoting: Dict[str, Any] = field(
|
|
default_factory=dict,
|
|
metadata=MergeBehavior.Update.meta(),
|
|
)
|
|
# This is actually only used by seeds. Should it be available to others?
|
|
# That would be a breaking change!
|
|
column_types: Dict[str, Any] = field(
|
|
default_factory=dict,
|
|
metadata=MergeBehavior.Update.meta(),
|
|
)
|
|
full_refresh: Optional[bool] = None
|
|
on_schema_change: Optional[str] = 'ignore'
|
|
|
|
@classmethod
|
|
def __pre_deserialize__(cls, data):
|
|
data = super().__pre_deserialize__(data)
|
|
field_map = {'post-hook': 'post_hook', 'pre-hook': 'pre_hook'}
|
|
# create a new dict because otherwise it gets overwritten in
|
|
# tests
|
|
new_dict = {}
|
|
for key in data:
|
|
new_dict[key] = data[key]
|
|
data = new_dict
|
|
for key in hooks.ModelHookType:
|
|
if key in data:
|
|
data[key] = [hooks.get_hook_dict(h) for h in data[key]]
|
|
for field_name in field_map:
|
|
if field_name in data:
|
|
new_name = field_map[field_name]
|
|
data[new_name] = data.pop(field_name)
|
|
return data
|
|
|
|
def __post_serialize__(self, dct):
|
|
dct = super().__post_serialize__(dct)
|
|
field_map = {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
|
|
for field_name in field_map:
|
|
if field_name in dct:
|
|
dct[field_map[field_name]] = dct.pop(field_name)
|
|
return dct
|
|
|
|
# this is still used by jsonschema validation
|
|
@classmethod
|
|
def field_mapping(cls):
|
|
return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
|
|
|
|
|
|
@dataclass
|
|
class SeedConfig(NodeConfig):
|
|
materialized: str = 'seed'
|
|
quote_columns: Optional[bool] = None
|
|
|
|
|
|
@dataclass
|
|
class TestConfig(NodeAndTestConfig):
|
|
# this is repeated because of a different default
|
|
schema: Optional[str] = field(
|
|
default='dbt_test__audit',
|
|
metadata=CompareBehavior.Exclude.meta(),
|
|
)
|
|
materialized: str = 'test'
|
|
severity: Severity = Severity('ERROR')
|
|
store_failures: Optional[bool] = None
|
|
where: Optional[str] = None
|
|
limit: Optional[int] = None
|
|
fail_calc: str = 'count(*)'
|
|
warn_if: str = '!= 0'
|
|
error_if: str = '!= 0'
|
|
|
|
@classmethod
|
|
def same_contents(
|
|
cls, unrendered: Dict[str, Any], other: Dict[str, Any]
|
|
) -> bool:
|
|
"""This is like __eq__, except it explicitly checks certain fields."""
|
|
modifiers = [
|
|
'severity',
|
|
'where',
|
|
'limit',
|
|
'fail_calc',
|
|
'warn_if',
|
|
'error_if',
|
|
'store_failures'
|
|
]
|
|
|
|
seen = set()
|
|
for _, target_name in cls._get_fields():
|
|
key = target_name
|
|
seen.add(key)
|
|
if key in modifiers:
|
|
if not cls.compare_key(unrendered, other, key):
|
|
return False
|
|
return True
|
|
|
|
|
|
@dataclass
|
|
class EmptySnapshotConfig(NodeConfig):
|
|
materialized: str = 'snapshot'
|
|
|
|
|
|
@dataclass
|
|
class SnapshotConfig(EmptySnapshotConfig):
|
|
strategy: Optional[str] = None
|
|
unique_key: Optional[str] = None
|
|
target_schema: Optional[str] = None
|
|
target_database: Optional[str] = None
|
|
updated_at: Optional[str] = None
|
|
check_cols: Optional[Union[str, List[str]]] = None
|
|
|
|
@classmethod
|
|
def validate(cls, data):
|
|
super().validate(data)
|
|
if not data.get('strategy') or not data.get('unique_key') or not \
|
|
data.get('target_schema'):
|
|
raise ValidationError(
|
|
"Snapshots must be configured with a 'strategy', 'unique_key', "
|
|
"and 'target_schema'.")
|
|
if data.get('strategy') == 'check':
|
|
if not data.get('check_cols'):
|
|
raise ValidationError(
|
|
"A snapshot configured with the check strategy must "
|
|
"specify a check_cols configuration.")
|
|
if (isinstance(data['check_cols'], str) and
|
|
data['check_cols'] != 'all'):
|
|
raise ValidationError(
|
|
f"Invalid value for 'check_cols': {data['check_cols']}. "
|
|
"Expected 'all' or a list of strings.")
|
|
|
|
elif data.get('strategy') == 'timestamp':
|
|
if not data.get('updated_at'):
|
|
raise ValidationError(
|
|
"A snapshot configured with the timestamp strategy "
|
|
"must specify an updated_at configuration.")
|
|
if data.get('check_cols'):
|
|
raise ValidationError(
|
|
"A 'timestamp' snapshot should not have 'check_cols'")
|
|
# If the strategy is not 'check' or 'timestamp' it's a custom strategy,
|
|
# formerly supported with GenericSnapshotConfig
|
|
|
|
def finalize_and_validate(self):
|
|
data = self.to_dict(omit_none=True)
|
|
self.validate(data)
|
|
return self.from_dict(data)
|
|
|
|
|
|
RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = {
|
|
NodeType.Source: SourceConfig,
|
|
NodeType.Seed: SeedConfig,
|
|
NodeType.Test: TestConfig,
|
|
NodeType.Model: NodeConfig,
|
|
NodeType.Snapshot: SnapshotConfig,
|
|
}
|
|
|
|
|
|
# base resource types are like resource types, except nothing has mandatory
|
|
# configs.
|
|
BASE_RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = RESOURCE_TYPES.copy()
|
|
BASE_RESOURCE_TYPES.update({
|
|
NodeType.Snapshot: EmptySnapshotConfig
|
|
})
|
|
|
|
|
|
def get_config_for(resource_type: NodeType, base=False) -> Type[BaseConfig]:
|
|
if base:
|
|
lookup = BASE_RESOURCE_TYPES
|
|
else:
|
|
lookup = RESOURCE_TYPES
|
|
return lookup.get(resource_type, NodeConfig)
|