Skip to content

Commit

Permalink
feat: improve error handling during contract upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
igorsenych-cw committed Oct 31, 2024
1 parent 548a781 commit fdca2b3
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 29 deletions.
9 changes: 0 additions & 9 deletions contracts/Cashier.sol
Original file line number Diff line number Diff line change
Expand Up @@ -894,10 +894,6 @@ contract Cashier is
* @param newImplementation The address of the new shard implementation.
*/
function upgradeShardsTo(address newImplementation) external onlyRole(OWNER_ROLE) {
if (newImplementation == address(0)) {
revert Cashier_ShardAddressZero();
}

for (uint256 i = 0; i < _shards.length; i++) {
_shards[i].upgradeTo(newImplementation);
}
Expand All @@ -909,12 +905,7 @@ contract Cashier is
* @param newShardImplementation The address of the new shard implementation.
*/
function upgradeRootAndShardsTo(address newRootImplementation, address newShardImplementation) external {
if (newShardImplementation == address(0)) {
revert Cashier_ShardAddressZero();
}

upgradeToAndCall(newRootImplementation, "");

for (uint256 i = 0; i < _shards.length; i++) {
_shards[i].upgradeTo(newShardImplementation);
}
Expand Down
24 changes: 14 additions & 10 deletions contracts/CashierShard.sol
Original file line number Diff line number Diff line change
Expand Up @@ -333,23 +333,27 @@ contract CashierShard is CashierShardStorage, OwnableUpgradeable, UUPSUpgradeabl
newImplementation; // Suppresses a compiler warning about the unused variable.
}

// ------------------ Service functions ----------------------- //

/**
* @dev The version of the standard upgrade function without the second parameter for backward compatibility.
* @custom:oz-upgrades-unsafe-allow-reachable delegatecall
*/
function upgradeTo(address newImplementation) external {
upgradeToAndCall(newImplementation, "");
}

/**
* @dev Validates the provided shard.
* @param shard The cashier shard contract address.
*/
function _validateShardContract(address shard) internal pure {
if (shard == address(0)) {
revert CashierShard_ShardAddressZero();
}

try ICashierShard(shard).isCashierShard() {} catch {
revert CashierShard_ContractNotShard();
}
}

// ------------------ Service functions ----------------------- //

/**
* @dev The version of the standard upgrade function without the second parameter for backward compatibility.
* @custom:oz-upgrades-unsafe-allow-reachable delegatecall
*/
function upgradeTo(address newImplementation) external {
upgradeToAndCall(newImplementation, "");
}
}
5 changes: 4 additions & 1 deletion contracts/interfaces/ICashierShard.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import { ICashierTypes } from "./ICashierTypes.sol";
* @dev Defines the custom errors used in the cashier shard contract.
*/
interface ICashierShardErrors {
/// @dev Thrown if the contract is not a shard contract.
/// @dev Thrown if the provided shard address is zero.
error CashierShard_ShardAddressZero();

/// @dev Thrown if the contract is not a cashier shard contract.
error CashierShard_ContractNotShard();

/// @dev Thrown if the caller is not an admin.
Expand Down
19 changes: 10 additions & 9 deletions test/CashierSharded.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async function setUpFixture<T>(func: () => Promise<T>): Promise<T> {
}
}

