dbt-selly/dbt-env/lib/python3.8/site-packages/dbt/graph/graph.py

117 lines
4.1 KiB
Python
Raw Normal View History

2022-03-22 15:13:27 +00:00
from typing import (
Set, Iterable, Iterator, Optional, NewType
)
from itertools import product
import networkx as nx # type: ignore
from dbt.exceptions import InternalException
UniqueId = NewType('UniqueId', str)
class Graph:
"""A wrapper around the networkx graph that understands SelectionCriteria
and how they interact with the graph.
"""
def __init__(self, graph):
self.graph = graph
def nodes(self) -> Set[UniqueId]:
return set(self.graph.nodes())
def edges(self):
return self.graph.edges()
def __iter__(self) -> Iterator[UniqueId]:
return iter(self.graph.nodes())
def ancestors(
self, node: UniqueId, max_depth: Optional[int] = None
) -> Set[UniqueId]:
"""Returns all nodes having a path to `node` in `graph`"""
if not self.graph.has_node(node):
raise InternalException(f'Node {node} not found in the graph!')
with nx.utils.reversed(self.graph):
anc = nx.single_source_shortest_path_length(G=self.graph,
source=node,
cutoff=max_depth)\
.keys()
return anc - {node}
def descendants(
self, node: UniqueId, max_depth: Optional[int] = None
) -> Set[UniqueId]:
"""Returns all nodes reachable from `node` in `graph`"""
if not self.graph.has_node(node):
raise InternalException(f'Node {node} not found in the graph!')
des = nx.single_source_shortest_path_length(G=self.graph,
source=node,
cutoff=max_depth)\
.keys()
return des - {node}
def select_childrens_parents(
self, selected: Set[UniqueId]
) -> Set[UniqueId]:
ancestors_for = self.select_children(selected) | selected
return self.select_parents(ancestors_for) | ancestors_for
def select_children(
self, selected: Set[UniqueId], max_depth: Optional[int] = None
) -> Set[UniqueId]:
descendants: Set[UniqueId] = set()
for node in selected:
descendants.update(self.descendants(node, max_depth))
return descendants
def select_parents(
self, selected: Set[UniqueId], max_depth: Optional[int] = None
) -> Set[UniqueId]:
ancestors: Set[UniqueId] = set()
for node in selected:
ancestors.update(self.ancestors(node, max_depth))
return ancestors
def select_successors(self, selected: Set[UniqueId]) -> Set[UniqueId]:
successors: Set[UniqueId] = set()
for node in selected:
successors.update(self.graph.successors(node))
return successors
def get_subset_graph(self, selected: Iterable[UniqueId]) -> "Graph":
"""Create and return a new graph that is a shallow copy of the graph,
but with only the nodes in include_nodes. Transitive edges across
removed nodes are preserved as explicit new edges.
"""
new_graph = self.graph.copy()
include_nodes = set(selected)
for node in self:
if node not in include_nodes:
source_nodes = [x for x, _ in new_graph.in_edges(node)]
target_nodes = [x for _, x in new_graph.out_edges(node)]
new_edges = product(source_nodes, target_nodes)
non_cyclic_new_edges = [
(source, target) for source, target in new_edges if source != target
] # removes cyclic refs
new_graph.add_edges_from(non_cyclic_new_edges)
new_graph.remove_node(node)
for node in include_nodes:
if node not in new_graph:
raise ValueError(
"Couldn't find model '{}' -- does it exist or is "
"it disabled?".format(node)
)
return Graph(new_graph)
def subgraph(self, nodes: Iterable[UniqueId]) -> 'Graph':
return Graph(self.graph.subgraph(nodes))
def get_dependent_nodes(self, node: UniqueId):
return nx.descendants(self.graph, node)