import os import time import json from abc import abstractmethod from concurrent.futures import as_completed from datetime import datetime from multiprocessing.dummy import Pool as ThreadPool from typing import Optional, Dict, List, Set, Tuple, Iterable, AbstractSet from pathlib import PosixPath, WindowsPath from .printer import ( print_run_result_error, print_run_end_messages, print_cancel_line, ) from dbt import ui from dbt.clients.system import write_file from dbt.task.base import ConfiguredTask from dbt.adapters.base import BaseRelation from dbt.adapters.factory import get_adapter from dbt.logger import ( GLOBAL_LOGGER as logger, DbtProcessState, TextOnly, UniqueID, TimestampNamed, DbtModelState, ModelMetadata, NodeCount, print_timestamped_line, ) from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedSourceDefinition from dbt.contracts.results import NodeStatus, RunExecutionResult from dbt.contracts.state import PreviousState from dbt.exceptions import ( InternalException, NotImplementedException, RuntimeException, FailFastException, ) from dbt.graph import ( GraphQueue, NodeSelector, SelectionSpec, parse_difference, Graph ) from dbt.parser.manifest import ManifestLoader import dbt.exceptions from dbt import flags import dbt.utils RESULT_FILE_NAME = 'run_results.json' MANIFEST_FILE_NAME = 'manifest.json' RUNNING_STATE = DbtProcessState('running') class ManifestTask(ConfiguredTask): def __init__(self, args, config): super().__init__(args, config) self.manifest: Optional[Manifest] = None self.graph: Optional[Graph] = None def write_manifest(self): if flags.WRITE_JSON: path = os.path.join(self.config.target_path, MANIFEST_FILE_NAME) self.manifest.write(path) if os.getenv('DBT_WRITE_FILES'): path = os.path.join(self.config.target_path, 'files.json') write_file(path, json.dumps(self.manifest.files, cls=dbt.utils.JSONEncoder, indent=4)) def load_manifest(self): self.manifest = ManifestLoader.get_full_manifest(self.config) self.write_manifest() def compile_manifest(self): if self.manifest is None: raise InternalException( 'compile_manifest called before manifest was loaded' ) adapter = get_adapter(self.config) compiler = adapter.get_compiler() self.graph = compiler.compile(self.manifest) def _runtime_initialize(self): self.load_manifest() self.compile_manifest() class GraphRunnableTask(ManifestTask): MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error] def __init__(self, args, config): super().__init__(args, config) self.job_queue: Optional[GraphQueue] = None self._flattened_nodes: Optional[List[CompileResultNode]] = None self.run_count: int = 0 self.num_nodes: int = 0 self.node_results = [] self._skipped_children = {} self._raise_next_tick = None self.previous_state: Optional[PreviousState] = None self.set_previous_state() def set_previous_state(self): if self.args.state is not None: self.previous_state = PreviousState(self.args.state) def index_offset(self, value: int) -> int: return value @property def selection_arg(self): return self.args.select @property def exclusion_arg(self): return self.args.exclude def get_selection_spec(self) -> SelectionSpec: default_selector_name = self.config.get_default_selector_name() if self.args.selector_name: # use pre-defined selector (--selector) spec = self.config.get_selector(self.args.selector_name) elif not (self.selection_arg or self.exclusion_arg) and default_selector_name: # use pre-defined selector (--selector) with default: true logger.info(f"Using default selector {default_selector_name}") spec = self.config.get_selector(default_selector_name) else: # use --select and --exclude args spec = parse_difference(self.selection_arg, self.exclusion_arg) return spec @abstractmethod def get_node_selector(self) -> NodeSelector: raise NotImplementedException( f'get_node_selector not implemented for task {type(self)}' ) def get_graph_queue(self) -> GraphQueue: selector = self.get_node_selector() spec = self.get_selection_spec() return selector.get_graph_queue(spec) def _runtime_initialize(self): super()._runtime_initialize() if self.manifest is None or self.graph is None: raise InternalException( '_runtime_initialize never loaded the manifest and graph!' ) self.job_queue = self.get_graph_queue() # we use this a couple of times. order does not matter. self._flattened_nodes = [] for uid in self.job_queue.get_selected_nodes(): if uid in self.manifest.nodes: self._flattened_nodes.append(self.manifest.nodes[uid]) elif uid in self.manifest.sources: self._flattened_nodes.append(self.manifest.sources[uid]) else: raise InternalException( f'Node selection returned {uid}, expected a node or a ' f'source' ) self.num_nodes = len([ n for n in self._flattened_nodes if not n.is_ephemeral_model ]) def raise_on_first_error(self): return False def get_runner_type(self, node): raise NotImplementedException('Not Implemented') def result_path(self): return os.path.join(self.config.target_path, RESULT_FILE_NAME) def get_runner(self, node): adapter = get_adapter(self.config) if node.is_ephemeral_model: run_count = 0 num_nodes = 0 else: self.run_count += 1 run_count = self.run_count num_nodes = self.num_nodes cls = self.get_runner_type(node) return cls(self.config, adapter, node, run_count, num_nodes) def call_runner(self, runner): uid_context = UniqueID(runner.node.unique_id) with RUNNING_STATE, uid_context: startctx = TimestampNamed('node_started_at') index = self.index_offset(runner.node_index) extended_metadata = ModelMetadata(runner.node, index) with startctx, extended_metadata: logger.debug('Began running node {}'.format( runner.node.unique_id)) status: Dict[str, str] try: result = runner.run_with_hooks(self.manifest) status = runner.get_result_status(result) finally: finishctx = TimestampNamed('node_finished_at') with finishctx, DbtModelState(status): logger.debug('Finished running node {}'.format( runner.node.unique_id)) fail_fast = getattr(self.config.args, 'fail_fast', False) if result.status in (NodeStatus.Error, NodeStatus.Fail) and fail_fast: self._raise_next_tick = FailFastException( message='Failing early due to test failure or runtime error', result=result, node=getattr(result, 'node', None) ) elif result.status == NodeStatus.Error and self.raise_on_first_error(): # if we raise inside a thread, it'll just get silently swallowed. # stash the error message we want here, and it will check the # next 'tick' - should be soon since our thread is about to finish! self._raise_next_tick = RuntimeException(result.message) return result def _submit(self, pool, args, callback): """If the caller has passed the magic 'single-threaded' flag, call the function directly instead of pool.apply_async. The single-threaded flag is intended for gathering more useful performance information about what happens beneath `call_runner`, since python's default profiling tools ignore child threads. This does still go through the callback path for result collection. """ if self.config.args.single_threaded: callback(self.call_runner(*args)) else: pool.apply_async(self.call_runner, args=args, callback=callback) def _raise_set_error(self): if self._raise_next_tick is not None: raise self._raise_next_tick def run_queue(self, pool): """Given a pool, submit jobs from the queue to the pool. """ if self.job_queue is None: raise InternalException( 'Got to run_queue with no job queue set' ) def callback(result): """Note: mark_done, at a minimum, must happen here or dbt will deadlock during ephemeral result error handling! """ self._handle_result(result) if self.job_queue is None: raise InternalException( 'Got to run_queue callback with no job queue set' ) self.job_queue.mark_done(result.node.unique_id) while not self.job_queue.empty(): node = self.job_queue.get() self._raise_set_error() runner = self.get_runner(node) # we finally know what we're running! Make sure we haven't decided # to skip it due to upstream failures if runner.node.unique_id in self._skipped_children: cause = self._skipped_children.pop(runner.node.unique_id) runner.do_skip(cause=cause) args = (runner,) self._submit(pool, args, callback) # block on completion if getattr(self.config.args, 'fail_fast', False): # checkout for an errors after task completion in case of # fast failure while self.job_queue.wait_until_something_was_done(): self._raise_set_error() else: # wait until every task will be complete self.job_queue.join() # if an error got set during join(), raise it. self._raise_set_error() return def _handle_result(self, result): """Mark the result as completed, insert the `CompileResultNode` into the manifest, and mark any descendants (potentially with a 'cause' if the result was an ephemeral model) as skipped. """ is_ephemeral = result.node.is_ephemeral_model if not is_ephemeral: self.node_results.append(result) node = result.node if self.manifest is None: raise InternalException('manifest was None in _handle_result') if isinstance(node, ParsedSourceDefinition): self.manifest.update_source(node) else: self.manifest.update_node(node) if result.status in self.MARK_DEPENDENT_ERRORS_STATUSES: if is_ephemeral: cause = result else: cause = None self._mark_dependent_errors(node.unique_id, result, cause) def _cancel_connections(self, pool): """Given a pool, cancel all adapter connections and wait until all runners gentle terminates. """ pool.close() pool.terminate() adapter = get_adapter(self.config) if not adapter.is_cancelable(): msg = ("The {} adapter does not support query " "cancellation. Some queries may still be " "running!".format(adapter.type())) yellow = ui.COLOR_FG_YELLOW print_timestamped_line(msg, yellow) else: with adapter.connection_named('master'): for conn_name in adapter.cancel_open_connections(): if self.manifest is not None: node = self.manifest.nodes.get(conn_name) if node is not None and node.is_ephemeral_model: continue # if we don't have a manifest/don't have a node, print # anyway. print_cancel_line(conn_name) pool.join() def execute_nodes(self): num_threads = self.config.threads target_name = self.config.target_name text = "Concurrency: {} threads (target='{}')" concurrency_line = text.format(num_threads, target_name) with NodeCount(self.num_nodes): print_timestamped_line(concurrency_line) with TextOnly(): print_timestamped_line("") pool = ThreadPool(num_threads) try: self.run_queue(pool) except FailFastException as failure: self._cancel_connections(pool) print_run_result_error(failure.result) raise except KeyboardInterrupt: self._cancel_connections(pool) print_run_end_messages(self.node_results, keyboard_interrupt=True) raise pool.close() pool.join() return self.node_results def _mark_dependent_errors(self, node_id, result, cause): if self.graph is None: raise InternalException('graph is None in _mark_dependent_errors') for dep_node_id in self.graph.get_dependent_nodes(node_id): self._skipped_children[dep_node_id] = cause def populate_adapter_cache(self, adapter): adapter.set_relations_cache(self.manifest) def before_hooks(self, adapter): pass def before_run(self, adapter, selected_uids: AbstractSet[str]): with adapter.connection_named('master'): self.populate_adapter_cache(adapter) def after_run(self, adapter, results): pass def after_hooks(self, adapter, results, elapsed): pass def execute_with_hooks(self, selected_uids: AbstractSet[str]): adapter = get_adapter(self.config) try: self.before_hooks(adapter) started = time.time() self.before_run(adapter, selected_uids) res = self.execute_nodes() self.after_run(adapter, res) elapsed = time.time() - started self.after_hooks(adapter, res, elapsed) finally: adapter.cleanup_connections() result = self.get_result( results=res, elapsed_time=elapsed, generated_at=datetime.utcnow() ) return result def write_result(self, result): result.write(self.result_path()) def run(self): """ Run dbt for the query, based on the graph. """ self._runtime_initialize() if self._flattened_nodes is None: raise InternalException( 'after _runtime_initialize, _flattened_nodes was still None' ) if len(self._flattened_nodes) == 0: logger.warning("\nWARNING: Nothing to do. Try checking your model " "configs and model specification args") result = self.get_result( results=[], generated_at=datetime.utcnow(), elapsed_time=0.0, ) else: with TextOnly(): logger.info("") selected_uids = frozenset(n.unique_id for n in self._flattened_nodes) result = self.execute_with_hooks(selected_uids) if flags.WRITE_JSON: self.write_manifest() self.write_result(result) self.task_end_messages(result.results) return result def interpret_results(self, results): if results is None: return False failures = [ r for r in results if r.status in ( NodeStatus.RuntimeErr, NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped # propogate error message causing skip ) ] return len(failures) == 0 def get_model_schemas( self, adapter, selected_uids: Iterable[str] ) -> Set[BaseRelation]: if self.manifest is None: raise InternalException('manifest was None in get_model_schemas') result: Set[BaseRelation] = set() for node in self.manifest.nodes.values(): if node.unique_id not in selected_uids: continue if node.is_relational and not node.is_ephemeral: relation = adapter.Relation.create_from(self.config, node) result.add(relation.without_identifier()) return result def create_schemas(self, adapter, selected_uids: Iterable[str]): required_schemas = self.get_model_schemas(adapter, selected_uids) # we want the string form of the information schema database required_databases: Set[BaseRelation] = set() for required in required_schemas: db_only = required.include( database=True, schema=False, identifier=False ) required_databases.add(db_only) existing_schemas_lowered: Set[Tuple[Optional[str], Optional[str]]] existing_schemas_lowered = set() def list_schemas( db_only: BaseRelation ) -> List[Tuple[Optional[str], str]]: # the database can be None on some warehouses that don't support it database_quoted: Optional[str] db_lowercase = dbt.utils.lowercase(db_only.database) if db_only.database is None: database_quoted = None else: database_quoted = str(db_only) # we should never create a null schema, so just filter them out return [ (db_lowercase, s.lower()) for s in adapter.list_schemas(database_quoted) if s is not None ] def create_schema(relation: BaseRelation) -> None: db = relation.database or '' schema = relation.schema with adapter.connection_named(f'create_{db}_{schema}'): adapter.create_schema(relation) list_futures = [] create_futures = [] with dbt.utils.executor(self.config) as tpe: for req in required_databases: if req.database is None: name = 'list_schemas' else: name = f'list_{req.database}' fut = tpe.submit_connected(adapter, name, list_schemas, req) list_futures.append(fut) for ls_future in as_completed(list_futures): existing_schemas_lowered.update(ls_future.result()) for info in required_schemas: if info.schema is None: # we are not in the business of creating null schemas, so # skip this continue db: Optional[str] = info.database db_lower: Optional[str] = dbt.utils.lowercase(db) schema: str = info.schema db_schema = (db_lower, schema.lower()) if db_schema not in existing_schemas_lowered: existing_schemas_lowered.add(db_schema) fut = tpe.submit_connected( adapter, f'create_{info.database or ""}_{info.schema}', create_schema, info ) create_futures.append(fut) for create_future in as_completed(create_futures): # trigger/re-raise any excceptions while creating schemas create_future.result() def get_result(self, results, elapsed_time, generated_at): return RunExecutionResult( results=results, elapsed_time=elapsed_time, generated_at=generated_at, args=self.args_to_dict(), ) def args_to_dict(self): var_args = vars(self.args) dict_args = {} # remove args keys that clutter up the dictionary for key in var_args: if key == 'cls': continue if var_args[key] is None: continue default_false_keys = ( 'debug', 'full_refresh', 'fail_fast', 'warn_error', 'single_threaded', 'test_new_parser', 'log_cache_events', 'strict' ) if key in default_false_keys and var_args[key] is False: continue if key == 'vars' and var_args[key] == '{}': continue # this was required for a test case if (isinstance(var_args[key], PosixPath) or isinstance(var_args[key], WindowsPath)): var_args[key] = str(var_args[key]) dict_args[key] = var_args[key] return dict_args def task_end_messages(self, results): print_run_end_messages(results)