217 lines
5.9 KiB
Python
217 lines
5.9 KiB
Python
import logbook
|
|
import logbook.queues
|
|
from jsonrpc.exceptions import JSONRPCError
|
|
from dbt.dataclass_schema import StrEnum
|
|
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta
|
|
from queue import Empty
|
|
from typing import Optional, Any
|
|
|
|
from dbt.contracts.rpc import (
|
|
RemoteResult,
|
|
)
|
|
from dbt.exceptions import InternalException
|
|
from dbt.utils import restrict_to
|
|
|
|
|
|
class QueueMessageType(StrEnum):
|
|
Error = 'error'
|
|
Result = 'result'
|
|
Timeout = 'timeout'
|
|
Log = 'log'
|
|
|
|
terminating = frozenset((Error, Result, Timeout))
|
|
|
|
|
|
# This class was subclassed from JsonSchemaMixin, but it
|
|
# doesn't appear to be necessary, and Mashumaro does not
|
|
# handle logbook.LogRecord
|
|
@dataclass
|
|
class QueueMessage:
|
|
message_type: QueueMessageType
|
|
|
|
|
|
@dataclass
|
|
class QueueLogMessage(QueueMessage):
|
|
message_type: QueueMessageType = field(
|
|
metadata=restrict_to(QueueMessageType.Log)
|
|
)
|
|
record: logbook.LogRecord
|
|
|
|
@classmethod
|
|
def from_record(cls, record: logbook.LogRecord):
|
|
return QueueLogMessage(
|
|
message_type=QueueMessageType.Log,
|
|
record=record,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class QueueErrorMessage(QueueMessage):
|
|
message_type: QueueMessageType = field(
|
|
metadata=restrict_to(QueueMessageType.Error)
|
|
)
|
|
error: JSONRPCError
|
|
|
|
@classmethod
|
|
def from_error(cls, error: JSONRPCError):
|
|
return QueueErrorMessage(
|
|
message_type=QueueMessageType.Error,
|
|
error=error,
|
|
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class QueueResultMessage(QueueMessage):
|
|
message_type: QueueMessageType = field(
|
|
metadata=restrict_to(QueueMessageType.Result)
|
|
)
|
|
result: RemoteResult
|
|
|
|
@classmethod
|
|
def from_result(cls, result: RemoteResult):
|
|
return cls(
|
|
message_type=QueueMessageType.Result,
|
|
result=result,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class QueueTimeoutMessage(QueueMessage):
|
|
message_type: QueueMessageType = field(
|
|
metadata=restrict_to(QueueMessageType.Timeout),
|
|
)
|
|
|
|
@classmethod
|
|
def create(cls):
|
|
return cls(message_type=QueueMessageType.Timeout)
|
|
|
|
|
|
class QueueLogHandler(logbook.queues.MultiProcessingHandler):
|
|
def emit(self, record: logbook.LogRecord):
|
|
# trigger the cached proeprties here
|
|
record.pull_information()
|
|
self.queue.put_nowait(QueueLogMessage.from_record(record))
|
|
|
|
def emit_error(self, error: JSONRPCError):
|
|
self.queue.put_nowait(QueueErrorMessage.from_error(error))
|
|
|
|
def emit_result(self, result: RemoteResult):
|
|
self.queue.put_nowait(QueueResultMessage.from_result(result))
|
|
|
|
|
|
def _next_timeout(
|
|
started: datetime,
|
|
timeout: Optional[float],
|
|
) -> Optional[float]:
|
|
if timeout is None:
|
|
return None
|
|
|
|
end = started + timedelta(seconds=timeout)
|
|
message_timeout = end - datetime.utcnow()
|
|
return message_timeout.total_seconds()
|
|
|
|
|
|
class QueueSubscriber(logbook.queues.MultiProcessingSubscriber):
|
|
def _recv_raw(self, timeout: Optional[float]) -> Any:
|
|
if timeout is None:
|
|
return self.queue.get()
|
|
|
|
if timeout < 0:
|
|
return QueueTimeoutMessage.create()
|
|
|
|
try:
|
|
return self.queue.get(block=True, timeout=timeout)
|
|
except Empty:
|
|
return QueueTimeoutMessage.create()
|
|
|
|
def recv(
|
|
self,
|
|
timeout: Optional[float] = None
|
|
) -> QueueMessage:
|
|
"""Receives one record from the socket, loads it and dispatches it.
|
|
Returns the message type if something was dispatched or `None` if it
|
|
timed out.
|
|
"""
|
|
rv = self._recv_raw(timeout)
|
|
if not isinstance(rv, QueueMessage):
|
|
raise InternalException(
|
|
'Got invalid queue message: {}'.format(rv)
|
|
)
|
|
return rv
|
|
|
|
def handle_message(
|
|
self,
|
|
timeout: Optional[float]
|
|
) -> QueueMessage:
|
|
msg = self.recv(timeout)
|
|
if isinstance(msg, QueueLogMessage):
|
|
logbook.dispatch_record(msg.record)
|
|
return msg
|
|
elif msg.message_type in QueueMessageType.terminating:
|
|
return msg
|
|
else:
|
|
raise InternalException(
|
|
'Got invalid queue message type {}'.format(msg.message_type)
|
|
)
|
|
|
|
def dispatch_until_exit(
|
|
self,
|
|
started: datetime,
|
|
timeout: Optional[float] = None
|
|
) -> QueueMessage:
|
|
while True:
|
|
message_timeout = _next_timeout(started, timeout)
|
|
msg = self.handle_message(message_timeout)
|
|
if msg.message_type in QueueMessageType.terminating:
|
|
return msg
|
|
|
|
|
|
# a bunch of processors to push/pop that set various rpc-related extras
|
|
class ServerContext(logbook.Processor):
|
|
def process(self, record):
|
|
# the server context is the last processor in the stack, so it should
|
|
# not overwrite a context if it's already been set.
|
|
if not record.extra['context']:
|
|
record.extra['context'] = 'server'
|
|
|
|
|
|
class HTTPRequest(logbook.Processor):
|
|
def __init__(self, request):
|
|
self.request = request
|
|
|
|
def process(self, record):
|
|
record.extra['addr'] = self.request.remote_addr
|
|
record.extra['http_method'] = self.request.method
|
|
|
|
|
|
class RPCRequest(logbook.Processor):
|
|
def __init__(self, request):
|
|
self.request = request
|
|
super().__init__()
|
|
|
|
def process(self, record):
|
|
record.extra['request_id'] = self.request._id
|
|
record.extra['method'] = self.request.method
|
|
|
|
|
|
class RPCResponse(logbook.Processor):
|
|
def __init__(self, response):
|
|
self.response = response
|
|
super().__init__()
|
|
|
|
def process(self, record):
|
|
record.extra['response_code'] = 200
|
|
# the request_id could be None if the request was bad
|
|
record.extra['request_id'] = getattr(
|
|
self.response.request, '_id', None
|
|
)
|
|
|
|
|
|
class RequestContext(RPCRequest):
|
|
def process(self, record):
|
|
super().process(record)
|
|
record.extra['context'] = 'request'
|