117 lines
4.1 KiB
Python
117 lines
4.1 KiB
Python
|
from typing import (
|
||
|
Set, Iterable, Iterator, Optional, NewType
|
||
|
)
|
||
|
from itertools import product
|
||
|
import networkx as nx # type: ignore
|
||
|
|
||
|
from dbt.exceptions import InternalException
|
||
|
|
||
|
UniqueId = NewType('UniqueId', str)
|
||
|
|
||
|
|
||
|
class Graph:
|
||
|
"""A wrapper around the networkx graph that understands SelectionCriteria
|
||
|
and how they interact with the graph.
|
||
|
"""
|
||
|
def __init__(self, graph):
|
||
|
self.graph = graph
|
||
|
|
||
|
def nodes(self) -> Set[UniqueId]:
|
||
|
return set(self.graph.nodes())
|
||
|
|
||
|
def edges(self):
|
||
|
return self.graph.edges()
|
||
|
|
||
|
def __iter__(self) -> Iterator[UniqueId]:
|
||
|
return iter(self.graph.nodes())
|
||
|
|
||
|
def ancestors(
|
||
|
self, node: UniqueId, max_depth: Optional[int] = None
|
||
|
) -> Set[UniqueId]:
|
||
|
"""Returns all nodes having a path to `node` in `graph`"""
|
||
|
if not self.graph.has_node(node):
|
||
|
raise InternalException(f'Node {node} not found in the graph!')
|
||
|
with nx.utils.reversed(self.graph):
|
||
|
anc = nx.single_source_shortest_path_length(G=self.graph,
|
||
|
source=node,
|
||
|
cutoff=max_depth)\
|
||
|
.keys()
|
||
|
return anc - {node}
|
||
|
|
||
|
def descendants(
|
||
|
self, node: UniqueId, max_depth: Optional[int] = None
|
||
|
) -> Set[UniqueId]:
|
||
|
"""Returns all nodes reachable from `node` in `graph`"""
|
||
|
if not self.graph.has_node(node):
|
||
|
raise InternalException(f'Node {node} not found in the graph!')
|
||
|
des = nx.single_source_shortest_path_length(G=self.graph,
|
||
|
source=node,
|
||
|
cutoff=max_depth)\
|
||
|
.keys()
|
||
|
return des - {node}
|
||
|
|
||
|
def select_childrens_parents(
|
||
|
self, selected: Set[UniqueId]
|
||
|
) -> Set[UniqueId]:
|
||
|
ancestors_for = self.select_children(selected) | selected
|
||
|
return self.select_parents(ancestors_for) | ancestors_for
|
||
|
|
||
|
def select_children(
|
||
|
self, selected: Set[UniqueId], max_depth: Optional[int] = None
|
||
|
) -> Set[UniqueId]:
|
||
|
descendants: Set[UniqueId] = set()
|
||
|
for node in selected:
|
||
|
descendants.update(self.descendants(node, max_depth))
|
||
|
return descendants
|
||
|
|
||
|
def select_parents(
|
||
|
self, selected: Set[UniqueId], max_depth: Optional[int] = None
|
||
|
) -> Set[UniqueId]:
|
||
|
ancestors: Set[UniqueId] = set()
|
||
|
for node in selected:
|
||
|
ancestors.update(self.ancestors(node, max_depth))
|
||
|
return ancestors
|
||
|
|
||
|
def select_successors(self, selected: Set[UniqueId]) -> Set[UniqueId]:
|
||
|
successors: Set[UniqueId] = set()
|
||
|
for node in selected:
|
||
|
successors.update(self.graph.successors(node))
|
||
|
return successors
|
||
|
|
||
|
def get_subset_graph(self, selected: Iterable[UniqueId]) -> "Graph":
|
||
|
"""Create and return a new graph that is a shallow copy of the graph,
|
||
|
but with only the nodes in include_nodes. Transitive edges across
|
||
|
removed nodes are preserved as explicit new edges.
|
||
|
"""
|
||
|
|
||
|
new_graph = self.graph.copy()
|
||
|
include_nodes = set(selected)
|
||
|
|
||
|
for node in self:
|
||
|
if node not in include_nodes:
|
||
|
source_nodes = [x for x, _ in new_graph.in_edges(node)]
|
||
|
target_nodes = [x for _, x in new_graph.out_edges(node)]
|
||
|
|
||
|
new_edges = product(source_nodes, target_nodes)
|
||
|
non_cyclic_new_edges = [
|
||
|
(source, target) for source, target in new_edges if source != target
|
||
|
] # removes cyclic refs
|
||
|
|
||
|
new_graph.add_edges_from(non_cyclic_new_edges)
|
||
|
new_graph.remove_node(node)
|
||
|
|
||
|
for node in include_nodes:
|
||
|
if node not in new_graph:
|
||
|
raise ValueError(
|
||
|
"Couldn't find model '{}' -- does it exist or is "
|
||
|
"it disabled?".format(node)
|
||
|
)
|
||
|
|
||
|
return Graph(new_graph)
|
||
|
|
||
|
def subgraph(self, nodes: Iterable[UniqueId]) -> 'Graph':
|
||
|
return Graph(self.graph.subgraph(nodes))
|
||
|
|
||
|
def get_dependent_nodes(self, node: UniqueId):
|
||
|
return nx.descendants(self.graph, node)
|