217 lines
6.7 KiB
Python
217 lines
6.7 KiB
Python
from distutils.util import strtobool
|
|
|
|
from dataclasses import dataclass
|
|
from dbt import utils
|
|
from dbt.dataclass_schema import dbtClassMixin
|
|
import threading
|
|
from typing import Dict, Any, Union
|
|
|
|
from .compile import CompileRunner
|
|
from .run import RunTask
|
|
from .printer import print_start_line, print_test_result_line
|
|
|
|
from dbt.contracts.graph.compiled import (
|
|
CompiledDataTestNode,
|
|
CompiledSchemaTestNode,
|
|
CompiledTestNode,
|
|
)
|
|
from dbt.contracts.graph.manifest import Manifest
|
|
from dbt.contracts.results import TestStatus, PrimitiveDict, RunResult
|
|
from dbt.context.providers import generate_runtime_model
|
|
from dbt.clients.jinja import MacroGenerator
|
|
from dbt.exceptions import (
|
|
InternalException,
|
|
invalid_bool_error,
|
|
missing_materialization
|
|
)
|
|
from dbt.graph import (
|
|
ResourceTypeSelector,
|
|
SelectionSpec,
|
|
parse_test_selectors,
|
|
)
|
|
from dbt.node_types import NodeType, RunHookType
|
|
from dbt import flags
|
|
|
|
|
|
@dataclass
|
|
class TestResultData(dbtClassMixin):
|
|
failures: int
|
|
should_warn: bool
|
|
should_error: bool
|
|
|
|
@classmethod
|
|
def validate(cls, data):
|
|
data['should_warn'] = cls.convert_bool_type(data['should_warn'])
|
|
data['should_error'] = cls.convert_bool_type(data['should_error'])
|
|
super().validate(data)
|
|
|
|
def convert_bool_type(field) -> bool:
|
|
# if it's type string let python decide if it's a valid value to convert to bool
|
|
if isinstance(field, str):
|
|
try:
|
|
return bool(strtobool(field)) # type: ignore
|
|
except ValueError:
|
|
raise invalid_bool_error(field, 'get_test_sql')
|
|
|
|
# need this so we catch both true bools and 0/1
|
|
return bool(field)
|
|
|
|
|
|
class TestRunner(CompileRunner):
|
|
def describe_node(self):
|
|
node_name = self.node.name
|
|
return "test {}".format(node_name)
|
|
|
|
def print_result_line(self, result):
|
|
print_test_result_line(result, self.node_index, self.num_nodes)
|
|
|
|
def print_start_line(self):
|
|
description = self.describe_node()
|
|
print_start_line(description, self.node_index, self.num_nodes)
|
|
|
|
def before_execute(self):
|
|
self.print_start_line()
|
|
|
|
def execute_test(
|
|
self,
|
|
test: Union[CompiledDataTestNode, CompiledSchemaTestNode],
|
|
manifest: Manifest
|
|
) -> TestResultData:
|
|
context = generate_runtime_model(
|
|
test, self.config, manifest
|
|
)
|
|
|
|
materialization_macro = manifest.find_materialization_macro_by_name(
|
|
self.config.project_name,
|
|
test.get_materialization(),
|
|
self.adapter.type()
|
|
)
|
|
|
|
if materialization_macro is None:
|
|
missing_materialization(test, self.adapter.type())
|
|
|
|
if 'config' not in context:
|
|
raise InternalException(
|
|
'Invalid materialization context generated, missing config: {}'
|
|
.format(context)
|
|
)
|
|
|
|
# generate materialization macro
|
|
macro_func = MacroGenerator(materialization_macro, context)
|
|
# execute materialization macro
|
|
macro_func()
|
|
# load results from context
|
|
# could eventually be returned directly by materialization
|
|
result = context['load_result']('main')
|
|
table = result['table']
|
|
num_rows = len(table.rows)
|
|
if num_rows != 1:
|
|
raise InternalException(
|
|
f"dbt internally failed to execute {test.unique_id}: "
|
|
f"Returned {num_rows} rows, but expected "
|
|
f"1 row"
|
|
)
|
|
num_cols = len(table.columns)
|
|
if num_cols != 3:
|
|
raise InternalException(
|
|
f"dbt internally failed to execute {test.unique_id}: "
|
|
f"Returned {num_cols} columns, but expected "
|
|
f"3 columns"
|
|
)
|
|
|
|
test_result_dct: PrimitiveDict = dict(
|
|
zip(
|
|
[column_name.lower() for column_name in table.column_names],
|
|
map(utils._coerce_decimal, table.rows[0])
|
|
)
|
|
)
|
|
TestResultData.validate(test_result_dct)
|
|
return TestResultData.from_dict(test_result_dct)
|
|
|
|
def execute(self, test: CompiledTestNode, manifest: Manifest):
|
|
result = self.execute_test(test, manifest)
|
|
|
|
severity = test.config.severity.upper()
|
|
thread_id = threading.current_thread().name
|
|
num_errors = utils.pluralize(result.failures, 'result')
|
|
status = None
|
|
message = None
|
|
failures = 0
|
|
if severity == "ERROR" and result.should_error:
|
|
status = TestStatus.Fail
|
|
message = f'Got {num_errors}, configured to fail if {test.config.error_if}'
|
|
failures = result.failures
|
|
elif result.should_warn:
|
|
if flags.WARN_ERROR:
|
|
status = TestStatus.Fail
|
|
message = f'Got {num_errors}, configured to fail if {test.config.warn_if}'
|
|
else:
|
|
status = TestStatus.Warn
|
|
message = f'Got {num_errors}, configured to warn if {test.config.warn_if}'
|
|
failures = result.failures
|
|
else:
|
|
status = TestStatus.Pass
|
|
|
|
return RunResult(
|
|
node=test,
|
|
status=status,
|
|
timing=[],
|
|
thread_id=thread_id,
|
|
execution_time=0,
|
|
message=message,
|
|
adapter_response={},
|
|
failures=failures,
|
|
)
|
|
|
|
def after_execute(self, result):
|
|
self.print_result_line(result)
|
|
|
|
|
|
class TestSelector(ResourceTypeSelector):
|
|
def __init__(self, graph, manifest, previous_state):
|
|
super().__init__(
|
|
graph=graph,
|
|
manifest=manifest,
|
|
previous_state=previous_state,
|
|
resource_types=[NodeType.Test],
|
|
)
|
|
|
|
|
|
class TestTask(RunTask):
|
|
"""
|
|
Testing:
|
|
Read schema files + custom data tests and validate that
|
|
constraints are satisfied.
|
|
"""
|
|
|
|
def raise_on_first_error(self):
|
|
return False
|
|
|
|
def safe_run_hooks(
|
|
self, adapter, hook_type: RunHookType, extra_context: Dict[str, Any]
|
|
) -> None:
|
|
# Don't execute on-run-* hooks for tests
|
|
pass
|
|
|
|
def get_selection_spec(self) -> SelectionSpec:
|
|
base_spec = super().get_selection_spec()
|
|
return parse_test_selectors(
|
|
data=self.args.data,
|
|
schema=self.args.schema,
|
|
base=base_spec
|
|
)
|
|
|
|
def get_node_selector(self) -> TestSelector:
|
|
if self.manifest is None or self.graph is None:
|
|
raise InternalException(
|
|
'manifest and graph must be set to get perform node selection'
|
|
)
|
|
return TestSelector(
|
|
graph=self.graph,
|
|
manifest=self.manifest,
|
|
previous_state=self.previous_state,
|
|
)
|
|
|
|
def get_runner_type(self, _):
|
|
return TestRunner
|