228 lines
7.5 KiB
Python
228 lines
7.5 KiB
Python
|
import threading
|
||
|
from pathlib import Path
|
||
|
from importlib import import_module
|
||
|
from typing import Type, Dict, Any, List, Optional, Set
|
||
|
|
||
|
from dbt.exceptions import RuntimeException, InternalException
|
||
|
from dbt.include.global_project import (
|
||
|
PACKAGE_PATH as GLOBAL_PROJECT_PATH,
|
||
|
PROJECT_NAME as GLOBAL_PROJECT_NAME,
|
||
|
)
|
||
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||
|
from dbt.contracts.connection import Credentials, AdapterRequiredConfig
|
||
|
|
||
|
|
||
|
from dbt.adapters.protocol import (
|
||
|
AdapterProtocol,
|
||
|
AdapterConfig,
|
||
|
RelationProtocol,
|
||
|
)
|
||
|
from dbt.adapters.base.plugin import AdapterPlugin
|
||
|
|
||
|
|
||
|
Adapter = AdapterProtocol
|
||
|
|
||
|
|
||
|
class AdapterContainer:
|
||
|
def __init__(self):
|
||
|
self.lock = threading.Lock()
|
||
|
self.adapters: Dict[str, Adapter] = {}
|
||
|
self.plugins: Dict[str, AdapterPlugin] = {}
|
||
|
# map package names to their include paths
|
||
|
self.packages: Dict[str, Path] = {
|
||
|
GLOBAL_PROJECT_NAME: Path(GLOBAL_PROJECT_PATH),
|
||
|
}
|
||
|
|
||
|
def get_plugin_by_name(self, name: str) -> AdapterPlugin:
|
||
|
with self.lock:
|
||
|
if name in self.plugins:
|
||
|
return self.plugins[name]
|
||
|
names = ", ".join(self.plugins.keys())
|
||
|
|
||
|
message = f"Invalid adapter type {name}! Must be one of {names}"
|
||
|
raise RuntimeException(message)
|
||
|
|
||
|
def get_adapter_class_by_name(self, name: str) -> Type[Adapter]:
|
||
|
plugin = self.get_plugin_by_name(name)
|
||
|
return plugin.adapter
|
||
|
|
||
|
def get_relation_class_by_name(self, name: str) -> Type[RelationProtocol]:
|
||
|
adapter = self.get_adapter_class_by_name(name)
|
||
|
return adapter.Relation
|
||
|
|
||
|
def get_config_class_by_name(
|
||
|
self, name: str
|
||
|
) -> Type[AdapterConfig]:
|
||
|
adapter = self.get_adapter_class_by_name(name)
|
||
|
return adapter.AdapterSpecificConfigs
|
||
|
|
||
|
def load_plugin(self, name: str) -> Type[Credentials]:
|
||
|
# this doesn't need a lock: in the worst case we'll overwrite packages
|
||
|
# and adapter_type entries with the same value, as they're all
|
||
|
# singletons
|
||
|
try:
|
||
|
# mypy doesn't think modules have any attributes.
|
||
|
mod: Any = import_module('.' + name, 'dbt.adapters')
|
||
|
except ModuleNotFoundError as exc:
|
||
|
# if we failed to import the target module in particular, inform
|
||
|
# the user about it via a runtime error
|
||
|
if exc.name == 'dbt.adapters.' + name:
|
||
|
raise RuntimeException(f'Could not find adapter type {name}!')
|
||
|
logger.info(f'Error importing adapter: {exc}')
|
||
|
# otherwise, the error had to have come from some underlying
|
||
|
# library. Log the stack trace.
|
||
|
logger.debug('', exc_info=True)
|
||
|
raise
|
||
|
plugin: AdapterPlugin = mod.Plugin
|
||
|
plugin_type = plugin.adapter.type()
|
||
|
|
||
|
if plugin_type != name:
|
||
|
raise RuntimeException(
|
||
|
f'Expected to find adapter with type named {name}, got '
|
||
|
f'adapter with type {plugin_type}'
|
||
|
)
|
||
|
|
||
|
with self.lock:
|
||
|
# things do hold the lock to iterate over it so we need it to add
|
||
|
self.plugins[name] = plugin
|
||
|
|
||
|
self.packages[plugin.project_name] = Path(plugin.include_path)
|
||
|
|
||
|
for dep in plugin.dependencies:
|
||
|
self.load_plugin(dep)
|
||
|
|
||
|
return plugin.credentials
|
||
|
|
||
|
def register_adapter(self, config: AdapterRequiredConfig) -> None:
|
||
|
adapter_name = config.credentials.type
|
||
|
adapter_type = self.get_adapter_class_by_name(adapter_name)
|
||
|
|
||
|
with self.lock:
|
||
|
if adapter_name in self.adapters:
|
||
|
# this shouldn't really happen...
|
||
|
return
|
||
|
|
||
|
adapter: Adapter = adapter_type(config) # type: ignore
|
||
|
self.adapters[adapter_name] = adapter
|
||
|
|
||
|
def lookup_adapter(self, adapter_name: str) -> Adapter:
|
||
|
return self.adapters[adapter_name]
|
||
|
|
||
|
def reset_adapters(self):
|
||
|
"""Clear the adapters. This is useful for tests, which change configs.
|
||
|
"""
|
||
|
with self.lock:
|
||
|
for adapter in self.adapters.values():
|
||
|
adapter.cleanup_connections()
|
||
|
self.adapters.clear()
|
||
|
|
||
|
def cleanup_connections(self):
|
||
|
"""Only clean up the adapter connections list without resetting the
|
||
|
actual adapters.
|
||
|
"""
|
||
|
with self.lock:
|
||
|
for adapter in self.adapters.values():
|
||
|
adapter.cleanup_connections()
|
||
|
|
||
|
def get_adapter_plugins(self, name: Optional[str]) -> List[AdapterPlugin]:
|
||
|
"""Iterate over the known adapter plugins. If a name is provided,
|
||
|
iterate in dependency order over the named plugin and its dependencies.
|
||
|
"""
|
||
|
if name is None:
|
||
|
return list(self.plugins.values())
|
||
|
|
||
|
plugins: List[AdapterPlugin] = []
|
||
|
seen: Set[str] = set()
|
||
|
plugin_names: List[str] = [name]
|
||
|
while plugin_names:
|
||
|
plugin_name = plugin_names[0]
|
||
|
plugin_names = plugin_names[1:]
|
||
|
try:
|
||
|
plugin = self.plugins[plugin_name]
|
||
|
except KeyError:
|
||
|
raise InternalException(
|
||
|
f'No plugin found for {plugin_name}'
|
||
|
) from None
|
||
|
plugins.append(plugin)
|
||
|
seen.add(plugin_name)
|
||
|
if plugin.dependencies is None:
|
||
|
continue
|
||
|
for dep in plugin.dependencies:
|
||
|
if dep not in seen:
|
||
|
plugin_names.append(dep)
|
||
|
return plugins
|
||
|
|
||
|
def get_adapter_package_names(self, name: Optional[str]) -> List[str]:
|
||
|
package_names: List[str] = [
|
||
|
p.project_name for p in self.get_adapter_plugins(name)
|
||
|
]
|
||
|
package_names.append(GLOBAL_PROJECT_NAME)
|
||
|
return package_names
|
||
|
|
||
|
def get_include_paths(self, name: Optional[str]) -> List[Path]:
|
||
|
paths = []
|
||
|
for package_name in self.get_adapter_package_names(name):
|
||
|
try:
|
||
|
path = self.packages[package_name]
|
||
|
except KeyError:
|
||
|
raise InternalException(
|
||
|
f'No internal package listing found for {package_name}'
|
||
|
)
|
||
|
paths.append(path)
|
||
|
return paths
|
||
|
|
||
|
def get_adapter_type_names(self, name: Optional[str]) -> List[str]:
|
||
|
return [p.adapter.type() for p in self.get_adapter_plugins(name)]
|
||
|
|
||
|
|
||
|
FACTORY: AdapterContainer = AdapterContainer()
|
||
|
|
||
|
|
||
|
def register_adapter(config: AdapterRequiredConfig) -> None:
|
||
|
FACTORY.register_adapter(config)
|
||
|
|
||
|
|
||
|
def get_adapter(config: AdapterRequiredConfig):
|
||
|
return FACTORY.lookup_adapter(config.credentials.type)
|
||
|
|
||
|
|
||
|
def reset_adapters():
|
||
|
"""Clear the adapters. This is useful for tests, which change configs.
|
||
|
"""
|
||
|
FACTORY.reset_adapters()
|
||
|
|
||
|
|
||
|
def cleanup_connections():
|
||
|
"""Only clean up the adapter connections list without resetting the actual
|
||
|
adapters.
|
||
|
"""
|
||
|
FACTORY.cleanup_connections()
|
||
|
|
||
|
|
||
|
def get_adapter_class_by_name(name: str) -> Type[AdapterProtocol]:
|
||
|
return FACTORY.get_adapter_class_by_name(name)
|
||
|
|
||
|
|
||
|
def get_config_class_by_name(name: str) -> Type[AdapterConfig]:
|
||
|
return FACTORY.get_config_class_by_name(name)
|
||
|
|
||
|
|
||
|
def get_relation_class_by_name(name: str) -> Type[RelationProtocol]:
|
||
|
return FACTORY.get_relation_class_by_name(name)
|
||
|
|
||
|
|
||
|
def load_plugin(name: str) -> Type[Credentials]:
|
||
|
return FACTORY.load_plugin(name)
|
||
|
|
||
|
|
||
|
def get_include_paths(name: Optional[str]) -> List[Path]:
|
||
|
return FACTORY.get_include_paths(name)
|
||
|
|
||
|
|
||
|
def get_adapter_package_names(name: Optional[str]) -> List[str]:
|
||
|
return FACTORY.get_adapter_package_names(name)
|
||
|
|
||
|
|
||
|
def get_adapter_type_names(name: Optional[str]) -> List[str]:
|
||
|
return FACTORY.get_adapter_type_names(name)
|