259 lines
8.7 KiB
Python
259 lines
8.7 KiB
Python
|
from copy import deepcopy
|
||
|
import threading
|
||
|
import uuid
|
||
|
from datetime import datetime
|
||
|
from typing import (
|
||
|
Any, Dict, Optional, List, Union, Set, Callable, Type
|
||
|
)
|
||
|
|
||
|
|
||
|
import dbt.exceptions
|
||
|
import dbt.flags as flags
|
||
|
from dbt.adapters.factory import reset_adapters, register_adapter
|
||
|
from dbt.contracts.graph.manifest import Manifest
|
||
|
from dbt.contracts.rpc import (
|
||
|
LastParse,
|
||
|
ManifestStatus,
|
||
|
GCSettings,
|
||
|
GCResult,
|
||
|
TaskRow,
|
||
|
TaskID,
|
||
|
)
|
||
|
from dbt.logger import LogMessage, list_handler
|
||
|
from dbt.parser.manifest import ManifestLoader
|
||
|
from dbt.rpc.error import dbt_error
|
||
|
from dbt.rpc.gc import GarbageCollector
|
||
|
from dbt.rpc.task_handler_protocol import TaskHandlerProtocol, TaskHandlerMap
|
||
|
from dbt.rpc.task_handler import set_parse_state_with
|
||
|
from dbt.rpc.method import (
|
||
|
RemoteMethod, RemoteManifestMethod, RemoteBuiltinMethod, TaskTypes,
|
||
|
)
|
||
|
# pick up our builtin methods
|
||
|
import dbt.rpc.builtins # noqa
|
||
|
|
||
|
|
||
|
# import this to make sure our timedelta encoder is registered
|
||
|
from dbt import helper_types # noqa
|
||
|
|
||
|
|
||
|
WrappedHandler = Callable[..., Dict[str, Any]]
|
||
|
|
||
|
|
||
|
class UnconditionalError:
|
||
|
def __init__(self, exception: dbt.exceptions.Exception):
|
||
|
self.exception = dbt_error(exception)
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
raise self.exception
|
||
|
|
||
|
|
||
|
class ParseError(UnconditionalError):
|
||
|
def __init__(self, parse_error):
|
||
|
exception = dbt.exceptions.RPCLoadException(parse_error)
|
||
|
super().__init__(exception)
|
||
|
|
||
|
|
||
|
class CurrentlyCompiling(UnconditionalError):
|
||
|
def __init__(self):
|
||
|
exception = dbt.exceptions.RPCCompiling('compile in progress')
|
||
|
super().__init__(exception)
|
||
|
|
||
|
|
||
|
class ManifestReloader(threading.Thread):
|
||
|
def __init__(self, task_manager: 'TaskManager') -> None:
|
||
|
super().__init__()
|
||
|
self.task_manager = task_manager
|
||
|
|
||
|
def reload_manifest(self):
|
||
|
logs: List[LogMessage] = []
|
||
|
with set_parse_state_with(self.task_manager, lambda: logs):
|
||
|
with list_handler(logs):
|
||
|
self.task_manager.parse_manifest()
|
||
|
|
||
|
def run(self) -> None:
|
||
|
try:
|
||
|
self.reload_manifest()
|
||
|
except Exception:
|
||
|
# ignore ugly thread-death error messages to stderr
|
||
|
pass
|
||
|
|
||
|
|
||
|
class TaskManager:
|
||
|
def __init__(self, args, config, task_types: TaskTypes) -> None:
|
||
|
self.args = args
|
||
|
self.config = config
|
||
|
self.manifest: Optional[Manifest] = None
|
||
|
self._task_types: TaskTypes = task_types
|
||
|
self.active_tasks: TaskHandlerMap = {}
|
||
|
self.gc = GarbageCollector(active_tasks=self.active_tasks)
|
||
|
self.last_parse: LastParse = LastParse(state=ManifestStatus.Init)
|
||
|
self._lock: flags.MP_CONTEXT.Lock = flags.MP_CONTEXT.Lock()
|
||
|
self._reloader: Optional[ManifestReloader] = None
|
||
|
self.reload_manifest()
|
||
|
|
||
|
def single_threaded(self):
|
||
|
return flags.SINGLE_THREADED_WEBSERVER or self.args.single_threaded
|
||
|
|
||
|
def _reload_task_manager_thread(self, reloader: ManifestReloader):
|
||
|
"""This function can only be running once at a time, as it runs in the
|
||
|
signal handler we replace
|
||
|
"""
|
||
|
# compile in a thread that will fix up the tag manager when it's done
|
||
|
reloader.start()
|
||
|
# only assign to _reloader here, to avoid calling join() before start()
|
||
|
self._reloader = reloader
|
||
|
|
||
|
def _reload_task_manager_fg(self, reloader: ManifestReloader):
|
||
|
"""Override for single-threaded mode to run in the foreground"""
|
||
|
# just reload directly
|
||
|
reloader.reload_manifest()
|
||
|
|
||
|
def reload_manifest(self) -> bool:
|
||
|
"""Reload the manifest using a manifest reloader. Returns False if the
|
||
|
reload was not started because it was already running.
|
||
|
"""
|
||
|
if not self.set_parsing():
|
||
|
return False
|
||
|
if self._reloader is not None:
|
||
|
# join() the existing reloader
|
||
|
self._reloader.join()
|
||
|
# perform the reload
|
||
|
reloader = ManifestReloader(self)
|
||
|
if self.single_threaded():
|
||
|
self._reload_task_manager_fg(reloader)
|
||
|
else:
|
||
|
self._reload_task_manager_thread(reloader)
|
||
|
return True
|
||
|
|
||
|
def reload_config(self):
|
||
|
config = self.config.from_args(self.args)
|
||
|
self.config = config
|
||
|
reset_adapters()
|
||
|
register_adapter(config)
|
||
|
return config
|
||
|
|
||
|
def add_request(self, request_handler: TaskHandlerProtocol):
|
||
|
self.active_tasks[request_handler.task_id] = request_handler
|
||
|
|
||
|
def get_request(self, task_id: TaskID) -> TaskHandlerProtocol:
|
||
|
try:
|
||
|
return self.active_tasks[task_id]
|
||
|
except KeyError:
|
||
|
# We don't recognize that ID.
|
||
|
raise dbt.exceptions.UnknownAsyncIDException(task_id) from None
|
||
|
|
||
|
def _get_manifest_callable(
|
||
|
self, task: Type[RemoteManifestMethod]
|
||
|
) -> Union[UnconditionalError, RemoteManifestMethod]:
|
||
|
state = self.last_parse.state
|
||
|
if state == ManifestStatus.Compiling:
|
||
|
return CurrentlyCompiling()
|
||
|
elif state == ManifestStatus.Error:
|
||
|
return ParseError(self.last_parse.error)
|
||
|
else:
|
||
|
if self.manifest is None:
|
||
|
raise dbt.exceptions.InternalException(
|
||
|
f'Manifest should not be None if the last parse state is '
|
||
|
f'{state}'
|
||
|
)
|
||
|
return task(deepcopy(self.args), self.config, self.manifest)
|
||
|
|
||
|
def rpc_task(
|
||
|
self, method_name: str
|
||
|
) -> Union[UnconditionalError, RemoteMethod]:
|
||
|
with self._lock:
|
||
|
task = self._task_types[method_name]
|
||
|
if issubclass(task, RemoteBuiltinMethod):
|
||
|
return task(self)
|
||
|
elif issubclass(task, RemoteManifestMethod):
|
||
|
return self._get_manifest_callable(task)
|
||
|
elif issubclass(task, RemoteMethod):
|
||
|
return task(deepcopy(self.args), self.config)
|
||
|
else:
|
||
|
raise dbt.exceptions.InternalException(
|
||
|
f'Got a task with an invalid type! {task} with method '
|
||
|
f'name {method_name} has a type of {task.__class__}, '
|
||
|
f'should be a RemoteMethod'
|
||
|
)
|
||
|
|
||
|
def ready(self) -> bool:
|
||
|
with self._lock:
|
||
|
return self.last_parse.state == ManifestStatus.Ready
|
||
|
|
||
|
def set_parsing(self) -> bool:
|
||
|
with self._lock:
|
||
|
if self.last_parse.state == ManifestStatus.Compiling:
|
||
|
return False
|
||
|
self.last_parse = LastParse(state=ManifestStatus.Compiling)
|
||
|
return True
|
||
|
|
||
|
def parse_manifest(self) -> None:
|
||
|
self.manifest = ManifestLoader.get_full_manifest(self.config, reset=True)
|
||
|
|
||
|
def set_compile_exception(self, exc, logs=List[LogMessage]) -> None:
|
||
|
assert self.last_parse.state == ManifestStatus.Compiling, \
|
||
|
f'invalid state {self.last_parse.state}'
|
||
|
self.last_parse = LastParse(
|
||
|
error={'message': str(exc)},
|
||
|
state=ManifestStatus.Error,
|
||
|
logs=logs
|
||
|
)
|
||
|
|
||
|
def set_ready(self, logs=List[LogMessage]) -> None:
|
||
|
assert self.last_parse.state == ManifestStatus.Compiling, \
|
||
|
f'invalid state {self.last_parse.state}'
|
||
|
self.last_parse = LastParse(
|
||
|
state=ManifestStatus.Ready,
|
||
|
logs=logs
|
||
|
)
|
||
|
|
||
|
def methods(self) -> Set[str]:
|
||
|
with self._lock:
|
||
|
return set(self._task_types)
|
||
|
|
||
|
def currently_compiling(self, *args, **kwargs):
|
||
|
"""Raise an RPC exception to trigger the error handler."""
|
||
|
raise dbt_error(dbt.exceptions.RPCCompiling('compile in progress'))
|
||
|
|
||
|
def compilation_error(self, *args, **kwargs):
|
||
|
"""Raise an RPC exception to trigger the error handler."""
|
||
|
raise dbt_error(
|
||
|
dbt.exceptions.RPCLoadException(self.last_parse.error)
|
||
|
)
|
||
|
|
||
|
def get_handler(
|
||
|
self, method, http_request, json_rpc_request
|
||
|
) -> Optional[Union[WrappedHandler, RemoteMethod]]:
|
||
|
# get_handler triggers a GC check. TODO: does this go somewhere else?
|
||
|
self.gc_as_required()
|
||
|
|
||
|
if method not in self._task_types:
|
||
|
return None
|
||
|
|
||
|
task = self.rpc_task(method)
|
||
|
|
||
|
return task
|
||
|
|
||
|
def task_table(self) -> List[TaskRow]:
|
||
|
rows: List[TaskRow] = []
|
||
|
now = datetime.utcnow()
|
||
|
with self._lock:
|
||
|
for task in self.active_tasks.values():
|
||
|
rows.append(task.make_task_row(now))
|
||
|
return rows
|
||
|
|
||
|
def gc_as_required(self) -> None:
|
||
|
with self._lock:
|
||
|
return self.gc.collect_as_required()
|
||
|
|
||
|
def gc_safe(
|
||
|
self,
|
||
|
task_ids: Optional[List[uuid.UUID]] = None,
|
||
|
before: Optional[datetime] = None,
|
||
|
settings: Optional[GCSettings] = None,
|
||
|
) -> GCResult:
|
||
|
with self._lock:
|
||
|
return self.gc.collect_selected(
|
||
|
task_ids=task_ids, before=before, settings=settings,
|
||
|
)
|