import functools from typing import ( Optional, Type, Union, Any, Dict, cast, Tuple, List, TypeVar, get_type_hints, Callable, Generic, Hashable, ClassVar, ) import re from datetime import datetime from dataclasses import fields, is_dataclass, Field, MISSING, dataclass, asdict from uuid import UUID from enum import Enum import threading import warnings from dateutil.parser import parse import jsonschema JSON_ENCODABLE_TYPES = { str: {"type": "string"}, int: {"type": "integer"}, bool: {"type": "boolean"}, float: {"type": "number"}, type(None): {"type": "null"}, } JsonEncodable = Union[int, float, str, bool, None] JsonDict = Dict[str, Any] OPTIONAL_TYPES = ["Union", "Optional"] class ValidationError(jsonschema.ValidationError): pass class FutureValidationError(ValidationError): # a validation error where we haven't called str() on inputs yet. def __init__(self, field: str, errors: Dict[str, Exception]): self.errors = errors self.field = field super().__init__("generic validation error") self.initialized = False def late_initialize(self): lines: List[str] = [] for name, exc in self.errors.items(): # do not use getattr(exc, 'message', str(exc)), it's slow! if hasattr(exc, "message"): msg = exc.message else: msg = str(exc) lines.append(f"{name}: {msg}") super().__init__( "Unable to decode value for '{}: No members matched:\n{}".format( self.field, lines ) ) self.initialized = True def __str__(self): if not self.initialized: self.late_initialize() return super().__str__() def is_enum(field_type: Any) -> bool: return issubclass_safe(field_type, Enum) def issubclass_safe(klass: Any, base: Type) -> bool: try: return issubclass(klass, base) except TypeError: return False def is_optional(field: Any) -> bool: if str(field).startswith("typing.Union") or str(field).startswith( "typing.Optional" ): for arg in field.__args__: if isinstance(arg, type) and issubclass(arg, type(None)): return True return False TV = TypeVar("TV") class FieldEncoder(Generic[TV]): """Base class for encoding fields to and from JSON encodable values""" def to_wire(self, value: TV) -> JsonEncodable: return value # type: ignore def to_python(self, value: JsonEncodable) -> TV: return value # type: ignore @property def json_schema(self) -> JsonDict: raise NotImplementedError() class DateTimeFieldEncoder(FieldEncoder[datetime]): """Encodes datetimes to RFC3339 format""" def to_wire(self, value: datetime) -> str: out = value.isoformat() # Assume UTC if timezone is missing if value.tzinfo is None: return out + "Z" return out def to_python(self, value: JsonEncodable) -> datetime: return ( value if isinstance(value, datetime) else parse(cast(str, value)) ) @property def json_schema(self) -> JsonDict: return {"type": "string", "format": "date-time"} class UuidField(FieldEncoder[UUID]): def to_wire(self, value: UUID) -> str: return str(value) def to_python(self, value) -> UUID: return UUID(value) @property def json_schema(self) -> JsonDict: # 'format': 'uuid' is not valid in "real" JSONSchema return { "type": "string", "pattern": ( "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" ), } _ValueEncoder = Callable[[Any, Any, bool], Any] _ValueDecoder = Callable[[str, Any, Any], Any] T = TypeVar("T", bound="JsonSchemaMixin") @functools.lru_cache() def _to_camel_case(value: str) -> str: if "_" in value: parts = value.split("_") return "".join( [parts[0]] + [part[0].upper() + part[1:] for part in parts[1:]] ) else: return value @dataclass class FieldMeta: default: Any = None description: Optional[str] = None @property def as_dict(self) -> Dict: return { _to_camel_case(k): v for k, v in asdict(self).items() if v is not None } @functools.lru_cache() def _validate_schema(h_schema_cls: Hashable) -> JsonDict: schema_cls = cast(Type[JsonSchemaMixin], h_schema_cls) # making mypy happy schema = schema_cls.json_schema() jsonschema.Draft7Validator.check_schema(schema) return schema # a restriction is a list of Field, str pairs Restriction = List[Tuple[Field, str]] # a restricted variant is a pair of an object that has fields with restrictions # and those restrictions. Only JsonSchemaMixin subclasses may have restrictied # fields. Variant = Tuple[Type[T], Optional[Restriction]] def _get_restrictions(variant_type: Type) -> Restriction: """Return a list of all restrictions on the given variant of a union, in the form of a Field, name pair, where `name` is the field's name in json and the Field is the dataclass-level field name. If the variant isn't a JsonSchemaMixin subclass, there are no restrictions. """ if not issubclass_safe(variant_type, JsonSchemaMixin): return [] restrictions: Restriction = [] for field, target_name in variant_type._get_fields(): if field.metadata and "restrict" in field.metadata: restrictions.append((field, target_name)) return restrictions def get_union_fields(field_type: Union[Any]) -> List[Variant]: """ Unions have a __args__ that is all their variants (after typing's type-collapsing magic has run, so caveat emptor...) JsonSchemaMixin dataclasses have `Field`s, returned by the `_get_fields` method. This method returns list of 2-tuples: - the first value is always a type - the second value is None if there are no restrictions, or a list of restrictions if there are restrictions The list will be sorted so that unrestricted variants will always be at the end. """ fields: List[Variant] = [] for variant in field_type.__args__: restrictions: Optional[Restriction] = _get_restrictions(variant) if not restrictions: restrictions = None fields.append((variant, restrictions)) # put unrestricted variants last fields.sort(key=lambda f: f[1] is None) return fields def _encode_restrictions_met( value: Any, restrict_fields: Optional[List[Tuple[Field, str]]] ) -> bool: if restrict_fields is None: return True return all( ( hasattr(value, f.name) and getattr(value, f.name) in f.metadata["restrict"] ) for f, _ in restrict_fields ) def _decode_restrictions_met( value: Any, restrict_fields: Optional[List[Tuple[Field, str]]] ) -> bool: if restrict_fields is None: return True return all( n in value and value[n] in f.metadata["restrict"] for f, n in restrict_fields ) @dataclass class CompleteSchema: schema: JsonDict definitions: JsonDict _HOLOGRAM_LOCK = threading.RLock() class JsonSchemaMixin: """Mixin which adds methods to generate a JSON schema and convert to and from JSON encodable dicts with validation against the schema """ _field_encoders: ClassVar[Dict[Type, FieldEncoder]] = { datetime: DateTimeFieldEncoder(), UUID: UuidField(), } # Cache of the generated schema _schema: ClassVar[Optional[Dict[str, CompleteSchema]]] = None # Cache of field encode / decode functions _encode_cache: ClassVar[Optional[Dict[Any, _ValueEncoder]]] = None _decode_cache: ClassVar[Optional[Dict[Any, _ValueDecoder]]] = None _mapped_fields: ClassVar[ Optional[Dict[Any, List[Tuple[Field, str]]]] ] = None ADDITIONAL_PROPERTIES: ClassVar[bool] = False @classmethod def field_mapping(cls) -> Dict[str, str]: """Defines the mapping of python field names to JSON field names. The main use-case is to allow JSON field names which are Python keywords """ return {} @classmethod def register_field_encoders(cls, field_encoders: Dict[Type, FieldEncoder]): """Registers additional custom field encoders. If called on the base, these are added globally. The DateTimeFieldEncoder is included by default. """ if cls is not JsonSchemaMixin: cls._field_encoders = {**cls._field_encoders, **field_encoders} else: cls._field_encoders.update(field_encoders) def _local_to_dict(self, **kwargs): return self.to_dict(**kwargs) @classmethod def _encode_field( cls, field_type: Any, value: Any, omit_none: bool ) -> Any: if value is None: return value try: encoder = cls._encode_cache[field_type] # type: ignore except (KeyError, TypeError): if cls._encode_cache is None: cls._encode_cache = {} field_type_name = cls._get_field_type_name(field_type) if field_type in cls._field_encoders: def encoder(ft, v, __): return cls._field_encoders[ft].to_wire(v) elif is_enum(field_type): def encoder(_, v, __): return v.value elif field_type_name in OPTIONAL_TYPES: # Attempt to encode the field with each union variant. # TODO: Find a more reliable method than this since in the case 'Union[List[str], Dict[str, int]]' this # will just output the dict keys as a list union_fields = get_union_fields(field_type) for variant, restrict_fields in union_fields: if _encode_restrictions_met(value, restrict_fields): try: encoded = cls._encode_field( variant, value, omit_none ) break except (TypeError, AttributeError): continue if encoded is None: raise TypeError( "No variant of '{}' matched the type '{}'".format( field_type, type(value) ) ) return encoded elif field_type_name in ("Mapping", "Dict"): def encoder(ft, val, o): return { cls._encode_field( ft.__args__[0], k, o ): cls._encode_field(ft.__args__[1], v, o) for k, v in val.items() } elif field_type_name == "PatternProperty": # TODO: is there some way to set __args__ on this so it can # just re-use Dict/Mapping? def encoder(ft, val, o): return { cls._encode_field(str, k, o): cls._encode_field( ft.TARGET_TYPE, v, o ) for k, v in val.items() } elif field_type_name == "List" or ( field_type_name == "Tuple" and ... in field_type.__args__ ): def encoder(ft, val, o): if not isinstance(val, (tuple, list)): valtype = type(val) # raise a TypeError so the union encoder will capture it raise TypeError( f"Invalid type, expected {field_type_name} but got {valtype}" ) return [ cls._encode_field(ft.__args__[0], v, o) for v in val ] elif field_type_name == "Sequence": def encoder(ft, val, o): return [ cls._encode_field(ft.__args__[0], v, o) for v in val ] elif field_type_name == "Tuple": def encoder(ft, val, o): return [ cls._encode_field(ft.__args__[idx], v, o) for idx, v in enumerate(val) ] elif cls._is_json_schema_subclass(field_type): # Only need to validate at the top level def encoder(_, v, o): # this calls _local_to_dict in order to support # combining this code with mashumaro return v._local_to_dict(omit_none=o) elif hasattr(field_type, "__supertype__"): # NewType field def encoder(ft, v, o): return cls._encode_field(ft.__supertype__, v, o) else: def encoder(_, v, __): return v cls._encode_cache[field_type] = encoder # type: ignore return encoder(field_type, value, omit_none) @classmethod def _get_fields(cls) -> List[Tuple[Field, str]]: if cls._mapped_fields is None: cls._mapped_fields = {} if cls.__name__ not in cls._mapped_fields: mapped_fields = [] type_hints = get_type_hints(cls) for f in fields(cls): # Skip internal fields if f.name.startswith("_"): continue # Note fields() doesn't resolve forward refs f.type = type_hints[f.name] mapped_fields.append( (f, cls.field_mapping().get(f.name, f.name)) ) cls._mapped_fields[cls.__name__] = mapped_fields return cls._mapped_fields[cls.__name__] @classmethod def _get_field_names(cls): fields = cls._get_fields() field_names = [] for element in fields: field_names.append(element[1]) return field_names def to_dict( self, omit_none: bool = True, validate: bool = False ) -> JsonDict: """Converts the dataclass instance to a JSON encodable dict, with optional JSON schema validation. If omit_none (default True) is specified, any items with value None are removed """ data = {} for field, target_field in self._get_fields(): value = self._encode_field( field.type, getattr(self, field.name), omit_none ) if omit_none and value is None: continue data[target_field] = value if validate: self.validate(data) return data @classmethod def _decode_field( cls, field: str, field_type: Any, value: Any, validate: bool ) -> Any: if value is None: return None decoder = None try: decoder = cls._decode_cache[field_type] # type: ignore except (KeyError, TypeError): if ( type(value) in JSON_ENCODABLE_TYPES and field_type in JSON_ENCODABLE_TYPES ): return value if cls._decode_cache is None: cls._decode_cache = {} # Replace any nested dictionaries with their targets field_type_name = cls._get_field_type_name(field_type) if field_type in cls._field_encoders: def decoder(_, ft, val): return cls._field_encoders[ft].to_python(val) elif cls._is_json_schema_subclass(field_type): def decoder(_, ft, val): return ft.from_dict(val, validate=validate) elif field_type_name in OPTIONAL_TYPES: # Attempt to decode the value using each decoder in turn union_excs = ( AttributeError, TypeError, ValueError, ValidationError, ) errors: Dict[str, Exception] = {} union_fields = get_union_fields(field_type) for variant, restrict_fields in union_fields: if _decode_restrictions_met(value, restrict_fields): try: return cls._decode_field( field, variant, value, True ) except union_excs as exc: errors[str(variant)] = exc continue # none of the unions decoded, so report about all of them raise FutureValidationError(field, errors) elif field_type_name in ("Mapping", "Dict"): def decoder(f, ft, val): return { cls._decode_field( f, ft.__args__[0], k, validate ): cls._decode_field(f, ft.__args__[1], v, validate) for k, v in val.items() } elif field_type_name == "List" or ( field_type_name == "Tuple" and ... in field_type.__args__ ): seq_type = tuple if field_type_name == "Tuple" else list def decoder(f, ft, val): if not isinstance(val, (tuple, list)): valtype = type(val) # raise a TypeError so the Union decoder will capture it raise TypeError( f"Invalid type, expected {field_type_name} but got {valtype}" ) return seq_type( cls._decode_field(f, ft.__args__[0], v, validate) for v in val ) # if you want to allow strings as sequences for some reason, you # can still use 'Sequence' to get back a list of characters... elif field_type_name == "Sequence": def decoder(f, ft, val): return list( cls._decode_field(f, ft.__args__[0], v, validate) for v in val ) elif field_type_name == "Tuple": def decoder(f, ft, val): return tuple( cls._decode_field(f, ft.__args__[idx], v, validate) for idx, v in enumerate(val) ) elif hasattr(field_type, "__supertype__"): # NewType field def decoder(f, ft, val): return cls._decode_field( f, ft.__supertype__, val, validate ) elif is_enum(field_type): def decoder(_, ft, val): return ft(val) elif field_type is Any: def decoder(_, __, val): return val if decoder is None: raise ValidationError( f"Unable to decode value for '{field}: {field_type_name}' (value={value})" ) return value cls._decode_cache[field_type] = decoder return decoder(field, field_type, value) @classmethod def _find_matching_validator(cls: Type[T], data: JsonDict) -> T: if cls is not JsonSchemaMixin: raise NotImplementedError decoded = None for subclass in cls.__subclasses__(): try: if is_dataclass(subclass): return subclass.from_dict(data) except ValidationError: continue if decoded is None: raise ValidationError("No matching validator for data.") return decoded @classmethod def from_dict(cls: Type[T], data: JsonDict, validate=True) -> T: """Returns a dataclass instance with all nested classes converted from the dict given""" if cls is JsonSchemaMixin: return cls._find_matching_validator(data) init_values: Dict[str, Any] = {} non_init_values: Dict[str, Any] = {} if validate: cls.validate(data) for field, target_field in cls._get_fields(): values = init_values if field.init else non_init_values if target_field in data or ( field.default == MISSING and field.default_factory == MISSING # type: ignore ): values[field.name] = cls._decode_field( field.name, field.type, data.get(target_field), validate ) # Need to ignore the type error here, since mypy doesn't know that # subclasses are dataclasses instance = cls(**init_values) # type: ignore for field_name, value in non_init_values.items(): setattr(instance, field_name, value) return instance @staticmethod def _is_json_schema_subclass(field_type: Type) -> bool: return issubclass_safe(field_type, JsonSchemaMixin) @staticmethod def _has_definition(field_type: Type) -> bool: return ( issubclass_safe(field_type, JsonSchemaMixin) and field_type.__name__ != "PatternProperty" ) @classmethod def _get_field_meta(cls, field: Field) -> Tuple[FieldMeta, bool]: required = True field_meta = FieldMeta() default_value: Optional[Callable[[], Any]] = None if field.default is not MISSING and field.default is not None: # In case of default value given default_value = field.default elif ( field.default_factory is not MISSING # type: ignore and field.default_factory is not None # type: ignore ): # type: ignore # In case of a default factory given, we call it default_value = field.default_factory() # type: ignore if default_value is not None: field_meta.default = cls._encode_field( field.type, default_value, omit_none=False ) required = False if field.metadata is not None: if "description" in field.metadata: field_meta.description = field.metadata["description"] return field_meta, required @classmethod def _encode_restrictions( cls, restrictions: Union[List[Any], Type[Enum]] ) -> JsonDict: field_schema: JsonDict = {} member_types = set() values = [] for member in restrictions: if isinstance(member, Enum): value = member.value else: value = member member_types.add(type(value)) values.append(value) if len(member_types) == 1: member_type = member_types.pop() if member_type in JSON_ENCODABLE_TYPES: field_schema.update(JSON_ENCODABLE_TYPES[member_type]) else: field_schema.update( cls._field_encoders[member_type].json_schema ) else: # hologram used to silently do nothing here, which seems worse raise ValidationError( "Invalid schema defined: Found multiple member types - {!s}".format( member_types ) ) field_schema["enum"] = values return field_schema @classmethod def _get_schema_for_type( cls, target: Type, required: bool = True, restrictions: Optional[List[Any]] = None, ) -> Tuple[JsonDict, bool]: field_schema: JsonDict = {"type": "object"} type_name = cls._get_field_type_name(target) if target in cls._field_encoders: field_schema.update(cls._field_encoders[target].json_schema) elif restrictions: field_schema.update(cls._encode_restrictions(restrictions)) # if Union[..., None] or Optional[...] elif type_name in OPTIONAL_TYPES: field_schema = { "oneOf": [ cls._get_field_schema(variant)[0] for variant in target.__args__ ] } if is_optional(target): required = False elif is_enum(target): field_schema.update(cls._encode_restrictions(target)) elif type_name in ("Dict", "Mapping"): field_schema = {"type": "object"} if target.__args__[1] is not Any: field_schema["additionalProperties"] = cls._get_field_schema( target.__args__[1] )[0] elif type_name == "PatternProperty": field_schema = {"type": "object"} field_schema["patternProperties"] = { ".*": cls._get_field_schema(target.TARGET_TYPE)[0] } elif type_name in ("Sequence", "List") or ( type_name == "Tuple" and ... in target.__args__ ): field_schema = {"type": "array"} if target.__args__[0] is not Any: field_schema["items"] = cls._get_field_schema( target.__args__[0] )[0] elif type_name == "Tuple": tuple_len = len(target.__args__) # TODO: How do we handle Optional type within lists / tuples field_schema = { "type": "array", "minItems": tuple_len, "maxItems": tuple_len, "items": [ cls._get_field_schema(type_arg)[0] for type_arg in target.__args__ ], } elif target in JSON_ENCODABLE_TYPES: field_schema.update(JSON_ENCODABLE_TYPES[target]) elif hasattr(target, "__supertype__"): # NewType fields field_schema, _ = cls._get_field_schema(target.__supertype__) else: raise ValidationError(f"Unable to create schema for '{type_name}'") return field_schema, required @classmethod def _get_field_schema( cls, field: Union[Field, Type] ) -> Tuple[JsonDict, bool]: required = True restrictions = None if isinstance(field, Field): field_type = field.type field_meta, required = cls._get_field_meta(field) if field.metadata is not None: restrictions = field.metadata.get("restrict") else: field_type = field field_meta = FieldMeta() field_type_name = cls._get_field_type_name(field_type) if cls._has_definition(field_type): field_schema: JsonDict = { "$ref": "#/definitions/{}".format(field_type_name) } else: field_schema, required = cls._get_schema_for_type( field_type, required=required, restrictions=restrictions ) field_schema.update(field_meta.as_dict) return field_schema, required @classmethod def _get_field_definitions(cls, field_type: Any, definitions: JsonDict): field_type_name = cls._get_field_type_name(field_type) if field_type_name == "Tuple": # tuples are either like Tuple[T, ...] or Tuple[T1, T2, T3]. for member in field_type.__args__: if member is not ...: cls._get_field_definitions(member, definitions) elif field_type_name in ("Sequence", "List"): cls._get_field_definitions(field_type.__args__[0], definitions) elif field_type_name in ("Dict", "Mapping"): cls._get_field_definitions(field_type.__args__[1], definitions) elif field_type_name == "PatternProperty": cls._get_field_definitions(field_type.TARGET_TYPE, definitions) elif field_type_name in OPTIONAL_TYPES: for variant in field_type.__args__: cls._get_field_definitions(variant, definitions) elif cls._is_json_schema_subclass(field_type): # Prevent recursion from forward refs & circular type dependencies if field_type.__name__ not in definitions: definitions[field_type.__name__] = None definitions.update( field_type._json_schema_recursive( embeddable=True, definitions=definitions ) ) @classmethod def all_json_schemas(cls) -> JsonDict: """Returns JSON schemas for all subclasses""" definitions = {} for subclass in cls.__subclasses__(): if is_dataclass(subclass): definitions.update(subclass.json_schema(embeddable=True)) else: definitions.update(subclass.all_json_schemas()) return definitions @classmethod def _collect_json_schema(cls, definitions: JsonDict) -> JsonDict: """Return the schema dictionary and update the definitions dictionary for this class. """ properties = {} required = [] for field, target_field in cls._get_fields(): properties[target_field], is_required = cls._get_field_schema( field ) cls._get_field_definitions(field.type, definitions) if is_required: required.append(target_field) schema = { "type": "object", "required": required, "properties": properties, "additionalProperties": cls.ADDITIONAL_PROPERTIES, } if cls.__doc__: schema["description"] = cls.__doc__ return schema @classmethod def _schema_defs_from_cache(cls, definitions: JsonDict) -> CompleteSchema: # this has to be done at the classmethod level because each subclass # needs its own dict, and we don't want to use metaclasses here (it # makes it hard for users to use metaclasses) if cls._schema is None: with _HOLOGRAM_LOCK: # check again, in case we were waiting for someone else to do # this. if cls._schema is None: cls._schema = {} if cls.__name__ in cls._schema: return cls._schema[cls.__name__] with _HOLOGRAM_LOCK: if cls.__name__ in cls._schema: return cls._schema[cls.__name__] # ok, no schema found. go build schemas schema = cls._collect_json_schema(definitions) complete_schema = CompleteSchema( schema=schema, definitions=definitions ) # now that we finished, write our schema in. In the worst-case we write # over another thread's work. cls._schema[cls.__name__] = complete_schema return complete_schema @classmethod def _json_schema_recursive( cls, embeddable: bool, definitions: JsonDict ) -> JsonDict: schema = cls._schema_defs_from_cache(definitions) if embeddable: return {**schema.definitions, cls.__name__: schema.schema} return { **schema.schema, **{ "definitions": schema.definitions, "$schema": "http://json-schema.org/draft-07/schema#", }, } @classmethod def json_schema(cls, embeddable: bool = False) -> JsonDict: """Returns the JSON schema for the dataclass, along with the schema of any nested dataclasses within the 'definitions' field. Enable the embeddable flag to generate the schema in a format for embedding into other schemas or documents supporting JSON schema such as Swagger specs. """ if cls is JsonSchemaMixin: warnings.warn( "Calling 'JsonSchemaMixin.json_schema' is deprecated. Use 'JsonSchemaMixin.all_json_schemas' instead", DeprecationWarning, ) return cls.all_json_schemas() definitions: JsonDict = {} return cls._json_schema_recursive( embeddable=embeddable, definitions=definitions ) @staticmethod def _get_field_type_name(field_type: Any) -> str: try: return field_type.__name__ except AttributeError: # The types in the 'typing' module lack the __name__ attribute match = re.match(r"typing\.([A-Za-z]+)", str(field_type)) return str(field_type) if match is None else match.group(1) @classmethod def validate(cls, data: Any): h_cls = cast(Hashable, cls) schema = _validate_schema(h_cls) validator = jsonschema.Draft7Validator(schema) error = next(iter(validator.iter_errors(data)), None) if error is not None: raise ValidationError.create_from(error) from error