97 lines
2.6 KiB
Python
97 lines
2.6 KiB
Python
|
import multiprocessing
|
||
|
from datetime import datetime
|
||
|
from typing import Optional, Union, MutableMapping
|
||
|
from typing_extensions import Protocol
|
||
|
|
||
|
import dbt.exceptions
|
||
|
from dbt.contracts.rpc import (
|
||
|
TaskHandlerState,
|
||
|
TaskID,
|
||
|
TaskTags,
|
||
|
TaskTiming,
|
||
|
TaskRow,
|
||
|
)
|
||
|
|
||
|
|
||
|
class TaskHandlerProtocol(Protocol):
|
||
|
task_id: TaskID
|
||
|
state: TaskHandlerState
|
||
|
started: Optional[datetime] = None
|
||
|
ended: Optional[datetime] = None
|
||
|
process: Optional[multiprocessing.Process] = None
|
||
|
|
||
|
@property
|
||
|
def request_id(self) -> Union[str, int]:
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
def request_source(self) -> str:
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
def timeout(self) -> Optional[float]:
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
def method(self) -> str:
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
def tags(self) -> Optional[TaskTags]:
|
||
|
pass
|
||
|
|
||
|
def _assert_started(self) -> datetime:
|
||
|
if self.started is None:
|
||
|
raise dbt.exceptions.InternalException(
|
||
|
'task handler started but start time is not set'
|
||
|
)
|
||
|
return self.started
|
||
|
|
||
|
def _assert_ended(self) -> datetime:
|
||
|
if self.ended is None:
|
||
|
raise dbt.exceptions.InternalException(
|
||
|
'task handler finished but end time is not set'
|
||
|
)
|
||
|
return self.ended
|
||
|
|
||
|
def make_task_timing(
|
||
|
self, now_time: datetime
|
||
|
) -> TaskTiming:
|
||
|
# get information about the task in a way that should not provide any
|
||
|
# conflicting information. Calculate elapsed time based on `now_time`
|
||
|
state = self.state
|
||
|
# store end/start so 'ps' output always makes sense:
|
||
|
# not started -> no start time/elapsed, running -> no end time, etc
|
||
|
end = None
|
||
|
start = None
|
||
|
elapsed = None
|
||
|
if state > TaskHandlerState.NotStarted:
|
||
|
start = self._assert_started()
|
||
|
elapsed_end = now_time
|
||
|
|
||
|
if state.finished:
|
||
|
elapsed_end = self._assert_ended()
|
||
|
end = elapsed_end
|
||
|
|
||
|
elapsed = (elapsed_end - start).total_seconds()
|
||
|
return TaskTiming(state=state, start=start, end=end, elapsed=elapsed)
|
||
|
|
||
|
def make_task_row(self, now_time: datetime) -> TaskRow:
|
||
|
timing = self.make_task_timing(now_time)
|
||
|
|
||
|
return TaskRow(
|
||
|
task_id=self.task_id,
|
||
|
request_id=self.request_id,
|
||
|
request_source=self.request_source,
|
||
|
method=self.method,
|
||
|
state=timing.state,
|
||
|
start=timing.start,
|
||
|
end=timing.end,
|
||
|
elapsed=timing.elapsed,
|
||
|
timeout=self.timeout,
|
||
|
tags=self.tags,
|
||
|
)
|
||
|
|
||
|
|
||
|
TaskHandlerMap = MutableMapping[TaskID, TaskHandlerProtocol]
|