diff --git a/src/common/wallet.c b/src/common/wallet.c index 217f64ae1..b47f0a8ca 100644 --- a/src/common/wallet.c +++ b/src/common/wallet.c @@ -459,7 +459,7 @@ static int parse_keyexpr(buffer_t *in_buf, return WITH_ERROR(-1, "The key index in a placeholder must be at most 32767"); } - out->key_index = (int16_t) k; + out->k.key_index = (int16_t) k; } else if (c == 'm') { // parse a musig(key1,...,keyn) expression, where each key is a key expression if (!consume_characters(in_buf, "usig(", 5)) { @@ -531,7 +531,7 @@ static int parse_keyexpr(buffer_t *in_buf, musig_info->n = n_musig_keys; i_uint16(&musig_info->key_indexes, key_indexes); - i_musig_aggr_key_info(&out->musig_info, musig_info); + i_musig_aggr_key_info(&out->m.musig_info, musig_info); } else { return WITH_ERROR(-1, "Expected key placeholder starting with '@', or musig"); } diff --git a/src/common/wallet.h b/src/common/wallet.h index fe2751e8e..58da5eb20 100644 --- a/src/common/wallet.h +++ b/src/common/wallet.h @@ -333,9 +333,14 @@ typedef struct { KeyExpressionType type; union { // type == 0 - int16_t key_index; // index of the key (common between V1 and V2) + struct { + int16_t key_index; // index of the key (common between V1 and V2) + } k; + // type == 1 - rptr_musig_aggr_key_info_t musig_info; + struct { + rptr_musig_aggr_key_info_t musig_info; + } m; }; } policy_node_keyexpr_t; diff --git a/src/handler/lib/policy.c b/src/handler/lib/policy.c index a155a743c..cb48be05c 100644 --- a/src/handler/lib/policy.c +++ b/src/handler/lib/policy.c @@ -462,7 +462,12 @@ __attribute__((warn_unused_result)) static int get_derived_pubkey( serialized_extended_pubkey_t ext_pubkey; - int ret = get_extended_pubkey(dispatcher_context, wdi, key_expr->key_index, &ext_pubkey); + if (key_expr->type != KEY_EXPRESSION_NORMAL) { + PRINTF("Not implemented\n"); // TODO + return -1; + } + + int ret = get_extended_pubkey(dispatcher_context, wdi, key_expr->k.key_index, &ext_pubkey); if (ret < 0) { return -1; } @@ -1376,7 +1381,12 @@ static int get_bip44_purpose(const policy_node_t *descriptor_template) { return -1; } - if (kp->key_index != 0 || kp->num_first != 0 || kp->num_second != 1) { + if (kp->type != KEY_EXPRESSION_NORMAL) { + // any key expression that is not a play xpub is not BIP-44 compliant + return -1; + } + + if (kp->k.key_index != 0 || kp->num_first != 0 || kp->num_second != 1) { return -1; } @@ -1508,7 +1518,7 @@ bool check_wallet_hmac(const uint8_t wallet_id[static 32], const uint8_t wallet_ static int get_keyexpr_by_index_in_tree(const policy_node_tree_t *tree, unsigned int i, const policy_node_t **out_tapleaf_ptr, - policy_node_keyexpr_t *out_keyexpr) { + policy_node_keyexpr_t **out_keyexpr) { if (tree->is_leaf) { int ret = get_keyexpr_by_index(r_policy_node(&tree->script), i, NULL, out_keyexpr); if (ret >= 0 && out_tapleaf_ptr != NULL && i < (unsigned) ret) { @@ -1534,16 +1544,12 @@ static int get_keyexpr_by_index_in_tree(const policy_node_tree_t *tree, } } -// TODO: generalize for musig. Note that this is broken for musig, as out_keyexpr -// can't be filled in for musig key expressions (as it's dynamic and contains -// relative pointers). We should probably refactor to return the pointer to the -// key expression and removing the out_keyexpr argument. int get_keyexpr_by_index(const policy_node_t *policy, unsigned int i, const policy_node_t **out_tapleaf_ptr, - policy_node_keyexpr_t *out_keyexpr) { + policy_node_keyexpr_t **out_keyexpr) { // make sure that out_keyexpr is a valid pointer, if the output is not needed - policy_node_keyexpr_t tmp; + policy_node_keyexpr_t *tmp; if (out_keyexpr == NULL) { out_keyexpr = &tmp; } @@ -1568,16 +1574,14 @@ int get_keyexpr_by_index(const policy_node_t *policy, case TOKEN_WPKH: { if (i == 0) { policy_node_with_key_t *wpkh = (policy_node_with_key_t *) policy; - memcpy(out_keyexpr, - r_policy_node_keyexpr(&wpkh->key), - sizeof(policy_node_keyexpr_t)); + *out_keyexpr = r_policy_node_keyexpr(&wpkh->key); } return 1; } case TOKEN_TR: { policy_node_tr_t *tr = (policy_node_tr_t *) policy; if (i == 0) { - memcpy(out_keyexpr, r_policy_node_keyexpr(&tr->key), sizeof(policy_node_keyexpr_t)); + *out_keyexpr = r_policy_node_keyexpr(&tr->key); } if (!isnull_policy_node_tree(&tr->tree)) { int ret_tree = get_keyexpr_by_index_in_tree( @@ -1604,7 +1608,7 @@ int get_keyexpr_by_index(const policy_node_t *policy, if (i < (unsigned int) node->n) { policy_node_keyexpr_t *key_expressions = r_policy_node_keyexpr(&node->keys); - memcpy(out_keyexpr, &key_expressions[i], sizeof(policy_node_keyexpr_t)); + *out_keyexpr = &key_expressions[i]; } return node->n; @@ -1715,16 +1719,24 @@ int get_keyexpr_by_index(const policy_node_t *policy, } int count_distinct_keys_info(const policy_node_t *policy) { - policy_node_keyexpr_t key_expression; + policy_node_keyexpr_t *key_expression_ptr; int ret = -1, cur, n_key_expressions; for (cur = 0; - cur < (n_key_expressions = get_keyexpr_by_index(policy, cur, NULL, &key_expression)); + cur < (n_key_expressions = get_keyexpr_by_index(policy, cur, NULL, &key_expression_ptr)); ++cur) { if (n_key_expressions < 0) { return -1; } - ret = MAX(ret, key_expression.key_index + 1); + if (key_expression_ptr->type == KEY_EXPRESSION_NORMAL) { + ret = MAX(ret, key_expression_ptr->k.key_index + 1); + } else if (key_expression_ptr->type == KEY_EXPRESSION_MUSIG) { + musig_aggr_key_info_t *musig_info = + r_musig_aggr_key_info(&key_expression_ptr->m.musig_info); + ret = MAX(ret, musig_info->n); + } else { + LEDGER_ASSERT(false, "Unknown key expression type"); + } } return ret; } @@ -1912,21 +1924,30 @@ int is_policy_sane(dispatcher_context_t *dispatcher_context, // proportional to the depth of the wallet policy's abstract syntax tree. for (int i = 0; i < n_key_expressions - 1; i++) { // no point in running this for the last key expression - policy_node_keyexpr_t kp_i; + policy_node_keyexpr_t *kp_i; if (0 > get_keyexpr_by_index(policy, i, NULL, &kp_i)) { return WITH_ERROR(-1, "Unexpected error retrieving key expressions from the policy"); } for (int j = i + 1; j < n_key_expressions; j++) { - policy_node_keyexpr_t kp_j; + policy_node_keyexpr_t *kp_j; if (0 > get_keyexpr_by_index(policy, j, NULL, &kp_j)) { return WITH_ERROR(-1, "Unexpected error retrieving key expressions from the policy"); } + if (kp_i->type != kp_j->type) { + // if one is a key and the other is a musig, there's nothing else to check + continue; + } + + LEDGER_ASSERT( + kp_i->type == KEY_EXPRESSION_NORMAL && kp_j->type == KEY_EXPRESSION_NORMAL, + "TODO"); + // key expressions for the same key must have disjoint derivation options - if (kp_i.key_index == kp_j.key_index) { - if (kp_i.num_first == kp_j.num_first || kp_i.num_first == kp_j.num_second || - kp_i.num_second == kp_j.num_first || kp_i.num_second == kp_j.num_second) { + if (kp_i->k.key_index == kp_j->k.key_index) { + if (kp_i->num_first == kp_j->num_first || kp_i->num_first == kp_j->num_second || + kp_i->num_second == kp_j->num_first || kp_i->num_second == kp_j->num_second) { return WITH_ERROR(-1, "Key expressions with repeated derivations in miniscript"); } diff --git a/src/handler/lib/policy.h b/src/handler/lib/policy.h index cfe6a5076..d62f3e4d0 100644 --- a/src/handler/lib/policy.h +++ b/src/handler/lib/policy.h @@ -187,13 +187,14 @@ bool check_wallet_hmac(const uint8_t wallet_id[static 32], const uint8_t wallet_ * If not NULL, and if the i-th key expression is in a tapleaf of the policy, receives the pointer * to the tapleaf's script. * @param[out] out_keyexpr - * If not NULL, it is a pointer that will receive the i-th key expression of the policy. + * If not NULL, it is a pointer that will receive a pointer to the i-th key expression of the + * policy. * @return the number of key expressions in the policy on success; -1 in case of error. */ __attribute__((warn_unused_result)) int get_keyexpr_by_index(const policy_node_t *policy, unsigned int i, const policy_node_t **out_tapleaf_ptr, - policy_node_keyexpr_t *out_keyexpr); + policy_node_keyexpr_t **out_keyexpr); /** * Determines the expected number of unique keys in the provided policy's key information. diff --git a/src/handler/sign_psbt.c b/src/handler/sign_psbt.c index 0c825d43d..bd73c6bf2 100644 --- a/src/handler/sign_psbt.c +++ b/src/handler/sign_psbt.c @@ -105,7 +105,7 @@ typedef struct { } output_info_t; typedef struct { - policy_node_keyexpr_t key_expression; + policy_node_keyexpr_t *key_expression_ptr; int cur_index; uint32_t fingerprint; uint8_t key_derivation_length; @@ -451,10 +451,10 @@ static int read_change_and_index_from_psbt_bip32_derivation( } // check if the 'change' derivation step is indeed coherent with the key expression - if (change == keyexpr_info->key_expression.num_first) { + if (change == keyexpr_info->key_expression_ptr->num_first) { in_out->is_change = false; in_out->address_index = addr_index; - } else if (change == keyexpr_info->key_expression.num_second) { + } else if (change == keyexpr_info->key_expression_ptr->num_second) { in_out->is_change = true; in_out->address_index = addr_index; } else { @@ -710,12 +710,17 @@ static bool __attribute__((noinline)) fill_keyexpr_info_if_internal(dispatcher_c policy_map_key_info_t key_info; { uint8_t key_info_str[MAX_POLICY_KEY_INFO_LEN]; - int key_info_len = call_get_merkle_leaf_element(dc, - st->wallet_header_keys_info_merkle_root, - st->wallet_header_n_keys, - keyexpr_info->key_expression.key_index, - key_info_str, - sizeof(key_info_str)); + + // TODO: generalize for musig: keyexpr_info->key_expression->key_index is wrong + LEDGER_ASSERT(keyexpr_info->key_expression_ptr->type == KEY_EXPRESSION_NORMAL, "TODO"); + + int key_info_len = + call_get_merkle_leaf_element(dc, + st->wallet_header_keys_info_merkle_root, + st->wallet_header_n_keys, + keyexpr_info->key_expression_ptr->k.key_index, + key_info_str, + sizeof(key_info_str)); if (key_info_len < 0) { SEND_SW(dc, SW_BAD_STATE); // should never happen @@ -775,7 +780,7 @@ static bool find_first_internal_keyexpr(dispatcher_context_t *dc, int n_key_expressions = get_keyexpr_by_index(st->wallet_policy_map, keyexpr_info->cur_index, NULL, - &keyexpr_info->key_expression); + &keyexpr_info->key_expression_ptr); if (n_key_expressions < 0) { SEND_SW(dc, SW_BAD_STATE); // should never happen return false; @@ -1884,9 +1889,9 @@ static bool __attribute__((noinline)) sign_sighash_ecdsa_and_yield(dispatcher_co for (int i = 0; i < keyexpr_info->key_derivation_length; i++) { sign_path[i] = keyexpr_info->key_derivation[i]; } - sign_path[keyexpr_info->key_derivation_length] = input->in_out.is_change - ? keyexpr_info->key_expression.num_second - : keyexpr_info->key_expression.num_first; + sign_path[keyexpr_info->key_derivation_length] = + input->in_out.is_change ? keyexpr_info->key_expression_ptr->num_second + : keyexpr_info->key_expression_ptr->num_first; sign_path[keyexpr_info->key_derivation_length + 1] = input->in_out.address_index; int sign_path_len = keyexpr_info->key_derivation_length + 2; @@ -1953,8 +1958,8 @@ static bool __attribute__((noinline)) sign_sighash_schnorr_and_yield(dispatcher_ sign_path[i] = keyexpr_info->key_derivation[i]; } sign_path[keyexpr_info->key_derivation_length] = - input->in_out.is_change ? keyexpr_info->key_expression.num_second - : keyexpr_info->key_expression.num_first; + input->in_out.is_change ? keyexpr_info->key_expression_ptr->num_second + : keyexpr_info->key_expression_ptr->num_first; sign_path[keyexpr_info->key_derivation_length + 1] = input->in_out.address_index; int sign_path_len = keyexpr_info->key_derivation_length + 2; @@ -2343,8 +2348,8 @@ static bool __attribute__((noinline)) fill_taproot_keyexpr_info(dispatcher_conte const input_info_t *input, const policy_node_t *tapleaf_ptr, keyexpr_info_t *keyexpr_info) { - uint32_t change = input->in_out.is_change ? keyexpr_info->key_expression.num_second - : keyexpr_info->key_expression.num_first; + uint32_t change = input->in_out.is_change ? keyexpr_info->key_expression_ptr->num_second + : keyexpr_info->key_expression_ptr->num_first; uint32_t address_index = input->in_out.address_index; cx_sha256_t hash_context; @@ -2413,7 +2418,7 @@ sign_transaction(dispatcher_context_t *dc, int n_key_expressions = get_keyexpr_by_index(st->wallet_policy_map, key_expression_index, &tapleaf_ptr, - &keyexpr_info.key_expression); + &keyexpr_info.key_expression_ptr); if (n_key_expressions < 0) { SEND_SW(dc, SW_BAD_STATE); // should never happen