from contextlib import contextmanager import psycopg2 import dbt.exceptions from dbt.adapters.base import Credentials from dbt.adapters.sql import SQLConnectionManager from dbt.contracts.connection import AdapterResponse from dbt.logger import GLOBAL_LOGGER as logger from dbt.helper_types import Port from dataclasses import dataclass from typing import Optional @dataclass class PostgresCredentials(Credentials): host: str user: str port: Port password: str # on postgres the password is mandatory connect_timeout: int = 10 role: Optional[str] = None search_path: Optional[str] = None keepalives_idle: int = 0 # 0 means to use the default value sslmode: Optional[str] = None sslcert: Optional[str] = None sslkey: Optional[str] = None sslrootcert: Optional[str] = None application_name: Optional[str] = 'dbt' _ALIASES = { 'dbname': 'database', 'pass': 'password' } @property def type(self): return 'postgres' @property def unique_field(self): return self.host def _connection_keys(self): return ('host', 'port', 'user', 'database', 'schema', 'search_path', 'keepalives_idle', 'sslmode') class PostgresConnectionManager(SQLConnectionManager): TYPE = 'postgres' @contextmanager def exception_handler(self, sql): try: yield except psycopg2.DatabaseError as e: logger.debug('Postgres error: {}'.format(str(e))) try: self.rollback_if_open() except psycopg2.Error: logger.debug("Failed to release connection!") pass raise dbt.exceptions.DatabaseException(str(e).strip()) from e except Exception as e: logger.debug("Error running SQL: {}", sql) logger.debug("Rolling back transaction.") self.rollback_if_open() if isinstance(e, dbt.exceptions.RuntimeException): # during a sql query, an internal to dbt exception was raised. # this sounds a lot like a signal handler and probably has # useful information, so raise it without modification. raise raise dbt.exceptions.RuntimeException(e) from e @classmethod def open(cls, connection): if connection.state == 'open': logger.debug('Connection is already open, skipping open.') return connection credentials = cls.get_credentials(connection.credentials) kwargs = {} # we don't want to pass 0 along to connect() as postgres will try to # call an invalid setsockopt() call (contrary to the docs). if credentials.keepalives_idle: kwargs['keepalives_idle'] = credentials.keepalives_idle # psycopg2 doesn't support search_path officially, # see https://github.com/psycopg/psycopg2/issues/465 search_path = credentials.search_path if search_path is not None and search_path != '': # see https://postgresql.org/docs/9.5/libpq-connect.html kwargs['options'] = '-c search_path={}'.format( search_path.replace(' ', '\\ ')) if credentials.sslmode: kwargs['sslmode'] = credentials.sslmode if credentials.sslcert is not None: kwargs["sslcert"] = credentials.sslcert if credentials.sslkey is not None: kwargs["sslkey"] = credentials.sslkey if credentials.sslrootcert is not None: kwargs["sslrootcert"] = credentials.sslrootcert if credentials.application_name: kwargs['application_name'] = credentials.application_name try: handle = psycopg2.connect( dbname=credentials.database, user=credentials.user, host=credentials.host, password=credentials.password, port=credentials.port, connect_timeout=credentials.connect_timeout, **kwargs) if credentials.role: handle.cursor().execute('set role {}'.format(credentials.role)) connection.handle = handle connection.state = 'open' except psycopg2.Error as e: logger.debug("Got an error when attempting to open a postgres " "connection: '{}'" .format(e)) connection.handle = None connection.state = 'fail' raise dbt.exceptions.FailedToConnectException(str(e)) return connection def cancel(self, connection): connection_name = connection.name try: pid = connection.handle.get_backend_pid() except psycopg2.InterfaceError as exc: # if the connection is already closed, not much to cancel! if 'already closed' in str(exc): logger.debug( f'Connection {connection_name} was already closed' ) return # probably bad, re-raise it raise sql = "select pg_terminate_backend({})".format(pid) logger.debug("Cancelling query '{}' ({})".format(connection_name, pid)) _, cursor = self.add_query(sql) res = cursor.fetchone() logger.debug("Cancel query '{}': {}".format(connection_name, res)) @classmethod def get_credentials(cls, credentials): return credentials @classmethod def get_response(cls, cursor) -> AdapterResponse: message = str(cursor.statusmessage) rows = cursor.rowcount status_message_parts = message.split() if message is not None else [] status_messsage_strings = [ part for part in status_message_parts if not part.isdigit() ] code = ' '.join(status_messsage_strings) return AdapterResponse( _message=message, code=code, rows_affected=rows )