From aef7faa922fb0e22a6e47c8a2d57df75d434b85d Mon Sep 17 00:00:00 2001 From: Robert Yokota Date: Thu, 19 Dec 2024 09:38:02 -0800 Subject: [PATCH] DGS-19409 Ensure Avro serde caches per subject (#2387) * DGS-19409 Ensure Avro serde caches per subject * Add test * Fix test --- .../GenericSerializerImpl.cs | 48 ++++----------- .../SpecificSerializerImpl.cs | 58 +++++-------------- .../BaseSerializeDeserialize.cs | 8 +++ .../SerializeDeserialize.cs | 37 ++++++++++++ 4 files changed, 70 insertions(+), 81 deletions(-) diff --git a/src/Confluent.SchemaRegistry.Serdes.Avro/GenericSerializerImpl.cs b/src/Confluent.SchemaRegistry.Serdes.Avro/GenericSerializerImpl.cs index 13d92afef..3049ff6f5 100644 --- a/src/Confluent.SchemaRegistry.Serdes.Avro/GenericSerializerImpl.cs +++ b/src/Confluent.SchemaRegistry.Serdes.Avro/GenericSerializerImpl.cs @@ -31,9 +31,10 @@ namespace Confluent.SchemaRegistry.Serdes { internal class GenericSerializerImpl : AsyncSerializer { - private Dictionary knownSchemas = new Dictionary(); - private HashSet> registeredSchemas = new HashSet>(); - private Dictionary schemaIds = new Dictionary(); + private Dictionary knownSchemas = + new Dictionary(); + private Dictionary, int> registeredSchemas = + new Dictionary, int>(); public GenericSerializerImpl( ISchemaRegistryClient schemaRegistryClient, @@ -99,12 +100,10 @@ public async Task Serialize(string topic, Headers headers, GenericRecord // something more sophisticated than the below + not allow // the misuse to keep happening without warning. if (knownSchemas.Count > schemaRegistryClient.MaxCachedSchemas || - registeredSchemas.Count > schemaRegistryClient.MaxCachedSchemas || - schemaIds.Count > schemaRegistryClient.MaxCachedSchemas) + registeredSchemas.Count > schemaRegistryClient.MaxCachedSchemas) { knownSchemas.Clear(); registeredSchemas.Clear(); - schemaIds.Clear(); } // Determine a schema string corresponding to the schema object. @@ -139,41 +138,18 @@ public async Task Serialize(string topic, Headers headers, GenericRecord { schemaId = latestSchema.Id; } - else if (!registeredSchemas.Contains(subjectSchemaPair)) + else if (!registeredSchemas.TryGetValue(subjectSchemaPair, out schemaId)) { - int newSchemaId; - // first usage: register/get schema to check compatibility - if (autoRegisterSchema) - { - newSchemaId = await schemaRegistryClient + schemaId = autoRegisterSchema + ? await schemaRegistryClient .RegisterSchemaAsync(subject, writerSchemaString, normalizeSchemas) + .ConfigureAwait(continueOnCapturedContext: false) + : await schemaRegistryClient + .GetSchemaIdAsync(subject, writerSchemaString, normalizeSchemas) .ConfigureAwait(continueOnCapturedContext: false); - } - else - { - newSchemaId = await schemaRegistryClient.GetSchemaIdAsync(subject, writerSchemaString, normalizeSchemas) - .ConfigureAwait(continueOnCapturedContext: false); - } - - if (!schemaIds.ContainsKey(writerSchemaString)) - { - schemaIds.Add(writerSchemaString, newSchemaId); - } - else if (schemaIds[writerSchemaString] != newSchemaId) - { - schemaIds.Clear(); - registeredSchemas.Clear(); - throw new KafkaException(new Error(isKey ? ErrorCode.Local_KeySerialization : ErrorCode.Local_ValueSerialization, $"Duplicate schema registration encountered: Schema ids {schemaIds[writerSchemaString]} and {newSchemaId} are associated with the same schema.")); - } - - registeredSchemas.Add(subjectSchemaPair); - schemaId = schemaIds[writerSchemaString]; - } - else - { - schemaId = schemaIds[writerSchemaString]; + registeredSchemas.Add(subjectSchemaPair, schemaId); } } finally diff --git a/src/Confluent.SchemaRegistry.Serdes.Avro/SpecificSerializerImpl.cs b/src/Confluent.SchemaRegistry.Serdes.Avro/SpecificSerializerImpl.cs index 90ebdb46f..f3effb5ba 100644 --- a/src/Confluent.SchemaRegistry.Serdes.Avro/SpecificSerializerImpl.cs +++ b/src/Confluent.SchemaRegistry.Serdes.Avro/SpecificSerializerImpl.cs @@ -35,23 +35,8 @@ internal class SerializerSchemaData { private string writerSchemaString; private global::Avro.Schema writerSchema; - - /// - /// A given schema is uniquely identified by a schema id, even when - /// registered against multiple subjects. - /// - private int? writerSchemaId; - private SpecificWriter avroWriter; - private HashSet subjectsRegistered = new HashSet(); - - public HashSet SubjectsRegistered - { - get => subjectsRegistered; - set => subjectsRegistered = value; - } - public string WriterSchemaString { get => writerSchemaString; @@ -64,12 +49,6 @@ public Avro.Schema WriterSchema set => writerSchema = value; } - public int? WriterSchemaId - { - get => writerSchemaId; - set => writerSchemaId = value; - } - public SpecificWriter AvroWriter { get => avroWriter; @@ -79,20 +58,14 @@ public SpecificWriter AvroWriter private Dictionary multiSchemaData = new Dictionary(); - - private SerializerSchemaData singleSchemaData; + private Dictionary, int> registeredSchemas = + new Dictionary, int>(); public SpecificSerializerImpl( ISchemaRegistryClient schemaRegistryClient, AvroSerializerConfig config, RuleRegistry ruleRegistry) : base(schemaRegistryClient, config, ruleRegistry) { - Type writerType = typeof(T); - if (writerType != typeof(ISpecificRecord)) - { - singleSchemaData = ExtractSchemaData(writerType); - } - if (config == null) { return; } if (config.BufferBytes != null) { this.initialBufferSize = config.BufferBytes.Value; } @@ -177,24 +150,18 @@ public async Task Serialize(string topic, Headers headers, T data, bool { try { + int schemaId; string subject; RegisteredSchema latestSchema = null; SerializerSchemaData currentSchemaData; await serdeMutex.WaitAsync().ConfigureAwait(continueOnCapturedContext: false); try { - if (singleSchemaData == null) - { - var key = data.GetType(); - if (!multiSchemaData.TryGetValue(key, out currentSchemaData)) - { - currentSchemaData = ExtractSchemaData(key); - multiSchemaData[key] = currentSchemaData; - } - } - else + var key = data != null ? data.GetType() : typeof(Null); + if (!multiSchemaData.TryGetValue(key, out currentSchemaData)) { - currentSchemaData = singleSchemaData; + currentSchemaData = ExtractSchemaData(key); + multiSchemaData[key] = currentSchemaData; } string fullname = null; @@ -204,17 +171,18 @@ public async Task Serialize(string topic, Headers headers, T data, bool } subject = GetSubjectName(topic, isKey, fullname); + var subjectSchemaPair = new KeyValuePair(subject, currentSchemaData.WriterSchemaString); latestSchema = await GetReaderSchema(subject) .ConfigureAwait(continueOnCapturedContext: false); if (latestSchema != null) { - currentSchemaData.WriterSchemaId = latestSchema.Id; + schemaId = latestSchema.Id; } - else if (!currentSchemaData.SubjectsRegistered.Contains(subject)) + else if (!registeredSchemas.TryGetValue(subjectSchemaPair, out schemaId)) { // first usage: register/get schema to check compatibility - currentSchemaData.WriterSchemaId = autoRegisterSchema + schemaId = autoRegisterSchema ? await schemaRegistryClient .RegisterSchemaAsync(subject, currentSchemaData.WriterSchemaString, normalizeSchemas) .ConfigureAwait(continueOnCapturedContext: false) @@ -222,7 +190,7 @@ public async Task Serialize(string topic, Headers headers, T data, bool .GetSchemaIdAsync(subject, currentSchemaData.WriterSchemaString, normalizeSchemas) .ConfigureAwait(continueOnCapturedContext: false); - currentSchemaData.SubjectsRegistered.Add(subject); + registeredSchemas.Add(subjectSchemaPair, schemaId); } } finally @@ -248,7 +216,7 @@ public async Task Serialize(string topic, Headers headers, T data, bool { stream.WriteByte(Constants.MagicByte); - writer.Write(IPAddress.HostToNetworkOrder(currentSchemaData.WriterSchemaId.Value)); + writer.Write(IPAddress.HostToNetworkOrder(schemaId)); currentSchemaData.AvroWriter.Write(data, new BinaryEncoder(stream)); // TODO: maybe change the ISerializer interface so that this copy isn't necessary. diff --git a/test/Confluent.SchemaRegistry.Serdes.UnitTests/BaseSerializeDeserialize.cs b/test/Confluent.SchemaRegistry.Serdes.UnitTests/BaseSerializeDeserialize.cs index 0dbd34b67..e7db25386 100644 --- a/test/Confluent.SchemaRegistry.Serdes.UnitTests/BaseSerializeDeserialize.cs +++ b/test/Confluent.SchemaRegistry.Serdes.UnitTests/BaseSerializeDeserialize.cs @@ -47,6 +47,14 @@ public BaseSerializeDeserializeTests() schemaRegistryMock.Setup(x => x.RegisterSchemaAsync(It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync( (string subject, string schema, bool normalize) => store.TryGetValue(schema, out int id) ? id : store[schema] = store.Count + 1 ); + schemaRegistryMock.Setup(x => x.GetSchemaIdAsync(It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync( + (string subject, string schema, bool normalize) => + { + return subjectStore[subject].First(x => + x.SchemaString == schema + ).Id; + } + ); schemaRegistryMock.Setup(x => x.LookupSchemaAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync( (string subject, Schema schema, bool ignoreDeleted, bool normalize) => { diff --git a/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs b/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs index a8ca546c2..85a859ecc 100644 --- a/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs +++ b/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs @@ -154,6 +154,43 @@ public void ISpecificRecord() Assert.Equal(user.favorite_number, result.favorite_number); } + [Fact] + public void ISpecificRecordStrings() + { + var schemaStr = "{\"type\":\"string\"}"; + var schema = new RegisteredSchema("topic1-value", 1, 1, schemaStr, SchemaType.Avro, null); + store[schemaStr] = 1; + subjectStore["topic1-value"] = new List { schema }; + + schema = new RegisteredSchema("topic2-value", 1, 2, schemaStr, SchemaType.Avro, null); + schema.Metadata = new Metadata(null, new Dictionary + { + { "confluent:version", "1" } + }, null); + store[schemaStr] = 2; + subjectStore["topic2-value"] = new List { schema }; + + var config = new AvroSerializerConfig + { + AutoRegisterSchemas = false, + SubjectNameStrategy = SubjectNameStrategy.Topic + }; + var serializer = new AvroSerializer(schemaRegistryClient, config); + + Headers headers = new Headers(); + var bytes = serializer.SerializeAsync("hi", new SerializationContext(MessageComponentType.Value, "topic1", headers)).Result; + Assert.Equal(1, bytes[4]); + + bytes = serializer.SerializeAsync("world", new SerializationContext(MessageComponentType.Value, "topic2", headers)).Result; + Assert.Equal(2, bytes[4]); + + bytes = serializer.SerializeAsync("hi", new SerializationContext(MessageComponentType.Value, "topic1", headers)).Result; + Assert.Equal(1, bytes[4]); + + bytes = serializer.SerializeAsync("world", new SerializationContext(MessageComponentType.Value, "topic2", headers)).Result; + Assert.Equal(2, bytes[4]); + } + [Fact] public void ISpecificRecordRecordNameStrategy() {