describe("Contracts 'Cashier' and `CashierShard`", async () => {
describe.only("Contracts 'Cashier' and `CashierShard`", async () => {
const TRANSACTION_ID1 = ethers.encodeBytes32String("MOCK_TRANSACTION_ID1");
const TRANSACTION_ID2 = ethers.encodeBytes32String("MOCK_TRANSACTION_ID2");
const TRANSACTION_ID3 = ethers.encodeBytes32String("MOCK_TRANSACTION_ID3");
Expand Down Expand Up @@ -217,15 +217,16 @@ describe("Contracts 'Cashier' and `CashierShard`", async () => {
const REVERT_ERROR_IF_CASH_OUT_ACCOUNT_INAPPROPRIATE = "Cashier_CashOutAccountInappropriate";
const REVERT_ERROR_IF_CASH_OUT_STATUS_INAPPROPRIATE = "Cashier_CashOutStatusInappropriate";
const REVERT_ERROR_IF_CONTRACT_NOT_ROOT = "Cashier_ContractNotRoot";
const REVERT_ERROR_IF_CONTRACT_NOT_SHARD_ON_SHARD = "CashierShard_ContractNotShard";
const REVERT_ERROR_IF_HOOK_CALLABLE_CONTRACT_ADDRESS_ZERO = "Cashier_HookCallableContractAddressZero";
const REVERT_ERROR_IF_HOOK_CALLABLE_CONTRACT_ADDRESS_NON_ZERO = "Cashier_HookCallableContractAddressNonZero";
const REVERT_ERROR_IF_HOOK_FLAGS_ALREADY_REGISTERED = "Cashier_HookFlagsAlreadyRegistered";
const REVERT_ERROR_IF_HOOK_FLAGS_INVALID = "Cashier_HookFlagsInvalid";
const REVERT_ERROR_IF_PREMINT_RELEASE_TIME_INAPPROPRIATE = "Cashier_PremintReleaseTimeInappropriate";
const REVERT_ERROR_IF_ROOT_ADDRESS_IS_ZERO = "Cashier_RootAddressZero";
const REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO = "Cashier_ShardAddressZero";
const REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO_ON_ROOT = "Cashier_ShardAddressZero";
const REVERT_ERROR_IF_CONTRACT_NOT_SHARD_ON_ROOT = "Cashier_ContractNotShard";
const REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO_ON_SHARD = "CashierShard_ShardAddressZero";
const REVERT_ERROR_IF_CONTRACT_NOT_SHARD_ON_SHARD = "CashierShard_ContractNotShard";
const REVERT_ERROR_IF_SHARD_COUNT_EXCESS = "Cashier_ShardCountExcess";
const REVERT_ERROR_IF_SHARD_REPLACEMENT_COUNT_EXCESS = "Cashier_ShardReplacementCountExcess";
const REVERT_ERROR_IF_TOKEN_ADDRESS_IS_ZERO = "Cashier_TokenAddressZero";
Expand Down Expand Up @@ -933,7 +934,7 @@ describe("Contracts 'Cashier' and `CashierShard`", async () => {
const { cashierRoot } = await setUpFixture(deployContracts);
await expect(
cashierRoot.addShards([ADDRESS_ZERO])
).to.be.revertedWithCustomError(cashierRoot, REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO);
).to.be.revertedWithCustomError(cashierRoot, REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO_ON_ROOT);
});

it("Is reverted if the provided contract is not a shard contract", async () => {
Expand Down Expand Up @@ -1028,7 +1029,7 @@ describe("Contracts 'Cashier' and `CashierShard`", async () => {
await proveTx(cashierRoot.addShards(shardAddresses));
await expect(
cashierRoot.replaceShards(0, [ADDRESS_ZERO])
).to.be.revertedWithCustomError(cashierRoot, REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO);
).to.be.revertedWithCustomError(cashierRoot, REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO_ON_ROOT);
});

it("Is reverted if the provided contract is not a shard contract", async () => {
Expand Down Expand Up @@ -1071,10 +1072,10 @@ describe("Contracts 'Cashier' and `CashierShard`", async () => {
});

it("Is reverted if the shard implementation address is zero", async () => {
const { cashierRoot } = await setUpFixture(deployAndConfigureContracts);
const { cashierRoot, cashierShards } = await setUpFixture(deployAndConfigureContracts);
await expect(
cashierRoot.upgradeShardsTo(ADDRESS_ZERO)
).to.be.revertedWithCustomError(cashierRoot, REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO);
).to.be.revertedWithCustomError(cashierShards[0], REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO_ON_SHARD);
});

it("Is reverted if the shard implementation address is not a shard contract", async () => {
Expand Down Expand Up @@ -1156,15 +1157,15 @@ describe("Contracts 'Cashier' and `CashierShard`", async () => {
});

it("Is reverted if the shard implementation address is zero", async () => {
const { cashierRoot } = await setUpFixture(deployAndConfigureContracts);
const { cashierRoot, cashierShards } = await setUpFixture(deployAndConfigureContracts);
const targetRootImplementationAddress = await getImplementationAddress(cashierRoot);

await expect(
cashierRoot.upgradeRootAndShardsTo(
targetRootImplementationAddress,
ADDRESS_ZERO
)
).to.be.revertedWithCustomError(cashierRoot, REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO);
).to.be.revertedWithCustomError(cashierShards[0], REVERT_ERROR_IF_SHARD_ADDRESS_IS_ZERO_ON_SHARD);
});

it("Is reverted if the provided root implementation is not a cashier root contract", async () => {
Expand Down

0 comments on commit fdca2b3

Please sign in to comment.