diff --git a/bitcoin_client/ledger_bitcoin/__init__.py b/bitcoin_client/ledger_bitcoin/__init__.py index 4c4bd82be..5777dfb37 100644 --- a/bitcoin_client/ledger_bitcoin/__init__.py +++ b/bitcoin_client/ledger_bitcoin/__init__.py @@ -1,7 +1,7 @@ """Ledger Nano Bitcoin app client""" -from .client_base import Client, TransportClient, PartialSignature +from .client_base import Client, TransportClient, PartialSignature, MusigPubNonce, MusigPartialSignature, SignPsbtYieldedObject from .client import createClient from .common import Chain @@ -13,6 +13,9 @@ "Client", "TransportClient", "PartialSignature", + "MusigPubNonce", + "MusigPartialSignature", + "SignPsbtYieldedObject", "createClient", "Chain", "AddressType", diff --git a/bitcoin_client/ledger_bitcoin/bip0327.py b/bitcoin_client/ledger_bitcoin/bip0327.py new file mode 100644 index 000000000..8d4680791 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/bip0327.py @@ -0,0 +1,177 @@ +# extracted from the BIP327 reference implementation: https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0327/reference.py + +# Only contains the key aggregation part of the library + +# The code in this source file is distributed under the BSD-3-Clause. + +# autopep8: off + +from typing import List, Optional, Tuple, NewType, NamedTuple +import hashlib + +# +# The following helper functions were copied from the BIP-340 reference implementation: +# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py +# + +p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F +n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 + +# Points are tuples of X and Y coordinates and the point at infinity is +# represented by the None keyword. +G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) + +Point = Tuple[int, int] + +# This implementation can be sped up by storing the midstate after hashing +# tag_hash instead of rehashing it all the time. +def tagged_hash(tag: str, msg: bytes) -> bytes: + tag_hash = hashlib.sha256(tag.encode()).digest() + return hashlib.sha256(tag_hash + tag_hash + msg).digest() + +def is_infinite(P: Optional[Point]) -> bool: + return P is None + +def x(P: Point) -> int: + assert not is_infinite(P) + return P[0] + +def y(P: Point) -> int: + assert not is_infinite(P) + return P[1] + +def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: + if P1 is None: + return P2 + if P2 is None: + return P1 + if (x(P1) == x(P2)) and (y(P1) != y(P2)): + return None + if P1 == P2: + lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p + else: + lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p + x3 = (lam * lam - x(P1) - x(P2)) % p + return (x3, (lam * (x(P1) - x3) - y(P1)) % p) + +def point_mul(P: Optional[Point], n: int) -> Optional[Point]: + R = None + for i in range(256): + if (n >> i) & 1: + R = point_add(R, P) + P = point_add(P, P) + return R + +def bytes_from_int(x: int) -> bytes: + return x.to_bytes(32, byteorder="big") + +def lift_x(b: bytes) -> Optional[Point]: + x = int_from_bytes(b) + if x >= p: + return None + y_sq = (pow(x, 3, p) + 7) % p + y = pow(y_sq, (p + 1) // 4, p) + if pow(y, 2, p) != y_sq: + return None + return (x, y if y & 1 == 0 else p-y) + +def int_from_bytes(b: bytes) -> int: + return int.from_bytes(b, byteorder="big") + +def has_even_y(P: Point) -> bool: + assert not is_infinite(P) + return y(P) % 2 == 0 + +# +# End of helper functions copied from BIP-340 reference implementation. +# + +PlainPk = NewType('PlainPk', bytes) +XonlyPk = NewType('XonlyPk', bytes) + +# There are two types of exceptions that can be raised by this implementation: +# - ValueError for indicating that an input doesn't conform to some function +# precondition (e.g. an input array is the wrong length, a serialized +# representation doesn't have the correct format). +# - InvalidContributionError for indicating that a signer (or the +# aggregator) is misbehaving in the protocol. +# +# Assertions are used to (1) satisfy the type-checking system, and (2) check for +# inconvenient events that can't happen except with negligible probability (e.g. +# output of a hash function is 0) and can't be manually triggered by any +# signer. + +# This exception is raised if a party (signer or nonce aggregator) sends invalid +# values. Actual implementations should not crash when receiving invalid +# contributions. Instead, they should hold the offending party accountable. +class InvalidContributionError(Exception): + def __init__(self, signer, contrib): + self.signer = signer + # contrib is one of "pubkey", "pubnonce", "aggnonce", or "psig". + self.contrib = contrib + +infinity = None + +def xbytes(P: Point) -> bytes: + return bytes_from_int(x(P)) + +def cbytes(P: Point) -> bytes: + a = b'\x02' if has_even_y(P) else b'\x03' + return a + xbytes(P) + +def point_negate(P: Optional[Point]) -> Optional[Point]: + if P is None: + return P + return (x(P), p - y(P)) + +def cpoint(x: bytes) -> Point: + if len(x) != 33: + raise ValueError('x is not a valid compressed point.') + P = lift_x(x[1:33]) + if P is None: + raise ValueError('x is not a valid compressed point.') + if x[0] == 2: + return P + elif x[0] == 3: + P = point_negate(P) + assert P is not None + return P + else: + raise ValueError('x is not a valid compressed point.') + +KeyAggContext = NamedTuple('KeyAggContext', [('Q', Point), + ('gacc', int), + ('tacc', int)]) + +def key_agg(pubkeys: List[PlainPk]) -> KeyAggContext: + pk2 = get_second_key(pubkeys) + u = len(pubkeys) + Q = infinity + for i in range(u): + try: + P_i = cpoint(pubkeys[i]) + except ValueError: + raise InvalidContributionError(i, "pubkey") + a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2) + Q = point_add(Q, point_mul(P_i, a_i)) + # Q is not the point at infinity except with negligible probability. + assert(Q is not None) + gacc = 1 + tacc = 0 + return KeyAggContext(Q, gacc, tacc) + +def hash_keys(pubkeys: List[PlainPk]) -> bytes: + return tagged_hash('KeyAgg list', b''.join(pubkeys)) + +def get_second_key(pubkeys: List[PlainPk]) -> PlainPk: + u = len(pubkeys) + for j in range(1, u): + if pubkeys[j] != pubkeys[0]: + return pubkeys[j] + return PlainPk(b'\x00'*33) + +def key_agg_coeff_internal(pubkeys: List[PlainPk], pk_: PlainPk, pk2: PlainPk) -> int: + L = hash_keys(pubkeys) + if pk_ == pk2: + return 1 + return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n diff --git a/bitcoin_client/ledger_bitcoin/client.py b/bitcoin_client/ledger_bitcoin/client.py index 351370320..89279df36 100644 --- a/bitcoin_client/ledger_bitcoin/client.py +++ b/bitcoin_client/ledger_bitcoin/client.py @@ -3,23 +3,25 @@ import base64 from io import BytesIO, BufferedReader +from .embit import base58 from .embit.base import EmbitError from .embit.descriptor import Descriptor from .embit.networks import NETWORKS from .command_builder import BitcoinCommandBuilder, BitcoinInsType from .common import Chain, read_uint, read_varint -from .client_command import ClientCommandInterpreter -from .client_base import Client, TransportClient, PartialSignature +from .client_command import ClientCommandInterpreter, CCMD_YIELD_MUSIG_PARTIALSIGNATURE_TAG, CCMD_YIELD_MUSIG_PUBNONCE_TAG +from .client_base import Client, MusigPartialSignature, MusigPubNonce, SignPsbtYieldedObject, TransportClient, PartialSignature from .client_legacy import LegacyClient from .exception import DeviceException from .errors import UnknownDeviceError from .merkle import get_merkleized_map_commitment from .wallet import WalletPolicy, WalletType from .psbt import PSBT, normalize_psbt -from . import segwit_addr from ._serialize import deser_string +from .bip0327 import key_agg, cbytes + def parse_stream_to_map(f: BufferedReader) -> Mapping[bytes, bytes]: result = {} @@ -39,6 +41,54 @@ def parse_stream_to_map(f: BufferedReader) -> Mapping[bytes, bytes]: return result +def aggr_xpub(pubkeys: List[bytes], chain: Chain) -> str: + BIP_MUSIG_CHAINCODE = bytes.fromhex( + "868087ca02a6f974c4598924c36b57762d32cb45717167e300622c7167e38965") + # sort the pubkeys prior to aggregation + ctx = key_agg(list(sorted(pubkeys))) + compressed_pubkey = cbytes(ctx.Q) + + # Serialize according to BIP-32 + if chain == Chain.MAIN: + version = 0x0488B21E + else: + version = 0x043587CF + + return base58.encode_check(b''.join([ + version.to_bytes(4, byteorder='big'), + b'\x00', # depth + b'\x00\x00\x00\x00', # parent fingerprint + b'\x00\x00\x00\x00', # child number + BIP_MUSIG_CHAINCODE, + compressed_pubkey + ])) + + +# Given a valid descriptor, replaces each musig() (if any) with the +# corresponding synthetic xpub/tpub. +def replace_musigs(desc: str, chain: Chain) -> str: + while True: + musig_start = desc.find("musig(") + if musig_start == -1: + break + musig_end = desc.find(")", musig_start) + if musig_end == -1: + raise ValueError("Invalid descriptor template") + + key_and_origs = desc[musig_start+6:musig_end].split(",") + pubkeys = [] + for key_orig in key_and_origs: + orig_end = key_orig.find("]") + xpub = key_orig if orig_end == -1 else key_orig[orig_end+1:] + pubkeys.append(base58.decode_check(xpub)[-33:]) + + # replace with the aggregate xpub + desc = desc[:musig_start] + \ + aggr_xpub(pubkeys, chain) + desc[musig_end+1:] + + return desc + + def _make_partial_signature(pubkey_augm: bytes, signature: bytes) -> PartialSignature: if len(pubkey_augm) == 64: # tapscript spend: pubkey_augm is the concatenation of: @@ -56,6 +106,60 @@ def _make_partial_signature(pubkey_augm: bytes, signature: bytes) -> PartialSign return PartialSignature(signature=signature, pubkey=pubkey_augm) +def _decode_signpsbt_yielded_value(res: bytes) -> Tuple[int, SignPsbtYieldedObject]: + res_buffer = BytesIO(res) + input_index_or_tag = read_varint(res_buffer) + if input_index_or_tag == CCMD_YIELD_MUSIG_PUBNONCE_TAG: + input_index = read_varint(res_buffer) + pubnonce = res_buffer.read(66) + participant_pk = res_buffer.read(33) + aggregate_pubkey = res_buffer.read(33) + tapleaf_hash = res_buffer.read() + if len(tapleaf_hash) == 0: + tapleaf_hash = None + + return ( + input_index, + MusigPubNonce( + participant_pubkey=participant_pk, + aggregate_pubkey=aggregate_pubkey, + tapleaf_hash=tapleaf_hash, + pubnonce=pubnonce + ) + ) + elif input_index_or_tag == CCMD_YIELD_MUSIG_PARTIALSIGNATURE_TAG: + input_index = read_varint(res_buffer) + partial_signature = res_buffer.read(32) + participant_pk = res_buffer.read(33) + aggregate_pubkey = res_buffer.read(33) + tapleaf_hash = res_buffer.read() + if len(tapleaf_hash) == 0: + tapleaf_hash = None + + return ( + input_index, + MusigPartialSignature( + participant_pubkey=participant_pk, + aggregate_pubkey=aggregate_pubkey, + tapleaf_hash=tapleaf_hash, + partial_signature=partial_signature + ) + ) + else: + # other values follow an encoding without an explicit tag, where the + # first element is the input index. All the signature types are implemented + # by the PartialSignature type (not to be confused with the musig Partial Signature). + input_index = input_index_or_tag + + pubkey_augm_len = read_uint(res_buffer, 8) + pubkey_augm = res_buffer.read(pubkey_augm_len) + + signature = res_buffer.read() + + return((input_index, _make_partial_signature(pubkey_augm, signature))) + + + class NewClient(Client): # internal use for testing: if set to True, sign_psbt will not clone the psbt before converting to psbt version 2 _no_clone_psbt: bool = False @@ -162,7 +266,7 @@ def get_wallet_address( return result - def sign_psbt(self, psbt: Union[PSBT, bytes, str], wallet: WalletPolicy, wallet_hmac: Optional[bytes]) -> List[Tuple[int, PartialSignature]]: + def sign_psbt(self, psbt: Union[PSBT, bytes, str], wallet: WalletPolicy, wallet_hmac: Optional[bytes]) -> List[Tuple[int, SignPsbtYieldedObject]]: psbt = normalize_psbt(psbt) @@ -231,17 +335,10 @@ def sign_psbt(self, psbt: Union[PSBT, bytes, str], wallet: WalletPolicy, wallet_ if any(len(x) <= 1 for x in results): raise RuntimeError("Invalid response") - results_list: List[Tuple[int, PartialSignature]] = [] + results_list: List[Tuple[int, SignPsbtYieldedObject]] = [] for res in results: - res_buffer = BytesIO(res) - input_index = read_varint(res_buffer) - - pubkey_augm_len = read_uint(res_buffer, 8) - pubkey_augm = res_buffer.read(pubkey_augm_len) - - signature = res_buffer.read() - - results_list.append((input_index, _make_partial_signature(pubkey_augm, signature))) + input_index, obj = _decode_signpsbt_yielded_value(res) + results_list.append((input_index, obj)) return results_list @@ -273,6 +370,11 @@ def sign_message(self, message: Union[str, bytes], bip32_path: str) -> str: def _derive_address_for_policy(self, wallet: WalletPolicy, change: bool, address_index: int) -> Optional[str]: desc_str = wallet.get_descriptor(change) + + # Since embit does not support musig() in descriptors, we replace each + # occurrence with the corresponding aggregated xpub + desc_str = replace_musigs(desc_str, self.chain) + try: desc = Descriptor.from_string(desc_str) diff --git a/bitcoin_client/ledger_bitcoin/client_base.py b/bitcoin_client/ledger_bitcoin/client_base.py index 5130bf7ef..d7b9461db 100644 --- a/bitcoin_client/ledger_bitcoin/client_base.py +++ b/bitcoin_client/ledger_bitcoin/client_base.py @@ -28,7 +28,8 @@ def __init__(self, sw: int, data: bytes) -> None: class TransportClient: def __init__(self, interface: Literal['hid', 'tcp'] = "tcp", *, server: str = "127.0.0.1", port: int = 9999, path: Optional[str] = None, hid: Optional[HID] = None, debug: bool = False): - self.transport = Transport('hid', path=path, hid=hid, debug=debug) if interface == 'hid' else Transport(interface, server=server, port=port, debug=debug) + self.transport = Transport('hid', path=path, hid=hid, debug=debug) if interface == 'hid' else Transport( + interface, server=server, port=port, debug=debug) def apdu_exchange( self, cla: int, ins: int, data: bytes = b"", p1: int = 0, p2: int = 0 @@ -67,18 +68,62 @@ def print_response(sw: int, data: bytes) -> None: @dataclass(frozen=True) class PartialSignature: - """Represents a partial signature returned by sign_psbt. + """Represents a partial signature returned by sign_psbt. Such objects can be added to the PSBT. It always contains a pubkey and a signature. - The pubkey + The pubkey is a compressed 33-byte for legacy and segwit Scripts, or 32-byte x-only key for taproot. + The signature is in the format it would be pushed on the scriptSig or the witness stack, therefore of + variable length, and possibly concatenated with the SIGHASH flag byte if appropriate. - The tapleaf_hash is also filled if signing a for a tapscript. + The tapleaf_hash is also filled if signing for a tapscript. + + Note: not to be confused with 'partial signature' of protocols like MuSig2; """ pubkey: bytes signature: bytes tapleaf_hash: Optional[bytes] = None +@dataclass(frozen=True) +class MusigPubNonce: + """Represents a pubnonce returned by sign_psbt during the first round of a Musig2 signing session. + + It always contains + - the participant_pubkey, a 33-byte compressed pubkey; + - aggregate_pubkey, the 33-byte compressed pubkey key that is the aggregate of all the participant + pubkeys, with the necessary tweaks; its x-only version is the key present in the Script; + - the 66-byte pubnonce. + + The tapleaf_hash is also filled if signing for a tapscript; `None` otherwise. + """ + participant_pubkey: bytes + aggregate_pubkey: bytes + tapleaf_hash: Optional[bytes] + pubnonce: bytes + + +@dataclass(frozen=True) +class MusigPartialSignature: + """Represents a partial signature returned by sign_psbt during the second round of a Musig2 signing session. + + It always contains + - the participant_pubkey, a 33-byte compressed pubkey; + - aggregate_pubkey, the 33-byte compressed pubkey key that is the aggregate of all the participant + pubkeys, with the necessary tweaks; its x-only version is the key present in the Script; + - the partial_signature, the 32-byte partial signature for this participant. + + The tapleaf_hash is also filled if signing for a tapscript; `None` otherwise + """ + participant_pubkey: bytes + aggregate_pubkey: bytes + tapleaf_hash: Optional[bytes] + partial_signature: bytes + + +SignPsbtYieldedObject = Union[PartialSignature, + MusigPubNonce, MusigPartialSignature] + + class Client: def __init__(self, transport_client: TransportClient, chain: Chain = Chain.MAIN, debug: bool = False) -> None: self.transport_client = transport_client @@ -218,7 +263,7 @@ def get_wallet_address( raise NotImplementedError - def sign_psbt(self, psbt: Union[PSBT, bytes, str], wallet: WalletPolicy, wallet_hmac: Optional[bytes]) -> List[Tuple[int, PartialSignature]]: + def sign_psbt(self, psbt: Union[PSBT, bytes, str], wallet: WalletPolicy, wallet_hmac: Optional[bytes]) -> List[Tuple[int, SignPsbtYieldedObject]]: """Signs a PSBT using a registered wallet (or a standard wallet that does not need registration). Signature requires explicit approval from the user. diff --git a/bitcoin_client/ledger_bitcoin/client_command.py b/bitcoin_client/ledger_bitcoin/client_command.py index 9e32a56ba..8495ec1c4 100644 --- a/bitcoin_client/ledger_bitcoin/client_command.py +++ b/bitcoin_client/ledger_bitcoin/client_command.py @@ -15,6 +15,10 @@ class ClientCommandCode(IntEnum): GET_MORE_ELEMENTS = 0xA0 +CCMD_YIELD_MUSIG_PUBNONCE_TAG = 0xFFFFFFFF +CCMD_YIELD_MUSIG_PARTIALSIGNATURE_TAG = 0xFFFFFFFE + + class ClientCommand: def execute(self, request: bytes) -> bytes: raise NotImplementedError("Subclasses should implement this method.") diff --git a/bitcoin_client/ledger_bitcoin/psbt.py b/bitcoin_client/ledger_bitcoin/psbt.py index 16de47d23..956b6cca0 100644 --- a/bitcoin_client/ledger_bitcoin/psbt.py +++ b/bitcoin_client/ledger_bitcoin/psbt.py @@ -1,6 +1,8 @@ # Original version: https://github.com/bitcoin-core/HWI/blob/3fe369d0379212fae1c72729a179d133b0adc872/hwilib/key.py # Distributed under the MIT License. +# fmt: off + """ PSBT Classes and Utilities ************************** @@ -107,6 +109,9 @@ class PartiallySignedInput: PSBT_IN_TAP_BIP32_DERIVATION = 0x16 PSBT_IN_TAP_INTERNAL_KEY = 0x17 PSBT_IN_TAP_MERKLE_ROOT = 0x18 + PSBT_IN_MUSIG2_PARTICIPANT_PUBKEYS = 0x1a + PSBT_IN_MUSIG2_PUB_NONCE = 0x1b + PSBT_IN_MUSIG2_PARTIAL_SIG = 0x1c def __init__(self, version: int) -> None: self.non_witness_utxo: Optional[CTransaction] = None @@ -129,6 +134,9 @@ def __init__(self, version: int) -> None: self.tap_bip32_paths: Dict[bytes, Tuple[Set[bytes], KeyOriginInfo]] = {} self.tap_internal_key = b"" self.tap_merkle_root = b"" + self.musig2_participant_pubkeys: Dict[bytes, List[bytes]] = {} + self.musig2_pub_nonces: Dict[Tuple[bytes, bytes, Optional[bytes]], bytes] = {} + self.musig2_partial_sigs: Dict[Tuple[bytes, bytes, Optional[bytes]], bytes] = {} self.unknown: Dict[bytes, bytes] = {} self.version: int = version @@ -355,6 +363,51 @@ def deserialize(self, f: Readable) -> None: self.tap_merkle_root = deser_string(f) if len(self.tap_merkle_root) != 32: raise PSBTSerializationError("Input Taproot merkle root is not 32 bytes") + elif key_type == PartiallySignedInput.PSBT_IN_MUSIG2_PARTICIPANT_PUBKEYS: + if key in key_lookup: + raise PSBTSerializationError("Duplicate key, input Musig2 participant pubkeys already provided") + elif len(key) != 1 + 33: + raise PSBTSerializationError("Input Musig2 aggregate compressed pubkey is not 33 bytes") + + pubkeys_cat = deser_string(f) + if len(pubkeys_cat) == 0: + raise PSBTSerializationError("The list of compressed pubkeys for Musig2 cannot be empty") + if (len(pubkeys_cat) % 33) != 0: + raise PSBTSerializationError("The compressed pubkeys for Musig2 must be exactly 33 bytes long") + pubkeys = [] + for i in range(0, len(pubkeys_cat), 33): + pubkeys.append(pubkeys_cat[33*i: 33*(i+1)]) + + self.musig2_participant_pubkeys[key] = pubkeys + elif key_type == PartiallySignedInput.PSBT_IN_MUSIG2_PUB_NONCE: + if key in key_lookup: + raise PSBTSerializationError("Duplicate key, Musig2 public nonce already provided") + elif len(key) not in [1 + 33 + 33, 1 + 33 + 33 + 32]: + raise PSBTSerializationError("Invalid key length for Musig2 public nonce") + + providing_pubkey = key[1:1+33] + aggregate_pubkey = key[1+33:1+33+33] + tapleaf_hash = None if len(key) == 1 + 33 + 33 else key[1+33+33:] + + public_nonces = deser_string(f) + if len(public_nonces) != 66: + raise PSBTSerializationError("The length of the public nonces in Musig2 must be exactly 66 bytes") + + self.musig2_pub_nonces[(providing_pubkey, aggregate_pubkey, tapleaf_hash)] = public_nonces + elif key_type == PartiallySignedInput.PSBT_IN_MUSIG2_PARTIAL_SIG: + if key in key_lookup: + raise PSBTSerializationError("Duplicate key, Musig2 partial signature already provided") + elif len(key) not in [1 + 33 + 33, 1 + 33 + 33 + 32]: + raise PSBTSerializationError("Invalid key length for Musig2 partial signature") + + providing_pubkey = key[1:1+33] + aggregate_pubkey = key[1+33:1+33+33] + tapleaf_hash = None if len(key) == 1 + 33 + 33 else key[1+33+33:] + + partial_sig = deser_string(f) + if len(partial_sig) != 32: + raise PSBTSerializationError("The length of the partial signature in Musig2 must be exactly 32 bytes") + self.musig2_partial_sigs[(providing_pubkey, aggregate_pubkey, tapleaf_hash)] = partial_sig else: if key in self.unknown: raise PSBTSerializationError("Duplicate key, key for unknown value already provided") @@ -466,6 +519,20 @@ def serialize(self) -> bytes: r += ser_string(ser_compact_size(PartiallySignedInput.PSBT_IN_REQUIRED_HEIGHT_LOCKTIME)) r += ser_string(struct.pack(" None: self.redeem_script = b"" @@ -497,6 +565,9 @@ def __init__(self, version: int) -> None: self.tap_internal_key = b"" self.tap_tree = b"" self.tap_bip32_paths: Dict[bytes, Tuple[Set[bytes], KeyOriginInfo]] = {} + + self.musig2_participant_pubkeys: Dict[bytes, List[bytes]] = {} + self.unknown: Dict[bytes, bytes] = {} self.version: int = version @@ -593,6 +664,22 @@ def deserialize(self, f: Readable) -> None: for i in range(0, num_hashes): leaf_hashes.add(vs.read(32)) self.tap_bip32_paths[xonly] = (leaf_hashes, KeyOriginInfo.deserialize(vs.read())) + elif key_type == PartiallySignedOutput.PSBT_OUT_MUSIG2_PARTICIPANT_PUBKEYS: + if key in key_lookup: + raise PSBTSerializationError("Duplicate key, output Musig2 participant pubkeys already provided") + elif len(key) != 1 + 33: + raise PSBTSerializationError("Output Musig2 aggregate compressed pubkey is not 33 bytes") + + pubkeys_cat = deser_string(f) + if len(pubkeys_cat) == 0: + raise PSBTSerializationError("The list of compressed pubkeys for Musig2 cannot be empty") + if (len(pubkeys_cat) % 33) != 0: + raise PSBTSerializationError("The compressed pubkeys for Musig2 must be exactly 33 bytes long") + pubkeys = [] + for i in range(0, len(pubkeys_cat), 33): + pubkeys.append(pubkeys_cat[33*i: 33*(i+1)]) + + self.musig2_participant_pubkeys[key] = pubkeys else: if key in self.unknown: raise PSBTSerializationError("Duplicate key, key for unknown value already provided") @@ -650,6 +737,11 @@ def serialize(self) -> bytes: value += origin.serialize() r += ser_string(value) + for pk, pubkeys in self.musig2_participant_pubkeys.items(): + r += ser_string(ser_compact_size( + PartiallySignedOutput.PSBT_OUT_MUSIG2_PARTICIPANT_PUBKEYS) + pk) + r += ser_string(b''.join(pubkeys)) + for key, value in sorted(self.unknown.items()): r += ser_string(key) r += ser_string(value) diff --git a/doc/musig.md b/doc/musig.md new file mode 100644 index 000000000..f3356aa89 --- /dev/null +++ b/doc/musig.md @@ -0,0 +1,87 @@ +# MuSig2 + +The Ledger Bitcoin app supports wallet policies with `musig()` key expressions. + +MuSig2 is a 2-round multi-signature scheme compatible with the public keys and signatures used in taproot transactions. The implementation is compliant with [BIP-0327](https://github.com/bitcoin/bips/blob/master/bip-0327.mediawiki). + +## Specs + +`musig()` key expressions are supported for all taproot policies, including taproot keypaths and miniscript. + +- At most 16 keys are allowed in the musig expression; performance limitations, however, might apply in practice. +- At most 8 parallel MuSig signing sessions are supported, due to the need to persist state in the device's memory. +- Only `musig(...)/**` or `musig(...)//*` key expressions are supported; the public keys must be xpubs aggregated without any further derivation. Schemes where each pubkey is derived prior to aggregation (for example descriptors similar to `musig(xpub1/<0;1>/*,xpub2/<0;1>/*,...)`) are not supported. + +## State minimization + +This section describes implementation details that allow to minimize the amount of statefor each MuSig2 signing session, allowing secure support for multiple parallel MuSig2 on embedded device with limited storage. + +### Introduction + +BIP-0327 discusses at length the necessity to keep some state during a signing session. However, a "signing session" in BIP-0327 only refers to the production of a single signature. + +In the typical signing flow of a wallet, it's more logical to consider a _session_ at the level of an entire transaction. All transaction inputs are likely obtained from the same [descriptor containing musig()](https://github.com/bitcoin/bips/pull/1540), with the signer producing the pubnonce/signature for all the inputs at once. + +Therefore, in the flow of BIP-0327, you would expect at least _one MuSig2 signing session per input_ to be active at the same time. In the context of hardware signing device support, that's somewhat problematic: it would require to persist state for an unbounded number of signing sessions, for example for a wallet that received a large number of small UTXOs. Persistent storage is often a scarce resource in embedded signing devices, and a naive approach would likely impose a maximum limit on the number of inputs of the transactions, depending on the hardware limitations. + +This document describes an approach that is compatible with and builds on top of BIP-0327 to define a _psbt-level session_ with only a small amount of state persisted on the device. Each psbt-level session allows to manage in parallel all the MuSig2 sessions involved in signing a transaction (typically, at least one for each input). Each psbt-level session only requires 64 bytes of storage for the entire transaction, regardless of the amount of inputs. + +### Signing flow with synthetic randomness + +#### Synthetic generation of BIP-0327 state + +This section presents the core idea, while the next section makes it more precise in the context of signing devices. + +In BIP-0327, the internal state that is kept by the signing device is essentially the *secnonce*, which in turn is computed from a random number _rand'_, and optionally from other parameters of _NonceGen_ which depend on the transaction being signed. + +The core idea for state minimization is to compute a global random `rand_root`; then, for the *i*-th input and for the *j*-th `musig()` key that the device is signing for in the [wallet policy](https://github.com/bitcoin/bips/pull/1389), one defines the *rand'* in _NonceGen_ as: + +$\qquad rand_{i,j} = SHA256(rand\_root || i || j)$ + +In the concatenation, a fixed-length encoding of $i$ and $j$ is used in order to avoid collisions. That is used as the *rand'* value in the *NonceGen* algorithm for that input/KEY pair. + +The *j* parameter allows to handle wallet policies that contain more than one `musig()` key expression involving the signing device. + +#### Signing flow in detail + +This section describes the handling of the psbt-level sessions, plugging on top of the default signing flow of BIP-0327. + +We assume that the signing device handles a single psbt-level session; this can be generalized to multiple parallel psbt-level sessions, where each session computes and stores a different `rand_root`. + +In the following, a _session_ always refers to the psbt-level signing session; it contains `rand_root`, and possibly any other auxiliary data that the device wishes to save while signing is in progress. + +The term *persistent memory* refers to secure storage that is not wiped out when the device is turned off. The term *volatile memory* refers to the working memory available while the device is involved in the signing process. In Ledger signing devices, the persistent storage is flash memory, and the volatile memory is the RAM of the app. Both are contained in the Secure Element. + +**Phase 1: pubnonce generation:** A PSBT is sent to the signing device, and it does not contain any pubnonce. +- If a session already exists, it is deleted from the persistent memory. +- A new session is created in volatile memory. +- The device produces a fresh random number $rand\_root$, and saves it in the current session. +- The device generates the randomness for the $i$-th input and for the $j$-th key as: $rand_{i,j} = SHA256(rand\_root || i || j)$. +- Compute each *(secnonce, pubnonce)* as per the `NonceGen` algorithm. +- At completion (after all the pubnonces are returned), the session secret $rand\_root$ is copied into the persistent memory. + +**Phase 2: partial signature generation:** A PSBT containing all the pubnonces is sent to the device. +- *A copy of the session is stored in the volatile memory, and the session is deleted from the persistent memory*. +- For each input/musig-key pair $(i, j)$: + - Recompute the pubnonce/secnonce pair using `NonceGen` with the synthetic randomness $rand_{i,j}$ as above. + - Verify that the pubnonce contained in the PSBT matches the one synthetically recomputed. + - Continue the signing flow as per BIP-0327, generating the partial signature. + +### Security considerations +#### State reuse avoidance +Storing the session in persistent memory only at the end of Phase 1, and deleting it before beginning Phase 2 simplifies auditing and making sure that there is no reuse of state across signing sessions. + +#### Security of synthetic randomness + +Generating $rand_{i, j}$ synthetically is not a problem, since the $rand\_root$ value is kept secret and never leaves the device. This ensures that all the values produced for different $i$ and $j$ not predictable for an attacker. + +#### Malleability of the PSBT +If the optional parameters are passed to the _NonceGen_ function, they will depend on the transaction data present in the PSBT. Therefore, there is no guarantee that they will be unchanged the next time the PSBT is provided. + +However, that does not constitute a security risk, as those parameters are only used as additional sources of entropy in _NonceGen_. A malicious software wallet can't affect the _secnonce_/_pubnonce_ pairs in any predictable way. Changing any of the parameters used in _NonceGen_ would cause a failure during Phase 2, as the recomputed _pubnonce_ would not match the one in the psbt. + +### Generalization to multiple PSBT signing sessions + +The approach described above assumes that no attempt to sign a PSBT containing for a wallet policy containing `musig()` keys is initiated while a session is already in progress. + +It is possible to generalize this to an arbitrary number of parallel signing sessions. Each session could be identified by a `psbt_session_id` computed by hashing together the transaction hashes, \ No newline at end of file diff --git a/src/commands.h b/src/commands.h index 63b3b4d10..aa1b3cd2d 100644 --- a/src/commands.h +++ b/src/commands.h @@ -11,3 +11,7 @@ typedef enum { GET_MASTER_FINGERPRINT = 0x05, SIGN_MESSAGE = 0x10, } command_e; + +// Tags used when yielding different objects with the YIELD client command. +#define CCMD_YIELD_MUSIG_PUBNONCE_TAG 0xffffffff +#define CCMD_YIELD_MUSIG_PARTIALSIGNATURE_TAG 0xfffffffe \ No newline at end of file diff --git a/src/common/psbt.h b/src/common/psbt.h index a566cc135..e076127df 100644 --- a/src/common/psbt.h +++ b/src/common/psbt.h @@ -3,55 +3,58 @@ // clang-format off enum PsbtGlobalType { - PSBT_GLOBAL_UNSIGNED_TX = 0x00, - PSBT_GLOBAL_XPUB = 0x01, - PSBT_GLOBAL_TX_VERSION = 0x02, - PSBT_GLOBAL_FALLBACK_LOCKTIME = 0x03, - PSBT_GLOBAL_INPUT_COUNT = 0x04, - PSBT_GLOBAL_OUTPUT_COUNT = 0x05, - PSBT_GLOBAL_TX_MODIFIABLE = 0x06, - PSBT_GLOBAL_SIGHASH_SINGLE_INPUTS = 0x07, - PSBT_GLOBAL_VERSION = 0xFB, - PSBT_GLOBAL_PROPRIETARY = 0xFC + PSBT_GLOBAL_UNSIGNED_TX = 0x00, + PSBT_GLOBAL_XPUB = 0x01, + PSBT_GLOBAL_TX_VERSION = 0x02, + PSBT_GLOBAL_FALLBACK_LOCKTIME = 0x03, + PSBT_GLOBAL_INPUT_COUNT = 0x04, + PSBT_GLOBAL_OUTPUT_COUNT = 0x05, + PSBT_GLOBAL_TX_MODIFIABLE = 0x06, + PSBT_GLOBAL_VERSION = 0xFB, + PSBT_GLOBAL_PROPRIETARY = 0xFC }; enum PsbtInputType { - PSBT_IN_NON_WITNESS_UTXO = 0x00, - PSBT_IN_WITNESS_UTXO = 0x01, - PSBT_IN_PARTIAL_SIG = 0x02, - PSBT_IN_SIGHASH_TYPE = 0x03, - PSBT_IN_REDEEM_SCRIPT = 0x04, - PSBT_IN_WITNESS_SCRIPT = 0x05, - PSBT_IN_BIP32_DERIVATION = 0x06, - PSBT_IN_FINAL_SCRIPTSIG = 0x07, - PSBT_IN_FINAL_SCRIPTWITNESS = 0x08, - PSBT_IN_POR_COMMITMENT = 0x09, - PSBT_IN_RIPEMD160 = 0x0A, - PSBT_IN_SHA256 = 0x0B, - PSBT_IN_HASH160 = 0x0C, - PSBT_IN_HASH256 = 0x0D, - PSBT_IN_PREVIOUS_TXID = 0x0E, - PSBT_IN_OUTPUT_INDEX = 0x0F, - PSBT_IN_SEQUENCE = 0x10, - PSBT_IN_REQUIRED_TIME_LOCKTIME = 0x11, - PSBT_IN_REQUIRED_HEIGHT_LOCKTIME = 0x12, - PSBT_IN_TAP_KEY_SIG = 0x13, - PSBT_IN_TAP_SCRIPT_SIG = 0x14, - PSBT_IN_TAP_LEAF_SCRIPT = 0x15, - PSBT_IN_TAP_BIP32_DERIVATION = 0x16, - PSBT_IN_TAP_INTERNAL_KEY = 0x17, - PSBT_IN_TAP_MERKLE_ROOT = 0x18, - PSBT_IN_PROPRIETARY = 0xFC + PSBT_IN_NON_WITNESS_UTXO = 0x00, + PSBT_IN_WITNESS_UTXO = 0x01, + PSBT_IN_PARTIAL_SIG = 0x02, + PSBT_IN_SIGHASH_TYPE = 0x03, + PSBT_IN_REDEEM_SCRIPT = 0x04, + PSBT_IN_WITNESS_SCRIPT = 0x05, + PSBT_IN_BIP32_DERIVATION = 0x06, + PSBT_IN_FINAL_SCRIPTSIG = 0x07, + PSBT_IN_FINAL_SCRIPTWITNESS = 0x08, + PSBT_IN_POR_COMMITMENT = 0x09, + PSBT_IN_RIPEMD160 = 0x0A, + PSBT_IN_SHA256 = 0x0B, + PSBT_IN_HASH160 = 0x0C, + PSBT_IN_HASH256 = 0x0D, + PSBT_IN_PREVIOUS_TXID = 0x0E, + PSBT_IN_OUTPUT_INDEX = 0x0F, + PSBT_IN_SEQUENCE = 0x10, + PSBT_IN_REQUIRED_TIME_LOCKTIME = 0x11, + PSBT_IN_REQUIRED_HEIGHT_LOCKTIME = 0x12, + PSBT_IN_TAP_KEY_SIG = 0x13, + PSBT_IN_TAP_SCRIPT_SIG = 0x14, + PSBT_IN_TAP_LEAF_SCRIPT = 0x15, + PSBT_IN_TAP_BIP32_DERIVATION = 0x16, + PSBT_IN_TAP_INTERNAL_KEY = 0x17, + PSBT_IN_TAP_MERKLE_ROOT = 0x18, + PSBT_IN_MUSIG2_PARTICIPANT_PUBKEYS = 0x1A, + PSBT_IN_MUSIG2_PUB_NONCE = 0x1B, + PSBT_IN_MUSIG2_PARTIAL_SIG = 0x1C, + PSBT_IN_PROPRIETARY = 0xFC }; enum PsbtOutputType { - PSBT_OUT_REDEEM_SCRIPT = 0x00, - PSBT_OUT_WITNESS_SCRIPT = 0x01, - PSBT_OUT_BIP32_DERIVATION = 0x02, - PSBT_OUT_AMOUNT = 0x03, - PSBT_OUT_SCRIPT = 0x04, - PSBT_OUT_TAP_INTERNAL_KEY = 0x05, - PSBT_OUT_TAP_TREE = 0x06, - PSBT_OUT_TAP_BIP32_DERIVATION = 0x07, - PSBT_OUT_PROPRIETARY = 0xFC + PSBT_OUT_REDEEM_SCRIPT = 0x00, + PSBT_OUT_WITNESS_SCRIPT = 0x01, + PSBT_OUT_BIP32_DERIVATION = 0x02, + PSBT_OUT_AMOUNT = 0x03, + PSBT_OUT_SCRIPT = 0x04, + PSBT_OUT_TAP_INTERNAL_KEY = 0x05, + PSBT_OUT_TAP_TREE = 0x06, + PSBT_OUT_TAP_BIP32_DERIVATION = 0x07, + PSBT_OUT_MUSIG2_PARTICIPANT_PUBKEYS = 0x08, + PSBT_OUT_PROPRIETARY = 0xFC }; \ No newline at end of file diff --git a/src/common/wallet.c b/src/common/wallet.c index 5821ff011..d67a60267 100644 --- a/src/common/wallet.c +++ b/src/common/wallet.c @@ -424,18 +424,119 @@ int parse_policy_map_key_info(buffer_t *buffer, policy_map_key_info_t *out, int return 0; } -static int parse_placeholder(buffer_t *in_buf, int version, policy_node_key_placeholder_t *out) { +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wcomment" +// The compiler doesn't like /** inside a block comment, so we disable this warning temporarily. + +/** + * Parses a key expression, in one of the following forms: + * - Single key index: + * - @IDX/** + * - @IDX//* + * - MuSig2 aggregate key (only if is_taproot is true): + * - musig(@IDX,@IDX,...,@IDX)/** + * - musig(@IDX,@IDX,...,@IDX)//* + * where IDX is a key index. + */ +#pragma GCC diagnostic pop +static int parse_keyexpr(buffer_t *in_buf, + int version, + policy_node_keyexpr_t *out, + bool is_taproot, + buffer_t *out_buf) { char c; - if (!buffer_read_u8(in_buf, (uint8_t *) &c) || c != '@') { - return WITH_ERROR(-1, "Expected key placeholder starting with '@'"); + if (!buffer_read_u8(in_buf, (uint8_t *) &c)) { + return WITH_ERROR(-1, "Expected key expression"); } - uint32_t k; - if (parse_unsigned_decimal(in_buf, &k) == -1 || k > INT16_MAX) { - return WITH_ERROR(-1, "The key index in a placeholder must be at most 32767"); - } + if (c == '@') { + out->type = KEY_EXPRESSION_NORMAL; + + uint32_t k; + if (parse_unsigned_decimal(in_buf, &k) == -1 || k > INT16_MAX) { + return WITH_ERROR(-1, "The key index in a placeholder must be at most 32767"); + } + + out->k.key_index = (int16_t) k; + } else if (c == 'm') { + // parse a musig(key1,...,keyn) expression, where each key is a key expression + if (!consume_characters(in_buf, "usig(", 5)) { + return WITH_ERROR(-1, "Expected musig key expression"); + } + + if (!is_taproot) { + return WITH_ERROR(-1, "musig is only allows in taproot"); + } + + out->type = KEY_EXPRESSION_MUSIG; + + if (version != WALLET_POLICY_VERSION_V2) { + return WITH_ERROR(-1, "musig key expressions are only supported with version number 2"); + } + + uint16_t key_placeholders[MAX_PUBKEYS_PER_MUSIG]; + int n_musig_keys = 0; + + // parse comma-separated list of @NUM + while (true) { + if (!buffer_read_u8(in_buf, (uint8_t *) &c) || c != '@') { + return WITH_ERROR(-1, "Expected key placeholder starting with '@'"); + } + + uint32_t k; + if (parse_unsigned_decimal(in_buf, &k) == -1 || k > INT16_MAX) { + return WITH_ERROR(-1, "The key index in a placeholder must be at most 32767"); + } - out->key_index = (int16_t) k; + if (n_musig_keys >= MAX_PUBKEYS_PER_MUSIG) { + return WITH_ERROR(-1, "Too many keys in musig"); + } + + key_placeholders[n_musig_keys] = (uint16_t) k; + ++n_musig_keys; + + // the next character must be "," if there are more keys, or ')' otherwise + if (!buffer_read_u8(in_buf, (uint8_t *) &c)) { + return WITH_ERROR(-1, "Expression terminated prematurely"); + } + + if (c == ')') { + break; + } else if (c != ',') { + return WITH_ERROR(-1, "Invalid character in musig; expected ',' or ')'"); + } + } + + if (n_musig_keys < 2) { + return WITH_ERROR(-1, "musig must have at least 2 key indexes"); + } + if (n_musig_keys > MAX_PUBKEYS_PER_MUSIG) { + return WITH_ERROR(-1, "Too many keys in musig"); + } + + // allocate musig structures + + musig_aggr_key_info_t *musig_info = + (musig_aggr_key_info_t *) buffer_alloc(out_buf, sizeof(musig_info), true); + + if (musig_info == NULL) { + return WITH_ERROR(-1, "Out of memory"); + } + + uint16_t *key_indexes = + (uint16_t *) buffer_alloc(out_buf, sizeof(uint16_t) * n_musig_keys, true); + if (key_indexes == NULL) { + return WITH_ERROR(-1, "Out of memory"); + } + memcpy(key_indexes, key_placeholders, sizeof(uint16_t) * n_musig_keys); + + musig_info->n = n_musig_keys; + i_uint16(&musig_info->key_indexes, key_indexes); + + i_musig_aggr_key_info(&out->m.musig_info, musig_info); + } else { + return WITH_ERROR(-1, "Expected key placeholder starting with '@', or musig"); + } if (version == WALLET_POLICY_VERSION_V1) { // default values for compatibility with the new code @@ -448,12 +549,12 @@ static int parse_placeholder(buffer_t *in_buf, int version, policy_node_key_plac || !buffer_peek(in_buf, &next_character) // we must be able to read the next character || !(next_character == '*' || next_character == '<') // and it must be '*' or '<' ) { - return WITH_ERROR(-1, "Expected /** or //* in key placeholder"); + return WITH_ERROR(-1, "Expected /** or //* in key expression"); } if (next_character == '*') { if (!consume_characters(in_buf, "**", 2)) { - return WITH_ERROR(-1, "Expected /** or //* in key placeholder"); + return WITH_ERROR(-1, "Expected /** or //* in key expression"); } out->num_first = 0; out->num_second = 1; @@ -463,18 +564,18 @@ static int parse_placeholder(buffer_t *in_buf, int version, policy_node_key_plac out->num_first > 0x80000000u) { return WITH_ERROR( -1, - "Expected /** or //* in key placeholder, with unhardened M and N"); + "Expected /** or //* in key expression, with unhardened M and N"); } if (!consume_character(in_buf, ';')) { - return WITH_ERROR(-1, "Expected /** or //* in key placeholder"); + return WITH_ERROR(-1, "Expected /** or //* in key expression"); } if (parse_unsigned_decimal(in_buf, &out->num_second) == -1 || out->num_second > 0x80000000u) { return WITH_ERROR( -1, - "Expected /** or //* in key placeholder, with unhardened M and N"); + "Expected /** or //* in key expression, with unhardened M and N"); } if (out->num_first == out->num_second) { @@ -482,7 +583,7 @@ static int parse_placeholder(buffer_t *in_buf, int version, policy_node_key_plac } if (!consume_characters(in_buf, ">/*", 3)) { - return WITH_ERROR(-1, "Expected /** or //* in key placeholder"); + return WITH_ERROR(-1, "Expected /** or //* in key expression"); } } } else { @@ -1378,13 +1479,13 @@ static int parse_script(buffer_t *in_buf, return WITH_ERROR(-1, "Out of memory"); } - policy_node_key_placeholder_t *key_placeholder = - buffer_alloc(out_buf, sizeof(policy_node_key_placeholder_t), true); + policy_node_keyexpr_t *key_expr = + buffer_alloc(out_buf, sizeof(policy_node_keyexpr_t), true); - if (key_placeholder == NULL) { + if (key_expr == NULL) { return WITH_ERROR(-1, "Out of memory"); } - i_policy_node_key_placeholder(&node->key_placeholder, key_placeholder); + i_policy_node_keyexpr(&node->key, key_expr); if (token == TOKEN_WPKH) { if (depth > 0 && ((context_flags & CONTEXT_WITHIN_SH) == 0)) { @@ -1396,8 +1497,9 @@ static int parse_script(buffer_t *in_buf, node->base.type = token; - if (0 > parse_placeholder(in_buf, version, key_placeholder)) { - return WITH_ERROR(-1, "Couldn't parse key placeholder"); + bool is_taproot = (context_flags & CONTEXT_WITHIN_TR) != 0; + if (0 > parse_keyexpr(in_buf, version, key_expr, is_taproot, out_buf)) { + return WITH_ERROR(-1, "Couldn't parse key expression"); } if (token == TOKEN_WPKH) { @@ -1459,15 +1561,15 @@ static int parse_script(buffer_t *in_buf, return WITH_ERROR(-1, "Out of memory"); } - policy_node_key_placeholder_t *key_placeholder = - buffer_alloc(out_buf, sizeof(policy_node_key_placeholder_t), true); - if (key_placeholder == NULL) { + policy_node_keyexpr_t *key_expr = + buffer_alloc(out_buf, sizeof(policy_node_keyexpr_t), true); + if (key_expr == NULL) { return WITH_ERROR(-1, "Out of memory"); } - i_policy_node_key_placeholder(&node->key_placeholder, key_placeholder); + i_policy_node_keyexpr(&node->key, key_expr); - if (0 > parse_placeholder(in_buf, version, key_placeholder)) { - return WITH_ERROR(-1, "Couldn't parse key placeholder"); + if (0 > parse_keyexpr(in_buf, version, key_expr, true, out_buf)) { + return WITH_ERROR(-1, "Couldn't parse key expression"); } uint8_t c; @@ -1543,7 +1645,8 @@ static int parse_script(buffer_t *in_buf, return WITH_ERROR(-1, "Out of memory"); } - if ((context_flags & CONTEXT_WITHIN_TR) != 0) { + bool is_taproot = (context_flags & CONTEXT_WITHIN_TR) != 0; + if (is_taproot) { if (token != TOKEN_MULTI_A && token != TOKEN_SORTEDMULTI_A) { return WITH_ERROR( -1, @@ -1581,7 +1684,7 @@ static int parse_script(buffer_t *in_buf, // We allocate the array of key indices at the current position in the output buffer // (on success) buffer_alloc(out_buf, 0, true); // ensure alignment of current pointer - i_policy_node_key_placeholder(&node->key_placeholders, buffer_get_cur(out_buf)); + i_policy_node_keyexpr(&node->keys, buffer_get_cur(out_buf)); node->n = 0; while (true) { @@ -1596,18 +1699,17 @@ static int parse_script(buffer_t *in_buf, return WITH_ERROR(-1, "Expected ','"); } - policy_node_key_placeholder_t *key_placeholder = - (policy_node_key_placeholder_t *) buffer_alloc( - out_buf, - sizeof(policy_node_key_placeholder_t), - true); // we align this pointer, as there's padding in an array of - // structures - if (key_placeholder == NULL) { + policy_node_keyexpr_t *key_expr = (policy_node_keyexpr_t *) buffer_alloc( + out_buf, + sizeof(policy_node_keyexpr_t), + true); // we align this pointer, as there's padding in an array of + // structures + if (key_expr == NULL) { return WITH_ERROR(-1, "Out of memory"); } - if (0 > parse_placeholder(in_buf, version, key_placeholder)) { - return WITH_ERROR(-1, "Error parsing key placeholder"); + if (0 > parse_keyexpr(in_buf, version, key_expr, is_taproot, out_buf)) { + return WITH_ERROR(-1, "Error parsing key expression"); } ++node->n; diff --git a/src/common/wallet.h b/src/common/wallet.h index 5435292f7..901acb387 100644 --- a/src/common/wallet.h +++ b/src/common/wallet.h @@ -19,6 +19,14 @@ // bitcoin-core supports up to 20, but we limit to 16 as bigger pushes require special handling. #define MAX_PUBKEYS_PER_MULTISIG 16 +// The maximum number of keys supported in a musig() key expression +// It is basically unlimited in theory, but we need to set a maximum limit. +#ifdef TARGET_NANOS +#define MAX_PUBKEYS_PER_MUSIG 3 +#else +#define MAX_PUBKEYS_PER_MUSIG MAX_PUBKEYS_PER_MULTISIG +#endif + #define WALLET_POLICY_VERSION_V1 1 // the legacy version of the first release #define WALLET_POLICY_VERSION_V2 2 // the current full version @@ -280,29 +288,54 @@ typedef struct policy_node_ext_info_s { unsigned int x : 1; // the last opcode is not EQUAL, CHECKSIG, or CHECKMULTISIG } policy_node_ext_info_t; +DEFINE_REL_PTR(uint16, uint16_t) + +typedef struct { + int16_t n; // number of key indexes + rptr_uint16_t key_indexes; // pointer to an array of exactly n key indexes +} musig_aggr_key_info_t; + +DEFINE_REL_PTR(musig_aggr_key_info, musig_aggr_key_info_t) + +typedef enum { + KEY_EXPRESSION_NORMAL = 0, // a key expression with a single key + KEY_EXPRESSION_MUSIG = 1 // a key expression containing a musig() +} KeyExpressionType; + #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wcomment" // The compiler doesn't like /** inside a block comment, so we disable this warning temporarily. -/** Structure representing a key placeholder. +/** Structure representing a key expression. * In V1, it's the index of a key expression in the key informations array, which includes the final * / ** step. * In V2, it's the index of a key expression in the key informations array, plus the two * numbers a, b in the //* derivation steps; here, the xpubs in the key informations * array don't have extra derivation steps. + * In V2, musig() key expressions are also represented in this struct. */ #pragma GCC diagnostic pop + // 12 bytes typedef struct { // the following fields are only used in V2 uint32_t num_first; // NUM_a of //* uint32_t num_second; // NUM_b of //* - // common between V1 and V2 - int16_t key_index; // index of the key -} policy_node_key_placeholder_t; + KeyExpressionType type; + union { + // type == 0 + struct { + int16_t key_index; // index of the key (common between V1 and V2) + } k; + // type == 1 + struct { + rptr_musig_aggr_key_info_t musig_info; + } m; + }; +} policy_node_keyexpr_t; -DEFINE_REL_PTR(policy_node_key_placeholder, policy_node_key_placeholder_t) +DEFINE_REL_PTR(policy_node_keyexpr, policy_node_keyexpr_t) // 4 bytes typedef struct { @@ -333,7 +366,7 @@ typedef policy_node_with_script3_t policy_node_with_scripts_t; // 4 bytes typedef struct { struct policy_node_s base; - rptr_policy_node_key_placeholder_t key_placeholder; + rptr_policy_node_keyexpr_t key; } policy_node_with_key_t; // 8 bytes @@ -344,11 +377,10 @@ typedef struct { // 12 bytes typedef struct { - struct policy_node_s base; // type is TOKEN_MULTI or TOKEN_SORTEDMULTI - int16_t k; // threshold - int16_t n; // number of keys - rptr_policy_node_key_placeholder_t - key_placeholders; // pointer to array of exactly n key placeholders + struct policy_node_s base; // type is TOKEN_MULTI or TOKEN_SORTEDMULTI + int16_t k; // threshold + int16_t n; // number of keys + rptr_policy_node_keyexpr_t keys; // pointer to array of exactly n key placeholders } policy_node_multisig_t; // 8 bytes @@ -398,7 +430,7 @@ typedef struct policy_node_tree_s { typedef struct { struct policy_node_s base; - rptr_policy_node_key_placeholder_t key_placeholder; + rptr_policy_node_keyexpr_t key; rptr_policy_node_tree_t tree; // NULL if tr(KP) } policy_node_tr_t; diff --git a/src/crypto.c b/src/crypto.c index 37486988e..a540b0183 100644 --- a/src/crypto.c +++ b/src/crypto.c @@ -39,44 +39,12 @@ #include "crypto.h" -/** - * Generator for secp256k1, value 'g' defined in "Standards for Efficient Cryptography" - * (SEC2) 2.7.1. - */ -// clang-format off -static const uint8_t secp256k1_generator[] = { - 0x04, - 0x79, 0xBE, 0x66, 0x7E, 0xF9, 0xDC, 0xBB, 0xAC, 0x55, 0xA0, 0x62, 0x95, 0xCE, 0x87, 0x0B, 0x07, - 0x02, 0x9B, 0xFC, 0xDB, 0x2D, 0xCE, 0x28, 0xD9, 0x59, 0xF2, 0x81, 0x5B, 0x16, 0xF8, 0x17, 0x98, - 0x48, 0x3A, 0xDA, 0x77, 0x26, 0xA3, 0xC4, 0x65, 0x5D, 0xA4, 0xFB, 0xFC, 0x0E, 0x11, 0x08, 0xA8, - 0xFD, 0x17, 0xB4, 0x48, 0xA6, 0x85, 0x54, 0x19, 0x9C, 0x47, 0xD0, 0x8F, 0xFB, 0x10, 0xD4, 0xB8}; -// clang-format on - -/** - * Modulo for secp256k1 - */ -static const uint8_t secp256k1_p[] = { - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xfc, 0x2f}; - -/** - * Curve order for secp256k1 - */ -static const uint8_t secp256k1_n[] = { - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, - 0xba, 0xae, 0xdc, 0xe6, 0xaf, 0x48, 0xa0, 0x3b, 0xbf, 0xd2, 0x5e, 0x8c, 0xd0, 0x36, 0x41, 0x41}; - -/** - * (p + 1)/4, used to calculate square roots in secp256k1 - */ -static const uint8_t secp256k1_sqr_exponent[] = { - 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xbf, 0xff, 0xff, 0x0c}; +#include "secp256k1.h" /* BIP0341 tags for computing the tagged hashes when tweaking public keys */ -static const uint8_t BIP0341_taptweak_tag[] = {'T', 'a', 'p', 'T', 'w', 'e', 'a', 'k'}; -static const uint8_t BIP0341_tapbranch_tag[] = {'T', 'a', 'p', 'B', 'r', 'a', 'n', 'c', 'h'}; -static const uint8_t BIP0341_tapleaf_tag[] = {'T', 'a', 'p', 'L', 'e', 'a', 'f'}; +const uint8_t BIP0341_taptweak_tag[8] = {'T', 'a', 'p', 'T', 'w', 'e', 'a', 'k'}; +const uint8_t BIP0341_tapbranch_tag[9] = {'T', 'a', 'p', 'B', 'r', 'a', 'n', 'c', 'h'}; +const uint8_t BIP0341_tapleaf_tag[7] = {'T', 'a', 'p', 'L', 'e', 'a', 'f'}; /** * Gets the point on the SECP256K1 that corresponds to kG, where G is the curve's generator point. @@ -90,7 +58,8 @@ static int secp256k1_point(const uint8_t k[static 32], uint8_t out[static 65]) { int bip32_CKDpub(const serialized_extended_pubkey_t *parent, uint32_t index, - serialized_extended_pubkey_t *child) { + serialized_extended_pubkey_t *child, + uint8_t *tweak) { PRINT_STACK_POINTER(); if (index >= BIP32_FIRST_HARDENED_CHILD) { @@ -115,6 +84,10 @@ int bip32_CKDpub(const serialized_extended_pubkey_t *parent, uint8_t *I_L = &I[0]; uint8_t *I_R = &I[32]; + if (tweak != NULL) { + memcpy(tweak, I_L, 32); + } + // fail if I_L is not smaller than the group order n, but the probability is < 1/2^128 int diff; if (CX_OK != cx_math_cmp_no_throw(I_L, secp256k1_n, 32, &diff) || diff >= 0) { @@ -432,7 +405,7 @@ void crypto_tr_tapleaf_hash_init(cx_sha256_t *hash_context) { crypto_tr_tagged_hash_init(hash_context, BIP0341_tapleaf_tag, sizeof(BIP0341_tapleaf_tag)); } -static int crypto_tr_lift_x(const uint8_t x[static 32], uint8_t out[static 65]) { +int crypto_tr_lift_x(const uint8_t x[static 32], uint8_t out[static 65]) { // save memory by reusing output buffer for intermediate results uint8_t *y = out + 1 + 32; // we use the memory for the x-coordinate of the output as a temporary variable @@ -471,13 +444,13 @@ static int crypto_tr_lift_x(const uint8_t x[static 32], uint8_t out[static 65]) // Computes a tagged hash according to BIP-340. // If data2_len > 0, then data2 must be non-NULL and the `data` and `data2` arrays are concatenated. -static void crypto_tr_tagged_hash(const uint8_t *tag, - uint16_t tag_len, - const uint8_t *data, - uint16_t data_len, - const uint8_t *data2, - uint16_t data2_len, - uint8_t out[static CX_SHA256_SIZE]) { +void crypto_tr_tagged_hash(const uint8_t *tag, + uint16_t tag_len, + const uint8_t *data, + uint16_t data_len, + const uint8_t *data2, + uint16_t data2_len, + uint8_t out[static CX_SHA256_SIZE]) { // First compute hashtag, reuse out buffer for that cx_sha256_hash(tag, tag_len, out); diff --git a/src/crypto.h b/src/crypto.h index cb8394bb5..466824caf 100644 --- a/src/crypto.h +++ b/src/crypto.h @@ -36,18 +36,22 @@ typedef struct { * * @param[in] parent * Pointer to the extended serialized pubkey of the parent. - * @param[out] index + * @param[in] index * Index of the child to derive. It MUST be not hardened, that is, strictly less than 0x80000000. * @param[out] child * Pointer to the output struct for the child's serialized pubkey. It can equal parent, which in * that case is overwritten. + * @param[out] tweak + * If not NULL, pointer to a 32-byte array that will receive the 32-byte tweak used during the + * child key derivation. * * @return 0 if success, a negative number on failure. * */ int bip32_CKDpub(const serialized_extended_pubkey_t *parent, uint32_t index, - serialized_extended_pubkey_t *child); + serialized_extended_pubkey_t *child, + uint8_t *tweak); /** * Convenience wrapper for cx_hash_no_throw to add some data to an initialized hash context. @@ -331,6 +335,11 @@ int crypto_ecdsa_sign_sha256_hash_with_key(const uint32_t bip32_path[], uint8_t out[static MAX_DER_SIG_LEN], uint32_t *info); +// Constants defined in BIP-0341 +extern const uint8_t BIP0341_taptweak_tag[8]; +extern const uint8_t BIP0341_tapbranch_tag[9]; +extern const uint8_t BIP0341_tapleaf_tag[7]; + /** * Initializes the "tagged" SHA256 hash with the given tag, as defined by BIP-0340. * @@ -343,6 +352,43 @@ int crypto_ecdsa_sign_sha256_hash_with_key(const uint32_t bip32_path[], */ void crypto_tr_tagged_hash_init(cx_sha256_t *hash_context, const uint8_t *tag, uint16_t tag_len); +/** + * Implementation of the lift_x procedure as defined by BIP-0340. + * + * @param[in] x + * Pointer to a 32-byte array. + * @param[out] out + * Pointer to an array that will received the output as an uncompressed 65-bytes pubkey. + */ +int crypto_tr_lift_x(const uint8_t x[static 32], uint8_t out[static 65]); + +/** + * A tagged hash as defined in BIP-0340. + * + * @param[in] tag + * Pointer to an array containing the tag of the tagged hash. + * @param[in] tag_len + * Length of the tag. + * @param[in] data + * Pointer to an array of data. + * @param[in] data_len + * Length of the array pointed by `data`. + * @param[in] data2 + * If NULL, ignored. If not null, a pointer to an array of data; the tagged hash for the + * concatenation of `data` and `data2` is computed. + * @param[in] data2_len + * If `data2` is NULL, ignored. Otherwise, the length the array pointed by `data2`. + * @param[out] out + * Pointer to a 32-byte array that will receive the result. + */ +void crypto_tr_tagged_hash(const uint8_t *tag, + uint16_t tag_len, + const uint8_t *data, + uint16_t data_len, + const uint8_t *data2, + uint16_t data2_len, + uint8_t out[static CX_SHA256_SIZE]); + /** * Initializes the "tagged" SHA256 hash with tag "TapLeaf", used for tapscript leaves. * diff --git a/src/handler/lib/policy.c b/src/handler/lib/policy.c index c7b7800d1..9f56f0503 100644 --- a/src/handler/lib/policy.c +++ b/src/handler/lib/policy.c @@ -5,6 +5,7 @@ #include "../lib/get_merkle_leaf_element.h" #include "../lib/get_preimage.h" #include "../../crypto.h" +#include "../../musig/musig.h" #include "../../common/base58.h" #include "../../common/bitvector.h" #include "../../common/read.h" @@ -419,7 +420,7 @@ execute_processor(policy_parser_state_t *state, policy_parser_processor_t proc, // convenience function, split from get_derived_pubkey only to improve stack usage // returns -1 on error, 0 if the returned key info has no wildcard (**), 1 if it has the wildcard -__attribute__((noinline, warn_unused_result)) static int get_extended_pubkey( +__attribute__((noinline, warn_unused_result)) int get_extended_pubkey( dispatcher_context_t *dispatcher_context, const wallet_derivation_info_t *wdi, int key_index, @@ -456,23 +457,61 @@ __attribute__((noinline, warn_unused_result)) static int get_extended_pubkey( __attribute__((warn_unused_result)) static int get_derived_pubkey( dispatcher_context_t *dispatcher_context, const wallet_derivation_info_t *wdi, - const policy_node_key_placeholder_t *key_placeholder, + const policy_node_keyexpr_t *key_expr, uint8_t out[static 33]) { PRINT_STACK_POINTER(); serialized_extended_pubkey_t ext_pubkey; - int ret = get_extended_pubkey(dispatcher_context, wdi, key_placeholder->key_index, &ext_pubkey); - if (ret < 0) { - return -1; + if (key_expr->type == KEY_EXPRESSION_NORMAL) { + if (0 > get_extended_pubkey(dispatcher_context, wdi, key_expr->k.key_index, &ext_pubkey)) { + return -1; + } + } else if (key_expr->type == KEY_EXPRESSION_MUSIG) { + const musig_aggr_key_info_t *musig_info = r_musig_aggr_key_info(&key_expr->m.musig_info); + const uint16_t *key_indexes = r_uint16(&musig_info->key_indexes); + plain_pk_t keys[MAX_PUBKEYS_PER_MUSIG]; + for (int i = 0; i < musig_info->n; i++) { + // we use ext_pubkey as a temporary variable; will overwrite later + if (0 > get_extended_pubkey(dispatcher_context, wdi, key_indexes[i], &ext_pubkey)) { + return -1; + } + memcpy(keys[i], ext_pubkey.compressed_pubkey, sizeof(ext_pubkey.compressed_pubkey)); + } + + // sort the keys in ascending order using bubble sort + for (int i = 0; i < musig_info->n; i++) { + for (int j = 0; j < musig_info->n - 1; j++) { + if (memcmp(keys[j], keys[j + 1], sizeof(plain_pk_t)) > 0) { + uint8_t tmp[sizeof(plain_pk_t)]; + memcpy(tmp, keys[j], sizeof(plain_pk_t)); + memcpy(keys[j], keys[j + 1], sizeof(plain_pk_t)); + memcpy(keys[j + 1], tmp, sizeof(plain_pk_t)); + } + } + } + + musig_keyagg_context_t musig_ctx; + musig_key_agg(keys, musig_info->n, &musig_ctx); + + // compute the aggregated extended pubkey + memset(&ext_pubkey, 0, sizeof(ext_pubkey)); + write_u32_be(ext_pubkey.version, 0, BIP32_PUBKEY_VERSION); + + ext_pubkey.compressed_pubkey[0] = (musig_ctx.Q.y[31] % 2 == 0) ? 2 : 3; + memcpy(&ext_pubkey.compressed_pubkey[1], musig_ctx.Q.x, sizeof(musig_ctx.Q.x)); + memcpy(&ext_pubkey.chain_code, BIP_MUSIG_CHAINCODE, sizeof(BIP_MUSIG_CHAINCODE)); + } else { + LEDGER_ASSERT(false, "Unreachable code"); } // we derive the // child of this pubkey // we reuse the same memory of ext_pubkey bip32_CKDpub(&ext_pubkey, - wdi->change ? key_placeholder->num_second : key_placeholder->num_first, - &ext_pubkey); - bip32_CKDpub(&ext_pubkey, wdi->address_index, &ext_pubkey); + wdi->change ? key_expr->num_second : key_expr->num_first, + &ext_pubkey, + NULL); + bip32_CKDpub(&ext_pubkey, wdi->address_index, &ext_pubkey, NULL); memcpy(out, ext_pubkey.compressed_pubkey, 33); @@ -569,11 +608,10 @@ __attribute__((warn_unused_result)) static int process_generic_node(policy_parse const policy_node_with_key_t *policy = (const policy_node_with_key_t *) node->policy_node; uint8_t compressed_pubkey[33]; - if (-1 == - get_derived_pubkey(state->dispatcher_context, - state->wdi, - r_policy_node_key_placeholder(&policy->key_placeholder), - compressed_pubkey)) { + if (-1 == get_derived_pubkey(state->dispatcher_context, + state->wdi, + r_policy_node_keyexpr(&policy->key), + compressed_pubkey)) { return -1; } @@ -591,11 +629,10 @@ __attribute__((warn_unused_result)) static int process_generic_node(policy_parse const policy_node_with_key_t *policy = (const policy_node_with_key_t *) node->policy_node; uint8_t compressed_pubkey[33]; - if (-1 == - get_derived_pubkey(state->dispatcher_context, - state->wdi, - r_policy_node_key_placeholder(&policy->key_placeholder), - compressed_pubkey)) { + if (-1 == get_derived_pubkey(state->dispatcher_context, + state->wdi, + r_policy_node_keyexpr(&policy->key), + compressed_pubkey)) { return -1; } if (!state->is_taproot) { @@ -684,7 +721,7 @@ __attribute__((warn_unused_result)) static int process_pkh_wpkh_node(policy_pars if (-1 == get_derived_pubkey(state->dispatcher_context, state->wdi, - r_policy_node_key_placeholder(&policy->key_placeholder), + r_policy_node_keyexpr(&policy->key), compressed_pubkey)) { return -1; } else if (policy->base.type == TOKEN_PKH) { @@ -811,11 +848,10 @@ __attribute__((warn_unused_result)) static int process_multi_sortedmulti_node( uint8_t compressed_pubkey[33]; if (policy->base.type == TOKEN_MULTI) { - if (-1 == - get_derived_pubkey(state->dispatcher_context, - state->wdi, - &r_policy_node_key_placeholder(&policy->key_placeholders)[i], - compressed_pubkey)) { + if (-1 == get_derived_pubkey(state->dispatcher_context, + state->wdi, + &r_policy_node_keyexpr(&policy->keys)[i], + compressed_pubkey)) { return -1; } } else { @@ -837,11 +873,10 @@ __attribute__((warn_unused_result)) static int process_multi_sortedmulti_node( for (int j = 0; j < policy->n; j++) { if (!bitvector_get(used, j)) { uint8_t cur_pubkey[33]; - if (-1 == get_derived_pubkey( - state->dispatcher_context, - state->wdi, - &r_policy_node_key_placeholder(&policy->key_placeholders)[j], - cur_pubkey)) { + if (-1 == get_derived_pubkey(state->dispatcher_context, + state->wdi, + &r_policy_node_keyexpr(&policy->keys)[j], + cur_pubkey)) { return -1; } @@ -889,11 +924,10 @@ __attribute__((warn_unused_result)) static int process_multi_a_sortedmulti_a_nod uint8_t compressed_pubkey[33]; if (policy->base.type == TOKEN_MULTI_A) { - if (-1 == - get_derived_pubkey(state->dispatcher_context, - state->wdi, - &r_policy_node_key_placeholder(&policy->key_placeholders)[i], - compressed_pubkey)) { + if (-1 == get_derived_pubkey(state->dispatcher_context, + state->wdi, + &r_policy_node_keyexpr(&policy->keys)[i], + compressed_pubkey)) { return -1; } } else { @@ -905,11 +939,10 @@ __attribute__((warn_unused_result)) static int process_multi_a_sortedmulti_a_nod for (int j = 0; j < policy->n; j++) { if (!bitvector_get(used, j)) { uint8_t cur_pubkey[33]; - if (-1 == get_derived_pubkey( - state->dispatcher_context, - state->wdi, - &r_policy_node_key_placeholder(&policy->key_placeholders)[j], - cur_pubkey)) { + if (-1 == get_derived_pubkey(state->dispatcher_context, + state->wdi, + &r_policy_node_keyexpr(&policy->keys)[j], + cur_pubkey)) { return -1; } @@ -1018,7 +1051,7 @@ int get_wallet_script(dispatcher_context_t *dispatcher_context, policy_node_with_key_t *pkh_policy = (policy_node_with_key_t *) policy; if (0 > get_derived_pubkey(dispatcher_context, wdi, - r_policy_node_key_placeholder(&pkh_policy->key_placeholder), + r_policy_node_keyexpr(&pkh_policy->key), compressed_pubkey)) { return -1; } @@ -1037,7 +1070,7 @@ int get_wallet_script(dispatcher_context_t *dispatcher_context, policy_node_with_key_t *wpkh_policy = (policy_node_with_key_t *) policy; if (0 > get_derived_pubkey(dispatcher_context, wdi, - r_policy_node_key_placeholder(&wpkh_policy->key_placeholder), + r_policy_node_keyexpr(&wpkh_policy->key), compressed_pubkey)) { return -1; } @@ -1116,7 +1149,7 @@ int get_wallet_script(dispatcher_context_t *dispatcher_context, if (0 > get_derived_pubkey(dispatcher_context, wdi, - r_policy_node_key_placeholder(&tr_policy->key_placeholder), + r_policy_node_keyexpr(&tr_policy->key), compressed_pubkey)) { return -1; } @@ -1344,17 +1377,17 @@ __attribute__((noinline)) int get_wallet_internal_script_hash( // For a standard descriptor template, return the corresponding BIP44 purpose // Otherwise, returns -1. static int get_bip44_purpose(const policy_node_t *descriptor_template) { - const policy_node_key_placeholder_t *kp = NULL; + const policy_node_keyexpr_t *kp = NULL; int purpose = -1; switch (descriptor_template->type) { case TOKEN_PKH: - kp = r_policy_node_key_placeholder( - &((const policy_node_with_key_t *) descriptor_template)->key_placeholder); + kp = + r_policy_node_keyexpr(&((const policy_node_with_key_t *) descriptor_template)->key); purpose = 44; // legacy break; case TOKEN_WPKH: - kp = r_policy_node_key_placeholder( - &((const policy_node_with_key_t *) descriptor_template)->key_placeholder); + kp = + r_policy_node_keyexpr(&((const policy_node_with_key_t *) descriptor_template)->key); purpose = 84; // native segwit break; case TOKEN_SH: { @@ -1364,8 +1397,7 @@ static int get_bip44_purpose(const policy_node_t *descriptor_template) { return -1; } - kp = r_policy_node_key_placeholder( - &((const policy_node_with_key_t *) inner)->key_placeholder); + kp = r_policy_node_keyexpr(&((const policy_node_with_key_t *) inner)->key); purpose = 49; // nested segwit break; } @@ -1375,8 +1407,7 @@ static int get_bip44_purpose(const policy_node_t *descriptor_template) { return -1; } - kp = r_policy_node_key_placeholder( - &((const policy_node_tr_t *) descriptor_template)->key_placeholder); + kp = r_policy_node_keyexpr(&((const policy_node_tr_t *) descriptor_template)->key); purpose = 86; // standard single-key P2TR break; } @@ -1384,7 +1415,12 @@ static int get_bip44_purpose(const policy_node_t *descriptor_template) { return -1; } - if (kp->key_index != 0 || kp->num_first != 0 || kp->num_second != 1) { + if (kp->type != KEY_EXPRESSION_NORMAL) { + // any key expression that is not a plain xpub is not BIP-44 compliant + return -1; + } + + if (kp->k.key_index != 0 || kp->num_first != 0 || kp->num_second != 1) { return -1; } @@ -1513,44 +1549,43 @@ bool check_wallet_hmac(const uint8_t wallet_id[static 32], const uint8_t wallet_ // make sure that the compiler gives an error if any PolicyNodeType is missed #pragma GCC diagnostic error "-Wswitch-enum" -static int get_key_placeholder_by_index_in_tree(const policy_node_tree_t *tree, - unsigned int i, - const policy_node_t **out_tapleaf_ptr, - policy_node_key_placeholder_t *out_placeholder) { +static int get_keyexpr_by_index_in_tree(const policy_node_tree_t *tree, + unsigned int i, + const policy_node_t **out_tapleaf_ptr, + policy_node_keyexpr_t **out_keyexpr) { if (tree->is_leaf) { - int ret = - get_key_placeholder_by_index(r_policy_node(&tree->script), i, NULL, out_placeholder); + int ret = get_keyexpr_by_index(r_policy_node(&tree->script), i, NULL, out_keyexpr); if (ret >= 0 && out_tapleaf_ptr != NULL && i < (unsigned) ret) { *out_tapleaf_ptr = r_policy_node(&tree->script); } return ret; } else { - int ret1 = get_key_placeholder_by_index_in_tree(r_policy_node_tree(&tree->left_tree), - i, - out_tapleaf_ptr, - out_placeholder); + int ret1 = get_keyexpr_by_index_in_tree(r_policy_node_tree(&tree->left_tree), + i, + out_tapleaf_ptr, + out_keyexpr); if (ret1 < 0) return -1; bool found = i < (unsigned int) ret1; - int ret2 = get_key_placeholder_by_index_in_tree(r_policy_node_tree(&tree->right_tree), - found ? 0 : i - ret1, - found ? NULL : out_tapleaf_ptr, - found ? NULL : out_placeholder); + int ret2 = get_keyexpr_by_index_in_tree(r_policy_node_tree(&tree->right_tree), + found ? 0 : i - ret1, + found ? NULL : out_tapleaf_ptr, + found ? NULL : out_keyexpr); if (ret2 < 0) return -1; return ret1 + ret2; } } -int get_key_placeholder_by_index(const policy_node_t *policy, - unsigned int i, - const policy_node_t **out_tapleaf_ptr, - policy_node_key_placeholder_t *out_placeholder) { - // make sure that out_placeholder is a valid pointer, if the output is not needed - policy_node_key_placeholder_t tmp; - if (out_placeholder == NULL) { - out_placeholder = &tmp; +int get_keyexpr_by_index(const policy_node_t *policy, + unsigned int i, + const policy_node_t **out_tapleaf_ptr, + policy_node_keyexpr_t **out_keyexpr) { + // make sure that out_keyexpr is a valid pointer, if the output is not needed + policy_node_keyexpr_t *tmp; + if (out_keyexpr == NULL) { + out_keyexpr = &tmp; } switch (policy->type) { @@ -1573,26 +1608,22 @@ int get_key_placeholder_by_index(const policy_node_t *policy, case TOKEN_WPKH: { if (i == 0) { policy_node_with_key_t *wpkh = (policy_node_with_key_t *) policy; - memcpy(out_placeholder, - r_policy_node_key_placeholder(&wpkh->key_placeholder), - sizeof(policy_node_key_placeholder_t)); + *out_keyexpr = r_policy_node_keyexpr(&wpkh->key); } return 1; } case TOKEN_TR: { policy_node_tr_t *tr = (policy_node_tr_t *) policy; if (i == 0) { - memcpy(out_placeholder, - r_policy_node_key_placeholder(&tr->key_placeholder), - sizeof(policy_node_key_placeholder_t)); + *out_keyexpr = r_policy_node_keyexpr(&tr->key); } if (!isnull_policy_node_tree(&tr->tree)) { - int ret_tree = get_key_placeholder_by_index_in_tree( + int ret_tree = get_keyexpr_by_index_in_tree( r_policy_node_tree(&tr->tree), i == 0 ? 0 : i - 1, i == 0 ? NULL : out_tapleaf_ptr, - i == 0 ? NULL : out_placeholder); // if i == 0, we already found it; so we - // recur with out_placeholder set to NULL + i == 0 ? NULL : out_keyexpr); // if i == 0, we already found it; so we + // recur with out_keyexpr set to NULL if (ret_tree < 0) { return -1; } @@ -1610,9 +1641,8 @@ int get_key_placeholder_by_index(const policy_node_t *policy, const policy_node_multisig_t *node = (const policy_node_multisig_t *) policy; if (i < (unsigned int) node->n) { - policy_node_key_placeholder_t *placeholders = - r_policy_node_key_placeholder(&node->key_placeholders); - memcpy(out_placeholder, &placeholders[i], sizeof(policy_node_key_placeholder_t)); + policy_node_keyexpr_t *key_expressions = r_policy_node_keyexpr(&node->keys); + *out_keyexpr = &key_expressions[i]; } return node->n; @@ -1631,11 +1661,11 @@ int get_key_placeholder_by_index(const policy_node_t *policy, case TOKEN_N: case TOKEN_L: case TOKEN_U: { - return get_key_placeholder_by_index( + return get_keyexpr_by_index( r_policy_node(&((const policy_node_with_script_t *) policy)->script), i, out_tapleaf_ptr, - out_placeholder); + out_keyexpr); } // nodes with exactly two child scripts @@ -1647,17 +1677,17 @@ int get_key_placeholder_by_index(const policy_node_t *policy, case TOKEN_OR_D: case TOKEN_OR_I: { const policy_node_with_script2_t *node = (const policy_node_with_script2_t *) policy; - int ret1 = get_key_placeholder_by_index(r_policy_node(&node->scripts[0]), - i, - out_tapleaf_ptr, - out_placeholder); + int ret1 = get_keyexpr_by_index(r_policy_node(&node->scripts[0]), + i, + out_tapleaf_ptr, + out_keyexpr); if (ret1 < 0) return -1; bool found = i < (unsigned int) ret1; - int ret2 = get_key_placeholder_by_index(r_policy_node(&node->scripts[1]), - found ? 0 : i - ret1, - found ? NULL : out_tapleaf_ptr, - found ? NULL : out_placeholder); + int ret2 = get_keyexpr_by_index(r_policy_node(&node->scripts[1]), + found ? 0 : i - ret1, + found ? NULL : out_tapleaf_ptr, + found ? NULL : out_keyexpr); if (ret2 < 0) return -1; return ret1 + ret2; @@ -1666,24 +1696,24 @@ int get_key_placeholder_by_index(const policy_node_t *policy, // nodes with exactly three child scripts case TOKEN_ANDOR: { const policy_node_with_script3_t *node = (const policy_node_with_script3_t *) policy; - int ret1 = get_key_placeholder_by_index(r_policy_node(&node->scripts[0]), - i, - out_tapleaf_ptr, - out_placeholder); + int ret1 = get_keyexpr_by_index(r_policy_node(&node->scripts[0]), + i, + out_tapleaf_ptr, + out_keyexpr); if (ret1 < 0) return -1; bool found = i < (unsigned int) ret1; - int ret2 = get_key_placeholder_by_index(r_policy_node(&node->scripts[1]), - found ? 0 : i - ret1, - found ? NULL : out_tapleaf_ptr, - found ? NULL : out_placeholder); + int ret2 = get_keyexpr_by_index(r_policy_node(&node->scripts[1]), + found ? 0 : i - ret1, + found ? NULL : out_tapleaf_ptr, + found ? NULL : out_keyexpr); if (ret2 < 0) return -1; found = i < (unsigned int) (ret1 + ret2); - int ret3 = get_key_placeholder_by_index(r_policy_node(&node->scripts[2]), - found ? 0 : i - ret1 - ret2, - found ? NULL : out_tapleaf_ptr, - found ? NULL : out_placeholder); + int ret3 = get_keyexpr_by_index(r_policy_node(&node->scripts[2]), + found ? 0 : i - ret1 - ret2, + found ? NULL : out_tapleaf_ptr, + found ? NULL : out_keyexpr); if (ret3 < 0) return -1; return ret1 + ret2 + ret3; } @@ -1699,10 +1729,10 @@ int get_key_placeholder_by_index(const policy_node_t *policy, "The script should always have exactly n child scripts"); found = i < (unsigned int) ret; - int ret_partial = get_key_placeholder_by_index(r_policy_node(&cur_child->script), - found ? 0 : i - ret, - found ? NULL : out_tapleaf_ptr, - found ? NULL : out_placeholder); + int ret_partial = get_keyexpr_by_index(r_policy_node(&cur_child->script), + found ? 0 : i - ret, + found ? NULL : out_tapleaf_ptr, + found ? NULL : out_keyexpr); if (ret_partial < 0) return -1; ret += ret_partial; @@ -1724,18 +1754,28 @@ int get_key_placeholder_by_index(const policy_node_t *policy, int count_distinct_keys_info(const policy_node_t *policy) { int ret = -1; - - int n_placeholders = get_key_placeholder_by_index(policy, 0, NULL, NULL); - if (n_placeholders < 0) { + policy_node_keyexpr_t *key_expression_ptr; + int n_key_expressions = get_keyexpr_by_index(policy, 0, NULL, NULL); + if (n_key_expressions < 0) { return -1; } - for (int cur = 0; cur < n_placeholders; ++cur) { - policy_node_key_placeholder_t placeholder; - if (0 > get_key_placeholder_by_index(policy, cur, NULL, &placeholder)) { + for (int cur = 0; cur < n_key_expressions; ++cur) { + if (0 > get_keyexpr_by_index(policy, cur, NULL, &key_expression_ptr)) { return -1; } - ret = MAX(ret, placeholder.key_index + 1); + if (key_expression_ptr->type == KEY_EXPRESSION_NORMAL) { + ret = MAX(ret, key_expression_ptr->k.key_index + 1); + } else if (key_expression_ptr->type == KEY_EXPRESSION_MUSIG) { + const musig_aggr_key_info_t *musig_info = + r_musig_aggr_key_info(&key_expression_ptr->m.musig_info); + const uint16_t *key_indexes = r_uint16(&musig_info->key_indexes); + for (int i = 0; i < musig_info->n; i++) { + ret = MAX(ret, key_indexes[i] + 1); + } + } else { + LEDGER_ASSERT(false, "Unknown key expression type"); + } } return ret; } @@ -1852,6 +1892,19 @@ static int is_taptree_miniscript_sane(const policy_node_tree_t *taptree) { return 0; } +// sort an array of uint16_t in place using bubble sort +static void sort_uint16_array(uint16_t *array, size_t n) { + for (size_t i = 0; i < n; i++) { + for (size_t j = i + 1; j < n; j++) { + if (array[i] > array[j]) { + uint16_t tmp = array[i]; + array[i] = array[j]; + array[j] = tmp; + } + } + } +} + int is_policy_sane(dispatcher_context_t *dispatcher_context, const policy_node_t *policy, int wallet_version, @@ -1909,36 +1962,109 @@ int is_policy_sane(dispatcher_context_t *dispatcher_context, } } - // check that all the key placeholders for the same xpub do indeed have different + // check that all the key expressions for the same xpub do indeed have different // derivations - int n_placeholders = get_key_placeholder_by_index(policy, 0, NULL, NULL); - if (n_placeholders < 0) { - return WITH_ERROR(-1, "Unexpected error while counting placeholders"); + int n_key_expressions = get_keyexpr_by_index(policy, 0, NULL, NULL); + if (n_key_expressions < 0) { + return WITH_ERROR(-1, "Unexpected error while counting key expressions"); } - // The following loop computationally very inefficient (quadratic in the number of - // placeholders), but more efficient solutions likely require a substantial amount of RAM - // (proportional to the number of key placeholders). Instead, this only requires stack depth + // for each MuSig key expression, checks that the key indices are all distinct + for (int i = 0; i < n_key_expressions; i++) { + policy_node_keyexpr_t *kp_i; + if (0 > get_keyexpr_by_index(policy, i, NULL, &kp_i)) { + return WITH_ERROR(-1, "Unexpected error retrieving key expressions from the policy"); + } + if (kp_i->type == KEY_EXPRESSION_MUSIG) { + const musig_aggr_key_info_t *musig_info_i = r_musig_aggr_key_info(&kp_i->m.musig_info); + const uint16_t *key_indexes_i = r_uint16(&musig_info_i->key_indexes); + + uint16_t key_indexes_i_sorted[MAX_PUBKEYS_PER_MUSIG]; + memcpy(key_indexes_i_sorted, key_indexes_i, musig_info_i->n * sizeof(uint16_t)); + + // sort the arrays + sort_uint16_array(key_indexes_i_sorted, musig_info_i->n); + for (int j = 0; j < musig_info_i->n - 1; j++) { + if (key_indexes_i_sorted[j] == key_indexes_i_sorted[j + 1]) { + return WITH_ERROR(-1, "Repeated key in musig key expression"); + } + } + } + } + + // The following loop is computationally very inefficient (quadratic in the number of + // key expressions), but more efficient solutions likely require a substantial amount of RAM + // (proportional to the number of key expressions). Instead, this only requires stack depth // proportional to the depth of the wallet policy's abstract syntax tree. - for (int i = 0; i < n_placeholders - 1; - i++) { // no point in running this for the last placeholder - policy_node_key_placeholder_t kp_i; - if (0 > get_key_placeholder_by_index(policy, i, NULL, &kp_i)) { - return WITH_ERROR(-1, "Unexpected error retrieving placeholders from the policy"); + for (int i = 0; i < n_key_expressions - 1; + i++) { // no point in running this for the last key expression + policy_node_keyexpr_t *kp_i; + if (0 > get_keyexpr_by_index(policy, i, NULL, &kp_i)) { + return WITH_ERROR(-1, "Unexpected error retrieving key expressions from the policy"); } - for (int j = i + 1; j < n_placeholders; j++) { - policy_node_key_placeholder_t kp_j; - if (0 > get_key_placeholder_by_index(policy, j, NULL, &kp_j)) { - return WITH_ERROR(-1, "Unexpected error retrieving placeholders from the policy"); + for (int j = i + 1; j < n_key_expressions; j++) { + policy_node_keyexpr_t *kp_j; + if (0 > get_keyexpr_by_index(policy, j, NULL, &kp_j)) { + return WITH_ERROR(-1, + "Unexpected error retrieving key expressions from the policy"); } - // placeholders for the same key must have disjoint derivation options - if (kp_i.key_index == kp_j.key_index) { - if (kp_i.num_first == kp_j.num_first || kp_i.num_first == kp_j.num_second || - kp_i.num_second == kp_j.num_first || kp_i.num_second == kp_j.num_second) { + if ((kp_i->type == KEY_EXPRESSION_NORMAL && kp_j->type == KEY_EXPRESSION_MUSIG) || + (kp_i->type == KEY_EXPRESSION_MUSIG && kp_j->type == KEY_EXPRESSION_NORMAL)) { + // if one is a key and the other is a musig, there's nothing else to check + continue; + } else if (kp_i->type == KEY_EXPRESSION_NORMAL && kp_j->type == KEY_EXPRESSION_NORMAL) { + // key expressions for the same key must have disjoint derivation options + if (kp_i->k.key_index == kp_j->k.key_index) { + if (kp_i->num_first == kp_j->num_first || kp_i->num_first == kp_j->num_second || + kp_i->num_second == kp_j->num_first || + kp_i->num_second == kp_j->num_second) { + return WITH_ERROR( + -1, + "Key expressions with repeated derivations in miniscript"); + } + } + } else if (kp_i->type == KEY_EXPRESSION_MUSIG && kp_j->type == KEY_EXPRESSION_MUSIG) { + const musig_aggr_key_info_t *musig_info_i = + r_musig_aggr_key_info(&kp_i->m.musig_info); + const uint16_t *key_indexes_i = r_uint16(&musig_info_i->key_indexes); + const musig_aggr_key_info_t *musig_info_j = + r_musig_aggr_key_info(&kp_j->m.musig_info); + const uint16_t *key_indexes_j = r_uint16(&musig_info_j->key_indexes); + // if two musigs have exactly the same set of keys, then the derivation options must + // be disjoint + + // make sure that there is no repeated key in the first musig + + if (musig_info_i->n != musig_info_j->n) { + continue; // cannot be the same set if the size is different + } + + uint16_t key_indexes_i_sorted[MAX_PUBKEYS_PER_MUSIG]; + uint16_t key_indexes_j_sorted[MAX_PUBKEYS_PER_MUSIG]; + memcpy(key_indexes_i_sorted, key_indexes_i, musig_info_i->n * sizeof(uint16_t)); + memcpy(key_indexes_j_sorted, key_indexes_j, musig_info_j->n * sizeof(uint16_t)); + + // sort the arrays + sort_uint16_array(key_indexes_i_sorted, musig_info_i->n); + sort_uint16_array(key_indexes_j_sorted, musig_info_j->n); + + if (memcmp(key_indexes_i_sorted, + key_indexes_j_sorted, + musig_info_i->n * sizeof(uint16_t)) != 0) { + continue; // different set of keys + } + + // same set of keys; therefore, we need to check that the derivation options are + // disjoint + if (kp_i->num_first == kp_j->num_first || kp_i->num_first == kp_j->num_second || + kp_i->num_second == kp_j->num_first || kp_i->num_second == kp_j->num_second) { return WITH_ERROR(-1, - "Key placeholders with repeated derivations in miniscript"); + "Key expressions with repeated derivations in miniscript"); } + + } else { + LEDGER_ASSERT(false, "Unexpected key expression type"); } } } diff --git a/src/handler/lib/policy.h b/src/handler/lib/policy.h index 121560ce4..06a371129 100644 --- a/src/handler/lib/policy.h +++ b/src/handler/lib/policy.h @@ -50,6 +50,29 @@ typedef struct { bool change; // whether a change address or a receive address is derived } wallet_derivation_info_t; +/** + * Computes the a derived compressed pubkey for one of the key of the wallet policy, + * for a given change/address_index combination. + * + * This function computes the extended public key (xpub) based on the provided + * BIP32 derivation path. It supports both standard BIP32 derivation and + * the derivation of Musig (multi-signature) keys. + * + * @param[in] dispatcher_context Pointer to the dispatcher content + * @param[in] wdi Pointer to a `wallet_derivation_info_t` struct with the details of the + * necessary details of the wallet policy, and the desired change/address_index pair. + * @param[in] key_index Index of the pubkey in the vector of keys of the wallet policy. + * @param[out] out Pointer to a `serialized_extended_pubkey_t` that will contain the requested + * extended pubkey. + * + * @return -1 on error, 0 if the returned key info has no wildcard (**), 1 if it has the wildcard. + */ +__attribute__((warn_unused_result)) int get_extended_pubkey( + dispatcher_context_t *dispatcher_context, + const wallet_derivation_info_t *wdi, + int key_index, + serialized_extended_pubkey_t *out); + /** * Computes the hash of a taptree, to be used as tweak for the internal key per BIP-0341; * The returned hash is the second value in the tuple returned by taproot_tree_helper in @@ -176,31 +199,31 @@ bool compute_wallet_hmac(const uint8_t wallet_id[static 32], uint8_t wallet_hmac bool check_wallet_hmac(const uint8_t wallet_id[static 32], const uint8_t wallet_hmac[static 32]); /** - * Copies the i-th placeholder (indexing from 0) of the given policy into `out_placeholder` (if not + * Copies the i-th key expression (indexing from 0) of the given policy into `out_keyexpr` (if not * null). * * @param[in] policy * Pointer to the root node of the policy * @param[in] i - * Index of the wanted placeholder. Ignored if out_placeholder is NULL. + * Index of the wanted placeholder. Ignored if out_keyexpr is NULL. * @param[out] out_tapleaf_ptr - * If not NULL, and if the i-th placeholder is in a tapleaf of the policy, receives the pointer to - * the tapleaf's script. - * @param[out] out_placeholder - * If not NULL, it is a pointer that will receive the i-th placeholder of the policy. - * @return the number of placeholders in the policy on success; -1 in case of error. + * If not NULL, and if the i-th key expression is in a tapleaf of the policy, receives the pointer + * to the tapleaf's script. + * @param[out] out_keyexpr + * If not NULL, it is a pointer that will receive a pointer to the i-th key expression of the + * policy. + * @return the number of key expressions in the policy on success; -1 in case of error. */ -__attribute__((warn_unused_result)) int get_key_placeholder_by_index( - const policy_node_t *policy, - unsigned int i, - const policy_node_t **out_tapleaf_ptr, - policy_node_key_placeholder_t *out_placeholder); +__attribute__((warn_unused_result)) int get_keyexpr_by_index(const policy_node_t *policy, + unsigned int i, + const policy_node_t **out_tapleaf_ptr, + policy_node_keyexpr_t **out_keyexpr); /** * Determines the expected number of unique keys in the provided policy's key information. - * The function calculates this by finding the maximum key index from placeholders and increments it - * by 1. For instance, if the maximum key index found in the placeholders is `n`, then the result - * would be `n + 1`. + * The function calculates this by finding the maximum key index from key expressions and increments + * it by 1. For instance, if the maximum key index found in the key expressions is `n`, then the + * result would be `n + 1`. * * @param[in] policy * Pointer to the root node of the policy diff --git a/src/handler/sign_psbt.c b/src/handler/sign_psbt.c index 4193ff4f8..91731a593 100644 --- a/src/handler/sign_psbt.c +++ b/src/handler/sign_psbt.c @@ -54,6 +54,8 @@ #include "../swap/swap_globals.h" #include "../swap/handle_swap_sign_transaction.h" +#include "../musig/musig.h" +#include "../musig/musig_sessions.h" // common info that applies to either the current input or the current output typedef struct { @@ -63,8 +65,8 @@ typedef struct { // PSBT_{IN,OUT}_BIP32_DERIVATION or // PSBT_{IN,OUT}_TAP_BIP32_DERIVATION is not the correct length. - bool placeholder_found; // Set to true if a matching placeholder is found in the input info - + bool key_expression_found; // Set to true if the input/output info in the psbt was correctly + // matched with the current key expression in the signing flow bool is_change; int address_index; @@ -104,24 +106,45 @@ typedef struct { } output_info_t; typedef struct { - policy_node_key_placeholder_t placeholder; + policy_node_keyexpr_t *key_expression_ptr; int cur_index; uint32_t fingerprint; - uint8_t key_derivation_length; + + // info about the internal key of this key expression + // used at signing time to derive the correct key uint32_t key_derivation[MAX_BIP32_PATH_STEPS]; + uint8_t key_derivation_length; + + // same as key_derivation_length for internal key + // expressions; 0 for musig, as the key derivation in + // the PSBT use the aggregate key as the root + // used to identify the correct change/address_index from the psbt + uint8_t psbt_root_key_derivation_length; + + // the root pubkey of this key expression serialized_extended_pubkey_t pubkey; + // the pubkey of the internal key of this key expression. + // same as `pubkey` for simple key expressions, but it's the actual + // internal key for musig key expressions + serialized_extended_pubkey_t internal_pubkey; bool is_tapscript; // true if signing with a BIP342 tapleaf script path spend uint8_t tapleaf_hash[32]; // only used for tapscripts -} placeholder_info_t; +} keyexpr_info_t; -// Cache for partial hashes during segwit signing (avoid quadratic hashing for segwit transactions) -typedef struct { +// Cache for partial hashes during signing (avoid quadratic hashing for segwit transactions) +typedef struct tx_hashes_s { uint8_t sha_prevouts[32]; uint8_t sha_amounts[32]; uint8_t sha_scriptpubkeys[32]; uint8_t sha_sequences[32]; uint8_t sha_outputs[32]; -} segwit_hashes_t; +} tx_hashes_t; + +// the signing state for the current transaction; it does not contain any per-input state +typedef struct signing_state_s { + tx_hashes_t tx_hashes; + musig_signing_state_t musig; +} signing_state_t; // We cache the first 2 external outputs; that's needed for the swap checks // Moreover, this helps the code for the simplified UX for transactions that @@ -371,7 +394,7 @@ static int get_amount_scriptpubkey_from_psbt( // PSBT_{IN|OUT}_{TAP}?_BIP32_DERIVATION fields. static int read_change_and_index_from_psbt_bip32_derivation( dispatcher_context_t *dc, - placeholder_info_t *placeholder_info, + const keyexpr_info_t *keyexpr_info, in_out_info_t *in_out, int psbt_key_type, buffer_t *data, @@ -413,13 +436,13 @@ static int read_change_and_index_from_psbt_bip32_derivation( return -1; } - // if this derivation path matches the internal placeholder, + // if this derivation path matches the key expression, // we use it to detect whether the current input is change or not, // and store its address index - if (fpt_der[0] == placeholder_info->fingerprint && - der_len == placeholder_info->key_derivation_length + 2) { - for (int i = 0; i < placeholder_info->key_derivation_length; i++) { - if (placeholder_info->key_derivation[i] != fpt_der[1 + i]) { + if (fpt_der[0] == keyexpr_info->fingerprint && + der_len == keyexpr_info->psbt_root_key_derivation_length + 2) { + for (int i = 0; i < keyexpr_info->psbt_root_key_derivation_length; i++) { + if (keyexpr_info->key_derivation[i] != fpt_der[1 + i]) { return 0; } } @@ -427,28 +450,31 @@ static int read_change_and_index_from_psbt_bip32_derivation( uint32_t change = fpt_der[1 + der_len - 2]; uint32_t addr_index = fpt_der[1 + der_len - 1]; - // check that we can indeed derive the same key from the current placeholder - serialized_extended_pubkey_t pubkey; - if (0 > bip32_CKDpub(&placeholder_info->pubkey, change, &pubkey)) return -1; - if (0 > bip32_CKDpub(&pubkey, addr_index, &pubkey)) return -1; - - int pk_offset = is_tap ? 1 : 0; - if (memcmp(pubkey.compressed_pubkey + pk_offset, bip32_derivation_pubkey, key_len) != 0) { - return 0; - } - - // check if the 'change' derivation step is indeed coherent with placeholder - if (change == placeholder_info->placeholder.num_first) { + // TODO: safe to remove this check? It should be, since we later re-derive + // the script independently. + // // check that we can indeed derive the same key from the current key expression + // serialized_extended_pubkey_t pubkey; + // if (0 > bip32_CKDpub(&keyexpr_info->pubkey, change, &pubkey, NULL)) return -1; + // if (0 > bip32_CKDpub(&pubkey, addr_index, &pubkey, NULL)) return -1; + + // int pk_offset = is_tap ? 1 : 0; + // if (memcmp(pubkey.compressed_pubkey + pk_offset, bip32_derivation_pubkey, key_len) != 0) + // { + // return 0; + // } + + // check if the 'change' derivation step is indeed coherent with key expression + if (change == keyexpr_info->key_expression_ptr->num_first) { in_out->is_change = false; in_out->address_index = addr_index; - } else if (change == placeholder_info->placeholder.num_second) { + } else if (change == keyexpr_info->key_expression_ptr->num_second) { in_out->is_change = true; in_out->address_index = addr_index; } else { return 0; } - in_out->placeholder_found = true; + in_out->key_expression_found = true; return 1; } return 0; @@ -465,9 +491,9 @@ static int is_in_out_internal(dispatcher_context_t *dispatcher_context, const sign_psbt_state_t *state, const in_out_info_t *in_out_info, bool is_input) { - // If we did not find any info about the pubkey associated to the placeholder we're considering, - // then it's external - if (!in_out_info->placeholder_found) { + // If we did not find any info about the pubkey associated to the key expression we're + // considering, then it's external + if (!in_out_info->key_expression_found) { return 0; } @@ -651,96 +677,182 @@ init_global_state(dispatcher_context_t *dc, sign_psbt_state_t *st) { return true; } -static bool __attribute__((noinline)) -fill_placeholder_info_if_internal(dispatcher_context_t *dc, - sign_psbt_state_t *st, - placeholder_info_t *placeholder_info) { +static bool __attribute__((noinline)) get_and_verify_key_info(dispatcher_context_t *dc, + sign_psbt_state_t *st, + uint16_t key_index, + keyexpr_info_t *keyexpr_info) { policy_map_key_info_t key_info; - { - uint8_t key_info_str[MAX_POLICY_KEY_INFO_LEN]; - int key_info_len = call_get_merkle_leaf_element(dc, - st->wallet_header.keys_info_merkle_root, - st->wallet_header.n_keys, - placeholder_info->placeholder.key_index, - key_info_str, - sizeof(key_info_str)); - - if (key_info_len < 0) { - SEND_SW(dc, SW_BAD_STATE); // should never happen - return false; - } + uint8_t key_info_str[MAX_POLICY_KEY_INFO_LEN]; - // Make a sub-buffer for the pubkey info - buffer_t key_info_buffer = buffer_create(key_info_str, key_info_len); + int key_info_len = call_get_merkle_leaf_element(dc, + st->wallet_header.keys_info_merkle_root, + st->wallet_header.n_keys, + key_index, + key_info_str, + sizeof(key_info_str)); + if (key_info_len < 0) { + return false; // should never happen + } - if (parse_policy_map_key_info(&key_info_buffer, &key_info, st->wallet_header.version) == - -1) { - SEND_SW(dc, SW_BAD_STATE); // should never happen - return false; - } + // Make a sub-buffer for the pubkey info + buffer_t key_info_buffer = buffer_create(key_info_str, key_info_len); + + if (parse_policy_map_key_info(&key_info_buffer, &key_info, st->wallet_header.version) == -1) { + return false; // should never happen + } + + keyexpr_info->key_derivation_length = key_info.master_key_derivation_len; + for (int i = 0; i < key_info.master_key_derivation_len; i++) { + keyexpr_info->key_derivation[i] = key_info.master_key_derivation[i]; } + keyexpr_info->fingerprint = read_u32_be(key_info.master_key_fingerprint, 0); + + memcpy(&keyexpr_info->pubkey, &key_info.ext_pubkey, sizeof(serialized_extended_pubkey_t)); + + // the rest of the function verifies if the key is indeed internal, if it has our fingerprint uint32_t fpr = read_u32_be(key_info.master_key_fingerprint, 0); if (fpr != st->master_key_fingerprint) { return false; } - { - // it could be a collision on the fingerprint; we verify that we can actually generate - // the same pubkey - if (0 > get_extended_pubkey_at_path(key_info.master_key_derivation, - key_info.master_key_derivation_len, - BIP32_PUBKEY_VERSION, - &placeholder_info->pubkey)) { - SEND_SW(dc, SW_BAD_STATE); - return false; - } + // it could be a collision on the fingerprint; we verify that we can actually generate + // the same pubkey + serialized_extended_pubkey_t derived_pubkey; + if (0 > get_extended_pubkey_at_path(key_info.master_key_derivation, + key_info.master_key_derivation_len, + BIP32_PUBKEY_VERSION, + &derived_pubkey)) { + return false; + } - if (memcmp(&key_info.ext_pubkey, - &placeholder_info->pubkey, - sizeof(placeholder_info->pubkey)) != 0) { - return false; + if (memcmp(&key_info.ext_pubkey, &derived_pubkey, sizeof(derived_pubkey)) != 0) { + return false; + } + + return true; +} + +static bool fill_keyexpr_info_if_internal(dispatcher_context_t *dc, + sign_psbt_state_t *st, + keyexpr_info_t *keyexpr_info) { + keyexpr_info_t tmp_keyexpr_info; + // preserve the fields that are already computed outside of this function + memcpy(&tmp_keyexpr_info, keyexpr_info, sizeof(keyexpr_info_t)); + + if (keyexpr_info->key_expression_ptr->type == KEY_EXPRESSION_NORMAL) { + bool result = get_and_verify_key_info(dc, + st, + keyexpr_info->key_expression_ptr->k.key_index, + &tmp_keyexpr_info); + if (result) { + memcpy(keyexpr_info, &tmp_keyexpr_info, sizeof(keyexpr_info_t)); + memcpy(&keyexpr_info->internal_pubkey, + &keyexpr_info->pubkey, + sizeof(serialized_extended_pubkey_t)); + keyexpr_info->psbt_root_key_derivation_length = keyexpr_info->key_derivation_length; + } + return result; + } else if (keyexpr_info->key_expression_ptr->type == KEY_EXPRESSION_MUSIG) { + // iterate through the keys of the musig() placeholder to find if a key is internal + const musig_aggr_key_info_t *musig_info = + r_musig_aggr_key_info(&keyexpr_info->key_expression_ptr->m.musig_info); + const uint16_t *key_indexes = r_uint16(&musig_info->key_indexes); + + bool has_internal_key = false; + + // collect the keys of the musig, and fill the info related to the internal key (if any) + uint8_t keys[MAX_PUBKEYS_PER_MUSIG][33]; + + LEDGER_ASSERT(musig_info->n <= MAX_PUBKEYS_PER_MUSIG, "Too many keys in musig placeholder"); + + for (int idx_in_musig = 0; idx_in_musig < musig_info->n; idx_in_musig++) { + if (get_and_verify_key_info(dc, st, key_indexes[idx_in_musig], &tmp_keyexpr_info)) { + memcpy(keyexpr_info->key_derivation, + tmp_keyexpr_info.key_derivation, + sizeof(tmp_keyexpr_info.key_derivation)); + keyexpr_info->key_derivation_length = tmp_keyexpr_info.key_derivation_length; + + // keep track of the actual internal key of this key expression + memcpy(&keyexpr_info->internal_pubkey, + &tmp_keyexpr_info.pubkey, + sizeof(serialized_extended_pubkey_t)); + + has_internal_key = true; + } + + memcpy(keys[idx_in_musig], tmp_keyexpr_info.pubkey.compressed_pubkey, 33); } - placeholder_info->key_derivation_length = key_info.master_key_derivation_len; - for (int i = 0; i < key_info.master_key_derivation_len; i++) { - placeholder_info->key_derivation[i] = key_info.master_key_derivation[i]; + if (has_internal_key) { + keyexpr_info->psbt_root_key_derivation_length = 0; + + // sort the keys in ascending order using bubble sort + for (int i = 0; i < musig_info->n; i++) { + for (int j = 0; j < musig_info->n - 1; j++) { + if (memcmp(keys[j], keys[j + 1], sizeof(plain_pk_t)) > 0) { + uint8_t tmp[sizeof(plain_pk_t)]; + memcpy(tmp, keys[j], sizeof(plain_pk_t)); + memcpy(keys[j], keys[j + 1], sizeof(plain_pk_t)); + memcpy(keys[j + 1], tmp, sizeof(plain_pk_t)); + } + } + } + + musig_keyagg_context_t musig_ctx; + musig_key_agg(keys, musig_info->n, &musig_ctx); + + // compute the aggregated extended pubkey + memset(&keyexpr_info->pubkey, 0, sizeof(keyexpr_info->pubkey)); + write_u32_be(keyexpr_info->pubkey.version, 0, BIP32_PUBKEY_VERSION); + + keyexpr_info->pubkey.compressed_pubkey[0] = (musig_ctx.Q.y[31] % 2 == 0) ? 2 : 3; + memcpy(&keyexpr_info->pubkey.compressed_pubkey[1], + musig_ctx.Q.x, + sizeof(musig_ctx.Q.x)); + memcpy(&keyexpr_info->pubkey.chain_code, + BIP_MUSIG_CHAINCODE, + sizeof(BIP_MUSIG_CHAINCODE)); + + keyexpr_info->fingerprint = + crypto_get_key_fingerprint(keyexpr_info->pubkey.compressed_pubkey); } - placeholder_info->fingerprint = read_u32_be(key_info.master_key_fingerprint, 0); + return has_internal_key; // no internal key found in musig placeholder + } else { + LEDGER_ASSERT(false, "Unreachable code"); + return false; } - - return true; } -// finds the first placeholder that corresponds to an internal key -static bool find_first_internal_key_placeholder(dispatcher_context_t *dc, - sign_psbt_state_t *st, - placeholder_info_t *placeholder_info) { - placeholder_info->cur_index = 0; +// finds the first key expression that corresponds to an internal key +static bool find_first_internal_keyexpr(dispatcher_context_t *dc, + sign_psbt_state_t *st, + keyexpr_info_t *keyexpr_info) { + keyexpr_info->cur_index = 0; // find and parse our registered key info in the wallet while (true) { - int n_key_placeholders = get_key_placeholder_by_index(st->wallet_policy_map, - placeholder_info->cur_index, - NULL, - &placeholder_info->placeholder); - if (n_key_placeholders < 0) { + int n_key_expressions = get_keyexpr_by_index(st->wallet_policy_map, + keyexpr_info->cur_index, + NULL, + &keyexpr_info->key_expression_ptr); + if (n_key_expressions < 0) { SEND_SW(dc, SW_BAD_STATE); // should never happen return false; } - if (placeholder_info->cur_index >= n_key_placeholders) { + if (keyexpr_info->cur_index >= n_key_expressions) { // all keys have been processed break; } - if (fill_placeholder_info_if_internal(dc, st, placeholder_info)) { + if (fill_keyexpr_info_if_internal(dc, st, keyexpr_info)) { return true; } // Not an internal key, move on - ++placeholder_info->cur_index; + ++keyexpr_info->cur_index; } PRINTF("No internal key found in wallet policy"); @@ -749,7 +861,7 @@ static bool find_first_internal_key_placeholder(dispatcher_context_t *dc, } typedef struct { - placeholder_info_t *placeholder_info; + keyexpr_info_t *keyexpr_info; input_info_t *input; } input_keys_callback_data_t; @@ -776,15 +888,14 @@ static void input_keys_callback(dispatcher_context_t *dc, callback_data->input->has_sighash_type = true; } else if ((key_type == PSBT_IN_BIP32_DERIVATION || key_type == PSBT_IN_TAP_BIP32_DERIVATION) && - !callback_data->input->in_out.placeholder_found) { - if (0 > - read_change_and_index_from_psbt_bip32_derivation(dc, - callback_data->placeholder_info, - &callback_data->input->in_out, - key_type, - data, - map_commitment, - i)) { + !callback_data->input->in_out.key_expression_found) { + if (0 > read_change_and_index_from_psbt_bip32_derivation(dc, + callback_data->keyexpr_info, + &callback_data->input->in_out, + key_type, + data, + map_commitment, + i)) { callback_data->input->in_out.unexpected_pubkey_error = true; } } @@ -799,18 +910,17 @@ preprocess_inputs(dispatcher_context_t *dc, memset(internal_inputs, 0, BITVECTOR_REAL_SIZE(MAX_N_INPUTS_CAN_SIGN)); - placeholder_info_t placeholder_info; - memset(&placeholder_info, 0, sizeof(placeholder_info)); + keyexpr_info_t keyexpr_info; + memset(&keyexpr_info, 0, sizeof(keyexpr_info)); - if (!find_first_internal_key_placeholder(dc, st, &placeholder_info)) return false; + if (!find_first_internal_keyexpr(dc, st, &keyexpr_info)) return false; // process each input for (unsigned int cur_input_index = 0; cur_input_index < st->n_inputs; cur_input_index++) { input_info_t input; memset(&input, 0, sizeof(input)); - input_keys_callback_data_t callback_data = {.input = &input, - .placeholder_info = &placeholder_info}; + input_keys_callback_data_t callback_data = {.input = &input, .keyexpr_info = &keyexpr_info}; int res = call_get_merkleized_map_with_callback( dc, (void *) &callback_data, @@ -1005,7 +1115,7 @@ preprocess_inputs(dispatcher_context_t *dc, } typedef struct { - placeholder_info_t *placeholder_info; + keyexpr_info_t *keyexpr_info; output_info_t *output; } output_keys_callback_data_t; @@ -1024,15 +1134,14 @@ static void output_keys_callback(dispatcher_context_t *dc, buffer_read_u8(data, &key_type); if ((key_type == PSBT_OUT_BIP32_DERIVATION || key_type == PSBT_OUT_TAP_BIP32_DERIVATION) && - !callback_data->output->in_out.placeholder_found) { - if (0 > - read_change_and_index_from_psbt_bip32_derivation(dc, - callback_data->placeholder_info, - &callback_data->output->in_out, - key_type, - data, - map_commitment, - i)) { + !callback_data->output->in_out.key_expression_found) { + if (0 > read_change_and_index_from_psbt_bip32_derivation(dc, + callback_data->keyexpr_info, + &callback_data->output->in_out, + key_type, + data, + map_commitment, + i)) { callback_data->output->in_out.unexpected_pubkey_error = true; } } @@ -1051,10 +1160,10 @@ preprocess_outputs(dispatcher_context_t *dc, LOG_PROCESSOR(__FILE__, __LINE__, __func__); - placeholder_info_t placeholder_info; - memset(&placeholder_info, 0, sizeof(placeholder_info)); + keyexpr_info_t keyexpr_info; + memset(&keyexpr_info, 0, sizeof(keyexpr_info)); - if (!find_first_internal_key_placeholder(dc, st, &placeholder_info)) return false; + if (!find_first_internal_keyexpr(dc, st, &keyexpr_info)) return false; memset(&st->outputs, 0, sizeof(st->outputs)); @@ -1067,7 +1176,7 @@ preprocess_outputs(dispatcher_context_t *dc, memset(&output, 0, sizeof(output)); output_keys_callback_data_t callback_data = {.output = &output, - .placeholder_info = &placeholder_info}; + .keyexpr_info = &keyexpr_info}; int res = call_get_merkleized_map_with_callback( dc, (void *) &callback_data, @@ -1670,7 +1779,7 @@ static bool __attribute__((noinline)) compute_sighash_legacy(dispatcher_context_ static bool __attribute__((noinline)) compute_sighash_segwitv0(dispatcher_context_t *dc, sign_psbt_state_t *st, - segwit_hashes_t *hashes, + const tx_hashes_t *hashes, input_info_t *input, unsigned int cur_input_index, uint8_t sighash[static 32]) { @@ -1855,10 +1964,10 @@ static bool __attribute__((noinline)) compute_sighash_segwitv0(dispatcher_contex static bool __attribute__((noinline)) compute_sighash_segwitv1(dispatcher_context_t *dc, sign_psbt_state_t *st, - segwit_hashes_t *hashes, + const tx_hashes_t *hashes, input_info_t *input, unsigned int cur_input_index, - placeholder_info_t *placeholder_info, + const keyexpr_info_t *keyexpr_info, uint8_t sighash[static 32]) { LOG_PROCESSOR(__FILE__, __LINE__, __func__); @@ -1893,7 +2002,7 @@ static bool __attribute__((noinline)) compute_sighash_segwitv1(dispatcher_contex } // ext_flag - uint8_t ext_flag = placeholder_info->is_tapscript ? 1 : 0; + uint8_t ext_flag = keyexpr_info->is_tapscript ? 1 : 0; // annex is not supported const uint8_t annex_present = 0; uint8_t spend_type = ext_flag * 2 + annex_present; @@ -1977,9 +2086,9 @@ static bool __attribute__((noinline)) compute_sighash_segwitv1(dispatcher_contex crypto_hash_update(&sighash_context.header, tmp, 32); } - if (placeholder_info->is_tapscript) { + if (keyexpr_info->is_tapscript) { // If spending a tapscript, append the Common Signature Message Extension per BIP-0342 - crypto_hash_update(&sighash_context.header, placeholder_info->tapleaf_hash, 32); + crypto_hash_update(&sighash_context.header, keyexpr_info->tapleaf_hash, 32); crypto_hash_update_u8(&sighash_context.header, 0x00); // key_version crypto_hash_update_u32(&sighash_context.header, 0xffffffff); // no OP_CODESEPARATOR } @@ -2034,22 +2143,22 @@ static bool __attribute__((noinline)) yield_signature(dispatcher_context_t *dc, static bool __attribute__((noinline)) sign_sighash_ecdsa_and_yield(dispatcher_context_t *dc, sign_psbt_state_t *st, - placeholder_info_t *placeholder_info, + const keyexpr_info_t *keyexpr_info, input_info_t *input, unsigned int cur_input_index, uint8_t sighash[static 32]) { LOG_PROCESSOR(__FILE__, __LINE__, __func__); uint32_t sign_path[MAX_BIP32_PATH_STEPS]; - for (int i = 0; i < placeholder_info->key_derivation_length; i++) { - sign_path[i] = placeholder_info->key_derivation[i]; + for (int i = 0; i < keyexpr_info->key_derivation_length; i++) { + sign_path[i] = keyexpr_info->key_derivation[i]; } - sign_path[placeholder_info->key_derivation_length] = - input->in_out.is_change ? placeholder_info->placeholder.num_second - : placeholder_info->placeholder.num_first; - sign_path[placeholder_info->key_derivation_length + 1] = input->in_out.address_index; + sign_path[keyexpr_info->key_derivation_length] = + input->in_out.is_change ? keyexpr_info->key_expression_ptr->num_second + : keyexpr_info->key_expression_ptr->num_first; + sign_path[keyexpr_info->key_derivation_length + 1] = input->in_out.address_index; - int sign_path_len = placeholder_info->key_derivation_length + 2; + int sign_path_len = keyexpr_info->key_derivation_length + 2; uint8_t sig[MAX_DER_SIG_LEN + 1]; // extra byte for the appended sighash-type @@ -2076,13 +2185,12 @@ sign_sighash_ecdsa_and_yield(dispatcher_context_t *dc, return true; } -static bool __attribute__((noinline)) -sign_sighash_schnorr_and_yield(dispatcher_context_t *dc, - sign_psbt_state_t *st, - placeholder_info_t *placeholder_info, - input_info_t *input, - unsigned int cur_input_index, - uint8_t sighash[static 32]) { +static bool __attribute__((noinline)) sign_sighash_schnorr_and_yield(dispatcher_context_t *dc, + sign_psbt_state_t *st, + keyexpr_info_t *keyexpr_info, + input_info_t *input, + unsigned int cur_input_index, + uint8_t sighash[static 32]) { LOG_PROCESSOR(__FILE__, __LINE__, __func__); if (st->wallet_policy_map->type != TOKEN_TR) { @@ -2110,15 +2218,15 @@ sign_sighash_schnorr_and_yield(dispatcher_context_t *dc, uint32_t sign_path[MAX_BIP32_PATH_STEPS]; - for (int i = 0; i < placeholder_info->key_derivation_length; i++) { - sign_path[i] = placeholder_info->key_derivation[i]; + for (int i = 0; i < keyexpr_info->key_derivation_length; i++) { + sign_path[i] = keyexpr_info->key_derivation[i]; } - sign_path[placeholder_info->key_derivation_length] = - input->in_out.is_change ? placeholder_info->placeholder.num_second - : placeholder_info->placeholder.num_first; - sign_path[placeholder_info->key_derivation_length + 1] = input->in_out.address_index; + sign_path[keyexpr_info->key_derivation_length] = + input->in_out.is_change ? keyexpr_info->key_expression_ptr->num_second + : keyexpr_info->key_expression_ptr->num_first; + sign_path[keyexpr_info->key_derivation_length + 1] = input->in_out.address_index; - int sign_path_len = placeholder_info->key_derivation_length + 2; + int sign_path_len = keyexpr_info->key_derivation_length + 2; if (bip32_derive_init_privkey_256(CX_CURVE_256K1, sign_path, @@ -2131,7 +2239,7 @@ sign_sighash_schnorr_and_yield(dispatcher_context_t *dc, policy_node_tr_t *policy = (policy_node_tr_t *) st->wallet_policy_map; - if (!placeholder_info->is_tapscript) { + if (!keyexpr_info->is_tapscript) { if (isnull_policy_node_tree(&policy->tree)) { // tweak as specified in BIP-86 and BIP-386 crypto_tr_tweak_seckey(seckey, (uint8_t[]){}, 0, seckey); @@ -2143,7 +2251,7 @@ sign_sighash_schnorr_and_yield(dispatcher_context_t *dc, } } else { // tapscript, we need to yield the tapleaf hash together with the pubkey - tapleaf_hash = placeholder_info->tapleaf_hash; + tapleaf_hash = keyexpr_info->tapleaf_hash; } // generate corresponding public key @@ -2200,8 +2308,454 @@ sign_sighash_schnorr_and_yield(dispatcher_context_t *dc, return true; } +static bool __attribute__((noinline)) yield_musig_data(dispatcher_context_t *dc, + sign_psbt_state_t *st, + unsigned int cur_input_index, + const uint8_t *data, + size_t data_len, + uint32_t tag, + const uint8_t participant_pk[static 33], + const uint8_t aggregate_pubkey[static 33], + const uint8_t *tapleaf_hash) { + LOG_PROCESSOR(__FILE__, __LINE__, __func__); + + if (st->protocol_version == 0) { + // Only support version 1 of the protocol + return false; + } + + // bytes: 1 5 varint data_len 33 33 0 or 32 + // CMD_YIELD + // + + // Yield signature + uint8_t cmd = CCMD_YIELD; + dc->add_to_response(&cmd, 1); + + uint8_t buf[9]; + + // Add tag + int tag_varint_len = varint_write(buf, 0, tag); + dc->add_to_response(buf, tag_varint_len); + + // Add input index + int input_index_varint_len = varint_write(buf, 0, cur_input_index); + dc->add_to_response(buf, input_index_varint_len); + + // Add data (pubnonce or partial signature) + dc->add_to_response(data, data_len); + + // Add participant public key + dc->add_to_response(participant_pk, 33); + + // Add aggregate public key + dc->add_to_response(aggregate_pubkey, 33); + + // Add tapleaf hash if provided + if (tapleaf_hash != NULL) { + dc->add_to_response(tapleaf_hash, 32); + } + + dc->finalize_response(SW_INTERRUPTED_EXECUTION); + + if (dc->process_interruption(dc) < 0) { + return false; + } + return true; +} + +static bool yield_musig_pubnonce(dispatcher_context_t *dc, + sign_psbt_state_t *st, + unsigned int cur_input_index, + const musig_pubnonce_t *pubnonce, + const uint8_t participant_pk[static 33], + const uint8_t aggregate_pubkey[static 33], + const uint8_t *tapleaf_hash) { + return yield_musig_data(dc, + st, + cur_input_index, + (const uint8_t *) pubnonce, + sizeof(musig_pubnonce_t), + CCMD_YIELD_MUSIG_PUBNONCE_TAG, + participant_pk, + aggregate_pubkey, + tapleaf_hash); +} + +static bool yield_musig_partial_signature(dispatcher_context_t *dc, + sign_psbt_state_t *st, + unsigned int cur_input_index, + const uint8_t psig[static 32], + const uint8_t participant_pk[static 33], + const uint8_t aggregate_pubkey[static 33], + const uint8_t *tapleaf_hash) { + return yield_musig_data(dc, + st, + cur_input_index, + psig, + 32, + CCMD_YIELD_MUSIG_PARTIALSIGNATURE_TAG, + participant_pk, + aggregate_pubkey, + tapleaf_hash); +} + static bool __attribute__((noinline)) -compute_segwit_hashes(dispatcher_context_t *dc, sign_psbt_state_t *st, segwit_hashes_t *hashes) { +sign_sighash_musig_and_yield(dispatcher_context_t *dc, + sign_psbt_state_t *st, + signing_state_t *signing_state, + const keyexpr_info_t *keyexpr_info, + const input_info_t *input, + unsigned int cur_input_index, + uint8_t sighash[static 32]) { + LOG_PROCESSOR(__FILE__, __LINE__, __func__); + + if (st->wallet_policy_map->type != TOKEN_TR) { + SEND_SW(dc, SW_BAD_STATE); // should never happen + return false; + } + + const policy_node_tr_t *tr_policy = (policy_node_tr_t *) st->wallet_policy_map; + + // plan: + // 1) compute aggregate pubkey + // 2) compute musig2 tweaks + // 3) compute taproot tweak (if keypath spend) + // if my pubnonce is in the psbt: + // 5) generate and yield pubnonce + // else: + // 6) generate and yield partial signature + + // 1) compute aggregate pubkey + + // TODO: we should compute the aggregate pubkey just once for the placeholder, instead of + // repeating for each input + wallet_derivation_info_t wdi = {.n_keys = st->wallet_header.n_keys, + .wallet_version = st->wallet_header.version, + .keys_merkle_root = st->wallet_header.keys_info_merkle_root, + .change = input->in_out.is_change, + .address_index = input->in_out.address_index}; + + // TODO: code duplication with policy.c::get_derived_pubkey; worth extracting a common method? + + serialized_extended_pubkey_t ext_pubkey; + + const policy_node_keyexpr_t *key_expr = keyexpr_info->key_expression_ptr; + const musig_aggr_key_info_t *musig_info = r_musig_aggr_key_info(&key_expr->m.musig_info); + const uint16_t *key_indexes = r_uint16(&musig_info->key_indexes); + plain_pk_t keys[MAX_PUBKEYS_PER_MUSIG]; + + LEDGER_ASSERT(musig_info->n <= MAX_PUBKEYS_PER_MUSIG, "Too many keys in musig key expression"); + for (int i = 0; i < musig_info->n; i++) { + // we use ext_pubkey as a temporary variable; will overwrite later + if (0 > get_extended_pubkey(dc, &wdi, key_indexes[i], &ext_pubkey)) { + return -1; + } + memcpy(keys[i], ext_pubkey.compressed_pubkey, sizeof(ext_pubkey.compressed_pubkey)); + } + + // sort the keys in ascending order using bubble sort + for (int i = 0; i < musig_info->n; i++) { + for (int j = 0; j < musig_info->n - 1; j++) { + if (memcmp(keys[j], keys[j + 1], sizeof(plain_pk_t)) > 0) { + uint8_t tmp[sizeof(plain_pk_t)]; + memcpy(tmp, keys[j], sizeof(plain_pk_t)); + memcpy(keys[j], keys[j + 1], sizeof(plain_pk_t)); + memcpy(keys[j + 1], tmp, sizeof(plain_pk_t)); + } + } + } + + musig_keyagg_context_t musig_ctx; + musig_key_agg(keys, musig_info->n, &musig_ctx); + + // compute the aggregated extended pubkey + memset(&ext_pubkey, 0, sizeof(ext_pubkey)); + write_u32_be(ext_pubkey.version, 0, BIP32_PUBKEY_VERSION); + + ext_pubkey.compressed_pubkey[0] = (musig_ctx.Q.y[31] % 2 == 0) ? 2 : 3; + memcpy(&ext_pubkey.compressed_pubkey[1], musig_ctx.Q.x, sizeof(musig_ctx.Q.x)); + memcpy(&ext_pubkey.chain_code, BIP_MUSIG_CHAINCODE, sizeof(BIP_MUSIG_CHAINCODE)); + + // 2) compute musig2 tweaks + // We always have exactly 2 BIP32 tweaks in wallet policies; if the musig is in the keypath + // spend, we also have an x-only taptweak with the taproot tree hash (or BIP-86/BIP-386 style if + // there is no taproot tree). + + uint32_t change_step = input->in_out.is_change ? keyexpr_info->key_expression_ptr->num_second + : keyexpr_info->key_expression_ptr->num_first; + uint32_t addr_index_step = input->in_out.address_index; + + // in wallet policies, we always have at least two bip32-tweaks, and we might have + // one x-only tweak per BIP-0341 (if spending from the keypath). + uint8_t tweaks[3][32]; + uint8_t *tweaks_ptrs[3] = {tweaks[0], tweaks[1], tweaks[2]}; + bool is_xonly[] = {false, false, true}; + size_t n_tweaks = 2; // might be changed to 3 below + + serialized_extended_pubkey_t agg_key_tweaked; + if (0 > bip32_CKDpub(&ext_pubkey, change_step, &agg_key_tweaked, tweaks[0])) { + SEND_SW(dc, SW_BAD_STATE); // should never happen + return false; + } + if (0 > bip32_CKDpub(&agg_key_tweaked, addr_index_step, &agg_key_tweaked, tweaks[1])) { + SEND_SW(dc, SW_BAD_STATE); // should never happen + return false; + } + + // 3) compute taproot tweak (if keypath spend) + memset(tweaks[2], 0, 32); + if (!keyexpr_info->is_tapscript) { + n_tweaks = 3; + + crypto_tr_tagged_hash( + BIP0341_taptweak_tag, + sizeof(BIP0341_taptweak_tag), + agg_key_tweaked.compressed_pubkey + 1, // xonly key, after BIP-32 tweaks + 32, + input->taptree_hash, + // BIP-86 compliant tweak if there's no taptree, otherwise use the taptree hash + isnull_policy_node_tree(&tr_policy->tree) ? 0 : 32, + tweaks[2]); + + // also apply the taptweak to agg_key_tweaked + + uint8_t parity = 0; + crypto_tr_tweak_pubkey(agg_key_tweaked.compressed_pubkey + 1, + input->taptree_hash, + isnull_policy_node_tree(&tr_policy->tree) ? 0 : 32, + &parity, + agg_key_tweaked.compressed_pubkey + 1); + agg_key_tweaked.compressed_pubkey[0] = 0x02 + parity; + } + + // we will no longer use the other fields of the extended pubkey, so we zero them for sanity + memset(agg_key_tweaked.chain_code, 0, sizeof(agg_key_tweaked.chain_code)); + memset(agg_key_tweaked.child_number, 0, sizeof(agg_key_tweaked.child_number)); + agg_key_tweaked.depth = 0; + memset(agg_key_tweaked.parent_fingerprint, 0, sizeof(agg_key_tweaked.parent_fingerprint)); + memset(agg_key_tweaked.version, 0, sizeof(agg_key_tweaked.version)); + + // Compute musig_my_psbt_id. It is the psbt key that this signer uses to find pubnonces and + // partial signatures (PSBT_IN_MUSIG2_PUB_NONCE and PSBT_IN_MUSIG2_PARTIAL_SIG fields). The + // length is either 33+33 (keypath spend), or 33+33+32 bytes (tapscript spend). It's the + // concatenation of: + // - the 33-byte compressed pubkey of this participant + // - the 33-byte aggregate compressed pubkey (after all the tweaks) + // - (tapscript only) the 32-byte tapleaf hash + uint8_t musig_my_psbt_id_key[1 + 33 + 33 + 32]; + musig_my_psbt_id_key[0] = PSBT_IN_MUSIG2_PUB_NONCE; + + uint8_t *musig_my_psbt_id = musig_my_psbt_id_key + 1; + size_t psbt_id_len = keyexpr_info->is_tapscript ? 33 + 33 + 32 : 33 + 33; + memcpy(musig_my_psbt_id, keyexpr_info->internal_pubkey.compressed_pubkey, 33); + memcpy(musig_my_psbt_id + 33, agg_key_tweaked.compressed_pubkey, 33); + if (keyexpr_info->is_tapscript) { + memcpy(musig_my_psbt_id + 33 + 33, keyexpr_info->tapleaf_hash, 32); + } + + // The psbt_session_id identifies the musig signing session for the entire (psbt, wallet_policy) + // pair, in both rounds 1 and 2 of the protocol; it is the same for all the musig placeholders + // in the policy (if more than one), and it is the same for all the inputs in the psbt. By + // making the hash depend on both the wallet policy and the transaction hashes, we make sure + // that an accidental collision is impossible, allowing for independent, parallel MuSig2 signing + // sessions for different transactions or wallet policies. + // Malicious collisions are not a concern, as they would only result in a signing failure (since + // the nonces would be incorrectly regenerated during round 2 of MuSig2). + uint8_t psbt_session_id[32]; + crypto_tr_tagged_hash( + (uint8_t[]){'P', 's', 'b', 't', 'S', 'e', 's', 's', 'i', 'o', 'n', 'I', 'd'}, + 13, + st->wallet_header.keys_info_merkle_root, // TODO: wallet policy id would be more precise + 32, + (uint8_t *) &signing_state->tx_hashes, + sizeof(signing_state->tx_hashes), + psbt_session_id); + memcpy(psbt_session_id, st->wallet_header.keys_info_merkle_root, sizeof(psbt_session_id)); + + // 4) check if my pubnonce is in the psbt + musig_pubnonce_t my_pubnonce; + if (sizeof(musig_pubnonce_t) != call_get_merkleized_map_value(dc, + &input->in_out.map, + musig_my_psbt_id_key, + 1 + psbt_id_len, + my_pubnonce.raw, + sizeof(musig_pubnonce_t))) { + /** + * Round 1 of the MuSig2 protocol + **/ + + const musig_psbt_session_t *psbt_session = + musigsession_round1_initialize(psbt_session_id, &signing_state->musig); + if (psbt_session == NULL) { + // This should never happen + PRINTF("Unexpected: failed to initialize MuSig2 round 1\n"); + SEND_SW(dc, SW_BAD_STATE); + return false; + } + + // 5) generate and yield pubnonce + + uint8_t rand_i_j[32]; + compute_rand_i_j(psbt_session, cur_input_index, keyexpr_info->cur_index, rand_i_j); + + musig_secnonce_t secnonce; + musig_pubnonce_t pubnonce; + if (0 > musig_nonce_gen(rand_i_j, + keyexpr_info->internal_pubkey.compressed_pubkey, + agg_key_tweaked.compressed_pubkey + 1, + &secnonce, + &pubnonce)) { + PRINTF("MuSig2 nonce generation failed\n"); + SEND_SW(dc, SW_BAD_STATE); // should never happen + return false; + } + + if (!yield_musig_pubnonce(dc, + st, + cur_input_index, + &pubnonce, + keyexpr_info->internal_pubkey.compressed_pubkey, + agg_key_tweaked.compressed_pubkey, + keyexpr_info->is_tapscript ? keyexpr_info->tapleaf_hash : NULL)) { + PRINTF("Failed yielding MuSig2 pubnonce\n"); + SEND_SW(dc, SW_BAD_STATE); // should never happen + return false; + } + } else { + /** + * Round 2 of the MuSig2 protocol + **/ + + const musig_psbt_session_t *psbt_session = + musigsession_round2_initialize(psbt_session_id, &signing_state->musig); + + if (psbt_session == NULL) { + // The PSBT contains a partial nonce, but we do not have the corresponding psbt + // session in storage. Either it was deleted, or the pubnonces were not real. Either + // way, we cannot continue. + PRINTF("Missing MuSig2 session\n"); + SEND_SW(dc, SW_BAD_STATE); + return false; + } + + // 6) generate and yield partial signature + + musig_pubnonce_t nonces[MAX_PUBKEYS_PER_MUSIG]; + + for (int i = 0; i < musig_info->n; i++) { + uint8_t musig_ith_psbt_id_key[1 + 33 + 33 + 32]; + uint8_t *musig_ith_psbt_id = musig_ith_psbt_id_key + 1; + // copy from musig_my_psbt_id_key, but replace the corresponding pubkey + memcpy(musig_ith_psbt_id_key, musig_my_psbt_id_key, sizeof(musig_my_psbt_id_key)); + memcpy(musig_ith_psbt_id, keys[i], sizeof(plain_pk_t)); + + // TODO: could avoid fetching again our own pubnonce + if (sizeof(musig_pubnonce_t) != + call_get_merkleized_map_value(dc, + &input->in_out.map, + musig_ith_psbt_id_key, + 1 + psbt_id_len, + nonces[i].raw, + sizeof(musig_pubnonce_t))) { + PRINTF("Missing or incorrect pubnonce for a MuSig2 cosigner\n"); + SEND_SW(dc, SW_INCORRECT_DATA); + return false; + } + } + + // compute aggregate nonce + musig_pubnonce_t aggnonce; + int res = musig_nonce_agg(nonces, musig_info->n, &aggnonce); + if (res < 0) { + PRINTF("Musig aggregation failed; disruptive signer has index %d\n", -res); + SEND_SW(dc, SW_INCORRECT_DATA); + } + + // recompute secnonce from psbt_session randomness + uint8_t rand_i_j[32]; + compute_rand_i_j(psbt_session, cur_input_index, keyexpr_info->cur_index, rand_i_j); + + musig_secnonce_t secnonce; + musig_pubnonce_t pubnonce; + + if (0 > musig_nonce_gen(rand_i_j, + keyexpr_info->internal_pubkey.compressed_pubkey, + agg_key_tweaked.compressed_pubkey + 1, + &secnonce, + &pubnonce)) { + PRINTF("MuSig2 nonce generation failed\n"); + SEND_SW(dc, SW_BAD_STATE); // should never happen + return false; + } + + // derive secret key + + cx_ecfp_private_key_t private_key = {0}; + uint8_t psig[32]; + bool err = false; + do { // block executed once, only to allow safely breaking out on error + + // derive secret key + uint32_t sign_path[MAX_BIP32_PATH_STEPS]; + + for (int i = 0; i < keyexpr_info->key_derivation_length; i++) { + sign_path[i] = keyexpr_info->key_derivation[i]; + } + int sign_path_len = keyexpr_info->key_derivation_length; + + if (bip32_derive_init_privkey_256(CX_CURVE_256K1, + sign_path, + sign_path_len, + &private_key, + NULL) != CX_OK) { + err = true; + break; + } + + // Create partial signature + musig_session_context_t musig_session_context = {.aggnonce = &aggnonce, + .n_keys = musig_info->n, + .pubkeys = keys, + .n_tweaks = n_tweaks, + .tweaks = tweaks_ptrs, + .is_xonly = is_xonly, + .msg = sighash, + .msg_len = 32}; + + if (0 > musig_sign(&secnonce, private_key.d, &musig_session_context, psig)) { + PRINTF("Musig2 signature failed\n"); + err = true; + break; + } + } while (false); + + explicit_bzero(&private_key, sizeof(private_key)); + + if (err) { + PRINTF("Partial signature generation failed\n"); + return false; + } + + if (!yield_musig_partial_signature( + dc, + st, + cur_input_index, + psig, + keyexpr_info->internal_pubkey.compressed_pubkey, + agg_key_tweaked.compressed_pubkey, + keyexpr_info->is_tapscript ? keyexpr_info->tapleaf_hash : NULL)) { + PRINTF("Failed yielding MuSig2 partial signature\n"); + SEND_SW(dc, SW_BAD_STATE); // should never happen + return false; + } + } + + return true; +} + +static bool __attribute__((noinline)) +compute_tx_hashes(dispatcher_context_t *dc, sign_psbt_state_t *st, tx_hashes_t *hashes) { { // compute sha_prevouts and sha_sequences cx_sha256_t sha_prevouts_context, sha_sequences_context; @@ -2329,8 +2883,8 @@ compute_segwit_hashes(dispatcher_context_t *dc, sign_psbt_state_t *st, segwit_ha static bool __attribute__((noinline)) sign_transaction_input(dispatcher_context_t *dc, sign_psbt_state_t *st, - segwit_hashes_t *hashes, - placeholder_info_t *placeholder_info, + signing_state_t *signing_state, + keyexpr_info_t *keyexpr_info, input_info_t *input, unsigned int cur_input_index) { LOG_PROCESSOR(__FILE__, __LINE__, __func__); @@ -2353,6 +2907,9 @@ static bool __attribute__((noinline)) sign_transaction_input(dispatcher_context_ // Sign as segwit input iff it has a witness utxo if (!input->has_witnessUtxo) { + LEDGER_ASSERT(keyexpr_info->key_expression_ptr->type == KEY_EXPRESSION_NORMAL, + "Only plain key expressions are valid for legacy inputs"); + // sign legacy P2PKH or P2SH // sign_non_witness(non_witness_utxo.vout[psbt.tx.input_[i].prevout.n].scriptPubKey, i) @@ -2376,12 +2933,7 @@ static bool __attribute__((noinline)) sign_transaction_input(dispatcher_context_ uint8_t sighash[32]; if (!compute_sighash_legacy(dc, st, input, cur_input_index, sighash)) return false; - if (!sign_sighash_ecdsa_and_yield(dc, - st, - placeholder_info, - input, - cur_input_index, - sighash)) + if (!sign_sighash_ecdsa_and_yield(dc, st, keyexpr_info, input, cur_input_index, sighash)) return false; } else { { @@ -2438,17 +2990,25 @@ static bool __attribute__((noinline)) sign_transaction_input(dispatcher_context_ int segwit_version = get_policy_segwit_version(st->wallet_policy_map); uint8_t sighash[32]; if (segwit_version == 0) { + LEDGER_ASSERT(keyexpr_info->key_expression_ptr->type == KEY_EXPRESSION_NORMAL, + "Only plain key expressions are valid for SegwitV0 inputs"); + if (!input->has_sighash_type) { // segwitv0 inputs default to SIGHASH_ALL input->sighash_type = SIGHASH_ALL; } - if (!compute_sighash_segwitv0(dc, st, hashes, input, cur_input_index, sighash)) + if (!compute_sighash_segwitv0(dc, + st, + &signing_state->tx_hashes, + input, + cur_input_index, + sighash)) return false; if (!sign_sighash_ecdsa_and_yield(dc, st, - placeholder_info, + keyexpr_info, input, cur_input_index, sighash)) @@ -2461,15 +3021,15 @@ static bool __attribute__((noinline)) sign_transaction_input(dispatcher_context_ if (!compute_sighash_segwitv1(dc, st, - hashes, + &signing_state->tx_hashes, input, cur_input_index, - placeholder_info, + keyexpr_info, sighash)) return false; policy_node_tr_t *policy = (policy_node_tr_t *) st->wallet_policy_map; - if (!placeholder_info->is_tapscript && !isnull_policy_node_tree(&policy->tree)) { + if (!keyexpr_info->is_tapscript && !isnull_policy_node_tree(&policy->tree)) { // keypath spend, we compute the taptree hash so that we find it ready // later in sign_sighash_schnorr_and_yield (which has less available stack). if (0 > compute_taptree_hash( @@ -2488,14 +3048,26 @@ static bool __attribute__((noinline)) sign_transaction_input(dispatcher_context_ } } - if (!sign_sighash_schnorr_and_yield(dc, - st, - placeholder_info, - input, - cur_input_index, - sighash)) - return false; - + if (keyexpr_info->key_expression_ptr->type == KEY_EXPRESSION_NORMAL) { + if (!sign_sighash_schnorr_and_yield(dc, + st, + keyexpr_info, + input, + cur_input_index, + sighash)) + return false; + } else if (keyexpr_info->key_expression_ptr->type == KEY_EXPRESSION_MUSIG) { + if (!sign_sighash_musig_and_yield(dc, + st, + signing_state, + keyexpr_info, + input, + cur_input_index, + sighash)) + return false; + } else { + LEDGER_ASSERT(false, "Unreachable"); + } } else { SEND_SW(dc, SW_BAD_STATE); // can't happen return false; @@ -2504,12 +3076,11 @@ static bool __attribute__((noinline)) sign_transaction_input(dispatcher_context_ return true; } -static bool __attribute__((noinline)) -fill_taproot_placeholder_info(dispatcher_context_t *dc, - sign_psbt_state_t *st, - const input_info_t *input, - const policy_node_t *tapleaf_ptr, - placeholder_info_t *placeholder_info) { +static bool __attribute__((noinline)) fill_taproot_keyexpr_info(dispatcher_context_t *dc, + sign_psbt_state_t *st, + const input_info_t *input, + const policy_node_t *tapleaf_ptr, + keyexpr_info_t *keyexpr_info) { cx_sha256_t hash_context; crypto_tr_tapleaf_hash_init(&hash_context); @@ -2547,7 +3118,7 @@ fill_taproot_placeholder_info(dispatcher_context_t *dc, &hash_context.header)) { return false; // should never happen! } - crypto_hash_digest(&hash_context.header, placeholder_info->tapleaf_hash, 32); + crypto_hash_digest(&hash_context.header, keyexpr_info->tapleaf_hash, 32); return true; } @@ -2558,54 +3129,53 @@ sign_transaction(dispatcher_context_t *dc, const uint8_t internal_inputs[static BITVECTOR_REAL_SIZE(MAX_N_INPUTS_CAN_SIGN)]) { LOG_PROCESSOR(__FILE__, __LINE__, __func__); - int placeholder_index = 0; + int key_expression_index = 0; - segwit_hashes_t hashes; + signing_state_t signing_state; // compute all the tx-wide hashes // while this is redundant for legacy transactions, we do it here in order to // avoid doing it in places that have more stack limitations - if (!compute_segwit_hashes(dc, st, &hashes)) { - // we do not send a status word, since compute_segwit_hashes already does it on failure + if (!compute_tx_hashes(dc, st, &signing_state.tx_hashes)) { + // we do not send a status word, since compute_tx_hashes already does it on failure return false; } - // Iterate over all the placeholders that correspond to keys owned by us + // Iterate over all the key expressions that correspond to keys owned by us while (true) { - placeholder_info_t placeholder_info; - memset(&placeholder_info, 0, sizeof(placeholder_info)); + keyexpr_info_t keyexpr_info; + memset(&keyexpr_info, 0, sizeof(keyexpr_info)); const policy_node_t *tapleaf_ptr = NULL; - int n_key_placeholders = get_key_placeholder_by_index(st->wallet_policy_map, - placeholder_index, - &tapleaf_ptr, - &placeholder_info.placeholder); + int n_key_expressions = get_keyexpr_by_index(st->wallet_policy_map, + key_expression_index, + &tapleaf_ptr, + &keyexpr_info.key_expression_ptr); - if (n_key_placeholders < 0) { + if (n_key_expressions < 0) { SEND_SW(dc, SW_BAD_STATE); // should never happen return false; } - if (placeholder_index >= n_key_placeholders) { - // all placeholders were processed + if (key_expression_index >= n_key_expressions) { + // all key expressions were processed break; } if (tapleaf_ptr != NULL) { - // get_key_placeholder_by_index returns the pointer to the tapleaf only if the key being + // get_keyexpr_by_index returns the pointer to the tapleaf only if the key being // spent is indeed in a tapleaf - placeholder_info.is_tapscript = true; + keyexpr_info.is_tapscript = true; } - if (fill_placeholder_info_if_internal(dc, st, &placeholder_info) == true) { + if (fill_keyexpr_info_if_internal(dc, st, &keyexpr_info) == true) { for (unsigned int i = 0; i < st->n_inputs; i++) if (bitvector_get(internal_inputs, i)) { input_info_t input; memset(&input, 0, sizeof(input)); - input_keys_callback_data_t callback_data = { - .input = &input, - .placeholder_info = &placeholder_info}; + input_keys_callback_data_t callback_data = {.input = &input, + .keyexpr_info = &keyexpr_info}; int res = call_get_merkleized_map_with_callback( dc, (void *) &callback_data, @@ -2619,14 +3189,12 @@ sign_transaction(dispatcher_context_t *dc, return false; } - if (tapleaf_ptr != NULL && !fill_taproot_placeholder_info(dc, - st, - &input, - tapleaf_ptr, - &placeholder_info)) + if (tapleaf_ptr != NULL && + !fill_taproot_keyexpr_info(dc, st, &input, tapleaf_ptr, &keyexpr_info)) { return false; + } - if (!sign_transaction_input(dc, st, &hashes, &placeholder_info, &input, i)) { + if (!sign_transaction_input(dc, st, &signing_state, &keyexpr_info, &input, i)) { // we do not send a status word, since sign_transaction_input // already does it on failure return false; @@ -2634,9 +3202,13 @@ sign_transaction(dispatcher_context_t *dc, } } - ++placeholder_index; + ++key_expression_index; } + // MuSig2: if there is an active session at the end of round 1, we move it to persistent + // storage. It is important that this is only done at the very end of the signing process. + musigsession_commit(&signing_state.musig); + return true; } @@ -2697,7 +3269,7 @@ void handler_sign_psbt(dispatcher_context_t *dc, uint8_t protocol_version) { /** SIGNING FLOW * - * For each internal placeholder, and for each internal input, sign using the + * For each internal key expression, and for each internal input, sign using the * appropriate algorithm. */ int sign_result = sign_transaction(dc, &st, internal_inputs); diff --git a/src/musig/musig.c b/src/musig/musig.c new file mode 100644 index 000000000..752e54e45 --- /dev/null +++ b/src/musig/musig.c @@ -0,0 +1,599 @@ +#include + +#include "cx_errors.h" + +#include "musig.h" + +#include "../crypto.h" +#include "../secp256k1.h" + +static const uint8_t BIP0327_keyagg_coeff_tag[] = + {'K', 'e', 'y', 'A', 'g', 'g', ' ', 'c', 'o', 'e', 'f', 'f', 'i', 'c', 'i', 'e', 'n', 't'}; +static const uint8_t BIP0327_keyagg_list_tag[] = + {'K', 'e', 'y', 'A', 'g', 'g', ' ', 'l', 'i', 's', 't'}; +static const uint8_t BIP0327_nonce_tag[] = {'M', 'u', 'S', 'i', 'g', '/', 'n', 'o', 'n', 'c', 'e'}; +static const uint8_t BIP0327_noncecoef_tag[] = + {'M', 'u', 'S', 'i', 'g', '/', 'n', 'o', 'n', 'c', 'e', 'c', 'o', 'e', 'f'}; + +static const uint8_t BIP0340_challenge_tag[] = + {'B', 'I', 'P', '0', '3', '4', '0', '/', 'c', 'h', 'a', 'l', 'l', 'e', 'n', 'g', 'e'}; + +static inline bool is_point_infinite(const point_t *P) { + return P->prefix == 0; +} + +static inline void set_point_infinite(point_t *P) { + memset(P->raw, 0, sizeof(point_t)); +} + +#define G ((const point_t *) secp256k1_generator) + +static cx_err_t point_add(const point_t *P1, const point_t *P2, point_t *out) { + if (is_point_infinite(P1)) { + memmove(out->raw, P2->raw, sizeof(point_t)); + return CX_OK; + } + if (is_point_infinite(P2)) { + memmove(out->raw, P1->raw, sizeof(point_t)); + return CX_OK; + } + if (memcmp(P1->x, P2->x, 32) == 0 && memcmp(P1->y, P2->y, 32) != 0) { + memset(out->raw, 0, sizeof(point_t)); + return CX_OK; + } + + cx_err_t res = cx_ecfp_add_point_no_throw(CX_CURVE_SECP256K1, out->raw, P1->raw, P2->raw); + if (res == CX_EC_INFINITE_POINT) { + set_point_infinite(out); + return CX_OK; + } + + return res; +} + +static cx_err_t point_mul(const point_t *P, const uint8_t scalar[static 32], point_t *out) { + if (is_point_infinite(P)) { + set_point_infinite(out); + return CX_OK; + } + point_t Q; // result + memcpy(&Q, P, sizeof(point_t)); + cx_err_t res = cx_ecfp_scalar_mult_no_throw(CX_CURVE_SECP256K1, Q.raw, scalar, 32); + if (res == CX_EC_INFINITE_POINT) { + set_point_infinite(out); + return CX_OK; + } + memcpy(out, &Q, sizeof(point_t)); + return res; +} + +// out can be equal to P +static int point_negate(const point_t *P, point_t *out) { + if (is_point_infinite(P)) { + set_point_infinite(out); + return 0; + } + memmove(out->x, P->x, 32); + + if (CX_OK != cx_math_sub_no_throw(out->y, secp256k1_p, P->y, 32)) return -1; + + out->prefix = 4; + return 0; +} + +static bool has_even_y(const point_t *P) { + LEDGER_ASSERT(!is_point_infinite(P), "has_even_y called with an infinite point"); + + return P->y[31] % 2 == 0; +} + +static int cpoint(const uint8_t x[33], point_t *out) { + crypto_tr_lift_x(&x[1], out->raw); + if (is_point_infinite(out)) { + PRINTF("Invalid compressed point\n"); + return -1; + } + if (x[0] == 2) { + return 0; + } else if (x[0] == 3) { + if (0 > point_negate(out, out)) { + return -1; + } + return 0; + } else { + PRINTF("Invalid compressed point: invalid prefix\n"); + return -1; + } +} + +static bool is_array_zero(const uint8_t buffer[], size_t buffer_len) { + uint8_t acc = 0; + for (size_t i = 0; i < buffer_len; i++) { + acc |= buffer[i]; + } + return acc == 0; +} + +int cpoint_ext(const uint8_t x[static 33], point_t *out) { + // Check if the point is at infinity (all bytes zero) + if (is_array_zero(x, 33)) { + set_point_infinite(out); + return 0; + } + + // Otherwise, handle as a regular compressed point + return cpoint(x, out); +} + +static void musig_get_second_key(const plain_pk_t pubkeys[], size_t n_keys, plain_pk_t out) { + for (size_t i = 0; i < n_keys; i++) { + if (memcmp(pubkeys[0], pubkeys[i], sizeof(plain_pk_t)) != 0) { + memcpy(out, pubkeys[i], sizeof(plain_pk_t)); + return; + } + } + memset(out, 0, sizeof(plain_pk_t)); +} + +static void musig_hash_keys(const plain_pk_t pubkeys[], size_t n_keys, uint8_t out[static 32]) { + cx_sha256_t hash_context; + crypto_tr_tagged_hash_init(&hash_context, + BIP0327_keyagg_list_tag, + sizeof(BIP0327_keyagg_list_tag)); + for (size_t i = 0; i < n_keys; i++) { + crypto_hash_update(&hash_context.header, pubkeys[i], sizeof(plain_pk_t)); + } + crypto_hash_digest(&hash_context.header, out, 32); +} + +static void musig_key_agg_coeff_internal(const plain_pk_t pubkeys[], + size_t n_keys, + const plain_pk_t pk_, + const plain_pk_t pk2, + uint8_t out[static CX_SHA256_SIZE]) { + uint8_t L[CX_SHA256_SIZE]; + musig_hash_keys(pubkeys, n_keys, L); + if (memcmp(pk_, pk2, sizeof(plain_pk_t)) == 0) { + memset(out, 0, CX_SHA256_SIZE); + out[31] = 1; + } else { + crypto_tr_tagged_hash(BIP0327_keyagg_coeff_tag, + sizeof(BIP0327_keyagg_coeff_tag), + L, + sizeof(L), + pk_, + sizeof(plain_pk_t), + out); + + // result modulo secp256k1_n + int res = cx_math_modm_no_throw(out, CX_SHA256_SIZE, secp256k1_n, sizeof(secp256k1_n)); + + LEDGER_ASSERT(res == CX_OK, "Modular reduction failed"); + } +} + +static void musig_key_agg_coeff(const plain_pk_t pubkeys[], + size_t n_keys, + const plain_pk_t pk_, + uint8_t out[static CX_SHA256_SIZE]) { + plain_pk_t pk2; + musig_get_second_key(pubkeys, n_keys, pk2); + + musig_key_agg_coeff_internal(pubkeys, n_keys, pk_, pk2, out); +} + +int musig_key_agg(const plain_pk_t pubkeys[], size_t n_keys, musig_keyagg_context_t *ctx) { + plain_pk_t pk2; + musig_get_second_key(pubkeys, n_keys, pk2); + + set_point_infinite(&ctx->Q); + for (size_t i = 0; i < n_keys; i++) { + point_t P; + + // set P := P_i + if (0 > cpoint(pubkeys[i], &P)) { + PRINTF("Invalid pubkey in musig_key_agg\n"); + return -1; + } + + uint8_t a_i[32]; + musig_key_agg_coeff_internal(pubkeys, n_keys, pubkeys[i], pk2, a_i); + + // set P := a_i * P_i + if (CX_OK != point_mul(&P, a_i, &P)) { + PRINTF("Scalar multiplication failed in musig_key_agg\n"); + return -1; + } + + point_add(&ctx->Q, &P, &ctx->Q); + } + memset(ctx->tacc, 0, sizeof(ctx->tacc)); + memset(ctx->gacc, 0, sizeof(ctx->gacc)); + ctx->gacc[31] = 1; + return 0; +} + +static void musig_nonce_hash(const uint8_t *rand, + const plain_pk_t pk, + const xonly_pk_t aggpk, + uint8_t i, + const uint8_t *msg_prefixed, + size_t msg_prefixed_len, + const uint8_t *extra_in, + size_t extra_in_len, + uint8_t out[static CX_SHA256_SIZE]) { + cx_sha256_t hash_context; + crypto_tr_tagged_hash_init(&hash_context, BIP0327_nonce_tag, sizeof(BIP0327_nonce_tag)); + + // rand + crypto_hash_update(&hash_context.header, rand, 32); + + // len(pk) + pk + crypto_hash_update_u8(&hash_context.header, sizeof(plain_pk_t)); + crypto_hash_update(&hash_context.header, pk, sizeof(plain_pk_t)); + + // len(aggpk) + aggpk + crypto_hash_update_u8(&hash_context.header, sizeof(xonly_pk_t)); + crypto_hash_update(&hash_context.header, aggpk, sizeof(xonly_pk_t)); + + // msg_prefixed + crypto_hash_update(&hash_context.header, msg_prefixed, msg_prefixed_len); + + // len(extra_in) (4 bytes) + extra_in + crypto_hash_update_u32(&hash_context.header, extra_in_len); + if (extra_in_len > 0) { + crypto_hash_update(&hash_context.header, extra_in, extra_in_len); + } + + crypto_hash_update_u8(&hash_context.header, i); + + crypto_hash_digest(&hash_context.header, out, CX_SHA256_SIZE); +} + +// same as nonce_gen_internal from the reference, removing the optional arguments sk, msg and +// extra_in, and making aggpk compulsory +int musig_nonce_gen(const uint8_t rand[32], + const plain_pk_t pk, + const xonly_pk_t aggpk, + musig_secnonce_t *secnonce, + musig_pubnonce_t *pubnonce) { + uint8_t msg[] = {0x00}; + + musig_nonce_hash(rand, pk, aggpk, 0, msg, 1, NULL, 0, secnonce->k_1); + if (CX_OK != cx_math_modm_no_throw(secnonce->k_1, 32, secp256k1_n, 32)) return -1; + musig_nonce_hash(rand, pk, aggpk, 1, msg, 1, NULL, 0, secnonce->k_2); + if (CX_OK != cx_math_modm_no_throw(secnonce->k_2, 32, secp256k1_n, 32)) return -1; + + memcpy(secnonce->pk, pk, 33); + + point_t R_s1, R_s2; + + if (CX_OK != point_mul(G, secnonce->k_1, &R_s1)) return -1; + if (CX_OK != point_mul(G, secnonce->k_2, &R_s2)) return -1; + + if (0 > crypto_get_compressed_pubkey(R_s1.raw, pubnonce->R_s1)) return -1; + if (0 > crypto_get_compressed_pubkey(R_s2.raw, pubnonce->R_s2)) return -1; + + return 0; +} + +int musig_nonce_agg(const musig_pubnonce_t pubnonces[], size_t n_keys, musig_pubnonce_t *out) { + for (size_t j = 1; j <= 2; j++) { + point_t R_j; + set_point_infinite(&R_j); + for (size_t i = 0; i < n_keys; i++) { + point_t R_ij; + if (0 > cpoint(&pubnonces[i].raw[(j - 1) * 33], &R_ij)) { + PRINTF("Musig2 nonce aggregation: invalid contribution from cosigner %d\n", i); + return -i - 1; + } + point_add(&R_j, &R_ij, &R_j); + } + + if (is_point_infinite(&R_j)) { + memset(&out->raw[(j - 1) * 33], 0, 33); + } else { + crypto_get_compressed_pubkey(R_j.raw, &out->raw[(j - 1) * 33]); + } + } + return 0; +} + +static int apply_tweak(musig_keyagg_context_t *ctx, const uint8_t tweak[static 32], bool is_xonly) { + if (tweak == NULL || ctx == NULL) { + return -1; + } + + uint8_t g[32]; + memset(g, 0, 31); + g[31] = 1; // g = 1 + + if (is_xonly && !has_even_y(&ctx->Q)) { + // g = n - 1 + if (CX_OK != cx_math_sub_no_throw(g, secp256k1_n, g, 32)) { + return -1; + }; + } + + int diff; + if (CX_OK != cx_math_cmp_no_throw(tweak, secp256k1_n, 32, &diff)) { + return -1; + } + if (diff >= 0) { + PRINTF("The tweak must be less than n\n"); + return -1; + } + + // compute Q * g (in place) + + if (point_mul(&ctx->Q, g, &ctx->Q) != CX_OK) { + return -1; + } + + point_t T; // compute T = tweak * G + if (point_mul(G, tweak, &T) != CX_OK) { + return -1; + } + + // compute the resulting tweaked point g * Q + tweak * G + point_add(&ctx->Q, &T, &ctx->Q); + if (is_point_infinite(&ctx->Q)) { + PRINTF("The result of tweaking cannot be infinity\n"); + return -1; + } + + // gacc := g * gacc % n + if (CX_OK != cx_math_multm_no_throw(ctx->gacc, g, ctx->gacc, secp256k1_n, 32)) { + return -1; + } + + // tacc := (g * tacc + t) % n + if (CX_OK != cx_math_multm_no_throw(ctx->tacc, g, ctx->tacc, secp256k1_n, 32)) { + return -1; + } + if (CX_OK != cx_math_addm_no_throw(ctx->tacc, ctx->tacc, tweak, secp256k1_n, 32)) { + return -1; + } + + return 0; +} + +static int musig_get_session_values(const musig_session_context_t *session_ctx, + point_t *Q, + uint8_t gacc[static 32], + uint8_t tacc[static 32], + uint8_t b[static 32], + point_t *R, + uint8_t e[static 32]) { + cx_sha256_t hash_context; + + // Perform key aggregation and tweaking + musig_keyagg_context_t keyagg_ctx; + musig_key_agg(session_ctx->pubkeys, session_ctx->n_keys, &keyagg_ctx); + for (size_t i = 0; i < session_ctx->n_tweaks; i++) { + if (0 > apply_tweak(&keyagg_ctx, session_ctx->tweaks[i], session_ctx->is_xonly[i])) { + return -1; + }; + } + + // Copy Q, gacc, tacc from keyagg_ctx + memcpy(Q, &keyagg_ctx.Q, sizeof(point_t)); + memcpy(gacc, keyagg_ctx.gacc, 32); + memcpy(tacc, keyagg_ctx.tacc, 32); + + // Calculate b + crypto_tr_tagged_hash_init(&hash_context, BIP0327_noncecoef_tag, sizeof(BIP0327_noncecoef_tag)); + crypto_hash_update(&hash_context.header, session_ctx->aggnonce->raw, 66); + crypto_hash_update(&hash_context.header, Q->x, 32); + crypto_hash_update(&hash_context.header, session_ctx->msg, session_ctx->msg_len); + crypto_hash_digest(&hash_context.header, b, 32); + + // Calculate R + point_t R_1, R_2; + if (0 > cpoint_ext(session_ctx->aggnonce->R_s1, &R_1)) { + return -1; + }; + if (0 > cpoint_ext(session_ctx->aggnonce->R_s2, &R_2)) { + return -1; + }; + + // R2 := b*R2 + if (point_mul(&R_2, b, &R_2) != CX_OK) { + return -1; + } + + if (CX_OK != point_add(&R_1, &R_2, R)) { + return -1; + }; + if (is_point_infinite(R)) { + memcpy(R->raw, G, sizeof(point_t)); + } + + // Calculate e + crypto_tr_tagged_hash_init(&hash_context, BIP0340_challenge_tag, sizeof(BIP0340_challenge_tag)); + crypto_hash_update(&hash_context.header, R->x, 32); + crypto_hash_update(&hash_context.header, Q->x, 32); + crypto_hash_update(&hash_context.header, session_ctx->msg, session_ctx->msg_len); + crypto_hash_digest(&hash_context.header, e, 32); + return 0; +} + +int musig_get_session_key_agg_coeff(const musig_session_context_t *session_ctx, + const point_t *P, + uint8_t out[static 32]) { + // Convert point to compressed public key + plain_pk_t pk; + crypto_get_compressed_pubkey(P->raw, pk); + + // Check if pk is in pubkeys + bool found = false; + for (size_t i = 0; i < session_ctx->n_keys; i++) { + if (memcmp(pk, session_ctx->pubkeys[i], sizeof(plain_pk_t)) == 0) { + found = true; + break; + } + } + if (!found) { + return -1; // Public key not found in the list of pubkeys + } + + musig_key_agg_coeff(session_ctx->pubkeys, session_ctx->n_keys, pk, out); + return 0; +} + +int musig_sign(musig_secnonce_t *secnonce, + const uint8_t sk[static 32], + const musig_session_context_t *session_ctx, + uint8_t psig[static 32]) { + point_t Q; + uint8_t gacc[32]; + uint8_t tacc[32]; + uint8_t b[32]; + point_t R; + uint8_t e[32]; + + int diff; + + if (0 > musig_get_session_values(session_ctx, &Q, gacc, tacc, b, &R, e)) { + return -1; + } + + uint8_t k_1[32]; + uint8_t k_2[32]; + memcpy(k_1, secnonce->k_1, 32); + memcpy(k_2, secnonce->k_2, 32); + + // paranoia: since reusing nonces is catastrophic, we make sure that they are zeroed out and + // work with a local copy instead + explicit_bzero(secnonce->k_1, sizeof(secnonce->k_1)); + explicit_bzero(secnonce->k_2, sizeof(secnonce->k_2)); + + if (CX_OK != cx_math_cmp_no_throw(k_1, secp256k1_n, 32, &diff)) { + return -1; + } + if (is_array_zero(k_1, sizeof(k_1)) || diff >= 0) { + PRINTF("first secnonce value is out of range\n"); + return -1; + } + if (CX_OK != cx_math_cmp_no_throw(k_2, secp256k1_n, 32, &diff)) { + return -1; + } + if (is_array_zero(k_2, sizeof(k_2)) || diff >= 0) { + PRINTF("second secnonce value is out of range\n"); + return -1; + } + + if (!has_even_y(&R)) { + if (CX_OK != cx_math_sub_no_throw(k_1, secp256k1_n, k_1, 32)) { + return -1; + }; + if (CX_OK != cx_math_sub_no_throw(k_2, secp256k1_n, k_2, 32)) { + return -1; + }; + } + + if (CX_OK != cx_math_cmp_no_throw(sk, secp256k1_n, 32, &diff)) { + return -1; + } + if (is_array_zero(sk, 32) || diff >= 0) { + PRINTF("secret key value is out of range\n"); + return -1; + } + + bool err = false; + + // Put together all the variables that we want to always zero out before returning. + // As an excess of safety, we put here any variable that is (directly or indirectly) derived + // from the secret during the computation of the signature + struct { + uint8_t d[32]; + point_t P; + uint8_t ead[32]; + uint8_t s[32]; + } secrets; + + do { // executed only once, to allow for an easy way to break out of the block + // P = d_ * G + if (point_mul(G, sk, &secrets.P) != CX_OK) { + err = true; + break; + } + + plain_pk_t pk; + crypto_get_compressed_pubkey(secrets.P.raw, pk); + + if (memcmp(pk, secnonce->pk, 33) != 0) { + err = true; + PRINTF("Public key does not match nonce_gen argument\n"); + break; + } + + uint8_t a[32]; + if (0 > musig_get_session_key_agg_coeff(session_ctx, &secrets.P, a)) { + err = true; + break; + } + + // g = 1 if has_even_y(Q) else n - 1 + uint8_t g[32]; + memset(g, 0, 31); + g[31] = 1; // g = 1 + if (!has_even_y(&Q)) { + // g = n - 1 + if (CX_OK != cx_math_sub_no_throw(g, secp256k1_n, g, 32)) { + err = true; + break; + }; + } + + // d_ in the reference implementation is just sk + // d = g * gacc % n + if (CX_OK != cx_math_multm_no_throw(secrets.d, g, gacc, secp256k1_n, 32)) { + err = true; + break; + } + // d = g * gacc * d_ % n + if (CX_OK != cx_math_multm_no_throw(secrets.d, secrets.d, sk, secp256k1_n, 32)) { + err = true; + break; + } + + uint8_t bk_2[32]; // b * k_2 + if (CX_OK != cx_math_multm_no_throw(bk_2, b, k_2, secp256k1_n, 32)) { + err = true; + break; + } + + // e * a * d + if (CX_OK != cx_math_multm_no_throw(secrets.ead, e, a, secp256k1_n, 32)) { + err = true; + break; + } + if (CX_OK != cx_math_multm_no_throw(secrets.ead, secrets.ead, secrets.d, secp256k1_n, 32)) { + err = true; + break; + } + + // s = k_1 + b * k_2 + e * a * d + memcpy(secrets.s, k_1, 32); + if (CX_OK != cx_math_addm_no_throw(secrets.s, secrets.s, bk_2, secp256k1_n, 32)) { + err = true; + break; + } + if (CX_OK != cx_math_addm_no_throw(secrets.s, secrets.s, secrets.ead, secp256k1_n, 32)) { + err = true; + break; + } + + memcpy(psig, secrets.s, 32); + } while (false); + + // make sure to zero out any variable derived from secrets before returning + explicit_bzero(&secrets, sizeof(secrets)); + + if (err) { + return -1; + } + + return 0; +} diff --git a/src/musig/musig.h b/src/musig/musig.h new file mode 100644 index 000000000..1632bd085 --- /dev/null +++ b/src/musig/musig.h @@ -0,0 +1,129 @@ +#pragma once + +#include +#include + +#define MUSIG_PUBNONCE_SIZE 66 + +// TODO: rename once BIP number is assigned +static uint8_t BIP_MUSIG_CHAINCODE[32] = { + 0x86, 0x80, 0x87, 0xCA, 0x02, 0xA6, 0xF9, 0x74, 0xC4, 0x59, 0x89, 0x24, 0xC3, 0x6B, 0x57, 0x76, + 0x2D, 0x32, 0xCB, 0x45, 0x71, 0x71, 0x67, 0xE3, 0x00, 0x62, 0x2C, 0x71, 0x67, 0xE3, 0x89, 0x65}; + +typedef uint8_t plain_pk_t[33]; +typedef uint8_t xonly_pk_t[32]; + +// An uncompressed pubkey, encoded as 04||x||y, where x and y are 32-byte big-endian coordinates. +// If the first byte (prefix) is 0, encodes the point at infinity. +typedef struct { + union { + uint8_t raw[65]; + struct { + uint8_t prefix; // 0 for the point at infinity, otherwise 4. + uint8_t x[32]; + uint8_t y[32]; + }; + }; +} point_t; + +typedef struct musig_keyagg_context_s { + point_t Q; + uint8_t gacc[32]; + uint8_t tacc[32]; +} musig_keyagg_context_t; + +typedef struct musig_secnonce_s { + uint8_t k_1[32]; + uint8_t k_2[32]; + uint8_t pk[33]; +} musig_secnonce_t; + +typedef struct musig_pubnonce_s { + union { + struct { + uint8_t R_s1[33]; + uint8_t R_s2[33]; + }; + uint8_t raw[66]; + }; +} musig_pubnonce_t; + +typedef struct musig_session_context_s { + musig_pubnonce_t *aggnonce; + size_t n_keys; + plain_pk_t *pubkeys; + size_t n_tweaks; + uint8_t **tweaks; + bool *is_xonly; + uint8_t *msg; + size_t msg_len; +} musig_session_context_t; + +/** + * Computes the KeyAgg Context per BIP-0327. + * + * @param[in] pubkeys + * Pointer to a list of pubkeys. + * @param[in] n_keys + * Number of pubkeys. + * @param[out] musig_keyagg_context_t + * Pointer to receive the musig KeyAgg Context. + * + * @return 0 on success, a negative number in case of error. + */ +int musig_key_agg(const plain_pk_t pubkeys[], size_t n_keys, musig_keyagg_context_t *ctx); + +/** + * Generates secret and public nonces (round 1 of MuSig per BIP-0327). + * + * @param[in] rand + * The randomness to use. + * @param[in] pk + * The 33-byte public key of the signer. + * @param[in] aggpk + * The 32-byte x-only aggregate public key. + * @param[out] secnonce + * Pointer to receive the secret nonce. + * @param[out] pubnonce + * Pointer to receive the public nonce. + * + * @return 0 on success, a negative number in case of error. + */ +int musig_nonce_gen(const uint8_t rand[32], + const plain_pk_t pk, + const xonly_pk_t aggpk, + musig_secnonce_t *secnonce, + musig_pubnonce_t *pubnonce); + +/** + * Generates the aggregate nonce (nonce_agg in the reference implementation). + * + * @param[in] rand + * A list of musig_pubnonce_t, the pubnonces of all the participants. + * @param[in] n_keys + * Number of pubkeys. + * @param[out] out + * Pointer to receive the aggregate nonce. + * + * @return 0 on success, a negative number in case of error. On error, `-i - 1` is returned if the + * nonce provided by the cosigner with index `i` is invalid, in order to allow blaming for a + * disruptive signer. + */ +int musig_nonce_agg(const musig_pubnonce_t pubnonces[], size_t n_keys, musig_pubnonce_t *out); + +/** + * Computes the partial signature (round 2 of MuSig per BIP-0327). + * + * @param[in] secnonce + * The secret nonce. + * @param[in] session_ctx + * The session context. + * @param[out] psig + * Pointer to receive the partial signature. + * + * @return 0 on success, a negative number in case of error. + */ +int musig_sign(musig_secnonce_t *secnonce, + const uint8_t *sk, + const musig_session_context_t *session_ctx, + uint8_t psig[static 32]); diff --git a/src/musig/musig_sessions.c b/src/musig/musig_sessions.c new file mode 100644 index 000000000..174e111b3 --- /dev/null +++ b/src/musig/musig_sessions.c @@ -0,0 +1,132 @@ +#include + +#include "cx.h" + +#include "musig_sessions.h" +#include "../crypto.h" + +typedef struct { + // Aligning by 4 is necessary due to platform limitations. + // Aligning by 64 further guarantees that each session occupies exactly + // a single NVRAM page, minimizing the number of writes. + __attribute__((aligned(64))) musig_psbt_session_t sessions[MAX_N_MUSIG_SESSIONS]; +} musig_persistent_storage_t; + +const musig_persistent_storage_t N_musig_storage_real; +#define N_musig_storage (*(const volatile musig_persistent_storage_t *) PIC(&N_musig_storage_real)) + +static bool is_all_zeros(const uint8_t *array, size_t size) { + for (size_t i = 0; i < size; ++i) { + if (array[i] != 0) { + return false; + } + } + return true; +} + +static bool musigsession_pop(const uint8_t psbt_session_id[static 32], musig_psbt_session_t *out) { + for (int i = 0; i < MAX_N_MUSIG_SESSIONS; i++) { + if (memcmp(psbt_session_id, (const void *) N_musig_storage.sessions[i]._id, 32) == 0) { + if (out != NULL) { + memcpy(out, + (const void *) &N_musig_storage.sessions[i], + sizeof(musig_psbt_session_t)); + } + uint8_t zeros[sizeof(musig_psbt_session_t)] = {0}; + nvm_write((void *) &N_musig_storage.sessions[i], + (void *) zeros, + sizeof(musig_psbt_session_t)); + + return true; + } + } + return false; +} + +static void musigsession_init_randomness(musig_psbt_session_t *session) { + // it is extremely important that the randomness is initialized with a cryptographically strong + // random number generator + cx_get_random_bytes(session->_rand_root, 32); +} + +static void musigsession_store(const uint8_t psbt_session_id[static 32], + const musig_psbt_session_t *session) { + // make sure that no session with the same id exists; delete it otherwise + musigsession_pop(psbt_session_id, NULL); + + int i; + for (i = 0; i < MAX_N_MUSIG_SESSIONS; i++) { + if (is_all_zeros((uint8_t *) &N_musig_storage.sessions[i], sizeof(musig_psbt_session_t))) { + break; + } + } + if (i >= MAX_N_MUSIG_SESSIONS) { + // no free slot found, delete the first by default + // TODO: should we use a LIFO structure? Could add a counter to musig_psbt_session_t + i = 0; + } + // replace the chosen slot + nvm_write((void *) &N_musig_storage.sessions[i], + (void *) session, + sizeof(musig_psbt_session_t)); +} + +void compute_rand_i_j(const musig_psbt_session_t *psbt_session, + int i, + int j, + uint8_t out[static 32]) { + cx_sha256_t hash_context; + cx_sha256_init(&hash_context); + crypto_hash_update(&hash_context.header, psbt_session->_rand_root, CX_SHA256_SIZE); + crypto_hash_update_u32(&hash_context.header, (uint32_t) i); + crypto_hash_update_u32(&hash_context.header, (uint32_t) j); + crypto_hash_digest(&hash_context.header, out, 32); +} + +const musig_psbt_session_t *musigsession_round1_initialize( + uint8_t psbt_session_id[static 32], + musig_signing_state_t *musig_signing_state) { + // if an existing session for psbt_session_id exists, delete it + if (musigsession_pop(psbt_session_id, NULL)) { + // We wouldn't expect this: probably the client sent the same psbt for + // round 1 twice, without adding the pubnonces to the psbt after the first round. + // We delete the old session and start a fresh one, but we print a + // warning if in debug mode. + PRINTF("Session with the same id already existing\n"); + } + + if (memcmp(musig_signing_state->_round1._id, psbt_session_id, 32) != 0) { + // first input/placeholder pair using this session: initialize the session + memcpy(musig_signing_state->_round1._id, psbt_session_id, 32); + musigsession_init_randomness(&musig_signing_state->_round1); + } + + return &musig_signing_state->_round1; +} + +const musig_psbt_session_t *musigsession_round2_initialize( + uint8_t psbt_session_id[static 32], + musig_signing_state_t *musig_signing_state) { + if (memcmp(musig_signing_state->_round2._id, psbt_session_id, 32) != 0) { + // get and delete the musig session from permanent storage + if (!musigsession_pop(psbt_session_id, &musig_signing_state->_round2)) { + // The PSBT contains a partial nonce, but we do not have the corresponding psbt + // session in storage. Either it was deleted, or the pubnonces were not real. Either + // way, we cannot continue. + PRINTF("Missing MuSig2 session\n"); + return NULL; + } + } + + return &musig_signing_state->_round2; +} + +void musigsession_commit(musig_signing_state_t *musig_signing_state) { + uint8_t acc = 0; + for (size_t i = 0; i < sizeof(musig_signing_state->_round1); i++) { + acc |= musig_signing_state->_round1._id[i]; + } + if (acc != 0) { + musigsession_store(musig_signing_state->_round1._id, &musig_signing_state->_round1); + } +} diff --git a/src/musig/musig_sessions.h b/src/musig/musig_sessions.h new file mode 100644 index 000000000..98bad1c0b --- /dev/null +++ b/src/musig/musig_sessions.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include "musig.h" + +/** + * This module encapsulates the logic to manage the psbt-level MuSig2 sessions. See the + * documentation in docs/musig.md for more information. + */ + +// the maximum number of musig sessions that are stored in permanent memory +#define MAX_N_MUSIG_SESSIONS 8 + +// state of a musig_psbt_session. Members are private and must not be accessed directly by any +// code outside of musig_sessions.c. +typedef struct musig_psbt_session_s { + uint8_t _id[32]; + uint8_t _rand_root[32]; +} musig_psbt_session_t; + +// volatile state for musig signing. Members are private and must not be accessed directly by any +// code outside of musig_sessions.c. +typedef struct musig_signing_state_s { + // a session created during round 1; if signing completes (and in no other case), it is moved to + // the persistent storage + musig_psbt_session_t _round1; + // a session that was removed from the persistent storage before any partial signature is + // returned during round 2. It is deleted at the end of signing, and must _never_ be used again. + musig_psbt_session_t _round2; +} musig_signing_state_t; + +/** + * Given a musig psbt session, computes the synthetic randomness for a given + * (input_index, placeholder_index) pair. + */ +void compute_rand_i_j(const musig_psbt_session_t *psbt_session, + int input_index, + int placeholder_index, + uint8_t out[static 32]); + +/** + * Handles the creation of a new musig psbt session into the volatile memory, or its retrieval (if + * the session already exists). + * It must be called when starting MuSig2 round 1 for a fixed input/placeholder pair, during the + * signing process. + * + * @param[in] psbt_session_id + * Pointer to the musig psbt session id. + * @param[in] musig_signing_state + * Pointer to the musig signing state. + * + * @return a musig_psbt_session_t on success, NULL on failure. + */ +__attribute__((warn_unused_result)) const musig_psbt_session_t *musigsession_round1_initialize( + uint8_t psbt_session_id[static 32], + musig_signing_state_t *musig_signing_state); + +/** + * Handles the retrieval of a musig psbt session from volatile memory (if it exists already) or its + * retrieval from the persistent memory otherwise. The session is guaranteed to be deleted from the + * persistent memory prior to returning. + * It must be called when starting MuSig2 round 2 for a fixed input/placeholder pair, during the + * signing process. + * + * @param[in] psbt_session_id + * Pointer to the musig psbt session id. + * @param[in] musig_signing_state + * Pointer to the musig signing state. + * + * @return a musig_psbt_session_t on success, NULL on failure. + */ +__attribute__((warn_unused_result)) const musig_psbt_session_t *musigsession_round2_initialize( + uint8_t psbt_session_id[static 32], + musig_signing_state_t *musig_signing_state); + +/** + * If a session produced in round 1 is active in volatile memory, it is stored in the persistent + * memory. + * This must be called at the end of a successful signing flow, after all the public nonces have + * been returned to the client. It must _not_ be called if any error occurs, or if the signing + * process is aborted for any reason. + * + * @param[in] psbt_session_id + * Pointer to the musig psbt session id. + * @param[in] musig_signing_state + * Pointer to the musig signing state. + */ +void musigsession_commit(musig_signing_state_t *musig_signing_state); \ No newline at end of file diff --git a/src/secp256k1.c b/src/secp256k1.c new file mode 100644 index 000000000..2ddb714a6 --- /dev/null +++ b/src/secp256k1.c @@ -0,0 +1,23 @@ +#include "secp256k1.h" + +// clang-format off +const uint8_t secp256k1_generator[65] = { + 0x04, + 0x79, 0xBE, 0x66, 0x7E, 0xF9, 0xDC, 0xBB, 0xAC, 0x55, 0xA0, 0x62, 0x95, 0xCE, 0x87, 0x0B, 0x07, + 0x02, 0x9B, 0xFC, 0xDB, 0x2D, 0xCE, 0x28, 0xD9, 0x59, 0xF2, 0x81, 0x5B, 0x16, 0xF8, 0x17, 0x98, + 0x48, 0x3A, 0xDA, 0x77, 0x26, 0xA3, 0xC4, 0x65, 0x5D, 0xA4, 0xFB, 0xFC, 0x0E, 0x11, 0x08, 0xA8, + 0xFD, 0x17, 0xB4, 0x48, 0xA6, 0x85, 0x54, 0x19, 0x9C, 0x47, 0xD0, 0x8F, 0xFB, 0x10, 0xD4, 0xB8}; + +const uint8_t secp256k1_p[32] = { + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xfc, 0x2f}; + +const uint8_t secp256k1_n[32] = { + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, + 0xba, 0xae, 0xdc, 0xe6, 0xaf, 0x48, 0xa0, 0x3b, 0xbf, 0xd2, 0x5e, 0x8c, 0xd0, 0x36, 0x41, 0x41}; + +const uint8_t secp256k1_sqr_exponent[32] = { + 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xbf, 0xff, 0xff, 0x0c}; + +// clang-format on diff --git a/src/secp256k1.h b/src/secp256k1.h new file mode 100644 index 000000000..c6ead0f33 --- /dev/null +++ b/src/secp256k1.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +/** + * Generator for secp256k1, value 'g' defined in "Standards for Efficient Cryptography" + * (SEC2) 2.7.1. + */ +extern const uint8_t secp256k1_generator[65]; + +/** + * Modulo for secp256k1 + */ +extern const uint8_t secp256k1_p[32]; + +/** + * Curve order for secp256k1 + */ +extern const uint8_t secp256k1_n[32]; + +/** + * (p + 1)/4, used to calculate square roots in secp256k1 + */ +extern const uint8_t secp256k1_sqr_exponent[32]; diff --git a/test_utils/bip0327.py b/test_utils/bip0327.py new file mode 100644 index 000000000..79149743f --- /dev/null +++ b/test_utils/bip0327.py @@ -0,0 +1,465 @@ +# from https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0327/reference.py +# Distributed as part of BIP-0327 under the BSD-3-Clause license + +# BIP327 reference implementation +# +# WARNING: This implementation is for demonstration purposes only and _not_ to +# be used in production environments. The code is vulnerable to timing attacks, +# for example. + +# fmt: off + +from typing import List, Optional, Tuple, NewType, NamedTuple +import hashlib +import secrets + +# +# The following helper functions were copied from the BIP-340 reference implementation: +# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py +# + +p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F +n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 + +# Points are tuples of X and Y coordinates and the point at infinity is +# represented by the None keyword. +G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) + +Point = Tuple[int, int] + +# This implementation can be sped up by storing the midstate after hashing +# tag_hash instead of rehashing it all the time. +def tagged_hash(tag: str, msg: bytes) -> bytes: + tag_hash = hashlib.sha256(tag.encode()).digest() + return hashlib.sha256(tag_hash + tag_hash + msg).digest() + +def is_infinite(P: Optional[Point]) -> bool: + return P is None + +def x(P: Point) -> int: + assert not is_infinite(P) + return P[0] + +def y(P: Point) -> int: + assert not is_infinite(P) + return P[1] + +def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: + if P1 is None: + return P2 + if P2 is None: + return P1 + if (x(P1) == x(P2)) and (y(P1) != y(P2)): + return None + if P1 == P2: + lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p + else: + lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p + x3 = (lam * lam - x(P1) - x(P2)) % p + return (x3, (lam * (x(P1) - x3) - y(P1)) % p) + +def point_mul(P: Optional[Point], n: int) -> Optional[Point]: + R = None + for i in range(256): + if (n >> i) & 1: + R = point_add(R, P) + P = point_add(P, P) + return R + +def bytes_from_int(x: int) -> bytes: + return x.to_bytes(32, byteorder="big") + +def lift_x(b: bytes) -> Optional[Point]: + x = int_from_bytes(b) + if x >= p: + return None + y_sq = (pow(x, 3, p) + 7) % p + y = pow(y_sq, (p + 1) // 4, p) + if pow(y, 2, p) != y_sq: + return None + return (x, y if y & 1 == 0 else p-y) + +def int_from_bytes(b: bytes) -> int: + return int.from_bytes(b, byteorder="big") + +def has_even_y(P: Point) -> bool: + assert not is_infinite(P) + return y(P) % 2 == 0 + +def schnorr_verify(msg: bytes, pubkey: bytes, sig: bytes) -> bool: + if len(msg) != 32: + raise ValueError('The message must be a 32-byte array.') + if len(pubkey) != 32: + raise ValueError('The public key must be a 32-byte array.') + if len(sig) != 64: + raise ValueError('The signature must be a 64-byte array.') + P = lift_x(pubkey) + r = int_from_bytes(sig[0:32]) + s = int_from_bytes(sig[32:64]) + if (P is None) or (r >= p) or (s >= n): + return False + e = int_from_bytes(tagged_hash("BIP0340/challenge", sig[0:32] + pubkey + msg)) % n + R = point_add(point_mul(G, s), point_mul(P, n - e)) + if (R is None) or (not has_even_y(R)) or (x(R) != r): + return False + return True + +# +# End of helper functions copied from BIP-340 reference implementation. +# + +PlainPk = NewType('PlainPk', bytes) +XonlyPk = NewType('XonlyPk', bytes) + +# There are two types of exceptions that can be raised by this implementation: +# - ValueError for indicating that an input doesn't conform to some function +# precondition (e.g. an input array is the wrong length, a serialized +# representation doesn't have the correct format). +# - InvalidContributionError for indicating that a signer (or the +# aggregator) is misbehaving in the protocol. +# +# Assertions are used to (1) satisfy the type-checking system, and (2) check for +# inconvenient events that can't happen except with negligible probability (e.g. +# output of a hash function is 0) and can't be manually triggered by any +# signer. + +# This exception is raised if a party (signer or nonce aggregator) sends invalid +# values. Actual implementations should not crash when receiving invalid +# contributions. Instead, they should hold the offending party accountable. +class InvalidContributionError(Exception): + def __init__(self, signer, contrib): + self.signer = signer + # contrib is one of "pubkey", "pubnonce", "aggnonce", or "psig". + self.contrib = contrib + +infinity = None + +def xbytes(P: Point) -> bytes: + return bytes_from_int(x(P)) + +def cbytes(P: Point) -> bytes: + a = b'\x02' if has_even_y(P) else b'\x03' + return a + xbytes(P) + +def cbytes_ext(P: Optional[Point]) -> bytes: + if is_infinite(P): + return (0).to_bytes(33, byteorder='big') + assert P is not None + return cbytes(P) + +def point_negate(P: Optional[Point]) -> Optional[Point]: + if P is None: + return P + return (x(P), p - y(P)) + +def cpoint(x: bytes) -> Point: + if len(x) != 33: + raise ValueError('x is not a valid compressed point.') + P = lift_x(x[1:33]) + if P is None: + raise ValueError('x is not a valid compressed point.') + if x[0] == 2: + return P + elif x[0] == 3: + P = point_negate(P) + assert P is not None + return P + else: + raise ValueError('x is not a valid compressed point.') + +def cpoint_ext(x: bytes) -> Optional[Point]: + if x == (0).to_bytes(33, 'big'): + return None + else: + return cpoint(x) + +# Return the plain public key corresponding to a given secret key +def individual_pk(seckey: bytes) -> PlainPk: + d0 = int_from_bytes(seckey) + if not (1 <= d0 <= n - 1): + raise ValueError('The secret key must be an integer in the range 1..n-1.') + P = point_mul(G, d0) + assert P is not None + return PlainPk(cbytes(P)) + +def key_sort(pubkeys: List[PlainPk]) -> List[PlainPk]: + pubkeys.sort() + return pubkeys + +KeyAggContext = NamedTuple('KeyAggContext', [('Q', Point), + ('gacc', int), + ('tacc', int)]) + +def get_xonly_pk(keyagg_ctx: KeyAggContext) -> XonlyPk: + Q, _, _ = keyagg_ctx + return XonlyPk(xbytes(Q)) + +def key_agg(pubkeys: List[PlainPk]) -> KeyAggContext: + pk2 = get_second_key(pubkeys) + u = len(pubkeys) + Q = infinity + for i in range(u): + try: + P_i = cpoint(pubkeys[i]) + except ValueError: + raise InvalidContributionError(i, "pubkey") + a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2) + Q = point_add(Q, point_mul(P_i, a_i)) + # Q is not the point at infinity except with negligible probability. + assert(Q is not None) + gacc = 1 + tacc = 0 + return KeyAggContext(Q, gacc, tacc) + +def hash_keys(pubkeys: List[PlainPk]) -> bytes: + return tagged_hash('KeyAgg list', b''.join(pubkeys)) + +def get_second_key(pubkeys: List[PlainPk]) -> PlainPk: + u = len(pubkeys) + for j in range(1, u): + if pubkeys[j] != pubkeys[0]: + return pubkeys[j] + return PlainPk(b'\x00'*33) + +def key_agg_coeff(pubkeys: List[PlainPk], pk_: PlainPk) -> int: + pk2 = get_second_key(pubkeys) + return key_agg_coeff_internal(pubkeys, pk_, pk2) + +def key_agg_coeff_internal(pubkeys: List[PlainPk], pk_: PlainPk, pk2: PlainPk) -> int: + L = hash_keys(pubkeys) + if pk_ == pk2: + return 1 + return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n + +def apply_tweak(keyagg_ctx: KeyAggContext, tweak: bytes, is_xonly: bool) -> KeyAggContext: + if len(tweak) != 32: + raise ValueError('The tweak must be a 32-byte array.') + Q, gacc, tacc = keyagg_ctx + if is_xonly and not has_even_y(Q): + g = n - 1 + else: + g = 1 + t = int_from_bytes(tweak) + if t >= n: + raise ValueError('The tweak must be less than n.') + Q_ = point_add(point_mul(Q, g), point_mul(G, t)) + if Q_ is None: + raise ValueError('The result of tweaking cannot be infinity.') + gacc_ = g * gacc % n + tacc_ = (t + g * tacc) % n + return KeyAggContext(Q_, gacc_, tacc_) + +def bytes_xor(a: bytes, b: bytes) -> bytes: + return bytes(x ^ y for x, y in zip(a, b)) + +def nonce_hash(rand: bytes, pk: PlainPk, aggpk: XonlyPk, i: int, msg_prefixed: bytes, extra_in: bytes) -> int: + buf = b'' + buf += rand + buf += len(pk).to_bytes(1, 'big') + buf += pk + buf += len(aggpk).to_bytes(1, 'big') + buf += aggpk + buf += msg_prefixed + buf += len(extra_in).to_bytes(4, 'big') + buf += extra_in + buf += i.to_bytes(1, 'big') + return int_from_bytes(tagged_hash('MuSig/nonce', buf)) + +def nonce_gen_internal(rand_: bytes, sk: Optional[bytes], pk: PlainPk, aggpk: Optional[XonlyPk], msg: Optional[bytes], extra_in: Optional[bytes]) -> Tuple[bytearray, bytes]: + if sk is not None: + rand = bytes_xor(sk, tagged_hash('MuSig/aux', rand_)) + else: + rand = rand_ + if aggpk is None: + aggpk = XonlyPk(b'') + if msg is None: + msg_prefixed = b'\x00' + else: + msg_prefixed = b'\x01' + msg_prefixed += len(msg).to_bytes(8, 'big') + msg_prefixed += msg + if extra_in is None: + extra_in = b'' + k_1 = nonce_hash(rand, pk, aggpk, 0, msg_prefixed, extra_in) % n + k_2 = nonce_hash(rand, pk, aggpk, 1, msg_prefixed, extra_in) % n + # k_1 == 0 or k_2 == 0 cannot occur except with negligible probability. + assert k_1 != 0 + assert k_2 != 0 + R_s1 = point_mul(G, k_1) + R_s2 = point_mul(G, k_2) + assert R_s1 is not None + assert R_s2 is not None + pubnonce = cbytes(R_s1) + cbytes(R_s2) + secnonce = bytearray(bytes_from_int(k_1) + bytes_from_int(k_2) + pk) + return secnonce, pubnonce + +def nonce_gen(sk: Optional[bytes], pk: PlainPk, aggpk: Optional[XonlyPk], msg: Optional[bytes], extra_in: Optional[bytes]) -> Tuple[bytearray, bytes]: + if sk is not None and len(sk) != 32: + raise ValueError('The optional byte array sk must have length 32.') + if aggpk is not None and len(aggpk) != 32: + raise ValueError('The optional byte array aggpk must have length 32.') + rand_ = secrets.token_bytes(32) + return nonce_gen_internal(rand_, sk, pk, aggpk, msg, extra_in) + +def nonce_agg(pubnonces: List[bytes]) -> bytes: + u = len(pubnonces) + aggnonce = b'' + for j in (1, 2): + R_j = infinity + for i in range(u): + try: + R_ij = cpoint(pubnonces[i][(j-1)*33:j*33]) + except ValueError: + raise InvalidContributionError(i, "pubnonce") + R_j = point_add(R_j, R_ij) + aggnonce += cbytes_ext(R_j) + return aggnonce + +SessionContext = NamedTuple('SessionContext', [('aggnonce', bytes), + ('pubkeys', List[PlainPk]), + ('tweaks', List[bytes]), + ('is_xonly', List[bool]), + ('msg', bytes)]) + +def key_agg_and_tweak(pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool]): + if len(tweaks) != len(is_xonly): + raise ValueError('The `tweaks` and `is_xonly` arrays must have the same length.') + keyagg_ctx = key_agg(pubkeys) + v = len(tweaks) + for i in range(v): + keyagg_ctx = apply_tweak(keyagg_ctx, tweaks[i], is_xonly[i]) + return keyagg_ctx + +def get_session_values(session_ctx: SessionContext) -> Tuple[Point, int, int, int, Point, int]: + (aggnonce, pubkeys, tweaks, is_xonly, msg) = session_ctx + Q, gacc, tacc = key_agg_and_tweak(pubkeys, tweaks, is_xonly) + b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + xbytes(Q) + msg)) % n + try: + R_1 = cpoint_ext(aggnonce[0:33]) + R_2 = cpoint_ext(aggnonce[33:66]) + except ValueError: + # Nonce aggregator sent invalid nonces + raise InvalidContributionError(None, "aggnonce") + R_ = point_add(R_1, point_mul(R_2, b)) + R = R_ if not is_infinite(R_) else G + assert R is not None + e = int_from_bytes(tagged_hash('BIP0340/challenge', xbytes(R) + xbytes(Q) + msg)) % n + return (Q, gacc, tacc, b, R, e) + +def get_session_key_agg_coeff(session_ctx: SessionContext, P: Point) -> int: + (_, pubkeys, _, _, _) = session_ctx + pk = PlainPk(cbytes(P)) + if pk not in pubkeys: + raise ValueError('The signer\'s pubkey must be included in the list of pubkeys.') + return key_agg_coeff(pubkeys, pk) + +def sign(secnonce: bytearray, sk: bytes, session_ctx: SessionContext) -> bytes: + (Q, gacc, _, b, R, e) = get_session_values(session_ctx) + k_1_ = int_from_bytes(secnonce[0:32]) + k_2_ = int_from_bytes(secnonce[32:64]) + # Overwrite the secnonce argument with zeros such that subsequent calls of + # sign with the same secnonce raise a ValueError. + secnonce[:64] = bytearray(b'\x00'*64) + if not 0 < k_1_ < n: + raise ValueError('first secnonce value is out of range.') + if not 0 < k_2_ < n: + raise ValueError('second secnonce value is out of range.') + k_1 = k_1_ if has_even_y(R) else n - k_1_ + k_2 = k_2_ if has_even_y(R) else n - k_2_ + d_ = int_from_bytes(sk) + if not 0 < d_ < n: + raise ValueError('secret key value is out of range.') + P = point_mul(G, d_) + assert P is not None + pk = cbytes(P) + if not pk == secnonce[64:97]: + raise ValueError('Public key does not match nonce_gen argument') + a = get_session_key_agg_coeff(session_ctx, P) + g = 1 if has_even_y(Q) else n - 1 + d = g * gacc * d_ % n + s = (k_1 + b * k_2 + e * a * d) % n + psig = bytes_from_int(s) + R_s1 = point_mul(G, k_1_) + R_s2 = point_mul(G, k_2_) + assert R_s1 is not None + assert R_s2 is not None + pubnonce = cbytes(R_s1) + cbytes(R_s2) + # Optional correctness check. The result of signing should pass signature verification. + assert partial_sig_verify_internal(psig, pubnonce, pk, session_ctx) + return psig + +def det_nonce_hash(sk_: bytes, aggothernonce: bytes, aggpk: bytes, msg: bytes, i: int) -> int: + buf = b'' + buf += sk_ + buf += aggothernonce + buf += aggpk + buf += len(msg).to_bytes(8, 'big') + buf += msg + buf += i.to_bytes(1, 'big') + return int_from_bytes(tagged_hash('MuSig/deterministic/nonce', buf)) + +def deterministic_sign(sk: bytes, aggothernonce: bytes, pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, rand: Optional[bytes]) -> Tuple[bytes, bytes]: + if rand is not None: + sk_ = bytes_xor(sk, tagged_hash('MuSig/aux', rand)) + else: + sk_ = sk + aggpk = get_xonly_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly)) + + k_1 = det_nonce_hash(sk_, aggothernonce, aggpk, msg, 0) % n + k_2 = det_nonce_hash(sk_, aggothernonce, aggpk, msg, 1) % n + # k_1 == 0 or k_2 == 0 cannot occur except with negligible probability. + assert k_1 != 0 + assert k_2 != 0 + + R_s1 = point_mul(G, k_1) + R_s2 = point_mul(G, k_2) + assert R_s1 is not None + assert R_s2 is not None + pubnonce = cbytes(R_s1) + cbytes(R_s2) + secnonce = bytearray(bytes_from_int(k_1) + bytes_from_int(k_2) + individual_pk(sk)) + try: + aggnonce = nonce_agg([pubnonce, aggothernonce]) + except Exception: + raise InvalidContributionError(None, "aggothernonce") + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + psig = sign(secnonce, sk, session_ctx) + return (pubnonce, psig) + +def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, i: int) -> bool: + if len(pubnonces) != len(pubkeys): + raise ValueError('The `pubnonces` and `pubkeys` arrays must have the same length.') + if len(tweaks) != len(is_xonly): + raise ValueError('The `tweaks` and `is_xonly` arrays must have the same length.') + aggnonce = nonce_agg(pubnonces) + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + return partial_sig_verify_internal(psig, pubnonces[i], pubkeys[i], session_ctx) + +def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk: bytes, session_ctx: SessionContext) -> bool: + (Q, gacc, _, b, R, e) = get_session_values(session_ctx) + s = int_from_bytes(psig) + if s >= n: + return False + R_s1 = cpoint(pubnonce[0:33]) + R_s2 = cpoint(pubnonce[33:66]) + Re_s_ = point_add(R_s1, point_mul(R_s2, b)) + Re_s = Re_s_ if has_even_y(R) else point_negate(Re_s_) + P = cpoint(pk) + if P is None: + return False + a = get_session_key_agg_coeff(session_ctx, P) + g = 1 if has_even_y(Q) else n - 1 + g_ = g * gacc % n + return point_mul(G, s) == point_add(Re_s, point_mul(P, e * a * g_ % n)) + +def partial_sig_agg(psigs: List[bytes], session_ctx: SessionContext) -> bytes: + (Q, _, tacc, _, R, e) = get_session_values(session_ctx) + s = 0 + u = len(psigs) + for i in range(u): + s_i = int_from_bytes(psigs[i]) + if s_i >= n: + raise InvalidContributionError(i, "psig") + s = (s + s_i) % n + g = 1 if has_even_y(Q) else n - 1 + s = (s + e * g * tacc) % n + return xbytes(R) + bytes_from_int(s) diff --git a/test_utils/musig2.py b/test_utils/musig2.py new file mode 100644 index 000000000..0c9130ce8 --- /dev/null +++ b/test_utils/musig2.py @@ -0,0 +1,862 @@ +""" +This module contains a complete, minimal, standalone MuSig cosigner implementation. +It is NOT a cryptographically secure implementation, and it is only to be used for +testing purposes. + +In lack of a library for wallet policies in python, a minimal version of it for +the purpose of parsing and processing tr() descriptors is implemented here, using +embit for the the final task of compiling simple miniscript descriptors to Script. + +The main objects and methods exported in this class are: + +- PsbtMusig2Cosigner: an abstract class that represents a cosigner in MuSig2. +- HotMusig2Cosigner: an implementation of PsbtMusig2Cosigner that contains a hot + extended private key. Useful for tests. +- run_musig2_test: tests a full signing cycle for a generic list of PsbtMusig2Cosigners. +""" + + +import hashlib +import hmac +from io import BytesIO +import re +from re import Match + +from dataclasses import dataclass +import secrets +import struct +from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union +from abc import ABC, abstractmethod + +import base58 + +from test_utils.taproot_sighash import SIGHASH_DEFAULT, TaprootSignatureHash + +from . import bip0327, bip0340, hash160, sha256 +from . import taproot + +from bitcoin_client.ledger_bitcoin.embit.descriptor.miniscript import Miniscript +from bitcoin_client.ledger_bitcoin.psbt import PSBT, PartiallySignedInput +from bitcoin_client.ledger_bitcoin.key import G, ExtendedKey, bytes_to_point, point_add, point_mul, point_to_bytes +from bitcoin_client.ledger_bitcoin.wallet import WalletPolicy + + +HARDENED_INDEX = 0x80000000 + + +def tapleaf_hash(script: Optional[bytes], leaf_version=b'\xC0') -> Optional[bytes]: + if script is None: + return None + return taproot.tagged_hash( + "TapLeaf", + leaf_version + taproot.ser_script(script) + ) + + +@dataclass +class PlainKeyPlaceholder: + key_index: int + num1: int + num2: int + + +@dataclass +class Musig2KeyPlaceholder: + key_indexes: List[int] + num1: int + num2: int + + +KeyPlaceholder = Union[PlainKeyPlaceholder, Musig2KeyPlaceholder] + + +def parse_placeholder(placeholder_str: str) -> KeyPlaceholder: + """Parses a placeholder string to create a KeyPlaceholder object.""" + if placeholder_str.startswith('musig'): + key_indexes_str = placeholder_str[6:placeholder_str.index( + ')/<')].split(',') + key_indexes = [int(index.strip('@')) for index in key_indexes_str] + + nums_part = placeholder_str[placeholder_str.index(')/<') + 3:-3] + num1, num2 = map(int, nums_part.split(';')) + + return Musig2KeyPlaceholder(key_indexes, num1, num2) + elif placeholder_str.startswith('@'): + parts = placeholder_str.split('/') + key_index = int(parts[0].strip('@')) + + # Remove '<' from the start and '>' from the end + nums_part = parts[1][1:-1] + num1, num2 = map(int, nums_part.split(';')) + + return PlainKeyPlaceholder(key_index, num1, num2) + else: + raise ValueError("Invalid placeholder string") + + +def extract_placeholders(desc_tmpl: str) -> List[KeyPlaceholder]: + """Extracts and parses all placeholders in a descriptor template, from left to right.""" + + pattern = r'musig\((?:@\d+,)*(?:@\d+)\)/<\d+;\d+>/\*|@\d+/<\d+;\d+>/\*' + matches = [(match.group(), match.start()) + for match in re.finditer(pattern, desc_tmpl)] + sorted_matches = sorted(matches, key=lambda x: x[1]) + return [parse_placeholder(match[0]) for match in sorted_matches] + + +def unsorted_musig(pubkeys: Iterable[bytes], version_bytes: bytes) -> Tuple[str, bip0327.KeyAggContext]: + """ + Constructs the musig2 aggregated extended public key from an unsorted list of + compressed public keys, and the version bytes. + """ + + assert all(len(pk) == 33 for pk in pubkeys) + assert len(version_bytes) == 4 + + depth = b'\x00' + fingerprint = b'\x00\x00\x00\x00' + child_number = b'\x00\x00\x00\x00' + + key_agg_ctx = bip0327.key_agg(pubkeys) + Q = key_agg_ctx.Q + compressed_pubkey = ( + b'\x02' if Q[1] % 2 == 0 else b'\x03') + bip0327.get_xonly_pk(key_agg_ctx) + chaincode = bytes.fromhex( + "868087ca02a6f974c4598924c36b57762d32cb45717167e300622c7167e38965") + ext_pubkey = version_bytes + depth + fingerprint + \ + child_number + chaincode + compressed_pubkey + return base58.b58encode_check(ext_pubkey).decode(), key_agg_ctx + + +def musig(pubkeys: Iterable[bytes], version_bytes: bytes) -> Tuple[str, bip0327.KeyAggContext]: + """ + Constructs the musig2 aggregated extended public key from a list of compressed public keys, + and the version bytes. The keys are sorted, as required by the `the musig()` key expression + in descriptors. + """ + return unsorted_musig(sorted(pubkeys), version_bytes) + + +def aggregate_musig_pubkey(keys_info: Iterable[str]) -> Tuple[str, bip0327.KeyAggContext]: + """ + Constructs the musig2 aggregated extended public key from the list of keys info + of the participating keys. + """ + + pubkeys: list[bytes] = [] + versions: Set[str] = set() + for ki in keys_info: + start = ki.find(']') + xpub = ki[start + 1:] + xpub_bytes = base58.b58decode_check(xpub) + versions.add(xpub_bytes[:4]) + pubkeys.append(xpub_bytes[-33:]) + + if len(versions) > 1: + raise ValueError( + "All the extended public keys should be from the same network") + + return musig(pubkeys, versions.pop()) + + +def derive_from_key_info(key_info: str, steps: List[int]) -> str: + start = key_info.find(']') + pk = ExtendedKey.deserialize(key_info[start + 1:]) + return pk.derive_pub_path(steps).to_string() + + +def derive_plain_descriptor(desc_tmpl: str, keys_info: List[str], is_change: bool, address_index: int): + """ + Given a wallet policy, and the change/address_index combination, computes the corresponding descriptor. + It replaces /** with /<0;1>/* + It also replaces each musig() key expression with the corresponding xpub. + The resulting descriptor can be used with descriptor libraries that do not support musig or wallet policies. + """ + + desc_tmpl = desc_tmpl.replace("/**", "/<0;1>/*") + desc_tmpl = desc_tmpl.replace("*", str(address_index)) + + # Replace each with M if is_change is False, otherwise with N + def replace_m_n(match: Match[str]): + m, n = match.groups() + return m if not is_change else n + + desc_tmpl = re.sub(r'<([^;]+);([^>]+)>', replace_m_n, desc_tmpl) + + # Replace musig(...) expressions + def replace_musig(match: Match[str]): + musig_content = match.group(1) + steps = [int(x) for x in match.group(2).split("/")] + + assert len(steps) == 2 + + key_indexes = [int(i.strip('@')) for i in musig_content.split(',')] + key_infos = [keys_info[i] for i in key_indexes] + agg_xpub = aggregate_musig_pubkey(key_infos)[0] + + return derive_from_key_info(agg_xpub, steps) + + desc_tmpl = re.sub(r'musig\(([^)]+)\)/(\d+/\d+)', replace_musig, desc_tmpl) + + # Replace @i/a/b with the i-th element in keys_info, deriving the key appropriately + # to get a plain xpub + def replace_key_index(match): + index, step1, step2 = [int(x) for x in match.group(1).split('/')] + return derive_from_key_info(keys_info[index], [step1, step2]) + + desc_tmpl = re.sub(r'@(\d+/\d+/\d+)', replace_key_index, desc_tmpl) + + return desc_tmpl + + +class Tree: + """ + Recursive structure that represents a taptree, or one of its subtrees. + It can either contain a single descriptor template (if it's a tapleaf), or exactly two child Trees. + """ + + def __init__(self, content: Union[str, Tuple['Tree', 'Tree']]): + if isinstance(content, str): + self.script = content + self.left, self.right = (None, None) + else: + self.script = None + self.left, self.right = content + + @property + def is_leaf(self) -> bool: + return self.script is not None + + def __str__(self): + if self.is_leaf: + return self.script + else: + return f'{{{str(self.left)},{str(self.right)}}}' + + def placeholders(self) -> Iterator[Tuple[KeyPlaceholder, str]]: + """ + Generates an iterator over the placeholders contained in the scripts of the tree's leaf nodes. + + Yields: + Iterator[Tuple[KeyPlaceholder, str]]: An iterator over tuples containing a KeyPlaceholder and its associated script. + """ + + if self.is_leaf: + assert self.script is not None + for placeholder in extract_placeholders(self.script): + yield (placeholder, self.script) + else: + assert self.left is not None and self.right is not None + for placeholder, script in self.left.placeholders(): + yield (placeholder, script) + for placeholder, script in self.right.placeholders(): + yield (placeholder, script) + + def get_taptree_hash(self, keys_info: List[str], is_change: bool, address_index: int) -> bytes: + if self.is_leaf: + assert self.script is not None + leaf_desc = derive_plain_descriptor( + self.script, keys_info, is_change, address_index) + + s = BytesIO(leaf_desc.encode()) + desc: Miniscript = Miniscript.read_from( + s, taproot=True) + + return tapleaf_hash(desc.compile()) + + else: + assert self.left is not None and self.right is not None + left_h = self.left.get_taptree_hash( + keys_info, is_change, address_index) + right_h = self.left.get_taptree_hash( + keys_info, is_change, address_index) + if left_h <= right_h: + return taproot.tagged_hash("TapBranch", left_h + right_h) + else: + return taproot.tagged_hash("TapBranch", right_h + left_h) + + +class TrDescriptorTemplate: + """ + Represents a descriptor template for a tr(KEY) or a tr(KEY,TREE). + This is minimal implementation in order to enable iterating over the placeholders, + and compile the corresponding leaf scripts. + """ + + def __init__(self, key: KeyPlaceholder, tree=Optional[Tree]): + self.key: KeyPlaceholder = key + self.tree: Optional[Tree] = tree + + @classmethod + def from_string(cls, input_string): + parser = cls.Parser(input_string.replace("/**", "/<0;1>/*")) + return parser.parse() + + class Parser: + def __init__(self, input): + self.input = input + self.index = 0 + self.length = len(input) + + def parse(self): + if self.input.startswith('tr('): + self.consume('tr(') + key = self.parse_keyplaceholder() + tree = None + if self.peek() == ',': + self.consume(',') + tree = self.parse_tree() + self.consume(')') + return TrDescriptorTemplate(key, tree) + else: + raise Exception( + "Syntax error: Input does not start with 'tr('") + + def parse_keyplaceholder(self): + if self.peek() == '@': + self.consume('@') + key_index = self.parse_num() + self.consume('/<') + num1 = self.parse_num() + self.consume(';') + num2 = self.parse_num() + self.consume('>/*') + return PlainKeyPlaceholder(key_index, num1, num2) + elif self.input[self.index:self.index+6] == 'musig(': + self.consume('musig(') + key_indexes = self.parse_key_indexes() + self.consume(')/<') + num1 = self.parse_num() + self.consume(';') + num2 = self.parse_num() + self.consume('>/*') + return Musig2KeyPlaceholder(key_indexes, num1, num2) + else: + raise Exception("Syntax error in key placeholder") + + def parse_tree(self) -> Tree: + if self.peek() == '{': + self.consume('{') + tree1 = self.parse_tree() + self.consume(',') + tree2 = self.parse_tree() + self.consume('}') + return Tree((tree1, tree2)) + else: + return Tree(self.parse_script()) + + def parse_script(self) -> str: + start = self.index + nesting = 0 + while self.index < self.length and (nesting > 0 or self.input[self.index] not in ('}', ',', ')')): + if self.input[self.index] == '(': + nesting = nesting + 1 + elif self.input[self.index] == ')': + nesting = nesting - 1 + + self.index += 1 + return self.input[start:self.index] + + def parse_key_indexes(self): + nums = [] + self.consume('@') + nums.append(self.parse_num()) + while self.peek() == ',': + self.consume(',@') + nums.append(self.parse_num()) + return nums + + def parse_num(self): + start = self.index + while self.index < self.length and self.input[self.index].isdigit(): + self.index += 1 + return int(self.input[start:self.index]) + + def consume(self, char): + if self.input[self.index:self.index+len(char)] == char: + self.index += len(char) + else: + raise Exception( + f"Syntax error: Expected '{char}'; rest: {self.input[self.index:]}") + + def peek(self): + return self.input[self.index] if self.index < self.length else None + + def placeholders(self) -> Iterator[Tuple[KeyPlaceholder, Optional[str]]]: + """ + Generates an iterator over the placeholders contained in the template and its tree, also + yielding the corresponding leaf script descriptor (or None for the keypath placeholder). + + Yields: + Iterator[Tuple[KeyPlaceholder, Optional[str]]]: An iterator over tuples containing a KeyPlaceholder and an optional associated script. + """ + + yield (self.key, None) + + if self.tree is not None: + for placeholder, script in self.tree.placeholders(): + yield (placeholder, script) + + def get_taptree_hash(self, is_change: bool, address_index: int) -> bytes: + if self.tree is None: + raise ValueError("There is no taptree") + return self.tree.get_taptree_hash(is_change, address_index) + + +class PsbtMusig2Cosigner(ABC): + @abstractmethod + def get_participant_pubkey(self) -> bip0327.Point: + """ + This method should returns this cosigner's public key. + """ + pass + + @abstractmethod + def generate_public_nonces(self, psbt: PSBT) -> None: + """ + This method should generate public nonces and modify the given Psbt object in-place. + It should raise an exception in case of failure. + """ + pass + + @abstractmethod + def generate_partial_signatures(self, psbt: PSBT) -> None: + """ + Receives a PSBT that contains all the participants' public nonces, and adds this participant's partial signature. + It should raise an exception in case of failure. + """ + pass + + +def find_change_and_addr_index_for_musig(input_psbt: PartiallySignedInput, placeholder: Musig2KeyPlaceholder, agg_xpub: ExtendedKey): + num1, num2 = placeholder.num1, placeholder.num2 + + agg_xpub_fingerprint = hash160(agg_xpub.pubkey)[0:4] + + # Iterate through tap key origins in the input + # TODO: this might be made more precise (e.g. use the leaf_hash from the tap_bip32_paths items) + for xonly, (_, key_origin) in input_psbt.tap_bip32_paths.items(): + der_path = key_origin.path + # Check if the fingerprint matches the expected pattern and the derivation path has the correct structure + if key_origin.fingerprint == agg_xpub_fingerprint and len(der_path) == 2 and der_path[0] < HARDENED_INDEX and der_path[1] < HARDENED_INDEX and (der_path[0] == num1 or der_path[0] == num2): + if xonly != agg_xpub.derive_pub_path(der_path).pubkey[1:]: + continue + + # Determine if the address is a change address and extract the address index + is_change = (der_path[0] == num2) + addr_index = int(der_path[1]) + return is_change, addr_index + + return None + + +def get_bip32_tweaks(ext_key: ExtendedKey, steps: List[int]) -> List[bytes]: + """ + Generate BIP32 tweaks for a series of derivation steps on an extended key. + + Args: + ext_key (ExtendedKey): The extended public key. + steps (List[int]): A list of derivation steps (must be unhardened). + + Returns: + List[bytes]: The list of additive tweaks for those derivation steps. + """ + + result = [] + + cur_pubkey = ext_key.pubkey + cur_chaincode = ext_key.chaincode + + for step in steps: + if step < 0 or step >= HARDENED_INDEX: + raise ValueError("Invalid unhardened derivation step") + + data = cur_pubkey + struct.pack(">L", step) + Ihmac = hmac.new(cur_chaincode, data, hashlib.sha512).digest() + Il = Ihmac[:32] + Ir = Ihmac[32:] + + result.append(Il) + + Il_int = int.from_bytes(Il, 'big') + child_pubkey_point = point_add(point_mul(G, Il_int), + bytes_to_point(cur_pubkey)) + child_pubkey = point_to_bytes(child_pubkey_point) + + cur_pubkey = child_pubkey + cur_chaincode = Ir + + return result + + +def process_placeholder( + wallet_policy: WalletPolicy, + psbt_input: PartiallySignedInput, + placeholder: Musig2KeyPlaceholder, + keyagg_ctx: bip0327.KeyAggContext, + agg_xpub: ExtendedKey, + tapleaf_desc: Optional[str], + desc_tmpl: TrDescriptorTemplate +) -> Optional[Tuple[List[bytes], List[bool], Optional[bytes], bytes]]: + """ + This method encapsulates all the precomputations that are done for a certain + wallet policy, psbt input and musig() placeholder that are common to both the + nonce generation and the partial signature generation flows. + + Returs a tuple containing: + - tweaks: a list of tweaks to be applied to the aggregate musig key + - is_xonly_tweak: a list of boolean of the same length of tweaks, specifying for + each of them if it's a plain tweak or an x-only tweak + - leaf_script: the compiled leaf script, or None for a taproot keypath spend + - aggpk_tweaked: the value of the aggregate pubkey after applying the tweaks + """ + res = find_change_and_addr_index_for_musig( + psbt_input, placeholder, agg_xpub) + if res is None: + return None + is_change, address_index = res + + leaf_script = None + if tapleaf_desc is not None: + leaf_desc = derive_plain_descriptor( + tapleaf_desc, wallet_policy.keys_info, is_change, address_index) + s = BytesIO(leaf_desc.encode()) + desc: Miniscript = Miniscript.read_from(s, taproot=True) + leaf_script = desc.compile() + + tweaks = [] + is_xonly_tweak = [] + + # Compute bip32 tweaks + bip32_steps = [ + placeholder.num2 if is_change else placeholder.num1, + address_index + ] + bip32_tweaks = get_bip32_tweaks(agg_xpub, bip32_steps) + for tweak in bip32_tweaks: + tweaks.append(tweak) + is_xonly_tweak.append(False) + + # aggregate key after the bip_32 derivations (but before the taptweak, if any) + der_key = agg_xpub.derive_pub_path(bip32_steps) + + # x-only tweak, only if spending the keypath + if tapleaf_desc is None: + t = der_key.pubkey[-32:] + if desc_tmpl.tree is not None: + t += desc_tmpl.get_taptree_hash(is_change, address_index) + tweaks.append(taproot.tagged_hash("TapTweak", t)) + is_xonly_tweak.append(True) + + keyagg_ctx = aggregate_musig_pubkey( + wallet_policy.keys_info[i] for i in placeholder.key_indexes)[1] + + for tweak, is_xonly in zip(tweaks, is_xonly_tweak): + keyagg_ctx = bip0327.apply_tweak(keyagg_ctx, tweak, is_xonly) + + aggpk_tweaked = bip0327.cbytes(keyagg_ctx.Q) + + return (tweaks, is_xonly_tweak, leaf_script, aggpk_tweaked) + + +class HotMusig2Cosigner(PsbtMusig2Cosigner): + """ + Implements a PsbtMusig2Cosigner for a given wallet policy and a private + that appears as one of the key in a musig() key expression. + """ + + def __init__(self, wallet_policy: WalletPolicy, privkey: str) -> None: + super().__init__() + + self.wallet_policy = wallet_policy + self.privkey = ExtendedKey.deserialize(privkey) + + assert self.privkey.to_string() == privkey + + self.musig_psbt_sessions: Dict[bytes, bytes] = {} + + assert self.privkey.is_private + + def compute_psbt_session_id(self, psbt: PSBT) -> bytes: + psbt.tx.rehash() + return sha256(psbt.tx.hash + self.wallet_policy.id) + + def get_participant_pubkey(self) -> bip0327.Point: + return bip0327.cpoint(self.privkey.pubkey) + + def generate_public_nonces(self, psbt: PSBT) -> None: + desc_tmpl = TrDescriptorTemplate.from_string( + self.wallet_policy.descriptor_template) + + psbt_session_id = self.compute_psbt_session_id(psbt) + + # root of all pseudorandomness for this psbt session + rand_seed = secrets.token_bytes(32) + + for placeholder_index, (placeholder, tapleaf_desc) in enumerate(desc_tmpl.placeholders()): + if not isinstance(placeholder, Musig2KeyPlaceholder): + continue + + agg_xpub_str, keyagg_ctx = aggregate_musig_pubkey( + self.wallet_policy.keys_info[i] for i in placeholder.key_indexes) + agg_xpub = ExtendedKey.deserialize(agg_xpub_str) + + for input_index, input in enumerate(psbt.inputs): + result = process_placeholder( + self.wallet_policy, input, placeholder, keyagg_ctx, agg_xpub, tapleaf_desc, desc_tmpl) + if result is None: + continue + + (_, _, leaf_script, aggpk_tweaked) = result + + rand_i_j = sha256( + rand_seed + + input_index.to_bytes(4, byteorder='big') + + placeholder_index.to_bytes(4, byteorder='big') + ) + + # secnonce: bytearray + # pubnonce: bytes + _, pubnonce = bip0327.nonce_gen_internal( + rand_=rand_i_j, + sk=None, + pk=self.privkey.pubkey, + aggpk=aggpk_tweaked, + msg=None, + extra_in=None + ) + + pubnonce_identifier = ( + self.privkey.pubkey, + aggpk_tweaked, + tapleaf_hash(leaf_script) + ) + + assert len(aggpk_tweaked) == 33 + + input.musig2_pub_nonces[pubnonce_identifier] = pubnonce + + self.musig_psbt_sessions[psbt_session_id] = rand_seed + + def generate_partial_signatures(self, psbt: PSBT) -> None: + desc_tmpl = TrDescriptorTemplate.from_string( + self.wallet_policy.descriptor_template) + + psbt_session_id = self.compute_psbt_session_id(psbt) + + # Get the session's randomness seed, while simultaneously deleting it from the open sessions + rand_seed = self.musig_psbt_sessions.pop(psbt_session_id, None) + + if rand_seed is None: + raise ValueError( + "No musig signing session for this psbt") + + for placeholder_index, (placeholder, tapleaf_desc) in enumerate(desc_tmpl.placeholders()): + if not isinstance(placeholder, Musig2KeyPlaceholder): + continue + + agg_xpub_str, keyagg_ctx = aggregate_musig_pubkey( + self.wallet_policy.keys_info[i] for i in placeholder.key_indexes) + agg_xpub = ExtendedKey.deserialize(agg_xpub_str) + + for input_index, input in enumerate(psbt.inputs): + result = process_placeholder( + self.wallet_policy, input, placeholder, keyagg_ctx, agg_xpub, tapleaf_desc, desc_tmpl) + if result is None: + continue + + (tweaks, is_xonly_tweak, leaf_script, aggpk_tweaked) = result + + rand_i_j = sha256( + rand_seed + + input_index.to_bytes(4, byteorder='big') + + placeholder_index.to_bytes(4, byteorder='big') + ) + + secnonce, pubnonce = bip0327.nonce_gen_internal( + rand_=rand_i_j, + sk=None, + pk=self.privkey.pubkey, + aggpk=aggpk_tweaked, + msg=None, + extra_in=None + ) + + pubkeys_in_musig: List[ExtendedKey] = [] + my_key_index_in_musig: Optional[int] = None + for i in placeholder.key_indexes: + k_i = self.wallet_policy.keys_info[i] + xpub_i = k_i[k_i.find(']') + 1:] + pubkeys_in_musig.append(ExtendedKey.deserialize(xpub_i)) + + if xpub_i == self.privkey.neutered().to_string(): + my_key_index_in_musig = i + + if my_key_index_in_musig is None: + raise ValueError("No internal key found in musig") + + # sort the keys in ascending order + pubkeys_in_musig = list( + sorted(pubkeys_in_musig, key=lambda x: x.pubkey)) + + nonces: List[bytes] = [] + for participant_key in pubkeys_in_musig: + participant_pubnonce_identifier = ( + participant_key.pubkey, + aggpk_tweaked, + tapleaf_hash(leaf_script) + ) + + if participant_key.pubkey == self.privkey.pubkey and input.musig2_pub_nonces[participant_pubnonce_identifier] != pubnonce: + raise ValueError( + f"Public nonce in psbt didn't match the expected one for cosigner {self.privkey.pubkey}") + + assert len(aggpk_tweaked) == 33 + + if participant_pubnonce_identifier in input.musig2_pub_nonces: + nonces.append( + input.musig2_pub_nonces[participant_pubnonce_identifier]) + else: + raise ValueError( + f"Missing pubnonce for pubkey {participant_key.pubkey.hex()} in psbt") + + if leaf_script is None: + sighash = TaprootSignatureHash( + txTo=psbt.tx, + spent_utxos=[ + psbt.inputs[i].witness_utxo for i in range(len(psbt.inputs))], + hash_type=input.sighash or SIGHASH_DEFAULT, + input_index=input_index, + ) + else: + sighash = TaprootSignatureHash( + txTo=psbt.tx, + spent_utxos=[ + psbt.inputs[i].witness_utxo for i in range(len(psbt.inputs))], + hash_type=input.sighash or SIGHASH_DEFAULT, + input_index=input_index, + scriptpath=True, + script=leaf_script + ) + + aggnonce = bip0327.nonce_agg(nonces) + + session_ctx = bip0327.SessionContext( + aggnonce=aggnonce, + pubkeys=[pk.pubkey for pk in pubkeys_in_musig], + tweaks=tweaks, + is_xonly=is_xonly_tweak, + msg=sighash) + + partial_sig = bip0327.sign( + secnonce, self.privkey.privkey, session_ctx) + + pubnonce_identifier = ( + self.privkey.pubkey, + aggpk_tweaked, + tapleaf_hash(leaf_script) + ) + + input.musig2_partial_sigs[pubnonce_identifier] = partial_sig + + +def run_musig2_test(wallet_policy: WalletPolicy, psbt: PSBT, cosigners: List[PsbtMusig2Cosigner], sighashes: list[bytes]): + """ + This performs the following steps: + - go through all the cosigners to let them add their pubnonce; + - go through all the cosigners to let them add their partial signature; + - aggregate the partial signatures to produce the final Schnorr signature; + - verify that the produced signature is valid for the provided sighash. + + The sighashes (one per input) are given as argument and are assument to be correct. + """ + + if len(psbt.inputs) != len(sighashes): + raise ValueError("The sighashes") + + for signer in cosigners: + signer.generate_public_nonces(psbt) + + for signer in cosigners: + signer.generate_partial_signatures(psbt) + + desc_tmpl = TrDescriptorTemplate.from_string( + wallet_policy.descriptor_template) + + for placeholder, tapleaf_desc in desc_tmpl.placeholders(): + if not isinstance(placeholder, Musig2KeyPlaceholder): + continue + + agg_xpub_str, keyagg_ctx = aggregate_musig_pubkey( + wallet_policy.keys_info[i] for i in placeholder.key_indexes) + agg_xpub = ExtendedKey.deserialize(agg_xpub_str) + + for input_index, input in enumerate(psbt.inputs): + result = process_placeholder( + wallet_policy, input, placeholder, keyagg_ctx, agg_xpub, tapleaf_desc, desc_tmpl) + + if result is None: + raise RuntimeError( + "Unexpected: processing the musig placeholder failed") + + (tweaks, is_xonly_tweak, leaf_script, aggpk_tweaked) = result + + assert len(aggpk_tweaked) == 33 + + pubkeys_in_musig: List[ExtendedKey] = [] + for i in placeholder.key_indexes: + k_i = wallet_policy.keys_info[i] + xpub_i = k_i[k_i.find(']') + 1:] + pubkeys_in_musig.append(ExtendedKey.deserialize(xpub_i)) + + # sort the keys in ascending order + pubkeys_in_musig = list( + sorted(pubkeys_in_musig, key=lambda x: x.pubkey)) + + nonces: List[bytes] = [] + for participant_key in pubkeys_in_musig: + pubnonce_identifier = ( + participant_key.pubkey, + aggpk_tweaked, + tapleaf_hash(leaf_script) + ) + + if pubnonce_identifier in input.musig2_pub_nonces: + nonces.append( + input.musig2_pub_nonces[pubnonce_identifier]) + else: + raise ValueError( + f"Missing pubnonce for pubkey {participant_key.pubkey.hex()} in psbt") + + aggnonce = bip0327.nonce_agg(nonces) + + sighash = sighashes[input_index] + + session_ctx = bip0327.SessionContext( + aggnonce=aggnonce, + pubkeys=[pk.pubkey for pk in pubkeys_in_musig], + tweaks=tweaks, + is_xonly=is_xonly_tweak, + msg=sighash) + + # collect partial signatures + psigs: List[bytes] = [] + + for participant_key in pubkeys_in_musig: + pubnonce_identifier = ( + participant_key.pubkey, + bytes(aggpk_tweaked), + tapleaf_hash(leaf_script) + ) + + if pubnonce_identifier in input.musig2_partial_sigs: + psigs.append( + input.musig2_partial_sigs[pubnonce_identifier]) + else: + raise ValueError( + f"Missing partial signature for pubkey {participant_key.pubkey.hex()} in psbt") + + sig = bip0327.partial_sig_agg(psigs, session_ctx) + + aggpk_tweaked_xonly = aggpk_tweaked[1:] + assert (bip0340.schnorr_verify(sighash, aggpk_tweaked_xonly, sig)) diff --git a/test_utils/taproot.py b/test_utils/taproot.py index 0ba25a2ae..3f84ab56e 100644 --- a/test_utils/taproot.py +++ b/test_utils/taproot.py @@ -1,14 +1,32 @@ -# from portions of BIP-0341 +# from BIP-0340 and BIP-0341 +# - https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0340.mediawiki # - https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0341.mediawiki # Distributed under the BSD-3-Clause license # fmt: off +# Set DEBUG to True to get a detailed debug output including +# intermediate values during key generation, signing, and +# verification. This is implemented via calls to the +# debug_print_vars() function. +# # If you want to print values on an individual basis, use # the pretty() function, e.g., print(pretty(foo)). import hashlib -import struct +from typing import Any, Optional, Tuple + + +DEBUG = False + +p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F +n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 +SECP256K1_ORDER = n + +# Points are tuples of X and Y coordinates and the point at infinity is +# represented by the None keyword. +G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) +Point = Tuple[int, int] # This implementation can be sped up by storing the midstate after hashing # tag_hash instead of rehashing it all the time. @@ -16,6 +34,140 @@ def tagged_hash(tag: str, msg: bytes) -> bytes: tag_hash = hashlib.sha256(tag.encode()).digest() return hashlib.sha256(tag_hash + tag_hash + msg).digest() +def is_infinite(P: Optional[Point]) -> bool: + return P is None + +def x(P: Point) -> int: + assert not is_infinite(P) + return P[0] + +def y(P: Point) -> int: + assert not is_infinite(P) + return P[1] + +def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: + if P1 is None: + return P2 + if P2 is None: + return P1 + if (x(P1) == x(P2)) and (y(P1) != y(P2)): + return None + if P1 == P2: + lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p + else: + lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p + x3 = (lam * lam - x(P1) - x(P2)) % p + return (x3, (lam * (x(P1) - x3) - y(P1)) % p) + +def point_mul(P: Optional[Point], n: int) -> Optional[Point]: + R = None + for i in range(256): + if (n >> i) & 1: + R = point_add(R, P) + P = point_add(P, P) + return R + +def bytes_from_int(x: int) -> bytes: + return x.to_bytes(32, byteorder="big") + +def bytes_from_point(P: Point) -> bytes: + return bytes_from_int(x(P)) + +def xor_bytes(b0: bytes, b1: bytes) -> bytes: + return bytes(x ^ y for (x, y) in zip(b0, b1)) + +def lift_x(x: int) -> Optional[Point]: + if x >= p: + return None + y_sq = (pow(x, 3, p) + 7) % p + y = pow(y_sq, (p + 1) // 4, p) + if pow(y, 2, p) != y_sq: + return None + return (x, y if y & 1 == 0 else p-y) + +def int_from_bytes(b: bytes) -> int: + return int.from_bytes(b, byteorder="big") + +def hash_sha256(b: bytes) -> bytes: + return hashlib.sha256(b).digest() + +def has_even_y(P: Point) -> bool: + assert not is_infinite(P) + return y(P) % 2 == 0 + +def pubkey_gen(seckey: bytes) -> bytes: + d0 = int_from_bytes(seckey) + if not (1 <= d0 <= n - 1): + raise ValueError('The secret key must be an integer in the range 1..n-1.') + P = point_mul(G, d0) + assert P is not None + return bytes_from_point(P) + +def schnorr_sign(msg: bytes, seckey: bytes, aux_rand: bytes) -> bytes: + d0 = int_from_bytes(seckey) + if not (1 <= d0 <= n - 1): + raise ValueError('The secret key must be an integer in the range 1..n-1.') + if len(aux_rand) != 32: + raise ValueError('aux_rand must be 32 bytes instead of %i.' % len(aux_rand)) + P = point_mul(G, d0) + assert P is not None + d = d0 if has_even_y(P) else n - d0 + t = xor_bytes(bytes_from_int(d), tagged_hash("BIP0340/aux", aux_rand)) + k0 = int_from_bytes(tagged_hash("BIP0340/nonce", t + bytes_from_point(P) + msg)) % n + if k0 == 0: + raise RuntimeError('Failure. This happens only with negligible probability.') + R = point_mul(G, k0) + assert R is not None + k = n - k0 if not has_even_y(R) else k0 + e = int_from_bytes(tagged_hash("BIP0340/challenge", bytes_from_point(R) + bytes_from_point(P) + msg)) % n + sig = bytes_from_point(R) + bytes_from_int((k + e * d) % n) + debug_print_vars() + if not schnorr_verify(msg, bytes_from_point(P), sig): + raise RuntimeError('The created signature does not pass verification.') + return sig + +def schnorr_verify(msg: bytes, pubkey: bytes, sig: bytes) -> bool: + if len(pubkey) != 32: + raise ValueError('The public key must be a 32-byte array.') + if len(sig) != 64: + raise ValueError('The signature must be a 64-byte array.') + P = lift_x(int_from_bytes(pubkey)) + r = int_from_bytes(sig[0:32]) + s = int_from_bytes(sig[32:64]) + if (P is None) or (r >= p) or (s >= n): + debug_print_vars() + return False + e = int_from_bytes(tagged_hash("BIP0340/challenge", sig[0:32] + pubkey + msg)) % n + R = point_add(point_mul(G, s), point_mul(P, n - e)) + if (R is None) or (not has_even_y(R)) or (x(R) != r): + debug_print_vars() + return False + debug_print_vars() + return True + +import inspect + +def pretty(v: Any) -> Any: + if isinstance(v, bytes): + return '0x' + v.hex() + if isinstance(v, int): + return pretty(bytes_from_int(v)) + if isinstance(v, tuple): + return tuple(map(pretty, v)) + return v + +def debug_print_vars() -> None: + if DEBUG: + current_frame = inspect.currentframe() + assert current_frame is not None + frame = current_frame.f_back + assert frame is not None + print(' Variables in function ', frame.f_code.co_name, ' at line ', frame.f_lineno, ':', sep='') + for var_name, var_val in frame.f_locals.items(): + print(' ' + var_name.rjust(11, ' '), '==', pretty(var_val)) + + +import struct def ser_compact_size(l): r = b"" @@ -48,3 +200,63 @@ def ser_string(s): def ser_script(s): return ser_string(s) + + +# BIP-0341 +def taproot_tweak_pubkey(pubkey, h): + t = int_from_bytes(tagged_hash("TapTweak", pubkey + h)) + if t >= SECP256K1_ORDER: + raise ValueError + P = lift_x(int_from_bytes(pubkey)) + if P is None: + raise ValueError + Q = point_add(P, point_mul(G, t)) + return 0 if has_even_y(Q) else 1, bytes_from_int(x(Q)) + +def taproot_tweak_seckey(seckey0, h): + seckey0 = int_from_bytes(seckey0) + P = point_mul(G, seckey0) + seckey = seckey0 if has_even_y(P) else SECP256K1_ORDER - seckey0 + t = int_from_bytes(tagged_hash("TapTweak", bytes_from_int(x(P)) + h)) + if t >= SECP256K1_ORDER: + raise ValueError + return bytes_from_int((seckey + t) % SECP256K1_ORDER) + +def taproot_tree_helper(script_tree): + if isinstance(script_tree, tuple): + leaf_version, script = script_tree + h = tagged_hash("TapLeaf", bytes([leaf_version]) + ser_script(script)) + return ([((leaf_version, script), bytes())], h) + left, left_h = taproot_tree_helper(script_tree[0]) + right, right_h = taproot_tree_helper(script_tree[1]) + ret = [(l, c + right_h) for l, c in left] + [(l, c + left_h) for l, c in right] + if right_h < left_h: + left_h, right_h = right_h, left_h + return (ret, tagged_hash("TapBranch", left_h + right_h)) + +def taproot_output_script(internal_pubkey, script_tree): + """Given a internal public key and a tree of scripts, compute the output script. + script_tree is either: + - a (leaf_version, script) tuple (leaf_version is 0xc0 for [[bip-0342.mediawiki|BIP342]] scripts) + - a list of two elements, each with the same structure as script_tree itself + - None + """ + if script_tree is None: + h = bytes() + else: + _, h = taproot_tree_helper(script_tree) + _, output_pubkey = taproot_tweak_pubkey(internal_pubkey, h) + return bytes([0x51, 0x20]) + output_pubkey + + +# Tweak without tag +def tweak_pubkey(pubkey, data: bytes): + assert len(data) == 32 + t = int_from_bytes(data) + if t >= SECP256K1_ORDER: + raise ValueError + P = lift_x(int_from_bytes(pubkey)) + if P is None: + raise ValueError + Q = point_add(P, point_mul(G, t)) + return 0 if has_even_y(Q) else 1, bytes_from_int(x(Q)) diff --git a/test_utils/taproot_sighash.py b/test_utils/taproot_sighash.py new file mode 100644 index 000000000..073be97e2 --- /dev/null +++ b/test_utils/taproot_sighash.py @@ -0,0 +1,85 @@ +# Based on code from the bitcoin's functional test framework, extracted from: +# https://github.com/bitcoin/bitcoin/blob/58446e1d92c7da46da1fc48e1eb5eefe2e0748cb/test/functional/feature_taproot.py +# +# Copyright (c) 2015-2022 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying + + +import struct +from test_utils import sha256 +from test_utils.taproot import ser_string, tagged_hash + + +def BIP341_sha_prevouts(txTo): + return sha256(b"".join(i.prevout.serialize() for i in txTo.vin)) + + +def BIP341_sha_amounts(spent_utxos): + return sha256(b"".join(struct.pack(" None: + super().__init__() + + self.client = client + self.wallet_policy = wallet_policy + self.wallet_hmac = wallet_hmac + + self.navigator = navigator + self.testname = testname + self.instructions = instructions + + self.fingerprint = client.get_master_fingerprint() + + desc_tmpl = TrDescriptorTemplate.from_string( + wallet_policy.descriptor_template) + + self.pubkey = None + for _, (placeholder, _) in enumerate(desc_tmpl.placeholders()): + if not isinstance(placeholder, Musig2KeyPlaceholder): + continue + + for i in placeholder.key_indexes: + key_info = self.wallet_policy.keys_info[i] + if key_info[0] == "[" and key_info[1:9] == self.fingerprint.hex(): + xpub = key_info[key_info.find(']') + 1:] + self.pubkey = ExtendedKey.deserialize(xpub) + break + + if self.pubkey is not None: + break + + if self.pubkey is None: + raise ValueError("no musig with an internal key in wallet policy") + + def get_participant_pubkey(self) -> bip0327.Point: + return bip0327.cpoint(self.pubkey.pubkey) + + def generate_public_nonces(self, psbt: PSBT) -> None: + print("PSBT before nonce generation:", psbt.serialize()) + res = self.client.sign_psbt( + psbt, self.wallet_policy, self.wallet_hmac, navigator=self.navigator, testname=self.testname, instructions=self.instructions) + print("Pubnonces:", res) + for (input_index, yielded) in res: + if isinstance(yielded, MusigPubNonce): + psbt_key = ( + yielded.participant_pubkey, + yielded.aggregate_pubkey, + yielded.tapleaf_hash + ) + print("Adding pubnonce to psbt for Ledger input", input_index) + print("Key:", psbt_key) + print("Value:", yielded.pubnonce) + + assert len(yielded.aggregate_pubkey) == 33 + + psbt.inputs[input_index].musig2_pub_nonces[psbt_key] = yielded.pubnonce + + def generate_partial_signatures(self, psbt: PSBT) -> None: + print("PSBT before partial signature generation:", psbt.serialize()) + res = self.client.sign_psbt( + psbt, self.wallet_policy, self.wallet_hmac, navigator=self.navigator, testname=self.testname, instructions=self.instructions) + print("Ledger result of second round:", res) + for (input_index, yielded) in res: + if isinstance(yielded, MusigPartialSignature): + psbt_key = ( + yielded.participant_pubkey, + yielded.aggregate_pubkey, + yielded.tapleaf_hash + ) + + print("Adding partial signature to psbt for Ledger input", input_index) + print("Key:", psbt_key) + print("Value:", yielded.partial_signature) + + psbt.inputs[input_index].musig2_partial_sigs[psbt_key] = yielded.partial_signature + elif isinstance(yielded, MusigPubNonce): + raise ValueError("Expected partial signatures, got a pubnonce") + + +def test_sign_psbt_musig2_keypath(navigator: Navigator, firmware: Firmware, client: RaggerClient, test_name: str, speculos_globals: SpeculosGlobals): + cosigner_1_xpub = "[f5acc2fd/44'/1'/0']tpubDCwYjpDhUdPGP5rS3wgNg13mTrrjBuG8V9VpWbyptX6TRPbNoZVXsoVUSkCjmQ8jJycjuDKBb9eataSymXakTTaGifxR6kmVsfFehH1ZgJT" + + cosigner_2_xpriv = "tprv8gFWbQBTLFhbX3EK3cS7LmenwE3JjXbD9kN9yXfq7LcBm81RSf8vPGPqGPjZSeX41LX9ZN14St3z8YxW48aq5Yhr9pQZVAyuBthfi6quTCf" + cosigner_2_xpub = "tpubDCwYjpDhUdPGQWG6wG6hkBJuWFZEtrn7j3xwG3i8XcQabcGC53xWZm1hSXrUPFS5UvZ3QhdPSjXWNfWmFGTioARHuG5J7XguEjgg7p8PxAm" + + wallet_policy = WalletPolicy( + name="Musig for my ears", + descriptor_template="tr(musig(@0,@1)/**)", + keys_info=[cosigner_1_xpub, cosigner_2_xpub] + ) + wallet_hmac = hmac.new( + speculos_globals.wallet_registration_key, wallet_policy.id, sha256).digest() + + psbt_b64 = "cHNidP8BAIACAAAAAdF2HhQ2XCgTpd3Sel7VkS5FvESbwo1rgeuG4tBt9GICAAAAAAD9////AQAAAAAAAAAARGpCVGhpcyBpbnB1dHMgaGFzIHR3byBwdWJrZXlzIGJ1dCB5b3Ugb25seSBzZWUgb25lLiAjbXBjZ2FuZyByZXZlbmdlAAAAAAABASuf/gQAAAAAACJRIMH9/r7QY6oUg0DEUTLmcY2N6BRmriuQkp49kyg2TNbtIRaQZkYWUCCfi7xZsFr10WFcUPX3nBiNe+dC/ZMiUvaPDA0AW4+8kwAAAAADAAAAAAA=" + psbt = PSBT() + psbt.deserialize(psbt_b64) + + sighashes = [ + bytes.fromhex( + "a3aeecb6c236b4a7e72c95fa138250d449b97a75c573f8ab612356279ff64046") + ] + + signer_1 = LedgerMusig2Cosigner(client, wallet_policy, wallet_hmac, + navigator=navigator, instructions=sign_psbt_instruction_approve(firmware, save_screenshot=False, has_spend_from_wallet=True, has_feewarning=True), testname=test_name) + signer_2 = HotMusig2Cosigner(wallet_policy, cosigner_2_xpriv) + + run_musig2_test(wallet_policy, psbt, [signer_1, signer_2], sighashes) + + +def test_sign_psbt_musig2_scriptpath(navigator: Navigator, firmware: Firmware, client: RaggerClient, test_name: str, speculos_globals: SpeculosGlobals): + cosigner_1_xpub = "[f5acc2fd/44'/1'/0']tpubDCwYjpDhUdPGP5rS3wgNg13mTrrjBuG8V9VpWbyptX6TRPbNoZVXsoVUSkCjmQ8jJycjuDKBb9eataSymXakTTaGifxR6kmVsfFehH1ZgJT" + + cosigner_2_xpriv = "tprv8gFWbQBTLFhbX3EK3cS7LmenwE3JjXbD9kN9yXfq7LcBm81RSf8vPGPqGPjZSeX41LX9ZN14St3z8YxW48aq5Yhr9pQZVAyuBthfi6quTCf" + cosigner_2_xpub = ExtendedKey.deserialize( + cosigner_2_xpriv).neutered().to_string() + + wallet_policy = WalletPolicy( + name="Musig2 in the scriptpath", + descriptor_template="tr(@0/**,pk(musig(@1,@2)/**))", + keys_info=[ + "tpubD6NzVbkrYhZ4WLczPJWReQycCJdd6YVWXubbVUFnJ5KgU5MDQrD998ZJLSmaB7GVcCnJSDWprxmrGkJ6SvgQC6QAffVpqSvonXmeizXcrkN", + cosigner_1_xpub, + cosigner_2_xpub + ] + ) + wallet_hmac = hmac.new( + speculos_globals.wallet_registration_key, wallet_policy.id, sha256).digest() + + psbt_b64 = "cHNidP8BAFoCAAAAAdOnEESfpXpBe9X59Q4jxz1u9E4Wovn2bkAuuyqUUY0mAAAAAAD9////AQAAAAAAAAAAHmocTXVzaWcyLiBOb3cgZXZlbiBpbiBTY3JpcHRzLgAAAAAAAQErOTAAAAAAAAAiUSDtVR7h2JYPJC463zrCcmfKriiugHBXAcXDP1O2ptF2LyIVwethFsEeXf/x51pIczoAIsj9RoVePIBTyk/rOMW8B6uIIyCQZkYWUCCfi7xZsFr10WFcUPX3nBiNe+dC/ZMiUvaPDKzAIRaQZkYWUCCfi7xZsFr10WFcUPX3nBiNe+dC/ZMiUvaPDC0BuYMCXh1wIlpyBMdMaCFPSwOeOyvhqg+FJ+fOMoWlJsRbj7yTAAAAAAMAAAABFyDrYRbBHl3/8edaSHM6ACLI/UaFXjyAU8pP6zjFvAeriAEYILmDAl4dcCJacgTHTGghT0sDnjsr4aoPhSfnzjKFpSbEAAA=" + psbt = PSBT() + psbt.deserialize(psbt_b64) + + sighashes = [ + bytes.fromhex( + "28f86cd95c144ed4a877701ae7166867e8805b654c43d9f44da45d7b0070c313") + ] + + signer_1 = LedgerMusig2Cosigner(client, wallet_policy, wallet_hmac, + navigator=navigator, instructions=sign_psbt_instruction_approve(firmware, save_screenshot=False, has_spend_from_wallet=True), testname=test_name) + signer_2 = HotMusig2Cosigner(wallet_policy, cosigner_2_xpriv) + + run_musig2_test(wallet_policy, psbt, [signer_1, signer_2], sighashes) diff --git a/unit-tests/test_wallet.c b/unit-tests/test_wallet.c index 975545487..c9ad399e7 100644 --- a/unit-tests/test_wallet.c +++ b/unit-tests/test_wallet.c @@ -32,12 +32,31 @@ static int parse_policy(const char *descriptor_template, uint8_t *out, size_t ou // about half of the memory would be needed #define MAX_WALLET_POLICY_MEMORY_SIZE 512 -// convenience function to compactly check common assertions on a key placeholder pointer -static void check_key_placeholder(const policy_node_key_placeholder_t *ptr, - int key_index, - uint32_t num_first, - uint32_t num_second) { - assert_int_equal(ptr->key_index, key_index); +// convenience function to compactly check common assertions on a pointer to a key expression with a +// single placeholder +static void check_key_expr_plain(const policy_node_keyexpr_t *ptr, + int key_index, + uint32_t num_first, + uint32_t num_second) { + assert_int_equal(ptr->type, KEY_EXPRESSION_NORMAL); + assert_int_equal(ptr->k.key_index, key_index); + assert_int_equal(ptr->num_first, num_first); + assert_int_equal(ptr->num_second, num_second); +} + +// convenience function to compactly check assertions on a pointer to a key expression with a musig +static void check_key_expr_musig(const policy_node_keyexpr_t *ptr, + int n_musig_keys, + const uint16_t *key_indices, + uint32_t num_first, + uint32_t num_second) { + assert_int_equal(ptr->type, KEY_EXPRESSION_MUSIG); + musig_aggr_key_info_t *musig_info = r_musig_aggr_key_info(&ptr->m.musig_info); + assert_int_equal(musig_info->n, n_musig_keys); + uint16_t *musig_key_indexes = r_uint16(&musig_info->key_indexes); + for (int i = 0; i < n_musig_keys; i++) { + assert_int_equal(musig_key_indexes[i], key_indices[i]); + } assert_int_equal(ptr->num_first, num_first); assert_int_equal(ptr->num_second, num_second); } @@ -53,7 +72,7 @@ static void test_parse_policy_map_singlesig_1(void **state) { policy_node_with_key_t *node_1 = (policy_node_with_key_t *) out; assert_int_equal(node_1->base.type, TOKEN_PKH); - check_key_placeholder(r_policy_node_key_placeholder(&node_1->key_placeholder), 0, 0, 1); + check_key_expr_plain(r_policy_node_keyexpr(&node_1->key), 0, 0, 1); } static void test_parse_policy_map_singlesig_2(void **state) { @@ -71,7 +90,7 @@ static void test_parse_policy_map_singlesig_2(void **state) { policy_node_with_key_t *inner = (policy_node_with_key_t *) r_policy_node(&root->script); assert_int_equal(inner->base.type, TOKEN_WPKH); - check_key_placeholder(r_policy_node_key_placeholder(&inner->key_placeholder), 0, 0, 1); + check_key_expr_plain(r_policy_node_keyexpr(&inner->key), 0, 0, 1); } static void test_parse_policy_map_singlesig_3(void **state) { @@ -93,7 +112,7 @@ static void test_parse_policy_map_singlesig_3(void **state) { policy_node_with_key_t *inner = (policy_node_with_key_t *) r_policy_node(&mid->script); assert_int_equal(inner->base.type, TOKEN_PKH); - check_key_placeholder(r_policy_node_key_placeholder(&inner->key_placeholder), 0, 0, 1); + check_key_expr_plain(r_policy_node_keyexpr(&inner->key), 0, 0, 1); } static void test_parse_policy_map_multisig_1(void **state) { @@ -109,9 +128,9 @@ static void test_parse_policy_map_multisig_1(void **state) { assert_int_equal(node_1->base.type, TOKEN_SORTEDMULTI); assert_int_equal(node_1->k, 2); assert_int_equal(node_1->n, 3); - check_key_placeholder(&r_policy_node_key_placeholder(&node_1->key_placeholders)[0], 0, 0, 1); - check_key_placeholder(&r_policy_node_key_placeholder(&node_1->key_placeholders)[1], 1, 0, 1); - check_key_placeholder(&r_policy_node_key_placeholder(&node_1->key_placeholders)[2], 2, 0, 1); + check_key_expr_plain(&r_policy_node_keyexpr(&node_1->keys)[0], 0, 0, 1); + check_key_expr_plain(&r_policy_node_keyexpr(&node_1->keys)[1], 1, 0, 1); + check_key_expr_plain(&r_policy_node_keyexpr(&node_1->keys)[2], 2, 0, 1); } static void test_parse_policy_map_multisig_2(void **state) { @@ -132,7 +151,7 @@ static void test_parse_policy_map_multisig_2(void **state) { assert_int_equal(inner->k, 3); assert_int_equal(inner->n, 5); for (int i = 0; i < 5; i++) { - check_key_placeholder(&r_policy_node_key_placeholder(&inner->key_placeholders)[i], i, 0, 1); + check_key_expr_plain(&r_policy_node_keyexpr(&inner->keys)[i], i, 0, 1); } } @@ -158,7 +177,7 @@ static void test_parse_policy_map_multisig_3(void **state) { assert_int_equal(inner->k, 3); assert_int_equal(inner->n, 5); for (int i = 0; i < 5; i++) { - check_key_placeholder(&r_policy_node_key_placeholder(&inner->key_placeholders)[i], i, 0, 1); + check_key_expr_plain(&r_policy_node_keyexpr(&inner->keys)[i], i, 0, 1); } } @@ -175,7 +194,7 @@ static void test_parse_policy_tr(void **state) { policy_node_tr_t *root = (policy_node_tr_t *) out; assert_true(isnull_policy_node_tree(&root->tree)); - check_key_placeholder(r_policy_node_key_placeholder(&root->key_placeholder), 0, 0, 1); + check_key_expr_plain(r_policy_node_keyexpr(&root->key), 0, 0, 1); // Simple tr with a TREE that is a simple script res = parse_policy("tr(@0/**,pk(@1/**))", out, sizeof(out)); @@ -183,7 +202,7 @@ static void test_parse_policy_tr(void **state) { assert_true(res >= 0); root = (policy_node_tr_t *) out; - check_key_placeholder(r_policy_node_key_placeholder(&root->key_placeholder), 0, 0, 1); + check_key_expr_plain(r_policy_node_keyexpr(&root->key), 0, 0, 1); assert_int_equal(r_policy_node_tree(&root->tree)->is_leaf, true); @@ -191,7 +210,7 @@ static void test_parse_policy_tr(void **state) { (policy_node_with_key_t *) r_policy_node(&r_policy_node_tree(&root->tree)->script); assert_int_equal(tapscript->base.type, TOKEN_PK); - check_key_placeholder(r_policy_node_key_placeholder(&tapscript->key_placeholder), 1, 0, 1); + check_key_expr_plain(r_policy_node_keyexpr(&tapscript->key), 1, 0, 1); // Simple tr with a TREE with two tapleaves res = parse_policy("tr(@0/**,{pk(@1/**),pk(@2/<5;7>/*)})", out, sizeof(out)); @@ -199,7 +218,7 @@ static void test_parse_policy_tr(void **state) { assert_true(res >= 0); root = (policy_node_tr_t *) out; - check_key_placeholder(r_policy_node_key_placeholder(&root->key_placeholder), 0, 0, 1); + check_key_expr_plain(r_policy_node_keyexpr(&root->key), 0, 0, 1); policy_node_tree_t *taptree = r_policy_node_tree(&root->tree); @@ -212,7 +231,7 @@ static void test_parse_policy_tr(void **state) { (policy_node_with_key_t *) r_policy_node(&taptree_left->script); assert_int_equal(tapscript_left->base.type, TOKEN_PK); - check_key_placeholder(r_policy_node_key_placeholder(&tapscript_left->key_placeholder), 1, 0, 1); + check_key_expr_plain(r_policy_node_keyexpr(&tapscript_left->key), 1, 0, 1); policy_node_tree_t *taptree_right = (policy_node_tree_t *) r_policy_node_tree(&taptree->right_tree); @@ -221,10 +240,7 @@ static void test_parse_policy_tr(void **state) { (policy_node_with_key_t *) r_policy_node(&taptree_right->script); assert_int_equal(tapscript_right->base.type, TOKEN_PK); - check_key_placeholder(r_policy_node_key_placeholder(&tapscript_right->key_placeholder), - 2, - 5, - 7); + check_key_expr_plain(r_policy_node_keyexpr(&tapscript_right->key), 2, 5, 7); } static void test_parse_policy_tr_multisig(void **state) { @@ -242,9 +258,9 @@ static void test_parse_policy_tr_multisig(void **state) { policy_node_tr_t *root = (policy_node_tr_t *) out; - assert_int_equal(r_policy_node_key_placeholder(&root->key_placeholder)->key_index, 0); - assert_int_equal(r_policy_node_key_placeholder(&root->key_placeholder)->num_first, 0); - assert_int_equal(r_policy_node_key_placeholder(&root->key_placeholder)->num_second, 1); + assert_int_equal(r_policy_node_keyexpr(&root->key)->k.key_index, 0); + assert_int_equal(r_policy_node_keyexpr(&root->key)->num_first, 0); + assert_int_equal(r_policy_node_keyexpr(&root->key)->num_second, 1); policy_node_tree_t *taptree = r_policy_node_tree(&root->tree); @@ -259,14 +275,8 @@ static void test_parse_policy_tr_multisig(void **state) { assert_int_equal(tapscript_left->base.type, TOKEN_MULTI_A); assert_int_equal(tapscript_left->k, 1); assert_int_equal(tapscript_left->n, 2); - check_key_placeholder(&r_policy_node_key_placeholder(&tapscript_left->key_placeholders)[0], - 1, - 0, - 1); - check_key_placeholder(&r_policy_node_key_placeholder(&tapscript_left->key_placeholders)[1], - 2, - 0, - 1); + check_key_expr_plain(&r_policy_node_keyexpr(&tapscript_left->keys)[0], 1, 0, 1); + check_key_expr_plain(&r_policy_node_keyexpr(&tapscript_left->keys)[1], 2, 0, 1); policy_node_tree_t *taptree_right = (policy_node_tree_t *) r_policy_node_tree(&taptree->right_tree); @@ -277,18 +287,50 @@ static void test_parse_policy_tr_multisig(void **state) { assert_int_equal(tapscript_right->base.type, TOKEN_SORTEDMULTI_A); assert_int_equal(tapscript_right->k, 2); assert_int_equal(tapscript_right->n, 3); - check_key_placeholder(&r_policy_node_key_placeholder(&tapscript_right->key_placeholders)[0], - 3, - 0, - 1); - check_key_placeholder(&r_policy_node_key_placeholder(&tapscript_right->key_placeholders)[1], - 4, - 0, - 1); - check_key_placeholder(&r_policy_node_key_placeholder(&tapscript_right->key_placeholders)[2], - 5, - 0, - 1); + check_key_expr_plain(&r_policy_node_keyexpr(&tapscript_right->keys)[0], 3, 0, 1); + check_key_expr_plain(&r_policy_node_keyexpr(&tapscript_right->keys)[1], 4, 0, 1); + check_key_expr_plain(&r_policy_node_keyexpr(&tapscript_right->keys)[2], 5, 0, 1); +} + +static void test_parse_policy_tr_musig_keypath(void **state) { + (void) state; + + uint8_t out[MAX_WALLET_POLICY_MEMORY_SIZE]; + int res; + + res = parse_policy("tr(musig(@2,@0,@1)/<3;13>/*)", out, sizeof(out)); + + assert_true(res >= 0); + + policy_node_tr_t *root = (policy_node_tr_t *) out; + assert_int_equal(root->base.type, TOKEN_TR); + assert_true(isnull_policy_node_tree(&root->tree)); + + check_key_expr_musig(r_policy_node_keyexpr(&root->key), 3, (uint16_t[]){2, 0, 1}, 3, 13); +} + +static void test_parse_policy_tr_musig_scriptpath(void **state) { + (void) state; + + uint8_t out[MAX_WALLET_POLICY_MEMORY_SIZE]; + int res; + + // tr with a musig in the script path + res = parse_policy("tr(@1/**,pk(musig(@2,@0,@3)/**))", out, sizeof(out)); + + assert_true(res >= 0); + + policy_node_tr_t *root = (policy_node_tr_t *) out; + assert_int_equal(root->base.type, TOKEN_TR); + + assert_false(isnull_policy_node_tree(&root->tree)); + policy_node_tree_t *tree = r_policy_node_tree(&root->tree); + assert_true(tree->is_leaf); + + policy_node_with_key_t *script_pk = (policy_node_with_key_t *) r_policy_node(&tree->script); + assert_int_equal(script_pk->base.type, TOKEN_PK); + + check_key_expr_musig(r_policy_node_keyexpr(&script_pk->key), 3, (uint16_t[]){2, 0, 3}, 0, 1); } static void test_get_policy_segwit_version(void **state) { @@ -377,6 +419,14 @@ static void test_failures(void **state) { assert_true(0 > parse_policy("tr(@0/**,sortedmulti(2,@1,@2))", out, sizeof(out))); assert_true(0 > parse_policy("tr(@0/**,sh(pk(@0/**)))", out, sizeof(out))); assert_true(0 > parse_policy("tr(@0/**,wsh(pk(@0/**)))", out, sizeof(out))); + + // invalid usages of musig expressions + assert_true(0 > parse_policy("tr(musig(@0,@1))", out, sizeof(out))); // missing derivations + assert_true(0 > parse_policy("tr(musig()/**)", out, sizeof(out))); // empty musig + assert_true(0 > parse_policy("tr(musig(@0)/**)", out, sizeof(out))); // needs at least two keys + assert_true(0 > parse_policy("wpkh(musig(@0,@1)/**)", out, sizeof(out))); // not taproot + assert_true( + 0 > parse_policy("tr(musig(@0,musig(@1,@2))/**)", out, sizeof(out))); // can't nest musig } enum TestMode { @@ -630,6 +680,8 @@ int main() { cmocka_unit_test(test_parse_policy_map_multisig_3), cmocka_unit_test(test_parse_policy_tr), cmocka_unit_test(test_parse_policy_tr_multisig), + cmocka_unit_test(test_parse_policy_tr_musig_keypath), + cmocka_unit_test(test_parse_policy_tr_musig_scriptpath), cmocka_unit_test(test_get_policy_segwit_version), cmocka_unit_test(test_failures), cmocka_unit_test(test_miniscript_types),