dbt-selly/dbt-env/lib/python3.8/site-packages/dbt/parser/schema_test_builders.py

439 lines
13 KiB
Python
Raw Normal View History

2022-03-22 15:13:27 +00:00
import hashlib
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import (
Generic, TypeVar, Dict, Any, Tuple, Optional, List,
)
from dbt.clients.jinja import get_rendered, SCHEMA_TEST_KWARGS_NAME
from dbt.contracts.graph.parsed import UnpatchedSourceDefinition
from dbt.contracts.graph.unparsed import (
TestDef,
UnparsedAnalysisUpdate,
UnparsedMacroUpdate,
UnparsedNodeUpdate,
UnparsedExposure,
)
from dbt.exceptions import raise_compiler_error
from dbt.parser.search import FileBlock
def get_nice_schema_test_name(
test_type: str, test_name: str, args: Dict[str, Any]
) -> Tuple[str, str]:
flat_args = []
for arg_name in sorted(args):
# the model is already embedded in the name, so skip it
if arg_name == 'model':
continue
arg_val = args[arg_name]
if isinstance(arg_val, dict):
parts = list(arg_val.values())
elif isinstance(arg_val, (list, tuple)):
parts = list(arg_val)
else:
parts = [arg_val]
flat_args.extend([str(part) for part in parts])
clean_flat_args = [re.sub('[^0-9a-zA-Z_]+', '_', arg) for arg in flat_args]
unique = "__".join(clean_flat_args)
# for the file path + alias, the name must be <64 characters
# if the full name is too long, include the first 30 identifying chars plus
# a 32-character hash of the full contents
test_identifier = '{}_{}'.format(test_type, test_name)
full_name = '{}_{}'.format(test_identifier, unique)
if len(full_name) >= 64:
test_trunc_identifier = test_identifier[:30]
label = hashlib.md5(full_name.encode('utf-8')).hexdigest()
short_name = '{}_{}'.format(test_trunc_identifier, label)
else:
short_name = full_name
return short_name, full_name
@dataclass
class YamlBlock(FileBlock):
data: Dict[str, Any]
@classmethod
def from_file_block(cls, src: FileBlock, data: Dict[str, Any]):
return cls(
file=src.file,
data=data,
)
Testable = TypeVar(
'Testable', UnparsedNodeUpdate, UnpatchedSourceDefinition
)
ColumnTarget = TypeVar(
'ColumnTarget',
UnparsedNodeUpdate,
UnparsedAnalysisUpdate,
UnpatchedSourceDefinition,
)
Target = TypeVar(
'Target',
UnparsedNodeUpdate,
UnparsedMacroUpdate,
UnparsedAnalysisUpdate,
UnpatchedSourceDefinition,
UnparsedExposure,
)
@dataclass
class TargetBlock(YamlBlock, Generic[Target]):
target: Target
@property
def name(self):
return self.target.name
@property
def columns(self):
return []
@property
def tests(self) -> List[TestDef]:
return []
@classmethod
def from_yaml_block(
cls, src: YamlBlock, target: Target
) -> 'TargetBlock[Target]':
return cls(
file=src.file,
data=src.data,
target=target,
)
@dataclass
class TargetColumnsBlock(TargetBlock[ColumnTarget], Generic[ColumnTarget]):
@property
def columns(self):
if self.target.columns is None:
return []
else:
return self.target.columns
@dataclass
class TestBlock(TargetColumnsBlock[Testable], Generic[Testable]):
@property
def tests(self) -> List[TestDef]:
if self.target.tests is None:
return []
else:
return self.target.tests
@property
def quote_columns(self) -> Optional[bool]:
return self.target.quote_columns
@classmethod
def from_yaml_block(
cls, src: YamlBlock, target: Testable
) -> 'TestBlock[Testable]':
return cls(
file=src.file,
data=src.data,
target=target,
)
@dataclass
class SchemaTestBlock(TestBlock[Testable], Generic[Testable]):
test: Dict[str, Any]
column_name: Optional[str]
tags: List[str]
@classmethod
def from_test_block(
cls,
src: TestBlock,
test: Dict[str, Any],
column_name: Optional[str],
tags: List[str],
) -> 'SchemaTestBlock':
return cls(
file=src.file,
data=src.data,
target=src.target,
test=test,
column_name=column_name,
tags=tags,
)
class TestBuilder(Generic[Testable]):
"""An object to hold assorted test settings and perform basic parsing
Test names have the following pattern:
- the test name itself may be namespaced (package.test)
- or it may not be namespaced (test)
"""
# The 'test_name' is used to find the 'macro' that implements the test
TEST_NAME_PATTERN = re.compile(
r'((?P<test_namespace>([a-zA-Z_][0-9a-zA-Z_]*))\.)?'
r'(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))'
)
# kwargs representing test configs
CONFIG_ARGS = (
'severity', 'tags', 'enabled', 'where', 'limit', 'warn_if', 'error_if',
'fail_calc', 'store_failures', 'meta', 'database', 'schema', 'alias',
)
def __init__(
self,
test: Dict[str, Any],
target: Testable,
package_name: str,
render_ctx: Dict[str, Any],
column_name: str = None,
) -> None:
test_name, test_args = self.extract_test_args(test, column_name)
self.args: Dict[str, Any] = test_args
if 'model' in self.args:
raise_compiler_error(
'Test arguments include "model", which is a reserved argument',
)
self.package_name: str = package_name
self.target: Testable = target
self.args['model'] = self.build_model_str()
match = self.TEST_NAME_PATTERN.match(test_name)
if match is None:
raise_compiler_error(
'Test name string did not match expected pattern: {}'
.format(test_name)
)
groups = match.groupdict()
self.name: str = groups['test_name']
self.namespace: str = groups['test_namespace']
self.config: Dict[str, Any] = {}
for key in self.CONFIG_ARGS:
value = self.args.pop(key, None)
# 'modifier' config could be either top level arg or in config
if value and 'config' in self.args and key in self.args['config']:
raise_compiler_error(
'Test cannot have the same key at the top-level and in config'
)
if not value and 'config' in self.args:
value = self.args['config'].pop(key, None)
if isinstance(value, str):
value = get_rendered(value, render_ctx, native=True)
if value is not None:
self.config[key] = value
if 'config' in self.args:
del self.args['config']
if self.namespace is not None:
self.package_name = self.namespace
compiled_name, fqn_name = self.get_test_name()
self.compiled_name: str = compiled_name
self.fqn_name: str = fqn_name
# use hashed name as alias if too long
if compiled_name != fqn_name and 'alias' not in self.config:
self.config['alias'] = compiled_name
def _bad_type(self) -> TypeError:
return TypeError('invalid target type "{}"'.format(type(self.target)))
@staticmethod
def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]:
if not isinstance(test, dict):
raise_compiler_error(
'test must be dict or str, got {} (value {})'.format(
type(test), test
)
)
test = list(test.items())
if len(test) != 1:
raise_compiler_error(
'test definition dictionary must have exactly one key, got'
' {} instead ({} keys)'.format(test, len(test))
)
test_name, test_args = test[0]
if not isinstance(test_args, dict):
raise_compiler_error(
'test arguments must be dict, got {} (value {})'.format(
type(test_args), test_args
)
)
if not isinstance(test_name, str):
raise_compiler_error(
'test name must be a str, got {} (value {})'.format(
type(test_name), test_name
)
)
test_args = deepcopy(test_args)
if name is not None:
test_args['column_name'] = name
return test_name, test_args
@property
def enabled(self) -> Optional[bool]:
return self.config.get('enabled')
@property
def alias(self) -> Optional[str]:
return self.config.get('alias')
@property
def severity(self) -> Optional[str]:
sev = self.config.get('severity')
if sev:
return sev.upper()
else:
return None
@property
def store_failures(self) -> Optional[bool]:
return self.config.get('store_failures')
@property
def where(self) -> Optional[str]:
return self.config.get('where')
@property
def limit(self) -> Optional[int]:
return self.config.get('limit')
@property
def warn_if(self) -> Optional[str]:
return self.config.get('warn_if')
@property
def error_if(self) -> Optional[str]:
return self.config.get('error_if')
@property
def fail_calc(self) -> Optional[str]:
return self.config.get('fail_calc')
@property
def meta(self) -> Optional[dict]:
return self.config.get('meta')
@property
def database(self) -> Optional[str]:
return self.config.get('database')
@property
def schema(self) -> Optional[str]:
return self.config.get('schema')
def get_static_config(self):
config = {}
if self.alias is not None:
config['alias'] = self.alias
if self.severity is not None:
config['severity'] = self.severity
if self.enabled is not None:
config['enabled'] = self.enabled
if self.where is not None:
config['where'] = self.where
if self.limit is not None:
config['limit'] = self.limit
if self.warn_if is not None:
config['warn_if'] = self.warn_if
if self.error_if is not None:
config['error_if'] = self.error_if
if self.fail_calc is not None:
config['fail_calc'] = self.fail_calc
if self.store_failures is not None:
config['store_failures'] = self.store_failures
if self.meta is not None:
config['meta'] = self.meta
if self.database is not None:
config['database'] = self.database
if self.schema is not None:
config['schema'] = self.schema
return config
def tags(self) -> List[str]:
tags = self.config.get('tags', [])
if isinstance(tags, str):
tags = [tags]
if not isinstance(tags, list):
raise_compiler_error(
f'got {tags} ({type(tags)}) for tags, expected a list of '
f'strings'
)
for tag in tags:
if not isinstance(tag, str):
raise_compiler_error(
f'got {tag} ({type(tag)}) for tag, expected a str'
)
return tags[:]
def macro_name(self) -> str:
macro_name = 'test_{}'.format(self.name)
if self.namespace is not None:
macro_name = "{}.{}".format(self.namespace, macro_name)
return macro_name
def get_test_name(self) -> Tuple[str, str]:
if isinstance(self.target, UnparsedNodeUpdate):
name = self.name
elif isinstance(self.target, UnpatchedSourceDefinition):
name = 'source_' + self.name
else:
raise self._bad_type()
if self.namespace is not None:
name = '{}_{}'.format(self.namespace, name)
return get_nice_schema_test_name(name, self.target.name, self.args)
def construct_config(self) -> str:
configs = ",".join([
f"{key}=" + (
("\"" + value.replace('\"', '\\\"') + "\"") if isinstance(value, str)
else str(value)
)
for key, value
in self.config.items()
])
if configs:
return f"{{{{ config({configs}) }}}}"
else:
return ""
# this is the 'raw_sql' that's used in 'render_update' and execution
# of the test macro
def build_raw_sql(self) -> str:
return (
"{{{{ {macro}(**{kwargs_name}) }}}}{config}"
).format(
macro=self.macro_name(),
config=self.construct_config(),
kwargs_name=SCHEMA_TEST_KWARGS_NAME,
)
def build_model_str(self):
targ = self.target
if isinstance(self.target, UnparsedNodeUpdate):
target_str = f"ref('{targ.name}')"
elif isinstance(self.target, UnpatchedSourceDefinition):
target_str = f"source('{targ.source.name}', '{targ.table.name}')"
return f"{{{{ get_where_subquery({target_str}) }}}}"