307 lines
9.8 KiB
Python
307 lines
9.8 KiB
Python
|
import os
|
||
|
import shutil
|
||
|
from datetime import datetime
|
||
|
from typing import Dict, List, Any, Optional, Tuple, Set
|
||
|
|
||
|
from dbt.dataclass_schema import ValidationError
|
||
|
|
||
|
from .compile import CompileTask
|
||
|
|
||
|
from dbt.adapters.factory import get_adapter
|
||
|
from dbt.contracts.graph.compiled import CompileResultNode
|
||
|
from dbt.contracts.graph.manifest import Manifest
|
||
|
from dbt.contracts.results import (
|
||
|
NodeStatus, TableMetadata, CatalogTable, CatalogResults, PrimitiveDict,
|
||
|
CatalogKey, StatsItem, StatsDict, ColumnMetadata, CatalogArtifact
|
||
|
)
|
||
|
from dbt.exceptions import InternalException
|
||
|
from dbt.include.global_project import DOCS_INDEX_FILE_PATH
|
||
|
from dbt.logger import GLOBAL_LOGGER as logger, print_timestamped_line
|
||
|
from dbt.parser.manifest import ManifestLoader
|
||
|
import dbt.utils
|
||
|
import dbt.compilation
|
||
|
import dbt.exceptions
|
||
|
|
||
|
|
||
|
CATALOG_FILENAME = 'catalog.json'
|
||
|
|
||
|
|
||
|
def get_stripped_prefix(source: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||
|
"""Go through source, extracting every key/value pair where the key starts
|
||
|
with the given prefix.
|
||
|
"""
|
||
|
cut = len(prefix)
|
||
|
return {
|
||
|
k[cut:]: v for k, v in source.items()
|
||
|
if k.startswith(prefix)
|
||
|
}
|
||
|
|
||
|
|
||
|
def build_catalog_table(data) -> CatalogTable:
|
||
|
# build the new table's metadata + stats
|
||
|
metadata = TableMetadata.from_dict(get_stripped_prefix(data, 'table_'))
|
||
|
stats = format_stats(get_stripped_prefix(data, 'stats:'))
|
||
|
|
||
|
return CatalogTable(
|
||
|
metadata=metadata,
|
||
|
stats=stats,
|
||
|
columns={},
|
||
|
)
|
||
|
|
||
|
|
||
|
# keys are database name, schema name, table name
|
||
|
class Catalog(Dict[CatalogKey, CatalogTable]):
|
||
|
def __init__(self, columns: List[PrimitiveDict]):
|
||
|
super().__init__()
|
||
|
for col in columns:
|
||
|
self.add_column(col)
|
||
|
|
||
|
def get_table(self, data: PrimitiveDict) -> CatalogTable:
|
||
|
database = data.get('table_database')
|
||
|
if database is None:
|
||
|
dkey: Optional[str] = None
|
||
|
else:
|
||
|
dkey = str(database)
|
||
|
|
||
|
try:
|
||
|
key = CatalogKey(
|
||
|
dkey,
|
||
|
str(data['table_schema']),
|
||
|
str(data['table_name']),
|
||
|
)
|
||
|
except KeyError as exc:
|
||
|
raise dbt.exceptions.CompilationException(
|
||
|
'Catalog information missing required key {} (got {})'
|
||
|
.format(exc, data)
|
||
|
)
|
||
|
table: CatalogTable
|
||
|
if key in self:
|
||
|
table = self[key]
|
||
|
else:
|
||
|
table = build_catalog_table(data)
|
||
|
self[key] = table
|
||
|
return table
|
||
|
|
||
|
def add_column(self, data: PrimitiveDict):
|
||
|
table = self.get_table(data)
|
||
|
column_data = get_stripped_prefix(data, 'column_')
|
||
|
# the index should really never be that big so it's ok to end up
|
||
|
# serializing this to JSON (2^53 is the max safe value there)
|
||
|
column_data['index'] = int(column_data['index'])
|
||
|
|
||
|
column = ColumnMetadata.from_dict(column_data)
|
||
|
table.columns[column.name] = column
|
||
|
|
||
|
def make_unique_id_map(
|
||
|
self, manifest: Manifest
|
||
|
) -> Tuple[Dict[str, CatalogTable], Dict[str, CatalogTable]]:
|
||
|
nodes: Dict[str, CatalogTable] = {}
|
||
|
sources: Dict[str, CatalogTable] = {}
|
||
|
|
||
|
node_map, source_map = get_unique_id_mapping(manifest)
|
||
|
table: CatalogTable
|
||
|
for table in self.values():
|
||
|
key = table.key()
|
||
|
if key in node_map:
|
||
|
unique_id = node_map[key]
|
||
|
nodes[unique_id] = table.replace(unique_id=unique_id)
|
||
|
|
||
|
unique_ids = source_map.get(table.key(), set())
|
||
|
for unique_id in unique_ids:
|
||
|
if unique_id in sources:
|
||
|
dbt.exceptions.raise_ambiguous_catalog_match(
|
||
|
unique_id,
|
||
|
sources[unique_id].to_dict(omit_none=True),
|
||
|
table.to_dict(omit_none=True),
|
||
|
)
|
||
|
else:
|
||
|
sources[unique_id] = table.replace(unique_id=unique_id)
|
||
|
return nodes, sources
|
||
|
|
||
|
|
||
|
def format_stats(stats: PrimitiveDict) -> StatsDict:
|
||
|
"""Given a dictionary following this layout:
|
||
|
|
||
|
{
|
||
|
'encoded:label': 'Encoded',
|
||
|
'encoded:value': 'Yes',
|
||
|
'encoded:description': 'Indicates if the column is encoded',
|
||
|
'encoded:include': True,
|
||
|
|
||
|
'size:label': 'Size',
|
||
|
'size:value': 128,
|
||
|
'size:description': 'Size of the table in MB',
|
||
|
'size:include': True,
|
||
|
}
|
||
|
|
||
|
format_stats will convert the dict into a StatsDict with keys of 'encoded'
|
||
|
and 'size'.
|
||
|
"""
|
||
|
stats_collector: StatsDict = {}
|
||
|
|
||
|
base_keys = {k.split(':')[0] for k in stats}
|
||
|
for key in base_keys:
|
||
|
dct: PrimitiveDict = {'id': key}
|
||
|
for subkey in ('label', 'value', 'description', 'include'):
|
||
|
dct[subkey] = stats['{}:{}'.format(key, subkey)]
|
||
|
|
||
|
try:
|
||
|
stats_item = StatsItem.from_dict(dct)
|
||
|
except ValidationError:
|
||
|
continue
|
||
|
if stats_item.include:
|
||
|
stats_collector[key] = stats_item
|
||
|
|
||
|
# we always have a 'has_stats' field, it's never included
|
||
|
has_stats = StatsItem(
|
||
|
id='has_stats',
|
||
|
label='Has Stats?',
|
||
|
value=len(stats_collector) > 0,
|
||
|
description='Indicates whether there are statistics for this table',
|
||
|
include=False,
|
||
|
)
|
||
|
stats_collector['has_stats'] = has_stats
|
||
|
return stats_collector
|
||
|
|
||
|
|
||
|
def mapping_key(node: CompileResultNode) -> CatalogKey:
|
||
|
dkey = dbt.utils.lowercase(node.database)
|
||
|
return CatalogKey(
|
||
|
dkey, node.schema.lower(), node.identifier.lower()
|
||
|
)
|
||
|
|
||
|
|
||
|
def get_unique_id_mapping(
|
||
|
manifest: Manifest
|
||
|
) -> Tuple[Dict[CatalogKey, str], Dict[CatalogKey, Set[str]]]:
|
||
|
# A single relation could have multiple unique IDs pointing to it if a
|
||
|
# source were also a node.
|
||
|
node_map: Dict[CatalogKey, str] = {}
|
||
|
source_map: Dict[CatalogKey, Set[str]] = {}
|
||
|
for unique_id, node in manifest.nodes.items():
|
||
|
key = mapping_key(node)
|
||
|
node_map[key] = unique_id
|
||
|
|
||
|
for unique_id, source in manifest.sources.items():
|
||
|
key = mapping_key(source)
|
||
|
if key not in source_map:
|
||
|
source_map[key] = set()
|
||
|
source_map[key].add(unique_id)
|
||
|
return node_map, source_map
|
||
|
|
||
|
|
||
|
class GenerateTask(CompileTask):
|
||
|
def _get_manifest(self) -> Manifest:
|
||
|
if self.manifest is None:
|
||
|
raise InternalException(
|
||
|
'manifest should not be None in _get_manifest'
|
||
|
)
|
||
|
return self.manifest
|
||
|
|
||
|
def run(self) -> CatalogArtifact:
|
||
|
compile_results = None
|
||
|
if self.args.compile:
|
||
|
compile_results = CompileTask.run(self)
|
||
|
if any(r.status == NodeStatus.Error for r in compile_results):
|
||
|
print_timestamped_line(
|
||
|
'compile failed, cannot generate docs'
|
||
|
)
|
||
|
return CatalogArtifact.from_results(
|
||
|
nodes={},
|
||
|
sources={},
|
||
|
generated_at=datetime.utcnow(),
|
||
|
errors=None,
|
||
|
compile_results=compile_results
|
||
|
)
|
||
|
else:
|
||
|
self.manifest = ManifestLoader.get_full_manifest(self.config)
|
||
|
|
||
|
shutil.copyfile(
|
||
|
DOCS_INDEX_FILE_PATH,
|
||
|
os.path.join(self.config.target_path, 'index.html'))
|
||
|
|
||
|
for asset_path in self.config.asset_paths:
|
||
|
to_asset_path = os.path.join(self.config.target_path, asset_path)
|
||
|
|
||
|
if os.path.exists(to_asset_path):
|
||
|
shutil.rmtree(to_asset_path)
|
||
|
|
||
|
if os.path.exists(asset_path):
|
||
|
shutil.copytree(
|
||
|
asset_path,
|
||
|
to_asset_path)
|
||
|
|
||
|
if self.manifest is None:
|
||
|
raise InternalException(
|
||
|
'self.manifest was None in run!'
|
||
|
)
|
||
|
|
||
|
adapter = get_adapter(self.config)
|
||
|
with adapter.connection_named('generate_catalog'):
|
||
|
print_timestamped_line("Building catalog")
|
||
|
catalog_table, exceptions = adapter.get_catalog(self.manifest)
|
||
|
|
||
|
catalog_data: List[PrimitiveDict] = [
|
||
|
dict(zip(catalog_table.column_names, map(dbt.utils._coerce_decimal, row)))
|
||
|
for row in catalog_table
|
||
|
]
|
||
|
|
||
|
catalog = Catalog(catalog_data)
|
||
|
|
||
|
errors: Optional[List[str]] = None
|
||
|
if exceptions:
|
||
|
errors = [str(e) for e in exceptions]
|
||
|
|
||
|
nodes, sources = catalog.make_unique_id_map(self.manifest)
|
||
|
results = self.get_catalog_results(
|
||
|
nodes=nodes,
|
||
|
sources=sources,
|
||
|
generated_at=datetime.utcnow(),
|
||
|
compile_results=compile_results,
|
||
|
errors=errors,
|
||
|
)
|
||
|
|
||
|
path = os.path.join(self.config.target_path, CATALOG_FILENAME)
|
||
|
results.write(path)
|
||
|
if self.args.compile:
|
||
|
self.write_manifest()
|
||
|
|
||
|
if exceptions:
|
||
|
logger.error(
|
||
|
'dbt encountered {} failure{} while writing the catalog'
|
||
|
.format(len(exceptions), (len(exceptions) != 1) * 's')
|
||
|
)
|
||
|
|
||
|
print_timestamped_line(
|
||
|
'Catalog written to {}'.format(os.path.abspath(path))
|
||
|
)
|
||
|
|
||
|
return results
|
||
|
|
||
|
def get_catalog_results(
|
||
|
self,
|
||
|
nodes: Dict[str, CatalogTable],
|
||
|
sources: Dict[str, CatalogTable],
|
||
|
generated_at: datetime,
|
||
|
compile_results: Optional[Any],
|
||
|
errors: Optional[List[str]]
|
||
|
) -> CatalogArtifact:
|
||
|
return CatalogArtifact.from_results(
|
||
|
generated_at=generated_at,
|
||
|
nodes=nodes,
|
||
|
sources=sources,
|
||
|
compile_results=compile_results,
|
||
|
errors=errors,
|
||
|
)
|
||
|
|
||
|
def interpret_results(self, results: Optional[CatalogResults]) -> bool:
|
||
|
if results is None:
|
||
|
return False
|
||
|
if results.errors:
|
||
|
return False
|
||
|
compile_results = results._compile_results
|
||
|
if compile_results is None:
|
||
|
return True
|
||
|
|
||
|
return super().interpret_results(compile_results)
|