dbt-selly/dbt-env/lib/python3.8/site-packages/dbt/task/test.py

217 lines
6.7 KiB
Python

from distutils.util import strtobool
from dataclasses import dataclass
from dbt import utils
from dbt.dataclass_schema import dbtClassMixin
import threading
from typing import Dict, Any, Union
from .compile import CompileRunner
from .run import RunTask
from .printer import print_start_line, print_test_result_line
from dbt.contracts.graph.compiled import (
CompiledDataTestNode,
CompiledSchemaTestNode,
CompiledTestNode,
)
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import TestStatus, PrimitiveDict, RunResult
from dbt.context.providers import generate_runtime_model
from dbt.clients.jinja import MacroGenerator
from dbt.exceptions import (
InternalException,
invalid_bool_error,
missing_materialization
)
from dbt.graph import (
ResourceTypeSelector,
SelectionSpec,
parse_test_selectors,
)
from dbt.node_types import NodeType, RunHookType
from dbt import flags
@dataclass
class TestResultData(dbtClassMixin):
failures: int
should_warn: bool
should_error: bool
@classmethod
def validate(cls, data):
data['should_warn'] = cls.convert_bool_type(data['should_warn'])
data['should_error'] = cls.convert_bool_type(data['should_error'])
super().validate(data)
def convert_bool_type(field) -> bool:
# if it's type string let python decide if it's a valid value to convert to bool
if isinstance(field, str):
try:
return bool(strtobool(field)) # type: ignore
except ValueError:
raise invalid_bool_error(field, 'get_test_sql')
# need this so we catch both true bools and 0/1
return bool(field)
class TestRunner(CompileRunner):
def describe_node(self):
node_name = self.node.name
return "test {}".format(node_name)
def print_result_line(self, result):
print_test_result_line(result, self.node_index, self.num_nodes)
def print_start_line(self):
description = self.describe_node()
print_start_line(description, self.node_index, self.num_nodes)
def before_execute(self):
self.print_start_line()
def execute_test(
self,
test: Union[CompiledDataTestNode, CompiledSchemaTestNode],
manifest: Manifest
) -> TestResultData:
context = generate_runtime_model(
test, self.config, manifest
)
materialization_macro = manifest.find_materialization_macro_by_name(
self.config.project_name,
test.get_materialization(),
self.adapter.type()
)
if materialization_macro is None:
missing_materialization(test, self.adapter.type())
if 'config' not in context:
raise InternalException(
'Invalid materialization context generated, missing config: {}'
.format(context)
)
# generate materialization macro
macro_func = MacroGenerator(materialization_macro, context)
# execute materialization macro
macro_func()
# load results from context
# could eventually be returned directly by materialization
result = context['load_result']('main')
table = result['table']
num_rows = len(table.rows)
if num_rows != 1:
raise InternalException(
f"dbt internally failed to execute {test.unique_id}: "
f"Returned {num_rows} rows, but expected "
f"1 row"
)
num_cols = len(table.columns)
if num_cols != 3:
raise InternalException(
f"dbt internally failed to execute {test.unique_id}: "
f"Returned {num_cols} columns, but expected "
f"3 columns"
)
test_result_dct: PrimitiveDict = dict(
zip(
[column_name.lower() for column_name in table.column_names],
map(utils._coerce_decimal, table.rows[0])
)
)
TestResultData.validate(test_result_dct)
return TestResultData.from_dict(test_result_dct)
def execute(self, test: CompiledTestNode, manifest: Manifest):
result = self.execute_test(test, manifest)
severity = test.config.severity.upper()
thread_id = threading.current_thread().name
num_errors = utils.pluralize(result.failures, 'result')
status = None
message = None
failures = 0
if severity == "ERROR" and result.should_error:
status = TestStatus.Fail
message = f'Got {num_errors}, configured to fail if {test.config.error_if}'
failures = result.failures
elif result.should_warn:
if flags.WARN_ERROR:
status = TestStatus.Fail
message = f'Got {num_errors}, configured to fail if {test.config.warn_if}'
else:
status = TestStatus.Warn
message = f'Got {num_errors}, configured to warn if {test.config.warn_if}'
failures = result.failures
else:
status = TestStatus.Pass
return RunResult(
node=test,
status=status,
timing=[],
thread_id=thread_id,
execution_time=0,
message=message,
adapter_response={},
failures=failures,
)
def after_execute(self, result):
self.print_result_line(result)
class TestSelector(ResourceTypeSelector):
def __init__(self, graph, manifest, previous_state):
super().__init__(
graph=graph,
manifest=manifest,
previous_state=previous_state,
resource_types=[NodeType.Test],
)
class TestTask(RunTask):
"""
Testing:
Read schema files + custom data tests and validate that
constraints are satisfied.
"""
def raise_on_first_error(self):
return False
def safe_run_hooks(
self, adapter, hook_type: RunHookType, extra_context: Dict[str, Any]
) -> None:
# Don't execute on-run-* hooks for tests
pass
def get_selection_spec(self) -> SelectionSpec:
base_spec = super().get_selection_spec()
return parse_test_selectors(
data=self.args.data,
schema=self.args.schema,
base=base_spec
)
def get_node_selector(self) -> TestSelector:
if self.manifest is None or self.graph is None:
raise InternalException(
'manifest and graph must be set to get perform node selection'
)
return TestSelector(
graph=self.graph,
manifest=self.manifest,
previous_state=self.previous_state,
)
def get_runner_type(self, _):
return TestRunner