dbt-selly/dbt-env/lib/python3.8/site-packages/dbt/adapters/base/impl.py

1260 lines
43 KiB
Python
Raw Normal View History

2022-03-22 15:13:27 +00:00
import abc
from concurrent.futures import as_completed, Future
from contextlib import contextmanager
from datetime import datetime
from itertools import chain
from typing import (
Optional, Tuple, Callable, Iterable, Type, Dict, Any, List, Mapping,
Iterator, Union, Set
)
import agate
import pytz
from dbt.exceptions import (
raise_database_error, raise_compiler_error, invalid_type_error,
get_relation_returned_multiple_results,
InternalException, NotImplementedException, RuntimeException,
)
from dbt import flags
from dbt import deprecations
from dbt.adapters.protocol import (
AdapterConfig,
ConnectionManagerProtocol,
)
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
from dbt.clients.jinja import MacroGenerator
from dbt.contracts.graph.compiled import (
CompileResultNode, CompiledSeedNode
)
from dbt.contracts.graph.manifest import Manifest, MacroManifest
from dbt.contracts.graph.parsed import ParsedSeedNode
from dbt.exceptions import warn_or_error
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import filter_null_values, executor
from dbt.adapters.base.connections import Connection, AdapterResponse
from dbt.adapters.base.meta import AdapterMeta, available
from dbt.adapters.base.relation import (
ComponentName, BaseRelation, InformationSchema, SchemaSearchMap
)
from dbt.adapters.base import Column as BaseColumn
from dbt.adapters.cache import RelationsCache
SeedModel = Union[ParsedSeedNode, CompiledSeedNode]
GET_CATALOG_MACRO_NAME = 'get_catalog'
FRESHNESS_MACRO_NAME = 'collect_freshness'
def _expect_row_value(key: str, row: agate.Row):
if key not in row.keys():
raise InternalException(
'Got a row without "{}" column, columns: {}'
.format(key, row.keys())
)
return row[key]
def _catalog_filter_schemas(manifest: Manifest) -> Callable[[agate.Row], bool]:
"""Return a function that takes a row and decides if the row should be
included in the catalog output.
"""
schemas = frozenset((d.lower(), s.lower())
for d, s in manifest.get_used_schemas())
def test(row: agate.Row) -> bool:
table_database = _expect_row_value('table_database', row)
table_schema = _expect_row_value('table_schema', row)
# the schema may be present but None, which is not an error and should
# be filtered out
if table_schema is None:
return False
return (table_database.lower(), table_schema.lower()) in schemas
return test
def _utc(
dt: Optional[datetime], source: BaseRelation, field_name: str
) -> datetime:
"""If dt has a timezone, return a new datetime that's in UTC. Otherwise,
assume the datetime is already for UTC and add the timezone.
"""
if dt is None:
raise raise_database_error(
"Expected a non-null value when querying field '{}' of table "
" {} but received value 'null' instead".format(
field_name,
source))
elif not hasattr(dt, 'tzinfo'):
raise raise_database_error(
"Expected a timestamp value when querying field '{}' of table "
"{} but received value of type '{}' instead".format(
field_name,
source,
type(dt).__name__))
elif dt.tzinfo:
return dt.astimezone(pytz.UTC)
else:
return dt.replace(tzinfo=pytz.UTC)
def _relation_name(rel: Optional[BaseRelation]) -> str:
if rel is None:
return 'null relation'
else:
return str(rel)
class BaseAdapter(metaclass=AdapterMeta):
"""The BaseAdapter provides an abstract base class for adapters.
Adapters must implement the following methods and macros. Some of the
methods can be safely overridden as a noop, where it makes sense
(transactions on databases that don't support them, for instance). Those
methods are marked with a (passable) in their docstrings. Check docstrings
for type information, etc.
To implement a macro, implement "${adapter_type}__${macro_name}". in the
adapter's internal project.
Methods:
- exception_handler
- date_function
- list_schemas
- drop_relation
- truncate_relation
- rename_relation
- get_columns_in_relation
- expand_column_types
- list_relations_without_caching
- is_cancelable
- create_schema
- drop_schema
- quote
- convert_text_type
- convert_number_type
- convert_boolean_type
- convert_datetime_type
- convert_date_type
- convert_time_type
Macros:
- get_catalog
"""
Relation: Type[BaseRelation] = BaseRelation
Column: Type[BaseColumn] = BaseColumn
ConnectionManager: Type[ConnectionManagerProtocol]
# A set of clobber config fields accepted by this adapter
# for use in materializations
AdapterSpecificConfigs: Type[AdapterConfig] = AdapterConfig
def __init__(self, config):
self.config = config
self.cache = RelationsCache()
self.connections = self.ConnectionManager(config)
self._macro_manifest_lazy: Optional[MacroManifest] = None
###
# Methods that pass through to the connection manager
###
def acquire_connection(self, name=None) -> Connection:
return self.connections.set_connection_name(name)
def release_connection(self) -> None:
self.connections.release()
def cleanup_connections(self) -> None:
self.connections.cleanup_all()
def clear_transaction(self) -> None:
self.connections.clear_transaction()
def commit_if_has_connection(self) -> None:
self.connections.commit_if_has_connection()
def debug_query(self) -> None:
self.execute('select 1 as id')
def nice_connection_name(self) -> str:
conn = self.connections.get_if_exists()
if conn is None or conn.name is None:
return '<None>'
return conn.name
@contextmanager
def connection_named(
self, name: str, node: Optional[CompileResultNode] = None
) -> Iterator[None]:
try:
if self.connections.query_header is not None:
self.connections.query_header.set(name, node)
self.acquire_connection(name)
yield
finally:
self.release_connection()
if self.connections.query_header is not None:
self.connections.query_header.reset()
@contextmanager
def connection_for(
self, node: CompileResultNode
) -> Iterator[None]:
with self.connection_named(node.unique_id, node):
yield
@available.parse(lambda *a, **k: ('', empty_table()))
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[Union[str, AdapterResponse], agate.Table]:
"""Execute the given SQL. This is a thin wrapper around
ConnectionManager.execute.
:param str sql: The sql to execute.
:param bool auto_begin: If set, and dbt is not currently inside a
transaction, automatically begin one.
:param bool fetch: If set, fetch results.
:return: A tuple of the status and the results (empty if fetch=False).
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
"""
return self.connections.execute(
sql=sql,
auto_begin=auto_begin,
fetch=fetch
)
@available.parse(lambda *a, **k: ('', empty_table()))
def get_partitions_metadata(
self, table: str
) -> Tuple[agate.Table]:
"""Obtain partitions metadata for a BigQuery partitioned table.
:param str table_id: a partitioned table id, in standard SQL format.
:return: a partition metadata tuple, as described in
https://cloud.google.com/bigquery/docs/creating-partitioned-tables#getting_partition_metadata_using_meta_tables.
:rtype: agate.Table
"""
return self.connections.get_partitions_metadata(
table=table
)
###
# Methods that should never be overridden
###
@classmethod
def type(cls) -> str:
"""Get the type of this adapter. Types must be class-unique and
consistent.
:return: The type name
:rtype: str
"""
return cls.ConnectionManager.TYPE
@property
def _macro_manifest(self) -> MacroManifest:
if self._macro_manifest_lazy is None:
return self.load_macro_manifest()
return self._macro_manifest_lazy
def check_macro_manifest(self) -> Optional[MacroManifest]:
"""Return the internal manifest (used for executing macros) if it's
been initialized, otherwise return None.
"""
return self._macro_manifest_lazy
def load_macro_manifest(self) -> MacroManifest:
if self._macro_manifest_lazy is None:
# avoid a circular import
from dbt.parser.manifest import ManifestLoader
manifest = ManifestLoader.load_macros(
self.config, self.connections.set_query_header
)
self._macro_manifest_lazy = manifest
return self._macro_manifest_lazy
def clear_macro_manifest(self):
if self._macro_manifest_lazy is not None:
self._macro_manifest_lazy = None
###
# Caching methods
###
def _schema_is_cached(self, database: Optional[str], schema: str) -> bool:
"""Check if the schema is cached, and by default logs if it is not."""
if flags.USE_CACHE is False:
return False
elif (database, schema) not in self.cache:
logger.debug(
'On "{}": cache miss for schema "{}.{}", this is inefficient'
.format(self.nice_connection_name(), database, schema)
)
return False
else:
return True
def _get_cache_schemas(self, manifest: Manifest) -> Set[BaseRelation]:
"""Get the set of schema relations that the cache logic needs to
populate. This means only executable nodes are included.
"""
# the cache only cares about executable nodes
return {
self.Relation.create_from(self.config, node).without_identifier()
for node in manifest.nodes.values()
if (
node.is_relational and not node.is_ephemeral_model
)
}
def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap:
"""Get a mapping of each node's "information_schema" relations to a
set of all schemas expected in that information_schema.
There may be keys that are technically duplicates on the database side,
for example all of '"foo", 'foo', '"FOO"' and 'FOO' could coexist as
databases, and values could overlap as appropriate. All values are
lowercase strings.
"""
info_schema_name_map = SchemaSearchMap()
nodes: Iterator[CompileResultNode] = chain(
manifest.nodes.values(),
manifest.sources.values(),
)
for node in nodes:
relation = self.Relation.create_from(self.config, node)
info_schema_name_map.add(relation)
# result is a map whose keys are information_schema Relations without
# identifiers that have appropriate database prefixes, and whose values
# are sets of lowercase schema names that are valid members of those
# databases
return info_schema_name_map
def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
"""Populate the relations cache for the given schemas. Returns an
iterable of the schemas populated, as strings.
"""
if not flags.USE_CACHE:
return
cache_schemas = self._get_cache_schemas(manifest)
with executor(self.config) as tpe:
futures: List[Future[List[BaseRelation]]] = []
for cache_schema in cache_schemas:
fut = tpe.submit_connected(
self,
f'list_{cache_schema.database}_{cache_schema.schema}',
self.list_relations_without_caching,
cache_schema
)
futures.append(fut)
for future in as_completed(futures):
# if we can't read the relations we need to just raise anyway,
# so just call future.result() and let that raise on failure
for relation in future.result():
self.cache.add(relation)
# it's possible that there were no relations in some schemas. We want
# to insert the schemas we query into the cache's `.schemas` attribute
# so we can check it later
cache_update: Set[Tuple[Optional[str], Optional[str]]] = set()
for relation in cache_schemas:
cache_update.add((relation.database, relation.schema))
self.cache.update_schemas(cache_update)
def set_relations_cache(
self, manifest: Manifest, clear: bool = False
) -> None:
"""Run a query that gets a populated cache of the relations in the
database and set the cache on this adapter.
"""
if not flags.USE_CACHE:
return
with self.cache.lock:
if clear:
self.cache.clear()
self._relations_cache_for_schemas(manifest)
@available
def cache_added(self, relation: Optional[BaseRelation]) -> str:
"""Cache a new relation in dbt. It will show up in `list relations`."""
if relation is None:
name = self.nice_connection_name()
raise_compiler_error(
'Attempted to cache a null relation for {}'.format(name)
)
if flags.USE_CACHE:
self.cache.add(relation)
# so jinja doesn't render things
return ''
@available
def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
"""Drop a relation in dbt. It will no longer show up in
`list relations`, and any bound views will be dropped from the cache
"""
if relation is None:
name = self.nice_connection_name()
raise_compiler_error(
'Attempted to drop a null relation for {}'.format(name)
)
if flags.USE_CACHE:
self.cache.drop(relation)
return ''
@available
def cache_renamed(
self,
from_relation: Optional[BaseRelation],
to_relation: Optional[BaseRelation],
) -> str:
"""Rename a relation in dbt. It will show up with a new name in
`list_relations`, but bound views will remain bound.
"""
if from_relation is None or to_relation is None:
name = self.nice_connection_name()
src_name = _relation_name(from_relation)
dst_name = _relation_name(to_relation)
raise_compiler_error(
'Attempted to rename {} to {} for {}'
.format(src_name, dst_name, name)
)
if flags.USE_CACHE:
self.cache.rename(from_relation, to_relation)
return ''
###
# Abstract methods for database-specific values, attributes, and types
###
@abc.abstractclassmethod
def date_function(cls) -> str:
"""Get the date function used by this adapter's database."""
raise NotImplementedException(
'`date_function` is not implemented for this adapter!')
@abc.abstractclassmethod
def is_cancelable(cls) -> bool:
raise NotImplementedException(
'`is_cancelable` is not implemented for this adapter!'
)
###
# Abstract methods about schemas
###
@abc.abstractmethod
def list_schemas(self, database: str) -> List[str]:
"""Get a list of existing schemas in database"""
raise NotImplementedException(
'`list_schemas` is not implemented for this adapter!'
)
@available.parse(lambda *a, **k: False)
def check_schema_exists(self, database: str, schema: str) -> bool:
"""Check if a schema exists.
The default implementation of this is potentially unnecessarily slow,
and adapters should implement it if there is an optimized path (and
there probably is)
"""
search = (
s.lower() for s in
self.list_schemas(database=database)
)
return schema.lower() in search
###
# Abstract methods about relations
###
@abc.abstractmethod
@available.parse_none
def drop_relation(self, relation: BaseRelation) -> None:
"""Drop the given relation.
*Implementors must call self.cache.drop() to preserve cache state!*
"""
raise NotImplementedException(
'`drop_relation` is not implemented for this adapter!'
)
@abc.abstractmethod
@available.parse_none
def truncate_relation(self, relation: BaseRelation) -> None:
"""Truncate the given relation."""
raise NotImplementedException(
'`truncate_relation` is not implemented for this adapter!'
)
@abc.abstractmethod
@available.parse_none
def rename_relation(
self, from_relation: BaseRelation, to_relation: BaseRelation
) -> None:
"""Rename the relation from from_relation to to_relation.
Implementors must call self.cache.rename() to preserve cache state.
"""
raise NotImplementedException(
'`rename_relation` is not implemented for this adapter!'
)
@abc.abstractmethod
@available.parse_list
def get_columns_in_relation(
self, relation: BaseRelation
) -> List[BaseColumn]:
"""Get a list of the columns in the given Relation. """
raise NotImplementedException(
'`get_columns_in_relation` is not implemented for this adapter!'
)
@available.deprecated('get_columns_in_relation', lambda *a, **k: [])
def get_columns_in_table(
self, schema: str, identifier: str
) -> List[BaseColumn]:
"""DEPRECATED: Get a list of the columns in the given table."""
relation = self.Relation.create(
database=self.config.credentials.database,
schema=schema,
identifier=identifier,
quote_policy=self.config.quoting
)
return self.get_columns_in_relation(relation)
@abc.abstractmethod
def expand_column_types(
self, goal: BaseRelation, current: BaseRelation
) -> None:
"""Expand the current table's types to match the goal table. (passable)
:param self.Relation goal: A relation that currently exists in the
database with columns of the desired types.
:param self.Relation current: A relation that currently exists in the
database with columns of unspecified types.
"""
raise NotImplementedException(
'`expand_target_column_types` is not implemented for this adapter!'
)
@abc.abstractmethod
def list_relations_without_caching(
self, schema_relation: BaseRelation
) -> List[BaseRelation]:
"""List relations in the given schema, bypassing the cache.
This is used as the underlying behavior to fill the cache.
:param schema_relation: A relation containing the database and schema
as appropraite for the underlying data warehouse
:return: The relations in schema
:rtype: List[self.Relation]
"""
raise NotImplementedException(
'`list_relations_without_caching` is not implemented for this '
'adapter!'
)
###
# Provided methods about relations
###
@available.parse_list
def get_missing_columns(
self, from_relation: BaseRelation, to_relation: BaseRelation
) -> List[BaseColumn]:
"""Returns a list of Columns in from_relation that are missing from
to_relation.
"""
if not isinstance(from_relation, self.Relation):
invalid_type_error(
method_name='get_missing_columns',
arg_name='from_relation',
got_value=from_relation,
expected_type=self.Relation)
if not isinstance(to_relation, self.Relation):
invalid_type_error(
method_name='get_missing_columns',
arg_name='to_relation',
got_value=to_relation,
expected_type=self.Relation)
from_columns = {
col.name: col for col in
self.get_columns_in_relation(from_relation)
}
to_columns = {
col.name: col for col in
self.get_columns_in_relation(to_relation)
}
missing_columns = set(from_columns.keys()) - set(to_columns.keys())
return [
col for (col_name, col) in from_columns.items()
if col_name in missing_columns
]
@available.parse_none
def valid_snapshot_target(self, relation: BaseRelation) -> None:
"""Ensure that the target relation is valid, by making sure it has the
expected columns.
:param Relation relation: The relation to check
:raises CompilationException: If the columns are
incorrect.
"""
if not isinstance(relation, self.Relation):
invalid_type_error(
method_name='valid_snapshot_target',
arg_name='relation',
got_value=relation,
expected_type=self.Relation)
columns = self.get_columns_in_relation(relation)
names = set(c.name.lower() for c in columns)
expanded_keys = ('scd_id', 'valid_from', 'valid_to')
extra = []
missing = []
for legacy in expanded_keys:
desired = 'dbt_' + legacy
if desired not in names:
missing.append(desired)
if legacy in names:
extra.append(legacy)
if missing:
if extra:
msg = (
'Snapshot target has ("{}") but not ("{}") - is it an '
'unmigrated previous version archive?'
.format('", "'.join(extra), '", "'.join(missing))
)
else:
msg = (
'Snapshot target is not a snapshot table (missing "{}")'
.format('", "'.join(missing))
)
raise_compiler_error(msg)
@available.parse_none
def expand_target_column_types(
self, from_relation: BaseRelation, to_relation: BaseRelation
) -> None:
if not isinstance(from_relation, self.Relation):
invalid_type_error(
method_name='expand_target_column_types',
arg_name='from_relation',
got_value=from_relation,
expected_type=self.Relation)
if not isinstance(to_relation, self.Relation):
invalid_type_error(
method_name='expand_target_column_types',
arg_name='to_relation',
got_value=to_relation,
expected_type=self.Relation)
self.expand_column_types(from_relation, to_relation)
def list_relations(
self, database: Optional[str], schema: str
) -> List[BaseRelation]:
if self._schema_is_cached(database, schema):
return self.cache.get_relations(database, schema)
schema_relation = self.Relation.create(
database=database,
schema=schema,
identifier='',
quote_policy=self.config.quoting
).without_identifier()
# we can't build the relations cache because we don't have a
# manifest so we can't run any operations.
relations = self.list_relations_without_caching(
schema_relation
)
logger.debug('with database={}, schema={}, relations={}'
.format(database, schema, relations))
return relations
def _make_match_kwargs(
self, database: str, schema: str, identifier: str
) -> Dict[str, str]:
quoting = self.config.quoting
if identifier is not None and quoting['identifier'] is False:
identifier = identifier.lower()
if schema is not None and quoting['schema'] is False:
schema = schema.lower()
if database is not None and quoting['database'] is False:
database = database.lower()
return filter_null_values({
'database': database,
'identifier': identifier,
'schema': schema,
})
def _make_match(
self,
relations_list: List[BaseRelation],
database: str,
schema: str,
identifier: str,
) -> List[BaseRelation]:
matches = []
search = self._make_match_kwargs(database, schema, identifier)
for relation in relations_list:
if relation.matches(**search):
matches.append(relation)
return matches
@available.parse_none
def get_relation(
self, database: str, schema: str, identifier: str
) -> Optional[BaseRelation]:
relations_list = self.list_relations(database, schema)
matches = self._make_match(relations_list, database, schema,
identifier)
if len(matches) > 1:
kwargs = {
'identifier': identifier,
'schema': schema,
'database': database,
}
get_relation_returned_multiple_results(
kwargs, matches
)
elif matches:
return matches[0]
return None
@available.deprecated('get_relation', lambda *a, **k: False)
def already_exists(self, schema: str, name: str) -> bool:
"""DEPRECATED: Return if a model already exists in the database"""
database = self.config.credentials.database
relation = self.get_relation(database, schema, name)
return relation is not None
###
# ODBC FUNCTIONS -- these should not need to change for every adapter,
# although some adapters may override them
###
@abc.abstractmethod
@available.parse_none
def create_schema(self, relation: BaseRelation):
"""Create the given schema if it does not exist."""
raise NotImplementedException(
'`create_schema` is not implemented for this adapter!'
)
@abc.abstractmethod
@available.parse_none
def drop_schema(self, relation: BaseRelation):
"""Drop the given schema (and everything in it) if it exists."""
raise NotImplementedException(
'`drop_schema` is not implemented for this adapter!'
)
@available
@abc.abstractclassmethod
def quote(cls, identifier: str) -> str:
"""Quote the given identifier, as appropriate for the database."""
raise NotImplementedException(
'`quote` is not implemented for this adapter!'
)
@available
def quote_as_configured(self, identifier: str, quote_key: str) -> str:
"""Quote or do not quote the given identifer as configured in the
project config for the quote key.
The quote key should be one of 'database' (on bigquery, 'profile'),
'identifier', or 'schema', or it will be treated as if you set `True`.
"""
try:
key = ComponentName(quote_key)
except ValueError:
return identifier
default = self.Relation.get_default_quote_policy().get_part(key)
if self.config.quoting.get(key, default):
return self.quote(identifier)
else:
return identifier
@available
def quote_seed_column(
self, column: str, quote_config: Optional[bool]
) -> str:
# this is the default for now
quote_columns: bool = False
if isinstance(quote_config, bool):
quote_columns = quote_config
elif quote_config is None:
deprecations.warn('column-quoting-unset')
else:
raise_compiler_error(
f'The seed configuration value of "quote_columns" has an '
f'invalid type {type(quote_config)}'
)
if quote_columns:
return self.quote(column)
else:
return column
###
# Conversions: These must be implemented by concrete implementations, for
# converting agate types into their sql equivalents.
###
@abc.abstractclassmethod
def convert_text_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
"""Return the type in the database that best maps to the agate.Text
type for the given agate table and column index.
:param agate_table: The table
:param col_idx: The index into the agate table for the column.
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_text_type` is not implemented for this adapter!')
@abc.abstractclassmethod
def convert_number_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
"""Return the type in the database that best maps to the agate.Number
type for the given agate table and column index.
:param agate_table: The table
:param col_idx: The index into the agate table for the column.
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_number_type` is not implemented for this adapter!')
@abc.abstractclassmethod
def convert_boolean_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
"""Return the type in the database that best maps to the agate.Boolean
type for the given agate table and column index.
:param agate_table: The table
:param col_idx: The index into the agate table for the column.
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_boolean_type` is not implemented for this adapter!')
@abc.abstractclassmethod
def convert_datetime_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
"""Return the type in the database that best maps to the agate.DateTime
type for the given agate table and column index.
:param agate_table: The table
:param col_idx: The index into the agate table for the column.
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_datetime_type` is not implemented for this adapter!')
@abc.abstractclassmethod
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
"""Return the type in the database that best maps to the agate.Date
type for the given agate table and column index.
:param agate_table: The table
:param col_idx: The index into the agate table for the column.
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_date_type` is not implemented for this adapter!')
@abc.abstractclassmethod
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
"""Return the type in the database that best maps to the
agate.TimeDelta type for the given agate table and column index.
:param agate_table: The table
:param col_idx: The index into the agate table for the column.
:return: The name of the type in the database
"""
raise NotImplementedException(
'`convert_time_type` is not implemented for this adapter!')
@available
@classmethod
def convert_type(
cls, agate_table: agate.Table, col_idx: int
) -> Optional[str]:
return cls.convert_agate_type(agate_table, col_idx)
@classmethod
def convert_agate_type(
cls, agate_table: agate.Table, col_idx: int
) -> Optional[str]:
agate_type: Type = agate_table.column_types[col_idx]
conversions: List[Tuple[Type, Callable[..., str]]] = [
(agate.Text, cls.convert_text_type),
(agate.Number, cls.convert_number_type),
(agate.Boolean, cls.convert_boolean_type),
(agate.DateTime, cls.convert_datetime_type),
(agate.Date, cls.convert_date_type),
(agate.TimeDelta, cls.convert_time_type),
]
for agate_cls, func in conversions:
if isinstance(agate_type, agate_cls):
return func(agate_table, col_idx)
return None
###
# Operations involving the manifest
###
def execute_macro(
self,
macro_name: str,
manifest: Optional[Manifest] = None,
project: Optional[str] = None,
context_override: Optional[Dict[str, Any]] = None,
kwargs: Dict[str, Any] = None,
release: bool = False,
text_only_columns: Optional[Iterable[str]] = None,
) -> agate.Table:
"""Look macro_name up in the manifest and execute its results.
:param macro_name: The name of the macro to execute.
:param manifest: The manifest to use for generating the base macro
execution context. If none is provided, use the internal manifest.
:param project: The name of the project to search in, or None for the
first match.
:param context_override: An optional dict to update() the macro
execution context.
:param kwargs: An optional dict of keyword args used to pass to the
macro.
:param release: Ignored.
"""
if release is not False:
deprecations.warn('execute-macro-release')
if kwargs is None:
kwargs = {}
if context_override is None:
context_override = {}
if manifest is None:
manifest = self._macro_manifest
macro = manifest.find_macro_by_name(
macro_name, self.config.project_name, project
)
if macro is None:
if project is None:
package_name = 'any package'
else:
package_name = 'the "{}" package'.format(project)
raise RuntimeException(
'dbt could not find a macro with the name "{}" in {}'
.format(macro_name, package_name)
)
# This causes a reference cycle, as generate_runtime_macro()
# ends up calling get_adapter, so the import has to be here.
from dbt.context.providers import generate_runtime_macro
macro_context = generate_runtime_macro(
macro=macro,
config=self.config,
manifest=manifest,
package_name=project
)
macro_context.update(context_override)
macro_function = MacroGenerator(macro, macro_context)
with self.connections.exception_handler(f'macro {macro_name}'):
result = macro_function(**kwargs)
return result
@classmethod
def _catalog_filter_table(
cls, table: agate.Table, manifest: Manifest
) -> agate.Table:
"""Filter the table as appropriate for catalog entries. Subclasses can
override this to change filtering rules on a per-adapter basis.
"""
# force database + schema to be strings
table = table_from_rows(
table.rows,
table.column_names,
text_only_columns=['table_database', 'table_schema', 'table_name']
)
return table.where(_catalog_filter_schemas(manifest))
def _get_one_catalog(
self,
information_schema: InformationSchema,
schemas: Set[str],
manifest: Manifest,
) -> agate.Table:
kwargs = {
'information_schema': information_schema,
'schemas': schemas
}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
# pass in the full manifest so we get any local project
# overrides
manifest=manifest,
)
results = self._catalog_filter_table(table, manifest)
return results
def get_catalog(
self, manifest: Manifest
) -> Tuple[agate.Table, List[Exception]]:
schema_map = self._get_catalog_schemas(manifest)
with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = []
for info, schemas in schema_map.items():
if len(schemas) == 0:
continue
name = '.'.join([
str(info.database),
'information_schema'
])
fut = tpe.submit_connected(
self, name,
self._get_one_catalog, info, schemas, manifest
)
futures.append(fut)
catalogs, exceptions = catch_as_completed(futures)
return catalogs, exceptions
def cancel_open_connections(self):
"""Cancel all open connections."""
return self.connections.cancel_open()
def calculate_freshness(
self,
source: BaseRelation,
loaded_at_field: str,
filter: Optional[str],
manifest: Optional[Manifest] = None
) -> Dict[str, Any]:
"""Calculate the freshness of sources in dbt, and return it"""
kwargs: Dict[str, Any] = {
'source': source,
'loaded_at_field': loaded_at_field,
'filter': filter,
}
# run the macro
table = self.execute_macro(
FRESHNESS_MACRO_NAME,
kwargs=kwargs,
manifest=manifest
)
# now we have a 1-row table of the maximum `loaded_at_field` value and
# the current time according to the db.
if len(table) != 1 or len(table[0]) != 2:
raise_compiler_error(
'Got an invalid result from "{}" macro: {}'.format(
FRESHNESS_MACRO_NAME, [tuple(r) for r in table]
)
)
if table[0][0] is None:
# no records in the table, so really the max_loaded_at was
# infinitely long ago. Just call it 0:00 January 1 year UTC
max_loaded_at = datetime(1, 1, 1, 0, 0, 0, tzinfo=pytz.UTC)
else:
max_loaded_at = _utc(table[0][0], source, loaded_at_field)
snapshotted_at = _utc(table[0][1], source, loaded_at_field)
age = (snapshotted_at - max_loaded_at).total_seconds()
return {
'max_loaded_at': max_loaded_at,
'snapshotted_at': snapshotted_at,
'age': age,
}
def pre_model_hook(self, config: Mapping[str, Any]) -> Any:
"""A hook for running some operation before the model materialization
runs. The hook can assume it has a connection available.
The only parameter is a configuration dictionary (the same one
available in the materialization context). It should be considered
read-only.
The pre-model hook may return anything as a context, which will be
passed to the post-model hook.
"""
pass
def post_model_hook(self, config: Mapping[str, Any], context: Any) -> None:
"""A hook for running some operation after the model materialization
runs. The hook can assume it has a connection available.
The first parameter is a configuration dictionary (the same one
available in the materialization context). It should be considered
read-only.
The second parameter is the value returned by pre_mdoel_hook.
"""
pass
def get_compiler(self):
from dbt.compilation import Compiler
return Compiler(self.config)
# Methods used in adapter tests
def update_column_sql(
self,
dst_name: str,
dst_column: str,
clause: str,
where_clause: Optional[str] = None,
) -> str:
clause = f'update {dst_name} set {dst_column} = {clause}'
if where_clause is not None:
clause += f' where {where_clause}'
return clause
def timestamp_add_sql(
self, add_to: str, number: int = 1, interval: str = 'hour'
) -> str:
# for backwards compatibility, we're compelled to set some sort of
# default. A lot of searching has lead me to believe that the
# '+ interval' syntax used in postgres/redshift is relatively common
# and might even be the SQL standard's intention.
return f"{add_to} + interval '{number} {interval}'"
def string_add_sql(
self, add_to: str, value: str, location='append',
) -> str:
if location == 'append':
return f"{add_to} || '{value}'"
elif location == 'prepend':
return f"'{value}' || {add_to}"
else:
raise RuntimeException(
f'Got an unexpected location value of "{location}"'
)
def get_rows_different_sql(
self,
relation_a: BaseRelation,
relation_b: BaseRelation,
column_names: Optional[List[str]] = None,
except_operator: str = 'EXCEPT',
) -> str:
"""Generate SQL for a query that returns a single row with a two
columns: the number of rows that are different between the two
relations and the number of mismatched rows.
"""
# This method only really exists for test reasons.
names: List[str]
if column_names is None:
columns = self.get_columns_in_relation(relation_a)
names = sorted((self.quote(c.name) for c in columns))
else:
names = sorted((self.quote(n) for n in column_names))
columns_csv = ', '.join(names)
sql = COLUMNS_EQUAL_SQL.format(
columns=columns_csv,
relation_a=str(relation_a),
relation_b=str(relation_b),
except_op=except_operator,
)
return sql
COLUMNS_EQUAL_SQL = '''
with diff_count as (
SELECT
1 as id,
COUNT(*) as num_missing FROM (
(SELECT {columns} FROM {relation_a} {except_op}
SELECT {columns} FROM {relation_b})
UNION ALL
(SELECT {columns} FROM {relation_b} {except_op}
SELECT {columns} FROM {relation_a})
) as a
), table_a as (
SELECT COUNT(*) as num_rows FROM {relation_a}
), table_b as (
SELECT COUNT(*) as num_rows FROM {relation_b}
), row_count_diff as (
select
1 as id,
table_a.num_rows - table_b.num_rows as difference
from table_a, table_b
)
select
row_count_diff.difference as row_count_difference,
diff_count.num_missing as num_mismatched
from row_count_diff
join diff_count using (id)
'''.strip()
def catch_as_completed(
futures # typing: List[Future[agate.Table]]
) -> Tuple[agate.Table, List[Exception]]:
# catalogs: agate.Table = agate.Table(rows=[])
tables: List[agate.Table] = []
exceptions: List[Exception] = []
for future in as_completed(futures):
exc = future.exception()
# we want to re-raise on ctrl+c and BaseException
if exc is None:
catalog = future.result()
tables.append(catalog)
elif (
isinstance(exc, KeyboardInterrupt) or
not isinstance(exc, Exception)
):
raise exc
else:
warn_or_error(
f'Encountered an error while generating catalog: {str(exc)}'
)
# exc is not None, derives from Exception, and isn't ctrl+c
exceptions.append(exc)
return merge_tables(tables), exceptions