dbt-selly/dbt-env/lib/python3.8/site-packages/dbt/rpc/task_manager.py

259 lines
8.7 KiB
Python

from copy import deepcopy
import threading
import uuid
from datetime import datetime
from typing import (
Any, Dict, Optional, List, Union, Set, Callable, Type
)
import dbt.exceptions
import dbt.flags as flags
from dbt.adapters.factory import reset_adapters, register_adapter
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.rpc import (
LastParse,
ManifestStatus,
GCSettings,
GCResult,
TaskRow,
TaskID,
)
from dbt.logger import LogMessage, list_handler
from dbt.parser.manifest import ManifestLoader
from dbt.rpc.error import dbt_error
from dbt.rpc.gc import GarbageCollector
from dbt.rpc.task_handler_protocol import TaskHandlerProtocol, TaskHandlerMap
from dbt.rpc.task_handler import set_parse_state_with
from dbt.rpc.method import (
RemoteMethod, RemoteManifestMethod, RemoteBuiltinMethod, TaskTypes,
)
# pick up our builtin methods
import dbt.rpc.builtins # noqa
# import this to make sure our timedelta encoder is registered
from dbt import helper_types # noqa
WrappedHandler = Callable[..., Dict[str, Any]]
class UnconditionalError:
def __init__(self, exception: dbt.exceptions.Exception):
self.exception = dbt_error(exception)
def __call__(self, *args, **kwargs):
raise self.exception
class ParseError(UnconditionalError):
def __init__(self, parse_error):
exception = dbt.exceptions.RPCLoadException(parse_error)
super().__init__(exception)
class CurrentlyCompiling(UnconditionalError):
def __init__(self):
exception = dbt.exceptions.RPCCompiling('compile in progress')
super().__init__(exception)
class ManifestReloader(threading.Thread):
def __init__(self, task_manager: 'TaskManager') -> None:
super().__init__()
self.task_manager = task_manager
def reload_manifest(self):
logs: List[LogMessage] = []
with set_parse_state_with(self.task_manager, lambda: logs):
with list_handler(logs):
self.task_manager.parse_manifest()
def run(self) -> None:
try:
self.reload_manifest()
except Exception:
# ignore ugly thread-death error messages to stderr
pass
class TaskManager:
def __init__(self, args, config, task_types: TaskTypes) -> None:
self.args = args
self.config = config
self.manifest: Optional[Manifest] = None
self._task_types: TaskTypes = task_types
self.active_tasks: TaskHandlerMap = {}
self.gc = GarbageCollector(active_tasks=self.active_tasks)
self.last_parse: LastParse = LastParse(state=ManifestStatus.Init)
self._lock: flags.MP_CONTEXT.Lock = flags.MP_CONTEXT.Lock()
self._reloader: Optional[ManifestReloader] = None
self.reload_manifest()
def single_threaded(self):
return flags.SINGLE_THREADED_WEBSERVER or self.args.single_threaded
def _reload_task_manager_thread(self, reloader: ManifestReloader):
"""This function can only be running once at a time, as it runs in the
signal handler we replace
"""
# compile in a thread that will fix up the tag manager when it's done
reloader.start()
# only assign to _reloader here, to avoid calling join() before start()
self._reloader = reloader
def _reload_task_manager_fg(self, reloader: ManifestReloader):
"""Override for single-threaded mode to run in the foreground"""
# just reload directly
reloader.reload_manifest()
def reload_manifest(self) -> bool:
"""Reload the manifest using a manifest reloader. Returns False if the
reload was not started because it was already running.
"""
if not self.set_parsing():
return False
if self._reloader is not None:
# join() the existing reloader
self._reloader.join()
# perform the reload
reloader = ManifestReloader(self)
if self.single_threaded():
self._reload_task_manager_fg(reloader)
else:
self._reload_task_manager_thread(reloader)
return True
def reload_config(self):
config = self.config.from_args(self.args)
self.config = config
reset_adapters()
register_adapter(config)
return config
def add_request(self, request_handler: TaskHandlerProtocol):
self.active_tasks[request_handler.task_id] = request_handler
def get_request(self, task_id: TaskID) -> TaskHandlerProtocol:
try:
return self.active_tasks[task_id]
except KeyError:
# We don't recognize that ID.
raise dbt.exceptions.UnknownAsyncIDException(task_id) from None
def _get_manifest_callable(
self, task: Type[RemoteManifestMethod]
) -> Union[UnconditionalError, RemoteManifestMethod]:
state = self.last_parse.state
if state == ManifestStatus.Compiling:
return CurrentlyCompiling()
elif state == ManifestStatus.Error:
return ParseError(self.last_parse.error)
else:
if self.manifest is None:
raise dbt.exceptions.InternalException(
f'Manifest should not be None if the last parse state is '
f'{state}'
)
return task(deepcopy(self.args), self.config, self.manifest)
def rpc_task(
self, method_name: str
) -> Union[UnconditionalError, RemoteMethod]:
with self._lock:
task = self._task_types[method_name]
if issubclass(task, RemoteBuiltinMethod):
return task(self)
elif issubclass(task, RemoteManifestMethod):
return self._get_manifest_callable(task)
elif issubclass(task, RemoteMethod):
return task(deepcopy(self.args), self.config)
else:
raise dbt.exceptions.InternalException(
f'Got a task with an invalid type! {task} with method '
f'name {method_name} has a type of {task.__class__}, '
f'should be a RemoteMethod'
)
def ready(self) -> bool:
with self._lock:
return self.last_parse.state == ManifestStatus.Ready
def set_parsing(self) -> bool:
with self._lock:
if self.last_parse.state == ManifestStatus.Compiling:
return False
self.last_parse = LastParse(state=ManifestStatus.Compiling)
return True
def parse_manifest(self) -> None:
self.manifest = ManifestLoader.get_full_manifest(self.config, reset=True)
def set_compile_exception(self, exc, logs=List[LogMessage]) -> None:
assert self.last_parse.state == ManifestStatus.Compiling, \
f'invalid state {self.last_parse.state}'
self.last_parse = LastParse(
error={'message': str(exc)},
state=ManifestStatus.Error,
logs=logs
)
def set_ready(self, logs=List[LogMessage]) -> None:
assert self.last_parse.state == ManifestStatus.Compiling, \
f'invalid state {self.last_parse.state}'
self.last_parse = LastParse(
state=ManifestStatus.Ready,
logs=logs
)
def methods(self) -> Set[str]:
with self._lock:
return set(self._task_types)
def currently_compiling(self, *args, **kwargs):
"""Raise an RPC exception to trigger the error handler."""
raise dbt_error(dbt.exceptions.RPCCompiling('compile in progress'))
def compilation_error(self, *args, **kwargs):
"""Raise an RPC exception to trigger the error handler."""
raise dbt_error(
dbt.exceptions.RPCLoadException(self.last_parse.error)
)
def get_handler(
self, method, http_request, json_rpc_request
) -> Optional[Union[WrappedHandler, RemoteMethod]]:
# get_handler triggers a GC check. TODO: does this go somewhere else?
self.gc_as_required()
if method not in self._task_types:
return None
task = self.rpc_task(method)
return task
def task_table(self) -> List[TaskRow]:
rows: List[TaskRow] = []
now = datetime.utcnow()
with self._lock:
for task in self.active_tasks.values():
rows.append(task.make_task_row(now))
return rows
def gc_as_required(self) -> None:
with self._lock:
return self.gc.collect_as_required()
def gc_safe(
self,
task_ids: Optional[List[uuid.UUID]] = None,
before: Optional[datetime] = None,
settings: Optional[GCSettings] = None,
) -> GCResult:
with self._lock:
return self.gc.collect_selected(
task_ids=task_ids, before=before, settings=settings,
)