170 lines
5.0 KiB
Python
170 lines
5.0 KiB
Python
from typing import (
|
|
Type, ClassVar, cast,
|
|
)
|
|
import re
|
|
from dataclasses import fields
|
|
from enum import Enum
|
|
from datetime import datetime
|
|
from dateutil.parser import parse
|
|
|
|
from hologram import JsonSchemaMixin, FieldEncoder, ValidationError
|
|
|
|
# type: ignore
|
|
from mashumaro import DataClassDictMixin
|
|
from mashumaro.config import (
|
|
TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig as MashBaseConfig
|
|
)
|
|
from mashumaro.types import SerializableType, SerializationStrategy
|
|
|
|
|
|
class DateTimeSerialization(SerializationStrategy):
|
|
def serialize(self, value):
|
|
out = value.isoformat()
|
|
# Assume UTC if timezone is missing
|
|
if value.tzinfo is None:
|
|
out = out + "Z"
|
|
return out
|
|
|
|
def deserialize(self, value):
|
|
return (
|
|
value if isinstance(value, datetime) else parse(cast(str, value))
|
|
)
|
|
|
|
|
|
# This class pulls in both JsonSchemaMixin from Hologram and
|
|
# DataClassDictMixin from our fork of Mashumaro. The 'to_dict'
|
|
# and 'from_dict' methods come from Mashumaro. Building
|
|
# jsonschemas for every class and the 'validate' method
|
|
# come from Hologram.
|
|
class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin):
|
|
"""Mixin which adds methods to generate a JSON schema and
|
|
convert to and from JSON encodable dicts with validation
|
|
against the schema
|
|
"""
|
|
|
|
class Config(MashBaseConfig):
|
|
code_generation_options = [
|
|
TO_DICT_ADD_OMIT_NONE_FLAG,
|
|
]
|
|
serialization_strategy = {
|
|
datetime: DateTimeSerialization(),
|
|
}
|
|
|
|
_hyphenated: ClassVar[bool] = False
|
|
ADDITIONAL_PROPERTIES: ClassVar[bool] = False
|
|
|
|
# This is called by the mashumaro to_dict in order to handle
|
|
# nested classes.
|
|
# Munges the dict that's returned.
|
|
def __post_serialize__(self, dct):
|
|
if self._hyphenated:
|
|
new_dict = {}
|
|
for key in dct:
|
|
if '_' in key:
|
|
new_key = key.replace('_', '-')
|
|
new_dict[new_key] = dct[key]
|
|
else:
|
|
new_dict[key] = dct[key]
|
|
dct = new_dict
|
|
|
|
return dct
|
|
|
|
# This is called by the mashumaro _from_dict method, before
|
|
# performing the conversion to a dict
|
|
@classmethod
|
|
def __pre_deserialize__(cls, data):
|
|
# `data` might not be a dict, e.g. for `query_comment`, which accepts
|
|
# a dict or a string; only snake-case for dict values.
|
|
if cls._hyphenated and isinstance(data, dict):
|
|
new_dict = {}
|
|
for key in data:
|
|
if '-' in key:
|
|
new_key = key.replace('-', '_')
|
|
new_dict[new_key] = data[key]
|
|
else:
|
|
new_dict[key] = data[key]
|
|
data = new_dict
|
|
return data
|
|
|
|
# This is used in the hologram._encode_field method, which calls
|
|
# a 'to_dict' method which does not have the same parameters in
|
|
# hologram and in mashumaro.
|
|
def _local_to_dict(self, **kwargs):
|
|
args = {}
|
|
if 'omit_none' in kwargs:
|
|
args['omit_none'] = kwargs['omit_none']
|
|
return self.to_dict(**args)
|
|
|
|
|
|
class ValidatedStringMixin(str, SerializableType):
|
|
ValidationRegex = ''
|
|
|
|
@classmethod
|
|
def _deserialize(cls, value: str) -> 'ValidatedStringMixin':
|
|
cls.validate(value)
|
|
return ValidatedStringMixin(value)
|
|
|
|
def _serialize(self) -> str:
|
|
return str(self)
|
|
|
|
@classmethod
|
|
def validate(cls, value):
|
|
res = re.match(cls.ValidationRegex, value)
|
|
|
|
if res is None:
|
|
raise ValidationError(f"Invalid value: {value}") # TODO
|
|
|
|
|
|
# These classes must be in this order or it doesn't work
|
|
class StrEnum(str, SerializableType, Enum):
|
|
def __str__(self):
|
|
return self.value
|
|
|
|
# https://docs.python.org/3.6/library/enum.html#using-automatic-values
|
|
def _generate_next_value_(name, *_):
|
|
return name
|
|
|
|
def _serialize(self) -> str:
|
|
return self.value
|
|
|
|
@classmethod
|
|
def _deserialize(cls, value: str):
|
|
return cls(value)
|
|
|
|
|
|
class HyphenatedDbtClassMixin(dbtClassMixin):
|
|
# used by from_dict/to_dict
|
|
_hyphenated: ClassVar[bool] = True
|
|
|
|
# used by jsonschema validation, _get_fields
|
|
@classmethod
|
|
def field_mapping(cls):
|
|
result = {}
|
|
for field in fields(cls):
|
|
skip = field.metadata.get("preserve_underscore")
|
|
if skip:
|
|
continue
|
|
|
|
if "_" in field.name:
|
|
result[field.name] = field.name.replace("_", "-")
|
|
return result
|
|
|
|
|
|
class ExtensibleDbtClassMixin(dbtClassMixin):
|
|
ADDITIONAL_PROPERTIES = True
|
|
|
|
|
|
# This is used by Hologram in jsonschema validation
|
|
def register_pattern(base_type: Type, pattern: str) -> None:
|
|
"""base_type should be a typing.NewType that should always have the given
|
|
regex pattern. That means that its underlying type ('__supertype__') had
|
|
better be a str!
|
|
"""
|
|
|
|
class PatternEncoder(FieldEncoder):
|
|
@property
|
|
def json_schema(self):
|
|
return {"type": "string", "pattern": pattern}
|
|
|
|
dbtClassMixin.register_field_encoders({base_type: PatternEncoder()})
|