diff --git a/.gitignore b/.gitignore index 6ca2f2542..c9e025776 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ __pycache__/ /kafka_*/ venv /karapace/version.py +.run .python-version diff --git a/karapace/dependency.py b/karapace/dependency.py new file mode 100644 index 000000000..81e602163 --- /dev/null +++ b/karapace/dependency.py @@ -0,0 +1,57 @@ +""" +karapace - dependency + +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" + +from karapace.schema_references import Reference +from karapace.typing import JsonData, Subject, Version +from typing import Any, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from karapace.schema_models import ValidatedTypedSchema + + +class DependencyVerifierResult: + def __init__(self, result: bool, message: Optional[str] = "") -> None: + self.result = result + self.message = message + + +class Dependency: + def __init__(self, name: str, subject: Subject, version: Version, target_schema: "ValidatedTypedSchema") -> None: + self.name = name + self.subject = subject + self.version = version + self.schema = target_schema + + def get_schema(self) -> "ValidatedTypedSchema": + return self.schema + + @staticmethod + def of(reference: Reference, target_schema: "ValidatedTypedSchema") -> "Dependency": + return Dependency(reference.name, reference.subject, reference.version, target_schema) + + def to_dict(self) -> JsonData: + return { + "name": self.name, + "subject": self.subject, + "version": self.version, + } + + def identifier(self) -> str: + return self.name + "_" + self.subject + "_" + str(self.version) + + def __hash__(self) -> int: + return hash((self.name, self.subject, self.version, self.schema)) + + def __eq__(self, other: Any) -> bool: + if other is None or not isinstance(other, Dependency): + return False + return ( + self.name == other.name + and self.subject == other.subject + and self.version == other.version + and self.schema == other.schema + ) diff --git a/karapace/errors.py b/karapace/errors.py index 1dbfd4248..4d091cbc2 100644 --- a/karapace/errors.py +++ b/karapace/errors.py @@ -2,6 +2,8 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ +from karapace.schema_references import Referents +from karapace.typing import Version class VersionNotFoundException(Exception): @@ -20,10 +22,18 @@ class InvalidSchema(Exception): pass +class InvalidTest(Exception): + pass + + class InvalidSchemaType(Exception): pass +class InvalidReferences(Exception): + pass + + class SchemasNotFoundException(Exception): pass @@ -44,6 +54,13 @@ class SubjectNotSoftDeletedException(Exception): pass +class ReferenceExistsException(Exception): + def __init__(self, referenced_by: Referents, version: Version): + super().__init__() + self.version = version + self.referenced_by = referenced_by + + class SubjectSoftDeletedException(Exception): pass diff --git a/karapace/in_memory_database.py b/karapace/in_memory_database.py index f04b2fb89..be54622ee 100644 --- a/karapace/in_memory_database.py +++ b/karapace/in_memory_database.py @@ -6,6 +6,7 @@ """ from dataclasses import dataclass, field from karapace.schema_models import SchemaVersion, TypedSchema +from karapace.schema_references import Reference, Referents from karapace.typing import ResolvedVersion, SchemaId, Subject from threading import Lock, RLock from typing import Dict, List, Optional, Tuple @@ -28,6 +29,7 @@ def __init__(self) -> None: self.subjects: Dict[Subject, SubjectData] = {} self.schemas: Dict[SchemaId, TypedSchema] = {} self.schema_lock_thread = RLock() + self.referenced_by: Dict[Tuple[Subject, ResolvedVersion], Referents] = {} # Content based deduplication of schemas. This is used to reduce memory # usage when the same schema is produce multiple times to the same or @@ -96,7 +98,14 @@ def get_next_version(self, *, subject: Subject) -> ResolvedVersion: return max(self.subjects[subject].schemas) + 1 def insert_schema_version( - self, *, subject: Subject, schema_id: SchemaId, version: ResolvedVersion, deleted: bool, schema: TypedSchema + self, + *, + subject: Subject, + schema_id: SchemaId, + version: ResolvedVersion, + deleted: bool, + schema: TypedSchema, + references: List[Reference], ) -> None: with self.schema_lock_thread: self.global_schema_id = max(self.global_schema_id, schema_id) @@ -119,6 +128,7 @@ def insert_schema_version( deleted=deleted, schema_id=schema_id, schema=schema, + references=references, ) if not deleted: @@ -235,3 +245,22 @@ def num_schema_versions(self) -> Tuple[int, int]: else: soft_deleted_versions += 1 return (live_versions, soft_deleted_versions) + + def insert_referenced_by(self, *, subject: Subject, version: ResolvedVersion, schema_id: SchemaId) -> None: + with self.schema_lock_thread: + referents = self.referenced_by.get((subject, version), None) + if referents: + referents.append(schema_id) + else: + self.referenced_by[(subject, version)] = [schema_id] + + def get_referenced_by(self, subject: Subject, version: ResolvedVersion) -> Optional[Referents]: + with self.schema_lock_thread: + return self.referenced_by.get((subject, version), None) + + def remove_referenced_by(self, schema_id: SchemaId, references: List[Reference]) -> None: + with self.schema_lock_thread: + for ref in references: + key = (ref.subject, ref.version) + if self.referenced_by.get(key, None) and schema_id in self.referenced_by[key]: + self.referenced_by[key].remove(schema_id) diff --git a/karapace/protobuf/compare_type_lists.py b/karapace/protobuf/compare_type_lists.py new file mode 100644 index 000000000..d8554d04d --- /dev/null +++ b/karapace/protobuf/compare_type_lists.py @@ -0,0 +1,73 @@ +""" +karapace - compare_type_lists + +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +from itertools import chain +from karapace.protobuf.compare_result import CompareResult, Modification +from karapace.protobuf.compare_type_storage import CompareTypes +from karapace.protobuf.enum_element import EnumElement +from karapace.protobuf.exception import IllegalStateException +from karapace.protobuf.message_element import MessageElement +from karapace.protobuf.type_element import TypeElement +from typing import List + + +def compare_type_lists( + self_types_list: List[TypeElement], + other_types_list: List[TypeElement], + result: CompareResult, + compare_types: CompareTypes, +) -> CompareResult: + self_types = {} + other_types = {} + self_indexes = {} + other_indexes = {} + + type_: TypeElement + for i, type_ in enumerate(self_types_list): + self_types[type_.name] = type_ + self_indexes[type_.name] = i + compare_types.add_self_type(compare_types.self_package_name, type_) + + for i, type_ in enumerate(other_types_list): + other_types[type_.name] = type_ + other_indexes[type_.name] = i + compare_types.add_other_type(compare_types.other_package_name, type_) + + for name in chain(self_types.keys(), other_types.keys() - self_types.keys()): + result.push_path(str(name), True) + + if self_types.get(name) is None and other_types.get(name) is not None: + if isinstance(other_types[name], MessageElement): + result.add_modification(Modification.MESSAGE_ADD) + elif isinstance(other_types[name], EnumElement): + result.add_modification(Modification.ENUM_ADD) + else: + raise IllegalStateException("Instance of element is not applicable") + elif self_types.get(name) is not None and other_types.get(name) is None: + if isinstance(self_types[name], MessageElement): + result.add_modification(Modification.MESSAGE_DROP) + elif isinstance(self_types[name], EnumElement): + result.add_modification(Modification.ENUM_DROP) + else: + raise IllegalStateException("Instance of element is not applicable") + else: + if other_indexes[name] != self_indexes[name]: + if isinstance(self_types[name], MessageElement): + # incompatible type + result.add_modification(Modification.MESSAGE_MOVE) + else: + raise IllegalStateException("Instance of element is not applicable") + else: + if isinstance(self_types[name], MessageElement) and isinstance(other_types[name], MessageElement): + self_types[name].compare(other_types[name], result, compare_types) + elif isinstance(self_types[name], EnumElement) and isinstance(other_types[name], EnumElement): + self_types[name].compare(other_types[name], result, compare_types) + else: + # incompatible type + result.add_modification(Modification.TYPE_ALTER) + result.pop_path(True) + + return result diff --git a/karapace/protobuf/compare_type_storage.py b/karapace/protobuf/compare_type_storage.py index e2a8aaada..cc1378d02 100644 --- a/karapace/protobuf/compare_type_storage.py +++ b/karapace/protobuf/compare_type_storage.py @@ -38,8 +38,9 @@ def compute_name(t: ProtoType, result_path: List[str], package_name: str, types: class CompareTypes: def __init__(self, self_package_name: str, other_package_name: str, result: CompareResult) -> None: - self.self_package_name = self_package_name - self.other_package_name = other_package_name + self.self_package_name = self_package_name or "" + self.other_package_name = other_package_name or "" + self.self_types: Dict[str, Union[TypeRecord, TypeRecordMap]] = {} self.other_types: Dict[str, Union[TypeRecord, TypeRecordMap]] = {} self.locked_messages: List["MessageElement"] = [] @@ -93,8 +94,11 @@ def self_type_short_name(self, t: ProtoType) -> Optional[str]: if name is None: raise IllegalArgumentException(f"Cannot determine message type {t}") type_record: TypeRecord = self.self_types.get(name) - if name.startswith(type_record.package_name): - return name[(len(type_record.package_name) + 1) :] + package_name = type_record.package_name + if package_name is None: + return name + if name.startswith(package_name): + return name[(len(package_name) + 1) :] return name def other_type_short_name(self, t: ProtoType) -> Optional[str]: @@ -102,8 +106,11 @@ def other_type_short_name(self, t: ProtoType) -> Optional[str]: if name is None: raise IllegalArgumentException(f"Cannot determine message type {t}") type_record: TypeRecord = self.other_types.get(name) - if name.startswith(type_record.package_name): - return name[(len(type_record.package_name) + 1) :] + package_name = type_record.package_name + if package_name is None: + return name + if name.startswith(package_name): + return name[(len(package_name) + 1) :] return name def lock_message(self, message: "MessageElement") -> bool: diff --git a/karapace/protobuf/dependency.py b/karapace/protobuf/dependency.py new file mode 100644 index 000000000..f02e864b5 --- /dev/null +++ b/karapace/protobuf/dependency.py @@ -0,0 +1,63 @@ +""" +karapace - dependency + +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" + +from karapace.dependency import DependencyVerifierResult +from karapace.protobuf.known_dependency import DependenciesHardcoded, KnownDependency +from karapace.protobuf.one_of_element import OneOfElement +from typing import List + + +class ProtobufDependencyVerifier: + def __init__(self) -> None: + self.declared_types: List[str] = [] + self.used_types: List[str] = [] + self.import_path: List[str] = [] + + def add_declared_type(self, full_name: str) -> None: + self.declared_types.append(full_name) + + def add_used_type(self, parent: str, element_type: str) -> None: + if element_type.find("map<") == 0: + end = element_type.find(">") + virgule = element_type.find(",") + key = element_type[4:virgule] + value = element_type[virgule + 1 : end] + value = value.strip() + self.used_types.append(parent + ";" + key) + self.used_types.append(parent + ";" + value) + else: + self.used_types.append(parent + ";" + element_type) + + def add_import(self, import_name: str) -> None: + self.import_path.append(import_name) + + def verify(self) -> DependencyVerifierResult: + declared_index = set(self.declared_types) + for used_type in self.used_types: + delimiter = used_type.rfind(";") + used_type_with_scope = "" + if delimiter != -1: + used_type_with_scope = used_type[:delimiter] + "." + used_type[delimiter + 1 :] + used_type = used_type[delimiter + 1 :] + + if not ( + used_type in DependenciesHardcoded.index + or KnownDependency.index_simple.get(used_type) is not None + or KnownDependency.index.get(used_type) is not None + or used_type in declared_index + or (delimiter != -1 and used_type_with_scope in declared_index) + or "." + used_type in declared_index + ): + return DependencyVerifierResult(False, f"type {used_type} is not defined") + + return DependencyVerifierResult(True) + + +def process_one_of(verifier: ProtobufDependencyVerifier, package_name: str, parent_name: str, one_of: OneOfElement) -> None: + parent = package_name + "." + parent_name + for field in one_of.fields: + verifier.add_used_type(parent, field.element_type) diff --git a/karapace/protobuf/exception.py b/karapace/protobuf/exception.py index 1707cfdb7..f1042d6ea 100644 --- a/karapace/protobuf/exception.py +++ b/karapace/protobuf/exception.py @@ -5,10 +5,6 @@ import json -class ProtobufParserRuntimeException(Exception): - pass - - class IllegalStateException(Exception): pass @@ -29,6 +25,10 @@ class ProtobufTypeException(Error): """Generic Protobuf type error.""" +class ProtobufUnresolvedDependencyException(ProtobufException): + """a Protobuf schema has unresolved dependency""" + + class SchemaParseException(ProtobufException): """Error while parsing a Protobuf schema descriptor.""" diff --git a/karapace/protobuf/field_element.py b/karapace/protobuf/field_element.py index d78190406..5e834ea7c 100644 --- a/karapace/protobuf/field_element.py +++ b/karapace/protobuf/field_element.py @@ -146,6 +146,7 @@ def compare_message( self_type_record = types.get_self_type(self_type) other_type_record = types.get_other_type(other_type) + self_type_element: MessageElement = self_type_record.type_element other_type_element: MessageElement = other_type_record.type_element @@ -153,3 +154,9 @@ def compare_message( result.add_modification(Modification.FIELD_NAME_ALTER) else: self_type_element.compare(other_type_element, result, types) + + def __repr__(self): + return f"{self.element_type} {self.name} = {self.tag}" + + def __str__(self): + return f"{self.element_type} {self.name} = {self.tag}" diff --git a/karapace/protobuf/io.py b/karapace/protobuf/io.py index 4dd273327..cbead3913 100644 --- a/karapace/protobuf/io.py +++ b/karapace/protobuf/io.py @@ -9,13 +9,15 @@ from karapace.protobuf.protobuf_to_dict import dict_to_protobuf, protobuf_to_dict from karapace.protobuf.schema import ProtobufSchema from karapace.protobuf.type_element import TypeElement -from typing import Any, Dict, List +from multiprocessing import Process, Queue +from typing import Any, Dict, List, Optional import hashlib import importlib import importlib.util import os import subprocess +import sys def calculate_class_name(name: str) -> str: @@ -49,30 +51,80 @@ def find_message_name(schema: ProtobufSchema, indexes: List[int]) -> str: return ".".join(result) +def crawl_dependencies_(schema: ProtobufSchema, deps_list: Dict[str, Dict[str, str]]): + if schema.dependencies: + for name, dependency in schema.dependencies.items(): + crawl_dependencies_(dependency.schema.schema, deps_list) + deps_list[name] = { + "schema": str(dependency.schema.schema), + "unique_class_name": calculate_class_name(f"{dependency.version}_{dependency.name}"), + } + + +def crawl_dependencies(schema: ProtobufSchema) -> Dict[str, Dict[str, str]]: + deps_list: Dict[str, Dict[str, str]] = {} + crawl_dependencies_(schema, deps_list) + return deps_list + + +def replace_imports(string: str, deps_list: Optional[Dict[str, Dict[str, str]]]) -> str: + if deps_list is None: + return string + for key, value in deps_list.items(): + unique_class_name = value["unique_class_name"] + ".proto" + string = string.replace('"' + key + '"', f'"{unique_class_name}"') + return string + + def get_protobuf_class_instance(schema: ProtobufSchema, class_name: str, cfg: Dict) -> Any: directory = cfg["protobuf_runtime_directory"] - proto_name = calculate_class_name(str(schema)) + deps_list = crawl_dependencies(schema) + root_class_name = "" + for value in deps_list.values(): + root_class_name = root_class_name + value["unique_class_name"] + root_class_name = root_class_name + str(schema) + proto_name = calculate_class_name(root_class_name) + proto_path = f"{proto_name}.proto" - class_path = f"{directory}/{proto_name}_pb2.py" - if not os.path.isfile(proto_path): - with open(f"{directory}/{proto_name}.proto", mode="w", encoding="utf8") as proto_text: - proto_text.write(str(schema)) - - if not os.path.isfile(class_path): - subprocess.run( - [ - "protoc", - "--python_out=./", - proto_path, - ], - check=True, - cwd=directory, - ) - - spec = importlib.util.spec_from_file_location(f"{proto_name}_pb2", class_path) + work_dir = f"{directory}/{proto_name}" + if not os.path.isdir(directory): + os.mkdir(directory) + if not os.path.isdir(work_dir): + os.mkdir(work_dir) + class_path = f"{directory}/{proto_name}/{proto_name}_pb2.py" + if not os.path.exists(class_path): + with open(f"{directory}/{proto_name}/{proto_name}.proto", mode="w", encoding="utf8") as proto_text: + proto_text.write(replace_imports(str(schema), deps_list)) + + protoc_arguments = [ + "protoc", + "--python_out=./", + proto_path, + ] + for value in deps_list.values(): + proto_file_name = value["unique_class_name"] + ".proto" + dependency_path = f"{directory}/{proto_name}/{proto_file_name}" + protoc_arguments.append(proto_file_name) + with open(dependency_path, mode="w", encoding="utf8") as proto_text: + proto_text.write(replace_imports(value["schema"], deps_list)) + + if not os.path.isfile(class_path): + subprocess.run( + protoc_arguments, + check=True, + cwd=work_dir, + ) + + sys.path.append(f"./runtime/{proto_name}") + spec = importlib.util.spec_from_file_location( + f"{proto_name}_pb2", + class_path, + ) tmp_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(tmp_module) + sys.path.pop() class_to_call = getattr(tmp_module, class_name) + return class_to_call() @@ -90,6 +142,35 @@ def read_data(config: dict, writer_schema: ProtobufSchema, reader_schema: Protob return class_instance +def reader_process(queue: Queue, config: dict, writer_schema: ProtobufSchema, reader_schema: ProtobufSchema, bio: BytesIO): + try: + queue.put(protobuf_to_dict(read_data(config, writer_schema, reader_schema, bio), True)) + except Exception as e: # pylint: disable=broad-except + queue.put(e) + + +def reader_mp(config: dict, writer_schema: ProtobufSchema, reader_schema: ProtobufSchema, bio: BytesIO) -> Dict: + # Note Protobuf enum values use C++ scoping rules, + # meaning that enum values are siblings of their type, not children of it. + # Therefore, if we have two proto files with Enums which elements have the same name we will have error. + # There we use simple way of Serialization/Deserialization (SerDe) which use python Protobuf library and + # protoc compiler. + # To avoid problem with enum values for basic SerDe support we + # will isolate work with call protobuf libraries in child process. + if __name__ == "karapace.protobuf.io": + queue = Queue() + p = Process(target=reader_process, args=(queue, config, writer_schema, reader_schema, bio)) + p.start() + result = queue.get() + p.join() + if isinstance(result, Dict): + return result + if isinstance(result, Exception): + raise result + raise IllegalArgumentException() + return {"Error": "This never must be returned"} + + class ProtobufDatumReader: """Deserialize Protobuf-encoded data into a Python data structure.""" @@ -102,10 +183,44 @@ def __init__(self, config: dict, writer_schema: ProtobufSchema = None, reader_sc self._writer_schema = writer_schema self._reader_schema = reader_schema - def read(self, bio: BytesIO) -> None: + def read(self, bio: BytesIO) -> Dict: if self._reader_schema is None: self._reader_schema = self._writer_schema - return protobuf_to_dict(read_data(self.config, self._writer_schema, self._reader_schema, bio), True) + return reader_mp(self.config, self._writer_schema, self._reader_schema, bio) + + +def writer_process(queue: Queue, config: Dict, writer_schema: ProtobufSchema, message_name: str, datum: dict): + class_instance = get_protobuf_class_instance(writer_schema, message_name, config) + try: + dict_to_protobuf(class_instance, datum) + except Exception: + # pylint: disable=raise-missing-from + e = ProtobufTypeException(writer_schema, datum) + queue.put(e) + raise e + queue.put(class_instance.SerializeToString()) + + +def writer_mp(config: Dict, writer_schema: ProtobufSchema, message_name: str, datum: Dict) -> str: + # Note Protobuf enum values use C++ scoping rules, + # meaning that enum values are siblings of their type, not children of it. + # Therefore, if we have two proto files with Enums which elements have the same name we will have error. + # There we use simple way of Serialization/Deserialization (SerDe) which use python Protobuf library and + # protoc compiler. + # To avoid problem with enum values for basic SerDe support we + # will isolate work with call protobuf libraries in child process. + if __name__ == "karapace.protobuf.io": + queue = Queue() + p = Process(target=writer_process, args=(queue, config, writer_schema, message_name, datum)) + p.start() + result = queue.get() + p.join() + if isinstance(result, bytes): + return result + if isinstance(result, Exception): + raise result + raise IllegalArgumentException() + return "Error :This never must be returned" class ProtobufDatumWriter: @@ -130,12 +245,4 @@ def write_index(self, writer: BytesIO) -> None: write_indexes(writer, [self._message_index]) def write(self, datum: dict, writer: BytesIO) -> None: - class_instance = get_protobuf_class_instance(self._writer_schema, self._message_name, self.config) - - try: - dict_to_protobuf(class_instance, datum) - except Exception: - # pylint: disable=raise-missing-from - raise ProtobufTypeException(self._writer_schema, datum) - - writer.write(class_instance.SerializeToString()) + writer.write(writer_mp(self.config, self._writer_schema, self._message_name, datum)) diff --git a/karapace/protobuf/known_dependency.py b/karapace/protobuf/known_dependency.py new file mode 100644 index 000000000..d74953c17 --- /dev/null +++ b/karapace/protobuf/known_dependency.py @@ -0,0 +1,163 @@ +""" +karapace - known_dependency + +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" + +# Support of known dependencies + +from enum import Enum +from typing import Any, Dict, Set + + +def static_init(cls: Any) -> object: + if getattr(cls, "static_init", None): + cls.static_init() + return cls + + +class KnownDependencyLocation(Enum): + ANY_LOCATION = "google/protobuf/any.proto" + API_LOCATION = "google/protobuf/api.proto" + DESCRIPTOR_LOCATION = "google/protobuf/descriptor.proto" + DURATION_LOCATION = "google/protobuf/duration.proto" + EMPTY_LOCATION = "google/protobuf/empty.proto" + FIELD_MASK_LOCATION = "google/protobuf/field_mask.proto" + SOURCE_CONTEXT_LOCATION = "google/protobuf/source_context.proto" + STRUCT_LOCATION = "google/protobuf/struct.proto" + TIMESTAMP_LOCATION = "google/protobuf/timestamp.proto" + TYPE_LOCATION = "google/protobuf/type.proto" + WRAPPER_LOCATION = "google/protobuf/wrappers.proto" + CALENDAR_PERIOD_LOCATION = "google/type/calendar_period.proto" + COLOR_LOCATION = "google/type/color.proto" + DATE_LOCATION = "google/type/date.proto" + DATETIME_LOCATION = "google/type/datetime.proto" + DAY_OF_WEEK_LOCATION = "google/type/dayofweek.proto" + DECIMAL_LOCATION = "google/type/decimal.proto" + EXPR_LOCATION = "google/type/expr.proto" + FRACTION_LOCATION = "google/type/fraction.proto" + INTERVAL_LOCATION = "google/type/interval.proto" + LATLNG_LOCATION = "google/type/latlng.proto" + MONEY_LOCATION = "google/type/money.proto" + MONTH_LOCATION = "google/type/month.proto" + PHONE_NUMBER_LOCATION = "google/type/phone_number.proto" + POSTAL_ADDRESS_LOCATION = "google/type/postal_address.proto" + QUATERNION_LOCATION = "google/type/quaternion.proto" + TIME_OF_DAY_LOCATION = "google/type/timeofday.proto" + + +@static_init +class KnownDependency: + index: Dict = dict() + index_simple: Dict = dict() + map: Dict = { + "google/protobuf/any.proto": ["google.protobuf.Any"], + "google/protobuf/api.proto": ["google.protobuf.Api", "google.protobuf.Method", "google.protobuf.Mixin"], + "google/protobuf/descriptor.proto": [ + "google.protobuf.FileDescriptorSet", + "google.protobuf.FileDescriptorProto", + "google.protobuf.DescriptorProto", + "google.protobuf.ExtensionRangeOptions", + "google.protobuf.FieldDescriptorProto", + "google.protobuf.OneofDescriptorProto", + "google.protobuf.EnumDescriptorProto", + "google.protobuf.EnumValueDescriptorProto", + "google.protobuf.ServiceDescriptorProto", + "google.protobuf.MethodDescriptorProto", + "google.protobuf.FileOptions", + "google.protobuf.MessageOptions", + "google.protobuf.FieldOptions", + "google.protobuf.OneofOptions", + "google.protobuf.EnumOptions", + "google.protobuf.EnumValueOptions", + "google.protobuf.ServiceOptions", + "google.protobuf.MethodOptions", + "google.protobuf.UninterpretedOption", + "google.protobuf.SourceCodeInfo", + "google.protobuf.GeneratedCodeInfo", + ], + "google/protobuf/duration.proto": ["google.protobuf.Duration"], + "google/protobuf/empty.proto": ["google.protobuf.Empty"], + "google/protobuf/field_mask.proto": ["google.protobuf.FieldMask"], + "google/protobuf/source_context.proto": ["google.protobuf.SourceContext"], + "google/protobuf/struct.proto": [ + "google.protobuf.Struct", + "google.protobuf.Value", + "google.protobuf.NullValue", + "google.protobuf.ListValue", + ], + "google/protobuf/timestamp.proto": ["google.protobuf.Timestamp"], + "google/protobuf/type.proto": [ + "google.protobuf.Type", + "google.protobuf.Field", + "google.protobuf.Enum", + "google.protobuf.EnumValue", + "google.protobuf.Option", + "google.protobuf.Syntax", + ], + "google/protobuf/wrappers.proto": [ + "google.protobuf.DoubleValue", + "google.protobuf.FloatValue", + "google.protobuf.Int64Value", + "google.protobuf.UInt64Value", + "google.protobuf.Int32Value", + "google.protobuf.UInt32Value", + "google.protobuf.BoolValue", + "google.protobuf.StringValue", + "google.protobuf.BytesValue", + ], + "google/type/calendar_period.proto": ["google.type.CalendarPeriod"], + "google/type/color.proto": ["google.type.Color"], + "google/type/date.proto": ["google.type.Date"], + "google/type/datetime.proto": ["google.type.DateTime", "google.type.TimeZone"], + "google/type/dayofweek.proto": ["google.type.DayOfWeek"], + "google/type/decimal.proto": ["google.type.Decimal"], + "google/type/expr.proto": ["google.type.Expr"], + "google/type/fraction.proto": ["google.type.Fraction"], + "google/type/interval.proto": ["google.type.Interval"], + "google/type/latlng.proto": ["google.type.LatLng"], + "google/type/money.proto": ["google.type.Money"], + "google/type/month.proto": ["google.type.Month"], + "google/type/phone_number.proto": ["google.type.PhoneNumber"], + "google/type/postal_address.proto": ["google.type.PostalAddress"], + "google/type/quaternion.proto": ["google.type.Quaternion"], + "google/type/timeofday.proto": ["google.type.TimeOfDay"], + "confluent/meta.proto": ["confluent.Meta"], + "confluent/type/decimal.proto": ["confluent.type.Decimal"], + } + + @classmethod + def static_init(cls) -> None: + for key, value in cls.map.items(): + for item in value: + cls.index[item] = key + cls.index["." + item] = key + dot = item.rfind(".") + cls.index_simple[item[dot + 1 :]] = key + cls.index_simple[item] = key + + +@static_init +class DependenciesHardcoded: + index: Set = set() + + @classmethod + def static_init(cls) -> None: + cls.index = { + "bool", + "bytes", + "double", + "float", + "fixed32", + "fixed64", + "int32", + "int64", + "sfixed32", + "sfixed64", + "sint32", + "sint64", + "string", + "uint32", + "uint64", + } diff --git a/karapace/protobuf/message_element.py b/karapace/protobuf/message_element.py index 6681d143f..8a2fb03f4 100644 --- a/karapace/protobuf/message_element.py +++ b/karapace/protobuf/message_element.py @@ -26,7 +26,7 @@ def __init__( location: Location, name: str, documentation: str = "", - nested_types: List[str] = None, + nested_types: List[TypeElement] = None, options: List[OptionElement] = None, reserveds: List[ReservedElement] = None, fields: List[FieldElement] = None, @@ -84,6 +84,8 @@ def to_schema(self) -> str: return "".join(result) def compare(self, other: "MessageElement", result: CompareResult, types: CompareTypes) -> None: + from karapace.protobuf.compare_type_lists import compare_type_lists + if types.lock_message(self): field: FieldElement subfield: FieldElement @@ -141,5 +143,5 @@ def compare(self, other: "MessageElement", result: CompareResult, types: Compare self_one_ofs[name].compare(other_one_ofs[name], result, types) result.pop_path() - + compare_type_lists(self.nested_types, other.nested_types, result, types) types.unlock_message(self) diff --git a/karapace/protobuf/proto_file_element.py b/karapace/protobuf/proto_file_element.py index c578f52db..eedb94986 100644 --- a/karapace/protobuf/proto_file_element.py +++ b/karapace/protobuf/proto_file_element.py @@ -4,32 +4,46 @@ """ # Ported from square/wire: # wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/ProtoFileElement.kt -from itertools import chain +from karapace.dependency import Dependency from karapace.protobuf.compare_result import CompareResult, Modification from karapace.protobuf.compare_type_storage import CompareTypes -from karapace.protobuf.enum_element import EnumElement -from karapace.protobuf.exception import IllegalStateException from karapace.protobuf.location import Location -from karapace.protobuf.message_element import MessageElement from karapace.protobuf.syntax import Syntax from karapace.protobuf.type_element import TypeElement +from typing import Dict, List, Optional + + +def _collect_dependencies_types(compare_types: CompareTypes, dependencies: Optional[Dict[str, Dependency]], is_self: bool): + for dep in dependencies.values(): + types: List[TypeElement] = dep.schema.schema.proto_file_element.types + sub_deps = dep.schema.schema.dependencies + package_name = dep.schema.schema.proto_file_element.package_name + type_: TypeElement + for type_ in types: + if is_self: + compare_types.add_self_type(package_name, type_) + else: + compare_types.add_other_type(package_name, type_) + if sub_deps is None: + return + _collect_dependencies_types(compare_types, sub_deps, is_self) class ProtoFileElement: def __init__( self, location: Location, - package_name: str = None, - syntax: Syntax = None, - imports: list = None, - public_imports: list = None, - types=None, - services: list = None, - extend_declarations: list = None, - options: list = None, + package_name: Optional[str] = None, + syntax: Optional[Syntax] = None, + imports: Optional[list] = None, + public_imports: Optional[list] = None, + types: Optional[List[TypeElement]] = None, + services: Optional[list] = None, + extend_declarations: Optional[list] = None, + options: Optional[list] = None, ) -> None: if types is None: - types = [] + types = list() self.location = location self.package_name = package_name self.syntax = syntax @@ -102,63 +116,25 @@ def __eq__(self, other: "ProtoFileElement") -> bool: # type: ignore def __repr__(self) -> str: return self.to_schema() - def compare(self, other: "ProtoFileElement", result: CompareResult) -> CompareResult: + def compare( + self, + other: "ProtoFileElement", + result: CompareResult, + self_dependencies: Optional[Dict[str, Dependency]] = None, + other_dependencies: Optional[Dict[str, Dependency]] = None, + ) -> CompareResult: + from karapace.protobuf.compare_type_lists import compare_type_lists + if self.package_name != other.package_name: result.add_modification(Modification.PACKAGE_ALTER) # TODO: do we need syntax check? if self.syntax != other.syntax: result.add_modification(Modification.SYNTAX_ALTER) - self_types = {} - other_types = {} - self_indexes = {} - other_indexes = {} compare_types = CompareTypes(self.package_name, other.package_name, result) - type_: TypeElement - for i, type_ in enumerate(self.types): - self_types[type_.name] = type_ - self_indexes[type_.name] = i - package_name = self.package_name or "" - compare_types.add_self_type(package_name, type_) - - for i, type_ in enumerate(other.types): - other_types[type_.name] = type_ - other_indexes[type_.name] = i - package_name = other.package_name or "" - compare_types.add_other_type(package_name, type_) - - for name in chain(self_types.keys(), other_types.keys() - self_types.keys()): - result.push_path(str(name), True) - - if self_types.get(name) is None and other_types.get(name) is not None: - if isinstance(other_types[name], MessageElement): - result.add_modification(Modification.MESSAGE_ADD) - elif isinstance(other_types[name], EnumElement): - result.add_modification(Modification.ENUM_ADD) - else: - raise IllegalStateException("Instance of element is not applicable") - elif self_types.get(name) is not None and other_types.get(name) is None: - if isinstance(self_types[name], MessageElement): - result.add_modification(Modification.MESSAGE_DROP) - elif isinstance(self_types[name], EnumElement): - result.add_modification(Modification.ENUM_DROP) - else: - raise IllegalStateException("Instance of element is not applicable") - else: - if other_indexes[name] != self_indexes[name]: - if isinstance(self_types[name], MessageElement): - # incompatible type - result.add_modification(Modification.MESSAGE_MOVE) - else: - raise IllegalStateException("Instance of element is not applicable") - else: - if isinstance(self_types[name], MessageElement) and isinstance(other_types[name], MessageElement): - self_types[name].compare(other_types[name], result, compare_types) - elif isinstance(self_types[name], EnumElement) and isinstance(other_types[name], EnumElement): - self_types[name].compare(other_types[name], result, compare_types) - else: - # incompatible type - result.add_modification(Modification.TYPE_ALTER) - result.pop_path(True) - - return result + if self_dependencies: + _collect_dependencies_types(compare_types, self_dependencies, True) + + if other_dependencies: + _collect_dependencies_types(compare_types, other_dependencies, False) + return compare_type_lists(self.types, other.types, result, compare_types) diff --git a/karapace/protobuf/proto_parser.py b/karapace/protobuf/proto_parser.py index 8290e5641..92c679e96 100644 --- a/karapace/protobuf/proto_parser.py +++ b/karapace/protobuf/proto_parser.py @@ -28,7 +28,7 @@ from karapace.protobuf.syntax_reader import SyntaxReader from karapace.protobuf.type_element import TypeElement from karapace.protobuf.utils import MAX_TAG_VALUE -from typing import List, Union +from typing import List, Optional, Union class Context(Enum): @@ -74,13 +74,13 @@ class ProtoParser: def __init__(self, location: Location, data: str) -> None: self.location = location self.imports: List[str] = [] - self.nested_types: List[str] = [] + self.nested_types: List[TypeElement] = [] self.services: List[str] = [] self.extends_list: List[str] = [] self.options: List[str] = [] self.declaration_count = 0 - self.syntax: Union[Syntax, None] = None - self.package_name: Union[str, None] = None + self.syntax: Optional[Syntax] = None + self.package_name: Optional[str] = None self.prefix = "" self.data = data self.public_imports: List[str] = [] @@ -179,7 +179,6 @@ def read_declaration( import_string = self.reader.read_string() if import_string == "public": self.public_imports.append(self.reader.read_string()) - else: self.imports.append(import_string) self.reader.require(";") diff --git a/karapace/protobuf/schema.py b/karapace/protobuf/schema.py index d068ecd16..dd9fa66db 100644 --- a/karapace/protobuf/schema.py +++ b/karapace/protobuf/schema.py @@ -5,7 +5,9 @@ # Ported from square/wire: # wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Schema.kt # Ported partially for required functionality. +from karapace.dependency import Dependency, DependencyVerifierResult from karapace.protobuf.compare_result import CompareResult +from karapace.protobuf.dependency import process_one_of, ProtobufDependencyVerifier from karapace.protobuf.enum_element import EnumElement from karapace.protobuf.exception import IllegalArgumentException from karapace.protobuf.location import Location @@ -13,7 +15,10 @@ from karapace.protobuf.option_element import OptionElement from karapace.protobuf.proto_file_element import ProtoFileElement from karapace.protobuf.proto_parser import ProtoParser +from karapace.protobuf.type_element import TypeElement from karapace.protobuf.utils import append_documentation, append_indented +from karapace.schema_references import Reference +from typing import Dict, List, Optional def add_slashes(text: str) -> str: @@ -104,12 +109,61 @@ def option_element_string(option: OptionElement) -> str: class ProtobufSchema: DEFAULT_LOCATION = Location.get("") - def __init__(self, schema: str) -> None: + def __init__( + self, schema: str, references: Optional[List[Reference]] = None, dependencies: Optional[Dict[str, Dependency]] = None + ) -> None: if type(schema).__name__ != "str": raise IllegalArgumentException("Non str type of schema string") self.dirty = schema self.cache_string = "" self.proto_file_element = ProtoParser.parse(self.DEFAULT_LOCATION, schema) + self.references = references + self.dependencies = dependencies + + def verify_schema_dependencies(self) -> DependencyVerifierResult: + verifier = ProtobufDependencyVerifier() + self.collect_dependencies(verifier) + return verifier.verify() + + def collect_dependencies(self, verifier: ProtobufDependencyVerifier) -> None: + if self.dependencies: + for key in self.dependencies: + self.dependencies[key].schema.schema.collect_dependencies(verifier) + + # verifier.add_import?? we have no access to own Kafka structure from this class... + # but we need data to analyse imports to avoid cyclic dependencies... + + package_name = self.proto_file_element.package_name + if package_name is None: + package_name = "" + else: + package_name = "." + package_name + for element_type in self.proto_file_element.types: + type_name = element_type.name + full_name = package_name + "." + type_name + verifier.add_declared_type(full_name) + verifier.add_declared_type(type_name) + if isinstance(element_type, MessageElement): + for one_of in element_type.one_ofs: + process_one_of(verifier, package_name, type_name, one_of) + for field in element_type.fields: + verifier.add_used_type(full_name, field.element_type) + for nested_type in element_type.nested_types: + self._process_nested_type(verifier, package_name, type_name, nested_type) + + def _process_nested_type( + self, verifier: ProtobufDependencyVerifier, package_name: str, parent_name, element_type: TypeElement + ): + verifier.add_declared_type(package_name + "." + parent_name + "." + element_type.name) + verifier.add_declared_type(parent_name + "." + element_type.name) + + if isinstance(element_type, MessageElement): + for one_of in element_type.one_ofs: + process_one_of(verifier, package_name, parent_name, one_of) + for field in element_type.fields: + verifier.add_used_type(parent_name, field.element_type) + for nested_type in element_type.nested_types: + self._process_nested_type(verifier, package_name, parent_name + "." + element_type.name, nested_type) def __str__(self) -> str: if not self.cache_string: @@ -166,4 +220,9 @@ def to_schema(self) -> str: return "".join(strings) def compare(self, other: "ProtobufSchema", result: CompareResult) -> CompareResult: - self.proto_file_element.compare(other.proto_file_element, result) + return self.proto_file_element.compare( + other.proto_file_element, + result, + self_dependencies=self.dependencies, + other_dependencies=other.dependencies, + ) diff --git a/karapace/protobuf/type_element.py b/karapace/protobuf/type_element.py index 1520beffc..e4bd20b2d 100644 --- a/karapace/protobuf/type_element.py +++ b/karapace/protobuf/type_element.py @@ -9,6 +9,8 @@ from typing import List, TYPE_CHECKING if TYPE_CHECKING: + from karapace.protobuf.compare_result import CompareResult + from karapace.protobuf.compare_type_storage import CompareTypes from karapace.protobuf.option_element import OptionElement @@ -34,3 +36,6 @@ def __repr__(self) -> str: def __str__(self) -> str: mytype = type(self) return f"{mytype}({self.to_schema()})" + + def compare(self, other: "TypeElement", result: "CompareResult", types: "CompareTypes") -> None: + pass diff --git a/karapace/schema_models.py b/karapace/schema_models.py index dd72f568f..e87ef9e53 100644 --- a/karapace/schema_models.py +++ b/karapace/schema_models.py @@ -5,22 +5,24 @@ from avro.errors import SchemaParseException from avro.schema import parse as avro_parse, Schema as AvroSchema from dataclasses import dataclass -from enum import Enum, unique from jsonschema import Draft7Validator from jsonschema.exceptions import SchemaError +from karapace.dependency import Dependency from karapace.errors import InvalidSchema from karapace.protobuf.exception import ( Error as ProtobufError, IllegalArgumentException, IllegalStateException, ProtobufException, - ProtobufParserRuntimeException, + ProtobufUnresolvedDependencyException, SchemaParseException as ProtobufSchemaParseException, ) from karapace.protobuf.schema import ProtobufSchema +from karapace.schema_references import Reference +from karapace.schema_type import SchemaType from karapace.typing import ResolvedVersion, SchemaId, Subject from karapace.utils import json_decode, json_encode, JSONDecodeError -from typing import Any, cast, Dict, NoReturn, Optional, Union +from typing import Any, cast, Dict, List, NoReturn, Optional, Union import hashlib import logging @@ -51,22 +53,24 @@ def parse_jsonschema_definition(schema_definition: str) -> Draft7Validator: return Draft7Validator(schema) -def parse_protobuf_schema_definition(schema_definition: str) -> ProtobufSchema: +def parse_protobuf_schema_definition( + schema_definition: str, + references: Optional[List[Reference]] = None, + dependencies: Optional[Dict[str, Dependency]] = None, + validate_references: bool = True, +) -> ProtobufSchema: """Parses and validates `schema_definition`. Raises: - Nothing yet. + ProtobufUnresolvedDependencyException if Protobuf dependency cannot be resolved. """ - - return ProtobufSchema(schema_definition) - - -@unique -class SchemaType(str, Enum): - AVRO = "AVRO" - JSONSCHEMA = "JSON" - PROTOBUF = "PROTOBUF" + protobuf_schema = ProtobufSchema(schema_definition, references, dependencies) + if validate_references: + result = protobuf_schema.verify_schema_dependencies() + if not result.result: + raise ProtobufUnresolvedDependencyException(f"{result.message}") + return protobuf_schema def _assert_never(no_return: NoReturn) -> NoReturn: @@ -74,18 +78,30 @@ def _assert_never(no_return: NoReturn) -> NoReturn: class TypedSchema: - def __init__(self, schema_type: SchemaType, schema_str: str): + def __init__( + self, + *, + schema_type: SchemaType, + schema_str: str, + schema: Optional[Union[Draft7Validator, AvroSchema, ProtobufSchema]] = None, + references: Optional[List[Reference]] = None, + dependencies: Optional[Dict[str, Dependency]] = None, + ): """Schema with type information Args: schema_type (SchemaType): The type of the schema schema_str (str): The original schema string + schema (Optional[Union[Draft7Validator, AvroSchema, ProtobufSchema]]): The parsed and validated schema + references (Optional[List[Dependency]]): The references of schema """ self.schema_type = schema_type - self.schema_str = TypedSchema.normalize_schema_str(schema_str, schema_type) + self.references = references + self.dependencies = dependencies + self.schema_str = TypedSchema.normalize_schema_str(schema_str, schema_type, schema) self.max_id: Optional[SchemaId] = None - self._fingerprint_cached: Optional[str] = None + self._str_cached: Optional[str] = None def to_dict(self) -> Dict[str, Any]: if self.schema_type is SchemaType.PROTOBUF: @@ -98,7 +114,13 @@ def fingerprint(self) -> str: return self._fingerprint_cached @staticmethod - def normalize_schema_str(schema_str: str, schema_type: SchemaType) -> str: + def normalize_schema_str( + schema_str: str, + schema_type: SchemaType, + schema: Optional[Union[Draft7Validator, AvroSchema, ProtobufSchema]] = None, + # references: Optional[List[Reference]] = None, + # dependencies: Optional[Dict[str, Dependency]] = None, + ) -> str: if schema_type is SchemaType.AVRO or schema_type is SchemaType.JSONSCHEMA: try: schema_str = json_encode(json_decode(schema_str), compact=True, sort_keys=True) @@ -106,23 +128,49 @@ def normalize_schema_str(schema_str: str, schema_type: SchemaType) -> str: LOG.error("Schema is not valid JSON") raise e elif schema_type == SchemaType.PROTOBUF: - try: - schema_str = str(parse_protobuf_schema_definition(schema_str)) - except InvalidSchema as e: - LOG.exception("Schema is not valid ProtoBuf definition") - raise e + if schema: + schema_str = str(schema) + else: + try: + schema_str = str(parse_protobuf_schema_definition(schema_str, None, None, False)) + except InvalidSchema as e: + LOG.exception("Schema is not valid ProtoBuf definition") + raise e + else: _assert_never(schema_type) return schema_str def __str__(self) -> str: - return self.schema_str + if self.schema_type == SchemaType.PROTOBUF: + return self.schema_str + + if self._str_cached is None: + self._str_cached = json_encode(self.to_dict()) + return self._str_cached def __repr__(self) -> str: return f"TypedSchema(type={self.schema_type}, schema={str(self)})" def __eq__(self, other: Any) -> bool: - return isinstance(other, TypedSchema) and str(self) == str(other) and self.schema_type is other.schema_type + return ( + isinstance(other, (TypedSchema)) + and self.schema_type is other.schema_type + and str(self) == str(other) + and self.references == other.references + ) + + @property + def schema(self) -> Union[Draft7Validator, AvroSchema, ProtobufSchema]: + parsed_typed_schema = parse( + schema_type=self.schema_type, + schema_str=self.schema_str, + validate_avro_names=True, + validate_avro_enum_symbols=True, + references=self.references, + dependencies=self.dependencies, + ) + return parsed_typed_schema.schema def parse( @@ -130,6 +178,8 @@ def parse( schema_str: str, validate_avro_enum_symbols: bool, validate_avro_names: bool, + references: Optional[List[Reference]] = None, + dependencies: Optional[Dict[str, Dependency]] = None, ) -> "ParsedTypedSchema": if schema_type not in [SchemaType.AVRO, SchemaType.JSONSCHEMA, SchemaType.PROTOBUF]: raise InvalidSchema(f"Unknown parser {schema_type} for {schema_str}") @@ -154,12 +204,11 @@ def parse( elif schema_type is SchemaType.PROTOBUF: try: - parsed_schema = parse_protobuf_schema_definition(schema_str) + parsed_schema = parse_protobuf_schema_definition(schema_str, references, dependencies) except ( TypeError, SchemaError, AssertionError, - ProtobufParserRuntimeException, IllegalStateException, IllegalArgumentException, ProtobufError, @@ -170,7 +219,13 @@ def parse( else: raise InvalidSchema(f"Unknown parser {schema_type} for {schema_str}") - return ParsedTypedSchema(schema_type=schema_type, schema_str=schema_str, schema=parsed_schema) + return ParsedTypedSchema( + schema_type=schema_type, + schema_str=schema_str, + schema=parsed_schema, + references=references, + dependencies=dependencies, + ) class ParsedTypedSchema(TypedSchema): @@ -192,17 +247,34 @@ class ParsedTypedSchema(TypedSchema): are considered by the current version of the SDK invalid. """ - def __init__(self, schema_type: SchemaType, schema_str: str, schema: Union[Draft7Validator, AvroSchema, ProtobufSchema]): - super().__init__(schema_type=schema_type, schema_str=schema_str) - self.schema = schema + def __init__( + self, + schema_type: SchemaType, + schema_str: str, + schema: Union[Draft7Validator, AvroSchema, ProtobufSchema], + references: Optional[List[Reference]] = None, + dependencies: Optional[Dict[str, Dependency]] = None, + ): + self._schema_cached: Optional[Union[Draft7Validator, AvroSchema, ProtobufSchema]] = schema + + super().__init__( + schema_type=schema_type, schema_str=schema_str, references=references, dependencies=dependencies, schema=schema + ) @staticmethod - def parse(schema_type: SchemaType, schema_str: str) -> "ParsedTypedSchema": + def parse( + schema_type: SchemaType, + schema_str: str, + references: Optional[List[Reference]] = None, + dependencies: Optional[Dict[str, Dependency]] = None, + ) -> "ParsedTypedSchema": return parse( schema_type=schema_type, schema_str=schema_str, validate_avro_enum_symbols=False, validate_avro_names=False, + references=references, + dependencies=dependencies, ) def __str__(self) -> str: @@ -210,6 +282,16 @@ def __str__(self) -> str: return str(self.schema) return super().__str__() + @property + def schema(self) -> Union[Draft7Validator, AvroSchema, ProtobufSchema]: + if self._schema_cached is not None: + return self._schema_cached + self._schema_cached = super().schema + return self._schema_cached + + def get_references(self) -> Optional[List[Reference]]: + return self.references + class ValidatedTypedSchema(ParsedTypedSchema): """Validated schema resource. @@ -225,17 +307,34 @@ class ValidatedTypedSchema(ParsedTypedSchema): are considered by the current version of the SDK invalid. """ - def __init__(self, schema_type: SchemaType, schema_str: str, schema: Union[Draft7Validator, AvroSchema, ProtobufSchema]): - super().__init__(schema_type=schema_type, schema_str=schema_str, schema=schema) + def __init__( + self, + schema_type: SchemaType, + schema_str: str, + schema: Union[Draft7Validator, AvroSchema, ProtobufSchema], + references: Optional[List[Reference]] = None, + dependencies: Optional[Dict[str, Dependency]] = None, + ): + super().__init__( + schema_type=schema_type, schema_str=schema_str, references=references, dependencies=dependencies, schema=schema + ) @staticmethod - def parse(schema_type: SchemaType, schema_str: str) -> "ValidatedTypedSchema": + def parse( + schema_type: SchemaType, + schema_str: str, + references: Optional[List[Reference]] = None, + dependencies: Optional[Dict[str, Dependency]] = None, + ) -> "ValidatedTypedSchema": parsed_schema = parse( schema_type=schema_type, schema_str=schema_str, validate_avro_enum_symbols=True, validate_avro_names=True, + references=references, + dependencies=dependencies, ) + return cast(ValidatedTypedSchema, parsed_schema) @@ -246,3 +345,4 @@ class SchemaVersion: deleted: bool schema_id: SchemaId schema: TypedSchema + references: Optional[List[Reference]] diff --git a/karapace/schema_reader.py b/karapace/schema_reader.py index dee44a942..b466202ff 100644 --- a/karapace/schema_reader.py +++ b/karapace/schema_reader.py @@ -4,7 +4,9 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ +from avro.schema import Schema as AvroSchema from contextlib import closing, ExitStack +from jsonschema.validators import Draft7Validator from kafka import KafkaConsumer, TopicPartition from kafka.admin import KafkaAdminClient, NewTopic from kafka.errors import ( @@ -17,17 +19,22 @@ ) from karapace import constants from karapace.config import Config -from karapace.errors import InvalidSchema +from karapace.dependency import Dependency +from karapace.errors import InvalidReferences, InvalidSchema from karapace.in_memory_database import InMemoryDatabase from karapace.key_format import is_key_in_canonical_format, KeyFormatter, KeyMode from karapace.master_coordinator import MasterCoordinator from karapace.offset_watcher import OffsetWatcher -from karapace.schema_models import SchemaType, TypedSchema +from karapace.protobuf.schema import ProtobufSchema +from karapace.schema_models import parse_protobuf_schema_definition, SchemaType, SchemaVersion, TypedSchema +from karapace.schema_references import Reference, Referents from karapace.statsd import StatsClient +from karapace.typing import JsonObject, ResolvedVersion, SchemaId from karapace.utils import json_decode, JSONDecodeError, KarapaceKafkaClient from threading import Event, Thread -from typing import Optional +from typing import Dict, List, Optional, Union +import json import logging import time @@ -44,7 +51,6 @@ KAFKA_CLIENT_CREATION_TIMEOUT_SECONDS = 2.0 SCHEMA_TOPIC_CREATION_TIMEOUT_SECONDS = 5.0 - # Metric names METRIC_SCHEMA_TOPIC_RECORDS_PROCESSED_COUNT = "karapace_schema_reader_records_processed" METRIC_SCHEMA_TOPIC_RECORDS_PER_KEYMODE_GAUGE = "karapace_schema_reader_records_per_keymode" @@ -116,6 +122,7 @@ def __init__( self.master_coordinator = master_coordinator self.timeout_ms = 200 self.config = config + self.database = database self.admin_client: Optional[KafkaAdminClient] = None self.topic_replication_factor = self.config["replication_factor"] @@ -381,7 +388,6 @@ def _handle_msg_config(self, key: dict, value: Optional[dict]) -> None: if self.database.find_subject(subject=subject) is None: LOG.info("Adding first version of subject: %r with no schemas", subject) self.database.insert_subject(subject=subject) - if not value: LOG.info("Deleting compatibility config completely for subject: %r", subject) self.database.delete_subject_compatibility(subject=subject) @@ -430,6 +436,9 @@ def _handle_msg_schema(self, key: dict, value: Optional[dict]) -> None: schema_id = value["id"] schema_version = value["version"] schema_deleted = value.get("deleted", False) + schema_references = value.get("references", None) + resolved_references: Optional[List[Reference]] = None + resolved_dependencies: Optional[Dict[str, Dependency]] = None try: schema_type_parsed = SchemaType(schema_type) @@ -443,21 +452,61 @@ def _handle_msg_schema(self, key: dict, value: Optional[dict]) -> None: # won't interfere with the equality. Note: This means it is possible # for the REST API to return data that is formatted differently from # what is available in the topic. + + parsed_schema: Optional[Union[Draft7Validator, AvroSchema, ProtobufSchema]] = None + resolved_dependencies: Optional[Dict[str, Dependency]] = None + if schema_type_parsed in [SchemaType.AVRO, SchemaType.JSONSCHEMA]: + try: + schema_str = json.dumps(json.loads(schema_str), sort_keys=True) + except json.JSONDecodeError: + LOG.error("Schema is not valid JSON") + return + elif schema_type_parsed == SchemaType.PROTOBUF: + try: + if schema_references: + resolved_references = [ + Reference(reference["name"], reference["subject"], reference["version"]) + for reference in schema_references + ] + resolved_dependencies = self.resolve_references(resolved_references) + parsed_schema = parse_protobuf_schema_definition( + schema_str, + resolved_references, + resolved_dependencies, + validate_references=False, + ) + schema_str = str(parsed_schema) + except InvalidSchema: + LOG.exception("Schema is not valid ProtoBuf definition") + return + except InvalidReferences: + LOG.exception("Invalid Protobuf references") + return + try: typed_schema = TypedSchema( schema_type=schema_type_parsed, schema_str=schema_str, + references=resolved_references, + dependencies=resolved_dependencies, + schema=parsed_schema, ) except (InvalidSchema, JSONDecodeError): return + self.database.insert_schema_version( subject=schema_subject, schema_id=schema_id, version=schema_version, deleted=schema_deleted, schema=typed_schema, + references=resolved_references, ) + if resolved_references: + for ref in resolved_references: + self.database.insert_referenced_by(subject=ref.subject, version=ref.version, schema_id=schema_id) + def handle_msg(self, key: dict, value: Optional[dict]) -> None: if key["keytype"] == "CONFIG": self._handle_msg_config(key, value) @@ -467,3 +516,36 @@ def handle_msg(self, key: dict, value: Optional[dict]) -> None: self._handle_msg_delete_subject(key, value) elif key["keytype"] == "NOOP": # for spec completeness pass + + def remove_referenced_by(self, schema_id: SchemaId, references: List[Reference]): + self.database.remove_referenced_by(schema_id, references) + + def get_referenced_by(self, subject: Subject, version: ResolvedVersion) -> Optional[Referents]: + return self.database.get_referenced_by(subject, version) + + def _resolve_reference(self, reference: Reference) -> Dependency: + subject_data = self.database.find_subject_schemas(subject=reference.subject, include_deleted=False) + if not subject_data: + raise InvalidReferences(f"Subject not found {reference.subject}.") + schema_version: SchemaVersion = subject_data.get(reference.version, None) + if schema_version is None: + raise InvalidReferences(f"Subject {reference.subject} has no such schema version") + schema: TypedSchema = schema_version.schema + if not schema: + raise InvalidReferences(f"No schema in {reference.subject} with version {reference.version}.") + if schema.references: + schema_dependencies = self.resolve_references(schema.references) + if schema.dependencies is None: + schema.dependencies = schema_dependencies + return Dependency.of(reference, schema) + + def resolve_references(self, references: Union[List[Reference], JsonObject]) -> Dict[str, Dependency]: + dependencies: Dict[str, Dependency] = dict() + for reference in references: + if isinstance(reference, Reference): + dependencies[reference.name] = self._resolve_reference(reference) + else: + dependencies[reference["name"]] = self._resolve_reference( + Reference(reference["name"], reference["subject"], reference["version"]) + ) + return dependencies diff --git a/karapace/schema_references.py b/karapace/schema_references.py new file mode 100644 index 000000000..50d7e978f --- /dev/null +++ b/karapace/schema_references.py @@ -0,0 +1,36 @@ +""" +karapace schema_references + +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" + +from karapace.typing import JsonData, ResolvedVersion, Subject +from typing import Any, List + +Referents = List + + +class Reference: + def __init__(self, name: str, subject: Subject, version: ResolvedVersion): + self.name = name + self.subject = subject + self.version = version + + def to_dict(self) -> JsonData: + return { + "name": self.name, + "subject": self.subject, + "version": self.version, + } + + def __repr__(self) -> str: + return f"{{name='{self.name}', subject='{self.subject}', version={self.version}}}" + + def __hash__(self) -> int: + return hash((self.name, self.subject, self.version)) + + def __eq__(self, other: Any) -> bool: + if other is None or not isinstance(other, Reference): + return False + return self.name == other.name and self.subject == other.subject and self.version == other.version diff --git a/karapace/schema_registry.py b/karapace/schema_registry.py index 23bb6da8c..f61787320 100644 --- a/karapace/schema_registry.py +++ b/karapace/schema_registry.py @@ -6,9 +6,11 @@ from karapace.compatibility import check_compatibility, CompatibilityModes from karapace.compatibility.jsonschema.checks import is_incompatible from karapace.config import Config +from karapace.dependency import Dependency from karapace.errors import ( IncompatibleSchema, InvalidVersion, + ReferenceExistsException, SchemasNotFoundException, SchemaVersionNotSoftDeletedException, SchemaVersionSoftDeletedException, @@ -24,6 +26,7 @@ from karapace.offset_watcher import OffsetWatcher from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema from karapace.schema_reader import KafkaSchemaReader +from karapace.schema_references import Reference from karapace.typing import JsonObject, ResolvedVersion, Subject, Version from typing import Dict, List, Optional, Tuple, Union @@ -171,7 +174,11 @@ async def subject_delete_local(self, subject: str, permanent: bool) -> List[Reso version_list = [] if permanent: version_list = list(schema_versions) - latest_version_id = version_list[-1] + for version_id, schema_version in list(schema_versions.items()): + referenced_by = self.schema_reader.get_referenced_by(subject, schema_version.version) + if referenced_by and len(referenced_by) > 0: + raise ReferenceExistsException(referenced_by, version_id) + for version_id, schema_version in list(schema_versions.items()): LOG.info( "Permanently deleting subject '%s' version %s (schema id=%s)", @@ -180,8 +187,15 @@ async def subject_delete_local(self, subject: str, permanent: bool) -> List[Reso schema_version.schema_id, ) self.send_schema_message( - subject=subject, schema=None, schema_id=schema_version.schema_id, version=version_id, deleted=True + subject=subject, + schema=None, + schema_id=schema_version.schema_id, + version=version_id, + deleted=True, + references=schema_version.references, ) + if schema_version.references and len(schema_version.references) > 0: + self.schema_reader.remove_referenced_by(schema_version.schema_id, schema_version.references) else: try: schema_versions_live = self.subject_get(subject, include_deleted=False) @@ -190,6 +204,10 @@ async def subject_delete_local(self, subject: str, permanent: bool) -> List[Reso latest_version_id = version_list[-1] except SchemasNotFoundException: pass + + referenced_by = self.schema_reader.get_referenced_by(subject, latest_version_id) + if referenced_by and len(referenced_by) > 0: + raise ReferenceExistsException(referenced_by, latest_version_id) self.send_delete_subject_message(subject, latest_version_id) return version_list @@ -215,13 +233,20 @@ async def subject_version_delete_local(self, subject: Subject, version: Version, if permanent and not schema_version.deleted: raise SchemaVersionNotSoftDeletedException() + referenced_by = self.schema_reader.get_referenced_by(subject, resolved_version) + if referenced_by and len(referenced_by) > 0: + raise ReferenceExistsException(referenced_by, version) + self.send_schema_message( subject=subject, schema=None if permanent else schema_version.schema, schema_id=schema_version.schema_id, version=resolved_version, deleted=True, + references=schema_version.references, ) + if schema_version.references and len(schema_version.references) > 0: + self.schema_reader.remove_referenced_by(schema_version.schema_id, schema_version.references) return resolved_version def subject_get(self, subject: Subject, include_deleted: bool = False) -> Dict[ResolvedVersion, SchemaVersion]: @@ -253,18 +278,39 @@ def subject_version_get(self, subject: Subject, version: Version, *, include_del "id": schema_id, "schema": schema.schema_str, } + if schema.references is not None: + ret["references"] = [reference.to_dict() for reference in schema.references] if schema.schema_type is not SchemaType.AVRO: ret["schemaType"] = schema.schema_type # Return also compatibility information to compatibility check compatibility = self.database.get_subject_compatibility(subject=subject) if compatibility: ret["compatibility"] = compatibility + return ret + async def subject_version_referencedby_get( + self, subject: Subject, version: Version, *, include_deleted: bool = False + ) -> List: + validate_version(version) + schema_versions = self.subject_get(subject, include_deleted=include_deleted) + if not schema_versions: + raise SubjectNotFoundException() + resolved_version = _resolve_version(schema_versions=schema_versions, version=version) + schema_data: Optional[SchemaVersion] = schema_versions.get(resolved_version, None) + if not schema_data: + raise VersionNotFoundException() + referenced_by = self.schema_reader.get_referenced_by(schema_data.subject, schema_data.version) + + if referenced_by and len(referenced_by) > 0: + return list(referenced_by) + return [] + async def write_new_schema_local( self, subject: Subject, new_schema: ValidatedTypedSchema, + new_schema_references: Optional[List[Reference]], ) -> int: """Write new schema and return new id or return id of matching existing schema @@ -317,6 +363,7 @@ async def write_new_schema_local( schema_id=schema_id, version=version, deleted=False, + references=new_schema_references, ) return schema_id @@ -333,8 +380,15 @@ async def write_new_schema_local( for old_version in check_against: old_schema = all_schema_versions[old_version].schema + old_schema_dependencies: Optional[Dict[str, Dependency]] = None + old_schema_references: Optional[List[Reference]] = old_schema.references + if old_schema_references: + old_schema_dependencies = self.resolve_references(old_schema_references) parsed_old_schema = ParsedTypedSchema.parse( - schema_type=old_schema.schema_type, schema_str=old_schema.schema_str + schema_type=old_schema.schema_type, + schema_str=old_schema.schema_str, + references=old_schema_references, + dependencies=old_schema_dependencies, ) result = check_compatibility( old_schema=parsed_old_schema, @@ -369,6 +423,7 @@ async def write_new_schema_local( schema_id=schema_id, version=version, deleted=False, + references=new_schema_references, ) return schema_id @@ -392,6 +447,7 @@ def send_schema_message( schema_id: int, version: int, deleted: bool, + references: Optional[List[Reference]], ) -> None: key = {"subject": subject, "version": version, "magic": 1, "keytype": "SCHEMA"} if schema: @@ -402,6 +458,8 @@ def send_schema_message( "schema": str(schema), "deleted": deleted, } + if references: + value["references"] = [reference.to_dict() for reference in references] if schema.schema_type is not SchemaType.AVRO: value["schemaType"] = schema.schema_type else: @@ -417,6 +475,13 @@ def send_config_subject_delete_message(self, subject: Subject) -> None: key = {"subject": subject, "magic": 0, "keytype": "CONFIG"} self.producer.send_message(key=key, value=None) + def resolve_references( + self, references: Optional[Union[List[Reference], JsonObject]] + ) -> Optional[Dict[str, Dependency]]: + if references: + return self.schema_reader.resolve_references(references) + return None + def send_delete_subject_message(self, subject: Subject, version: Version) -> None: key = {"subject": subject, "magic": 0, "keytype": "DELETE_SUBJECT"} value = {"subject": subject, "version": version} diff --git a/karapace/schema_registry_apis.py b/karapace/schema_registry_apis.py index f33c4aea0..49eba0511 100644 --- a/karapace/schema_registry_apis.py +++ b/karapace/schema_registry_apis.py @@ -12,9 +12,11 @@ from karapace.config import Config from karapace.errors import ( IncompatibleSchema, + InvalidReferences, InvalidSchema, InvalidSchemaType, InvalidVersion, + ReferenceExistsException, SchemasNotFoundException, SchemaTooLargeException, SchemaVersionNotSoftDeletedException, @@ -27,10 +29,11 @@ from karapace.karapace import KarapaceBase from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE, SERVER_NAME from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema +from karapace.schema_references import Reference from karapace.schema_registry import KarapaceSchemaRegistry, validate_version from karapace.typing import JsonData, ResolvedVersion, SchemaId from karapace.utils import JSONDecodeError -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import aiohttp import async_timeout @@ -56,6 +59,8 @@ class SchemaErrorCodes(Enum): INVALID_SCHEMA = 42201 INVALID_SUBJECT = 42208 SCHEMA_TOO_LARGE_ERROR_CODE = 42209 + REFERENCES_SUPPORT_NOT_IMPLEMENTED = 44302 + REFERENCE_EXISTS = 42206 NO_MASTER_ERROR = 50003 @@ -68,6 +73,14 @@ class SchemaErrorMessages(Enum): "full_transitive" ) SUBJECT_LEVEL_COMPATIBILITY_NOT_CONFIGURED_FMT = "Subject '%s' does not have subject-level compatibility configured" + REFERENCES_SUPPORT_NOT_IMPLEMENTED = "Schema references are not supported for '{schema_type}' schema type" + + +def references_list(references: Optional[Dict]) -> Optional[List[Reference]]: + _references: Optional[List[Reference]] = None + if references: + _references = [Reference(reference["name"], reference["subject"], reference["version"]) for reference in references] + return _references class KarapaceSchemaRegistryController(KarapaceBase): @@ -277,6 +290,12 @@ def _add_schema_registry_routes(self) -> None: schema_request=True, auth=self._auth, ) + self.route( + "/subjects//versions//referencedby", + callback=self.subject_version_referencedby_get, + method="GET", + schema_request=True, + ) self.route( "/subjects/", callback=self.subject_delete, @@ -344,8 +363,15 @@ async def compatibility_check( body = request.json schema_type = self._validate_schema_type(content_type=content_type, data=body) + references = self._validate_references(content_type, schema_type, body) try: - new_schema = ValidatedTypedSchema.parse(schema_type, body["schema"]) + new_schema_dependencies = self.schema_registry.resolve_references(references) + new_schema = ValidatedTypedSchema.parse( + schema_type=schema_type, + schema_str=body["schema"], + references=references, + dependencies=new_schema_dependencies, + ) except InvalidSchema: self.r( body={ @@ -368,10 +394,13 @@ async def compatibility_check( content_type=content_type, status=HTTPStatus.NOT_FOUND, ) - old_schema_type = self._validate_schema_type(content_type=content_type, data=old) try: - old_schema = ParsedTypedSchema.parse(old_schema_type, old["schema"]) + old_references = old.get("references", None) + old_dependencies = None + if old_references: + old_dependencies = self.schema_registry.resolve_references(old_references) + old_schema = ParsedTypedSchema.parse(old_schema_type, old["schema"], old_references, old_dependencies) except InvalidSchema: self.r( body={ @@ -696,6 +725,19 @@ async def subject_delete( content_type=content_type, status=HTTPStatus.NOT_FOUND, ) + + except ReferenceExistsException as arg: + self.r( + body={ + "error_code": SchemaErrorCodes.REFERENCE_EXISTS.value, + "message": ( + f"One or more references exist to the schema " + f"{{magic=1,keytype=SCHEMA,subject={subject},version={arg.version}}}." + ), + }, + content_type=content_type, + status=HTTPStatus.UNPROCESSABLE_ENTITY, + ) elif not master_url: self.no_master_error(content_type) else: @@ -787,6 +829,18 @@ async def subject_version_delete( content_type=content_type, status=HTTPStatus.NOT_FOUND, ) + except ReferenceExistsException as arg: + self.r( + body={ + "error_code": SchemaErrorCodes.REFERENCE_EXISTS.value, + "message": ( + f"One or more references exist to the schema " + f"{{magic=1,keytype=SCHEMA,subject={subject},version={arg.version}}}." + ), + }, + content_type=content_type, + status=HTTPStatus.UNPROCESSABLE_ENTITY, + ) elif not master_url: self.no_master_error(content_type) else: @@ -822,6 +876,34 @@ async def subject_version_schema_get( status=HTTPStatus.NOT_FOUND, ) + async def subject_version_referencedby_get(self, content_type, *, subject, version, user: Optional[User] = None): + self._check_authorization(user, Operation.Read, f"Subject:{subject}") + + try: + referenced_by = await self.schema_registry.subject_version_referencedby_get(subject, version) + except (SubjectNotFoundException, SchemasNotFoundException): + self.r( + body={ + "error_code": SchemaErrorCodes.SUBJECT_NOT_FOUND.value, + "message": SchemaErrorMessages.SUBJECT_NOT_FOUND_FMT.value.format(subject=subject), + }, + content_type=content_type, + status=HTTPStatus.NOT_FOUND, + ) + except VersionNotFoundException: + self.r( + body={ + "error_code": SchemaErrorCodes.VERSION_NOT_FOUND.value, + "message": f"Version {version} not found.", + }, + content_type=content_type, + status=HTTPStatus.NOT_FOUND, + ) + except InvalidVersion: + self._invalid_version(content_type, version) + + self.r(referenced_by, content_type, status=HTTPStatus.OK) + async def subject_versions_list( self, content_type: str, *, subject: str, request: HTTPRequest, user: Optional[User] = None ) -> None: @@ -862,12 +944,12 @@ def _validate_schema_request_body(self, content_type: str, body: Union[dict, Any content_type=content_type, status=HTTPStatus.BAD_REQUEST, ) - for attr in body: - if attr not in {"schema", "schemaType"}: + for field in body: + if field not in {"schema", "schemaType", "references"}: self.r( body={ "error_code": SchemaErrorCodes.HTTP_UNPROCESSABLE_ENTITY.value, - "message": f"Unrecognized field: {attr}", + "message": f"Unrecognized field: {field}", }, content_type=content_type, status=HTTPStatus.UNPROCESSABLE_ENTITY, @@ -908,6 +990,31 @@ def _validate_schema_key(self, content_type: str, body: dict) -> None: status=HTTPStatus.UNPROCESSABLE_ENTITY, ) + def _validate_references(self, content_type: str, schema_type: SchemaType, body: JsonData) -> Optional[List[Reference]]: + references = body.get("references", []) + if references and schema_type != SchemaType.PROTOBUF: + self.r( + body={ + "error_code": SchemaErrorCodes.REFERENCES_SUPPORT_NOT_IMPLEMENTED.value, + "message": SchemaErrorMessages.REFERENCES_SUPPORT_NOT_IMPLEMENTED.value.format( + schema_type=schema_type.value + ), + }, + content_type=content_type, + status=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + + validated_references = [] + for reference in references: + if ["name", "subject", "version"] != sorted(reference.keys()): + raise InvalidReferences() + validated_references.append( + Reference(name=reference["name"], subject=reference["subject"], version=reference["version"]) + ) + if validated_references: + return validated_references + return None + async def subjects_schema_post( self, content_type: str, *, subject: str, request: HTTPRequest, user: Optional[User] = None ) -> None: @@ -939,10 +1046,14 @@ async def subjects_schema_post( ) schema_str = body["schema"] schema_type = self._validate_schema_type(content_type=content_type, data=body) + references = self._validate_references(content_type, schema_type, body) + new_schema_dependencies = self.schema_registry.resolve_references(references) try: # When checking if schema is already registered, allow unvalidated schema in as # there might be stored schemas that are non-compliant from the past. - new_schema = ParsedTypedSchema.parse(schema_type, schema_str) + new_schema = ParsedTypedSchema.parse( + schema_type=schema_type, schema_str=schema_str, references=references, dependencies=new_schema_dependencies + ) except InvalidSchema: self.log.exception("No proper parser found") self.r( @@ -953,6 +1064,16 @@ async def subjects_schema_post( content_type=content_type, status=HTTPStatus.INTERNAL_SERVER_ERROR, ) + except InvalidReferences: + human_error = "Provided references is not valid" + self.r( + body={ + "error_code": SchemaErrorCodes.INVALID_SCHEMA.value, + "message": f"Invalid {schema_type} references. Error: {human_error}", + }, + content_type=content_type, + status=HTTPStatus.UNPROCESSABLE_ENTITY, + ) # Match schemas based on version from latest to oldest for schema_version in sorted(subject_data.values(), key=lambda item: item.version, reverse=True): @@ -991,6 +1112,7 @@ async def subjects_schema_post( self.r(ret, content_type) else: self.log.debug("Schema %r did not match %r", schema_version, parsed_typed_schema) + self.r( body={ "error_code": SchemaErrorCodes.SCHEMA_NOT_FOUND.value, @@ -1011,15 +1133,23 @@ async def subject_post( self._validate_schema_request_body(content_type, body) schema_type = self._validate_schema_type(content_type, body) self._validate_schema_key(content_type, body) + references = self._validate_references(content_type, schema_type, body) try: - new_schema = ValidatedTypedSchema.parse(schema_type=schema_type, schema_str=body["schema"]) - except (InvalidSchema, InvalidSchemaType) as e: + resolved_dependencies = self.schema_registry.resolve_references(references) + new_schema = ValidatedTypedSchema.parse( + schema_type=schema_type, + schema_str=body["schema"], + references=references, + dependencies=resolved_dependencies, + ) + except (InvalidReferences, InvalidSchema, InvalidSchemaType) as e: self.log.warning("Invalid schema: %r", body["schema"], exc_info=True) if isinstance(e.__cause__, (SchemaParseException, JSONDecodeError)): human_error = f"{e.__cause__.args[0]}" # pylint: disable=no-member else: - human_error = "Provided schema is not valid" + from_body_schema_str = body["schema"] + human_error = f"Invalid schema {from_body_schema_str} with refs {references} of type {schema_type}" self.r( body={ "error_code": SchemaErrorCodes.INVALID_SCHEMA.value, @@ -1036,7 +1166,7 @@ async def subject_post( are_we_master, master_url = await self.schema_registry.get_master() if are_we_master: try: - schema_id = await self.schema_registry.write_new_schema_local(subject, new_schema) + schema_id = await self.schema_registry.write_new_schema_local(subject, new_schema, references) self.r( body={"id": schema_id}, content_type=content_type, @@ -1068,6 +1198,8 @@ async def subject_post( content_type=content_type, status=HTTPStatus.UNPROCESSABLE_ENTITY, ) + except Exception as xx: + raise xx elif not master_url: self.no_master_error(content_type) diff --git a/karapace/schema_type.py b/karapace/schema_type.py new file mode 100644 index 000000000..ff0bc1166 --- /dev/null +++ b/karapace/schema_type.py @@ -0,0 +1,15 @@ +""" +karapace - schema_type + +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" + +from enum import Enum, unique + + +@unique +class SchemaType(str, Enum): + AVRO = "AVRO" + JSONSCHEMA = "JSON" + PROTOBUF = "PROTOBUF" diff --git a/karapace/serialization.py b/karapace/serialization.py index 3ee308f08..09e36b194 100644 --- a/karapace/serialization.py +++ b/karapace/serialization.py @@ -7,9 +7,11 @@ from google.protobuf.message import DecodeError from jsonschema import ValidationError from karapace.client import Client +from karapace.errors import InvalidReferences from karapace.protobuf.exception import ProtobufTypeException from karapace.protobuf.io import ProtobufDatumReader, ProtobufDatumWriter from karapace.schema_models import InvalidSchema, ParsedTypedSchema, SchemaType, TypedSchema, ValidatedTypedSchema +from karapace.schema_references import Reference from karapace.utils import json_decode, json_encode from typing import Any, Dict, Optional, Tuple from urllib.parse import quote @@ -82,9 +84,14 @@ def __init__( self.client = Client(server_uri=schema_registry_url, server_ca=server_ca, session_auth=session_auth) self.base_url = schema_registry_url - async def post_new_schema(self, subject: str, schema: ValidatedTypedSchema) -> int: + async def post_new_schema( + self, subject: str, schema: ValidatedTypedSchema, references: Optional[Reference] = None + ) -> int: if schema.schema_type is SchemaType.PROTOBUF: - payload = {"schema": str(schema), "schemaType": schema.schema_type.value} + if references: + payload = {"schema": str(schema), "schemaType": schema.schema_type.value, "references": references.json()} + else: + payload = {"schema": str(schema), "schemaType": schema.schema_type.value} else: payload = {"schema": json_encode(schema.to_dict()), "schemaType": schema.schema_type.value} result = await self.client.post(f"subjects/{quote(subject)}/versions", json=payload) @@ -114,6 +121,19 @@ async def get_schema_for_id(self, schema_id: int) -> ParsedTypedSchema: raise SchemaRetrievalError(f"Invalid result format: {json_result}") try: schema_type = SchemaType(json_result.get("schemaType", "AVRO")) + + references = json_result.get("references") + parsed_references = None + if references: + parsed_references = [] + for reference in references: + if ["name", "subject", "version"] != sorted(reference.keys()): + raise InvalidReferences() + parsed_references.append( + Reference(name=reference["name"], subject=reference["subject"], version=reference["version"]) + ) + if parsed_references: + return ParsedTypedSchema.parse(schema_type, json_result["schema"], references=parsed_references) return ParsedTypedSchema.parse(schema_type, json_result["schema"]) except InvalidSchema as e: raise SchemaRetrievalError(f"Failed to parse schema string from response: {json_result}") from e @@ -167,6 +187,7 @@ def get_subject_name(self, topic_name: str, schema: str, subject_type: str, sche async def get_schema_for_subject(self, subject: str) -> TypedSchema: assert self.registry_client, "must not call this method after the object is closed." + schema_id, schema = await self.registry_client.get_latest_schema(subject) async with self.state_lock: schema_ser = str(schema) @@ -184,6 +205,7 @@ async def get_id_for_schema(self, schema: str, subject: str, schema_type: Schema if schema_ser in self.schemas_to_ids: return self.schemas_to_ids[schema_ser] schema_id = await self.registry_client.post_new_schema(subject, schema_typed) + async with self.state_lock: self.schemas_to_ids[schema_ser] = schema_id self.ids_to_schemas[schema_id] = schema_typed diff --git a/karapace/typing.py b/karapace/typing.py index 1a82764ce..9cc3f3660 100644 --- a/karapace/typing.py +++ b/karapace/typing.py @@ -2,7 +2,7 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ -from typing import Dict, List, Mapping, Sequence, Union +from typing import Any, Dict, List, Mapping, Sequence, Union from typing_extensions import TypeAlias JsonArray: TypeAlias = List["JsonData"] @@ -16,8 +16,7 @@ ArgJsonData: TypeAlias = Union[JsonScalar, ArgJsonObject, ArgJsonArray] Subject: TypeAlias = str - Version: TypeAlias = Union[int, str] ResolvedVersion: TypeAlias = int - SchemaId: TypeAlias = int +Schema = Dict[str, Any] diff --git a/karapace/utils.py b/karapace/utils.py index f89059681..49ff9a451 100644 --- a/karapace/utils.py +++ b/karapace/utils.py @@ -15,6 +15,7 @@ from http import HTTPStatus from kafka.client_async import BrokerConnection, KafkaClient from karapace.typing import ArgJsonData, JsonData +from pathlib import Path from types import MappingProxyType from typing import AnyStr, cast, IO, Literal, NoReturn, overload, TypeVar @@ -152,6 +153,10 @@ def assert_never(value: NoReturn) -> NoReturn: raise RuntimeError(f"This code should never be reached, got: {value}") +def get_project_root() -> Path: + return Path(__file__).parent.parent + + class Timeout(Exception): pass diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_client_protobuf.py b/tests/integration/test_client_protobuf.py index fd70084b9..accca0421 100644 --- a/tests/integration/test_client_protobuf.py +++ b/tests/integration/test_client_protobuf.py @@ -14,7 +14,7 @@ async def test_remote_client_protobuf(registry_async_client): reg_cli = SchemaRegistryClient() reg_cli.client = registry_async_client subject = new_random_name("subject") - sc_id = await reg_cli.post_new_schema(subject, schema_protobuf) + sc_id = await reg_cli.post_new_schema(subject, schema_protobuf, None) assert sc_id >= 0 stored_schema = await reg_cli.get_schema_for_id(sc_id) assert stored_schema == schema_protobuf, f"stored schema {stored_schema} is not {schema_protobuf}" @@ -29,7 +29,7 @@ async def test_remote_client_protobuf2(registry_async_client): reg_cli = SchemaRegistryClient() reg_cli.client = registry_async_client subject = new_random_name("subject") - sc_id = await reg_cli.post_new_schema(subject, schema_protobuf) + sc_id = await reg_cli.post_new_schema(subject, schema_protobuf, None) assert sc_id >= 0 stored_schema = await reg_cli.get_schema_for_id(sc_id) assert stored_schema == schema_protobuf, f"stored schema {stored_schema} is not {schema_protobuf}" diff --git a/tests/integration/test_dependencies_compatibility_protobuf.py b/tests/integration/test_dependencies_compatibility_protobuf.py new file mode 100644 index 000000000..0393002a8 --- /dev/null +++ b/tests/integration/test_dependencies_compatibility_protobuf.py @@ -0,0 +1,358 @@ +""" +karapace - schema tests + +Copyright (c) 2019 Aiven Ltd +See LICENSE for details +""" +from karapace.client import Client +from karapace.protobuf.kotlin_wrapper import trim_margin +from tests.utils import create_subject_name_factory + +import pytest + + +@pytest.mark.parametrize("trail", ["", "/"]) +async def test_protobuf_schema_compatibility(registry_async_client: Client, trail: str) -> None: + subject = create_subject_name_factory(f"test_protobuf_schema_compatibility-{trail}")() + res = await registry_async_client.put(f"config/{subject}{trail}", json={"compatibility": "BACKWARD"}) + assert res.status_code == 200 + + original_dependencies = """ + |syntax = "proto3"; + |package a1; + |message container { + | message Hint { + | string hint_str = 1; + | } + |} + |""" + + evolved_dependencies = """ + |syntax = "proto3"; + |package a1; + |message container { + | message Hint { + | string hint_str = 1; + | } + |} + |""" + + original_dependencies = trim_margin(original_dependencies) + res = await registry_async_client.post( + "subjects/container1/versions", json={"schemaType": "PROTOBUF", "schema": original_dependencies} + ) + assert res.status_code == 200 + assert "id" in res.json() + + evolved_dependencies = trim_margin(evolved_dependencies) + res = await registry_async_client.post( + "subjects/container2/versions", json={"schemaType": "PROTOBUF", "schema": evolved_dependencies} + ) + assert res.status_code == 200 + assert "id" in res.json() + + original_schema = """ + |syntax = "proto3"; + |package a1; + |import "container1.proto"; + |message TestMessage { + | message Value { + | .a1.container.Hint hint = 1; + | int32 x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + original_schema = trim_margin(original_schema) + + original_references = [{"name": "container1.proto", "subject": "container1", "version": 1}] + res = await registry_async_client.post( + f"subjects/{subject}/versions{trail}", + json={"schemaType": "PROTOBUF", "schema": original_schema, "references": original_references}, + ) + assert res.status_code == 200 + assert "id" in res.json() + + evolved_schema = """ + |syntax = "proto3"; + |package a1; + |import "container2.proto"; + |message TestMessage { + | message Value { + | .a1.container.Hint hint = 1; + | int32 x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + evolved_schema = trim_margin(evolved_schema) + evolved_references = [{"name": "container2.proto", "subject": "container2", "version": 1}] + res = await registry_async_client.post( + f"compatibility/subjects/{subject}/versions/latest{trail}", + json={"schemaType": "PROTOBUF", "schema": evolved_schema, "references": evolved_references}, + ) + assert res.status_code == 200 + assert res.json() == {"is_compatible": True} + + +@pytest.mark.parametrize("trail", ["", "/"]) +async def test_protobuf_schema_compatibility_dependencies(registry_async_client: Client, trail: str) -> None: + subject = create_subject_name_factory(f"test_protobuf_schema_compatibility-{trail}")() + + res = await registry_async_client.put(f"config/{subject}{trail}", json={"compatibility": "BACKWARD"}) + assert res.status_code == 200 + + original_dependencies = """ + |syntax = "proto3"; + |package a1; + |message container { + | message Hint { + | string hint_str = 1; + | } + |} + |""" + + evolved_dependencies = """ + |syntax = "proto3"; + |package a1; + |message container { + | message Hint { + | int32 hint_str = 1; + | } + |} + |""" + + original_dependencies = trim_margin(original_dependencies) + res = await registry_async_client.post( + "subjects/container1/versions", json={"schemaType": "PROTOBUF", "schema": original_dependencies} + ) + assert res.status_code == 200 + assert "id" in res.json() + + evolved_dependencies = trim_margin(evolved_dependencies) + res = await registry_async_client.post( + "subjects/container2/versions", json={"schemaType": "PROTOBUF", "schema": evolved_dependencies} + ) + assert res.status_code == 200 + assert "id" in res.json() + + original_schema = """ + |syntax = "proto3"; + |package a1; + |import "container1.proto"; + |message TestMessage { + | message Value { + | .a1.container.Hint hint = 1; + | int32 x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + original_schema = trim_margin(original_schema) + + original_references = [{"name": "container1.proto", "subject": "container1", "version": 1}] + res = await registry_async_client.post( + f"subjects/{subject}/versions{trail}", + json={"schemaType": "PROTOBUF", "schema": original_schema, "references": original_references}, + ) + assert res.status_code == 200 + assert "id" in res.json() + + evolved_schema = """ + |syntax = "proto3"; + |package a1; + |import "container2.proto"; + |message TestMessage { + | message Value { + | .a1.container.Hint hint = 1; + | int32 x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + evolved_schema = trim_margin(evolved_schema) + evolved_references = [{"name": "container2.proto", "subject": "container2", "version": 1}] + res = await registry_async_client.post( + f"compatibility/subjects/{subject}/versions/latest{trail}", + json={"schemaType": "PROTOBUF", "schema": evolved_schema, "references": evolved_references}, + ) + assert res.status_code == 200 + assert res.json() == {"is_compatible": False} + + +@pytest.mark.parametrize("trail", ["", "/"]) +async def test_protobuf_schema_compatibility_dependencies1(registry_async_client: Client, trail: str) -> None: + subject = create_subject_name_factory(f"test_protobuf_schema_compatibility-{trail}")() + + res = await registry_async_client.put(f"config/{subject}{trail}", json={"compatibility": "BACKWARD"}) + assert res.status_code == 200 + + original_dependencies = """ + |syntax = "proto3"; + |package a1; + |message container { + | message H { + | string s = 1; + | } + |} + |""" + + evolved_dependencies = """ + |syntax = "proto3"; + |package a1; + |message container { + | message H { + | int32 s = 1; + | } + |} + |""" + + original_dependencies = trim_margin(original_dependencies) + res = await registry_async_client.post( + "subjects/container1/versions", json={"schemaType": "PROTOBUF", "schema": original_dependencies} + ) + assert res.status_code == 200 + assert "id" in res.json() + + evolved_dependencies = trim_margin(evolved_dependencies) + res = await registry_async_client.post( + "subjects/container2/versions", json={"schemaType": "PROTOBUF", "schema": evolved_dependencies} + ) + assert res.status_code == 200 + assert "id" in res.json() + + original_schema = """ + |syntax = "proto3"; + |package a1; + |import "container1.proto"; + |message TestMessage { + | message V { + | .a1.container.H h = 1; + | int32 x = 2; + | } + | string t = 1; + | .a1.TestMessage.V v = 2; + |} + |""" + + original_schema = trim_margin(original_schema) + + original_references = [{"name": "container1.proto", "subject": "container1", "version": 1}] + res = await registry_async_client.post( + f"subjects/{subject}/versions{trail}", + json={"schemaType": "PROTOBUF", "schema": original_schema, "references": original_references}, + ) + assert res.status_code == 200 + assert "id" in res.json() + + evolved_schema = """ + |syntax = "proto3"; + |package a1; + |import "container2.proto"; + |message TestMessage { + | message V { + | .a1.container.H h = 1; + | int32 x = 2; + | } + | string t = 1; + | .a1.TestMessage.V v = 2; + |} + |""" + evolved_schema = trim_margin(evolved_schema) + evolved_references = [{"name": "container2.proto", "subject": "container2", "version": 1}] + res = await registry_async_client.post( + f"compatibility/subjects/{subject}/versions/latest{trail}", + json={"schemaType": "PROTOBUF", "schema": evolved_schema, "references": evolved_references}, + ) + assert res.status_code == 200 + assert res.json() == {"is_compatible": False} + + +@pytest.mark.parametrize("trail", ["", "/"]) +async def test_protobuf_schema_compatibility_dependencies2(registry_async_client: Client, trail: str) -> None: + subject = create_subject_name_factory(f"test_protobuf_schema_compatibility-{trail}")() + + res = await registry_async_client.put(f"config/{subject}{trail}", json={"compatibility": "BACKWARD"}) + assert res.status_code == 200 + + original_dependencies = """ + |syntax = "proto3"; + |message container { + | message H { + | string s = 1; + | } + |} + |""" + + evolved_dependencies = """ + |syntax = "proto3"; + |message container { + | message H { + | int32 s = 1; + | } + |} + |""" + + original_dependencies = trim_margin(original_dependencies) + res = await registry_async_client.post( + "subjects/container1/versions", json={"schemaType": "PROTOBUF", "schema": original_dependencies} + ) + assert res.status_code == 200 + assert "id" in res.json() + + evolved_dependencies = trim_margin(evolved_dependencies) + res = await registry_async_client.post( + "subjects/container2/versions", json={"schemaType": "PROTOBUF", "schema": evolved_dependencies} + ) + assert res.status_code == 200 + assert "id" in res.json() + + original_schema = """ + |syntax = "proto3"; + |import "container1.proto"; + |message TestMessage { + | message V { + | .container.H h = 1; + | int32 x = 2; + | } + | string t = 1; + | .TestMessage.V v = 2; + |} + |""" + + original_schema = trim_margin(original_schema) + + original_references = [{"name": "container1.proto", "subject": "container1", "version": 1}] + res = await registry_async_client.post( + f"subjects/{subject}/versions{trail}", + json={"schemaType": "PROTOBUF", "schema": original_schema, "references": original_references}, + ) + assert res.status_code == 200 + assert "id" in res.json() + + evolved_schema = """ + |syntax = "proto3"; + |import "container2.proto"; + |message TestMessage { + | message V { + | .container.H h = 1; + | int32 x = 2; + | } + | string t = 1; + | .TestMessage.V v = 2; + |} + |""" + evolved_schema = trim_margin(evolved_schema) + evolved_references = [{"name": "container2.proto", "subject": "container2", "version": 1}] + res = await registry_async_client.post( + f"compatibility/subjects/{subject}/versions/latest{trail}", + json={"schemaType": "PROTOBUF", "schema": evolved_schema, "references": evolved_references}, + ) + assert res.status_code == 200 + assert res.json() == {"is_compatible": False} diff --git a/tests/integration/test_schema_protobuf.py b/tests/integration/test_schema_protobuf.py index 04a2dc0b0..1362fd4d9 100644 --- a/tests/integration/test_schema_protobuf.py +++ b/tests/integration/test_schema_protobuf.py @@ -4,9 +4,15 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ +from dataclasses import dataclass from karapace.client import Client +from karapace.errors import InvalidTest from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.schema_type import SchemaType +from karapace.typing import JsonData +from tests.base_testcase import BaseTestCase from tests.utils import create_subject_name_factory +from typing import List, Optional, Union import logging import pytest @@ -167,3 +173,862 @@ async def test_protobuf_schema_normalization(registry_async_client: Client, trai assert "id" in res.json() assert "schema" in res.json() assert evolved_id == res.json()["id"], "Check returns evolved id" + + +async def test_protobuf_schema_references(registry_async_client: Client) -> None: + customer_schema = """ + |syntax = "proto3"; + |package a1; + |import "Place.proto"; + |import "google/protobuf/duration.proto"; + |import "google/type/color.proto"; + |message Customer { + | string name = 1; + | int32 code = 2; + | Place place = 3; + | google.protobuf.Duration dur = 4; + | google.type.Color color = 5; + |} + |""" + + customer_schema = trim_margin(customer_schema) + + place_schema = """ + |syntax = "proto3"; + |package a1; + |message Place { + | string city = 1; + | int32 zone = 2; + |} + |""" + + place_schema = trim_margin(place_schema) + res = await registry_async_client.post( + "subjects/place/versions", json={"schemaType": "PROTOBUF", "schema": place_schema} + ) + assert res.status_code == 200 + + assert "id" in res.json() + + customer_references = [{"name": "Place.proto", "subject": "place", "version": 1}] + res = await registry_async_client.post( + "subjects/customer/versions", + json={"schemaType": "PROTOBUF", "schema": customer_schema, "references": customer_references}, + ) + assert res.status_code == 200 + + assert "id" in res.json() + + original_schema = """ + |syntax = "proto3"; + |package a1; + |import "Customer.proto"; + |message TestMessage { + | enum Enum { + | HIGH = 0; + | MIDDLE = 1; + | LOW = 2; + | } + | message Value { + | message Label{ + | int32 Id = 1; + | string name = 2; + | } + | Customer customer = 1; + | int32 x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + | oneof page_info { + | option (my_option) = true; + | int32 page_number = 3; + | int32 result_per_page = 4; + | } + |} + |""" + + original_schema = trim_margin(original_schema) + references = [{"name": "Customer.proto", "subject": "customer", "version": 1}] + res = await registry_async_client.post( + "subjects/test_schema/versions", + json={"schemaType": "PROTOBUF", "schema": original_schema, "references": references}, + ) + assert res.status_code == 200 + + assert "id" in res.json() + + res = await registry_async_client.get("subjects/customer/versions/latest/referencedby", json={}) + assert res.status_code == 200 + + myjson = res.json() + referents = [3] + assert not any(x != y for x, y in zip(myjson, referents)) + + res = await registry_async_client.get("subjects/place/versions/latest/referencedby", json={}) + assert res.status_code == 200 + + res = await registry_async_client.delete("subjects/customer/versions/1") + assert res.status_code == 422 + + match_msg = "One or more references exist to the schema {magic=1,keytype=SCHEMA,subject=customer,version=1}." + myjson = res.json() + assert myjson["error_code"] == 42206 and myjson["message"] == match_msg + + res = await registry_async_client.delete("subjects/test_schema/versions/1") + assert res.status_code == 200 + res = await registry_async_client.delete("subjects/test_schema/versions/1") + myjson = res.json() + match_msg = "Subject 'test_schema' Version 1 was soft deleted.Set permanent=true to delete permanently" + + assert res.status_code == 404 + + assert myjson["error_code"] == 40406 and myjson["message"] == match_msg + res = await registry_async_client.delete("subjects/customer/versions/1") + myjson = res.json() + match_msg = "One or more references exist to the schema {magic=1,keytype=SCHEMA,subject=customer,version=1}." + + assert res.status_code == 422 + assert myjson["error_code"] == 42206 and myjson["message"] == match_msg + + res = await registry_async_client.delete("subjects/test_schema/versions/1?permanent=true") + assert res.status_code == 200 + + res = await registry_async_client.delete("subjects/customer/versions/1") + assert res.status_code == 200 + + +async def test_protobuf_schema_jjaakola_one(registry_async_client: Client) -> None: + no_ref = """ + |syntax = "proto3"; + | + |message NoReference { + | string name = 1; + |} + |""" + + no_ref = trim_margin(no_ref) + res = await registry_async_client.post("subjects/sub1/versions", json={"schemaType": "PROTOBUF", "schema": no_ref}) + assert res.status_code == 200 + assert "id" in res.json() + + with_first_ref = """ + |syntax = "proto3"; + | + |import "NoReference.proto"; + | + |message WithReference { + | string name = 1; + | NoReference ref = 2; + |}""" + + with_first_ref = trim_margin(with_first_ref) + references = [{"name": "NoReference.proto", "subject": "sub1", "version": 1}] + res = await registry_async_client.post( + "subjects/sub2/versions", + json={"schemaType": "PROTOBUF", "schema": with_first_ref, "references": references}, + ) + assert res.status_code == 200 + assert "id" in res.json() + + no_ref_second = """ + |syntax = "proto3"; + | + |message NoReferenceTwo { + | string name = 1; + |} + |""" + + no_ref_second = trim_margin(no_ref_second) + res = await registry_async_client.post( + "subjects/sub3/versions", json={"schemaType": "PROTOBUF", "schema": no_ref_second} + ) + assert res.status_code == 200 + assert "id" in res.json() + + add_new_ref_in_sub2 = """ + |syntax = "proto3"; + |import "NoReference.proto"; + |import "NoReferenceTwo.proto"; + |message WithReference { + | string name = 1; + | NoReference ref = 2; + | NoReferenceTwo refTwo = 3; + |} + |""" + + add_new_ref_in_sub2 = trim_margin(add_new_ref_in_sub2) + + references = [ + {"name": "NoReference.proto", "subject": "sub1", "version": 1}, + {"name": "NoReferenceTwo.proto", "subject": "sub3", "version": 1}, + ] + + res = await registry_async_client.post( + "subjects/sub2/versions", + json={"schemaType": "PROTOBUF", "schema": add_new_ref_in_sub2, "references": references}, + ) + assert res.status_code == 200 + assert "id" in res.json() + + +async def test_protobuf_schema_verifier(registry_async_client: Client) -> None: + customer_schema = """ + |syntax = "proto3"; + |package a1; + |message Customer { + | string name = 1; + | int32 code = 2; + |} + |""" + + customer_schema = trim_margin(customer_schema) + res = await registry_async_client.post( + "subjects/customer/versions", json={"schemaType": "PROTOBUF", "schema": customer_schema} + ) + assert res.status_code == 200 + assert "id" in res.json() + original_schema = """ + |syntax = "proto3"; + |package a1; + |import "Customer.proto"; + |message TestMessage { + | enum Enum { + | HIGH = 0; + | MIDDLE = 1; + | LOW = 2; + | } + | message Value { + | message Label{ + | int32 Id = 1; + | string name = 2; + | } + | Customer customer = 1; + | int32 x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + | TestMessage.Value valx = 3; + | + | oneof page_info { + | option (my_option) = true; + | int32 page_number = 5; + | int32 result_per_page = 6; + | } + |} + |""" + + original_schema = trim_margin(original_schema) + references = [{"name": "Customer.proto", "subject": "customer", "version": 1}] + res = await registry_async_client.post( + "subjects/test_schema/versions", + json={"schemaType": "PROTOBUF", "schema": original_schema, "references": references}, + ) + assert res.status_code == 200 + assert "id" in res.json() + res = await registry_async_client.get("subjects/customer/versions/latest/referencedby", json={}) + assert res.status_code == 200 + myjson = res.json() + referents = [2] + assert not any(x != y for x, y in zip(myjson, referents)) + + res = await registry_async_client.delete("subjects/customer/versions/1") + assert res.status_code == 422 + match_msg = "One or more references exist to the schema {magic=1,keytype=SCHEMA,subject=customer,version=1}." + myjson = res.json() + assert myjson["error_code"] == 42206 and myjson["message"] == match_msg + + res = await registry_async_client.delete("subjects/test_schema/versions/1") + assert res.status_code == 200 + + res = await registry_async_client.delete("subjects/customer/versions/1") + assert res.status_code == 422 + + res = await registry_async_client.delete("subjects/test_schema/versions/1?permanent=true") + assert res.status_code == 200 + + res = await registry_async_client.delete("subjects/customer/versions/1") + assert res.status_code == 200 + + +@dataclass +class TestCaseSchema: + schema_type: SchemaType + schema_str: str + subject: str + references: Optional[List[JsonData]] = None + expected: int = 200 + expected_msg: str = "" + expected_error_code: Optional[int] = None + + +TestCaseSchema.__test__ = False + + +@dataclass +class TestCaseDeleteSchema: + subject: str + version: int + schema_id: int + expected: int = 200 + expected_msg: str = "" + expected_error_code: Optional[int] = None + + +TestCaseDeleteSchema.__test__ = False + + +@dataclass +class TestCaseHardDeleteSchema(TestCaseDeleteSchema): + pass + + +@dataclass +class ReferenceTestCase(BaseTestCase): + schemas: List[Union[TestCaseSchema, TestCaseDeleteSchema]] + + +# Base case +SCHEMA_NO_REF = """\ +syntax = "proto3"; + +message NoReference { + string name = 1; +} +""" + +SCHEMA_NO_REF_TWO = """\ +syntax = "proto3"; + +message NoReferenceTwo { + string name = 1; +} +""" + +SCHEMA_WITH_REF = """\ +syntax = "proto3"; + +import "NoReference.proto"; + +message WithReference { + string name = 1; + NoReference ref = 2; +} +""" + +SCHEMA_WITH_2ND_LEVEL_REF = """\ +syntax = "proto3"; + +import "WithReference.proto"; + +message With2ndLevelReference { + string name = 1; + WithReference ref = 2; +} +""" + +SCHEMA_REMOVES_REFERENCED_FIELD_INCOMPATIBLE = """\ +syntax = "proto3"; + +message WithReference { + string name = 1; +} +""" + +SCHEMA_ADDS_NEW_REFERENCE = """\ +syntax = "proto3"; + +import "NoReference.proto"; +import "NoReferenceTwo.proto"; + +message WithReference { + string name = 1; + NoReference ref = 2; + NoReferenceTwo refTwo = 3; +} +""" + +# Invalid schema +SCHEMA_INVALID_MISSING_CLOSING_BRACE = """\ +syntax = "proto3"; + +import "NoReference.proto"; + +message SchemaMissingClosingBrace { + string name = 1; + NoReference ref = 2; + +""" + +# Schema having multiple messages +SCHEMA_NO_REF_TWO_MESSAGES = """\ +syntax = "proto3"; + +message NoReferenceOne { + string nameOne = 1; +} + +message NoReferenceTwo { + string nameTwo = 1; +} +""" + +SCHEMA_WITH_REF_TO_NO_REFERENCE_TWO = """\ +syntax = "proto3"; + +import "NoReferenceTwo.proto"; + +message WithReference { + string name = 1; + NoReferenceTwo ref = 2; +} +""" + +# Nested references +SCHEMA_NO_REF_NESTED_MESSAGE = """\ +syntax = "proto3"; + +message NoReference { + message NoReferenceNested { + string nameNested = 1; + } + string name = 1; + NoReferenceNested ref = 2; +} +""" + +SCHEMA_WITH_REF_TO_NESTED = """\ +syntax = "proto3"; + +import "NoReferenceNested.proto"; + +message WithReference { + string name = 1; + NoReference.NoReferenceNested ref = 2; +} +""" + + +@pytest.mark.parametrize( + "testcase", + [ + ReferenceTestCase( + test_name="No references", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF, + subject="nr_s1", + references=None, + expected=200, + ), + TestCaseDeleteSchema( + subject="nr_s1", + schema_id=1, + version=1, + expected=200, + ), + ], + ), + # Better error message should be given back, now it is only InvalidSchema + ReferenceTestCase( + test_name="With reference, ref schema does not exist", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_REF, + subject="wr_nonexisting_s1", + references=[{"name": "NoReference.proto", "subject": "wr_not_found", "version": 1}], + expected=422, + expected_msg=( + f"Invalid PROTOBUF schema. Error: Invalid schema {SCHEMA_WITH_REF} " + "with refs [{name='NoReference.proto', subject='wr_not_found', version=1}]" + f" of type {SchemaType.PROTOBUF}" + ), + expected_error_code=42201, + ), + ], + ), + ReferenceTestCase( + test_name="With reference, references not given", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_REF, + subject="wr_nonexisting_s1_missing_references", + references=None, + expected=422, + expected_msg=f"Invalid PROTOBUF schema. Error: Invalid schema {SCHEMA_WITH_REF} " + f"with refs None of type {SchemaType.PROTOBUF}", + expected_error_code=42201, + ), + ], + ), + ReferenceTestCase( + test_name="With reference", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF, + subject="wr_s1", + references=None, + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_REF, + subject="wr_s2", + references=[{"name": "NoReference.proto", "subject": "wr_s1", "version": 1}], + expected=200, + ), + TestCaseDeleteSchema( + subject="wr_s1", + schema_id=1, + version=1, + expected=422, + expected_msg=( + "One or more references exist to the schema {magic=1,keytype=SCHEMA,subject=wr_s1,version=1}." + ), + expected_error_code=42206, + ), + TestCaseDeleteSchema( + subject="wr_s2", + schema_id=2, + version=1, + expected=200, + ), + TestCaseHardDeleteSchema( + subject="wr_s2", + schema_id=2, + version=1, + expected=200, + ), + TestCaseDeleteSchema( + subject="wr_s1", + schema_id=1, + version=1, + expected=200, + ), + ], + ), + ReferenceTestCase( + test_name="With reference, remove referenced field causes incompatible schema", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF, + subject="wr_s1_test_incompatible_change", + references=None, + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_REF, + subject="wr_s2_test_incompatible_change", + references=[{"name": "NoReference.proto", "subject": "wr_s1_test_incompatible_change", "version": 1}], + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_REMOVES_REFERENCED_FIELD_INCOMPATIBLE, + subject="wr_s2_test_incompatible_change", + references=None, + expected=200, + # It is erroneous assumption, there FIELD_DROP only, and it is compatible. + # expected = 200 + # expected_msg=( + # "Incompatible schema, compatibility_mode=BACKWARD " + # "Incompatible modification Modification.MESSAGE_DROP found" + # ), + ), + ], + ), + ReferenceTestCase( + test_name="With reference, add new referenced field", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF, + subject="wr_s1_add_new_reference", + references=None, + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_REF, + subject="wr_s2_add_new_reference", + references=[{"name": "NoReference.proto", "subject": "wr_s1_add_new_reference", "version": 1}], + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF_TWO, + subject="wr_s3_the_new_reference", + references=None, + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_ADDS_NEW_REFERENCE, + subject="wr_s2_add_new_reference", + references=[ + {"name": "NoReference.proto", "subject": "wr_s1_add_new_reference", "version": 1}, + {"name": "NoReferenceTwo.proto", "subject": "wr_s3_the_new_reference", "version": 1}, + ], + expected=200, + ), + ], + ), + ReferenceTestCase( + test_name="With reference chain, nonexisting schema", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF, + subject="wr_chain_s1", + references=None, + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_REF, + subject="wr_chain_s2", + references=[ + {"name": "NoReference.proto", "subject": "wr_chain_s1", "version": 1}, + {"name": "NotFoundReference.proto", "subject": "wr_chain_nonexisting", "version": 1}, + ], + expected=422, + expected_msg=( + f"Invalid PROTOBUF schema. Error: Invalid schema {SCHEMA_WITH_REF} " + "with refs [{name='NoReference.proto', subject='wr_chain_s1', version=1}, " + "{name='NotFoundReference.proto', subject='wr_chain_nonexisting', version=1}] " + f"of type {SchemaType.PROTOBUF}" + ), + expected_error_code=42201, + ), + ], + ), + ReferenceTestCase( + test_name="With reference chain", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF, + subject="wr_chain_s1", + references=None, + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_REF, + subject="wr_chain_s2", + references=[{"name": "NoReference.proto", "subject": "wr_chain_s1", "version": 1}], + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_2ND_LEVEL_REF, + subject="wr_chain_s3", + references=[{"name": "WithReference.proto", "subject": "wr_chain_s2", "version": 1}], + expected=200, + ), + TestCaseDeleteSchema( + subject="wr_chain_s1", + schema_id=1, + version=1, + expected=422, + ), + TestCaseDeleteSchema( + subject="wr_chain_s2", + schema_id=2, + version=1, + expected=422, + ), + TestCaseDeleteSchema( + subject="wr_chain_s3", + schema_id=3, + version=1, + expected=200, + ), + TestCaseHardDeleteSchema( + subject="wr_chain_s3", + schema_id=3, + version=1, + expected=200, + ), + TestCaseDeleteSchema( + subject="wr_chain_s2", + schema_id=2, + version=1, + expected=200, + ), + TestCaseHardDeleteSchema( + subject="wr_chain_s2", + schema_id=2, + version=1, + expected=200, + ), + TestCaseDeleteSchema( + subject="wr_chain_s1", + schema_id=1, + version=1, + expected=200, + ), + ], + ), + ReferenceTestCase( + test_name="Invalid schema missing closing brace", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF, + subject="wr_invalid_reference_ok_schema", + references=None, + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_INVALID_MISSING_CLOSING_BRACE, + subject="wr_invalid_missing_closing_brace", + references=[{"name": "NoReference.proto", "subject": "wr_invalid_reference_ok_schema", "version": 1}], + expected=422, + expected_msg=( + f"Invalid PROTOBUF schema. Error: Invalid schema {SCHEMA_INVALID_MISSING_CLOSING_BRACE} " + "with refs [{name='NoReference.proto', subject='wr_invalid_reference_ok_schema', version=1}] " + f"of type {SchemaType.PROTOBUF}" + ), + expected_error_code=42201, + ), + ], + ), + ReferenceTestCase( + test_name="With reference to message from schema file defining two messages", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF_TWO_MESSAGES, + subject="wr_s1_two_messages", + references=None, + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_REF_TO_NO_REFERENCE_TWO, + subject="wr_s2_referencing_message_two", + references=[{"name": "NoReferenceTwo.proto", "subject": "wr_s1_two_messages", "version": 1}], + expected=200, + ), + ], + ), + ReferenceTestCase( + test_name="With reference to nested message", + schemas=[ + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF_NESTED_MESSAGE, + subject="wr_s1_with_nested_message", + references=None, + expected=200, + ), + TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_REF_TO_NESTED, + subject="wr_s2_referencing_nested_message", + references=[{"name": "NoReference.proto", "subject": "wr_s1_with_nested_message", "version": 1}], + expected=200, + ), + ], + ), + ], + ids=str, +) +async def test_references(testcase: ReferenceTestCase, registry_async_client: Client): + for testdata in testcase.schemas: + if isinstance(testdata, TestCaseSchema): + print(f"Adding new schema, subject: '{testdata.subject}'\n{testdata.schema_str}") + body = {"schemaType": testdata.schema_type, "schema": testdata.schema_str} + if testdata.references: + body["references"] = testdata.references + res = await registry_async_client.post(f"subjects/{testdata.subject}/versions", json=body) + elif isinstance(testdata, TestCaseHardDeleteSchema): + print( + f"Permanently deleting schema, subject: '{testdata.subject}, " + f"schema: {testdata.schema_id}, version: {testdata.version}' " + ) + res = await registry_async_client.delete( + f"subjects/{testdata.subject}/versions/{testdata.version}?permanent=true" + ) + elif isinstance(testdata, TestCaseDeleteSchema): + print( + f"Deleting schema, subject: '{testdata.subject}, schema: {testdata.schema_id}, version: {testdata.version}' " + ) + res = await registry_async_client.delete(f"subjects/{testdata.subject}/versions/{testdata.version}") + else: + raise InvalidTest("Unknown test case.") + + assert res.status_code == testdata.expected + if testdata.expected_msg: + assert res.json_result.get("message", None) == testdata.expected_msg + if testdata.expected_error_code: + assert res.json_result.get("error_code") == testdata.expected_error_code + if isinstance(testdata, TestCaseSchema): + if testdata.expected == 200: + schema_id = res.json().get("id") + fetch_schema_res = await registry_async_client.get(f"/schemas/ids/{schema_id}") + assert fetch_schema_res.status_code == 200 + if isinstance(testdata, TestCaseDeleteSchema): + if testdata.expected == 200: + fetch_res = await registry_async_client.get(f"/subjects/{testdata.subject}/versions/{testdata.version}") + assert fetch_res.status_code == 404 + else: + fetch_schema_res = await registry_async_client.get(f"/schemas/ids/{testdata.schema_id}") + assert fetch_schema_res.status_code == 200 + + +async def test_protobuf_error(registry_async_client: Client) -> None: + testdata = TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_NO_REF, + subject="wr_s1_test_incompatible_change", + references=None, + expected=200, + ) + print(f"Adding new schema, subject: '{testdata.subject}'\n{testdata.schema_str}") + body = {"schemaType": testdata.schema_type, "schema": testdata.schema_str} + if testdata.references: + body["references"] = testdata.references + res = await registry_async_client.post(f"subjects/{testdata.subject}/versions", json=body) + + assert res.status_code == 200 + + testdata = TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_WITH_REF, + subject="wr_s2_test_incompatible_change", + references=[{"name": "NoReference.proto", "subject": "wr_s1_test_incompatible_change", "version": 1}], + expected=200, + ) + print(f"Adding new schema, subject: '{testdata.subject}'\n{testdata.schema_str}") + body = {"schemaType": testdata.schema_type, "schema": testdata.schema_str} + if testdata.references: + body["references"] = testdata.references + res = await registry_async_client.post(f"subjects/{testdata.subject}/versions", json=body) + assert res.status_code == 200 + testdata = TestCaseSchema( + schema_type=SchemaType.PROTOBUF, + schema_str=SCHEMA_REMOVES_REFERENCED_FIELD_INCOMPATIBLE, + subject="wr_s2_test_incompatible_change", + references=None, + expected=409, + expected_msg=( + # ACTUALLY THERE NO MESSAGE_DROP!!! + "Incompatible schema, compatibility_mode=BACKWARD " + "Incompatible modification Modification.MESSAGE_DROP found" + ), + ) + print(f"Adding new schema, subject: '{testdata.subject}'\n{testdata.schema_str}") + body = {"schemaType": testdata.schema_type, "schema": testdata.schema_str} + if testdata.references: + body["references"] = testdata.references + res = await registry_async_client.post(f"subjects/{testdata.subject}/versions", json=body) + + assert res.status_code == 200 diff --git a/tests/unit/test_dependency_verifier.py b/tests/unit/test_dependency_verifier.py new file mode 100644 index 000000000..5bbf3d019 --- /dev/null +++ b/tests/unit/test_dependency_verifier.py @@ -0,0 +1,61 @@ +""" +karapace - tests_dependency_verifier + +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" + +from karapace.protobuf.dependency import ProtobufDependencyVerifier + +import logging + +log = logging.getLogger(__name__) + + +async def test_protobuf_dependency_verifier(): + declared_types = [ + ".a1.Place", + "Place", + ".a1.Customer", + "Customer", + ".a1.TestMessage", + "TestMessage", + ".a1", + ".TestMessage", + ".Enum", + "TestMessage.Enum", + ".a1.TestMessage.Value", + "TestMessage.Value", + ".a1.TestMessage.Value.Label", + "TestMessage", + ".Value.Label", + ] + + used_types = [ + ".a1.Place;string", + ".a1.Place;int32", + ".a1.Customer;string", + ".a1.Customer;int32", + ".a1.Customer;Place", + ".a1.TestMessage;int32", + ".a1.TestMessage;int32", + ".a1.TestMessage;string", + ".a1.TestMessage;.a1.TestMessage.Value", + "TestMessage;Customer", + "TestMessage;int32", + "TestMessage.Value;int32", + "TestMessage.Value;string", + ] + + verifier = ProtobufDependencyVerifier() + for declared in declared_types: + verifier.add_declared_type(declared) + for used in used_types: + x = used.split(";") + verifier.add_used_type(x[0], x[1]) + + result = verifier.verify() + assert result.result, True + + verifier.add_used_type("TestMessage.Delta", "Tag") + assert result.result, False diff --git a/tests/unit/test_protobuf_serialization.py b/tests/unit/test_protobuf_serialization.py index 47f6388a1..29d249f3e 100644 --- a/tests/unit/test_protobuf_serialization.py +++ b/tests/unit/test_protobuf_serialization.py @@ -3,8 +3,10 @@ See LICENSE for details """ from karapace.config import read_config +from karapace.dependency import Dependency from karapace.protobuf.kotlin_wrapper import trim_margin -from karapace.schema_models import SchemaType, ValidatedTypedSchema +from karapace.schema_models import ParsedTypedSchema, SchemaType +from karapace.schema_references import Reference from karapace.serialization import ( InvalidMessageHeader, InvalidMessageSchema, @@ -35,10 +37,10 @@ async def make_ser_deser(config_path: str, mock_client) -> SchemaRegistrySeriali async def test_happy_flow(default_config_path): mock_protobuf_registry_client = Mock() schema_for_id_one_future = asyncio.Future() - schema_for_id_one_future.set_result(ValidatedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf))) + schema_for_id_one_future.set_result(ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf))) mock_protobuf_registry_client.get_schema_for_id.return_value = schema_for_id_one_future get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ValidatedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)))) + get_latest_schema_future.set_result((1, ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)))) mock_protobuf_registry_client.get_latest_schema.return_value = get_latest_schema_future serializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) @@ -54,10 +56,165 @@ async def test_happy_flow(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_latest_schema("top")] +async def test_happy_flow_references(default_config_path): + no_ref_schema_str = """ + |syntax = "proto3"; + | + |option java_package = "com.codingharbour.protobuf"; + |option java_outer_classname = "TestEnumOrder"; + | + |message Speed { + | Enum speed = 1; + |} + | + |enum Enum { + | HIGH = 0; + | MIDDLE = 1; + | LOW = 2; + |} + | + """ + + ref_schema_str = """ + |syntax = "proto3"; + | + |option java_package = "com.codingharbour.protobuf"; + |option java_outer_classname = "TestEnumOrder"; + |import "Speed.proto"; + | + |message Message { + | int32 query = 1; + | Speed speed = 2; + |} + | + | + """ + no_ref_schema_str = trim_margin(no_ref_schema_str) + ref_schema_str = trim_margin(ref_schema_str) + + test_objects = [ + {"query": 5, "speed": {"speed": "HIGH"}}, + {"query": 10, "speed": {"speed": "MIDDLE"}}, + ] + + references = [Reference("Speed.proto", "speed", 1)] + + no_ref_schema = ParsedTypedSchema.parse(SchemaType.PROTOBUF, no_ref_schema_str) + dep = Dependency("Speed.proto", "speed", 1, no_ref_schema) + ref_schema = ParsedTypedSchema.parse(SchemaType.PROTOBUF, ref_schema_str, references, {"Speed.proto": dep}) + + mock_protobuf_registry_client = Mock() + schema_for_id_one_future = asyncio.Future() + schema_for_id_one_future.set_result(ref_schema) + mock_protobuf_registry_client.get_schema_for_id.return_value = schema_for_id_one_future + get_latest_schema_future = asyncio.Future() + get_latest_schema_future.set_result((1, ref_schema)) + mock_protobuf_registry_client.get_latest_schema.return_value = get_latest_schema_future + + serializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) + assert len(serializer.ids_to_schemas) == 0 + schema = await serializer.get_schema_for_subject("top") + for o in test_objects: + a = await serializer.serialize(schema, o) + u = await serializer.deserialize(a) + assert o == u + assert len(serializer.ids_to_schemas) == 1 + assert 1 in serializer.ids_to_schemas + + assert mock_protobuf_registry_client.method_calls == [call.get_latest_schema("top")] + + +async def test_happy_flow_references_two(default_config_path): + no_ref_schema_str = """ + |syntax = "proto3"; + | + |option java_package = "com.serge.protobuf"; + |option java_outer_classname = "TestSpeed"; + | + |message Speed { + | Enum speed = 1; + |} + | + |enum Enum { + | HIGH = 0; + | MIDDLE = 1; + | LOW = 2; + |} + | + """ + + ref_schema_str = """ + |syntax = "proto3"; + | + |option java_package = "com.serge.protobuf"; + |option java_outer_classname = "TestQuery"; + |import "Speed.proto"; + | + |message Query { + | int32 query = 1; + | Speed speed = 2; + |} + | + """ + + ref_schema_str_two = """ + |syntax = "proto3"; + | + |option java_package = "com.serge.protobuf"; + |option java_outer_classname = "TestMessage"; + |import "Query.proto"; + | + |message Message { + | int32 index = 1; + | Query qry = 2; + |} + | + """ + + no_ref_schema_str = trim_margin(no_ref_schema_str) + ref_schema_str = trim_margin(ref_schema_str) + ref_schema_str_two = trim_margin(ref_schema_str_two) + test_objects = [ + {"index": 1, "qry": {"query": 5, "speed": {"speed": "HIGH"}}}, + {"index": 2, "qry": {"query": 10, "speed": {"speed": "HIGH"}}}, + ] + + references = [Reference("Speed.proto", "speed", 1)] + references_two = [Reference("Query.proto", "msg", 1)] + + no_ref_schema = ParsedTypedSchema.parse(SchemaType.PROTOBUF, no_ref_schema_str) + dep = Dependency("Speed.proto", "speed", 1, no_ref_schema) + ref_schema = ParsedTypedSchema.parse(SchemaType.PROTOBUF, ref_schema_str, references, {"Speed.proto": dep}) + dep_two = Dependency("Query.proto", "qry", 1, ref_schema) + ref_schema_two = ParsedTypedSchema.parse( + SchemaType.PROTOBUF, ref_schema_str_two, references_two, {"Query.proto": dep_two} + ) + + mock_protobuf_registry_client = Mock() + schema_for_id_one_future = asyncio.Future() + schema_for_id_one_future.set_result(ref_schema_two) + mock_protobuf_registry_client.get_schema_for_id.return_value = schema_for_id_one_future + get_latest_schema_future = asyncio.Future() + get_latest_schema_future.set_result((1, ref_schema_two)) + mock_protobuf_registry_client.get_latest_schema.return_value = get_latest_schema_future + + serializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) + assert len(serializer.ids_to_schemas) == 0 + schema = await serializer.get_schema_for_subject("top") + for o in test_objects: + a = await serializer.serialize(schema, o) + u = await serializer.deserialize(a) + assert o == u + assert len(serializer.ids_to_schemas) == 1 + assert 1 in serializer.ids_to_schemas + + assert mock_protobuf_registry_client.method_calls == [call.get_latest_schema("top")] + + async def test_serialization_fails(default_config_path): mock_protobuf_registry_client = Mock() get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ValidatedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)))) + get_latest_schema_future.set_result((1, ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)))) mock_protobuf_registry_client.get_latest_schema.return_value = get_latest_schema_future serializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client)