214 lines
7.1 KiB
Python
214 lines
7.1 KiB
Python
import base64
|
|
import signal
|
|
import threading
|
|
from datetime import datetime
|
|
from typing import Dict, Any
|
|
|
|
from dbt import flags
|
|
from dbt.adapters.factory import get_adapter
|
|
from dbt.clients.jinja import extract_toplevel_blocks
|
|
from dbt.config.runtime import RuntimeConfig
|
|
from dbt.contracts.graph.manifest import Manifest
|
|
from dbt.contracts.graph.parsed import ParsedRPCNode
|
|
from dbt.contracts.rpc import RPCExecParameters
|
|
from dbt.contracts.rpc import RemoteExecutionResult
|
|
from dbt.exceptions import RPCKilledException, InternalException
|
|
from dbt.logger import GLOBAL_LOGGER as logger
|
|
from dbt.parser.manifest import process_node, process_macro
|
|
from dbt.parser.rpc import RPCCallParser, RPCMacroParser
|
|
from dbt.rpc.error import invalid_params
|
|
from dbt.rpc.node_runners import RPCCompileRunner, RPCExecuteRunner
|
|
from dbt.task.compile import CompileTask
|
|
from dbt.task.run import RunTask
|
|
|
|
from .base import RPCTask
|
|
|
|
|
|
def add_new_refs(
|
|
manifest: Manifest,
|
|
config: RuntimeConfig,
|
|
node: ParsedRPCNode,
|
|
macros: Dict[str, Any]
|
|
) -> None:
|
|
"""Given a new node that is not in the manifest, insert the new node
|
|
into it as if it were part of regular ref processing.
|
|
"""
|
|
if config.args.single_threaded or flags.SINGLE_THREADED_HANDLER:
|
|
manifest = manifest.deepcopy()
|
|
# it's ok for macros to silently override a local project macro name
|
|
manifest.macros.update(macros)
|
|
|
|
for macro in macros.values():
|
|
process_macro(config, manifest, macro)
|
|
|
|
# We used to do 'manifest.add_nodes({node.unique_id: node}) here, but the
|
|
# node has already been added to the Manifest by the RPCCallParser
|
|
# now that we save nodes to the Manifest instead of ParseResults.
|
|
process_node(config, manifest, node)
|
|
|
|
|
|
class RemoteRunSQLTask(RPCTask[RPCExecParameters]):
|
|
def runtime_cleanup(self, selected_uids):
|
|
"""Do some pre-run cleanup that is usually performed in Task __init__.
|
|
"""
|
|
self.run_count = 0
|
|
self.num_nodes = len(selected_uids)
|
|
self.node_results = []
|
|
self._skipped_children = {}
|
|
self._skipped_children = {}
|
|
self._raise_next_tick = None
|
|
|
|
def decode_sql(self, sql: str) -> str:
|
|
"""Base64 decode a string. This should only be used for sql in calls.
|
|
|
|
:param str sql: The base64 encoded form of the original utf-8 string
|
|
:return str: The decoded utf-8 string
|
|
"""
|
|
# JSON is defined as using "unicode", we'll go a step further and
|
|
# mandate utf-8 (though for the base64 part, it doesn't really matter!)
|
|
base64_sql_bytes = str(sql).encode('utf-8')
|
|
|
|
try:
|
|
sql_bytes = base64.b64decode(base64_sql_bytes, validate=True)
|
|
except ValueError:
|
|
self.raise_invalid_base64(sql)
|
|
|
|
return sql_bytes.decode('utf-8')
|
|
|
|
@staticmethod
|
|
def raise_invalid_base64(sql):
|
|
raise invalid_params(
|
|
data={
|
|
'message': 'invalid base64-encoded sql input',
|
|
'sql': str(sql),
|
|
}
|
|
)
|
|
|
|
def _extract_request_data(self, data):
|
|
data = self.decode_sql(data)
|
|
macro_blocks = []
|
|
data_chunks = []
|
|
for block in extract_toplevel_blocks(data):
|
|
if block.block_type_name == 'macro':
|
|
macro_blocks.append(block.full_block)
|
|
else:
|
|
data_chunks.append(block.full_block)
|
|
macros = '\n'.join(macro_blocks)
|
|
sql = ''.join(data_chunks)
|
|
return sql, macros
|
|
|
|
def _get_exec_node(self):
|
|
if self.manifest is None:
|
|
raise InternalException(
|
|
'manifest not set in _get_exec_node'
|
|
)
|
|
|
|
macro_overrides = {}
|
|
macros = self.args.macros
|
|
sql, macros = self._extract_request_data(self.args.sql)
|
|
|
|
if macros:
|
|
macro_parser = RPCMacroParser(self.config, self.manifest)
|
|
for node in macro_parser.parse_remote(macros):
|
|
macro_overrides[node.unique_id] = node
|
|
|
|
self.manifest.macros.update(macro_overrides)
|
|
rpc_parser = RPCCallParser(
|
|
project=self.config,
|
|
manifest=self.manifest,
|
|
root_project=self.config,
|
|
)
|
|
rpc_node = rpc_parser.parse_remote(sql, self.args.name)
|
|
add_new_refs(
|
|
manifest=self.manifest,
|
|
config=self.config,
|
|
node=rpc_node,
|
|
macros=macro_overrides
|
|
)
|
|
|
|
# don't write our new, weird manifest!
|
|
adapter = get_adapter(self.config)
|
|
compiler = adapter.get_compiler()
|
|
self.graph = compiler.compile(self.manifest, write=False)
|
|
# previously, this compiled the ancestors, but they are compiled at
|
|
# runtime now.
|
|
return rpc_node
|
|
|
|
def _raise_set_error(self):
|
|
if self._raise_next_tick is not None:
|
|
raise self._raise_next_tick
|
|
|
|
def _in_thread(self, node, thread_done):
|
|
runner = self.get_runner(node)
|
|
try:
|
|
self.node_results.append(runner.safe_run(self.manifest))
|
|
except Exception as exc:
|
|
logger.debug('Got exception {}'.format(exc), exc_info=True)
|
|
self._raise_next_tick = exc
|
|
finally:
|
|
thread_done.set()
|
|
|
|
def set_args(self, params: RPCExecParameters):
|
|
self.args.name = params.name
|
|
self.args.sql = params.sql
|
|
self.args.macros = params.macros
|
|
|
|
def handle_request(self) -> RemoteExecutionResult:
|
|
# we could get a ctrl+c at any time, including during parsing.
|
|
thread = None
|
|
started = datetime.utcnow()
|
|
try:
|
|
node = self._get_exec_node()
|
|
|
|
selected_uids = [node.unique_id]
|
|
self.runtime_cleanup(selected_uids)
|
|
|
|
thread_done = threading.Event()
|
|
thread = threading.Thread(target=self._in_thread,
|
|
args=(node, thread_done))
|
|
thread.start()
|
|
thread_done.wait()
|
|
except KeyboardInterrupt:
|
|
adapter = get_adapter(self.config) # type: ignore
|
|
if adapter.is_cancelable():
|
|
|
|
for conn_name in adapter.cancel_open_connections():
|
|
logger.debug('canceled query {}'.format(conn_name))
|
|
if thread:
|
|
thread.join()
|
|
else:
|
|
msg = ("The {} adapter does not support query "
|
|
"cancellation. Some queries may still be "
|
|
"running!".format(adapter.type()))
|
|
|
|
logger.debug(msg)
|
|
|
|
raise RPCKilledException(signal.SIGINT)
|
|
|
|
self._raise_set_error()
|
|
|
|
ended = datetime.utcnow()
|
|
elapsed = (ended - started).total_seconds()
|
|
return self.get_result(
|
|
results=self.node_results,
|
|
elapsed_time=elapsed,
|
|
generated_at=ended,
|
|
)
|
|
|
|
def interpret_results(self, results):
|
|
return True
|
|
|
|
|
|
class RemoteCompileTask(RemoteRunSQLTask, CompileTask):
|
|
METHOD_NAME = 'compile_sql'
|
|
|
|
def get_runner_type(self, _):
|
|
return RPCCompileRunner
|
|
|
|
|
|
class RemoteRunTask(RemoteRunSQLTask, RunTask):
|
|
METHOD_NAME = 'run_sql'
|
|
|
|
def get_runner_type(self, _):
|
|
return RPCExecuteRunner
|