import os from collections import defaultdict from typing import List, Dict, Any, Tuple, cast, Optional import networkx as nx # type: ignore import sqlparse from dbt import flags from dbt.adapters.factory import get_adapter from dbt.clients import jinja from dbt.clients.system import make_directory from dbt.context.providers import generate_runtime_model from dbt.contracts.graph.manifest import Manifest, UniqueID from dbt.contracts.graph.compiled import ( COMPILED_TYPES, CompiledSchemaTestNode, GraphMemberNode, InjectedCTE, ManifestNode, NonSourceCompiledNode, ) from dbt.contracts.graph.parsed import ParsedNode from dbt.exceptions import ( dependency_not_found, InternalException, RuntimeException, ) from dbt.graph import Graph from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType from dbt.utils import pluralize import dbt.tracking graph_file_name = 'graph.gpickle' def _compiled_type_for(model: ParsedNode): if type(model) not in COMPILED_TYPES: raise InternalException( f'Asked to compile {type(model)} node, but it has no compiled form' ) return COMPILED_TYPES[type(model)] def print_compile_stats(stats): names = { NodeType.Model: 'model', NodeType.Test: 'test', NodeType.Snapshot: 'snapshot', NodeType.Analysis: 'analysis', NodeType.Macro: 'macro', NodeType.Operation: 'operation', NodeType.Seed: 'seed file', NodeType.Source: 'source', NodeType.Exposure: 'exposure', } results = {k: 0 for k in names.keys()} results.update(stats) # create tracking event for resource_counts if dbt.tracking.active_user is not None: resource_counts = {k.pluralize(): v for k, v in results.items()} dbt.tracking.track_resource_counts(resource_counts) stat_line = ", ".join([ pluralize(ct, names.get(t)) for t, ct in results.items() if t in names ]) logger.info("Found {}".format(stat_line)) def _node_enabled(node: ManifestNode): # Disabled models are already excluded from the manifest if node.resource_type == NodeType.Test and not node.config.enabled: return False else: return True def _generate_stats(manifest: Manifest): stats: Dict[NodeType, int] = defaultdict(int) for node in manifest.nodes.values(): if _node_enabled(node): stats[node.resource_type] += 1 for source in manifest.sources.values(): stats[source.resource_type] += 1 for exposure in manifest.exposures.values(): stats[exposure.resource_type] += 1 for macro in manifest.macros.values(): stats[macro.resource_type] += 1 return stats def _add_prepended_cte(prepended_ctes, new_cte): for cte in prepended_ctes: if cte.id == new_cte.id: cte.sql = new_cte.sql return prepended_ctes.append(new_cte) def _extend_prepended_ctes(prepended_ctes, new_prepended_ctes): for new_cte in new_prepended_ctes: _add_prepended_cte(prepended_ctes, new_cte) def _get_tests_for_node(manifest: Manifest, unique_id: UniqueID) -> List[UniqueID]: """ Get a list of tests that depend on the node with the provided unique id """ tests = [] if unique_id in manifest.child_map: for child_unique_id in manifest.child_map[unique_id]: if child_unique_id.startswith('test.'): tests.append(child_unique_id) return tests class Linker: def __init__(self, data=None): if data is None: data = {} self.graph = nx.DiGraph(**data) def edges(self): return self.graph.edges() def nodes(self): return self.graph.nodes() def find_cycles(self): try: cycle = nx.find_cycle(self.graph) except nx.NetworkXNoCycle: return None else: # cycles is a List[Tuple[str, ...]] return " --> ".join(c[0] for c in cycle) def dependency(self, node1, node2): "indicate that node1 depends on node2" self.graph.add_node(node1) self.graph.add_node(node2) self.graph.add_edge(node2, node1) def add_node(self, node): self.graph.add_node(node) def write_graph(self, outfile: str, manifest: Manifest): """Write the graph to a gpickle file. Before doing so, serialize and include all nodes in their corresponding graph entries. """ out_graph = self.graph.copy() for node_id in self.graph: data = manifest.expect(node_id).to_dict(omit_none=True) out_graph.add_node(node_id, **data) nx.write_gpickle(out_graph, outfile) class Compiler: def __init__(self, config): self.config = config def initialize(self): make_directory(self.config.target_path) make_directory(self.config.modules_path) # creates a ModelContext which is converted to # a dict for jinja rendering of SQL def _create_node_context( self, node: NonSourceCompiledNode, manifest: Manifest, extra_context: Dict[str, Any], ) -> Dict[str, Any]: context = generate_runtime_model( node, self.config, manifest ) context.update(extra_context) if isinstance(node, CompiledSchemaTestNode): # for test nodes, add a special keyword args value to the context jinja.add_rendered_test_kwargs(context, node) return context def add_ephemeral_prefix(self, name: str): adapter = get_adapter(self.config) relation_cls = adapter.Relation return relation_cls.add_ephemeral_prefix(name) def _get_relation_name(self, node: ParsedNode): relation_name = None if node.is_relational and not node.is_ephemeral_model: adapter = get_adapter(self.config) relation_cls = adapter.Relation relation_name = str(relation_cls.create_from(self.config, node)) return relation_name def _inject_ctes_into_sql(self, sql: str, ctes: List[InjectedCTE]) -> str: """ `ctes` is a list of InjectedCTEs like: [ InjectedCTE( id="cte_id_1", sql="__dbt__cte__ephemeral as (select * from table)", ), InjectedCTE( id="cte_id_2", sql="__dbt__cte__events as (select id, type from events)", ), ] Given `sql` like: "with internal_cte as (select * from sessions) select * from internal_cte" This will spit out: "with __dbt__cte__ephemeral as (select * from table), __dbt__cte__events as (select id, type from events), with internal_cte as (select * from sessions) select * from internal_cte" (Whitespace enhanced for readability.) """ if len(ctes) == 0: return sql parsed_stmts = sqlparse.parse(sql) parsed = parsed_stmts[0] with_stmt = None for token in parsed.tokens: if token.is_keyword and token.normalized == 'WITH': with_stmt = token break if with_stmt is None: # no with stmt, add one, and inject CTEs right at the beginning first_token = parsed.token_first() with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with') parsed.insert_before(first_token, with_stmt) else: # stmt exists, add a comma (which will come after injected CTEs) trailing_comma = sqlparse.sql.Token( sqlparse.tokens.Punctuation, ',' ) parsed.insert_after(with_stmt, trailing_comma) token = sqlparse.sql.Token( sqlparse.tokens.Keyword, ", ".join(c.sql for c in ctes) ) parsed.insert_after(with_stmt, token) return str(parsed) def _recursively_prepend_ctes( self, model: NonSourceCompiledNode, manifest: Manifest, extra_context: Optional[Dict[str, Any]], ) -> Tuple[NonSourceCompiledNode, List[InjectedCTE]]: """This method is called by the 'compile_node' method. Starting from the node that it is passed in, it will recursively call itself using the 'extra_ctes'. The 'ephemeral' models do not produce SQL that is executed directly, instead they are rolled up into the models that refer to them by inserting CTEs into the SQL. """ if model.compiled_sql is None: raise RuntimeException( 'Cannot inject ctes into an unparsed node', model ) if model.extra_ctes_injected: return (model, model.extra_ctes) # Just to make it plain that nothing is actually injected for this case if not model.extra_ctes: model.extra_ctes_injected = True manifest.update_node(model) return (model, model.extra_ctes) # This stores the ctes which will all be recursively # gathered and then "injected" into the model. prepended_ctes: List[InjectedCTE] = [] # extra_ctes are added to the model by # RuntimeRefResolver.create_relation, which adds an # extra_cte for every model relation which is an # ephemeral model. for cte in model.extra_ctes: if cte.id not in manifest.nodes: raise InternalException( f'During compilation, found a cte reference that ' f'could not be resolved: {cte.id}' ) cte_model = manifest.nodes[cte.id] if not cte_model.is_ephemeral_model: raise InternalException(f'{cte.id} is not ephemeral') # This model has already been compiled, so it's been # through here before if getattr(cte_model, 'compiled', False): assert isinstance(cte_model, tuple(COMPILED_TYPES.values())) cte_model = cast(NonSourceCompiledNode, cte_model) new_prepended_ctes = cte_model.extra_ctes # if the cte_model isn't compiled, i.e. first time here else: # This is an ephemeral parsed model that we can compile. # Compile and update the node cte_model = self._compile_node( cte_model, manifest, extra_context) # recursively call this method cte_model, new_prepended_ctes = \ self._recursively_prepend_ctes( cte_model, manifest, extra_context ) # Save compiled SQL file and sync manifest self._write_node(cte_model) manifest.sync_update_node(cte_model) _extend_prepended_ctes(prepended_ctes, new_prepended_ctes) new_cte_name = self.add_ephemeral_prefix(cte_model.name) rendered_sql = ( cte_model._pre_injected_sql or cte_model.compiled_sql ) sql = f' {new_cte_name} as (\n{rendered_sql}\n)' _add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql)) injected_sql = self._inject_ctes_into_sql( model.compiled_sql, prepended_ctes, ) model._pre_injected_sql = model.compiled_sql model.compiled_sql = injected_sql model.extra_ctes_injected = True model.extra_ctes = prepended_ctes model.validate(model.to_dict(omit_none=True)) manifest.update_node(model) return model, prepended_ctes # creates a compiled_node from the ManifestNode passed in, # creates a "context" dictionary for jinja rendering, # and then renders the "compiled_sql" using the node, the # raw_sql and the context. def _compile_node( self, node: ManifestNode, manifest: Manifest, extra_context: Optional[Dict[str, Any]] = None, ) -> NonSourceCompiledNode: if extra_context is None: extra_context = {} logger.debug("Compiling {}".format(node.unique_id)) data = node.to_dict(omit_none=True) data.update({ 'compiled': False, 'compiled_sql': None, 'extra_ctes_injected': False, 'extra_ctes': [], }) compiled_node = _compiled_type_for(node).from_dict(data) context = self._create_node_context( compiled_node, manifest, extra_context ) compiled_node.compiled_sql = jinja.get_rendered( node.raw_sql, context, node, ) compiled_node.relation_name = self._get_relation_name(node) compiled_node.compiled = True return compiled_node def write_graph_file(self, linker: Linker, manifest: Manifest): filename = graph_file_name graph_path = os.path.join(self.config.target_path, filename) if flags.WRITE_JSON: linker.write_graph(graph_path, manifest) def link_node( self, linker: Linker, node: GraphMemberNode, manifest: Manifest ): linker.add_node(node.unique_id) for dependency in node.depends_on_nodes: if dependency in manifest.nodes: linker.dependency( node.unique_id, (manifest.nodes[dependency].unique_id) ) elif dependency in manifest.sources: linker.dependency( node.unique_id, (manifest.sources[dependency].unique_id) ) else: dependency_not_found(node, dependency) def link_graph(self, linker: Linker, manifest: Manifest, add_test_edges: bool = False): for source in manifest.sources.values(): linker.add_node(source.unique_id) for node in manifest.nodes.values(): self.link_node(linker, node, manifest) for exposure in manifest.exposures.values(): self.link_node(linker, exposure, manifest) cycle = linker.find_cycles() if cycle: raise RuntimeError("Found a cycle: {}".format(cycle)) if add_test_edges: manifest.build_parent_and_child_maps() self.add_test_edges(linker, manifest) def add_test_edges(self, linker: Linker, manifest: Manifest) -> None: """ This method adds additional edges to the DAG. For a given non-test executable node, add an edge from an upstream test to the given node if the set of nodes the test depends on is a subset of the upstream nodes for the given node. """ # Given a graph: # model1 --> model2 --> model3 # | | # | \/ # \/ test 2 # test1 # # Produce the following graph: # model1 --> model2 --> model3 # | /\ | /\ /\ # | | \/ | | # \/ | test2 ----| | # test1 ----|---------------| for node_id in linker.graph: # If node is executable (in manifest.nodes) and does _not_ # represent a test, continue. if ( node_id in manifest.nodes and manifest.nodes[node_id].resource_type != NodeType.Test ): # Get *everything* upstream of the node all_upstream_nodes = nx.traversal.bfs_tree( linker.graph, node_id, reverse=True ) # Get the set of upstream nodes not including the current node. upstream_nodes = set([ n for n in all_upstream_nodes if n != node_id ]) # Get all tests that depend on any upstream nodes. upstream_tests = [] for upstream_node in upstream_nodes: upstream_tests += _get_tests_for_node( manifest, upstream_node ) for upstream_test in upstream_tests: # Get the set of all nodes that the test depends on # including the upstream_node itself. This is necessary # because tests can depend on multiple nodes (ex: # relationship tests). Test nodes do not distinguish # between what node the test is "testing" and what # node(s) it depends on. test_depends_on = set( manifest.nodes[upstream_test].depends_on_nodes ) # If the set of nodes that an upstream test depends on # is a subset of all upstream nodes of the current node, # add an edge from the upstream test to the current node. if (test_depends_on.issubset(upstream_nodes)): linker.graph.add_edge( upstream_test, node_id ) def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph: self.initialize() linker = Linker() self.link_graph(linker, manifest, add_test_edges) stats = _generate_stats(manifest) if write: self.write_graph_file(linker, manifest) print_compile_stats(stats) return Graph(linker.graph) # writes the "compiled_sql" into the target/compiled directory def _write_node(self, node: NonSourceCompiledNode) -> ManifestNode: if (not node.extra_ctes_injected or node.resource_type == NodeType.Snapshot): return node logger.debug(f'Writing injected SQL for node "{node.unique_id}"') if node.compiled_sql: node.compiled_path = node.write_node( self.config.target_path, 'compiled', node.compiled_sql ) return node def compile_node( self, node: ManifestNode, manifest: Manifest, extra_context: Optional[Dict[str, Any]] = None, write: bool = True, ) -> NonSourceCompiledNode: """This is the main entry point into this code. It's called by CompileRunner.compile, GenericRPCRunner.compile, and RunTask.get_hook_sql. It calls '_compile_node' to convert the node into a compiled node, and then calls the recursive method to "prepend" the ctes. """ node = self._compile_node(node, manifest, extra_context) node, _ = self._recursively_prepend_ctes( node, manifest, extra_context ) if write: self._write_node(node) return node