Skip to content

Commit

Permalink
Refactor Message class
Browse files Browse the repository at this point in the history
Add changes to ergo-core readme
  • Loading branch information
ccellado committed Feb 8, 2024
1 parent 3c061a2 commit c55fdef
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 51 deletions.
4 changes: 3 additions & 1 deletion ergo-core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ val handshakeMessageSerialized = HandshakeSerializer.toBytes(handshakeMessage)
Serialize the message and send it.
If the message arrived successfully, start communicating with the peer node.

All communication is wrapped with Message headers, format described [here](https://docs.ergoplatform.com/dev/p2p/network/#message-format).
All communication is wrapped with message headers.
Format described [here](https://docs.ergoplatform.com/dev/p2p/network/#message-format).
[MessageBase](src/main/scala/org/ergoplatform/network/message/MessageBase.scala) interface to implement.

## Syncing with the node

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package org.ergoplatform.network.message

import org.ergoplatform.network.message.MessageConstants._

import scala.util.{Success, Try}

/**
* Trait for a ergo network message
*
* @param spec - message specification
* @param input - message being wrapped, whether in byte-array form (if from outside),
* or structured data (if formed locally)
* @tparam Content - message data type
*/
trait MessageBase[Content] {
val spec: MessageSpec[Content]
val input: Either[Array[Byte], Content]

/**
* Message data bytes
*/
lazy val dataBytes: Array[Byte] = input match {
case Left(db) => db
case Right(d) => spec.toBytes(d)
}

/**
* Structured message content
*/
lazy val data: Try[Content] = input match {
case Left(db) => spec.parseBytesTry(db)
case Right(d) => Success(d)
}

lazy val dataLength: Int = dataBytes.length

/**
* @return serialized message length in bytes
*/
def messageLength: Int = {
if (dataLength > 0) {
HeaderLength + ChecksumLength + dataLength
} else {
HeaderLength
}
}

}
43 changes: 6 additions & 37 deletions src/main/scala/org/ergoplatform/network/message/Message.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package org.ergoplatform.network.message


import akka.actor.DeadLetterSuppression
import scorex.core.network.ConnectedPeer
import scala.util.{Success, Try}
import org.ergoplatform.network.message.MessageConstants._

/**
* Wrapper for a network message, whether come from external peer or generated locally
Expand All @@ -15,38 +12,10 @@ import org.ergoplatform.network.message.MessageConstants._
* @param source - source peer, if the message is from outside
* @tparam Content - message data type
*/
case class Message[Content](spec: MessageSpec[Content],
input: Either[Array[Byte], Content],
source: Option[ConnectedPeer])
extends DeadLetterSuppression {

/**
* Message data bytes
*/
lazy val dataBytes: Array[Byte] = input match {
case Left(db) => db
case Right(d) => spec.toBytes(d)
}

/**
* Structured message content
*/
lazy val data: Try[Content] = input match {
case Left(db) => spec.parseBytesTry(db)
case Right(d) => Success(d)
}

lazy val dataLength: Int = dataBytes.length

/**
* @return serialized message length in bytes
*/
def messageLength: Int = {
if (dataLength > 0) {
HeaderLength + ChecksumLength + dataLength
} else {
HeaderLength
}
}

}
case class Message[Content](
spec: MessageSpec[Content],
input: Either[Array[Byte], Content],
source: Option[ConnectedPeer]
) extends MessageBase[Content]
with DeadLetterSuppression
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
package org.ergoplatform.network.message

import java.nio.ByteOrder

import akka.util.ByteString
import scorex.core.network.{ConnectedPeer, MaliciousBehaviorException}
import scorex.crypto.hash.Blake2b256
import scala.util.Try


class MessageSerializer(specs: Seq[MessageSpec[_]], magicBytes: Array[Byte]) {

import MessageConstants.{ChecksumLength, HeaderLength, MagicLength}

import scala.language.existentials

private implicit val byteOrder: ByteOrder = ByteOrder.BIG_ENDIAN
implicit private val byteOrder: ByteOrder = ByteOrder.BIG_ENDIAN

private val specsMap = Map(specs.map(s => s.messageCode -> s): _*)
.ensuring(m => m.size == specs.size, "Duplicate message codes")

def serialize(obj: Message[_]): ByteString = {
def serialize[A <: MessageBase[_]](obj: A): ByteString = {
val builder = ByteString.createBuilder
.putBytes(magicBytes)
.putByte(obj.spec.messageCode)
Expand All @@ -34,14 +32,17 @@ class MessageSerializer(specs: Seq[MessageSpec[_]], magicBytes: Array[Byte]) {
}

//MAGIC ++ Array(spec.messageCode) ++ Ints.toByteArray(dataLength) ++ dataWithChecksum
def deserialize(byteString: ByteString, sourceOpt: Option[ConnectedPeer]): Try[Option[Message[_]]] = Try {
def deserialize(
byteString: ByteString,
sourceOpt: Option[ConnectedPeer]
): Try[Option[Message[_]]] = Try {
if (byteString.length < HeaderLength) {
None
} else {
val it = byteString.iterator
val magic = it.getBytes(MagicLength)
val it = byteString.iterator
val magic = it.getBytes(MagicLength)
val msgCode = it.getByte
val length = it.getInt
val length = it.getInt

//peer is trying to cause buffer overflow or breaking the parsing
if (length < 0) {
Expand All @@ -53,17 +54,22 @@ class MessageSerializer(specs: Seq[MessageSpec[_]], magicBytes: Array[Byte]) {
} else {
//peer is from another network
if (!java.util.Arrays.equals(magic, magicBytes)) {
throw MaliciousBehaviorException(s"Wrong magic bytes, expected ${magicBytes.mkString}, got ${magic.mkString} in : ${byteString.utf8String}")
throw MaliciousBehaviorException(
s"Wrong magic bytes, expected ${magicBytes.mkString}, got ${magic.mkString} in : ${byteString.utf8String}"
)
}
val spec = specsMap.getOrElse(msgCode, throw new Error(s"No message handler found for $msgCode"))
val spec = specsMap
.getOrElse(msgCode, throw new Error(s"No message handler found for $msgCode"))
val msgData = if (length > 0) {
val checksum = it.getBytes(ChecksumLength)
val data = it.getBytes(length)
val digest = Blake2b256.hash(data).take(ChecksumLength)
val data = it.getBytes(length)
val digest = Blake2b256.hash(data).take(ChecksumLength)

//peer reported incorrect checksum
if (!java.util.Arrays.equals(checksum, digest)) {
throw MaliciousBehaviorException(s"Wrong checksum, expected ${digest.mkString}, got ${checksum.mkString}")
throw MaliciousBehaviorException(
s"Wrong checksum, expected ${digest.mkString}, got ${checksum.mkString}"
)
}
data
} else {
Expand Down

0 comments on commit c55fdef

Please sign in to comment.