dbt-selly/dbt-env/lib/python3.8/site-packages/dbt/adapters/factory.py

228 lines
7.5 KiB
Python

import threading
from pathlib import Path
from importlib import import_module
from typing import Type, Dict, Any, List, Optional, Set
from dbt.exceptions import RuntimeException, InternalException
from dbt.include.global_project import (
PACKAGE_PATH as GLOBAL_PROJECT_PATH,
PROJECT_NAME as GLOBAL_PROJECT_NAME,
)
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.contracts.connection import Credentials, AdapterRequiredConfig
from dbt.adapters.protocol import (
AdapterProtocol,
AdapterConfig,
RelationProtocol,
)
from dbt.adapters.base.plugin import AdapterPlugin
Adapter = AdapterProtocol
class AdapterContainer:
def __init__(self):
self.lock = threading.Lock()
self.adapters: Dict[str, Adapter] = {}
self.plugins: Dict[str, AdapterPlugin] = {}
# map package names to their include paths
self.packages: Dict[str, Path] = {
GLOBAL_PROJECT_NAME: Path(GLOBAL_PROJECT_PATH),
}
def get_plugin_by_name(self, name: str) -> AdapterPlugin:
with self.lock:
if name in self.plugins:
return self.plugins[name]
names = ", ".join(self.plugins.keys())
message = f"Invalid adapter type {name}! Must be one of {names}"
raise RuntimeException(message)
def get_adapter_class_by_name(self, name: str) -> Type[Adapter]:
plugin = self.get_plugin_by_name(name)
return plugin.adapter
def get_relation_class_by_name(self, name: str) -> Type[RelationProtocol]:
adapter = self.get_adapter_class_by_name(name)
return adapter.Relation
def get_config_class_by_name(
self, name: str
) -> Type[AdapterConfig]:
adapter = self.get_adapter_class_by_name(name)
return adapter.AdapterSpecificConfigs
def load_plugin(self, name: str) -> Type[Credentials]:
# this doesn't need a lock: in the worst case we'll overwrite packages
# and adapter_type entries with the same value, as they're all
# singletons
try:
# mypy doesn't think modules have any attributes.
mod: Any = import_module('.' + name, 'dbt.adapters')
except ModuleNotFoundError as exc:
# if we failed to import the target module in particular, inform
# the user about it via a runtime error
if exc.name == 'dbt.adapters.' + name:
raise RuntimeException(f'Could not find adapter type {name}!')
logger.info(f'Error importing adapter: {exc}')
# otherwise, the error had to have come from some underlying
# library. Log the stack trace.
logger.debug('', exc_info=True)
raise
plugin: AdapterPlugin = mod.Plugin
plugin_type = plugin.adapter.type()
if plugin_type != name:
raise RuntimeException(
f'Expected to find adapter with type named {name}, got '
f'adapter with type {plugin_type}'
)
with self.lock:
# things do hold the lock to iterate over it so we need it to add
self.plugins[name] = plugin
self.packages[plugin.project_name] = Path(plugin.include_path)
for dep in plugin.dependencies:
self.load_plugin(dep)
return plugin.credentials
def register_adapter(self, config: AdapterRequiredConfig) -> None:
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)
with self.lock:
if adapter_name in self.adapters:
# this shouldn't really happen...
return
adapter: Adapter = adapter_type(config) # type: ignore
self.adapters[adapter_name] = adapter
def lookup_adapter(self, adapter_name: str) -> Adapter:
return self.adapters[adapter_name]
def reset_adapters(self):
"""Clear the adapters. This is useful for tests, which change configs.
"""
with self.lock:
for adapter in self.adapters.values():
adapter.cleanup_connections()
self.adapters.clear()
def cleanup_connections(self):
"""Only clean up the adapter connections list without resetting the
actual adapters.
"""
with self.lock:
for adapter in self.adapters.values():
adapter.cleanup_connections()
def get_adapter_plugins(self, name: Optional[str]) -> List[AdapterPlugin]:
"""Iterate over the known adapter plugins. If a name is provided,
iterate in dependency order over the named plugin and its dependencies.
"""
if name is None:
return list(self.plugins.values())
plugins: List[AdapterPlugin] = []
seen: Set[str] = set()
plugin_names: List[str] = [name]
while plugin_names:
plugin_name = plugin_names[0]
plugin_names = plugin_names[1:]
try:
plugin = self.plugins[plugin_name]
except KeyError:
raise InternalException(
f'No plugin found for {plugin_name}'
) from None
plugins.append(plugin)
seen.add(plugin_name)
if plugin.dependencies is None:
continue
for dep in plugin.dependencies:
if dep not in seen:
plugin_names.append(dep)
return plugins
def get_adapter_package_names(self, name: Optional[str]) -> List[str]:
package_names: List[str] = [
p.project_name for p in self.get_adapter_plugins(name)
]
package_names.append(GLOBAL_PROJECT_NAME)
return package_names
def get_include_paths(self, name: Optional[str]) -> List[Path]:
paths = []
for package_name in self.get_adapter_package_names(name):
try:
path = self.packages[package_name]
except KeyError:
raise InternalException(
f'No internal package listing found for {package_name}'
)
paths.append(path)
return paths
def get_adapter_type_names(self, name: Optional[str]) -> List[str]:
return [p.adapter.type() for p in self.get_adapter_plugins(name)]
FACTORY: AdapterContainer = AdapterContainer()
def register_adapter(config: AdapterRequiredConfig) -> None:
FACTORY.register_adapter(config)
def get_adapter(config: AdapterRequiredConfig):
return FACTORY.lookup_adapter(config.credentials.type)
def reset_adapters():
"""Clear the adapters. This is useful for tests, which change configs.
"""
FACTORY.reset_adapters()
def cleanup_connections():
"""Only clean up the adapter connections list without resetting the actual
adapters.
"""
FACTORY.cleanup_connections()
def get_adapter_class_by_name(name: str) -> Type[AdapterProtocol]:
return FACTORY.get_adapter_class_by_name(name)
def get_config_class_by_name(name: str) -> Type[AdapterConfig]:
return FACTORY.get_config_class_by_name(name)
def get_relation_class_by_name(name: str) -> Type[RelationProtocol]:
return FACTORY.get_relation_class_by_name(name)
def load_plugin(name: str) -> Type[Credentials]:
return FACTORY.load_plugin(name)
def get_include_paths(name: Optional[str]) -> List[Path]:
return FACTORY.get_include_paths(name)
def get_adapter_package_names(name: Optional[str]) -> List[str]:
return FACTORY.get_adapter_package_names(name)
def get_adapter_type_names(name: Optional[str]) -> List[str]:
return FACTORY.get_adapter_type_names(name)