Skip to content

Commit

Permalink
fix: using ChaingConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
blmalone committed Sep 25, 2024
1 parent 52518ee commit be23568
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions validation/superchain-version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,21 @@ func checkForStandardVersions(t *testing.T, chain *ChainConfig) {
// than the versions specified in the standard config
isTestnet := (chain.Superchain == "sepolia" || chain.Superchain == "sepolia-dev-0")

versions, err := getContractVersionsFromChain(*Addresses[chain.ChainID], client)
versions, err := getContractVersionsFromChain(*Addresses[chain.ChainID], client, chain)
require.NoError(t, err)
requireStandardSemvers(t, versions, isTestnet)
requireStandardSemvers(t, versions, isTestnet, chain)

// don't perform bytecode checking for testnets
if !isTestnet {
bytecodeHashes, err := getContractBytecodeHashesFromChain(chain.ChainID, *Addresses[chain.ChainID], client)
bytecodeHashes, err := getContractBytecodeHashesFromChain(chain.ChainID, *Addresses[chain.ChainID], client, chain)
require.NoError(t, err)
requireStandardByteCodeHashes(t, bytecodeHashes)
requireStandardByteCodeHashes(t, bytecodeHashes, chain)
}
}

// getContractVersionsFromChain pulls the appropriate contract versions from chain
// using the supplied client (calling the version() method for each contract). It does this concurrently.
func getContractVersionsFromChain(list AddressList, client *ethclient.Client) (ContractVersions, error) {
func getContractVersionsFromChain(list AddressList, client *ethclient.Client, chain *ChainConfig) (ContractVersions, error) {
// Prepare a concurrency-safe object to store version information in, and
// spin up a goroutine for each contract we are checking (to speed things up).
results := new(sync.Map)
Expand All @@ -66,7 +66,7 @@ func getContractVersionsFromChain(list AddressList, client *ethclient.Client) (C

wg := new(sync.WaitGroup)

contractsToCheckVersionOf := standard.NetworkVersions["mainnet"].Releases[standard.NetworkVersions["mainnet"].StandardRelease].GetNonEmpty()
contractsToCheckVersionOf := standard.NetworkVersions[chain.Superchain].Releases[standard.NetworkVersions[chain.Superchain].StandardRelease].GetNonEmpty()

for _, contractName := range contractsToCheckVersionOf {
a, err := list.AddressFor(contractName)
Expand Down Expand Up @@ -108,7 +108,7 @@ func getContractVersionsFromChain(list AddressList, client *ethclient.Client) (C

// getContractBytecodeHashesFromChain pulls the appropriate bytecode from chain
// using the supplied client, concurrently.
func getContractBytecodeHashesFromChain(chainID uint64, list AddressList, client *ethclient.Client) (standard.L1ContractBytecodeHashes, error) {
func getContractBytecodeHashesFromChain(chainID uint64, list AddressList, client *ethclient.Client, chain *ChainConfig) (standard.L1ContractBytecodeHashes, error) {
// Prepare a concurrency-safe object to store bytecode information in, and
// spin up a goroutine for each contract we are checking (to speed things up).
results := new(sync.Map)
Expand All @@ -124,7 +124,7 @@ func getContractBytecodeHashesFromChain(chainID uint64, list AddressList, client

wg := new(sync.WaitGroup)

contractsToCheckBytecodeOf := standard.BytecodeHashes[standard.NetworkVersions["mainnet"].StandardRelease].GetNonEmpty()
contractsToCheckBytecodeOf := standard.BytecodeHashes[standard.NetworkVersions[chain.Superchain].StandardRelease].GetNonEmpty()

for _, contractName := range contractsToCheckBytecodeOf {
contractAddress, err := list.AddressFor(contractName)
Expand Down Expand Up @@ -253,8 +253,8 @@ func getBytecodeHash(ctx context.Context, chainID uint64, contractName string, t
return crypto.Keccak256Hash(bytecodeImmutableFilterer.Bytecode).Hex(), nil
}

func requireStandardSemvers(t *testing.T, versions ContractVersions, isTestnet bool) {
standardVersions := standard.NetworkVersions["mainnet"].Releases[standard.NetworkVersions["mainnet"].StandardRelease]
func requireStandardSemvers(t *testing.T, versions ContractVersions, isTestnet bool, chain *ChainConfig) {
standardVersions := standard.NetworkVersions[chain.Superchain].Releases[standard.NetworkVersions[chain.Superchain].StandardRelease]
s := reflect.ValueOf(standardVersions)
c := reflect.ValueOf(versions)
matches := checkMatchOrTestnet(s, c, isTestnet)
Expand All @@ -265,12 +265,12 @@ func requireStandardSemvers(t *testing.T, versions ContractVersions, isTestnet b
}, cmp.Ignore()))
require.Truef(t, matches,
"contract versions do not match the standard versions for the %s release \n (-removed from standard / +added to actual):\n %s",
standard.NetworkVersions["mainnet"].StandardRelease, diff)
standard.NetworkVersions[chain.Superchain].StandardRelease, diff)
}
}

func requireStandardByteCodeHashes(t *testing.T, hashes standard.L1ContractBytecodeHashes) {
standardHashes := standard.BytecodeHashes[standard.NetworkVersions["mainnet"].StandardRelease]
func requireStandardByteCodeHashes(t *testing.T, hashes standard.L1ContractBytecodeHashes, chain *ChainConfig) {
standardHashes := standard.BytecodeHashes[standard.NetworkVersions[chain.Superchain].StandardRelease]
s := reflect.ValueOf(standardHashes)
c := reflect.ValueOf(hashes)
matches := checkMatch(s, c)
Expand All @@ -279,7 +279,7 @@ func requireStandardByteCodeHashes(t *testing.T, hashes standard.L1ContractBytec
diff := cmp.Diff(standardHashes, hashes)
require.Truef(t, matches,
"contract bytecode hashes do not match the standard bytecode hashes for the %s release \n (-removed from standard / +added to actual):\n %s",
standard.NetworkVersions["mainnet"].StandardRelease, diff)
standard.NetworkVersions[chain.Superchain].StandardRelease, diff)
}
}

Expand Down

0 comments on commit be23568

Please sign in to comment.