dbt-selly/dbt-env/lib/python3.8/site-packages/dbt/task/rpc/sql_commands.py

214 lines
7.1 KiB
Python
Raw Normal View History

2022-03-22 15:13:27 +00:00
import base64
import signal
import threading
from datetime import datetime
from typing import Dict, Any
from dbt import flags
from dbt.adapters.factory import get_adapter
from dbt.clients.jinja import extract_toplevel_blocks
from dbt.config.runtime import RuntimeConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedRPCNode
from dbt.contracts.rpc import RPCExecParameters
from dbt.contracts.rpc import RemoteExecutionResult
from dbt.exceptions import RPCKilledException, InternalException
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.parser.manifest import process_node, process_macro
from dbt.parser.rpc import RPCCallParser, RPCMacroParser
from dbt.rpc.error import invalid_params
from dbt.rpc.node_runners import RPCCompileRunner, RPCExecuteRunner
from dbt.task.compile import CompileTask
from dbt.task.run import RunTask
from .base import RPCTask
def add_new_refs(
manifest: Manifest,
config: RuntimeConfig,
node: ParsedRPCNode,
macros: Dict[str, Any]
) -> None:
"""Given a new node that is not in the manifest, insert the new node
into it as if it were part of regular ref processing.
"""
if config.args.single_threaded or flags.SINGLE_THREADED_HANDLER:
manifest = manifest.deepcopy()
# it's ok for macros to silently override a local project macro name
manifest.macros.update(macros)
for macro in macros.values():
process_macro(config, manifest, macro)
# We used to do 'manifest.add_nodes({node.unique_id: node}) here, but the
# node has already been added to the Manifest by the RPCCallParser
# now that we save nodes to the Manifest instead of ParseResults.
process_node(config, manifest, node)
class RemoteRunSQLTask(RPCTask[RPCExecParameters]):
def runtime_cleanup(self, selected_uids):
"""Do some pre-run cleanup that is usually performed in Task __init__.
"""
self.run_count = 0
self.num_nodes = len(selected_uids)
self.node_results = []
self._skipped_children = {}
self._skipped_children = {}
self._raise_next_tick = None
def decode_sql(self, sql: str) -> str:
"""Base64 decode a string. This should only be used for sql in calls.
:param str sql: The base64 encoded form of the original utf-8 string
:return str: The decoded utf-8 string
"""
# JSON is defined as using "unicode", we'll go a step further and
# mandate utf-8 (though for the base64 part, it doesn't really matter!)
base64_sql_bytes = str(sql).encode('utf-8')
try:
sql_bytes = base64.b64decode(base64_sql_bytes, validate=True)
except ValueError:
self.raise_invalid_base64(sql)
return sql_bytes.decode('utf-8')
@staticmethod
def raise_invalid_base64(sql):
raise invalid_params(
data={
'message': 'invalid base64-encoded sql input',
'sql': str(sql),
}
)
def _extract_request_data(self, data):
data = self.decode_sql(data)
macro_blocks = []
data_chunks = []
for block in extract_toplevel_blocks(data):
if block.block_type_name == 'macro':
macro_blocks.append(block.full_block)
else:
data_chunks.append(block.full_block)
macros = '\n'.join(macro_blocks)
sql = ''.join(data_chunks)
return sql, macros
def _get_exec_node(self):
if self.manifest is None:
raise InternalException(
'manifest not set in _get_exec_node'
)
macro_overrides = {}
macros = self.args.macros
sql, macros = self._extract_request_data(self.args.sql)
if macros:
macro_parser = RPCMacroParser(self.config, self.manifest)
for node in macro_parser.parse_remote(macros):
macro_overrides[node.unique_id] = node
self.manifest.macros.update(macro_overrides)
rpc_parser = RPCCallParser(
project=self.config,
manifest=self.manifest,
root_project=self.config,
)
rpc_node = rpc_parser.parse_remote(sql, self.args.name)
add_new_refs(
manifest=self.manifest,
config=self.config,
node=rpc_node,
macros=macro_overrides
)
# don't write our new, weird manifest!
adapter = get_adapter(self.config)
compiler = adapter.get_compiler()
self.graph = compiler.compile(self.manifest, write=False)
# previously, this compiled the ancestors, but they are compiled at
# runtime now.
return rpc_node
def _raise_set_error(self):
if self._raise_next_tick is not None:
raise self._raise_next_tick
def _in_thread(self, node, thread_done):
runner = self.get_runner(node)
try:
self.node_results.append(runner.safe_run(self.manifest))
except Exception as exc:
logger.debug('Got exception {}'.format(exc), exc_info=True)
self._raise_next_tick = exc
finally:
thread_done.set()
def set_args(self, params: RPCExecParameters):
self.args.name = params.name
self.args.sql = params.sql
self.args.macros = params.macros
def handle_request(self) -> RemoteExecutionResult:
# we could get a ctrl+c at any time, including during parsing.
thread = None
started = datetime.utcnow()
try:
node = self._get_exec_node()
selected_uids = [node.unique_id]
self.runtime_cleanup(selected_uids)
thread_done = threading.Event()
thread = threading.Thread(target=self._in_thread,
args=(node, thread_done))
thread.start()
thread_done.wait()
except KeyboardInterrupt:
adapter = get_adapter(self.config) # type: ignore
if adapter.is_cancelable():
for conn_name in adapter.cancel_open_connections():
logger.debug('canceled query {}'.format(conn_name))
if thread:
thread.join()
else:
msg = ("The {} adapter does not support query "
"cancellation. Some queries may still be "
"running!".format(adapter.type()))
logger.debug(msg)
raise RPCKilledException(signal.SIGINT)
self._raise_set_error()
ended = datetime.utcnow()
elapsed = (ended - started).total_seconds()
return self.get_result(
results=self.node_results,
elapsed_time=elapsed,
generated_at=ended,
)
def interpret_results(self, results):
return True
class RemoteCompileTask(RemoteRunSQLTask, CompileTask):
METHOD_NAME = 'compile_sql'
def get_runner_type(self, _):
return RPCCompileRunner
class RemoteRunTask(RemoteRunSQLTask, RunTask):
METHOD_NAME = 'run_sql'
def get_runner_type(self, _):
return RPCExecuteRunner