121 lines
3.6 KiB
Python
121 lines
3.6 KiB
Python
|
from abc import abstractmethod
|
||
|
from datetime import datetime
|
||
|
from typing import Generic, TypeVar
|
||
|
|
||
|
import dbt.exceptions
|
||
|
from dbt.contracts.rpc import (
|
||
|
RemoteCompileResult,
|
||
|
RemoteCompileResultMixin,
|
||
|
RemoteRunResult,
|
||
|
ResultTable,
|
||
|
)
|
||
|
from dbt.logger import GLOBAL_LOGGER as logger
|
||
|
from dbt.task.compile import CompileRunner
|
||
|
from dbt.rpc.error import dbt_error, RPCException, server_error
|
||
|
|
||
|
|
||
|
RPCSQLResult = TypeVar('RPCSQLResult', bound=RemoteCompileResultMixin)
|
||
|
|
||
|
|
||
|
class GenericRPCRunner(CompileRunner, Generic[RPCSQLResult]):
|
||
|
def __init__(self, config, adapter, node, node_index, num_nodes):
|
||
|
CompileRunner.__init__(
|
||
|
self, config, adapter, node, node_index, num_nodes
|
||
|
)
|
||
|
|
||
|
def handle_exception(self, e, ctx):
|
||
|
logger.debug('Got an exception: {}'.format(e), exc_info=True)
|
||
|
if isinstance(e, dbt.exceptions.Exception):
|
||
|
if isinstance(e, dbt.exceptions.RuntimeException):
|
||
|
e.add_node(ctx.node)
|
||
|
return dbt_error(e)
|
||
|
elif isinstance(e, RPCException):
|
||
|
return e
|
||
|
else:
|
||
|
return server_error(e)
|
||
|
|
||
|
def before_execute(self):
|
||
|
pass
|
||
|
|
||
|
def after_execute(self, result):
|
||
|
pass
|
||
|
|
||
|
def compile(self, manifest):
|
||
|
compiler = self.adapter.get_compiler()
|
||
|
return compiler.compile_node(self.node, manifest, {}, write=False)
|
||
|
|
||
|
@abstractmethod
|
||
|
def execute(self, compiled_node, manifest) -> RPCSQLResult:
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def from_run_result(self, result, start_time, timing_info) -> RPCSQLResult:
|
||
|
pass
|
||
|
|
||
|
def error_result(self, node, error, start_time, timing_info):
|
||
|
raise error
|
||
|
|
||
|
def ephemeral_result(self, node, start_time, timing_info):
|
||
|
raise dbt.exceptions.NotImplementedException(
|
||
|
'cannot execute ephemeral nodes remotely!'
|
||
|
)
|
||
|
|
||
|
|
||
|
class RPCCompileRunner(GenericRPCRunner[RemoteCompileResult]):
|
||
|
def execute(self, compiled_node, manifest) -> RemoteCompileResult:
|
||
|
return RemoteCompileResult(
|
||
|
raw_sql=compiled_node.raw_sql,
|
||
|
compiled_sql=compiled_node.compiled_sql,
|
||
|
node=compiled_node,
|
||
|
timing=[], # this will get added later
|
||
|
logs=[],
|
||
|
generated_at=datetime.utcnow(),
|
||
|
)
|
||
|
|
||
|
def from_run_result(
|
||
|
self, result, start_time, timing_info
|
||
|
) -> RemoteCompileResult:
|
||
|
return RemoteCompileResult(
|
||
|
raw_sql=result.raw_sql,
|
||
|
compiled_sql=result.compiled_sql,
|
||
|
node=result.node,
|
||
|
timing=timing_info,
|
||
|
logs=[],
|
||
|
generated_at=datetime.utcnow(),
|
||
|
)
|
||
|
|
||
|
|
||
|
class RPCExecuteRunner(GenericRPCRunner[RemoteRunResult]):
|
||
|
def execute(self, compiled_node, manifest) -> RemoteRunResult:
|
||
|
_, execute_result = self.adapter.execute(
|
||
|
compiled_node.compiled_sql, fetch=True
|
||
|
)
|
||
|
|
||
|
table = ResultTable(
|
||
|
column_names=list(execute_result.column_names),
|
||
|
rows=[list(row) for row in execute_result],
|
||
|
)
|
||
|
|
||
|
return RemoteRunResult(
|
||
|
raw_sql=compiled_node.raw_sql,
|
||
|
compiled_sql=compiled_node.compiled_sql,
|
||
|
node=compiled_node,
|
||
|
table=table,
|
||
|
timing=[],
|
||
|
logs=[],
|
||
|
generated_at=datetime.utcnow(),
|
||
|
)
|
||
|
|
||
|
def from_run_result(
|
||
|
self, result, start_time, timing_info
|
||
|
) -> RemoteRunResult:
|
||
|
return RemoteRunResult(
|
||
|
raw_sql=result.raw_sql,
|
||
|
compiled_sql=result.compiled_sql,
|
||
|
node=result.node,
|
||
|
table=result.table,
|
||
|
timing=timing_info,
|
||
|
logs=[],
|
||
|
generated_at=datetime.utcnow(),
|
||
|
)
|