Skip to content

Commit

Permalink
Refactor tx signing (no functional changes)
Browse files Browse the repository at this point in the history
  • Loading branch information
sstone committed Sep 9, 2024
1 parent 8370fa2 commit a8148e8
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ object LocalCommit {
fundingTxIndex: Long, remoteFundingPubKey: PublicKey, commitInput: InputInfo,
commit: CommitSig, localCommitIndex: Long, spec: CommitmentSpec, localPerCommitmentPoint: PublicKey): Either[ChannelException, LocalCommit] = {
val (localCommitTx, htlcTxs) = Commitment.makeLocalTxs(keyManager, params.channelConfig, params.channelFeatures, localCommitIndex, params.localParams, params.remoteParams, fundingTxIndex, remoteFundingPubKey, commitInput, localPerCommitmentPoint, spec)
if (!checkSig(localCommitTx, commit.signature, remoteFundingPubKey, TxOwner.Remote, params.commitmentFormat)) {
if (!localCommitTx.checkSig(commit.signature, remoteFundingPubKey, TxOwner.Remote, params.commitmentFormat)) {
return Left(InvalidCommitmentSignature(params.channelId, fundingTxId, fundingTxIndex, localCommitTx.tx))
}
val sortedHtlcTxs = htlcTxs.sortBy(_.input.outPoint.index)
Expand All @@ -238,7 +238,7 @@ object LocalCommit {
val remoteHtlcPubkey = Generators.derivePubKey(params.remoteParams.htlcBasepoint, localPerCommitmentPoint)
val htlcTxsAndRemoteSigs = sortedHtlcTxs.zip(commit.htlcSignatures).toList.map {
case (htlcTx: HtlcTx, remoteSig) =>
if (!checkSig(htlcTx, remoteSig, remoteHtlcPubkey, TxOwner.Remote, params.commitmentFormat)) {
if (!htlcTx.checkSig(remoteSig, remoteHtlcPubkey, TxOwner.Remote, params.commitmentFormat)) {
return Left(InvalidHtlcSignature(params.channelId, htlcTx.tx.txid))
}
HtlcTxAndRemoteSig(htlcTx, remoteSig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class LocalChannelKeyManager(seed: ByteVector, chainHash: BlockHash) extends Cha
Metrics.SignTxCount.withTags(tags).increment()
KamonExt.time(Metrics.SignTxDuration.withTags(tags)) {
val privateKey = privateKeys.get(publicKey.path)
Transactions.sign(tx, privateKey.privateKey, txOwner, commitmentFormat)
tx.sign(privateKey.privateKey, txOwner, commitmentFormat)
}
}

Expand All @@ -134,7 +134,7 @@ class LocalChannelKeyManager(seed: ByteVector, chainHash: BlockHash) extends Cha
KamonExt.time(Metrics.SignTxDuration.withTags(tags)) {
val privateKey = privateKeys.get(publicKey.path)
val currentKey = Generators.derivePrivKey(privateKey.privateKey, remotePoint)
Transactions.sign(tx, currentKey, txOwner, commitmentFormat)
tx.sign(currentKey, txOwner, commitmentFormat)
}
}

Expand All @@ -154,7 +154,7 @@ class LocalChannelKeyManager(seed: ByteVector, chainHash: BlockHash) extends Cha
KamonExt.time(Metrics.SignTxDuration.withTags(tags)) {
val privateKey = privateKeys.get(publicKey.path)
val currentKey = Generators.revocationPrivKey(privateKey.privateKey, remoteSecret)
Transactions.sign(tx, currentKey, txOwner, commitmentFormat)
tx.sign(currentKey, txOwner, commitmentFormat)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,23 @@ object Transactions {
}
/** Sighash flags to use when signing the transaction. */
def sighash(txOwner: TxOwner, commitmentFormat: CommitmentFormat): Int = SIGHASH_ALL

def sign(key: PrivateKey, txOwner: TxOwner, commitmentFormat: CommitmentFormat): ByteVector64 = Transactions.sign(this, key, sighash(txOwner, commitmentFormat))

def sign(key: PrivateKey, sighashType: Int): ByteVector64 = {
// NB: the tx may have multiple inputs, we will only sign the one provided in txinfo.input. Bear in mind that the
// signature will be invalidated if other inputs are added *afterwards* and sighashType was SIGHASH_ALL.
val inputIndex = tx.txIn.zipWithIndex.find(_._1.outPoint == input.outPoint).get._2
Transactions.sign(tx, input.redeemScript, input.txOut.amount, key, sighashType, inputIndex)
}

def checkSig(sig: ByteVector64, pubKey: PublicKey, txOwner: TxOwner, commitmentFormat: CommitmentFormat): Boolean = {
val sighash = this.sighash(txOwner, commitmentFormat)
val data = Transaction.hashForSigning(tx, inputIndex = 0, input.redeemScript, sighash, input.txOut.amount, SIGVERSION_WITNESS_V0)
Crypto.verifySignature(data, sig, pubKey)
}
}

sealed trait ReplaceableTransactionWithInputInfo extends TransactionWithInputInfo {
/** Block before which the transaction must be confirmed. */
def confirmationTarget: ConfirmationTarget
Expand Down Expand Up @@ -852,15 +868,13 @@ object Transactions {
sig64
}

def sign(txinfo: TransactionWithInputInfo, key: PrivateKey, sighashType: Int): ByteVector64 = {
private def sign(txinfo: TransactionWithInputInfo, key: PrivateKey, sighashType: Int): ByteVector64 = {
// NB: the tx may have multiple inputs, we will only sign the one provided in txinfo.input. Bear in mind that the
// signature will be invalidated if other inputs are added *afterwards* and sighashType was SIGHASH_ALL.
val inputIndex = txinfo.tx.txIn.zipWithIndex.find(_._1.outPoint == txinfo.input.outPoint).get._2
sign(txinfo.tx, txinfo.input.redeemScript, txinfo.input.txOut.amount, key, sighashType, inputIndex)
}

def sign(txinfo: TransactionWithInputInfo, key: PrivateKey, txOwner: TxOwner, commitmentFormat: CommitmentFormat): ByteVector64 = sign(txinfo, key, txinfo.sighash(txOwner, commitmentFormat))

def addSigs(commitTx: CommitTx, localFundingPubkey: PublicKey, remoteFundingPubkey: PublicKey, localSig: ByteVector64, remoteSig: ByteVector64): CommitTx = {
val witness = Scripts.witness2of2(localSig, remoteSig, localFundingPubkey, remoteFundingPubkey)
commitTx.copy(tx = commitTx.tx.updateWitness(0, witness))
Expand Down Expand Up @@ -935,11 +949,4 @@ object Transactions {
// NB: we don't verify the other inputs as they should only be wallet inputs used to RBF the transaction
Try(Transaction.correctlySpends(txinfo.tx, Map(txinfo.input.outPoint -> txinfo.input.txOut), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS))
}

def checkSig(txinfo: TransactionWithInputInfo, sig: ByteVector64, pubKey: PublicKey, txOwner: TxOwner, commitmentFormat: CommitmentFormat): Boolean = {
val sighash = txinfo.sighash(txOwner, commitmentFormat)
val data = Transaction.hashForSigning(txinfo.tx, inputIndex = 0, txinfo.input.redeemScript, sighash, txinfo.input.txOut.amount, SIGVERSION_WITNESS_V0)
Crypto.verifySignature(data, sig, pubKey)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ trait TestVectorsSpec extends AnyFunSuite with Logging {
remotePaymentBasePoint = Remote.payment_basepoint,
localIsChannelOpener = true,
outputs = outputs)
val local_sig = Transactions.sign(tx, Local.funding_privkey, TxOwner.Local, commitmentFormat)
val local_sig = tx.sign(Local.funding_privkey, TxOwner.Local, commitmentFormat)
logger.info(s"# local_signature = ${Scripts.der(local_sig).dropRight(1).toHex}")
val remote_sig = Transactions.sign(tx, Remote.funding_privkey, TxOwner.Remote, commitmentFormat)
val remote_sig = tx.sign(Remote.funding_privkey, TxOwner.Remote, commitmentFormat)
logger.info(s"remote_signature: ${Scripts.der(remote_sig).dropRight(1).toHex}")
Transactions.addSigs(tx, Local.funding_pubkey, Remote.funding_pubkey, local_sig, remote_sig)
}
Expand Down Expand Up @@ -248,8 +248,8 @@ trait TestVectorsSpec extends AnyFunSuite with Logging {

val signedTxs = htlcTxs.collect {
case tx: HtlcSuccessTx =>
val localSig = Transactions.sign(tx, Local.htlc_privkey, TxOwner.Local, commitmentFormat)
val remoteSig = Transactions.sign(tx, Remote.htlc_privkey, TxOwner.Remote, commitmentFormat)
val localSig = tx.sign(Local.htlc_privkey, TxOwner.Local, commitmentFormat)
val remoteSig = tx.sign(Remote.htlc_privkey, TxOwner.Remote, commitmentFormat)
val htlcIndex = htlcScripts.indexOf(Script.parse(tx.input.redeemScript))
val preimage = paymentPreimages.find(p => Crypto.sha256(p) == tx.paymentHash).get
val tx1 = Transactions.addSigs(tx, localSig, remoteSig, preimage, commitmentFormat)
Expand All @@ -260,8 +260,8 @@ trait TestVectorsSpec extends AnyFunSuite with Logging {
logger.info(s"htlc_success_tx (htlc #$htlcIndex): ${tx1.tx}")
tx1
case tx: HtlcTimeoutTx =>
val localSig = Transactions.sign(tx, Local.htlc_privkey, TxOwner.Local, commitmentFormat)
val remoteSig = Transactions.sign(tx, Remote.htlc_privkey, TxOwner.Remote, commitmentFormat)
val localSig = tx.sign(Local.htlc_privkey, TxOwner.Local, commitmentFormat)
val remoteSig = tx.sign(Remote.htlc_privkey, TxOwner.Remote, commitmentFormat)
val htlcIndex = htlcScripts.indexOf(Script.parse(tx.input.redeemScript))
val tx1 = Transactions.addSigs(tx, localSig, remoteSig, commitmentFormat)
Transaction.correctlySpends(tx1.tx, Seq(commitTx.tx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
Expand Down
Loading

0 comments on commit a8148e8

Please sign in to comment.