import hashlib import re from copy import deepcopy from dataclasses import dataclass from typing import ( Generic, TypeVar, Dict, Any, Tuple, Optional, List, ) from dbt.clients.jinja import get_rendered, SCHEMA_TEST_KWARGS_NAME from dbt.contracts.graph.parsed import UnpatchedSourceDefinition from dbt.contracts.graph.unparsed import ( TestDef, UnparsedAnalysisUpdate, UnparsedMacroUpdate, UnparsedNodeUpdate, UnparsedExposure, ) from dbt.exceptions import raise_compiler_error from dbt.parser.search import FileBlock def get_nice_schema_test_name( test_type: str, test_name: str, args: Dict[str, Any] ) -> Tuple[str, str]: flat_args = [] for arg_name in sorted(args): # the model is already embedded in the name, so skip it if arg_name == 'model': continue arg_val = args[arg_name] if isinstance(arg_val, dict): parts = list(arg_val.values()) elif isinstance(arg_val, (list, tuple)): parts = list(arg_val) else: parts = [arg_val] flat_args.extend([str(part) for part in parts]) clean_flat_args = [re.sub('[^0-9a-zA-Z_]+', '_', arg) for arg in flat_args] unique = "__".join(clean_flat_args) # for the file path + alias, the name must be <64 characters # if the full name is too long, include the first 30 identifying chars plus # a 32-character hash of the full contents test_identifier = '{}_{}'.format(test_type, test_name) full_name = '{}_{}'.format(test_identifier, unique) if len(full_name) >= 64: test_trunc_identifier = test_identifier[:30] label = hashlib.md5(full_name.encode('utf-8')).hexdigest() short_name = '{}_{}'.format(test_trunc_identifier, label) else: short_name = full_name return short_name, full_name @dataclass class YamlBlock(FileBlock): data: Dict[str, Any] @classmethod def from_file_block(cls, src: FileBlock, data: Dict[str, Any]): return cls( file=src.file, data=data, ) Testable = TypeVar( 'Testable', UnparsedNodeUpdate, UnpatchedSourceDefinition ) ColumnTarget = TypeVar( 'ColumnTarget', UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnpatchedSourceDefinition, ) Target = TypeVar( 'Target', UnparsedNodeUpdate, UnparsedMacroUpdate, UnparsedAnalysisUpdate, UnpatchedSourceDefinition, UnparsedExposure, ) @dataclass class TargetBlock(YamlBlock, Generic[Target]): target: Target @property def name(self): return self.target.name @property def columns(self): return [] @property def tests(self) -> List[TestDef]: return [] @classmethod def from_yaml_block( cls, src: YamlBlock, target: Target ) -> 'TargetBlock[Target]': return cls( file=src.file, data=src.data, target=target, ) @dataclass class TargetColumnsBlock(TargetBlock[ColumnTarget], Generic[ColumnTarget]): @property def columns(self): if self.target.columns is None: return [] else: return self.target.columns @dataclass class TestBlock(TargetColumnsBlock[Testable], Generic[Testable]): @property def tests(self) -> List[TestDef]: if self.target.tests is None: return [] else: return self.target.tests @property def quote_columns(self) -> Optional[bool]: return self.target.quote_columns @classmethod def from_yaml_block( cls, src: YamlBlock, target: Testable ) -> 'TestBlock[Testable]': return cls( file=src.file, data=src.data, target=target, ) @dataclass class SchemaTestBlock(TestBlock[Testable], Generic[Testable]): test: Dict[str, Any] column_name: Optional[str] tags: List[str] @classmethod def from_test_block( cls, src: TestBlock, test: Dict[str, Any], column_name: Optional[str], tags: List[str], ) -> 'SchemaTestBlock': return cls( file=src.file, data=src.data, target=src.target, test=test, column_name=column_name, tags=tags, ) class TestBuilder(Generic[Testable]): """An object to hold assorted test settings and perform basic parsing Test names have the following pattern: - the test name itself may be namespaced (package.test) - or it may not be namespaced (test) """ # The 'test_name' is used to find the 'macro' that implements the test TEST_NAME_PATTERN = re.compile( r'((?P([a-zA-Z_][0-9a-zA-Z_]*))\.)?' r'(?P([a-zA-Z_][0-9a-zA-Z_]*))' ) # kwargs representing test configs CONFIG_ARGS = ( 'severity', 'tags', 'enabled', 'where', 'limit', 'warn_if', 'error_if', 'fail_calc', 'store_failures', 'meta', 'database', 'schema', 'alias', ) def __init__( self, test: Dict[str, Any], target: Testable, package_name: str, render_ctx: Dict[str, Any], column_name: str = None, ) -> None: test_name, test_args = self.extract_test_args(test, column_name) self.args: Dict[str, Any] = test_args if 'model' in self.args: raise_compiler_error( 'Test arguments include "model", which is a reserved argument', ) self.package_name: str = package_name self.target: Testable = target self.args['model'] = self.build_model_str() match = self.TEST_NAME_PATTERN.match(test_name) if match is None: raise_compiler_error( 'Test name string did not match expected pattern: {}' .format(test_name) ) groups = match.groupdict() self.name: str = groups['test_name'] self.namespace: str = groups['test_namespace'] self.config: Dict[str, Any] = {} for key in self.CONFIG_ARGS: value = self.args.pop(key, None) # 'modifier' config could be either top level arg or in config if value and 'config' in self.args and key in self.args['config']: raise_compiler_error( 'Test cannot have the same key at the top-level and in config' ) if not value and 'config' in self.args: value = self.args['config'].pop(key, None) if isinstance(value, str): value = get_rendered(value, render_ctx, native=True) if value is not None: self.config[key] = value if 'config' in self.args: del self.args['config'] if self.namespace is not None: self.package_name = self.namespace compiled_name, fqn_name = self.get_test_name() self.compiled_name: str = compiled_name self.fqn_name: str = fqn_name # use hashed name as alias if too long if compiled_name != fqn_name and 'alias' not in self.config: self.config['alias'] = compiled_name def _bad_type(self) -> TypeError: return TypeError('invalid target type "{}"'.format(type(self.target))) @staticmethod def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]: if not isinstance(test, dict): raise_compiler_error( 'test must be dict or str, got {} (value {})'.format( type(test), test ) ) test = list(test.items()) if len(test) != 1: raise_compiler_error( 'test definition dictionary must have exactly one key, got' ' {} instead ({} keys)'.format(test, len(test)) ) test_name, test_args = test[0] if not isinstance(test_args, dict): raise_compiler_error( 'test arguments must be dict, got {} (value {})'.format( type(test_args), test_args ) ) if not isinstance(test_name, str): raise_compiler_error( 'test name must be a str, got {} (value {})'.format( type(test_name), test_name ) ) test_args = deepcopy(test_args) if name is not None: test_args['column_name'] = name return test_name, test_args @property def enabled(self) -> Optional[bool]: return self.config.get('enabled') @property def alias(self) -> Optional[str]: return self.config.get('alias') @property def severity(self) -> Optional[str]: sev = self.config.get('severity') if sev: return sev.upper() else: return None @property def store_failures(self) -> Optional[bool]: return self.config.get('store_failures') @property def where(self) -> Optional[str]: return self.config.get('where') @property def limit(self) -> Optional[int]: return self.config.get('limit') @property def warn_if(self) -> Optional[str]: return self.config.get('warn_if') @property def error_if(self) -> Optional[str]: return self.config.get('error_if') @property def fail_calc(self) -> Optional[str]: return self.config.get('fail_calc') @property def meta(self) -> Optional[dict]: return self.config.get('meta') @property def database(self) -> Optional[str]: return self.config.get('database') @property def schema(self) -> Optional[str]: return self.config.get('schema') def get_static_config(self): config = {} if self.alias is not None: config['alias'] = self.alias if self.severity is not None: config['severity'] = self.severity if self.enabled is not None: config['enabled'] = self.enabled if self.where is not None: config['where'] = self.where if self.limit is not None: config['limit'] = self.limit if self.warn_if is not None: config['warn_if'] = self.warn_if if self.error_if is not None: config['error_if'] = self.error_if if self.fail_calc is not None: config['fail_calc'] = self.fail_calc if self.store_failures is not None: config['store_failures'] = self.store_failures if self.meta is not None: config['meta'] = self.meta if self.database is not None: config['database'] = self.database if self.schema is not None: config['schema'] = self.schema return config def tags(self) -> List[str]: tags = self.config.get('tags', []) if isinstance(tags, str): tags = [tags] if not isinstance(tags, list): raise_compiler_error( f'got {tags} ({type(tags)}) for tags, expected a list of ' f'strings' ) for tag in tags: if not isinstance(tag, str): raise_compiler_error( f'got {tag} ({type(tag)}) for tag, expected a str' ) return tags[:] def macro_name(self) -> str: macro_name = 'test_{}'.format(self.name) if self.namespace is not None: macro_name = "{}.{}".format(self.namespace, macro_name) return macro_name def get_test_name(self) -> Tuple[str, str]: if isinstance(self.target, UnparsedNodeUpdate): name = self.name elif isinstance(self.target, UnpatchedSourceDefinition): name = 'source_' + self.name else: raise self._bad_type() if self.namespace is not None: name = '{}_{}'.format(self.namespace, name) return get_nice_schema_test_name(name, self.target.name, self.args) def construct_config(self) -> str: configs = ",".join([ f"{key}=" + ( ("\"" + value.replace('\"', '\\\"') + "\"") if isinstance(value, str) else str(value) ) for key, value in self.config.items() ]) if configs: return f"{{{{ config({configs}) }}}}" else: return "" # this is the 'raw_sql' that's used in 'render_update' and execution # of the test macro def build_raw_sql(self) -> str: return ( "{{{{ {macro}(**{kwargs_name}) }}}}{config}" ).format( macro=self.macro_name(), config=self.construct_config(), kwargs_name=SCHEMA_TEST_KWARGS_NAME, ) def build_model_str(self): targ = self.target if isinstance(self.target, UnparsedNodeUpdate): target_str = f"ref('{targ.name}')" elif isinstance(self.target, UnpatchedSourceDefinition): target_str = f"source('{targ.source.name}', '{targ.table.name}')" return f"{{{{ get_where_subquery({target_str}) }}}}"