diff --git a/masp_primitives/src/convert.rs b/masp_primitives/src/convert.rs index 82d78a36..c19fa1fa 100644 --- a/masp_primitives/src/convert.rs +++ b/masp_primitives/src/convert.rs @@ -3,7 +3,7 @@ use crate::{ pedersen_hash::{pedersen_hash, Personalization}, Node, ValueCommitment, }, - transaction::components::amount::Amount, + transaction::components::amount::{Amount, IAmt}, }; use borsh::{BorshDeserialize, BorshSerialize}; use group::{Curve, GroupEncoding}; @@ -16,7 +16,7 @@ use std::{ #[derive(Clone, Debug, PartialEq, Eq)] pub struct AllowedConversion { /// The asset type that the note represents - assets: Amount, + assets: IAmt, /// Memorize generator because it's expensive to recompute generator: jubjub::ExtendedPoint, } @@ -71,15 +71,15 @@ impl AllowedConversion { } } -impl From for Amount { - fn from(allowed_conversion: AllowedConversion) -> Amount { +impl From for IAmt { + fn from(allowed_conversion: AllowedConversion) -> IAmt { allowed_conversion.assets } } -impl From for AllowedConversion { +impl From for AllowedConversion { /// Produces an asset generator without cofactor cleared - fn from(assets: Amount) -> Self { + fn from(assets: IAmt) -> Self { let mut asset_generator = jubjub::ExtendedPoint::identity(); for (asset, value) in assets.components() { // Compute the absolute value (failing if -i64::MAX is @@ -199,11 +199,11 @@ mod tests { #[test] fn test_homomorphism() { // Left operand - let a = Amount::from_pair(zec(), 5).unwrap() - + Amount::from_pair(btc(), 6).unwrap() - + Amount::from_pair(xan(), 7).unwrap(); + let a = Amount::from_pair(zec(), 5i64).unwrap() + + Amount::from_pair(btc(), 6i64).unwrap() + + Amount::from_pair(xan(), 7i64).unwrap(); // Right operand - let b = Amount::from_pair(zec(), 2).unwrap() + Amount::from_pair(xan(), 10).unwrap(); + let b = Amount::from_pair(zec(), 2i64).unwrap() + Amount::from_pair(xan(), 10i64).unwrap(); // Test homomorphism assert_eq!( AllowedConversion::from(a.clone() + b.clone()), @@ -213,9 +213,9 @@ mod tests { #[test] fn test_serialization() { // Make conversion - let a: AllowedConversion = (Amount::from_pair(zec(), 5).unwrap() - + Amount::from_pair(btc(), 6).unwrap() - + Amount::from_pair(xan(), 7).unwrap()) + let a: AllowedConversion = (Amount::from_pair(zec(), 5i64).unwrap() + + Amount::from_pair(btc(), 6i64).unwrap() + + Amount::from_pair(xan(), 7i64).unwrap()) .into(); // Serialize conversion let mut data = Vec::new(); diff --git a/masp_primitives/src/sapling/prover.rs b/masp_primitives/src/sapling/prover.rs index 03a442fb..5357955b 100644 --- a/masp_primitives/src/sapling/prover.rs +++ b/masp_primitives/src/sapling/prover.rs @@ -8,7 +8,7 @@ use crate::{ redjubjub::{PublicKey, Signature}, Node, }, - transaction::components::{Amount, GROTH_PROOF_SIZE}, + transaction::components::{IAmt, GROTH_PROOF_SIZE}, }; use super::{Diversifier, PaymentAddress, ProofGenerationKey, Rseed}; @@ -73,7 +73,7 @@ pub trait TxProver { fn binding_sig( &self, ctx: &mut Self::SaplingProvingContext, - amount: &Amount, + amount: &IAmt, sighash: &[u8; 32], ) -> Result; } @@ -92,7 +92,7 @@ pub mod mock { redjubjub::{PublicKey, Signature}, Diversifier, Node, PaymentAddress, ProofGenerationKey, Rseed, }, - transaction::components::{Amount, GROTH_PROOF_SIZE}, + transaction::components::{IAmt, GROTH_PROOF_SIZE}, }; use super::TxProver; @@ -169,7 +169,7 @@ pub mod mock { fn binding_sig( &self, _ctx: &mut Self::SaplingProvingContext, - _value: &Amount, + _value: &IAmt, _sighash: &[u8; 32], ) -> Result { Err(()) diff --git a/masp_primitives/src/transaction.rs b/masp_primitives/src/transaction.rs index 3938ee30..ac765d3e 100644 --- a/masp_primitives/src/transaction.rs +++ b/masp_primitives/src/transaction.rs @@ -25,7 +25,7 @@ use crate::{ use self::{ components::{ - amount::Amount, + amount::{Amount, IAmt}, sapling::{ self, ConvertDescriptionV5, OutputDescriptionV5, SpendDescription, SpendDescriptionV5, }, @@ -269,7 +269,7 @@ impl TransactionData { } impl TransactionData { - pub fn sapling_value_balance(&self) -> Amount { + pub fn sapling_value_balance(&self) -> IAmt { self.sapling_bundle .as_ref() .map_or(Amount::zero(), |b| b.value_balance.clone()) @@ -355,7 +355,7 @@ impl Transaction { }) } - fn read_amount(mut reader: R) -> io::Result { + fn read_amount(mut reader: R) -> io::Result { Amount::read(&mut reader).map_err(|_| { io::Error::new( io::ErrorKind::InvalidData, diff --git a/masp_primitives/src/transaction/builder.rs b/masp_primitives/src/transaction/builder.rs index 63681b3b..ec6416ef 100644 --- a/masp_primitives/src/transaction/builder.rs +++ b/masp_primitives/src/transaction/builder.rs @@ -19,7 +19,7 @@ use crate::{ sapling::{prover::TxProver, Diversifier, Node, Note, PaymentAddress}, transaction::{ components::{ - amount::{Amount, BalanceError, MAX_MONEY}, + amount::{Amount, BalanceError, IAmt, MAX_MONEY}, sapling::{ self, builder::{SaplingBuilder, SaplingMetadata}, @@ -43,10 +43,10 @@ const DEFAULT_TX_EXPIRY_DELTA: u32 = 20; pub enum Error { /// Insufficient funds were provided to the transaction builder; the given /// additional amount is required in order to construct the transaction. - InsufficientFunds(Amount), + InsufficientFunds(IAmt), /// The transaction has inputs in excess of outputs and fees; the user must /// add a change output. - ChangeRequired(Amount), + ChangeRequired(IAmt), /// An error occurred in computing the fees for a transaction. Fee(FeeError), /// An overflow or underflow occurred when computing value balances @@ -293,13 +293,13 @@ impl Builder { } /// Returns the sum of the transparent, Sapling, and TZE value balances. - pub fn value_balance(&self) -> Result { + pub fn value_balance(&self) -> Result { let value_balances = [ self.transparent_builder.value_balance()?, self.sapling_builder.value_balance(), ]; - Ok(value_balances.into_iter().sum::()) + Ok(value_balances.into_iter().sum::()) } /// Builds a transaction from the configured spends and outputs. @@ -326,7 +326,7 @@ impl Builder { fn build_internal( self, prover: &impl TxProver, - fee: Amount, + fee: IAmt, ) -> Result<(Transaction, SaplingMetadata), Error> { let consensus_branch_id = BranchId::for_height(&self.params, self.target_height); diff --git a/masp_primitives/src/transaction/components.rs b/masp_primitives/src/transaction/components.rs index db07c9a5..65f943e5 100644 --- a/masp_primitives/src/transaction/components.rs +++ b/masp_primitives/src/transaction/components.rs @@ -4,7 +4,7 @@ pub mod amount; pub mod sapling; pub mod transparent; pub use self::{ - amount::Amount, + amount::{Amount, IAmt}, sapling::{ConvertDescription, OutputDescription, SpendDescription}, transparent::{TxIn, TxOut}, }; diff --git a/masp_primitives/src/transaction/components/amount.rs b/masp_primitives/src/transaction/components/amount.rs index 49da8501..9f0158b5 100644 --- a/masp_primitives/src/transaction/components/amount.rs +++ b/masp_primitives/src/transaction/components/amount.rs @@ -4,7 +4,6 @@ use std::cmp::Ordering; use std::collections::btree_map::Keys; use std::collections::btree_map::{IntoIter, Iter}; use std::collections::BTreeMap; -use std::convert::TryInto; use std::hash::Hash; use std::io::{Read, Write}; use std::iter::Sum; @@ -13,7 +12,7 @@ use zcash_encoding::Vector; pub const MAX_MONEY: i64 = i64::MAX; lazy_static::lazy_static! { -pub static ref DEFAULT_FEE: Amount = Amount::from_pair(zec(), 1000).unwrap(); +pub static ref DEFAULT_FEE: Amount = Amount::from_pair(zec(), 1000).unwrap(); } /// A type-safe representation of some quantity of Zcash. /// @@ -25,12 +24,25 @@ pub static ref DEFAULT_FEE: Amount = Amount::from_pair(zec(), 1000).unwrap(); /// particular, a `Transaction` containing serialized invalid Amounts will be rejected /// by the network consensus rules. /// + +pub type IAmt = Amount; + +pub type UAmt = Amount; + +pub type IAmt128 = Amount; + #[derive(Clone, Default, Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize, Hash)] -pub struct Amount( - pub BTreeMap, +pub struct Amount< + Unit: Hash + Ord + BorshSerialize + BorshDeserialize, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq, + > ( + pub BTreeMap, ); -impl memuse::DynamicUsage for Amount { +impl memuse::DynamicUsage for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, +{ #[inline(always)] fn dynamic_usage(&self) -> usize { unimplemented!() @@ -44,20 +56,17 @@ impl memuse::DynamicUsage for Amount { } } -impl Amount { - /// Returns a zero-valued Amount. - pub fn zero() -> Self { - Amount(BTreeMap::new()) - } - +impl Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, +{ /// Creates a non-negative Amount from an i64. /// /// Returns an error if the amount is outside the range `{0..MAX_MONEY}`. - pub fn from_nonnegative>(atype: Unit, amount: Amt) -> Result { - let amount = amount.try_into().map_err(|_| ())?; - if amount == 0 { + pub fn from_nonnegative(atype: Unit, amount: Magnitude) -> Result { + if amount == Magnitude::default() { Ok(Self::zero()) - } else if 0 <= amount && amount <= MAX_MONEY { + } else if Magnitude::default() <= amount { let mut ret = BTreeMap::new(); ret.insert(atype, amount); Ok(Amount(ret)) @@ -65,50 +74,70 @@ impl Amount Err(()) } } +} + +impl Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default, +{ /// Creates an Amount from a type convertible to i64. /// /// Returns an error if the amount is outside the range `{-MAX_MONEY..MAX_MONEY}`. - pub fn from_pair>(atype: Unit, amount: Amt) -> Result { - let amount = amount.try_into().map_err(|_| ())?; - if amount == 0 { + pub fn from_pair(atype: Unit, amount: Magnitude) -> Result { + if amount == Magnitude::default() { Ok(Self::zero()) - } else if -MAX_MONEY <= amount && amount <= MAX_MONEY { + } else { let mut ret = BTreeMap::new(); ret.insert(atype, amount); Ok(Amount(ret)) - } else { - Err(()) } } + /// Filters out everything but the given AssetType from this Amount + pub fn project(&self, index: Unit) -> Self { + let val = self.0.get(&index).copied().unwrap_or(Magnitude::default()); + Self::from_pair(index, val).unwrap() + } + + /// Get the given AssetType within this Amount + pub fn get(&self, index: &Unit) -> Magnitude { + *self.0.get(index).unwrap_or(&Magnitude::default()) + } +} + +impl Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy, +{ + /// Returns a zero-valued Amount. + pub fn zero() -> Self { + Amount(BTreeMap::new()) + } + /// Returns an iterator over the amount's non-zero asset-types - pub fn asset_types(&self) -> Keys<'_, Unit, i64> { + pub fn asset_types(&self) -> Keys<'_, Unit, Magnitude> { self.0.keys() } /// Returns an iterator over the amount's non-zero components - pub fn components(&self) -> Iter<'_, Unit, i64> { + pub fn components(&self) -> Iter<'_, Unit, Magnitude> { self.0.iter() } /// Returns an iterator over the amount's non-zero components - pub fn into_components(self) -> IntoIter { + pub fn into_components(self) -> IntoIter { self.0.into_iter() } - /// Filters out everything but the given AssetType from this Amount - pub fn project(&self, index: Unit) -> Self { - let val = self.0.get(&index).copied().unwrap_or(0); - Self::from_pair(index, val).unwrap() - } - /// Filters out the given AssetType from this Amount pub fn reject(&self, index: Unit) -> Self { - self.clone() - self.project(index) + let mut val = self.clone(); + val.0.remove(&index); + val } } -impl Amount { +impl Amount { /// Deserialize an Amount object from a list of amounts denominated by /// different assets pub fn read(reader: &mut R) -> std::io::Result { @@ -143,32 +172,30 @@ impl Amount { } } -impl From for Amount { +impl From for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + From { fn from(atype: Unit) -> Self { let mut ret = BTreeMap::new(); - ret.insert(atype, 1); + ret.insert(atype, true.into()); Amount(ret) } } -impl PartialOrd for Amount { +impl PartialOrd for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, + Self: Sub +{ /// One Amount is more than or equal to another if each corresponding /// coordinate is more than the other's. fn partial_cmp(&self, other: &Self) -> Option { - let mut diff = other.clone(); - for (atype, amount) in self.components() { - let ent = diff[atype] - amount; - if ent == 0 { - diff.0.remove(atype); - } else { - diff.0.insert(atype.clone(), ent); - } - } - if diff.0.values().all(|x| *x == 0) { + let diff = other.clone() - self.clone(); + if diff.0.values().all(|x| *x == Default::default()) { Some(Ordering::Equal) - } else if diff.0.values().all(|x| *x >= 0) { + } else if diff.0.values().all(|x| *x >= Default::default()) { Some(Ordering::Less) - } else if diff.0.values().all(|x| *x <= 0) { + } else if diff.0.values().all(|x| *x <= Default::default()) { Some(Ordering::Greater) } else { None @@ -176,147 +203,225 @@ impl PartialOrd fo } } -impl Index<&Unit> for Amount { - type Output = i64; - /// Query how much of the given asset this amount contains - fn index(&self, index: &Unit) -> &Self::Output { - self.0.get(index).unwrap_or(&0) +macro_rules! impl_index { + ($struct_type:ty) => { + impl Index<&Unit> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize, + { + type Output = $struct_type; + /// Query how much of the given asset this amount contains + fn index(&self, index: &Unit) -> &Self::Output { + self.0.get(index).unwrap_or(&0) + } + } } } -impl MulAssign for Amount { - fn mul_assign(&mut self, rhs: i64) { +impl_index!(i64); + +impl_index!(u64); + +impl_index!(i128); + +impl MulAssign for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + MulAssign, + Rhs: Copy, +{ + fn mul_assign(&mut self, rhs: Rhs) { for (_atype, amount) in self.0.iter_mut() { - let ent = *amount * rhs; - if -MAX_MONEY <= ent && ent <= MAX_MONEY { - *amount = ent; - } else { - panic!("multiplication should remain in range"); - } + *amount *= rhs; } } } -impl Mul for Amount { - type Output = Self; +impl Mul for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Mul, + Rhs: Copy, +>::Output: BorshSerialize + BorshDeserialize + Eq + PartialOrd, +{ + type Output = Amount>::Output>; - fn mul(mut self, rhs: i64) -> Self { - self *= rhs; - self + fn mul(self, rhs: Rhs) -> Self::Output { + let mut comps = BTreeMap::new(); + for (atype, amount) in self.0.iter() { + comps.insert(atype.clone(), *amount * rhs); + } + Amount(comps) } } -impl AddAssign<&Amount> - for Amount +impl AddAssign<&Amount> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + AddAssign, + Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Mul, { - fn add_assign(&mut self, rhs: &Self) { + fn add_assign(&mut self, rhs: &Amount) { for (atype, amount) in rhs.components() { - let ent = self[atype] + amount; - if ent == 0 { - self.0.remove(atype); - } else if -MAX_MONEY <= ent && ent <= MAX_MONEY { - self.0.insert(atype.clone(), ent); - } else { - panic!("addition should remain in range"); - } + let mut val = self.get(atype); + val += *amount; + self.0.insert(atype.clone(), val); } } } -impl AddAssign> - for Amount +impl AddAssign> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + AddAssign, + Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Mul, { - fn add_assign(&mut self, rhs: Self) { + fn add_assign(&mut self, rhs: Amount) { *self += &rhs } } -impl Add<&Amount> - for Amount +impl Add<&Amount> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Add, + Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, +>::Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, { - type Output = Self; + type Output = Amount>::Output>; - fn add(mut self, rhs: &Self) -> Self { - self += rhs; - self + fn add(self, rhs: &Amount) -> Self::Output { + let mut comps = BTreeMap::new(); + for (atype, amount) in rhs.components() { + comps.insert(atype.clone(), self.get(atype)+ *amount); + } + Amount(comps) } } -impl Add> - for Amount +impl Add> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Add, + Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, +>::Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, { - type Output = Self; + type Output = Amount>::Output>; - fn add(mut self, rhs: Self) -> Self { - self += &rhs; - self + fn add(self, rhs: Amount) -> Self::Output { + self + &rhs } } -impl SubAssign<&Amount> - for Amount +impl SubAssign<&Amount> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + SubAssign, + Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, { - fn sub_assign(&mut self, rhs: &Self) { + fn sub_assign(&mut self, rhs: &Amount) { for (atype, amount) in rhs.components() { - let ent = self[atype] - amount; - if ent == 0 { - self.0.remove(atype); - } else if -MAX_MONEY <= ent && ent <= MAX_MONEY { - self.0.insert(atype.clone(), ent); - } else { - panic!("subtraction should remain in range"); - } + let mut val = self.get(atype); + val -= amount.clone(); + self.0.insert(atype.clone(), val); } } } -impl SubAssign> - for Amount +impl SubAssign> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + SubAssign, + Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, { - fn sub_assign(&mut self, rhs: Self) { + fn sub_assign(&mut self, rhs: Amount) { *self -= &rhs } } -impl Neg for Amount { - type Output = Self; +impl Neg for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Neg, +::Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, +{ + type Output = Amount::Output>; - fn neg(mut self) -> Self { - for (_, amount) in self.0.iter_mut() { - *amount = -*amount; + fn neg(mut self) -> Self::Output { + let mut comps = BTreeMap::new(); + for (atype, amount) in self.0.iter_mut() { + comps.insert(atype.clone(), -*amount); } - self + Amount(comps) } } -impl Sub<&Amount> - for Amount +impl Sub<&Amount> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + Sub, + Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy, +>::Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy, { - type Output = Self; + type Output = Amount>::Output>; - fn sub(mut self, rhs: &Self) -> Self { - self -= rhs; - self + fn sub(self, rhs: &Amount) -> Self::Output { + let mut comps = BTreeMap::new(); + for (atype, amount) in rhs.components() { + comps.insert(atype.clone(), self.get(atype) - amount.clone()); + } + Amount(comps) } } -impl Sub> - for Amount +impl Sub> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + Sub, + Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy, +>::Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy, { - type Output = Self; + type Output = Amount>::Output>; - fn sub(mut self, rhs: Self) -> Self { - self -= &rhs; - self + fn sub(self, rhs: Amount) -> Self::Output { + self - &rhs } } -impl Sum for Amount { +impl Sum for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd, + Self: Add, +{ fn sum>(iter: I) -> Self { iter.fold(Self::zero(), Add::add) } } +/// Workaround for the blanket implementation of TryFrom +pub struct TryFromNt(pub X); + +impl TryFrom>> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy, + Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + TryFrom, +{ + type Error = >::Error; + + fn try_from(x: TryFromNt>) -> Result { + let mut comps = BTreeMap::new(); + for (atype, amount) in x.0.0 { + comps.insert(atype, amount.try_into()?); + } + Ok(Self(comps)) + } +} + +/// Workaround for the blanket implementation of TryFrom +pub struct FromNt(pub X); + +impl From>> for Amount where + Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone, + Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy, + Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + From, +{ + fn from(x: FromNt>) -> Self { + let mut comps = BTreeMap::new(); + for (atype, amount) in x.0.0 { + comps.insert(atype, amount.into()); + } + Self(comps) + } +} + /// A type for balance violations in amount addition and subtraction /// (overflow and underflow of allowed ranges) #[derive(Copy, Clone, Debug, PartialEq, Eq)] @@ -346,7 +451,7 @@ pub fn zec() -> AssetType { AssetType::new(b"ZEC").unwrap() } -pub fn default_fee() -> Amount { +pub fn default_fee() -> Amount { Amount::from_pair(zec(), 10000).unwrap() } @@ -354,23 +459,23 @@ pub fn default_fee() -> Amount { pub mod testing { use proptest::prelude::prop_compose; - use super::{Amount, MAX_MONEY}; + use super::{Amount, IAmt, MAX_MONEY}; use crate::asset_type::testing::arb_asset_type; prop_compose! { - pub fn arb_amount()(asset_type in arb_asset_type(), amt in -MAX_MONEY..MAX_MONEY) -> Amount { + pub fn arb_amount()(asset_type in arb_asset_type(), amt in -MAX_MONEY..MAX_MONEY) -> IAmt { Amount::from_pair(asset_type, amt).unwrap() } } prop_compose! { - pub fn arb_nonnegative_amount()(asset_type in arb_asset_type(), amt in 0i64..MAX_MONEY) -> Amount { + pub fn arb_nonnegative_amount()(asset_type in arb_asset_type(), amt in 0i64..MAX_MONEY) -> IAmt { Amount::from_pair(asset_type, amt).unwrap() } } prop_compose! { - pub fn arb_positive_amount()(asset_type in arb_asset_type(), amt in 1i64..MAX_MONEY) -> Amount { + pub fn arb_positive_amount()(asset_type in arb_asset_type(), amt in 1i64..MAX_MONEY) -> IAmt { Amount::from_pair(asset_type, amt).unwrap() } } diff --git a/masp_primitives/src/transaction/components/sapling.rs b/masp_primitives/src/transaction/components/sapling.rs index 2048abd9..84e25bc0 100644 --- a/masp_primitives/src/transaction/components/sapling.rs +++ b/masp_primitives/src/transaction/components/sapling.rs @@ -23,7 +23,7 @@ use crate::{ }, }; -use super::{amount::Amount, GROTH_PROOF_SIZE}; +use super::{amount::IAmt, GROTH_PROOF_SIZE}; pub type GrothProofBytes = [u8; GROTH_PROOF_SIZE]; @@ -90,7 +90,7 @@ pub struct Bundle>, pub shielded_converts: Vec>, pub shielded_outputs: Vec>, - pub value_balance: Amount, + pub value_balance: IAmt, pub authorization: A, } diff --git a/masp_primitives/src/transaction/components/sapling/builder.rs b/masp_primitives/src/transaction/components/sapling/builder.rs index 381295f0..2b22ed4b 100644 --- a/masp_primitives/src/transaction/components/sapling/builder.rs +++ b/masp_primitives/src/transaction/components/sapling/builder.rs @@ -25,7 +25,7 @@ use crate::{ transaction::{ builder::Progress, components::{ - amount::{Amount, MAX_MONEY}, + amount::{Amount, IAmt, IAmt128, TryFromNt, FromNt, MAX_MONEY}, sapling::{ fees, Authorization, Authorized, Bundle, ConvertDescription, GrothProofBytes, OutputDescription, SpendDescription, @@ -272,7 +272,7 @@ pub struct SaplingBuilder { params: P, spend_anchor: Option, target_height: BlockHeight, - value_balance: Amount, + value_balance: Amount, convert_anchor: Option, spends: Vec>, converts: Vec, @@ -303,7 +303,7 @@ impl BorshDeserialize for SaplingBui .map(|x| x.ok_or_else(|| std::io::Error::from(std::io::ErrorKind::InvalidData))) .transpose()?; let target_height = BlockHeight::deserialize(buf)?; - let value_balance: Amount = Amount::deserialize(buf)?; + let value_balance = Amount::::deserialize(buf)?; let convert_anchor: Option> = Option::<[u8; 32]>::deserialize(buf)?.map(|x| bls12_381::Scalar::from_bytes(&x).into()); let convert_anchor = convert_anchor @@ -369,9 +369,19 @@ impl SaplingBuilder { &self.outputs } + /// Returns the net value represented by the spends and outputs added to this builder, + /// or an error if the values added to this builder overflow the range of a Zcash + /// monetary amount. + fn try_value_balance(&self) -> Result { + TryFromNt(self.value_balance.clone()) + .try_into() + .map_err(|_| Error::InvalidAmount) + } + /// Returns the net value represented by the spends and outputs added to this builder. - pub fn value_balance(&self) -> Amount { - self.value_balance.clone() + pub fn value_balance(&self) -> IAmt { + self.try_value_balance() + .expect("we check this when mutating self.value_balance") } } @@ -402,7 +412,7 @@ impl SaplingBuilder

{ let alpha = jubjub::Fr::random(&mut rng); self.value_balance += - Amount::from_pair(note.asset_type, note.value).map_err(|_| Error::InvalidAmount)?; + IAmt128::from(FromNt(Amount::from_pair(note.asset_type, note.value).map_err(|_| Error::InvalidAmount)?)); self.spends.push(SpendDescriptionInfo { extsk, @@ -437,8 +447,8 @@ impl SaplingBuilder

{ self.convert_anchor = Some(merkle_path.root(node).into()) } - let allowed_amt: Amount = allowed.clone().into(); - self.value_balance += allowed_amt * value.try_into().unwrap(); + let allowed_amt: IAmt = allowed.clone().into(); + self.value_balance += IAmt128::from(FromNt(allowed_amt * (value as i64))); self.converts.push(ConvertDescriptionInfo { allowed, @@ -472,7 +482,7 @@ impl SaplingBuilder

{ )?; self.value_balance -= - Amount::from_pair(asset_type, value).map_err(|_| Error::InvalidAmount)?; + IAmt128::from(FromNt(Amount::from_pair(asset_type, value).map_err(|_| Error::InvalidAmount)?)); self.outputs.push(output); @@ -488,6 +498,7 @@ impl SaplingBuilder

{ progress_notifier: Option<&Sender>, ) -> Result>, Error> { // Record initial positions of spends and outputs + let value_balance = self.try_value_balance()?; let params = self.params; let mut indexed_spends: Vec<_> = self.spends.into_iter().enumerate().collect(); let mut indexed_converts: Vec<_> = self.converts.into_iter().enumerate().collect(); @@ -723,7 +734,7 @@ impl SaplingBuilder

{ shielded_spends, shielded_converts, shielded_outputs, - value_balance: self.value_balance, + value_balance, authorization: Unauthorized { tx_metadata }, }) }; diff --git a/masp_primitives/src/transaction/components/transparent.rs b/masp_primitives/src/transaction/components/transparent.rs index 06efcaee..f650567a 100644 --- a/masp_primitives/src/transaction/components/transparent.rs +++ b/masp_primitives/src/transaction/components/transparent.rs @@ -7,7 +7,7 @@ use std::io::{self, Read, Write}; use crate::asset_type::AssetType; use crate::transaction::TransparentAddress; -use super::amount::{Amount, BalanceError, MAX_MONEY}; +use super::amount::{Amount, BalanceError, IAmt, MAX_MONEY}; pub mod builder; pub mod fees; @@ -58,7 +58,7 @@ impl Bundle { /// transferred out of the transparent pool into shielded pools or to fees; a negative value /// means that the containing transaction has funds being transferred into the transparent pool /// from the shielded pools. - pub fn value_balance(&self) -> Result + pub fn value_balance(&self) -> Result where E: From, { @@ -72,7 +72,7 @@ impl Bundle { Err(()) } }) - .sum::>() + .sum::>() .map_err(|_| BalanceError::Overflow)?; let output_sum = self @@ -85,7 +85,7 @@ impl Bundle { Err(()) } }) - .sum::>() + .sum::>() .map_err(|_| BalanceError::Overflow)?; // Cannot panic when subtracting two positive i64 diff --git a/masp_primitives/src/transaction/components/transparent/builder.rs b/masp_primitives/src/transaction/components/transparent/builder.rs index 6c0cc642..5d191a88 100644 --- a/masp_primitives/src/transaction/components/transparent/builder.rs +++ b/masp_primitives/src/transaction/components/transparent/builder.rs @@ -6,7 +6,7 @@ use crate::{ asset_type::AssetType, transaction::{ components::{ - amount::{Amount, BalanceError, MAX_MONEY}, + amount::{Amount, BalanceError, IAmt, MAX_MONEY}, transparent::{self, fees, Authorization, Authorized, Bundle, TxIn, TxOut}, }, sighash::TransparentAuthorizingContext, @@ -133,7 +133,7 @@ impl TransparentBuilder { Ok(()) } - pub fn value_balance(&self) -> Result { + pub fn value_balance(&self) -> Result { #[cfg(feature = "transparent-inputs")] let input_sum = self .inputs @@ -145,7 +145,7 @@ impl TransparentBuilder { Err(()) } }) - .sum::>() + .sum::>() .map_err(|_| BalanceError::Overflow)?; #[cfg(not(feature = "transparent-inputs"))] @@ -161,7 +161,7 @@ impl TransparentBuilder { Err(()) } }) - .sum::>() + .sum::>() .map_err(|_| BalanceError::Overflow)?; // Cannot panic when subtracting two positive i64 diff --git a/masp_primitives/src/transaction/fees.rs b/masp_primitives/src/transaction/fees.rs index 0108e20f..84f965ee 100644 --- a/masp_primitives/src/transaction/fees.rs +++ b/masp_primitives/src/transaction/fees.rs @@ -2,7 +2,7 @@ use crate::{ consensus::{self, BlockHeight}, - transaction::components::{amount::Amount, transparent::fees as transparent}, + transaction::components::{amount::{Amount, IAmt}, transparent::fees as transparent}, }; pub mod fixed; @@ -24,5 +24,5 @@ pub trait FeeRule { transparent_outputs: &[impl transparent::OutputView], sapling_input_count: usize, sapling_output_count: usize, - ) -> Result; + ) -> Result; } diff --git a/masp_primitives/src/transaction/fees/fixed.rs b/masp_primitives/src/transaction/fees/fixed.rs index 02e3bd93..320a960e 100644 --- a/masp_primitives/src/transaction/fees/fixed.rs +++ b/masp_primitives/src/transaction/fees/fixed.rs @@ -1,7 +1,7 @@ use crate::{ consensus::{self, BlockHeight}, transaction::components::{ - amount::{Amount, DEFAULT_FEE}, + amount::{IAmt, DEFAULT_FEE}, transparent::fees as transparent, }, }; @@ -10,12 +10,12 @@ use crate::{ /// the transaction being constructed. #[derive(Clone, Debug)] pub struct FeeRule { - fixed_fee: Amount, + fixed_fee: IAmt, } impl FeeRule { /// Creates a new nonstandard fixed fee rule with the specified fixed fee. - pub fn non_standard(fixed_fee: Amount) -> Self { + pub fn non_standard(fixed_fee: IAmt) -> Self { Self { fixed_fee } } @@ -27,7 +27,7 @@ impl FeeRule { } /// Returns the fixed fee amount which which this rule was configured. - pub fn fixed_fee(&self) -> Amount { + pub fn fixed_fee(&self) -> IAmt { self.fixed_fee.clone() } } @@ -42,7 +42,7 @@ impl super::FeeRule for FeeRule { _transparent_outputs: &[impl transparent::OutputView], _sapling_input_count: usize, _sapling_output_count: usize, - ) -> Result { + ) -> Result { Ok(self.fixed_fee.clone()) } } diff --git a/masp_proofs/src/prover.rs b/masp_proofs/src/prover.rs index b61fbc1b..ca5b19b5 100644 --- a/masp_proofs/src/prover.rs +++ b/masp_proofs/src/prover.rs @@ -11,7 +11,7 @@ use masp_primitives::{ redjubjub::{PublicKey, Signature}, Diversifier, Node, PaymentAddress, ProofGenerationKey, Rseed, }, - transaction::components::{Amount, GROTH_PROOF_SIZE}, + transaction::components::{IAmt, GROTH_PROOF_SIZE}, }; use std::path::Path; @@ -247,7 +247,7 @@ impl TxProver for LocalTxProver { fn binding_sig( &self, ctx: &mut Self::SaplingProvingContext, - assets_and_values: &Amount, //&[(AssetType, i64)], + assets_and_values: &IAmt, //&[(AssetType, i64)], sighash: &[u8; 32], ) -> Result { ctx.binding_sig(assets_and_values, sighash) diff --git a/masp_proofs/src/sapling/prover.rs b/masp_proofs/src/sapling/prover.rs index ad5eaf7b..5c14fed0 100644 --- a/masp_proofs/src/sapling/prover.rs +++ b/masp_proofs/src/sapling/prover.rs @@ -13,7 +13,7 @@ use masp_primitives::{ redjubjub::{PrivateKey, PublicKey, Signature}, Diversifier, Node, Note, PaymentAddress, ProofGenerationKey, Rseed, }, - transaction::components::Amount, + transaction::components::IAmt, }; use rand_core::OsRng; use std::ops::{AddAssign, Neg}; @@ -284,7 +284,7 @@ impl SaplingProvingContext { /// and output_proof() must be completed before calling this function. pub fn binding_sig( &self, - assets_and_values: &Amount, + assets_and_values: &IAmt, sighash: &[u8; 32], ) -> Result { // Initialize secure RNG diff --git a/masp_proofs/src/sapling/verifier.rs b/masp_proofs/src/sapling/verifier.rs index 22cceafb..3e14857d 100644 --- a/masp_proofs/src/sapling/verifier.rs +++ b/masp_proofs/src/sapling/verifier.rs @@ -5,7 +5,7 @@ use bls12_381::Bls12; use group::{Curve, GroupEncoding}; use masp_primitives::{ sapling::redjubjub::{PublicKey, Signature}, - transaction::components::Amount, + transaction::components::IAmt, }; use super::masp_compute_value_balance; @@ -172,7 +172,7 @@ impl SaplingVerificationContextInner { /// have been checked before calling this function. fn final_check( &self, - value_balance: Amount, + value_balance: IAmt, sighash_value: &[u8; 32], binding_sig: Signature, binding_sig_verifier: impl FnOnce(PublicKey, [u8; 64], Signature) -> bool, diff --git a/masp_proofs/src/sapling/verifier/single.rs b/masp_proofs/src/sapling/verifier/single.rs index 2df7dfae..82a5c942 100644 --- a/masp_proofs/src/sapling/verifier/single.rs +++ b/masp_proofs/src/sapling/verifier/single.rs @@ -3,7 +3,7 @@ use bls12_381::Bls12; use masp_primitives::{ constants::{SPENDING_KEY_GENERATOR, VALUE_COMMITMENT_RANDOMNESS_GENERATOR}, sapling::redjubjub::{PublicKey, Signature}, - transaction::components::Amount, + transaction::components::IAmt, }; use super::SaplingVerificationContextInner; @@ -98,7 +98,7 @@ impl SaplingVerificationContext { /// have been checked before calling this function. pub fn final_check( &self, - value_balance: Amount, + value_balance: IAmt, sighash_value: &[u8; 32], binding_sig: Signature, ) -> bool {