Skip to content

Commit

Permalink
Better account id storage serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Kayanski committed Nov 26, 2024
1 parent 3bbd505 commit 191f9c8
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 57 deletions.
105 changes: 78 additions & 27 deletions framework/packages/abstract-std/src/objects/account/account_id.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::{fmt::Display, str::FromStr};

use cosmwasm_std::{StdError, StdResult};
use cosmwasm_std::StdResult;
use cw_storage_plus::{Key, KeyDeserialize, Prefixer, PrimaryKey};
use deser::split_first_key;

use super::{account_trace::AccountTrace, AccountSequence};
use crate::{objects::TruncatedChainId, AbstractError};
Expand Down Expand Up @@ -160,48 +161,72 @@ impl<'a> Prefixer<'a> for AccountId {

impl KeyDeserialize for &AccountId {
type Output = AccountId;
const KEY_ELEMS: u16 = 1;
const KEY_ELEMS: u16 = AccountId::KEY_ELEMS;

#[inline(always)]
fn from_vec(mut value: Vec<u8>) -> StdResult<Self::Output> {
let mut tu = value.split_off(2);
let t_len = parse_length(&value)?;
let u = tu.split_off(t_len);
fn from_vec(value: Vec<u8>) -> StdResult<Self::Output> {
let (trace, seq) = split_first_key(AccountTrace::KEY_ELEMS, value.as_ref())?;

println!("{:x?} - {:?}", trace, seq);

Ok(AccountId {
seq: AccountSequence::from_vec(u)?,
trace: AccountTrace::from_string(String::from_vec(tu)?),
seq: AccountSequence::from_vec(seq.to_vec())?,
trace: AccountTrace::from_vec(trace)?,
})
}
}

impl KeyDeserialize for AccountId {
type Output = AccountId;
const KEY_ELEMS: u16 = 1;
const KEY_ELEMS: u16 = AccountTrace::KEY_ELEMS + u32::KEY_ELEMS;

#[inline(always)]
fn from_vec(mut value: Vec<u8>) -> StdResult<Self::Output> {
let mut tu = value.split_off(2);
let t_len = parse_length(&value)?;
let u = tu.split_off(t_len);

Ok(AccountId {
seq: AccountSequence::from_vec(u)?,
trace: AccountTrace::from_string(String::from_vec(tu)?),
})
fn from_vec(value: Vec<u8>) -> StdResult<Self::Output> {
<&AccountId>::from_vec(value)
}
}

#[inline(always)]
fn parse_length(value: &[u8]) -> StdResult<usize> {
Ok(u16::from_be_bytes(
value
.try_into()
.map_err(|_| StdError::generic_err("Could not read 2 byte length"))?,
)
.into())
}
/// This was copied from cosmwasm-std
///
/// https://github.com/CosmWasm/cw-storage-plus/blob/f65cd4000a0dc1c009f3f99e23f9e10a1c256a68/src/de.rs#L173
pub(crate) mod deser {
use cosmwasm_std::{StdError, StdResult};

/// Splits the first key from the value based on the provided number of key elements.
/// The return value is ordered as (first_key, remainder).
///
pub fn split_first_key(key_elems: u16, value: &[u8]) -> StdResult<(Vec<u8>, &[u8])> {
let mut index = 0;
let mut first_key = Vec::new();

// Iterate over the sub keys
for i in 0..key_elems {
let len_slice = &value[index..index + 2];
index += 2;
let is_last_key = i == key_elems - 1;

if !is_last_key {
first_key.extend_from_slice(len_slice);
}

let subkey_len = parse_length(len_slice)?;
first_key.extend_from_slice(&value[index..index + subkey_len]);
index += subkey_len;
}

let remainder = &value[index..];
Ok((first_key, remainder))
}

fn parse_length(value: &[u8]) -> StdResult<usize> {
Ok(u16::from_be_bytes(
value
.try_into()
.map_err(|_| StdError::generic_err("Could not read 2 byte length"))?,
)
.into())
}
}
//--------------------------------------------------------------------------------------------------
// Tests
//--------------------------------------------------------------------------------------------------
Expand All @@ -226,6 +251,13 @@ mod test {
}
}

fn mock_local_key() -> AccountId {
AccountId {
seq: 54,
trace: AccountTrace::Remote(vec![]),
}
}

fn mock_keys() -> (AccountId, AccountId, AccountId) {
(
AccountId {
Expand Down Expand Up @@ -268,6 +300,25 @@ mod test {
assert_eq!(items[0], (key, 42069));
}

#[coverage_helper::test]
fn storage_key_local_works() {
let mut deps = mock_dependencies();
let key = mock_local_key();
let map: Map<&AccountId, u64> = Map::new("map");

map.save(deps.as_mut().storage, &key, &42069).unwrap();

assert_eq!(map.load(deps.as_ref().storage, &key).unwrap(), 42069);

let items = map
.range(deps.as_ref().storage, None, None, Order::Ascending)
.map(|item| item.unwrap())
.collect::<Vec<_>>();

assert_eq!(items.len(), 1);
assert_eq!(items[0], (key, 42069));
}

#[coverage_helper::test]
fn composite_key_works() {
let mut deps = mock_dependencies();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::fmt::Display;

use super::account_id::deser::split_first_key;
use cosmwasm_std::{ensure, Env, StdError, StdResult};
use cw_storage_plus::{Key, KeyDeserialize, Prefixer, PrimaryKey};

use crate::{constants::CHAIN_DELIMITER, objects::TruncatedChainId, AbstractError};

pub const MAX_TRACE_LENGTH: usize = 6;
pub const MAX_TRACE_LENGTH: u16 = 6;
pub(crate) const LOCAL: &str = "local";

/// The identifier of chain that triggered the account creation
Expand All @@ -16,14 +17,30 @@ pub enum AccountTrace {
Remote(Vec<TruncatedChainId>),
}

pub const ACCOUNT_TRACE_KEY_PLACEHOLDER: &str = "place-holder-key";

impl KeyDeserialize for &AccountTrace {
type Output = AccountTrace;
const KEY_ELEMS: u16 = 1;
const KEY_ELEMS: u16 = AccountTrace::KEY_ELEMS;

#[inline(always)]
fn from_vec(value: Vec<u8>) -> StdResult<Self::Output> {
let value = value.into_iter().filter(|b| *b > 32).collect();
Ok(AccountTrace::from_string(String::from_vec(value)?))
let mut trace = vec![];
// We parse the whole data for the MAX_TRACE_LENGTH keys
let mut value = value.as_ref();
for i in 0..MAX_TRACE_LENGTH - 1 {
let (t, remainder) = split_first_key(1, value)?;
value = remainder;
let chain = String::from_utf8(t)?;
if i == 0 && chain == "local" {
return Ok(AccountTrace::Local);
}
if chain != ACCOUNT_TRACE_KEY_PLACEHOLDER {
trace.push(TruncatedChainId::from_string(chain).unwrap())
}
}

Ok(AccountTrace::Remote(trace))
}
}

Expand All @@ -34,35 +51,27 @@ impl<'a> PrimaryKey<'a> for AccountTrace {
type SuperSuffix = Self;

fn key(&self) -> Vec<cw_storage_plus::Key> {
match self {
let mut serialization_result = match self {
AccountTrace::Local => LOCAL.key(),
AccountTrace::Remote(chain_name) => {
let len = chain_name.len();
chain_name
.iter()
.rev()
.enumerate()
.flat_map(|(s, c)| {
if s == len - 1 {
vec![c.str_ref().key()]
} else {
vec![c.str_ref().key(), CHAIN_DELIMITER.key()]
}
})
.flatten()
.collect::<Vec<Key>>()
}
AccountTrace::Remote(chain_name) => chain_name
.iter()
.flat_map(|c| c.str_ref().key())
.collect::<Vec<Key>>(),
};
for _ in serialization_result.len()..(MAX_TRACE_LENGTH as usize) {
serialization_result.extend(ACCOUNT_TRACE_KEY_PLACEHOLDER.key());
}
serialization_result
}
}

impl KeyDeserialize for AccountTrace {
type Output = AccountTrace;
const KEY_ELEMS: u16 = 1;
const KEY_ELEMS: u16 = MAX_TRACE_LENGTH;

#[inline(always)]
fn from_vec(value: Vec<u8>) -> StdResult<Self::Output> {
Ok(AccountTrace::from_string(String::from_vec(value)?))
<&AccountTrace>::from_vec(value)
}
}

Expand All @@ -80,7 +89,7 @@ impl AccountTrace {
AccountTrace::Remote(chain_trace) => {
// Ensure the trace length is limited
ensure!(
chain_trace.len() <= MAX_TRACE_LENGTH,
chain_trace.len() <= MAX_TRACE_LENGTH as usize,
AbstractError::FormattingError {
object: "chain-seq".into(),
expected: format!("between 1 and {MAX_TRACE_LENGTH}"),
Expand Down Expand Up @@ -370,7 +379,7 @@ mod test {
.map(|item| item.unwrap())
.collect::<Vec<_>>();

assert_eq!(items.len(), 3);
assert_eq!(items.len(), 2);
assert_eq!(items[0], (Addr::unchecked("jake"), 69420));
assert_eq!(items[1], (Addr::unchecked("larry"), 42069));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl KeyDeserialize for AssetEntry {

impl KeyDeserialize for &AssetEntry {
type Output = AssetEntry;
const KEY_ELEMS: u16 = 1;
const KEY_ELEMS: u16 = AssetEntry::KEY_ELEMS;

#[inline(always)]
fn from_vec(value: Vec<u8>) -> StdResult<Self::Output> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl<'a> Prefixer<'a> for &ChannelEntry {

impl KeyDeserialize for &ChannelEntry {
type Output = ChannelEntry;
const KEY_ELEMS: u16 = 1;
const KEY_ELEMS: u16 = 2;

#[inline(always)]
fn from_vec(mut value: Vec<u8>) -> StdResult<Self::Output> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl<'a> Prefixer<'a> for &ContractEntry {

impl KeyDeserialize for &ContractEntry {
type Output = ContractEntry;
const KEY_ELEMS: u16 = 1;
const KEY_ELEMS: u16 = 2;

#[inline(always)]
fn from_vec(mut value: Vec<u8>) -> StdResult<Self::Output> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl<'a> Prefixer<'a> for &DexAssetPairing {

impl KeyDeserialize for &DexAssetPairing {
type Output = DexAssetPairing;
const KEY_ELEMS: u16 = 1;
const KEY_ELEMS: u16 = 3;

#[inline(always)]
fn from_vec(value: Vec<u8>) -> StdResult<Self::Output> {
Expand Down
2 changes: 1 addition & 1 deletion framework/packages/abstract-std/src/objects/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl<'a> Prefixer<'a> for &ModuleInfo {

impl KeyDeserialize for &ModuleInfo {
type Output = ModuleInfo;
const KEY_ELEMS: u16 = 1;
const KEY_ELEMS: u16 = Namespace::KEY_ELEMS + String::KEY_ELEMS + ModuleVersion::KEY_ELEMS;

#[inline(always)]
fn from_vec(mut value: Vec<u8>) -> StdResult<Self::Output> {
Expand Down

0 comments on commit 191f9c8

Please sign in to comment.