Skip to content

Commit

Permalink
Moved to using checked arithmetic for Amounts.
Browse files Browse the repository at this point in the history
  • Loading branch information
murisi committed Jun 29, 2023
1 parent e7da829 commit bbed283
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 78 deletions.
3 changes: 3 additions & 0 deletions masp_primitives/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ sha2 = "0.9"
# - Metrics
memuse = "0.2.1"

# - Checked arithmetic
num-traits = "0.2.14"

# - Secret management
subtle = "2.2.3"

Expand Down
133 changes: 55 additions & 78 deletions masp_primitives/src/transaction/components/amount.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::asset_type::AssetType;
use borsh::{BorshDeserialize, BorshSerialize};
use num_traits::{CheckedAdd, CheckedMul, CheckedNeg, CheckedSub};
use std::cmp::Ordering;
use std::collections::btree_map::Keys;
use std::collections::btree_map::{IntoIter, Iter};
Expand All @@ -11,6 +12,7 @@ use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub, SubAssign};
use zcash_encoding::Vector;

pub const MAX_MONEY: i64 = i64::MAX;
pub const MIN_MONEY: i64 = i64::MIN;
lazy_static::lazy_static! {
pub static ref DEFAULT_FEE: Amount<AssetType, i64> = Amount::from_pair(zec(), 1000).unwrap();
}
Expand Down Expand Up @@ -223,155 +225,133 @@ impl_index!(u64);

impl_index!(i128);

