from dataclasses import field, Field, dataclass from enum import Enum from itertools import chain from typing import ( Any, List, Optional, Dict, Union, Type, TypeVar, Callable ) from dbt.dataclass_schema import ( dbtClassMixin, ValidationError, register_pattern, ) from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed from dbt.exceptions import InternalException, CompilationException from dbt.contracts.util import Replaceable, list_str from dbt import hooks from dbt.node_types import NodeType M = TypeVar('M', bound='Metadata') def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M: # a metadata field might exist. If it does, it might have a matching key. # If it has both, make sure the value is valid and return it. If it # doesn't, return the default. if fld.metadata: value = fld.metadata.get(key, default) else: value = default try: return cls(value) except ValueError as exc: raise InternalException( f'Invalid {cls} value: {value}' ) from exc def _set_meta_value( obj: M, key: str, existing: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: if existing is None: result = {} else: result = existing.copy() result.update({key: obj}) return result class Metadata(Enum): @classmethod def from_field(cls: Type[M], fld: Field) -> M: default = cls.default_field() key = cls.metadata_key() return _get_meta_value(cls, fld, key, default) def meta( self, existing: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: key = self.metadata_key() return _set_meta_value(self, key, existing) @classmethod def default_field(cls) -> 'Metadata': raise NotImplementedError('Not implemented') @classmethod def metadata_key(cls) -> str: raise NotImplementedError('Not implemented') class MergeBehavior(Metadata): Append = 1 Update = 2 Clobber = 3 @classmethod def default_field(cls) -> 'MergeBehavior': return cls.Clobber @classmethod def metadata_key(cls) -> str: return 'merge' class ShowBehavior(Metadata): Show = 1 Hide = 2 @classmethod def default_field(cls) -> 'ShowBehavior': return cls.Show @classmethod def metadata_key(cls) -> str: return 'show_hide' @classmethod def should_show(cls, fld: Field) -> bool: return cls.from_field(fld) == cls.Show class CompareBehavior(Metadata): Include = 1 Exclude = 2 @classmethod def default_field(cls) -> 'CompareBehavior': return cls.Include @classmethod def metadata_key(cls) -> str: return 'compare' @classmethod def should_include(cls, fld: Field) -> bool: return cls.from_field(fld) == cls.Include def metas(*metas: Metadata) -> Dict[str, Any]: existing: Dict[str, Any] = {} for m in metas: existing = m.meta(existing) return existing def _listify(value: Any) -> List: if isinstance(value, list): return value[:] else: return [value] def _merge_field_value( merge_behavior: MergeBehavior, self_value: Any, other_value: Any, ): if merge_behavior == MergeBehavior.Clobber: return other_value elif merge_behavior == MergeBehavior.Append: return _listify(self_value) + _listify(other_value) elif merge_behavior == MergeBehavior.Update: if not isinstance(self_value, dict): raise InternalException(f'expected dict, got {self_value}') if not isinstance(other_value, dict): raise InternalException(f'expected dict, got {other_value}') value = self_value.copy() value.update(other_value) return value else: raise InternalException( f'Got an invalid merge_behavior: {merge_behavior}' ) def insensitive_patterns(*patterns: str): lowercased = [] for pattern in patterns: lowercased.append( ''.join('[{}{}]'.format(s.upper(), s.lower()) for s in pattern) ) return '^({})$'.format('|'.join(lowercased)) class Severity(str): pass register_pattern(Severity, insensitive_patterns('warn', 'error')) @dataclass class Hook(dbtClassMixin, Replaceable): sql: str transaction: bool = True index: Optional[int] = None T = TypeVar('T', bound='BaseConfig') @dataclass class BaseConfig( AdditionalPropertiesAllowed, Replaceable ): # enable syntax like: config['key'] def __getitem__(self, key): return self.get(key) # like doing 'get' on a dictionary def get(self, key, default=None): if hasattr(self, key): return getattr(self, key) elif key in self._extra: return self._extra[key] else: return default # enable syntax like: config['key'] = value def __setitem__(self, key, value): if hasattr(self, key): setattr(self, key, value) else: self._extra[key] = value def __delitem__(self, key): if hasattr(self, key): msg = ( 'Error, tried to delete config key "{}": Cannot delete ' 'built-in keys' ).format(key) raise CompilationException(msg) else: del self._extra[key] def _content_iterator(self, include_condition: Callable[[Field], bool]): seen = set() for fld, _ in self._get_fields(): seen.add(fld.name) if include_condition(fld): yield fld.name for key in self._extra: if key not in seen: seen.add(key) yield key def __iter__(self): yield from self._content_iterator(include_condition=lambda f: True) def __len__(self): return len(self._get_fields()) + len(self._extra) @staticmethod def compare_key( unrendered: Dict[str, Any], other: Dict[str, Any], key: str, ) -> bool: if key not in unrendered and key not in other: return True elif key not in unrendered and key in other: return False elif key in unrendered and key not in other: return False else: return unrendered[key] == other[key] @classmethod def same_contents( cls, unrendered: Dict[str, Any], other: Dict[str, Any] ) -> bool: """This is like __eq__, except it ignores some fields.""" seen = set() for fld, target_name in cls._get_fields(): key = target_name seen.add(key) if CompareBehavior.should_include(fld): if not cls.compare_key(unrendered, other, key): return False for key in chain(unrendered, other): if key not in seen: seen.add(key) if not cls.compare_key(unrendered, other, key): return False return True # This is used in 'add_config_call' to created the combined config_call_dict. # 'meta' moved here from node mergebehavior = { "append": ['pre-hook', 'pre_hook', 'post-hook', 'post_hook', 'tags'], "update": ['quoting', 'column_types', 'meta'], } @classmethod def _merge_dicts( cls, src: Dict[str, Any], data: Dict[str, Any] ) -> Dict[str, Any]: """Find all the items in data that match a target_field on this class, and merge them with the data found in `src` for target_field, using the field's specified merge behavior. Matching items will be removed from `data` (but _not_ `src`!). Returns a dict with the merge results. That means this method mutates its input! Any remaining values in data were not merged. """ result = {} for fld, target_field in cls._get_fields(): if target_field not in data: continue data_attr = data.pop(target_field) if target_field not in src: result[target_field] = data_attr continue merge_behavior = MergeBehavior.from_field(fld) self_attr = src[target_field] result[target_field] = _merge_field_value( merge_behavior=merge_behavior, self_value=self_attr, other_value=data_attr, ) return result def update_from( self: T, data: Dict[str, Any], adapter_type: str, validate: bool = True ) -> T: """Given a dict of keys, update the current config from them, validate it, and return a new config with the updated values """ # sadly, this is a circular import from dbt.adapters.factory import get_config_class_by_name dct = self.to_dict(omit_none=False) adapter_config_cls = get_config_class_by_name(adapter_type) self_merged = self._merge_dicts(dct, data) dct.update(self_merged) adapter_merged = adapter_config_cls._merge_dicts(dct, data) dct.update(adapter_merged) # any remaining fields must be "clobber" dct.update(data) # any validation failures must have come from the update if validate: self.validate(dct) return self.from_dict(dct) def finalize_and_validate(self: T) -> T: dct = self.to_dict(omit_none=False) self.validate(dct) return self.from_dict(dct) def replace(self, **kwargs): dct = self.to_dict(omit_none=True) mapping = self.field_mapping() for key, value in kwargs.items(): new_key = mapping.get(key, key) dct[new_key] = value return self.from_dict(dct) @dataclass class SourceConfig(BaseConfig): enabled: bool = True @dataclass class NodeAndTestConfig(BaseConfig): enabled: bool = True # these fields are included in serialized output, but are not part of # config comparison (they are part of database_representation) alias: Optional[str] = field( default=None, metadata=CompareBehavior.Exclude.meta(), ) schema: Optional[str] = field( default=None, metadata=CompareBehavior.Exclude.meta(), ) database: Optional[str] = field( default=None, metadata=CompareBehavior.Exclude.meta(), ) tags: Union[List[str], str] = field( default_factory=list_str, metadata=metas(ShowBehavior.Hide, MergeBehavior.Append, CompareBehavior.Exclude), ) meta: Dict[str, Any] = field( default_factory=dict, metadata=MergeBehavior.Update.meta(), ) @dataclass class NodeConfig(NodeAndTestConfig): # Note: if any new fields are added with MergeBehavior, also update the # 'mergebehavior' dictionary materialized: str = 'view' persist_docs: Dict[str, Any] = field(default_factory=dict) post_hook: List[Hook] = field( default_factory=list, metadata=MergeBehavior.Append.meta(), ) pre_hook: List[Hook] = field( default_factory=list, metadata=MergeBehavior.Append.meta(), ) quoting: Dict[str, Any] = field( default_factory=dict, metadata=MergeBehavior.Update.meta(), ) # This is actually only used by seeds. Should it be available to others? # That would be a breaking change! column_types: Dict[str, Any] = field( default_factory=dict, metadata=MergeBehavior.Update.meta(), ) full_refresh: Optional[bool] = None on_schema_change: Optional[str] = 'ignore' @classmethod def __pre_deserialize__(cls, data): data = super().__pre_deserialize__(data) field_map = {'post-hook': 'post_hook', 'pre-hook': 'pre_hook'} # create a new dict because otherwise it gets overwritten in # tests new_dict = {} for key in data: new_dict[key] = data[key] data = new_dict for key in hooks.ModelHookType: if key in data: data[key] = [hooks.get_hook_dict(h) for h in data[key]] for field_name in field_map: if field_name in data: new_name = field_map[field_name] data[new_name] = data.pop(field_name) return data def __post_serialize__(self, dct): dct = super().__post_serialize__(dct) field_map = {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'} for field_name in field_map: if field_name in dct: dct[field_map[field_name]] = dct.pop(field_name) return dct # this is still used by jsonschema validation @classmethod def field_mapping(cls): return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'} @dataclass class SeedConfig(NodeConfig): materialized: str = 'seed' quote_columns: Optional[bool] = None @dataclass class TestConfig(NodeAndTestConfig): # this is repeated because of a different default schema: Optional[str] = field( default='dbt_test__audit', metadata=CompareBehavior.Exclude.meta(), ) materialized: str = 'test' severity: Severity = Severity('ERROR') store_failures: Optional[bool] = None where: Optional[str] = None limit: Optional[int] = None fail_calc: str = 'count(*)' warn_if: str = '!= 0' error_if: str = '!= 0' @classmethod def same_contents( cls, unrendered: Dict[str, Any], other: Dict[str, Any] ) -> bool: """This is like __eq__, except it explicitly checks certain fields.""" modifiers = [ 'severity', 'where', 'limit', 'fail_calc', 'warn_if', 'error_if', 'store_failures' ] seen = set() for _, target_name in cls._get_fields(): key = target_name seen.add(key) if key in modifiers: if not cls.compare_key(unrendered, other, key): return False return True @dataclass class EmptySnapshotConfig(NodeConfig): materialized: str = 'snapshot' @dataclass class SnapshotConfig(EmptySnapshotConfig): strategy: Optional[str] = None unique_key: Optional[str] = None target_schema: Optional[str] = None target_database: Optional[str] = None updated_at: Optional[str] = None check_cols: Optional[Union[str, List[str]]] = None @classmethod def validate(cls, data): super().validate(data) if not data.get('strategy') or not data.get('unique_key') or not \ data.get('target_schema'): raise ValidationError( "Snapshots must be configured with a 'strategy', 'unique_key', " "and 'target_schema'.") if data.get('strategy') == 'check': if not data.get('check_cols'): raise ValidationError( "A snapshot configured with the check strategy must " "specify a check_cols configuration.") if (isinstance(data['check_cols'], str) and data['check_cols'] != 'all'): raise ValidationError( f"Invalid value for 'check_cols': {data['check_cols']}. " "Expected 'all' or a list of strings.") elif data.get('strategy') == 'timestamp': if not data.get('updated_at'): raise ValidationError( "A snapshot configured with the timestamp strategy " "must specify an updated_at configuration.") if data.get('check_cols'): raise ValidationError( "A 'timestamp' snapshot should not have 'check_cols'") # If the strategy is not 'check' or 'timestamp' it's a custom strategy, # formerly supported with GenericSnapshotConfig def finalize_and_validate(self): data = self.to_dict(omit_none=True) self.validate(data) return self.from_dict(data) RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = { NodeType.Source: SourceConfig, NodeType.Seed: SeedConfig, NodeType.Test: TestConfig, NodeType.Model: NodeConfig, NodeType.Snapshot: SnapshotConfig, } # base resource types are like resource types, except nothing has mandatory # configs. BASE_RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = RESOURCE_TYPES.copy() BASE_RESOURCE_TYPES.update({ NodeType.Snapshot: EmptySnapshotConfig }) def get_config_for(resource_type: NodeType, base=False) -> Type[BaseConfig]: if base: lookup = BASE_RESOURCE_TYPES else: lookup = RESOURCE_TYPES return lookup.get(resource_type, NodeConfig)