Skip to content

Commit

Permalink
Merge pull request #14 from multiversx/subscription-audit-fixes-2
Browse files Browse the repository at this point in the history
subscription audit fixes 2
  • Loading branch information
psorinionut authored Dec 19, 2023
2 parents 2d48400 + 4da728f commit 545ebfb
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 137 deletions.
11 changes: 1 addition & 10 deletions farm-boosted-rewards-subscriber/tests/subscription_setup/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,14 @@ use std::{cell::RefCell, rc::Rc};
use auto_farm::common::address_to_id_mapper::AddressId;
use multiversx_sc::types::{Address, MultiValueEncoded};
use multiversx_sc_scenario::{
managed_address, managed_biguint, managed_token_id, rust_biguint,
managed_address, managed_token_id, rust_biguint,
testing_framework::{BlockchainStateWrapper, ContractObjWrapper, TxResult},
DebugApi,
};
use subscription_fee::{fees::FeesModule, service::ServiceModule, SubscriptionFee};

use crate::{USDC_TOKEN_ID, WEGLD_TOKEN_ID};

pub const MAX_USER_DEPOSITS: usize = 5;
pub const MIN_USER_DEPOSIT_VALUE: u64 = 1_000_000;
pub const MAX_PENDING_SERVICES: usize = 5;
pub const MAX_SERVICE_INFO_NO: usize = 5;

pub struct SubscriptionSetup<SubscriptionObjBuilder>
where
SubscriptionObjBuilder: 'static + Copy + Fn() -> subscription_fee::ContractObj<DebugApi>,
Expand Down Expand Up @@ -58,10 +53,6 @@ where
sc.init(
managed_token_id!(USDC_TOKEN_ID),
managed_token_id!(WEGLD_TOKEN_ID),
MAX_USER_DEPOSITS,
managed_biguint!(MIN_USER_DEPOSIT_VALUE),
MAX_PENDING_SERVICES,
MAX_SERVICE_INFO_NO,
managed_address!(pair_address),
args,
);
Expand Down
10 changes: 1 addition & 9 deletions subscriber/tests/subscription_setup/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@ use std::{cell::RefCell, rc::Rc};
use auto_farm::common::address_to_id_mapper::AddressId;
use multiversx_sc::types::{Address, MultiValueEncoded};
use multiversx_sc_scenario::{
managed_address, managed_biguint, managed_token_id, rust_biguint,
managed_address, managed_token_id, rust_biguint,
testing_framework::{BlockchainStateWrapper, ContractObjWrapper, TxResult},
DebugApi,
};
use subscription_fee::{fees::FeesModule, service::ServiceModule, SubscriptionFee};

use crate::{USDC_TOKEN_ID, WEGLD_TOKEN_ID};
pub const MAX_USER_DEPOSITS: usize = 5;
pub const MIN_USER_DEPOSIT_VALUE: u64 = 1_000_000;
pub const MAX_PENDING_SERVICES: usize = 5;
pub const MAX_SERVICE_INFO_NO: usize = 5;

pub struct SubscriptionSetup<SubscriptionObjBuilder>
where
Expand Down Expand Up @@ -58,10 +54,6 @@ where
sc.init(
managed_token_id!(USDC_TOKEN_ID),
managed_token_id!(WEGLD_TOKEN_ID),
MAX_USER_DEPOSITS,
managed_biguint!(MIN_USER_DEPOSIT_VALUE),
MAX_PENDING_SERVICES,
MAX_SERVICE_INFO_NO,
managed_address!(pair_address),
args,
)
Expand Down
21 changes: 5 additions & 16 deletions subscription-fee/src/common_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,12 @@ pub trait CommonStorageModule {
user_id: AddressId,
) -> SingleValueMapper<UniquePayments<Self::Api>>;

#[view(getMaxUserDeposits)]
#[storage_mapper("maxUserDeposits")]
fn max_user_deposits(&self) -> SingleValueMapper<usize>;
#[view(getMinTokenDepositValue)]
#[storage_mapper("minTokenDepositValue")]
fn min_token_deposit_value(&self, token_id: &TokenIdentifier) -> SingleValueMapper<BigUint>;