impl<Unit, Magnitude, Rhs> MulAssign<Rhs> for Amount<Unit, Magnitude> where
impl<Unit, Magnitude> MulAssign<Magnitude> for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + MulAssign<Rhs>,
Rhs: Copy,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + CheckedMul,
{
fn mul_assign(&mut self, rhs: Rhs) {
for (_atype, amount) in self.0.iter_mut() {
*amount *= rhs;
}
fn mul_assign(&mut self, rhs: Magnitude) {
*self = self.clone() * rhs;
}
}

impl<Unit, Magnitude, Rhs> Mul<Rhs> for Amount<Unit, Magnitude> where
impl<Unit, Magnitude> Mul<Magnitude> for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Mul<Rhs>,
Rhs: Copy,
<Magnitude as Mul<Rhs>>::Output: BorshSerialize + BorshDeserialize + Eq + PartialOrd,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + CheckedMul,
{
type Output = Amount<Unit, <Magnitude as Mul<Rhs>>::Output>;
type Output = Amount<Unit, Magnitude>;

fn mul(self, rhs: Rhs) -> Self::Output {
fn mul(self, rhs: Magnitude) -> Self::Output {
let mut comps = BTreeMap::new();
for (atype, amount) in self.0.iter() {
comps.insert(atype.clone(), *amount * rhs);
comps.insert(atype.clone(), amount.checked_mul(&rhs).expect("overflow detected"));
}
comps.retain(|_, v| *v != Magnitude::default());
Amount(comps)
}
}

impl<Unit, Magnitude, Rhs> AddAssign<&Amount<Unit, Rhs>> for Amount<Unit, Magnitude> where
impl<Unit, Magnitude> AddAssign<&Amount<Unit, Magnitude>> for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + AddAssign<Rhs>,
Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Mul<Rhs>,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + CheckedAdd,
{
fn add_assign(&mut self, rhs: &Amount<Unit, Rhs>) {
for (atype, amount) in rhs.components() {
let mut val = self.get(atype);
val += *amount;
self.0.insert(atype.clone(), val);
}
fn add_assign(&mut self, rhs: &Amount<Unit, Magnitude>) {
*self = self.clone() + rhs;
}
}

impl<Unit, Magnitude, Rhs> AddAssign<Amount<Unit, Rhs>> for Amount<Unit, Magnitude> where
impl<Unit, Magnitude> AddAssign<Amount<Unit, Magnitude>> for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + AddAssign<Rhs>,
Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Mul<Rhs>,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + CheckedAdd,
{
fn add_assign(&mut self, rhs: Amount<Unit, Rhs>) {
fn add_assign(&mut self, rhs: Amount<Unit, Magnitude>) {
*self += &rhs
}
}

impl<Unit, Magnitude, Rhs> Add<&Amount<Unit, Rhs>> for Amount<Unit, Magnitude> where
impl<Unit, Magnitude> Add<&Amount<Unit, Magnitude>> for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Add<Rhs>,
Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd,
<Magnitude as Add<Rhs>>::Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + CheckedAdd,
{
type Output = Amount<Unit, <Magnitude as Add<Rhs>>::Output>;
type Output = Amount<Unit, Magnitude>;

fn add(self, rhs: &Amount<Unit, Rhs>) -> Self::Output {
let mut comps = BTreeMap::new();
fn add(self, rhs: &Amount<Unit, Magnitude>) -> Self::Output {
let mut comps = self.0.clone();
for (atype, amount) in rhs.components() {
comps.insert(atype.clone(), self.get(atype)+ *amount);
comps.insert(atype.clone(), self.get(atype).checked_add(amount).expect("overflow detected"));
}
comps.retain(|_, v| *v != Magnitude::default());
Amount(comps)
}
}

impl<Unit, Magnitude, Rhs> Add<Amount<Unit, Rhs>> for Amount<Unit, Magnitude> where
impl<Unit, Magnitude> Add<Amount<Unit, Magnitude>> for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Add<Rhs>,
Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd,
<Magnitude as Add<Rhs>>::Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + CheckedAdd,
{
type Output = Amount<Unit, <Magnitude as Add<Rhs>>::Output>;
type Output = Amount<Unit, Magnitude>;

fn add(self, rhs: Amount<Unit, Rhs>) -> Self::Output {
fn add(self, rhs: Amount<Unit, Magnitude>) -> Self::Output {
self + &rhs
}
}

impl<Unit, Magnitude, Rhs> SubAssign<&Amount<Unit, Rhs>> for Amount<Unit, Magnitude> where
impl<Unit, Magnitude> SubAssign<&Amount<Unit, Magnitude>> for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + SubAssign<Rhs>,
Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + CheckedSub,
{
fn sub_assign(&mut self, rhs: &Amount<Unit, Rhs>) {
for (atype, amount) in rhs.components() {
let mut val = self.get(atype);
val -= amount.clone();
self.0.insert(atype.clone(), val);
}
fn sub_assign(&mut self, rhs: &Amount<Unit, Magnitude>) {
*self = self.clone() - rhs
}
}

impl<Unit, Magnitude, Rhs> SubAssign<Amount<Unit, Rhs>> for Amount<Unit, Magnitude> where
impl<Unit, Magnitude> SubAssign<Amount<Unit, Magnitude>> for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + SubAssign<Rhs>,
Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + CheckedSub,
{
fn sub_assign(&mut self, rhs: Amount<Unit, Rhs>) {
fn sub_assign(&mut self, rhs: Amount<Unit, Magnitude>) {
*self -= &rhs
}
}

impl<Unit, Magnitude> Neg for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + Neg,
<Magnitude as Neg>::Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + PartialOrd + CheckedNeg,
{
type Output = Amount<Unit, <Magnitude as Neg>::Output>;
type Output = Amount<Unit, Magnitude>;

fn neg(mut self) -> Self::Output {
let mut comps = BTreeMap::new();
for (atype, amount) in self.0.iter_mut() {
comps.insert(atype.clone(), -*amount);
comps.insert(atype.clone(), amount.checked_neg().expect("overflow detected"));
}
comps.retain(|_, v| *v != Magnitude::default());
Amount(comps)
}
}

impl<Unit, Magnitude, Rhs> Sub<&Amount<Unit, Rhs>> for Amount<Unit, Magnitude> where
impl<Unit, Magnitude> Sub<&Amount<Unit, Magnitude>> for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + Sub<Rhs>,
Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy,
<Magnitude as Sub<Rhs>>::Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + CheckedSub,
{
type Output = Amount<Unit, <Magnitude as Sub<Rhs>>::Output>;
type Output = Amount<Unit, Magnitude>;

fn sub(self, rhs: &Amount<Unit, Rhs>) -> Self::Output {
let mut comps = BTreeMap::new();
fn sub(self, rhs: &Amount<Unit, Magnitude>) -> Self::Output {
let mut comps = self.0.clone();
for (atype, amount) in rhs.components() {
comps.insert(atype.clone(), self.get(atype) - amount.clone());
comps.insert(atype.clone(), self.get(atype).checked_sub(&amount).expect("overflow detected"));

Check failure on line 341 in masp_primitives/src/transaction/components/amount.rs

View workflow job for this annotation

GitHub Actions / Clippy (MSRV)

this expression creates a reference which is immediately dereferenced by the compiler

error: this expression creates a reference which is immediately dereferenced by the compiler --> masp_primitives/src/transaction/components/amount.rs:341:69 | 341 | comps.insert(atype.clone(), self.get(atype).checked_sub(&amount).expect("overflow detected")); | ^^^^^^^ help: change this to: `amount` | = note: `-D clippy::needless-borrow` implied by `-D warnings` = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrow

Check failure on line 341 in masp_primitives/src/transaction/components/amount.rs

View workflow job for this annotation

GitHub Actions / Clippy (MSRV)

this expression creates a reference which is immediately dereferenced by the compiler

error: this expression creates a reference which is immediately dereferenced by the compiler --> masp_primitives/src/transaction/components/amount.rs:341:69 | 341 | comps.insert(atype.clone(), self.get(atype).checked_sub(&amount).expect("overflow detected")); | ^^^^^^^ help: change this to: `amount` | = note: `-D clippy::needless-borrow` implied by `-D warnings` = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrow
}
comps.retain(|_, v| *v != Magnitude::default());
Amount(comps)
}
}

impl<Unit, Magnitude, Rhs> Sub<Amount<Unit, Rhs>> for Amount<Unit, Magnitude> where
impl<Unit, Magnitude> Sub<Amount<Unit, Magnitude>> for Amount<Unit, Magnitude> where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + Sub<Rhs>,
Rhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy,
<Magnitude as Sub<Rhs>>::Output: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy,
Magnitude: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + CheckedSub,
{
type Output = Amount<Unit, <Magnitude as Sub<Rhs>>::Output>;
type Output = Amount<Unit, Magnitude>;

fn sub(self, rhs: Amount<Unit, Rhs>) -> Self::Output {
fn sub(self, rhs: Amount<Unit, Magnitude>) -> Self::Output {
self - &rhs
}
}
Expand Down Expand Up @@ -483,7 +463,7 @@ pub mod testing {

#[cfg(test)]
mod tests {
use super::{zec, Amount, MAX_MONEY};
use super::{zec, Amount, MAX_MONEY, MIN_MONEY};

#[test]
fn amount_in_range() {
Expand Down Expand Up @@ -517,9 +497,6 @@ mod tests {
Amount::read(&mut neg_max_money.as_ref()).unwrap(),
Amount::from_pair(zec(), -MAX_MONEY).unwrap()
);

let neg_max_money_m1 = b"\x01\x94\xf3O\xfdd\xef\n\xc3i\x08\xfd\xdf\xec\x05hX\x06)\xc4Vq\x0f\xa1\x86\x83\x12\xa8\x7f\xbf\n\xa5\t\x00\x00\x00\x00\x00\x00\x00\x80";
assert!(Amount::read(&mut neg_max_money_m1.as_ref()).is_err());
}

#[test]
Expand All @@ -539,14 +516,14 @@ mod tests {
#[test]
#[should_panic]
fn sub_panics_on_underflow() {
let v = Amount::from_pair(zec(), -MAX_MONEY).unwrap();
let v = Amount::from_pair(zec(), MIN_MONEY).unwrap();
let _diff = v - Amount::from_pair(zec(), 1).unwrap();
}

#[test]
#[should_panic]
fn sub_assign_panics_on_underflow() {
let mut a = Amount::from_pair(zec(), -MAX_MONEY).unwrap();
let mut a = Amount::from_pair(zec(), MIN_MONEY).unwrap();
a -= Amount::from_pair(zec(), 1).unwrap();
}
}

0 comments on commit bbed283

Please sign in to comment.