Skip to content

Commit

Permalink
test vector
Browse files Browse the repository at this point in the history
  • Loading branch information
thomash-acinq committed Dec 9, 2022
1 parent 3ffbf6d commit 1e54913
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 25 deletions.
24 changes: 12 additions & 12 deletions eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto}
import fr.acinq.eclair.wire.protocol._
import grizzled.slf4j.Logging
import scodec.Attempt
import scodec.bits.ByteVector
import scodec.{Attempt, Codec}

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -343,14 +343,14 @@ object Sphinx extends Logging {
private val payloadAndPadLength = 256
private val hopPayloadLength = 9
private val maxNumHop = 27
private val codec: Codec[FatError] = fatErrorCodec(payloadAndPadLength, hopPayloadLength, maxNumHop)
private val totalLength = 12599

def create(sharedSecret: ByteVector32, failure: FailureMessage): ByteVector = {
val failurePayload = FailureMessageCodecs.failureOnionPayload(payloadAndPadLength).encode(failure).require.toByteVector
val hopPayload = HopPayload(ErrorSource, 0 millis)
val zeroPayloads = Seq.fill(maxNumHop)(ByteVector.fill(hopPayloadLength)(0))
val zeroHmacs = (maxNumHop.to(1, -1)).map(Seq.fill(_)(ByteVector32.Zeroes))
val plainError = codec.encode(FatError(failurePayload, zeroPayloads, zeroHmacs)).require.bytes
val plainError = fatErrorCodec(totalLength, hopPayloadLength, maxNumHop).encode(FatError(failurePayload, zeroPayloads, zeroHmacs)).require.bytes
wrap(plainError, sharedSecret, hopPayload).get
}

Expand All @@ -366,32 +366,32 @@ object Sphinx extends Logging {

def wrap(errorPacket: ByteVector, sharedSecret: ByteVector32, hopPayload: HopPayload): Try[ByteVector] = Try {
val um = generateKey("um", sharedSecret)
val error = codec.decode(errorPacket.bits).require.value
val error = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(errorPacket.bits).require.value
val hopPayloads = hopPayloadCodec.encode(hopPayload).require.bytes +: error.hopPayloads.dropRight(1)
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, hopPayloads, error.hmacs, 0) +: error.hmacs.dropRight(1).map(_.drop(1))
val newError = codec.encode(FatError(error.failurePayload, hopPayloads, hmacs)).require.bytes
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, hopPayloads, error.hmacs.map(_.drop(1)), 0) +: error.hmacs.dropRight(1).map(_.drop(1))
val newError = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).encode(FatError(error.failurePayload, hopPayloads, hmacs)).require.bytes
val key = generateKey("ammag", sharedSecret)
val stream = generateStream(key, newError.length.toInt)
newError xor stream
}

