diff --git a/crates/bdk/src/wallet/mod.rs b/crates/bdk/src/wallet/mod.rs index 4f0687690..33deb480a 100644 --- a/crates/bdk/src/wallet/mod.rs +++ b/crates/bdk/src/wallet/mod.rs @@ -28,6 +28,7 @@ use bdk_chain::{ Append, BlockId, ChainPosition, ConfirmationTime, ConfirmationTimeHeightAnchor, FullTxOut, IndexedTxGraph, Persist, PersistBackend, }; +use bitcoin::hashes::{sha256, Hash}; use bitcoin::secp256k1::{All, Secp256k1}; use bitcoin::sighash::{EcdsaSighashType, TapSighashType}; use bitcoin::{ @@ -39,6 +40,7 @@ use bitcoin::{constants::genesis_block, psbt}; use core::fmt; use core::ops::Deref; use descriptor::error::Error as DescriptorError; +use miniscript::descriptor::KeyMap; use miniscript::psbt::{PsbtExt, PsbtInputExt, PsbtInputSatisfier}; use bdk_chain::tx_graph::CalculateFeeError; @@ -129,20 +131,43 @@ pub struct ChangeSet { keychain::ChangeSet, >, - /// Stores the network type of the wallet. + /// Stores the `Network` type of the wallet. This is used to prevent appending changesets + /// that were created for a different network. The initial changeset must have the correct + /// network value set. pub network: Option, + + /// Stores hashes for wallet keychain descriptors. This is used to prevent appending + /// changesets that were created for a different set of descriptors. The hashes should be + /// made from the public key versions of the descriptors. The initial changeset must have the + /// correct descriptor values set. + pub descriptor_hashes: BTreeMap, } impl Append for ChangeSet { fn append(&mut self, other: Self) { - Append::append(&mut self.chain, other.chain); - Append::append(&mut self.indexed_tx_graph, other.indexed_tx_graph); - if other.network.is_some() { - debug_assert!( - self.network.is_none() || self.network == other.network, - "network type must be consistent" - ); - self.network = other.network; + // Only append change set if network and descriptors match + let external_match = self + .descriptor_hashes + .get(&KeychainKind::External) + .is_none() + || other.descriptor_hashes.get(&KeychainKind::External) + == self.descriptor_hashes.get(&KeychainKind::External); + let internal_match = self + .descriptor_hashes + .get(&KeychainKind::Internal) + .is_none() + || other.descriptor_hashes.get(&KeychainKind::Internal) + == self.descriptor_hashes.get(&KeychainKind::Internal); + + if self.network.is_none() + || other.network == self.network && external_match && internal_match + { + Append::append(&mut self.chain, other.chain); + Append::append(&mut self.indexed_tx_graph, other.indexed_tx_graph); + Append::append(&mut self.network, other.network); + Append::append(&mut self.descriptor_hashes, other.descriptor_hashes); + } else { + panic!("something didn't match"); } } @@ -316,16 +341,35 @@ impl std::error::Error for NewError where W: core::fmt::Display + core::fm /// [`load`]: Wallet::load #[derive(Debug)] pub enum LoadError { - /// There was a problem with the passed-in descriptor(s). - Descriptor(crate::descriptor::DescriptorError), /// Loading data from the persistence backend failed. Load(L), + /// There was problem with the passed-in descriptor. + Descriptor(DescriptorError), /// Wallet not initialized, persistence backend is empty. NotInitialized, - /// Data loaded from persistence is missing network type. - MissingNetwork, - /// Data loaded from persistence is missing genesis hash. - MissingGenesis, + /// The loaded genesis hash does not match what was expected. + GenesisDoesNotMatch { + /// The expected genesis block hash. + expected: BlockHash, + /// The block hash loaded from persistence. + got: Option, + }, + /// The loaded network type does not match what was expected. + NetworkDoesNotMatch { + /// The expected network type. + expected: Network, + /// The network type loaded from persistence. + got: Option, + }, + /// The loaded descriptor hash does not match what was expected. + DescriptorDoesNotMatch { + /// Keychain + keychain: KeychainKind, + /// The expected descriptor string hash. + expected: Option, + /// The descriptor string hash loaded from persistence. + got: Option, + }, } impl fmt::Display for LoadError @@ -334,13 +378,28 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - LoadError::Descriptor(e) => e.fmt(f), LoadError::Load(e) => e.fmt(f), + LoadError::Descriptor(e) => e.fmt(f), LoadError::NotInitialized => { write!(f, "wallet is not initialized, persistence backend is empty") } - LoadError::MissingNetwork => write!(f, "loaded data is missing network type"), - LoadError::MissingGenesis => write!(f, "loaded data is missing genesis hash"), + LoadError::GenesisDoesNotMatch { expected, got } => { + write!(f, "loaded genesis hash is not {}, got {:?}", expected, got) + } + LoadError::NetworkDoesNotMatch { expected, got } => { + write!(f, "loaded network is not {}, got {:?}", expected, got) + } + LoadError::DescriptorDoesNotMatch { + keychain, + expected, + got, + } => { + write!( + f, + "loaded {:?} descriptor hash is not {:?}, got {:?}", + keychain, expected, got + ) + } } } } @@ -356,28 +415,12 @@ impl std::error::Error for LoadError where L: core::fmt::Display + core::f /// [`new_or_load_with_genesis_hash`]: Wallet::new_or_load_with_genesis_hash #[derive(Debug)] pub enum NewOrLoadError { - /// There is a problem with the passed-in descriptor. - Descriptor(crate::descriptor::DescriptorError), /// Writing to the persistence backend failed. Write(W), + /// There was problem with the passed-in descriptor. + Descriptor(DescriptorError), /// Loading from the persistence backend failed. - Load(L), - /// Wallet is not initialized, persistence backend is empty. - NotInitialized, - /// The loaded genesis hash does not match what was provided. - LoadedGenesisDoesNotMatch { - /// The expected genesis block hash. - expected: BlockHash, - /// The block hash loaded from persistence. - got: Option, - }, - /// The loaded network type does not match what was provided. - LoadedNetworkDoesNotMatch { - /// The expected network type. - expected: Network, - /// The network type loaded from persistence. - got: Option, - }, + Load(LoadError), } impl fmt::Display for NewOrLoadError @@ -387,18 +430,9 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - NewOrLoadError::Descriptor(e) => e.fmt(f), NewOrLoadError::Write(e) => write!(f, "failed to write to persistence: {}", e), - NewOrLoadError::Load(e) => write!(f, "failed to load from persistence: {}", e), - NewOrLoadError::NotInitialized => { - write!(f, "wallet is not initialized, persistence backend is empty") - } - NewOrLoadError::LoadedGenesisDoesNotMatch { expected, got } => { - write!(f, "loaded genesis hash is not {}, got {:?}", expected, got) - } - NewOrLoadError::LoadedNetworkDoesNotMatch { expected, got } => { - write!(f, "loaded network type is not {}, got {:?}", expected, got) - } + NewOrLoadError::Descriptor(e) => e.fmt(f), + NewOrLoadError::Load(e) => e.fmt(f), } } } @@ -457,17 +491,40 @@ impl Wallet { let (chain, chain_changeset) = LocalChain::from_genesis_hash(genesis_hash); let mut index = KeychainTxOutIndex::::default(); - let (signers, change_signers) = - create_signers(&mut index, &secp, descriptor, change_descriptor, network) + let external_descriptor: (ExtendedDescriptor, KeyMap) = + into_wallet_descriptor_checked(descriptor, &secp, network) .map_err(NewError::Descriptor)?; + let internal_descriptor: Option<(ExtendedDescriptor, KeyMap)> = change_descriptor + .map(|d| into_wallet_descriptor_checked(d, &secp, network)) + .transpose() + .map_err(NewError::Descriptor)?; + + let external_descriptor_hash = + sha256::Hash::hash(external_descriptor.0.to_string().as_bytes()); + let internal_descriptor_hash: Option = internal_descriptor + .clone() + .map(|(d, _)| sha256::Hash::hash(d.to_string().as_bytes())); + + let (signers, change_signers) = + create_signers(&mut index, &secp, external_descriptor, internal_descriptor); + let indexed_graph = IndexedTxGraph::new(index); let mut persist = Persist::new(db); + + let mut descriptor_hashes = BTreeMap::new(); + descriptor_hashes.insert(KeychainKind::External, external_descriptor_hash); + match internal_descriptor_hash { + Some(hash) => descriptor_hashes.insert(KeychainKind::Internal, hash), + None => None, + }; + persist.stage(ChangeSet { chain: chain_changeset, indexed_tx_graph: indexed_graph.initial_changeset(), network: Some(network), + descriptor_hashes, }); persist.commit().map_err(NewError::Write)?; @@ -487,6 +544,7 @@ impl Wallet { descriptor: E, change_descriptor: Option, mut db: D, + network: Network, ) -> Result> where D: PersistBackend, @@ -495,27 +553,105 @@ impl Wallet { .load_from_persistence() .map_err(LoadError::Load)? .ok_or(LoadError::NotInitialized)?; - Self::load_from_changeset(descriptor, change_descriptor, db, changeset) + + let genesis_hash = genesis_block(network).block_hash(); + Self::load_from_changeset( + descriptor, + change_descriptor, + db, + network, + genesis_hash, + changeset, + ) } fn load_from_changeset( descriptor: E, change_descriptor: Option, db: D, + network: Network, + genesis_hash: BlockHash, changeset: ChangeSet, ) -> Result> where D: PersistBackend, { let secp = Secp256k1::new(); - let network = changeset.network.ok_or(LoadError::MissingNetwork)?; - let chain = - LocalChain::from_changeset(changeset.chain).map_err(|_| LoadError::MissingGenesis)?; + + // verify loaded changeset network matches expected network + let changeset_network = changeset.network.ok_or(LoadError::NetworkDoesNotMatch { + expected: network, + got: None, + })?; + if changeset_network != network { + return Err(LoadError::NetworkDoesNotMatch { + expected: network, + got: Some(changeset_network), + }); + } + + // verify loaded genesis hash matches expected genesis hash + let changeset_chain = LocalChain::from_changeset(changeset.chain).map_err(|_| { + LoadError::GenesisDoesNotMatch { + expected: genesis_hash, + got: None, + } + })?; + if changeset_chain.genesis_hash() != genesis_hash { + return Err(LoadError::GenesisDoesNotMatch { + expected: genesis_hash, + got: Some(changeset_chain.genesis_hash()), + }); + } + + let external_descriptor: (ExtendedDescriptor, KeyMap) = + into_wallet_descriptor_checked(descriptor, &secp, network) + .map_err(LoadError::Descriptor)?; + + let internal_descriptor: Option<(ExtendedDescriptor, KeyMap)> = change_descriptor + .map(|d| into_wallet_descriptor_checked(d, &secp, network)) + .transpose() + .map_err(LoadError::Descriptor)?; + + // verify loaded external descriptor exists and hash matches expected descriptor hash + let external_descriptor_hash = + sha256::Hash::hash(external_descriptor.0.to_string().as_bytes()); + let changeset_external_descriptor_hash = changeset + .descriptor_hashes + .get(&KeychainKind::External) + .ok_or(LoadError::DescriptorDoesNotMatch { + keychain: KeychainKind::External, + expected: Some(external_descriptor_hash), + got: None, + })?; + if *changeset_external_descriptor_hash != external_descriptor_hash { + return Err(LoadError::DescriptorDoesNotMatch { + keychain: KeychainKind::External, + expected: Some(external_descriptor_hash), + got: Some(*changeset_external_descriptor_hash), + }); + } + + // verify loaded internal descriptor hash matches expected descriptor hash + let internal_descriptor_hash: Option = internal_descriptor + .clone() + .map(|(d, _)| sha256::Hash::hash(d.to_string().as_bytes())); + let changeset_internal_descriptor_hash = changeset + .descriptor_hashes + .get(&KeychainKind::Internal) + .copied(); + if changeset_internal_descriptor_hash != internal_descriptor_hash { + return Err(LoadError::DescriptorDoesNotMatch { + keychain: KeychainKind::Internal, + expected: internal_descriptor_hash, + got: changeset_internal_descriptor_hash, + }); + } + let mut index = KeychainTxOutIndex::::default(); let (signers, change_signers) = - create_signers(&mut index, &secp, descriptor, change_descriptor, network) - .map_err(LoadError::Descriptor)?; + create_signers(&mut index, &secp, external_descriptor, internal_descriptor); let indexed_graph = IndexedTxGraph::new(index); let persist = Persist::new(db); @@ -523,7 +659,7 @@ impl Wallet { Ok(Wallet { signers, change_signers, - chain, + chain: changeset_chain, indexed_graph, persist, network, @@ -569,40 +705,21 @@ impl Wallet { where D: PersistBackend, { - let changeset = db.load_from_persistence().map_err(NewOrLoadError::Load)?; + let changeset = db + .load_from_persistence() + .map_err(|e| NewOrLoadError::Load(LoadError::Load(e)))?; match changeset { Some(changeset) => { - let wallet = - Self::load_from_changeset(descriptor, change_descriptor, db, changeset) - .map_err(|e| match e { - LoadError::Descriptor(e) => NewOrLoadError::Descriptor(e), - LoadError::Load(e) => NewOrLoadError::Load(e), - LoadError::NotInitialized => NewOrLoadError::NotInitialized, - LoadError::MissingNetwork => { - NewOrLoadError::LoadedNetworkDoesNotMatch { - expected: network, - got: None, - } - } - LoadError::MissingGenesis => { - NewOrLoadError::LoadedGenesisDoesNotMatch { - expected: genesis_hash, - got: None, - } - } - })?; - if wallet.network != network { - return Err(NewOrLoadError::LoadedNetworkDoesNotMatch { - expected: network, - got: Some(wallet.network), - }); - } - if wallet.chain.genesis_hash() != genesis_hash { - return Err(NewOrLoadError::LoadedGenesisDoesNotMatch { - expected: genesis_hash, - got: Some(wallet.chain.genesis_hash()), - }); - } + let wallet = Self::load_from_changeset( + descriptor, + change_descriptor, + db, + network, + genesis_hash, + changeset, + ) + .map_err(NewOrLoadError::Load)?; + Ok(wallet) } None => Self::new_with_genesis_hash( @@ -2366,20 +2483,18 @@ fn new_local_utxo( } } -fn create_signers( +fn create_signers( index: &mut KeychainTxOutIndex, secp: &Secp256k1, - descriptor: E, - change_descriptor: Option, - network: Network, -) -> Result<(Arc, Arc), crate::descriptor::error::Error> { - let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, secp, network)?; + descriptor: (ExtendedDescriptor, KeyMap), + change_descriptor: Option<(ExtendedDescriptor, KeyMap)>, +) -> (Arc, Arc) { + let (descriptor, keymap) = descriptor; let signers = Arc::new(SignersContainer::build(keymap, &descriptor, secp)); index.add_keychain(KeychainKind::External, descriptor); let change_signers = match change_descriptor { - Some(descriptor) => { - let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, secp, network)?; + Some((descriptor, keymap)) => { let signers = Arc::new(SignersContainer::build(keymap, &descriptor, secp)); index.add_keychain(KeychainKind::Internal, descriptor); signers @@ -2387,7 +2502,7 @@ fn create_signers( None => Arc::new(SignersContainer::new()), }; - Ok((signers, change_signers)) + (signers, change_signers) } #[macro_export] diff --git a/crates/bdk/tests/wallet.rs b/crates/bdk/tests/wallet.rs index 4fa399d86..e87c5c81b 100644 --- a/crates/bdk/tests/wallet.rs +++ b/crates/bdk/tests/wallet.rs @@ -81,7 +81,8 @@ fn load_recovers_wallet() { // recover wallet { let db = bdk_file_store::Store::open(DB_MAGIC, &file_path).expect("must recover db"); - let wallet = Wallet::load(get_test_wpkh(), None, db).expect("must recover wallet"); + let wallet = + Wallet::load(get_test_wpkh(), None, db, Network::Testnet).expect("must recover wallet"); assert_eq!(wallet.network(), Network::Testnet); assert_eq!(wallet.spk_index().keychains(), &wallet_keychains); } @@ -110,10 +111,10 @@ fn new_or_load() { assert!( matches!( err, - bdk::wallet::NewOrLoadError::LoadedNetworkDoesNotMatch { + bdk::wallet::NewOrLoadError::Load(bdk::wallet::LoadError::NetworkDoesNotMatch { got: Some(Network::Testnet), - expected: Network::Bitcoin - } + expected: Network::Bitcoin, + }) ), "err: {}", err, @@ -139,7 +140,7 @@ fn new_or_load() { assert!( matches!( err, - bdk::wallet::NewOrLoadError::LoadedGenesisDoesNotMatch { got, expected } + bdk::wallet::NewOrLoadError::Load( bdk::wallet::LoadError::GenesisDoesNotMatch { got, expected } ) if got == Some(got_blockhash) && expected == exp_blockhash ), "err: {}", @@ -147,6 +148,45 @@ fn new_or_load() { ); } + // wrong external descriptor + { + let db = + bdk_file_store::Store::open_or_create_new(DB_MAGIC, &file_path).expect("must open db"); + let err = Wallet::new_or_load(get_test_single_sig_csv(), None, db, Network::Testnet) + .expect_err("wrong genesis hash"); + assert!( + matches!( + err, + bdk::wallet::NewOrLoadError::Load( bdk::wallet::LoadError::DescriptorDoesNotMatch { keychain, got:_, expected:_ } ) + if keychain == KeychainKind::External + ), + "err: {}", + err, + ); + } + + // wrong internal descriptor + { + let db = + bdk_file_store::Store::open_or_create_new(DB_MAGIC, &file_path).expect("must open db"); + let err = Wallet::new_or_load( + get_test_wpkh(), + Some(get_test_single_sig_csv()), + db, + Network::Testnet, + ) + .expect_err("wrong genesis hash"); + assert!( + matches!( + err, + bdk::wallet::NewOrLoadError::Load( bdk::wallet::LoadError::DescriptorDoesNotMatch { keychain, got:_, expected:_ } ) + if keychain == KeychainKind::Internal + ), + "err: {}", + err, + ); + } + // all parameters match { let db = diff --git a/crates/chain/src/tx_data_traits.rs b/crates/chain/src/tx_data_traits.rs index c957a3e57..0206bde22 100644 --- a/crates/chain/src/tx_data_traits.rs +++ b/crates/chain/src/tx_data_traits.rs @@ -128,6 +128,17 @@ impl Append for Vec { } } +impl Append for Option { + // If other is Some then replace self's value with other's value, if other is None do nothing. + fn append(&mut self, other: Self) { + other.and_then(|v| self.replace(v)); + } + + fn is_empty(&self) -> bool { + self.is_none() + } +} + macro_rules! impl_append_for_tuple { ($($a:ident $b:tt)*) => { impl<$($a),*> Append for ($($a,)*) where $($a: Append),* { diff --git a/crates/chain/src/tx_graph.rs b/crates/chain/src/tx_graph.rs index f84c3a3dc..cd01950a5 100644 --- a/crates/chain/src/tx_graph.rs +++ b/crates/chain/src/tx_graph.rs @@ -559,10 +559,7 @@ impl TxGraph { } for (outpoint, txout) in changeset.txouts { - let tx_entry = self - .txs - .entry(outpoint.txid) - .or_insert_with(Default::default); + let tx_entry = self.txs.entry(outpoint.txid).or_default(); match tx_entry { (TxNodeInternal::Whole(_), _, _) => { /* do nothing since we already have full tx */ @@ -575,13 +572,13 @@ impl TxGraph { for (anchor, txid) in changeset.anchors { if self.anchors.insert((anchor.clone(), txid)) { - let (_, anchors, _) = self.txs.entry(txid).or_insert_with(Default::default); + let (_, anchors, _) = self.txs.entry(txid).or_default(); anchors.insert(anchor); } } for (txid, new_last_seen) in changeset.last_seen { - let (_, _, last_seen) = self.txs.entry(txid).or_insert_with(Default::default); + let (_, _, last_seen) = self.txs.entry(txid).or_default(); if new_last_seen > *last_seen { *last_seen = new_last_seen; }