diff --git a/src/handler/sign_psbt.c b/src/handler/sign_psbt.c index d90d6ed8..f8001021 100644 --- a/src/handler/sign_psbt.c +++ b/src/handler/sign_psbt.c @@ -114,6 +114,10 @@ typedef struct { int index; uint32_t fingerprint; + // we only sign for keys expressions for which we find a matching key derivation in the PSBT, + // at least for one of the inputs + bool to_sign; + // info about the internal key of this key expression // used at signing time to derive the correct key uint32_t key_derivation[MAX_BIP32_PATH_STEPS]; @@ -132,8 +136,10 @@ typedef struct { // internal key for musig key expressions serialized_extended_pubkey_t internal_pubkey; - bool is_tapscript; // true if signing with a BIP342 tapleaf script path spend - uint8_t tapleaf_hash[32]; // only used for tapscripts + bool is_tapscript; // true if signing with a BIP342 tapleaf script path spend + // only used for tapscripts + const policy_node_t *tapleaf_ptr; + uint8_t tapleaf_hash[32]; } keyexpr_info_t; // Cache for partial hashes during signing (avoid quadratic hashing for segwit transactions) @@ -398,19 +404,24 @@ static int get_amount_scriptpubkey_from_psbt( NULL); } -// Convenience function to share common logic when processing all the -// PSBT_{IN|OUT}_{TAP}?_BIP32_DERIVATION fields. +typedef struct { + uint32_t fingerprint; + size_t derivation_len; + uint32_t key_origin[MAX_BIP32_PATH_STEPS]; +} derivation_info_t; + +// Convenience function to share common logic when parsing the +// PSBT_{IN|OUT}_{TAP}?_BIP32_DERIVATION fields from inputs or outputs. // Note: This function must return -1 only on errors (causing signing to abort). -// It should return 1 if a derivation that makes sense for this input/output is found. -// It should return 0 otherwise (no match found, but continue the signing flow). +// It returns 1 if a that might match the wallet policy is found. +// It returns 0 otherwise (not a match, but continue the signing flow). static int read_change_and_index_from_psbt_bip32_derivation( dispatcher_context_t *dc, - const keyexpr_info_t *keyexpr_info, - in_out_info_t *in_out, int psbt_key_type, buffer_t *data, const merkleized_map_commitment_t *map_commitment, - int index) { + int index, + derivation_info_t *derivation_info) { uint8_t bip32_derivation_pubkey[33]; bool is_tap = psbt_key_type == PSBT_IN_TAP_BIP32_DERIVATION || @@ -423,7 +434,6 @@ static int read_change_and_index_from_psbt_bip32_derivation( || buffer_can_read(data, 1) // ...but should not be able to read more ) { PRINTF("Unexpected pubkey length\n"); - in_out->unexpected_pubkey_error = true; return -1; } @@ -447,35 +457,34 @@ static int read_change_and_index_from_psbt_bip32_derivation( return 0; } - // if this derivation path matches the internal key expression, - // we use it to detect whether the current input is change or not, - // and store its address index - if (fpt_der[0] == keyexpr_info->fingerprint && - der_len == keyexpr_info->psbt_root_key_derivation_length + 2) { - for (int i = 0; i < keyexpr_info->psbt_root_key_derivation_length; i++) { - if (keyexpr_info->key_derivation[i] != fpt_der[1 + i]) { - return 0; - } - } + derivation_info->fingerprint = fpt_der[0]; + for (int i = 0; i < der_len; i++) { + derivation_info->key_origin[i] = fpt_der[i + 1]; + } + derivation_info->derivation_len = der_len; - uint32_t change_step = fpt_der[1 + der_len - 2]; - uint32_t addr_index = fpt_der[1 + der_len - 1]; + return 1; +} - // check if the 'change' derivation step is indeed coherent with key expression - if (change_step == keyexpr_info->key_expression_ptr->num_first) { - in_out->is_change = false; - in_out->address_index = addr_index; - } else if (change_step == keyexpr_info->key_expression_ptr->num_second) { - in_out->is_change = true; - in_out->address_index = addr_index; - } else { - return 0; +bool is_keyexpr_compatible_with_derivation_info(const keyexpr_info_t *keyexpr_info, + const derivation_info_t *derivation_info) { + if (keyexpr_info->fingerprint != derivation_info->fingerprint) { + return false; + } + if (keyexpr_info->psbt_root_key_derivation_length + 2 != derivation_info->derivation_len) { + return false; + } + for (int i = 0; i < keyexpr_info->psbt_root_key_derivation_length; i++) { + if (keyexpr_info->key_derivation[i] != derivation_info->key_origin[i]) { + return false; } - - in_out->key_expression_found = true; - return 1; } - return 0; + uint32_t change_step = derivation_info->key_origin[derivation_info->derivation_len - 2]; + if (change_step != keyexpr_info->key_expression_ptr->num_first && + change_step != keyexpr_info->key_expression_ptr->num_second) { + return false; + } + return true; } /** @@ -824,7 +833,7 @@ static bool fill_keyexpr_info_if_internal(dispatcher_context_t *dc, } typedef struct { - keyexpr_info_t *keyexpr_info; + sign_psbt_state_t *state; input_info_t *input; } input_keys_callback_data_t; @@ -835,7 +844,7 @@ typedef struct { static void input_keys_callback(dispatcher_context_t *dc, input_keys_callback_data_t *callback_data, const merkleized_map_commitment_t *map_commitment, - int i, + int index, buffer_t *data) { size_t data_len = data->size - data->offset; if (data_len >= 1) { @@ -849,17 +858,39 @@ static void input_keys_callback(dispatcher_context_t *dc, callback_data->input->has_redeemScript = true; } else if (key_type == PSBT_IN_SIGHASH_TYPE) { callback_data->input->has_sighash_type = true; - } else if ((key_type == PSBT_IN_BIP32_DERIVATION || - key_type == PSBT_IN_TAP_BIP32_DERIVATION) && - !callback_data->input->in_out.key_expression_found) { - if (0 > read_change_and_index_from_psbt_bip32_derivation(dc, - callback_data->keyexpr_info, - &callback_data->input->in_out, - key_type, - data, - map_commitment, - i)) { + } else if (key_type == PSBT_IN_BIP32_DERIVATION || + key_type == PSBT_IN_TAP_BIP32_DERIVATION) { + derivation_info_t derivation_info; + int res = read_change_and_index_from_psbt_bip32_derivation(dc, + key_type, + data, + map_commitment, + index, + &derivation_info); + if (res < 0) { + // there was an error; we keep track of it so an error SW is sent later callback_data->input->in_out.unexpected_pubkey_error = true; + } else if (res == 0) { + // nothing to do + } else if (res == 1) { + in_out_info_t *in_out = &callback_data->input->in_out; + for (size_t i = 0; i < callback_data->state->n_internal_key_expressions; i++) { + keyexpr_info_t *key_expr = &callback_data->state->internal_key_expressions[i]; + if (is_keyexpr_compatible_with_derivation_info(key_expr, &derivation_info)) { + key_expr->to_sign = true; + + bool is_change = + key_expr->key_expression_ptr->num_second == + derivation_info.key_origin[derivation_info.derivation_len - 2]; + + in_out->key_expression_found = true; + in_out->is_change = is_change; + in_out->address_index = + derivation_info.key_origin[derivation_info.derivation_len - 1]; + } + } + } else { + LEDGER_ASSERT(false, "Unreachable code"); } } } @@ -873,11 +904,20 @@ static bool fill_internal_key_expressions(dispatcher_context_t *dc, sign_psbt_st // find and parse our registered key info in the wallet keyexpr_info_t keyexpr_info; + memset(&keyexpr_info, 0, sizeof(keyexpr_info_t)); while (true) { + keyexpr_info.index = cur_index; + const policy_node_t *tapleaf_ptr = NULL; int n_key_expressions = get_keyexpr_by_index(st->wallet_policy_map, cur_index, - NULL, + &tapleaf_ptr, &keyexpr_info.key_expression_ptr); + if (tapleaf_ptr != NULL) { + // get_keyexpr_by_index returns the pointer to the tapleaf only if the key being + // spent is indeed in a tapleaf + keyexpr_info.tapleaf_ptr = tapleaf_ptr; + keyexpr_info.is_tapscript = true; + } if (n_key_expressions < 0) { SEND_SW(dc, SW_BAD_STATE); // should never happen return false; @@ -901,7 +941,6 @@ static bool fill_internal_key_expressions(dispatcher_context_t *dc, sign_psbt_st &keyexpr_info, sizeof(keyexpr_info_t)); ++st->n_internal_key_expressions; - keyexpr_info.index = 0; } // Not an internal key, move on @@ -933,10 +972,7 @@ preprocess_inputs(dispatcher_context_t *dc, input_info_t input; memset(&input, 0, sizeof(input)); - input_keys_callback_data_t callback_data = { - .input = &input, - // it doesn't matter which key expression we use here - .keyexpr_info = &st->internal_key_expressions[0]}; + input_keys_callback_data_t callback_data = {.input = &input, .state = st}; int res = call_get_merkleized_map_with_callback( dc, (void *) &callback_data, @@ -1137,7 +1173,7 @@ preprocess_inputs(dispatcher_context_t *dc, } typedef struct { - keyexpr_info_t *keyexpr_info; + sign_psbt_state_t *state; output_info_t *output; } output_keys_callback_data_t; @@ -1148,7 +1184,7 @@ typedef struct { static void output_keys_callback(dispatcher_context_t *dc, output_keys_callback_data_t *callback_data, const merkleized_map_commitment_t *map_commitment, - int i, + int index, buffer_t *data) { size_t data_len = data->size - data->offset; if (data_len >= 1) { @@ -1157,14 +1193,36 @@ static void output_keys_callback(dispatcher_context_t *dc, if ((key_type == PSBT_OUT_BIP32_DERIVATION || key_type == PSBT_OUT_TAP_BIP32_DERIVATION) && !callback_data->output->in_out.key_expression_found) { - if (0 > read_change_and_index_from_psbt_bip32_derivation(dc, - callback_data->keyexpr_info, - &callback_data->output->in_out, - key_type, - data, - map_commitment, - i)) { + derivation_info_t derivation_info; + int res = read_change_and_index_from_psbt_bip32_derivation(dc, + key_type, + data, + map_commitment, + index, + &derivation_info); + if (res < 0) { + // there was an error; we keep track of it so an error SW is sent later callback_data->output->in_out.unexpected_pubkey_error = true; + } else if (res == 1) { + in_out_info_t *in_out = &callback_data->output->in_out; + for (size_t i = 0; i < callback_data->state->n_internal_key_expressions; i++) { + const keyexpr_info_t *key_expr = + &callback_data->state->internal_key_expressions[i]; + if (is_keyexpr_compatible_with_derivation_info(key_expr, &derivation_info)) { + bool is_change = + key_expr->key_expression_ptr->num_second == + derivation_info.key_origin[derivation_info.derivation_len - 2]; + + in_out->key_expression_found = true; + in_out->is_change = is_change; + in_out->address_index = + derivation_info.key_origin[derivation_info.derivation_len - 1]; + // unlike for inputs, where we want to keep track of all the key expressions + // we want to sign for, here we only care about finding the relevant info + // for this output. Therefore, we're done as soon as we have a match. + break; + } + } } } } @@ -1193,10 +1251,7 @@ preprocess_outputs(dispatcher_context_t *dc, output_info_t output; memset(&output, 0, sizeof(output)); - output_keys_callback_data_t callback_data = { - .output = &output, - // any internal key expression is good here - .keyexpr_info = &st->internal_key_expressions[0]}; + output_keys_callback_data_t callback_data = {.output = &output, .state = st}; int res = call_get_merkleized_map_with_callback( dc, (void *) &callback_data, @@ -3255,74 +3310,56 @@ sign_transaction(dispatcher_context_t *dc, } // Iterate over all the key expressions that correspond to keys owned by us - while (true) { - keyexpr_info_t keyexpr_info; - memset(&keyexpr_info, 0, sizeof(keyexpr_info)); - - const policy_node_t *tapleaf_ptr = NULL; - int n_key_expressions = get_keyexpr_by_index(st->wallet_policy_map, - key_expression_index, - &tapleaf_ptr, - &keyexpr_info.key_expression_ptr); - - if (n_key_expressions < 0) { - SEND_SW(dc, SW_BAD_STATE); // should never happen - return false; - } - - if (key_expression_index >= n_key_expressions) { - // all key expressions were processed - break; + for (size_t i_keyexpr = 0; i_keyexpr < st->n_internal_key_expressions; i_keyexpr++) { + keyexpr_info_t *keyexpr_info = &st->internal_key_expressions[i_keyexpr]; + if (!keyexpr_info->to_sign) { + continue; } - if (tapleaf_ptr != NULL) { - // get_keyexpr_by_index returns the pointer to the tapleaf only if the key being - // spent is indeed in a tapleaf - keyexpr_info.is_tapscript = true; + if (!fill_keyexpr_info_if_internal(dc, st, keyexpr_info) == true) { + continue; } - if (fill_keyexpr_info_if_internal(dc, st, &keyexpr_info) == true) { - for (unsigned int i = 0; i < st->n_inputs; i++) - if (bitvector_get(internal_inputs, i)) { - input_info_t input; - memset(&input, 0, sizeof(input)); - - input_keys_callback_data_t callback_data = {.input = &input, - .keyexpr_info = &keyexpr_info}; - int res = call_get_merkleized_map_with_callback( - dc, - (void *) &callback_data, - st->inputs_root, - st->n_inputs, - i, - (merkle_tree_elements_callback_t) input_keys_callback, - &input.in_out.map); - if (res < 0) { - SEND_SW(dc, SW_INCORRECT_DATA); - return false; - } + for (unsigned int i = 0; i < st->n_inputs; i++) { + if (bitvector_get(internal_inputs, i)) { + input_info_t input; + memset(&input, 0, sizeof(input)); - if (tapleaf_ptr != NULL && !fill_taproot_keyexpr_info(dc, - st, - &input, - tapleaf_ptr, - &keyexpr_info, - sign_psbt_cache)) { - return false; - } + input_keys_callback_data_t callback_data = {.input = &input, .state = st}; + int res = call_get_merkleized_map_with_callback( + dc, + (void *) &callback_data, + st->inputs_root, + st->n_inputs, + i, + (merkle_tree_elements_callback_t) input_keys_callback, + &input.in_out.map); + if (res < 0) { + SEND_SW(dc, SW_INCORRECT_DATA); + return false; + } + if (keyexpr_info->tapleaf_ptr != NULL && + !fill_taproot_keyexpr_info(dc, + st, + &input, + keyexpr_info->tapleaf_ptr, + keyexpr_info, + sign_psbt_cache)) { + return false; + } - if (!sign_transaction_input(dc, - st, - sign_psbt_cache, - &signing_state, - &keyexpr_info, - &input, - i)) { - // we do not send a status word, since sign_transaction_input - // already does it on failure - return false; - } + if (!sign_transaction_input(dc, + st, + sign_psbt_cache, + &signing_state, + keyexpr_info, + &input, + i)) { + // we do not send a status word, since sign_transaction_input + // already does it on failure + return false; } + } } ++key_expression_index;