private def unwrap(errorPacket: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Try[(ByteVector, HopPayload)] = Try {
def unwrap(errorPacket: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Try[(ByteVector, HopPayload)] = Try {
val key = generateKey("ammag", sharedSecret)
val stream = generateStream(key, errorPacket.length.toInt)
val error = codec.decode((errorPacket xor stream).bits).require.value
val error = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode((errorPacket xor stream).bits).require.value
val um = generateKey("um", sharedSecret)
val shiftedHmacs = error.hmacs.tail.map(ByteVector32.Zeroes +: _) :+ Seq(ByteVector32.Zeroes)
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, error.hopPayloads, shiftedHmacs, minNumHop)
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, error.hopPayloads, error.hmacs.tail, minNumHop)
require(hmacs == error.hmacs.head.drop(minNumHop), "Invalid HMAC")
val shiftedHopPayloads = error.hopPayloads.tail :+ ByteVector.fill(hopPayloadLength)(0)
val unwrapedError = FatError(error.failurePayload, shiftedHopPayloads, shiftedHmacs)
(codec.encode(unwrapedError).require.bytes,
(fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).encode(unwrapedError).require.bytes,
hopPayloadCodec.decode(error.hopPayloads.head.bits).require.value)
}

def decrypt(errorPacket: ByteVector, sharedSecrets: Seq[(ByteVector32, PublicKey)]): Either[InvalidFatErrorPacket, DecryptedFailurePacket] = {
var packet = errorPacket
var minNumHop = 1
var minNumHop = 0
val hopPayloads = ArrayBuffer.empty[(PublicKey, HopPayload)]
for ((sharedSecret, nodeId) <- sharedSecrets) {
unwrap(packet, sharedSecret, minNumHop) match {
Expand All @@ -403,7 +403,7 @@ object Sphinx extends Logging {
minNumHop += 1
hopPayloads += ((nodeId, hopPayload))
case FatError.ErrorSource =>
val failurePayload = codec.decode(unwrapedPacket.bits).require.value.failurePayload
val failurePayload = fatErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(unwrapedPacket.bits).require.value.failurePayload
FailureMessageCodecs.failureOnionPayload(payloadAndPadLength).decode(failurePayload.bits) match {
case Attempt.Successful(failureMessage) =>
return Right(DecryptedFailurePacket(nodeId, failureMessage.value))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ object FatError {
.xmap(pair => pair._1 +: pair._2, seq => (seq.head, seq.tail))
}

def fatErrorCodec(payloadAndPadLength: Int = 256, hopPayloadLength: Int = 9, maxHop: Int = 27): Codec[FatError] = (
("failure_payload" | bytes(payloadAndPadLength + 4)) ::
("hop_payloads" | listOfN(provide(maxHop), bytes(hopPayloadLength)).xmap[Seq[ByteVector]](_.toSeq, _.toList)) ::
("hmacs" | hmacsCodec(maxHop))).as[FatError].complete
def fatErrorCodec(totalLength: Int, hopPayloadLength: Int, maxNumHop: Int): Codec[FatError] = {
val metadataLength = maxNumHop * hopPayloadLength + (maxNumHop * (maxNumHop + 1)) / 2 * 32
(("failure_payload" | bytes(totalLength - metadataLength)) ::
("hop_payloads" | listOfN(provide(maxNumHop), bytes(hopPayloadLength)).xmap[Seq[ByteVector]](_.toSeq, _.toList)) ::
("hmacs" | hmacsCodec(maxNumHop))).as[FatError].complete}
}
25 changes: 25 additions & 0 deletions eclair-core/src/test/resources/fat_error.json

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedRoute, BlindedRouteDe
import fr.acinq.eclair.wire.protocol
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, UInt64, randomBytes, randomKey}
import org.json4s.JsonAST._
import org.json4s.jackson.JsonMethods
import org.scalatest.funsuite.AnyFunSuite
import scodec.bits._

import java.io.File
import scala.concurrent.duration.DurationInt
import scala.io.Source
import scala.util.Success

/**
Expand Down Expand Up @@ -444,6 +448,21 @@ class SphinxSpec extends AnyFunSuite {
assert(decryptionError == expected)
}

test("fat error test vector") {
val src = Source.fromFile(new File(getClass.getResource(s"/fat_error.json").getFile))
try {
val testVector = JsonMethods.parse(src.mkString).asInstanceOf[JObject].values
val hops = testVector("hops").asInstanceOf[List[Map[String, String]]]
val sharedSecrets = hops.map(hop => ByteVector32(ByteVector.fromValidHex(hop("sharedSecret"))))
val encryptedMessages = hops.map(hop => ByteVector.fromValidHex(hop("encryptedMessage")))
val nodeIds = (1 to 5).map(_ => randomKey().publicKey)
//println(FatErrorPacket.unwrap(encryptedMessages(0), sharedSecrets(0), 0))
//println(FatErrorPacket.decrypt(encryptedMessages.head, sharedSecrets.zip(nodeIds)))
} finally {
src.close()
}
}

test("create blinded route (reference test vector)") {
val alice = PrivateKey(hex"4141414141414141414141414141414141414141414141414141414141414141")
val bob = PrivateKey(hex"4242424242424242424242424242424242424242424242424242424242424242")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
val failures = Seq(
LocalFailure(finalAmount, Nil, ChannelUnavailable(randomBytes32())),
RemoteFailure(finalAmount, Nil, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(100 msat, makeChannelUpdate(ShortChannelId(2), 15 msat, 150, CltvExpiryDelta(48))))),
UnreadableRemoteFailure(finalAmount, Nil)
UnreadableRemoteFailure(finalAmount, Nil, ???)
)
val extraEdges1 = Seq(
BasicEdge(a, b, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12)), BasicEdge(b, c, ShortChannelId(2), 15 msat, 150, CltvExpiryDelta(48)),
Expand Down Expand Up @@ -412,14 +412,14 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
childPayFsm.expectMsgType[SendPaymentToRoute]

val (failedId1, failedRoute1) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head
childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.amount, failedRoute1.hops))))
childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.amount, failedRoute1.hops, ???))))
router.expectMsgType[RouteRequest]
router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ad :: hop_de :: Nil))))
childPayFsm.expectMsgType[SendPaymentToRoute]

assert(!payFsm.stateData.asInstanceOf[PaymentProgress].pending.contains(failedId1))
val (failedId2, failedRoute2) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head
val result = abortAfterFailure(f, PaymentFailed(failedId2, paymentHash, Seq(UnreadableRemoteFailure(failedRoute2.amount, failedRoute2.hops))))
val result = abortAfterFailure(f, PaymentFailed(failedId2, paymentHash, Seq(UnreadableRemoteFailure(failedRoute2.amount, failedRoute2.hops, ???))))
assert(result.failures.length >= 3)
assert(result.failures.contains(LocalFailure(finalAmount, Nil, RetryExhausted)))

Expand Down Expand Up @@ -508,7 +508,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
childPayFsm.expectMsgType[SendPaymentToRoute]

val (failedId1, failedRoute1) :: (failedId2, failedRoute2) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq
childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.amount, failedRoute1.hops))))
childPayFsm.send(payFsm, PaymentFailed(failedId1, paymentHash, Seq(UnreadableRemoteFailure(failedRoute1.amount, failedRoute1.hops, ???))))
router.expectMsgType[RouteRequest]

val result = abortAfterFailure(f, PaymentFailed(failedId2, paymentHash, Seq(RemoteFailure(failedRoute2.amount, failedRoute2.hops, Sphinx.DecryptedFailurePacket(e, PaymentTimeout)))))
Expand All @@ -526,7 +526,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
childPayFsm.expectMsgType[SendPaymentToRoute]

val (failedId, failedRoute) :: (successId, successRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq
childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(UnreadableRemoteFailure(failedRoute.amount, failedRoute.hops))))
childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(UnreadableRemoteFailure(failedRoute.amount, failedRoute.hops, ???))))
router.expectMsgType[RouteRequest]

val result = fulfillPendingPayments(f, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)))) // unparsable message

// we allow 2 tries, so we send a 2nd request to the router
assert(sender.expectMsgType[PaymentFailed].failures == UnreadableRemoteFailure(route.amount, route.hops) :: UnreadableRemoteFailure(route.amount, route.hops) :: Nil)
assert(sender.expectMsgType[PaymentFailed].failures == UnreadableRemoteFailure(route.amount, route.hops, ???) :: UnreadableRemoteFailure(route.amount, route.hops, ???) :: Nil)
awaitCond(nodeParams.db.payments.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) // after last attempt the payment is failed

val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics]
Expand Down Expand Up @@ -794,8 +794,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
(RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(c, UnknownNextPeer)), Set.empty, Set(ChannelDesc(scid_cd, c, d))),
(RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(100 msat, update_bc))), Set.empty, Set.empty),
// unreadable remote failures -> blacklist all nodes except our direct peer and the final recipient
(UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: Nil), Set.empty, Set.empty),
(UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: ChannelHop(ShortChannelId(5656986L), d, e, null) :: Nil), Set(c, d), Set.empty)
(UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: Nil, ???), Set.empty, Set.empty),
(UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: ChannelHop(ShortChannelId(5656986L), d, e, null) :: Nil, ???), Set(c, d), Set.empty)
)

for ((failure, expectedNodes, expectedChannels) <- testCases) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
val payFSM = mockPayFSM.expectMessageType[akka.actor.ActorRef]
router.expectMessageType[RouteRequest]

val failures = RemoteFailure(outgoingAmount, Nil, Sphinx.DecryptedFailurePacket(outgoingNodeId, FinalIncorrectHtlcAmount(42 msat))) :: UnreadableRemoteFailure(outgoingAmount, Nil) :: Nil
val failures = RemoteFailure(outgoingAmount, Nil, Sphinx.DecryptedFailurePacket(outgoingNodeId, FinalIncorrectHtlcAmount(42 msat))) :: UnreadableRemoteFailure(outgoingAmount, Nil, ???) :: Nil
payFSM ! PaymentFailed(relayId, incomingMultiPart.head.add.paymentHash, failures)

incomingMultiPart.foreach { p =>
Expand Down

0 comments on commit 1e54913

Please sign in to comment.