#[view(getMinUserDepositValue)]
#[storage_mapper("minUserDepositValue")]
fn min_user_deposit_value(&self) -> SingleValueMapper<BigUint>;

#[storage_mapper("userLastActionEpoch")]
fn user_last_action_epoch(
#[storage_mapper("userNextPaymentEpoch")]
fn user_next_payment_epoch(
&self,
user_id: AddressId,
service_id: AddressId,
Expand All @@ -46,10 +42,6 @@ pub trait CommonStorageModule {
#[storage_mapper("pendingServices")]
fn pending_services(&self) -> UnorderedSetMapper<ManagedAddress>;

#[view(getMaxPendingServices)]
#[storage_mapper("maxPendingServices")]
fn max_pending_services(&self) -> SingleValueMapper<usize>;

#[storage_mapper("pendingServiceInfo")]
fn pending_service_info(
&self,
Expand All @@ -64,9 +56,6 @@ pub trait CommonStorageModule {
service_id: AddressId,
) -> SingleValueMapper<ManagedVec<ServiceInfo<Self::Api>>>;

#[storage_mapper("maxServiceInfoNo")]
fn max_service_info_no(&self) -> SingleValueMapper<usize>;

#[view(getSubscribedUsers)]
#[storage_mapper("subscribedUsers")]
fn subscribed_users(
Expand Down
22 changes: 14 additions & 8 deletions subscription-fee/src/fees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use auto_farm::common::unique_payments::UniquePayments;

use crate::common_storage;
use crate::pair_actions;
use crate::service::MAX_USER_DEPOSITS;

#[multiversx_sc::module]
pub trait FeesModule:
Expand All @@ -20,10 +21,14 @@ pub trait FeesModule:
}

#[only_owner]
#[endpoint(setMaxUserDeposits)]
fn set_max_user_deposits(&self, max_user_deposits: usize) {
require!(max_user_deposits > 0, "Value must be greater than o");
self.max_user_deposits().set(max_user_deposits);
#[endpoint(setMinDepositValue)]
fn set_min_deposit_value(&self, token_id: TokenIdentifier, min_token_deposit_value: BigUint) {
if min_token_deposit_value == BigUint::zero() {
self.min_token_deposit_value(&token_id).clear();
} else {
self.min_token_deposit_value(&token_id)
.set(min_token_deposit_value);
}
}

#[payable("*")]
Expand All @@ -43,9 +48,11 @@ pub trait FeesModule:
require!(payment_value_result.is_ok(), "Could not get payment value");

let payment_value = unsafe { payment_value_result.unwrap_unchecked() };
let min_user_deposit_value = self.min_user_deposit_value().get();
let min_token_deposit_value = self
.min_token_deposit_value(&payment.token_identifier)
.get();
require!(
payment_value > min_user_deposit_value,
payment_value >= min_token_deposit_value,
"Payment value is lesser than the minimum accepted"
);

Expand Down Expand Up @@ -113,9 +120,8 @@ pub trait FeesModule:
dest_mapper.update(|fees| {
fees.add_payment(payment);

let max_user_deposits = self.max_user_deposits().get();
require!(
fees.clone().into_payments().len() < max_user_deposits,
fees.clone().into_payments().len() <= MAX_USER_DEPOSITS,
"Maximum number of deposits per user reached"
);
});
Expand Down
26 changes: 0 additions & 26 deletions subscription-fee/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ pub trait SubscriptionFee:
&self,
stable_token_id: TokenIdentifier,
wegld_token_id: TokenIdentifier,
max_user_deposits: usize,
min_user_deposit_value: BigUint,
max_pending_services: usize,
max_service_info_no: usize,
price_query_address: ManagedAddress,
accepted_tokens: MultiValueEncoded<TokenIdentifier>,
) {
Expand All @@ -39,22 +35,6 @@ pub trait SubscriptionFee:
wegld_token_id.is_valid_esdt_identifier(),
"WEGLD token not valid"
);
require!(
max_user_deposits > 0,
"Max user deposits no must be greater than 0"
);
require!(
min_user_deposit_value > 0,
"Min user deposit value must be greater than 0"
);
require!(
max_pending_services > 0,
"Max pending services no must be greater than 0"
);
require!(
max_service_info_no > 0,
"Max service info no must be greater than 0"
);
require!(
self.blockchain().is_smart_contract(&price_query_address),
"Invalid price query address"
Expand All @@ -64,11 +44,5 @@ pub trait SubscriptionFee:
self.wegld_token_id().set_if_empty(wegld_token_id);
self.price_query_address().set_if_empty(price_query_address);
self.add_accepted_fees_tokens(accepted_tokens);
self.max_user_deposits().set_if_empty(max_user_deposits);
self.min_user_deposit_value()
.set_if_empty(min_user_deposit_value);
self.max_pending_services()
.set_if_empty(max_pending_services);
self.max_service_info_no().set_if_empty(max_service_info_no);
}
}
6 changes: 3 additions & 3 deletions subscription-fee/src/pair_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub mod pair_proxy {
#[multiversx_sc::module]
pub trait PairActionsModule: crate::common_storage::CommonStorageModule {
#[only_owner]
#[endpoint(addUsdcPair)]
#[endpoint(addPairAddress)]
fn add_pair_address(&self, payment_token_id: TokenIdentifier, pair_address: ManagedAddress) {
require!(
payment_token_id.is_valid_esdt_identifier(),
Expand All @@ -32,8 +32,8 @@ pub trait PairActionsModule: crate::common_storage::CommonStorageModule {
}

#[only_owner]
#[endpoint(removeUsdcPair)]
fn remove_pair_data(&self, token_id: TokenIdentifier) {
#[endpoint(removePairAddress)]
fn remove_pair_address(&self, token_id: TokenIdentifier) {
self.pair_address_for_token(&token_id).clear();
}

Expand Down
60 changes: 23 additions & 37 deletions subscription-fee/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use crate::common_storage;
use crate::subtract_payments::Epoch;
use crate::{fees, pair_actions};

pub const MAX_USER_DEPOSITS: usize = 20;
pub const MAX_SERVICES_LENGTH: usize = 20;

#[derive(TypeAbi, TopEncode, TopDecode, NestedEncode, NestedDecode, ManagedVecItem)]
pub struct ServiceInfo<M: ManagedTypeApi> {
pub opt_payment_token: Option<TokenIdentifier<M>>,
Expand All @@ -18,20 +21,6 @@ pub struct ServiceInfo<M: ManagedTypeApi> {
pub trait ServiceModule:
fees::FeesModule + pair_actions::PairActionsModule + common_storage::CommonStorageModule
{
#[only_owner]
#[endpoint(setMaxServiceInfoNo)]
fn set_max_service_info_no(&self, max_service_info_no: usize) {
require!(max_service_info_no > 0, "Value must be greater than o");
self.max_service_info_no().set(max_service_info_no);
}

#[only_owner]
#[endpoint(setMaxPendingServices)]
fn set_max_pending_services(&self, max_pending_services: usize) {
require!(max_pending_services > 0, "Value must be greater than o");
self.max_pending_services().set(max_pending_services);
}

/// Arguments are MultiValue3 of opt_payment_token, payment_amount and subscription_epochs
#[endpoint(registerService)]
fn register_service(
Expand All @@ -48,6 +37,7 @@ pub trait ServiceModule:
for arg in args {
let (opt_payment_token, amount, subscription_epochs) = arg.into_tuple();

require!(subscription_epochs > 0, "Subscription epochs must be > 0");
if let Some(token_id) = &opt_payment_token {
require!(
self.accepted_fees_tokens().contains(token_id),
Expand All @@ -63,13 +53,14 @@ pub trait ServiceModule:
}

self.pending_service_info(&service_address)
.update(|existing_services| existing_services.extend(services.iter()));
.update(|existing_services| {
existing_services.extend(services.iter());
require!(
existing_services.len() <= MAX_SERVICES_LENGTH,
"Maximum services length reached"
);
});
let _ = self.pending_services().insert(service_address);
let max_pending_services = self.max_pending_services().get();
require!(
self.pending_services().len() <= max_pending_services,
"Maximum number of pendind services reached"
);
}

#[endpoint(addExtraServices)]
Expand All @@ -87,6 +78,7 @@ pub trait ServiceModule:
for arg in args {
let (opt_payment_token, amount, subscription_epochs) = arg.into_tuple();

require!(subscription_epochs > 0, "Subscription epochs must be > 0");
if let Some(token_id) = &opt_payment_token {
require!(
self.accepted_fees_tokens().contains(token_id),
Expand All @@ -102,13 +94,13 @@ pub trait ServiceModule:
}

let service_info_mapper = self.service_info(existing_service_id);
service_info_mapper.update(|existing_services| existing_services.extend(services.iter()));

let max_service_info_no = self.max_service_info_no().get();
require!(
service_info_mapper.get().len() <= max_service_info_no,
"Maximum service info no reached"
);
service_info_mapper.update(|existing_services| {
existing_services.extend(services.iter());
require!(
existing_services.len() <= MAX_SERVICES_LENGTH,
"Maximum services length reached"
);
});
}

#[endpoint(unregisterService)]
Expand Down Expand Up @@ -147,16 +139,15 @@ pub trait ServiceModule:
let service_info = self.pending_service_info(&service_address).take();
self.service_info(service_id).set(&service_info);

let max_service_info_no = self.max_service_info_no().get();
require!(
self.service_info(service_id).get().len() <= max_service_info_no,
"Maximum service info no reached"
self.service_info(service_id).get().len() <= MAX_SERVICES_LENGTH,
"Maximum services lenght reached"
);

let _ = self.pending_services().swap_remove(&service_address);
}

/// subscribe with the following arguments: service_id, service index, subscription type
/// subscribe with the following arguments: service_id, service index
#[endpoint]
fn subscribe(&self, services: MultiValueEncoded<MultiValue2<AddressId, usize>>) {
let caller = self.blockchain().get_caller();
Expand Down Expand Up @@ -184,16 +175,11 @@ pub trait ServiceModule:

for service in services {
let (service_id, service_index) = service.into_tuple();
let service_options = self.service_info(service_id).get();
require!(
service_index < service_options.len(),
"Invalid service index"
);

let _ = self
.subscribed_users(service_id, service_index)
.swap_remove(&caller_id);
self.user_last_action_epoch(caller_id, service_id, service_index)
self.user_next_payment_epoch(caller_id, service_id, service_index)
.clear();
}
}
Expand Down
20 changes: 10 additions & 10 deletions subscription-fee/src/subtract_payments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,15 @@ pub trait SubtractPaymentsModule:
let service_id = self.service_id().get_id_non_zero(&caller);
let current_epoch = self.blockchain().get_block_epoch();

let last_action_mapper = self.user_last_action_epoch(user_id, service_id, service_index);
let last_action_epoch = last_action_mapper.get();
let next_payment_mapper = self.user_next_payment_epoch(user_id, service_id, service_index);
let next_payment_epoch = next_payment_mapper.get();

require!(next_payment_epoch <= current_epoch, "Cannot subtract yet");
require!(
self.subscribed_users(service_id, service_index)
.contains(&user_id),
"User is not subscribed to the service"
);

let service_info = self.service_info(service_id).get().get(service_index);

Expand All @@ -64,13 +71,6 @@ pub trait SubtractPaymentsModule:
return ScResult::Err(());
}

let next_subtract_epoch = if last_action_epoch > 0 {
last_action_epoch + subscription_epochs
} else {
current_epoch
};
require!(next_subtract_epoch <= current_epoch, "Cannot subtract yet");

let opt_user_address = self.user_id().get_address(user_id);
if opt_user_address.is_none() {
return ScResult::Err(());
Expand All @@ -88,7 +88,7 @@ pub trait SubtractPaymentsModule:
&payment.amount,
);

last_action_mapper.set(next_subtract_epoch);
next_payment_mapper.set(current_epoch + subscription_epochs);
}

subtract_result
Expand Down
Loading

0 comments on commit 545ebfb

Please sign in to comment.