Skip to content

Commit

Permalink
Merge pull request #17 from OSGP/feature/FDP-2357-kod-status-en-meetb…
Browse files Browse the repository at this point in the history
…erichten

FDP-2357: AvroSerializer can handle all Avro messages
  • Loading branch information
loesimmens authored Jul 4, 2024
2 parents 6f18a53 + ab75a84 commit 388afa5
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.gxf.utilities.kafka.avro

import org.apache.avro.message.BinaryMessageEncoder
import org.apache.avro.specific.SpecificData
import org.apache.avro.specific.SpecificRecordBase
import org.slf4j.LoggerFactory
import java.io.IOException
import java.io.OutputStream
import kotlin.reflect.KClass

object AvroEncoder {
val encoders: HashMap<KClass<out SpecificRecordBase>, BinaryMessageEncoder<SpecificRecordBase>> = HashMap()

private val logger = LoggerFactory.getLogger(AvroEncoder::class.java)

@Throws(IOException::class)
fun encode(message: SpecificRecordBase): ByteArray {
val encoder = getEncoder(message)
val byteBuffer = encoder.encode(message)
val bytes = ByteArray(byteBuffer.remaining())
byteBuffer[bytes]
return bytes
}

@Throws(IOException::class)
fun encode(message: SpecificRecordBase, stream: OutputStream) {
val encoder = getEncoder(message)
encoder.encode(message, stream)
}

private fun getEncoder(message: SpecificRecordBase): BinaryMessageEncoder<SpecificRecordBase> {
val existingEncoder = encoders[message::class]

if(existingEncoder != null) {
return existingEncoder
}

logger.info("New encoder created for Avro schema {}", message::class)
val newEncoder = BinaryMessageEncoder<SpecificRecordBase>(SpecificData(), message.schema)
encoders[message::class] = newEncoder
return newEncoder
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,25 @@ SPDX-License-Identifier: Apache-2.0
*/
package com.gxf.utilities.kafka.avro

import org.apache.avro.message.BinaryMessageEncoder
import org.apache.avro.specific.SpecificRecord
import org.apache.avro.specific.SpecificRecordBase
import org.apache.kafka.common.errors.SerializationException
import org.apache.kafka.common.serialization.Serializer
import org.slf4j.LoggerFactory
import java.io.ByteArrayOutputStream

class AvroSerializer<T : SpecificRecord>(private val encoder: BinaryMessageEncoder<T>) : Serializer<T> {
class AvroSerializer : Serializer<SpecificRecordBase> {
companion object {
private val logger = LoggerFactory.getLogger(AvroSerializer::class.java)
}

/**
* Serializes a Byte Array to an Avro specific record
*/
override fun serialize(topic: String?, data: T): ByteArray {
override fun serialize(topic: String?, data: SpecificRecordBase): ByteArray {
try {
logger.trace("Serializing for {}", topic)
val outputStream = ByteArrayOutputStream()
encoder.encode(data, outputStream)
AvroEncoder.encode(data, outputStream)
return outputStream.toByteArray()
} catch (ex: Exception) {
throw SerializationException("Error serializing Avro message for topic: ${topic}", ex)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package com.gxf.utilities.kafka.avro

import org.apache.avro.Schema
import org.apache.avro.specific.SpecificRecordBase
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import java.io.ByteArrayOutputStream

class AvroEncoderTest {
@Test
fun testEncodersCache() {
val message1 = AvroSchema1("field no 1", "field no 2")
val message2 = AvroSchema2("message in a bottle")
val message3 = AvroSchema2("another message for you")
val message4 = AvroSchema2("encode to stream!")

AvroEncoder.encode(message1)
AvroEncoder.encode(message2)
AvroEncoder.encode(message3)
AvroEncoder.encode(message4, ByteArrayOutputStream())

assertThat(AvroEncoder.encoders).containsKeys(AvroSchema1::class)
assertThat(AvroEncoder.encoders).containsKeys(AvroSchema2::class)
assertThat(AvroEncoder.encoders.size).isEqualTo(2)
}
}

class AvroSchema1(private var field1: String, private var field2: String): SpecificRecordBase() {
override fun getSchema(): Schema = Schema.Parser()
.parse("{\"type\":\"record\",\"name\":\"AvroSchema1\",\"namespace\":\"com.alliander.gxf.utilities.kafka.avro\",\"fields\":[{\"name\":\"field1\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}, {\"name\":\"field2\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}")

override fun put(field: Int, value: Any?) {
when(field) {
0 -> {
if(value != null) {
field1 = value.toString()
}
}
1 -> {
if(value != null) {
field2 = value.toString()
}
}
else -> throw IndexOutOfBoundsException()
}
}

override fun get(field: Int): Any {
return when(field) {
0 -> field1
1 -> field2
else -> throw IndexOutOfBoundsException()
}
}
}

class AvroSchema2(private var message: String): SpecificRecordBase() {
override fun getSchema(): Schema = Schema.Parser()
.parse("{\"type\":\"record\",\"name\":\"AvroSchema2\",\"namespace\":\"com.alliander.gxf.utilities.kafka.avro\",\"fields\":[{\"name\":\"message\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}")

override fun put(field: Int, value: Any?) {
when(field) {
0 -> {
if(value != null) {
message = value.toString()
}
}
else -> throw IndexOutOfBoundsException()
}
}

override fun get(field: Int): Any {
return when(field) {
0 -> message
else -> throw IndexOutOfBoundsException()
}
}
}
2 changes: 2 additions & 0 deletions kafka-message-signing/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ dependencies {
implementation("org.springframework.kafka:spring-kafka")
implementation("org.springframework.boot:spring-boot-autoconfigure")

implementation(project(":kafka-avro"))

api(libs.avro)

testImplementation("org.junit.jupiter:junit-jupiter-api")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

package com.gxf.utilities.kafka.message.signing

import com.gxf.utilities.kafka.avro.AvroEncoder
import com.gxf.utilities.kafka.message.wrapper.SignableMessageWrapper
import org.apache.avro.message.BinaryMessageEncoder
import org.apache.avro.specific.SpecificRecordBase
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.kafka.clients.producer.ProducerRecord
Expand Down Expand Up @@ -101,16 +101,10 @@ class MessageSigner(properties: MessageSigningProperties) {
private fun signature(message: SignableMessageWrapper<*>): ByteArray {
check(this.canSignMessages()) { "This MessageSigner is not configured for signing, it can only be used for verification" }
val oldSignature = message.getSignature()
message.setSignature(null)
val byteArray = this.toByteArray(message)
try {
message.setSignature(null)
val messageBytes: ByteArray = if (this.stripAvroHeader) {
this.stripAvroHeader(this.toByteBuffer(message))
} else {
this.toByteBuffer(message)!!.array()
}
val signingSignature = signatureInstance(signatureAlgorithm, signatureProvider, signingKey!!)
signingSignature.update(messageBytes)
return signingSignature.sign()
return signature(byteArray)
} catch (e: SignatureException) {
throw UncheckedSecurityException("Unable to sign message", e)
} finally {
Expand All @@ -135,17 +129,11 @@ class MessageSigner(properties: MessageSigningProperties) {
private fun signature(producerRecord: ProducerRecord<String, out SpecificRecordBase>): ByteArray {
check(this.canSignMessages()) { "This MessageSigner is not configured for signing, it can only be used for verification" }
val oldSignatureHeader = producerRecord.headers().lastHeader(RECORD_HEADER_KEY_SIGNATURE)
producerRecord.headers().remove(RECORD_HEADER_KEY_SIGNATURE)
val specificRecordBase = producerRecord.value()
val byteArray = this.toByteArray(specificRecordBase)
try {
producerRecord.headers().remove(RECORD_HEADER_KEY_SIGNATURE)
val specificRecordBase = producerRecord.value()
val messageBytes: ByteArray = if (this.stripAvroHeader) {
this.stripAvroHeader(this.toByteBuffer(specificRecordBase))
} else {
this.toByteBuffer(specificRecordBase).array()
}
val signingSignature = signatureInstance(signatureAlgorithm, signatureProvider, signingKey!!)
signingSignature.update(messageBytes)
return signingSignature.sign()
return signature(byteArray)
} catch (e: SignatureException) {
throw UncheckedSecurityException("Unable to sign message", e)
} finally {
Expand All @@ -155,6 +143,17 @@ class MessageSigner(properties: MessageSigningProperties) {
}
}

private fun signature(byteArray: ByteArray): ByteArray {
val messageBytes: ByteArray = if (this.stripAvroHeader) {
this.stripAvroHeader(byteArray)
} else {
byteArray
}
val signingSignature = signatureInstance(signatureAlgorithm, signatureProvider, signingKey!!)
signingSignature.update(messageBytes)
return signingSignature.sign()
}

fun canVerifyMessageSignatures(): Boolean {
return this.signingEnabled && this.verificationKey != null
}
Expand Down Expand Up @@ -183,7 +182,7 @@ class MessageSigner(properties: MessageSigningProperties) {

try {
message.setSignature(null)
if(this.verifySignatureBytes(signatureBytes, this.toByteBuffer(message))) {
if(this.verifySignatureBytes(signatureBytes, this.toByteArray(message))) {
return message.message
} else {
throw VerificationException("Verification of message signing failed")
Expand Down Expand Up @@ -221,7 +220,7 @@ class MessageSigner(properties: MessageSigningProperties) {
try {
consumerRecord.headers().remove(RECORD_HEADER_KEY_SIGNATURE)
val specificRecordBase: SpecificRecordBase = consumerRecord.value()
if(this.verifySignatureBytes(signatureBytes, this.toByteBuffer(specificRecordBase))) {
if(this.verifySignatureBytes(signatureBytes, this.toByteArray(specificRecordBase))) {
return consumerRecord
} else {
throw VerificationException("Verification of record signing failed")
Expand All @@ -232,11 +231,11 @@ class MessageSigner(properties: MessageSigningProperties) {
}

@Throws(SignatureException::class)
private fun verifySignatureBytes(signatureBytes: ByteArray, messageByteBuffer: ByteBuffer?): Boolean {
private fun verifySignatureBytes(signatureBytes: ByteArray, messageByteArray: ByteArray): Boolean {
val messageBytes: ByteArray = if (this.stripAvroHeader) {
this.stripAvroHeader(messageByteBuffer)
this.stripAvroHeader(messageByteArray)
} else {
messageByteBuffer!!.array()
messageByteArray
}
val verificationSignature = signatureInstance(signatureAlgorithm, signatureProvider, verificationKey!!)
verificationSignature.update(messageBytes)
Expand All @@ -249,28 +248,29 @@ class MessageSigner(properties: MessageSigningProperties) {
&& ((bytes[1].toInt() and 0xFF) == 0x01)
}

private fun stripAvroHeader(byteBuffer: ByteBuffer?): ByteArray {
val bytes = ByteArray(byteBuffer!!.remaining())
byteBuffer[bytes]
private fun stripAvroHeader(bytes: ByteArray): ByteArray {
if (this.hasAvroHeader(bytes)) {
return Arrays.copyOfRange(bytes, AVRO_HEADER_LENGTH, bytes.size)
}
return bytes
}

private fun toByteBuffer(message: SignableMessageWrapper<*>): ByteBuffer? {
private fun toByteArray(message: SignableMessageWrapper<*>): ByteArray {
try {
return message.toByteBuffer()
val byteBuffer = message.toByteBuffer()
val bytes = ByteArray(byteBuffer.remaining())
byteBuffer[bytes]
return bytes
} catch (e: IOException) {
throw UncheckedIOException("Unable to determine ByteBuffer for Message", e)
throw UncheckedIOException("Unable to determine bytes for message", e)
}
}

private fun toByteBuffer(message: SpecificRecordBase): ByteBuffer {
private fun toByteArray(message: SpecificRecordBase): ByteArray {
try {
return BinaryMessageEncoder<Any>(message.specificData, message.schema).encode(message)
return AvroEncoder.encode(message)
} catch (e: IOException) {
throw UncheckedIOException("Unable to determine ByteBuffer for Message", e)
throw UncheckedIOException("Unable to determine bytes for message", e)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ abstract class SignableMessageWrapper<T>(val message: T) {
* @return ByteBuffer of the whole message
*/
@Throws(IOException::class)
abstract fun toByteBuffer(): ByteBuffer?
abstract fun toByteBuffer(): ByteBuffer

/**
* @return ByteBuffer of the signature in the message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class MessageSignerTest {
private class TestableWrapper : SignableMessageWrapper<String>("Some test message") {
private var signature: ByteBuffer? = null

override fun toByteBuffer(): ByteBuffer? {
override fun toByteBuffer(): ByteBuffer {
return ByteBuffer.wrap(message.toByteArray(StandardCharsets.UTF_8))
}

Expand Down

0 comments on commit 388afa5

Please sign in to comment.