dbt-selly/dbt-env/lib/python3.8/site-packages/dbt/task/runnable.py

605 lines
21 KiB
Python

import os
import time
import json
from abc import abstractmethod
from concurrent.futures import as_completed
from datetime import datetime
from multiprocessing.dummy import Pool as ThreadPool
from typing import Optional, Dict, List, Set, Tuple, Iterable, AbstractSet
from pathlib import PosixPath, WindowsPath
from .printer import (
print_run_result_error,
print_run_end_messages,
print_cancel_line,
)
from dbt import ui
from dbt.clients.system import write_file
from dbt.task.base import ConfiguredTask
from dbt.adapters.base import BaseRelation
from dbt.adapters.factory import get_adapter
from dbt.logger import (
GLOBAL_LOGGER as logger,
DbtProcessState,
TextOnly,
UniqueID,
TimestampNamed,
DbtModelState,
ModelMetadata,
NodeCount,
print_timestamped_line,
)
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedSourceDefinition
from dbt.contracts.results import NodeStatus, RunExecutionResult
from dbt.contracts.state import PreviousState
from dbt.exceptions import (
InternalException,
NotImplementedException,
RuntimeException,
FailFastException,
)
from dbt.graph import (
GraphQueue,
NodeSelector,
SelectionSpec,
parse_difference,
Graph
)
from dbt.parser.manifest import ManifestLoader
import dbt.exceptions
from dbt import flags
import dbt.utils
RESULT_FILE_NAME = 'run_results.json'
MANIFEST_FILE_NAME = 'manifest.json'
RUNNING_STATE = DbtProcessState('running')
class ManifestTask(ConfiguredTask):
def __init__(self, args, config):
super().__init__(args, config)
self.manifest: Optional[Manifest] = None
self.graph: Optional[Graph] = None
def write_manifest(self):
if flags.WRITE_JSON:
path = os.path.join(self.config.target_path, MANIFEST_FILE_NAME)
self.manifest.write(path)
if os.getenv('DBT_WRITE_FILES'):
path = os.path.join(self.config.target_path, 'files.json')
write_file(path, json.dumps(self.manifest.files, cls=dbt.utils.JSONEncoder, indent=4))
def load_manifest(self):
self.manifest = ManifestLoader.get_full_manifest(self.config)
self.write_manifest()
def compile_manifest(self):
if self.manifest is None:
raise InternalException(
'compile_manifest called before manifest was loaded'
)
adapter = get_adapter(self.config)
compiler = adapter.get_compiler()
self.graph = compiler.compile(self.manifest)
def _runtime_initialize(self):
self.load_manifest()
self.compile_manifest()
class GraphRunnableTask(ManifestTask):
MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error]
def __init__(self, args, config):
super().__init__(args, config)
self.job_queue: Optional[GraphQueue] = None
self._flattened_nodes: Optional[List[CompileResultNode]] = None
self.run_count: int = 0
self.num_nodes: int = 0
self.node_results = []
self._skipped_children = {}
self._raise_next_tick = None
self.previous_state: Optional[PreviousState] = None
self.set_previous_state()
def set_previous_state(self):
if self.args.state is not None:
self.previous_state = PreviousState(self.args.state)
def index_offset(self, value: int) -> int:
return value
@property
def selection_arg(self):
return self.args.select
@property
def exclusion_arg(self):
return self.args.exclude
def get_selection_spec(self) -> SelectionSpec:
default_selector_name = self.config.get_default_selector_name()
if self.args.selector_name:
# use pre-defined selector (--selector)
spec = self.config.get_selector(self.args.selector_name)
elif not (self.selection_arg or self.exclusion_arg) and default_selector_name:
# use pre-defined selector (--selector) with default: true
logger.info(f"Using default selector {default_selector_name}")
spec = self.config.get_selector(default_selector_name)
else:
# use --select and --exclude args
spec = parse_difference(self.selection_arg, self.exclusion_arg)
return spec
@abstractmethod
def get_node_selector(self) -> NodeSelector:
raise NotImplementedException(
f'get_node_selector not implemented for task {type(self)}'
)
def get_graph_queue(self) -> GraphQueue:
selector = self.get_node_selector()
spec = self.get_selection_spec()
return selector.get_graph_queue(spec)
def _runtime_initialize(self):
super()._runtime_initialize()
if self.manifest is None or self.graph is None:
raise InternalException(
'_runtime_initialize never loaded the manifest and graph!'
)
self.job_queue = self.get_graph_queue()
# we use this a couple of times. order does not matter.
self._flattened_nodes = []
for uid in self.job_queue.get_selected_nodes():
if uid in self.manifest.nodes:
self._flattened_nodes.append(self.manifest.nodes[uid])
elif uid in self.manifest.sources:
self._flattened_nodes.append(self.manifest.sources[uid])
else:
raise InternalException(
f'Node selection returned {uid}, expected a node or a '
f'source'
)
self.num_nodes = len([
n for n in self._flattened_nodes
if not n.is_ephemeral_model
])
def raise_on_first_error(self):
return False
def get_runner_type(self, node):
raise NotImplementedException('Not Implemented')
def result_path(self):
return os.path.join(self.config.target_path, RESULT_FILE_NAME)
def get_runner(self, node):
adapter = get_adapter(self.config)
if node.is_ephemeral_model:
run_count = 0
num_nodes = 0
else:
self.run_count += 1
run_count = self.run_count
num_nodes = self.num_nodes
cls = self.get_runner_type(node)
return cls(self.config, adapter, node, run_count, num_nodes)
def call_runner(self, runner):
uid_context = UniqueID(runner.node.unique_id)
with RUNNING_STATE, uid_context:
startctx = TimestampNamed('node_started_at')
index = self.index_offset(runner.node_index)
extended_metadata = ModelMetadata(runner.node, index)
with startctx, extended_metadata:
logger.debug('Began running node {}'.format(
runner.node.unique_id))
status: Dict[str, str]
try:
result = runner.run_with_hooks(self.manifest)
status = runner.get_result_status(result)
finally:
finishctx = TimestampNamed('node_finished_at')
with finishctx, DbtModelState(status):
logger.debug('Finished running node {}'.format(
runner.node.unique_id))
fail_fast = getattr(self.config.args, 'fail_fast', False)
if result.status in (NodeStatus.Error, NodeStatus.Fail) and fail_fast:
self._raise_next_tick = FailFastException(
message='Failing early due to test failure or runtime error',
result=result,
node=getattr(result, 'node', None)
)
elif result.status == NodeStatus.Error and self.raise_on_first_error():
# if we raise inside a thread, it'll just get silently swallowed.
# stash the error message we want here, and it will check the
# next 'tick' - should be soon since our thread is about to finish!
self._raise_next_tick = RuntimeException(result.message)
return result
def _submit(self, pool, args, callback):
"""If the caller has passed the magic 'single-threaded' flag, call the
function directly instead of pool.apply_async. The single-threaded flag
is intended for gathering more useful performance information about
what happens beneath `call_runner`, since python's default profiling
tools ignore child threads.
This does still go through the callback path for result collection.
"""
if self.config.args.single_threaded:
callback(self.call_runner(*args))
else:
pool.apply_async(self.call_runner, args=args, callback=callback)
def _raise_set_error(self):
if self._raise_next_tick is not None:
raise self._raise_next_tick
def run_queue(self, pool):
"""Given a pool, submit jobs from the queue to the pool.
"""
if self.job_queue is None:
raise InternalException(
'Got to run_queue with no job queue set'
)
def callback(result):
"""Note: mark_done, at a minimum, must happen here or dbt will
deadlock during ephemeral result error handling!
"""
self._handle_result(result)
if self.job_queue is None:
raise InternalException(
'Got to run_queue callback with no job queue set'
)
self.job_queue.mark_done(result.node.unique_id)
while not self.job_queue.empty():
node = self.job_queue.get()
self._raise_set_error()
runner = self.get_runner(node)
# we finally know what we're running! Make sure we haven't decided
# to skip it due to upstream failures
if runner.node.unique_id in self._skipped_children:
cause = self._skipped_children.pop(runner.node.unique_id)
runner.do_skip(cause=cause)
args = (runner,)
self._submit(pool, args, callback)
# block on completion
if getattr(self.config.args, 'fail_fast', False):
# checkout for an errors after task completion in case of
# fast failure
while self.job_queue.wait_until_something_was_done():
self._raise_set_error()
else:
# wait until every task will be complete
self.job_queue.join()
# if an error got set during join(), raise it.
self._raise_set_error()
return
def _handle_result(self, result):
"""Mark the result as completed, insert the `CompileResultNode` into
the manifest, and mark any descendants (potentially with a 'cause' if
the result was an ephemeral model) as skipped.
"""
is_ephemeral = result.node.is_ephemeral_model
if not is_ephemeral:
self.node_results.append(result)
node = result.node
if self.manifest is None:
raise InternalException('manifest was None in _handle_result')
if isinstance(node, ParsedSourceDefinition):
self.manifest.update_source(node)
else:
self.manifest.update_node(node)
if result.status in self.MARK_DEPENDENT_ERRORS_STATUSES:
if is_ephemeral:
cause = result
else:
cause = None
self._mark_dependent_errors(node.unique_id, result, cause)
def _cancel_connections(self, pool):
"""Given a pool, cancel all adapter connections and wait until all
runners gentle terminates.
"""
pool.close()
pool.terminate()
adapter = get_adapter(self.config)
if not adapter.is_cancelable():
msg = ("The {} adapter does not support query "
"cancellation. Some queries may still be "
"running!".format(adapter.type()))
yellow = ui.COLOR_FG_YELLOW
print_timestamped_line(msg, yellow)
else:
with adapter.connection_named('master'):
for conn_name in adapter.cancel_open_connections():
if self.manifest is not None:
node = self.manifest.nodes.get(conn_name)
if node is not None and node.is_ephemeral_model:
continue
# if we don't have a manifest/don't have a node, print
# anyway.
print_cancel_line(conn_name)
pool.join()
def execute_nodes(self):
num_threads = self.config.threads
target_name = self.config.target_name
text = "Concurrency: {} threads (target='{}')"
concurrency_line = text.format(num_threads, target_name)
with NodeCount(self.num_nodes):
print_timestamped_line(concurrency_line)
with TextOnly():
print_timestamped_line("")
pool = ThreadPool(num_threads)
try:
self.run_queue(pool)
except FailFastException as failure:
self._cancel_connections(pool)
print_run_result_error(failure.result)
raise
except KeyboardInterrupt:
self._cancel_connections(pool)
print_run_end_messages(self.node_results, keyboard_interrupt=True)
raise
pool.close()
pool.join()
return self.node_results
def _mark_dependent_errors(self, node_id, result, cause):
if self.graph is None:
raise InternalException('graph is None in _mark_dependent_errors')
for dep_node_id in self.graph.get_dependent_nodes(node_id):
self._skipped_children[dep_node_id] = cause
def populate_adapter_cache(self, adapter):
adapter.set_relations_cache(self.manifest)
def before_hooks(self, adapter):
pass
def before_run(self, adapter, selected_uids: AbstractSet[str]):
with adapter.connection_named('master'):
self.populate_adapter_cache(adapter)
def after_run(self, adapter, results):
pass
def after_hooks(self, adapter, results, elapsed):
pass
def execute_with_hooks(self, selected_uids: AbstractSet[str]):
adapter = get_adapter(self.config)
try:
self.before_hooks(adapter)
started = time.time()
self.before_run(adapter, selected_uids)
res = self.execute_nodes()
self.after_run(adapter, res)
elapsed = time.time() - started
self.after_hooks(adapter, res, elapsed)
finally:
adapter.cleanup_connections()
result = self.get_result(
results=res,
elapsed_time=elapsed,
generated_at=datetime.utcnow()
)
return result
def write_result(self, result):
result.write(self.result_path())
def run(self):
"""
Run dbt for the query, based on the graph.
"""
self._runtime_initialize()
if self._flattened_nodes is None:
raise InternalException(
'after _runtime_initialize, _flattened_nodes was still None'
)
if len(self._flattened_nodes) == 0:
logger.warning("\nWARNING: Nothing to do. Try checking your model "
"configs and model specification args")
result = self.get_result(
results=[],
generated_at=datetime.utcnow(),
elapsed_time=0.0,
)
else:
with TextOnly():
logger.info("")
selected_uids = frozenset(n.unique_id for n in self._flattened_nodes)
result = self.execute_with_hooks(selected_uids)
if flags.WRITE_JSON:
self.write_manifest()
self.write_result(result)
self.task_end_messages(result.results)
return result
def interpret_results(self, results):
if results is None:
return False
failures = [
r for r in results if r.status in (
NodeStatus.RuntimeErr,
NodeStatus.Error,
NodeStatus.Fail,
NodeStatus.Skipped # propogate error message causing skip
)
]
return len(failures) == 0
def get_model_schemas(
self, adapter, selected_uids: Iterable[str]
) -> Set[BaseRelation]:
if self.manifest is None:
raise InternalException('manifest was None in get_model_schemas')
result: Set[BaseRelation] = set()
for node in self.manifest.nodes.values():
if node.unique_id not in selected_uids:
continue
if node.is_relational and not node.is_ephemeral:
relation = adapter.Relation.create_from(self.config, node)
result.add(relation.without_identifier())
return result
def create_schemas(self, adapter, selected_uids: Iterable[str]):
required_schemas = self.get_model_schemas(adapter, selected_uids)
# we want the string form of the information schema database
required_databases: Set[BaseRelation] = set()
for required in required_schemas:
db_only = required.include(
database=True, schema=False, identifier=False
)
required_databases.add(db_only)
existing_schemas_lowered: Set[Tuple[Optional[str], Optional[str]]]
existing_schemas_lowered = set()
def list_schemas(
db_only: BaseRelation
) -> List[Tuple[Optional[str], str]]:
# the database can be None on some warehouses that don't support it
database_quoted: Optional[str]
db_lowercase = dbt.utils.lowercase(db_only.database)
if db_only.database is None:
database_quoted = None
else:
database_quoted = str(db_only)
# we should never create a null schema, so just filter them out
return [
(db_lowercase, s.lower())
for s in adapter.list_schemas(database_quoted)
if s is not None
]
def create_schema(relation: BaseRelation) -> None:
db = relation.database or ''
schema = relation.schema
with adapter.connection_named(f'create_{db}_{schema}'):
adapter.create_schema(relation)
list_futures = []
create_futures = []
with dbt.utils.executor(self.config) as tpe:
for req in required_databases:
if req.database is None:
name = 'list_schemas'
else:
name = f'list_{req.database}'
fut = tpe.submit_connected(adapter, name, list_schemas, req)
list_futures.append(fut)
for ls_future in as_completed(list_futures):
existing_schemas_lowered.update(ls_future.result())
for info in required_schemas:
if info.schema is None:
# we are not in the business of creating null schemas, so
# skip this
continue
db: Optional[str] = info.database
db_lower: Optional[str] = dbt.utils.lowercase(db)
schema: str = info.schema
db_schema = (db_lower, schema.lower())
if db_schema not in existing_schemas_lowered:
existing_schemas_lowered.add(db_schema)
fut = tpe.submit_connected(
adapter, f'create_{info.database or ""}_{info.schema}',
create_schema, info
)
create_futures.append(fut)
for create_future in as_completed(create_futures):
# trigger/re-raise any excceptions while creating schemas
create_future.result()
def get_result(self, results, elapsed_time, generated_at):
return RunExecutionResult(
results=results,
elapsed_time=elapsed_time,
generated_at=generated_at,
args=self.args_to_dict(),
)
def args_to_dict(self):
var_args = vars(self.args)
dict_args = {}
# remove args keys that clutter up the dictionary
for key in var_args:
if key == 'cls':
continue
if var_args[key] is None:
continue
default_false_keys = (
'debug', 'full_refresh', 'fail_fast', 'warn_error',
'single_threaded', 'test_new_parser', 'log_cache_events',
'strict'
)
if key in default_false_keys and var_args[key] is False:
continue
if key == 'vars' and var_args[key] == '{}':
continue
# this was required for a test case
if (isinstance(var_args[key], PosixPath) or
isinstance(var_args[key], WindowsPath)):
var_args[key] = str(var_args[key])
dict_args[key] = var_args[key]
return dict_args
def task_end_messages(self, results):
print_run_end_messages(results)