307 lines
11 KiB
Python
307 lines
11 KiB
Python
import abc
|
|
import os
|
|
# multiprocessing.RLock is a function returning this type
|
|
from multiprocessing.synchronize import RLock
|
|
from threading import get_ident
|
|
from typing import (
|
|
Dict, Tuple, Hashable, Optional, ContextManager, List, Union
|
|
)
|
|
|
|
import agate
|
|
|
|
import dbt.exceptions
|
|
from dbt.contracts.connection import (
|
|
Connection, Identifier, ConnectionState,
|
|
AdapterRequiredConfig, LazyHandle, AdapterResponse
|
|
)
|
|
from dbt.contracts.graph.manifest import Manifest
|
|
from dbt.adapters.base.query_headers import (
|
|
MacroQueryStringSetter,
|
|
)
|
|
from dbt.logger import GLOBAL_LOGGER as logger
|
|
from dbt import flags
|
|
|
|
|
|
class BaseConnectionManager(metaclass=abc.ABCMeta):
|
|
"""Methods to implement:
|
|
- exception_handler
|
|
- cancel_open
|
|
- open
|
|
- begin
|
|
- commit
|
|
- clear_transaction
|
|
- execute
|
|
|
|
You must also set the 'TYPE' class attribute with a class-unique constant
|
|
string.
|
|
"""
|
|
TYPE: str = NotImplemented
|
|
|
|
def __init__(self, profile: AdapterRequiredConfig):
|
|
self.profile = profile
|
|
self.thread_connections: Dict[Hashable, Connection] = {}
|
|
self.lock: RLock = flags.MP_CONTEXT.RLock()
|
|
self.query_header: Optional[MacroQueryStringSetter] = None
|
|
|
|
def set_query_header(self, manifest: Manifest) -> None:
|
|
self.query_header = MacroQueryStringSetter(self.profile, manifest)
|
|
|
|
@staticmethod
|
|
def get_thread_identifier() -> Hashable:
|
|
# note that get_ident() may be re-used, but we should never experience
|
|
# that within a single process
|
|
return (os.getpid(), get_ident())
|
|
|
|
def get_thread_connection(self) -> Connection:
|
|
key = self.get_thread_identifier()
|
|
with self.lock:
|
|
if key not in self.thread_connections:
|
|
raise dbt.exceptions.InvalidConnectionException(
|
|
key, list(self.thread_connections)
|
|
)
|
|
return self.thread_connections[key]
|
|
|
|
def set_thread_connection(self, conn: Connection) -> None:
|
|
key = self.get_thread_identifier()
|
|
if key in self.thread_connections:
|
|
raise dbt.exceptions.InternalException(
|
|
'In set_thread_connection, existing connection exists for {}'
|
|
)
|
|
self.thread_connections[key] = conn
|
|
|
|
def get_if_exists(self) -> Optional[Connection]:
|
|
key = self.get_thread_identifier()
|
|
with self.lock:
|
|
return self.thread_connections.get(key)
|
|
|
|
def clear_thread_connection(self) -> None:
|
|
key = self.get_thread_identifier()
|
|
with self.lock:
|
|
if key in self.thread_connections:
|
|
del self.thread_connections[key]
|
|
|
|
def clear_transaction(self) -> None:
|
|
"""Clear any existing transactions."""
|
|
conn = self.get_thread_connection()
|
|
if conn is not None:
|
|
if conn.transaction_open:
|
|
self._rollback(conn)
|
|
self.begin()
|
|
self.commit()
|
|
|
|
def rollback_if_open(self) -> None:
|
|
conn = self.get_if_exists()
|
|
if conn is not None and conn.handle and conn.transaction_open:
|
|
self._rollback(conn)
|
|
|
|
@abc.abstractmethod
|
|
def exception_handler(self, sql: str) -> ContextManager:
|
|
"""Create a context manager that handles exceptions caused by database
|
|
interactions.
|
|
|
|
:param str sql: The SQL string that the block inside the context
|
|
manager is executing.
|
|
:return: A context manager that handles exceptions raised by the
|
|
underlying database.
|
|
"""
|
|
raise dbt.exceptions.NotImplementedException(
|
|
'`exception_handler` is not implemented for this adapter!')
|
|
|
|
def set_connection_name(self, name: Optional[str] = None) -> Connection:
|
|
conn_name: str
|
|
if name is None:
|
|
# if a name isn't specified, we'll re-use a single handle
|
|
# named 'master'
|
|
conn_name = 'master'
|
|
else:
|
|
if not isinstance(name, str):
|
|
raise dbt.exceptions.CompilerException(
|
|
f'For connection name, got {name} - not a string!'
|
|
)
|
|
assert isinstance(name, str)
|
|
conn_name = name
|
|
|
|
conn = self.get_if_exists()
|
|
if conn is None:
|
|
conn = Connection(
|
|
type=Identifier(self.TYPE),
|
|
name=None,
|
|
state=ConnectionState.INIT,
|
|
transaction_open=False,
|
|
handle=None,
|
|
credentials=self.profile.credentials
|
|
)
|
|
self.set_thread_connection(conn)
|
|
|
|
if conn.name == conn_name and conn.state == 'open':
|
|
return conn
|
|
|
|
logger.debug(
|
|
'Acquiring new {} connection "{}".'.format(self.TYPE, conn_name))
|
|
|
|
if conn.state == 'open':
|
|
logger.debug(
|
|
'Re-using an available connection from the pool (formerly {}).'
|
|
.format(conn.name)
|
|
)
|
|
else:
|
|
conn.handle = LazyHandle(self.open)
|
|
|
|
conn.name = conn_name
|
|
return conn
|
|
|
|
@abc.abstractmethod
|
|
def cancel_open(self) -> Optional[List[str]]:
|
|
"""Cancel all open connections on the adapter. (passable)"""
|
|
raise dbt.exceptions.NotImplementedException(
|
|
'`cancel_open` is not implemented for this adapter!'
|
|
)
|
|
|
|
@abc.abstractclassmethod
|
|
def open(cls, connection: Connection) -> Connection:
|
|
"""Open the given connection on the adapter and return it.
|
|
|
|
This may mutate the given connection (in particular, its state and its
|
|
handle).
|
|
|
|
This should be thread-safe, or hold the lock if necessary. The given
|
|
connection should not be in either in_use or available.
|
|
"""
|
|
raise dbt.exceptions.NotImplementedException(
|
|
'`open` is not implemented for this adapter!'
|
|
)
|
|
|
|
def release(self) -> None:
|
|
with self.lock:
|
|
conn = self.get_if_exists()
|
|
if conn is None:
|
|
return
|
|
|
|
try:
|
|
# always close the connection. close() calls _rollback() if there
|
|
# is an open transaction
|
|
self.close(conn)
|
|
except Exception:
|
|
# if rollback or close failed, remove our busted connection
|
|
self.clear_thread_connection()
|
|
raise
|
|
|
|
def cleanup_all(self) -> None:
|
|
with self.lock:
|
|
for connection in self.thread_connections.values():
|
|
if connection.state not in {'closed', 'init'}:
|
|
logger.debug("Connection '{}' was left open."
|
|
.format(connection.name))
|
|
else:
|
|
logger.debug("Connection '{}' was properly closed."
|
|
.format(connection.name))
|
|
self.close(connection)
|
|
|
|
# garbage collect these connections
|
|
self.thread_connections.clear()
|
|
|
|
@abc.abstractmethod
|
|
def begin(self) -> None:
|
|
"""Begin a transaction. (passable)"""
|
|
raise dbt.exceptions.NotImplementedException(
|
|
'`begin` is not implemented for this adapter!'
|
|
)
|
|
|
|
@abc.abstractmethod
|
|
def commit(self) -> None:
|
|
"""Commit a transaction. (passable)"""
|
|
raise dbt.exceptions.NotImplementedException(
|
|
'`commit` is not implemented for this adapter!'
|
|
)
|
|
|
|
@classmethod
|
|
def _rollback_handle(cls, connection: Connection) -> None:
|
|
"""Perform the actual rollback operation."""
|
|
try:
|
|
connection.handle.rollback()
|
|
except Exception:
|
|
logger.debug(
|
|
'Failed to rollback {}'.format(connection.name),
|
|
exc_info=True
|
|
)
|
|
|
|
@classmethod
|
|
def _close_handle(cls, connection: Connection) -> None:
|
|
"""Perform the actual close operation."""
|
|
# On windows, sometimes connection handles don't have a close() attr.
|
|
if hasattr(connection.handle, 'close'):
|
|
logger.debug(f'On {connection.name}: Close')
|
|
connection.handle.close()
|
|
else:
|
|
logger.debug(f'On {connection.name}: No close available on handle')
|
|
|
|
@classmethod
|
|
def _rollback(cls, connection: Connection) -> None:
|
|
"""Roll back the given connection."""
|
|
if flags.STRICT_MODE:
|
|
if not isinstance(connection, Connection):
|
|
raise dbt.exceptions.CompilerException(
|
|
f'In _rollback, got {connection} - not a Connection!'
|
|
)
|
|
|
|
if connection.transaction_open is False:
|
|
raise dbt.exceptions.InternalException(
|
|
f'Tried to rollback transaction on connection '
|
|
f'"{connection.name}", but it does not have one open!'
|
|
)
|
|
|
|
logger.debug(f'On {connection.name}: ROLLBACK')
|
|
cls._rollback_handle(connection)
|
|
|
|
connection.transaction_open = False
|
|
|
|
@classmethod
|
|
def close(cls, connection: Connection) -> Connection:
|
|
if flags.STRICT_MODE:
|
|
if not isinstance(connection, Connection):
|
|
raise dbt.exceptions.CompilerException(
|
|
f'In close, got {connection} - not a Connection!'
|
|
)
|
|
|
|
# if the connection is in closed or init, there's nothing to do
|
|
if connection.state in {ConnectionState.CLOSED, ConnectionState.INIT}:
|
|
return connection
|
|
|
|
if connection.transaction_open and connection.handle:
|
|
logger.debug('On {}: ROLLBACK'.format(connection.name))
|
|
cls._rollback_handle(connection)
|
|
connection.transaction_open = False
|
|
|
|
cls._close_handle(connection)
|
|
connection.state = ConnectionState.CLOSED
|
|
|
|
return connection
|
|
|
|
def commit_if_has_connection(self) -> None:
|
|
"""If the named connection exists, commit the current transaction."""
|
|
connection = self.get_if_exists()
|
|
if connection:
|
|
self.commit()
|
|
|
|
def _add_query_comment(self, sql: str) -> str:
|
|
if self.query_header is None:
|
|
return sql
|
|
return self.query_header.add(sql)
|
|
|
|
@abc.abstractmethod
|
|
def execute(
|
|
self, sql: str, auto_begin: bool = False, fetch: bool = False
|
|
) -> Tuple[Union[str, AdapterResponse], agate.Table]:
|
|
"""Execute the given SQL.
|
|
|
|
:param str sql: The sql to execute.
|
|
:param bool auto_begin: If set, and dbt is not currently inside a
|
|
transaction, automatically begin one.
|
|
:param bool fetch: If set, fetch results.
|
|
:return: A tuple of the status and the results (empty if fetch=False).
|
|
:rtype: Tuple[Union[str, AdapterResponse], agate.Table]
|
|
"""
|
|
raise dbt.exceptions.NotImplementedException(
|
|
'`execute` is not implemented for this adapter!'
|
|
)
|