Skip to content

Commit

Permalink
move crosschain keeper tests from types to keeper
Browse files Browse the repository at this point in the history
  • Loading branch information
gartnera committed Sep 24, 2024
1 parent f95066d commit 87f6474
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 162 deletions.
160 changes: 160 additions & 0 deletions x/crosschain/keeper/cctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@ import (
"testing"

"cosmossdk.io/math"
sdkmath "cosmossdk.io/math"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/query"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/zeta-chain/node/pkg/chains"
"github.com/zeta-chain/node/pkg/coin"
keepertest "github.com/zeta-chain/node/testutil/keeper"
"github.com/zeta-chain/node/testutil/sample"
"github.com/zeta-chain/node/x/crosschain/keeper"
"github.com/zeta-chain/node/x/crosschain/types"
observertypes "github.com/zeta-chain/node/x/observer/types"
)

func createNCctxWithStatus(
Expand Down Expand Up @@ -288,3 +292,159 @@ func TestKeeper_RemoveCrossChainTx(t *testing.T) {
txs = keeper.GetAllCrossChainTx(ctx)
require.Equal(t, 4, len(txs))
}

func TestCrossChainTx_AddOutbound(t *testing.T) {
t.Run("successfully get outbound tx", func(t *testing.T) {
_, ctx, _, _ := keepertest.CrosschainKeeper(t)
cctx := sample.CrossChainTx(t, "test")
hash := sample.Hash().String()

err := cctx.AddOutbound(ctx, types.MsgVoteOutbound{
ValueReceived: cctx.GetCurrentOutboundParam().Amount,
ObservedOutboundHash: hash,
ObservedOutboundBlockHeight: 10,
ObservedOutboundGasUsed: 100,
ObservedOutboundEffectiveGasPrice: sdkmath.NewInt(100),
ObservedOutboundEffectiveGasLimit: 20,
}, observertypes.BallotStatus_BallotFinalized_SuccessObservation)
require.NoError(t, err)
require.Equal(t, cctx.GetCurrentOutboundParam().Hash, hash)
require.Equal(t, cctx.GetCurrentOutboundParam().GasUsed, uint64(100))
require.Equal(t, cctx.GetCurrentOutboundParam().EffectiveGasPrice, sdkmath.NewInt(100))
require.Equal(t, cctx.GetCurrentOutboundParam().EffectiveGasLimit, uint64(20))
require.Equal(t, cctx.GetCurrentOutboundParam().ObservedExternalHeight, uint64(10))
})

t.Run("successfully get outbound tx for failed ballot without amount check", func(t *testing.T) {
_, ctx, _, _ := keepertest.CrosschainKeeper(t)
cctx := sample.CrossChainTx(t, "test")
hash := sample.Hash().String()

err := cctx.AddOutbound(ctx, types.MsgVoteOutbound{
ObservedOutboundHash: hash,
ObservedOutboundBlockHeight: 10,
ObservedOutboundGasUsed: 100,
ObservedOutboundEffectiveGasPrice: sdkmath.NewInt(100),
ObservedOutboundEffectiveGasLimit: 20,
}, observertypes.BallotStatus_BallotFinalized_FailureObservation)
require.NoError(t, err)
require.Equal(t, cctx.GetCurrentOutboundParam().Hash, hash)
require.Equal(t, cctx.GetCurrentOutboundParam().GasUsed, uint64(100))
require.Equal(t, cctx.GetCurrentOutboundParam().EffectiveGasPrice, sdkmath.NewInt(100))
require.Equal(t, cctx.GetCurrentOutboundParam().EffectiveGasLimit, uint64(20))
require.Equal(t, cctx.GetCurrentOutboundParam().ObservedExternalHeight, uint64(10))
})

t.Run("failed to get outbound tx if amount does not match value received", func(t *testing.T) {
_, ctx, _, _ := keepertest.CrosschainKeeper(t)

cctx := sample.CrossChainTx(t, "test")
hash := sample.Hash().String()

err := cctx.AddOutbound(ctx, types.MsgVoteOutbound{
ValueReceived: sdkmath.NewUint(100),
ObservedOutboundHash: hash,
ObservedOutboundBlockHeight: 10,
ObservedOutboundGasUsed: 100,
ObservedOutboundEffectiveGasPrice: sdkmath.NewInt(100),
ObservedOutboundEffectiveGasLimit: 20,
}, observertypes.BallotStatus_BallotFinalized_SuccessObservation)
require.ErrorIs(t, err, sdkerrors.ErrInvalidRequest)
})
}

func Test_NewCCTX(t *testing.T) {
t.Run("should return a cctx with correct values", func(t *testing.T) {
_, ctx, _, _ := keepertest.CrosschainKeeper(t)
senderChain := chains.Goerli
sender := sample.EthAddress()
receiverChain := chains.Goerli
receiver := sample.EthAddress()
creator := sample.AccAddress()
amount := sdkmath.NewUint(42)
message := "test"
inboundBlockHeight := uint64(420)
inboundHash := sample.Hash()
gasLimit := uint64(100)
asset := "test-asset"
eventIndex := uint64(1)
cointType := coin.CoinType_ERC20
tss := sample.Tss()
msg := types.MsgVoteInbound{
Creator: creator,
Sender: sender.String(),
SenderChainId: senderChain.ChainId,
Receiver: receiver.String(),
ReceiverChain: receiverChain.ChainId,
Amount: amount,
Message: message,
InboundHash: inboundHash.String(),
InboundBlockHeight: inboundBlockHeight,
GasLimit: gasLimit,
CoinType: cointType,
TxOrigin: sender.String(),
Asset: asset,
EventIndex: eventIndex,
ProtocolContractVersion: types.ProtocolContractVersion_V2,
}
cctx, err := types.NewCCTX(ctx, msg, tss.TssPubkey)
require.NoError(t, err)
require.Equal(t, receiver.String(), cctx.GetCurrentOutboundParam().Receiver)
require.Equal(t, receiverChain.ChainId, cctx.GetCurrentOutboundParam().ReceiverChainId)
require.Equal(t, sender.String(), cctx.GetInboundParams().Sender)
require.Equal(t, senderChain.ChainId, cctx.GetInboundParams().SenderChainId)
require.Equal(t, amount, cctx.GetInboundParams().Amount)
require.Equal(t, message, cctx.RelayedMessage)
require.Equal(t, inboundHash.String(), cctx.GetInboundParams().ObservedHash)
require.Equal(t, inboundBlockHeight, cctx.GetInboundParams().ObservedExternalHeight)
require.Equal(t, gasLimit, cctx.GetCurrentOutboundParam().GasLimit)
require.Equal(t, asset, cctx.GetInboundParams().Asset)
require.Equal(t, cointType, cctx.InboundParams.CoinType)
require.Equal(t, uint64(0), cctx.GetCurrentOutboundParam().TssNonce)
require.Equal(t, sdkmath.ZeroUint(), cctx.GetCurrentOutboundParam().Amount)
require.Equal(t, types.CctxStatus_PendingInbound, cctx.CctxStatus.Status)
require.Equal(t, false, cctx.CctxStatus.IsAbortRefunded)
require.Equal(t, types.ProtocolContractVersion_V2, cctx.ProtocolContractVersion)
})

t.Run("should return an error if the cctx is invalid", func(t *testing.T) {
_, ctx, _, _ := keepertest.CrosschainKeeper(t)
senderChain := chains.Goerli
sender := sample.EthAddress()
receiverChain := chains.Goerli
receiver := sample.EthAddress()
creator := sample.AccAddress()
amount := sdkmath.NewUint(42)
message := "test"
inboundBlockHeight := uint64(420)
inboundHash := sample.Hash()
gasLimit := uint64(100)
asset := "test-asset"
eventIndex := uint64(1)
cointType := coin.CoinType_ERC20
tss := sample.Tss()
msg := types.MsgVoteInbound{
Creator: creator,
Sender: "",
SenderChainId: senderChain.ChainId,
Receiver: receiver.String(),
ReceiverChain: receiverChain.ChainId,
Amount: amount,
Message: message,
InboundHash: inboundHash.String(),
InboundBlockHeight: inboundBlockHeight,
GasLimit: gasLimit,
CoinType: cointType,
TxOrigin: sender.String(),
Asset: asset,
EventIndex: eventIndex,
}
_, err := types.NewCCTX(ctx, msg, tss.TssPubkey)
require.ErrorContains(t, err, "sender cannot be empty")
})

t.Run("zero value for protocol contract version gives V1", func(t *testing.T) {
cctx := types.CrossChainTx{}
require.Equal(t, types.ProtocolContractVersion_V1, cctx.ProtocolContractVersion)
})
}
162 changes: 0 additions & 162 deletions x/crosschain/types/cctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,10 @@ import (
"math/rand"
"testing"

sdkmath "cosmossdk.io/math"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/stretchr/testify/require"

"github.com/zeta-chain/node/pkg/chains"
"github.com/zeta-chain/node/pkg/coin"
keepertest "github.com/zeta-chain/node/testutil/keeper"
"github.com/zeta-chain/node/testutil/sample"
"github.com/zeta-chain/node/x/crosschain/types"
observertypes "github.com/zeta-chain/node/x/observer/types"
)

func TestCrossChainTx_GetEVMRevertAddress(t *testing.T) {
Expand Down Expand Up @@ -63,102 +57,6 @@ func TestCrossChainTx_GetCCTXIndexBytes(t *testing.T) {
require.Equal(t, cctx.Index, types.GetCctxIndexFromBytes(indexBytes))
}

func Test_NewCCTX(t *testing.T) {
t.Run("should return a cctx with correct values", func(t *testing.T) {
_, ctx, _, _ := keepertest.CrosschainKeeper(t)
senderChain := chains.Goerli
sender := sample.EthAddress()
receiverChain := chains.Goerli
receiver := sample.EthAddress()
creator := sample.AccAddress()
amount := sdkmath.NewUint(42)
message := "test"
inboundBlockHeight := uint64(420)
inboundHash := sample.Hash()
gasLimit := uint64(100)
asset := "test-asset"
eventIndex := uint64(1)
cointType := coin.CoinType_ERC20
tss := sample.Tss()
msg := types.MsgVoteInbound{
Creator: creator,
Sender: sender.String(),
SenderChainId: senderChain.ChainId,
Receiver: receiver.String(),
ReceiverChain: receiverChain.ChainId,
Amount: amount,
Message: message,
InboundHash: inboundHash.String(),
InboundBlockHeight: inboundBlockHeight,
GasLimit: gasLimit,
CoinType: cointType,
TxOrigin: sender.String(),
Asset: asset,
EventIndex: eventIndex,
ProtocolContractVersion: types.ProtocolContractVersion_V2,
}
cctx, err := types.NewCCTX(ctx, msg, tss.TssPubkey)
require.NoError(t, err)
require.Equal(t, receiver.String(), cctx.GetCurrentOutboundParam().Receiver)
require.Equal(t, receiverChain.ChainId, cctx.GetCurrentOutboundParam().ReceiverChainId)
require.Equal(t, sender.String(), cctx.GetInboundParams().Sender)
require.Equal(t, senderChain.ChainId, cctx.GetInboundParams().SenderChainId)
require.Equal(t, amount, cctx.GetInboundParams().Amount)
require.Equal(t, message, cctx.RelayedMessage)
require.Equal(t, inboundHash.String(), cctx.GetInboundParams().ObservedHash)
require.Equal(t, inboundBlockHeight, cctx.GetInboundParams().ObservedExternalHeight)
require.Equal(t, gasLimit, cctx.GetCurrentOutboundParam().GasLimit)
require.Equal(t, asset, cctx.GetInboundParams().Asset)
require.Equal(t, cointType, cctx.InboundParams.CoinType)
require.Equal(t, uint64(0), cctx.GetCurrentOutboundParam().TssNonce)
require.Equal(t, sdkmath.ZeroUint(), cctx.GetCurrentOutboundParam().Amount)
require.Equal(t, types.CctxStatus_PendingInbound, cctx.CctxStatus.Status)
require.Equal(t, false, cctx.CctxStatus.IsAbortRefunded)
require.Equal(t, types.ProtocolContractVersion_V2, cctx.ProtocolContractVersion)
})

t.Run("should return an error if the cctx is invalid", func(t *testing.T) {
_, ctx, _, _ := keepertest.CrosschainKeeper(t)
senderChain := chains.Goerli
sender := sample.EthAddress()
receiverChain := chains.Goerli
receiver := sample.EthAddress()
creator := sample.AccAddress()
amount := sdkmath.NewUint(42)
message := "test"
inboundBlockHeight := uint64(420)
inboundHash := sample.Hash()
gasLimit := uint64(100)
asset := "test-asset"
eventIndex := uint64(1)
cointType := coin.CoinType_ERC20
tss := sample.Tss()
msg := types.MsgVoteInbound{
Creator: creator,
Sender: "",
SenderChainId: senderChain.ChainId,
Receiver: receiver.String(),
ReceiverChain: receiverChain.ChainId,
Amount: amount,
Message: message,
InboundHash: inboundHash.String(),
InboundBlockHeight: inboundBlockHeight,
GasLimit: gasLimit,
CoinType: cointType,
TxOrigin: sender.String(),
Asset: asset,
EventIndex: eventIndex,
}
_, err := types.NewCCTX(ctx, msg, tss.TssPubkey)
require.ErrorContains(t, err, "sender cannot be empty")
})

t.Run("zero value for protocol contract version gives V1", func(t *testing.T) {
cctx := types.CrossChainTx{}
require.Equal(t, types.ProtocolContractVersion_V1, cctx.ProtocolContractVersion)
})
}

func TestCrossChainTx_Validate(t *testing.T) {
cctx := sample.CrossChainTx(t, "foo")
cctx.InboundParams = nil
Expand Down Expand Up @@ -221,66 +119,6 @@ func TestCrossChainTx_OriginalDestinationChainID(t *testing.T) {
require.Equal(t, cctx.OutboundParams[0].ReceiverChainId, cctx.OriginalDestinationChainID())
}

func TestCrossChainTx_AddOutbound(t *testing.T) {
t.Run("successfully get outbound tx", func(t *testing.T) {
_, ctx, _, _ := keepertest.CrosschainKeeper(t)
cctx := sample.CrossChainTx(t, "test")
hash := sample.Hash().String()

err := cctx.AddOutbound(ctx, types.MsgVoteOutbound{
ValueReceived: cctx.GetCurrentOutboundParam().Amount,
ObservedOutboundHash: hash,
ObservedOutboundBlockHeight: 10,
ObservedOutboundGasUsed: 100,
ObservedOutboundEffectiveGasPrice: sdkmath.NewInt(100),
ObservedOutboundEffectiveGasLimit: 20,
}, observertypes.BallotStatus_BallotFinalized_SuccessObservation)
require.NoError(t, err)
require.Equal(t, cctx.GetCurrentOutboundParam().Hash, hash)
require.Equal(t, cctx.GetCurrentOutboundParam().GasUsed, uint64(100))
require.Equal(t, cctx.GetCurrentOutboundParam().EffectiveGasPrice, sdkmath.NewInt(100))
require.Equal(t, cctx.GetCurrentOutboundParam().EffectiveGasLimit, uint64(20))
require.Equal(t, cctx.GetCurrentOutboundParam().ObservedExternalHeight, uint64(10))
})

t.Run("successfully get outbound tx for failed ballot without amount check", func(t *testing.T) {
_, ctx, _, _ := keepertest.CrosschainKeeper(t)
cctx := sample.CrossChainTx(t, "test")
hash := sample.Hash().String()

err := cctx.AddOutbound(ctx, types.MsgVoteOutbound{
ObservedOutboundHash: hash,
ObservedOutboundBlockHeight: 10,
ObservedOutboundGasUsed: 100,
ObservedOutboundEffectiveGasPrice: sdkmath.NewInt(100),
ObservedOutboundEffectiveGasLimit: 20,
}, observertypes.BallotStatus_BallotFinalized_FailureObservation)
require.NoError(t, err)
require.Equal(t, cctx.GetCurrentOutboundParam().Hash, hash)
require.Equal(t, cctx.GetCurrentOutboundParam().GasUsed, uint64(100))
require.Equal(t, cctx.GetCurrentOutboundParam().EffectiveGasPrice, sdkmath.NewInt(100))
require.Equal(t, cctx.GetCurrentOutboundParam().EffectiveGasLimit, uint64(20))
require.Equal(t, cctx.GetCurrentOutboundParam().ObservedExternalHeight, uint64(10))
})

t.Run("failed to get outbound tx if amount does not match value received", func(t *testing.T) {
_, ctx, _, _ := keepertest.CrosschainKeeper(t)

cctx := sample.CrossChainTx(t, "test")
hash := sample.Hash().String()

err := cctx.AddOutbound(ctx, types.MsgVoteOutbound{
ValueReceived: sdkmath.NewUint(100),
ObservedOutboundHash: hash,
ObservedOutboundBlockHeight: 10,
ObservedOutboundGasUsed: 100,
ObservedOutboundEffectiveGasPrice: sdkmath.NewInt(100),
ObservedOutboundEffectiveGasLimit: 20,
}, observertypes.BallotStatus_BallotFinalized_SuccessObservation)
require.ErrorIs(t, err, sdkerrors.ErrInvalidRequest)
})
}

func Test_SetRevertOutboundValues(t *testing.T) {
t.Run("successfully set revert outbound values", func(t *testing.T) {
cctx := sample.CrossChainTx(t, "test")
Expand Down

0 comments on commit 87f6474

Please sign in to comment.