diff --git a/crates/papyrus_rpc/src/v0_7/api/api_impl.rs b/crates/papyrus_rpc/src/v0_7/api/api_impl.rs index 48cd50d7f2..059a807e88 100644 --- a/crates/papyrus_rpc/src/v0_7/api/api_impl.rs +++ b/crates/papyrus_rpc/src/v0_7/api/api_impl.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use async_trait::async_trait; -use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use jsonrpsee::core::RpcResult; use jsonrpsee::types::ErrorObjectOwned; use jsonrpsee::RpcModule; @@ -118,6 +117,7 @@ use super::{ BlockHashAndNumber, BlockId, CallRequest, + CompiledContractClass, ContinuationToken, EventFilter, EventsChunk, @@ -1439,23 +1439,31 @@ impl JsonRpcServer for JsonRpcServerImpl { &self, block_id: BlockId, class_hash: ClassHash, - ) -> RpcResult { + ) -> RpcResult { let storage_txn = self.storage_reader.begin_ro_txn().map_err(internal_server_error)?; + let state_reader = storage_txn.get_state_reader().map_err(internal_server_error)?; let block_number = get_accepted_block_number(&storage_txn, block_id)?; - let class_definition_block_number = storage_txn - .get_state_reader() - .map_err(internal_server_error)? + if let Some(class_definition_block_number) = state_reader .get_class_definition_block_number(&class_hash) .map_err(internal_server_error)? - .ok_or_else(|| ErrorObjectOwned::from(CLASS_HASH_NOT_FOUND))?; - if class_definition_block_number > block_number { - return Err(ErrorObjectOwned::from(CLASS_HASH_NOT_FOUND)); + { + if class_definition_block_number > block_number { + return Err(ErrorObjectOwned::from(CLASS_HASH_NOT_FOUND)); + } + let casm = storage_txn + .get_casm(&class_hash) + .map_err(internal_server_error)? + .ok_or_else(|| ErrorObjectOwned::from(CLASS_HASH_NOT_FOUND))?; + return Ok(CompiledContractClass::V1(casm)); } - let casm = storage_txn - .get_casm(&class_hash) + + let state_number = StateNumber::right_after_block(block_number) + .ok_or_else(|| ErrorObjectOwned::from(CLASS_HASH_NOT_FOUND))?; + let deprecated_compiled_contract_class = state_reader + .get_deprecated_class_definition_at(state_number, &class_hash) .map_err(internal_server_error)? .ok_or_else(|| ErrorObjectOwned::from(CLASS_HASH_NOT_FOUND))?; - Ok(casm) + Ok(CompiledContractClass::V0(deprecated_compiled_contract_class)) } } diff --git a/crates/papyrus_rpc/src/v0_7/api/mod.rs b/crates/papyrus_rpc/src/v0_7/api/mod.rs index 5c87a41e07..a19000eaf0 100644 --- a/crates/papyrus_rpc/src/v0_7/api/mod.rs +++ b/crates/papyrus_rpc/src/v0_7/api/mod.rs @@ -20,7 +20,10 @@ use papyrus_storage::StorageTxn; use serde::{Deserialize, Serialize}; use starknet_api::block::BlockNumber; use starknet_api::core::{ClassHash, ContractAddress, Nonce}; -use starknet_api::deprecated_contract_class::Program; +use starknet_api::deprecated_contract_class::{ + ContractClass as StarknetApiDeprecatedContractClass, + Program, +}; use starknet_api::state::{StateNumber, StorageKey}; use starknet_api::transaction::{EventKey, Fee, TransactionHash, TransactionOffsetInBlock}; use starknet_types_core::felt::Felt; @@ -260,7 +263,7 @@ pub trait JsonRpc { &self, block_id: BlockId, class_hash: ClassHash, - ) -> RpcResult; + ) -> RpcResult; } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -670,3 +673,9 @@ pub struct TransactionTraceWithHash { pub transaction_hash: TransactionHash, pub trace_root: TransactionTrace, } + +#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq)] +pub enum CompiledContractClass { + V0(StarknetApiDeprecatedContractClass), + V1(CasmContractClass), +} diff --git a/crates/papyrus_rpc/src/v0_7/api/test.rs b/crates/papyrus_rpc/src/v0_7/api/test.rs index 42f10cfb36..891a0388b2 100644 --- a/crates/papyrus_rpc/src/v0_7/api/test.rs +++ b/crates/papyrus_rpc/src/v0_7/api/test.rs @@ -52,6 +52,7 @@ use starknet_api::core::{ }; use starknet_api::data_availability::L1DataAvailabilityMode; use starknet_api::deprecated_contract_class::{ + ContractClass as StarknetApiDeprecatedContractClass, ContractClassAbiEntry, FunctionAbiEntry, FunctionStateMutability, @@ -191,6 +192,7 @@ use crate::test_utils::{ validate_schema, SpecFile, }; +use crate::v0_7::api::CompiledContractClass; use crate::version_config::VERSION_0_7 as VERSION; use crate::{ internal_server_error, @@ -3565,39 +3567,55 @@ async fn get_deprecated_class_state_mutability() { #[tokio::test] async fn get_compiled_contract_class() { + let casm_class_hash = ClassHash(felt!("0x1")); + let deprecated_class_hash = ClassHash(felt!("0x2")); + let invalid_class_hash = ClassHash(felt!("0x3")); + let method_name = "starknet_V0_7_getCompiledContractClass"; let (module, mut storage_writer) = get_test_rpc_server_and_storage_writer_from_params::< JsonRpcServerImpl, >(None, None, None, None, None); - let class_hash = ClassHash(felt!("0x1")); let casm_contract_class = CasmContractClass::get_test_instance(&mut get_rng()); + let deprecated_contract_class = + StarknetApiDeprecatedContractClass::get_test_instance(&mut get_rng()); storage_writer .begin_rw_txn() .unwrap() .append_state_diff( BlockNumber(0), starknet_api::state::ThinStateDiff { - declared_classes: IndexMap::from([(class_hash, CompiledClassHash::default())]), + declared_classes: IndexMap::from([(casm_class_hash, CompiledClassHash::default())]), ..Default::default() }, ) .unwrap() - .append_casm(&class_hash, &casm_contract_class) + .append_casm(&casm_class_hash, &casm_contract_class) + .unwrap() + .append_classes(BlockNumber(0), &[], &[(deprecated_class_hash, &deprecated_contract_class)]) .unwrap() .commit() .unwrap(); let res = module - .call::<_, CasmContractClass>(method_name, (BlockId::Tag(Tag::Latest), class_hash)) + .call::<_, CompiledContractClass>(method_name, (BlockId::Tag(Tag::Latest), casm_class_hash)) + .await + .unwrap(); + assert_eq!(res, CompiledContractClass::V1(casm_contract_class)); + + let res = module + .call::<_, CompiledContractClass>( + method_name, + (BlockId::Tag(Tag::Latest), deprecated_class_hash), + ) .await .unwrap(); - assert_eq!(res, casm_contract_class); + assert_eq!(res, CompiledContractClass::V0(deprecated_contract_class)); // Ask for an invalid class hash. let err = module - .call::<_, CasmContractClass>( + .call::<_, CompiledContractClass>( method_name, - (BlockId::Tag(Tag::Latest), ClassHash(felt!("0x2"))), + (BlockId::Tag(Tag::Latest), invalid_class_hash), ) .await .unwrap_err();