457 lines
14 KiB
Python
457 lines
14 KiB
Python
|
from collections.abc import Hashable
|
||
|
from dataclasses import dataclass
|
||
|
from typing import (
|
||
|
Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set
|
||
|
)
|
||
|
|
||
|
from dbt.contracts.graph.compiled import CompiledNode
|
||
|
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
|
||
|
from dbt.contracts.relation import (
|
||
|
RelationType, ComponentName, HasQuoting, FakeAPIObject, Policy, Path
|
||
|
)
|
||
|
from dbt.exceptions import InternalException
|
||
|
from dbt.node_types import NodeType
|
||
|
from dbt.utils import filter_null_values, deep_merge, classproperty
|
||
|
|
||
|
import dbt.exceptions
|
||
|
|
||
|
|
||
|
Self = TypeVar('Self', bound='BaseRelation')
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True, eq=False, repr=False)
|
||
|
class BaseRelation(FakeAPIObject, Hashable):
|
||
|
path: Path
|
||
|
type: Optional[RelationType] = None
|
||
|
quote_character: str = '"'
|
||
|
include_policy: Policy = Policy()
|
||
|
quote_policy: Policy = Policy()
|
||
|
dbt_created: bool = False
|
||
|
|
||
|
def _is_exactish_match(self, field: ComponentName, value: str) -> bool:
|
||
|
if self.dbt_created and self.quote_policy.get_part(field) is False:
|
||
|
return self.path.get_lowered_part(field) == value.lower()
|
||
|
else:
|
||
|
return self.path.get_part(field) == value
|
||
|
|
||
|
@classmethod
|
||
|
def _get_field_named(cls, field_name):
|
||
|
for field, _ in cls._get_fields():
|
||
|
if field.name == field_name:
|
||
|
return field
|
||
|
# this should be unreachable
|
||
|
raise ValueError(f'BaseRelation has no {field_name} field!')
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
if not isinstance(other, self.__class__):
|
||
|
return False
|
||
|
return self.to_dict(omit_none=True) == other.to_dict(omit_none=True)
|
||
|
|
||
|
@classmethod
|
||
|
def get_default_quote_policy(cls) -> Policy:
|
||
|
return cls._get_field_named('quote_policy').default
|
||
|
|
||
|
@classmethod
|
||
|
def get_default_include_policy(cls) -> Policy:
|
||
|
return cls._get_field_named('include_policy').default
|
||
|
|
||
|
def get(self, key, default=None):
|
||
|
"""Override `.get` to return a metadata object so we don't break
|
||
|
dbt_utils.
|
||
|
"""
|
||
|
if key == 'metadata':
|
||
|
return {
|
||
|
'type': self.__class__.__name__
|
||
|
}
|
||
|
return super().get(key, default)
|
||
|
|
||
|
def matches(
|
||
|
self,
|
||
|
database: Optional[str] = None,
|
||
|
schema: Optional[str] = None,
|
||
|
identifier: Optional[str] = None,
|
||
|
) -> bool:
|
||
|
search = filter_null_values({
|
||
|
ComponentName.Database: database,
|
||
|
ComponentName.Schema: schema,
|
||
|
ComponentName.Identifier: identifier
|
||
|
})
|
||
|
|
||
|
if not search:
|
||
|
# nothing was passed in
|
||
|
raise dbt.exceptions.RuntimeException(
|
||
|
"Tried to match relation, but no search path was passed!")
|
||
|
|
||
|
exact_match = True
|
||
|
approximate_match = True
|
||
|
|
||
|
for k, v in search.items():
|
||
|
if not self._is_exactish_match(k, v):
|
||
|
exact_match = False
|
||
|
|
||
|
if self.path.get_lowered_part(k) != v.lower():
|
||
|
approximate_match = False
|
||
|
|
||
|
if approximate_match and not exact_match:
|
||
|
target = self.create(
|
||
|
database=database, schema=schema, identifier=identifier
|
||
|
)
|
||
|
dbt.exceptions.approximate_relation_match(target, self)
|
||
|
|
||
|
return exact_match
|
||
|
|
||
|
def replace_path(self, **kwargs):
|
||
|
return self.replace(path=self.path.replace(**kwargs))
|
||
|
|
||
|
def quote(
|
||
|
self: Self,
|
||
|
database: Optional[bool] = None,
|
||
|
schema: Optional[bool] = None,
|
||
|
identifier: Optional[bool] = None,
|
||
|
) -> Self:
|
||
|
policy = filter_null_values({
|
||
|
ComponentName.Database: database,
|
||
|
ComponentName.Schema: schema,
|
||
|
ComponentName.Identifier: identifier
|
||
|
})
|
||
|
|
||
|
new_quote_policy = self.quote_policy.replace_dict(policy)
|
||
|
return self.replace(quote_policy=new_quote_policy)
|
||
|
|
||
|
def include(
|
||
|
self: Self,
|
||
|
database: Optional[bool] = None,
|
||
|
schema: Optional[bool] = None,
|
||
|
identifier: Optional[bool] = None,
|
||
|
) -> Self:
|
||
|
policy = filter_null_values({
|
||
|
ComponentName.Database: database,
|
||
|
ComponentName.Schema: schema,
|
||
|
ComponentName.Identifier: identifier
|
||
|
})
|
||
|
|
||
|
new_include_policy = self.include_policy.replace_dict(policy)
|
||
|
return self.replace(include_policy=new_include_policy)
|
||
|
|
||
|
def information_schema(self, view_name=None) -> 'InformationSchema':
|
||
|
# some of our data comes from jinja, where things can be `Undefined`.
|
||
|
if not isinstance(view_name, str):
|
||
|
view_name = None
|
||
|
|
||
|
# Kick the user-supplied schema out of the information schema relation
|
||
|
# Instead address this as <database>.information_schema by default
|
||
|
info_schema = InformationSchema.from_relation(self, view_name)
|
||
|
return info_schema.incorporate(path={"schema": None})
|
||
|
|
||
|
def information_schema_only(self) -> 'InformationSchema':
|
||
|
return self.information_schema()
|
||
|
|
||
|
def without_identifier(self) -> 'BaseRelation':
|
||
|
"""Return a form of this relation that only has the database and schema
|
||
|
set to included. To get the appropriately-quoted form the schema out of
|
||
|
the result (for use as part of a query), use `.render()`. To get the
|
||
|
raw database or schema name, use `.database` or `.schema`.
|
||
|
|
||
|
The hash of the returned object is the result of render().
|
||
|
"""
|
||
|
return self.include(identifier=False).replace_path(identifier=None)
|
||
|
|
||
|
def _render_iterator(
|
||
|
self
|
||
|
) -> Iterator[Tuple[Optional[ComponentName], Optional[str]]]:
|
||
|
|
||
|
for key in ComponentName:
|
||
|
path_part: Optional[str] = None
|
||
|
if self.include_policy.get_part(key):
|
||
|
path_part = self.path.get_part(key)
|
||
|
if path_part is not None and self.quote_policy.get_part(key):
|
||
|
path_part = self.quoted(path_part)
|
||
|
yield key, path_part
|
||
|
|
||
|
def render(self) -> str:
|
||
|
# if there is nothing set, this will return the empty string.
|
||
|
return '.'.join(
|
||
|
part for _, part in self._render_iterator()
|
||
|
if part is not None
|
||
|
)
|
||
|
|
||
|
def quoted(self, identifier):
|
||
|
return '{quote_char}{identifier}{quote_char}'.format(
|
||
|
quote_char=self.quote_character,
|
||
|
identifier=identifier,
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def create_from_source(
|
||
|
cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any
|
||
|
) -> Self:
|
||
|
source_quoting = source.quoting.to_dict(omit_none=True)
|
||
|
source_quoting.pop('column', None)
|
||
|
quote_policy = deep_merge(
|
||
|
cls.get_default_quote_policy().to_dict(omit_none=True),
|
||
|
source_quoting,
|
||
|
kwargs.get('quote_policy', {}),
|
||
|
)
|
||
|
|
||
|
return cls.create(
|
||
|
database=source.database,
|
||
|
schema=source.schema,
|
||
|
identifier=source.identifier,
|
||
|
quote_policy=quote_policy,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def add_ephemeral_prefix(name: str):
|
||
|
return f'__dbt__cte__{name}'
|
||
|
|
||
|
@classmethod
|
||
|
def create_ephemeral_from_node(
|
||
|
cls: Type[Self],
|
||
|
config: HasQuoting,
|
||
|
node: Union[ParsedNode, CompiledNode],
|
||
|
) -> Self:
|
||
|
# Note that ephemeral models are based on the name.
|
||
|
identifier = cls.add_ephemeral_prefix(node.name)
|
||
|
return cls.create(
|
||
|
type=cls.CTE,
|
||
|
identifier=identifier,
|
||
|
).quote(identifier=False)
|
||
|
|
||
|
@classmethod
|
||
|
def create_from_node(
|
||
|
cls: Type[Self],
|
||
|
config: HasQuoting,
|
||
|
node: Union[ParsedNode, CompiledNode],
|
||
|
quote_policy: Optional[Dict[str, bool]] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> Self:
|
||
|
if quote_policy is None:
|
||
|
quote_policy = {}
|
||
|
|
||
|
quote_policy = dbt.utils.merge(config.quoting, quote_policy)
|
||
|
|
||
|
return cls.create(
|
||
|
database=node.database,
|
||
|
schema=node.schema,
|
||
|
identifier=node.alias,
|
||
|
quote_policy=quote_policy,
|
||
|
**kwargs)
|
||
|
|
||
|
@classmethod
|
||
|
def create_from(
|
||
|
cls: Type[Self],
|
||
|
config: HasQuoting,
|
||
|
node: Union[CompiledNode, ParsedNode, ParsedSourceDefinition],
|
||
|
**kwargs: Any,
|
||
|
) -> Self:
|
||
|
if node.resource_type == NodeType.Source:
|
||
|
if not isinstance(node, ParsedSourceDefinition):
|
||
|
raise InternalException(
|
||
|
'type mismatch, expected ParsedSourceDefinition but got {}'
|
||
|
.format(type(node))
|
||
|
)
|
||
|
return cls.create_from_source(node, **kwargs)
|
||
|
else:
|
||
|
if not isinstance(node, (ParsedNode, CompiledNode)):
|
||
|
raise InternalException(
|
||
|
'type mismatch, expected ParsedNode or CompiledNode but '
|
||
|
'got {}'.format(type(node))
|
||
|
)
|
||
|
return cls.create_from_node(config, node, **kwargs)
|
||
|
|
||
|
@classmethod
|
||
|
def create(
|
||
|
cls: Type[Self],
|
||
|
database: Optional[str] = None,
|
||
|
schema: Optional[str] = None,
|
||
|
identifier: Optional[str] = None,
|
||
|
type: Optional[RelationType] = None,
|
||
|
**kwargs,
|
||
|
) -> Self:
|
||
|
kwargs.update({
|
||
|
'path': {
|
||
|
'database': database,
|
||
|
'schema': schema,
|
||
|
'identifier': identifier,
|
||
|
},
|
||
|
'type': type,
|
||
|
})
|
||
|
return cls.from_dict(kwargs)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return "<{} {}>".format(self.__class__.__name__, self.render())
|
||
|
|
||
|
def __hash__(self) -> int:
|
||
|
return hash(self.render())
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return self.render()
|
||
|
|
||
|
@property
|
||
|
def database(self) -> Optional[str]:
|
||
|
return self.path.database
|
||
|
|
||
|
@property
|
||
|
def schema(self) -> Optional[str]:
|
||
|
return self.path.schema
|
||
|
|
||
|
@property
|
||
|
def identifier(self) -> Optional[str]:
|
||
|
return self.path.identifier
|
||
|
|
||
|
@property
|
||
|
def table(self) -> Optional[str]:
|
||
|
return self.path.identifier
|
||
|
|
||
|
# Here for compatibility with old Relation interface
|
||
|
@property
|
||
|
def name(self) -> Optional[str]:
|
||
|
return self.identifier
|
||
|
|
||
|
@property
|
||
|
def is_table(self) -> bool:
|
||
|
return self.type == RelationType.Table
|
||
|
|
||
|
@property
|
||
|
def is_cte(self) -> bool:
|
||
|
return self.type == RelationType.CTE
|
||
|
|
||
|
@property
|
||
|
def is_view(self) -> bool:
|
||
|
return self.type == RelationType.View
|
||
|
|
||
|
@classproperty
|
||
|
def Table(cls) -> str:
|
||
|
return str(RelationType.Table)
|
||
|
|
||
|
@classproperty
|
||
|
def CTE(cls) -> str:
|
||
|
return str(RelationType.CTE)
|
||
|
|
||
|
@classproperty
|
||
|
def View(cls) -> str:
|
||
|
return str(RelationType.View)
|
||
|
|
||
|
@classproperty
|
||
|
def External(cls) -> str:
|
||
|
return str(RelationType.External)
|
||
|
|
||
|
@classproperty
|
||
|
def get_relation_type(cls) -> Type[RelationType]:
|
||
|
return RelationType
|
||
|
|
||
|
|
||
|
Info = TypeVar('Info', bound='InformationSchema')
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True, eq=False, repr=False)
|
||
|
class InformationSchema(BaseRelation):
|
||
|
information_schema_view: Optional[str] = None
|
||
|
|
||
|
def __post_init__(self):
|
||
|
if not isinstance(self.information_schema_view, (type(None), str)):
|
||
|
raise dbt.exceptions.CompilationException(
|
||
|
'Got an invalid name: {}'.format(self.information_schema_view)
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def get_path(
|
||
|
cls, relation: BaseRelation, information_schema_view: Optional[str]
|
||
|
) -> Path:
|
||
|
return Path(
|
||
|
database=relation.database,
|
||
|
schema=relation.schema,
|
||
|
identifier='INFORMATION_SCHEMA',
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def get_include_policy(
|
||
|
cls,
|
||
|
relation,
|
||
|
information_schema_view: Optional[str],
|
||
|
) -> Policy:
|
||
|
return relation.include_policy.replace(
|
||
|
database=relation.database is not None,
|
||
|
schema=False,
|
||
|
identifier=True,
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def get_quote_policy(
|
||
|
cls,
|
||
|
relation,
|
||
|
information_schema_view: Optional[str],
|
||
|
) -> Policy:
|
||
|
return relation.quote_policy.replace(
|
||
|
identifier=False,
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def from_relation(
|
||
|
cls: Type[Info],
|
||
|
relation: BaseRelation,
|
||
|
information_schema_view: Optional[str],
|
||
|
) -> Info:
|
||
|
include_policy = cls.get_include_policy(
|
||
|
relation, information_schema_view
|
||
|
)
|
||
|
quote_policy = cls.get_quote_policy(relation, information_schema_view)
|
||
|
path = cls.get_path(relation, information_schema_view)
|
||
|
return cls(
|
||
|
type=RelationType.View,
|
||
|
path=path,
|
||
|
include_policy=include_policy,
|
||
|
quote_policy=quote_policy,
|
||
|
information_schema_view=information_schema_view,
|
||
|
)
|
||
|
|
||
|
def _render_iterator(self):
|
||
|
for k, v in super()._render_iterator():
|
||
|
yield k, v
|
||
|
yield None, self.information_schema_view
|
||
|
|
||
|
|
||
|
class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
|
||
|
"""A utility class to keep track of what information_schema tables to
|
||
|
search for what schemas. The schema values are all lowercased to avoid
|
||
|
duplication.
|
||
|
"""
|
||
|
def add(self, relation: BaseRelation):
|
||
|
key = relation.information_schema_only()
|
||
|
if key not in self:
|
||
|
self[key] = set()
|
||
|
schema: Optional[str] = None
|
||
|
if relation.schema is not None:
|
||
|
schema = relation.schema.lower()
|
||
|
self[key].add(schema)
|
||
|
|
||
|
def search(
|
||
|
self
|
||
|
) -> Iterator[Tuple[InformationSchema, Optional[str]]]:
|
||
|
for information_schema_name, schemas in self.items():
|
||
|
for schema in schemas:
|
||
|
yield information_schema_name, schema
|
||
|
|
||
|
def flatten(self, allow_multiple_databases: bool = False):
|
||
|
new = self.__class__()
|
||
|
|
||
|
# make sure we don't have multiple databases if allow_multiple_databases is set to False
|
||
|
if not allow_multiple_databases:
|
||
|
seen = {r.database.lower() for r in self if r.database}
|
||
|
if len(seen) > 1:
|
||
|
dbt.exceptions.raise_compiler_error(str(seen))
|
||
|
|
||
|
for information_schema_name, schema in self.search():
|
||
|
path = {
|
||
|
'database': information_schema_name.database,
|
||
|
'schema': schema
|
||
|
}
|
||
|
new.add(information_schema_name.incorporate(
|
||
|
path=path,
|
||
|
quote_policy={'database': False},
|
||
|
include_policy={'database': False},
|
||
|
))
|
||
|
|
||
|
return new
|