diff --git a/x/wasm/keeper/handler_plugin_test.go b/x/wasm/keeper/handler_plugin_test.go index c88a46abb..ca0a234e9 100644 --- a/x/wasm/keeper/handler_plugin_test.go +++ b/x/wasm/keeper/handler_plugin_test.go @@ -6,9 +6,11 @@ import ( wasmvm "github.com/CosmWasm/wasmvm/v2" wasmvmtypes "github.com/CosmWasm/wasmvm/v2/types" + "github.com/cosmos/gogoproto/proto" capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types" clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types" //nolint:staticcheck channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" + ibcexported "github.com/cosmos/ibc-go/v8/modules/core/exported" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -240,11 +242,13 @@ func TestIBCRawPacketHandler(t *testing.T) { timeoutTimestamp uint64 data []byte } - var capturedPacket *CapturedPacket + var capturedPacketSent *CapturedPacket + var capturedAck []byte + var capturedPacketAck *CapturedPacket - capturePacketsSenderMock := &wasmtesting.MockICS4Wrapper{ + capturingICS4Mock := &wasmtesting.MockICS4Wrapper{ SendPacketFn: func(ctx sdk.Context, channelCap *capabilitytypes.Capability, sourcePort, sourceChannel string, timeoutHeight clienttypes.Height, timeoutTimestamp uint64, data []byte) (uint64, error) { - capturedPacket = &CapturedPacket{ + capturedPacketSent = &CapturedPacket{ sourcePort: sourcePort, sourceChannel: sourceChannel, timeoutHeight: timeoutHeight, @@ -253,6 +257,17 @@ func TestIBCRawPacketHandler(t *testing.T) { } return 1, nil }, + WriteAcknowledgementFn: func(ctx sdk.Context, chanCap *capabilitytypes.Capability, packet ibcexported.PacketI, acknowledgement ibcexported.Acknowledgement) error { + capturedPacketAck = &CapturedPacket{ + sourcePort: packet.GetSourcePort(), + sourceChannel: packet.GetSourceChannel(), + timeoutHeight: packet.GetTimeoutHeight().(clienttypes.Height), + timeoutTimestamp: packet.GetTimeoutTimestamp(), + data: packet.GetData(), + } + capturedAck = acknowledgement.Acknowledgement() + return nil + }, } chanKeeper := &wasmtesting.MockChannelKeeper{ GetChannelFn: func(ctx sdk.Context, srcPort, srcChan string) (channeltypes.Channel, bool) { @@ -270,19 +285,39 @@ func TestIBCRawPacketHandler(t *testing.T) { }, } contractKeeper := wasmtesting.IBCContractKeeperMock{} + // also store a packet to be acked + ackPacket := channeltypes.Packet{ + Sequence: 1, + SourcePort: "src-port", + SourceChannel: "channel-0", + DestinationPort: ibcPort, + DestinationChannel: "channel-1", + Data: []byte{}, + TimeoutHeight: clienttypes.Height{}, + TimeoutTimestamp: 1720000000000000000, + } + contractKeeper.StoreAsyncAckPacket(ctx, ackPacket) + + sendResponse := types.MsgIBCSendResponse{Sequence: 1} + ackResponse := types.MsgIBCWriteAcknowledgementResponse{} specs := map[string]struct { - srcMsg wasmvmtypes.SendPacketMsg + srcMsg wasmvmtypes.IBCMsg chanKeeper types.ChannelKeeper capKeeper types.CapabilityKeeper expPacketSent *CapturedPacket + expPacketAck *CapturedPacket + expAck []byte expErr *errorsmod.Error + expResp proto.Message }{ - "all good": { - srcMsg: wasmvmtypes.SendPacketMsg{ - ChannelID: "channel-1", - Data: []byte("myData"), - Timeout: wasmvmtypes.IBCTimeout{Block: &wasmvmtypes.IBCTimeoutBlock{Revision: 1, Height: 2}}, + "send packet, all good": { + srcMsg: wasmvmtypes.IBCMsg{ + SendPacket: &wasmvmtypes.SendPacketMsg{ + ChannelID: "channel-1", + Data: []byte("myData"), + Timeout: wasmvmtypes.IBCTimeout{Block: &wasmvmtypes.IBCTimeoutBlock{Revision: 1, Height: 2}}, + }, }, chanKeeper: chanKeeper, capKeeper: capKeeper, @@ -292,12 +327,15 @@ func TestIBCRawPacketHandler(t *testing.T) { timeoutHeight: clienttypes.Height{RevisionNumber: 1, RevisionHeight: 2}, data: []byte("myData"), }, + expResp: &sendResponse, }, - "capability not found returns error": { - srcMsg: wasmvmtypes.SendPacketMsg{ - ChannelID: "channel-1", - Data: []byte("myData"), - Timeout: wasmvmtypes.IBCTimeout{Block: &wasmvmtypes.IBCTimeoutBlock{Revision: 1, Height: 2}}, + "send packet, capability not found returns error": { + srcMsg: wasmvmtypes.IBCMsg{ + SendPacket: &wasmvmtypes.SendPacketMsg{ + ChannelID: "channel-1", + Data: []byte("myData"), + Timeout: wasmvmtypes.IBCTimeout{Block: &wasmvmtypes.IBCTimeoutBlock{Revision: 1, Height: 2}}, + }, }, chanKeeper: chanKeeper, capKeeper: wasmtesting.MockCapabilityKeeper{ @@ -307,13 +345,37 @@ func TestIBCRawPacketHandler(t *testing.T) { }, expErr: channeltypes.ErrChannelCapabilityNotFound, }, + "async ack, all good": { + srcMsg: wasmvmtypes.IBCMsg{ + WriteAcknowledgement: &wasmvmtypes.WriteAcknowledgementMsg{ + ChannelID: "channel-1", + PacketSequence: 1, + Ack: wasmvmtypes.IBCAcknowledgement{Data: []byte("myAck")}, + }, + }, + chanKeeper: chanKeeper, + capKeeper: capKeeper, + expPacketAck: &CapturedPacket{ + sourcePort: ackPacket.SourcePort, + sourceChannel: ackPacket.SourceChannel, + timeoutHeight: ackPacket.TimeoutHeight, + timeoutTimestamp: ackPacket.TimeoutTimestamp, + data: ackPacket.Data, + }, + expAck: []byte("myAck"), + expResp: &ackResponse, + }, } for name, spec := range specs { t.Run(name, func(t *testing.T) { - capturedPacket = nil + capturedPacketSent = nil + capturedAck = nil + capturedPacketAck = nil + // when - h := NewIBCRawPacketHandler(capturePacketsSenderMock, &contractKeeper, spec.chanKeeper, spec.capKeeper) - evts, data, msgResponses, gotErr := h.DispatchMsg(ctx, RandomAccountAddress(t), ibcPort, wasmvmtypes.CosmosMsg{IBC: &wasmvmtypes.IBCMsg{SendPacket: &spec.srcMsg}}) //nolint:gosec + h := NewIBCRawPacketHandler(capturingICS4Mock, &contractKeeper, spec.chanKeeper, spec.capKeeper) + evts, data, msgResponses, gotErr := h.DispatchMsg(ctx, RandomAccountAddress(t), ibcPort, wasmvmtypes.CosmosMsg{IBC: &spec.srcMsg}) //nolint:gosec + // then require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr) if spec.expErr != nil { @@ -323,16 +385,16 @@ func TestIBCRawPacketHandler(t *testing.T) { assert.Nil(t, evts) require.NotNil(t, data) assert.Len(t, msgResponses, 1) - assert.Equal(t, "/cosmwasm.wasm.v1.MsgIBCSendResponse", msgResponses[0][0].TypeUrl) - - expMsg := types.MsgIBCSendResponse{Sequence: 1} + assert.Equal(t, "/"+proto.MessageName(spec.expResp), msgResponses[0][0].TypeUrl) - actualMsg := types.MsgIBCSendResponse{} - err := actualMsg.Unmarshal(data[0]) + // compare expected responses + expResp, err := proto.Marshal(spec.expResp) require.NoError(t, err) + assert.Equal(t, expResp, msgResponses[0][0].Value) - assert.Equal(t, expMsg, actualMsg) - assert.Equal(t, spec.expPacketSent, capturedPacket) + assert.Equal(t, spec.expPacketSent, capturedPacketSent) + assert.Equal(t, spec.expAck, capturedAck) + assert.Equal(t, spec.expPacketAck, capturedPacketAck) }) } } diff --git a/x/wasm/keeper/wasmtesting/mock_keepers.go b/x/wasm/keeper/wasmtesting/mock_keepers.go index 4e30964cb..5c5f2a7f6 100644 --- a/x/wasm/keeper/wasmtesting/mock_keepers.go +++ b/x/wasm/keeper/wasmtesting/mock_keepers.go @@ -69,7 +69,8 @@ func (m *MockChannelKeeper) SetChannel(ctx sdk.Context, portID, channelID string var _ types.ICS4Wrapper = &MockICS4Wrapper{} type MockICS4Wrapper struct { - SendPacketFn func(ctx sdk.Context, channelCap *capabilitytypes.Capability, sourcePort, sourceChannel string, timeoutHeight clienttypes.Height, timeoutTimestamp uint64, data []byte) (uint64, error) + SendPacketFn func(ctx sdk.Context, channelCap *capabilitytypes.Capability, sourcePort, sourceChannel string, timeoutHeight clienttypes.Height, timeoutTimestamp uint64, data []byte) (uint64, error) + WriteAcknowledgementFn func(ctx sdk.Context, chanCap *capabilitytypes.Capability, packet ibcexported.PacketI, acknowledgement ibcexported.Acknowledgement) error } func (m *MockICS4Wrapper) SendPacket(ctx sdk.Context, channelCap *capabilitytypes.Capability, sourcePort, sourceChannel string, timeoutHeight clienttypes.Height, timeoutTimestamp uint64, data []byte) (uint64, error) { @@ -85,8 +86,10 @@ func (m *MockICS4Wrapper) WriteAcknowledgement( packet ibcexported.PacketI, acknowledgement ibcexported.Acknowledgement, ) error { - // TODO: implement mocking - panic("not supposed to be called!") + if m.WriteAcknowledgementFn == nil { + panic("not supposed to be called!") + } + return m.WriteAcknowledgementFn(ctx, chanCap, packet, acknowledgement) } func MockChannelKeeperIterator(s []channeltypes.IdentifiedChannel) func(ctx sdk.Context, cb func(channeltypes.IdentifiedChannel) bool) {