Skip to content

Commit

Permalink
Introduce custom logic into OnRecvPacket in PFM module to charge fee …
Browse files Browse the repository at this point in the history
…to prevent a spams (#519)

All test passed.
  • Loading branch information
RustNinja authored May 22, 2024
1 parent 154f046 commit 2b9a785
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 11 deletions.
3 changes: 2 additions & 1 deletion app/keepers/keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ func (appKeepers *AppKeepers) InitNormalKeepers(
0,
routerkeeper.DefaultForwardTransferPacketTimeoutTimestamp,
routerkeeper.DefaultRefundTransferPacketTimeoutTimestamp,
appKeepers.TransferMiddlewareKeeper,
&appKeepers.IbcTransferMiddlewareKeeper,
&appKeepers.BankKeeper,
)
ratelimitMiddlewareStack := ratelimitmodule.NewIBCMiddleware(appKeepers.RatelimitKeeper, ibcMiddlewareStack)
hooksTransferMiddleware := ibc_hooks.NewIBCMiddleware(ratelimitMiddlewareStack, &appKeepers.HooksICS4Wrapper)
Expand Down
71 changes: 61 additions & 10 deletions custom/custompfm/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
"github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v7/packetforward/keeper"
"github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v7/packetforward/types"
transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types"
clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types"
channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"
porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types"
ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported"
ibctransfermiddlewarekeeper "github.com/notional-labs/composable/v6/x/transfermiddleware/keeper"
custombankkeeper "github.com/notional-labs/composable/v6/custom/bank/keeper"
ibctransfermiddlewarekeeper "github.com/notional-labs/composable/v6/x/ibctransfermiddleware/keeper"
)

var _ porttypes.Middleware = &IBCMiddleware{}
Expand All @@ -25,14 +27,15 @@ var _ porttypes.Middleware = &IBCMiddleware{}
// forward keeper and the underlying application.
type IBCMiddleware struct {
router.IBCMiddleware
ibcfeekeeper ibctransfermiddlewarekeeper.Keeper

app1 porttypes.IBCModule
keeper1 *keeper.Keeper

retriesOnTimeout1 uint8
forwardTimeout1 time.Duration
refundTimeout1 time.Duration
ibcfeekeeper *ibctransfermiddlewarekeeper.Keeper
bank *custombankkeeper.Keeper
}

func NewIBCMiddleware(
Expand All @@ -41,18 +44,19 @@ func NewIBCMiddleware(
retriesOnTimeout uint8,
forwardTimeout time.Duration,
refundTimeout time.Duration,
ibcfeekeeper ibctransfermiddlewarekeeper.Keeper,
ibcfeekeeper *ibctransfermiddlewarekeeper.Keeper,
bankkeeper *custombankkeeper.Keeper,
) IBCMiddleware {
return IBCMiddleware{
IBCMiddleware: router.NewIBCMiddleware(app, k, retriesOnTimeout, forwardTimeout, refundTimeout),
ibcfeekeeper: ibcfeekeeper,

//we need this bz this field is not exported in the parent struct
app1: app,
keeper1: k,
retriesOnTimeout1: retriesOnTimeout,
forwardTimeout1: forwardTimeout,
refundTimeout1: refundTimeout,
bank: bankkeeper,
}
}

Expand All @@ -63,8 +67,6 @@ func (im IBCMiddleware) OnRecvPacket(
) ibcexported.Acknowledgement {
logger := im.keeper1.Logger(ctx)

im.ibcfeekeeper.GetParams(ctx)

var data transfertypes.FungibleTokenPacketData
if err := transfertypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil {
logger.Debug(fmt.Sprintf("packetForwardMiddleware OnRecvPacket payload is not a FungibleTokenPacketData: %s", err.Error()))
Expand Down Expand Up @@ -160,6 +162,59 @@ func (im IBCMiddleware) OnRecvPacket(
retries = im.retriesOnTimeout1
}

// im.ibcfeekeeper.Transfer()

feeAmount := sdk.NewDecFromInt(token.Amount).Mul(im.keeper1.GetFeePercentage(ctx)).RoundInt()
packetAmount := token.Amount.Sub(feeAmount)
packetCoin := sdk.NewCoin(token.Denom, packetAmount)

memo := ""

// set memo for next transfer with next from this transfer.
if metadata.Next != nil {
memoBz, err := json.Marshal(metadata.Next)
if err != nil {
im.keeper1.Logger(ctx).Error("packetForwardMiddleware error marshaling next as JSON",
"error", err,
)
// return errorsmod.Wrapf(sdkerrors.ErrJSONMarshal, err.Error())
}
memo = string(memoBz)
}

tr := transfertypes.NewMsgTransfer(
metadata.Port,
metadata.Channel,
packetCoin,
overrideReceiver,
metadata.Receiver,
clienttypes.Height{
RevisionNumber: 0,
RevisionHeight: 0,
},
uint64(ctx.BlockTime().UnixNano())+uint64(timeout.Nanoseconds()),
memo,
)

result, err := im.ibcfeekeeper.ChargeFee(ctx, tr)
if err != nil {
logger.Error("packetForwardMiddleware OnRecvPacket error charging fee", "error", err)
return newErrorAcknowledgement(fmt.Errorf("error charging fee: %w", err))
}
if result != nil {
if result.Fee.Amount.LT(token.Amount) {
token = token.SubAmount(result.Fee.Amount)
} else {
send_err := im.bank.SendCoins(ctx, result.Sender, result.Receiver, sdk.NewCoins(result.Fee))
if send_err != nil {
logger.Error("packetForwardMiddleware OnRecvPacket error sending fee", "error", send_err)
return newErrorAcknowledgement(fmt.Errorf("error charging fee: %w", send_err))
}
ack := channeltypes.NewResultAcknowledgement([]byte{byte(1)})
return ack
}
}

err = im.keeper1.ForwardTransferPacket(ctx, nil, packet, data.Sender, overrideReceiver, metadata, token, retries, timeout, []metrics.Label{}, nonrefundable)
if err != nil {
logger.Error("packetForwardMiddleware OnRecvPacket error forwarding packet", "error", err)
Expand All @@ -169,10 +224,6 @@ func (im IBCMiddleware) OnRecvPacket(
// returning nil ack will prevent WriteAcknowledgement from occurring for forwarded packet.
// This is intentional so that the acknowledgement will be written later based on the ack/timeout of the forwarded packet.
return nil

// charge_coin := sdk.NewCoin(packet.Token.Denom, sdk.ZeroInt())
// return channeltypes.NewErrorAcknowledgement(fmt.Errorf("error parsing forward metadata"))
// return im.IBCMiddleware.OnRecvPacket(ctx, packet, relayer)
}

func newErrorAcknowledgement(err error) channeltypes.Acknowledgement {
Expand Down
113 changes: 113 additions & 0 deletions x/ibctransfermiddleware/keeper/keeper.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package keeper

import (
"encoding/json"
"fmt"
"time"

"github.com/cometbft/cometbft/libs/log"
"github.com/notional-labs/composable/v6/x/ibctransfermiddleware/types"

"github.com/cosmos/cosmos-sdk/codec"
storetypes "github.com/cosmos/cosmos-sdk/store/types"
sdk "github.com/cosmos/cosmos-sdk/types"
ibctypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types"
)

// Keeper of the staking middleware store
Expand Down Expand Up @@ -104,3 +109,111 @@ func (k Keeper) GetChannelFeeAddress(ctx sdk.Context, targetChannelID string) st
}
return channelFee.FeeAddress
}

type BridgeFee struct {
Fee sdk.Coin
Sender sdk.AccAddress
Receiver sdk.AccAddress
}

func (k Keeper) ChargeFee(ctx sdk.Context, msg *ibctypes.MsgTransfer) (*BridgeFee, error) {
params := k.GetParams(ctx)
// charge_coin := sdk.NewCoin(msg.Token.Denom, sdk.ZeroInt())
if params.ChannelFees != nil && len(params.ChannelFees) > 0 {
channelFee := findChannelParams(params.ChannelFees, msg.SourceChannel)
if channelFee != nil {
if channelFee.MinTimeoutTimestamp > 0 {

blockTime := ctx.BlockTime()

timeoutTimeInFuture := time.Unix(0, int64(msg.TimeoutTimestamp))
if timeoutTimeInFuture.Before(blockTime) {
return nil, fmt.Errorf("incorrect timeout timestamp found during ibc transfer. timeout timestamp is in the past")
}

difference := timeoutTimeInFuture.Sub(blockTime).Nanoseconds()
if difference < channelFee.MinTimeoutTimestamp {
return nil, fmt.Errorf("incorrect timeout timestamp found during ibc transfer. too soon")
}
}
coin := findCoinByDenom(channelFee.AllowedTokens, msg.Token.Denom)
if coin == nil {
return nil, fmt.Errorf("token not allowed to be transferred in this channel")
}

minFee := coin.MinFee.Amount
priority := GetPriority(msg.Memo)
if priority != nil {
p := findPriority(coin.TxPriorityFee, *priority)
if p != nil && coin.MinFee.Denom == p.PriorityFee.Denom {
minFee = minFee.Add(p.PriorityFee.Amount)
}
}

charge := minFee
if charge.GT(msg.Token.Amount) {
charge = msg.Token.Amount
}

newAmount := msg.Token.Amount.Sub(charge)

if newAmount.IsPositive() {
percentageCharge := newAmount.QuoRaw(coin.Percentage)
newAmount = newAmount.Sub(percentageCharge)
charge = charge.Add(percentageCharge)
}

msgSender, err := sdk.AccAddressFromBech32(msg.Sender)
if err != nil {
return nil, err
}

feeAddress, err := sdk.AccAddressFromBech32(channelFee.FeeAddress)
if err != nil {
return nil, err
}

charge_coin := sdk.NewCoin(msg.Token.Denom, charge)
// send_err := k.bank.SendCoins(ctx, msgSender, feeAddress, sdk.NewCoins(charge_coin))
// if send_err != nil {
// return nil, send_err
// }
msg.Token.Amount = newAmount
return &BridgeFee{Fee: charge_coin, Sender: msgSender, Receiver: feeAddress}, nil

// if newAmount.LTE(sdk.ZeroInt()) {
// zeroTransfer := sdk.NewCoin(msg.Token.Denom, sdk.ZeroInt())
// return &zeroTransfer, nil
// }
}
}
// ret, err := k.Keeper.Transfer(goCtx, msg)
// if err == nil && ret != nil && !charge_coin.IsZero() {
// if !charge_coin.IsZero() {
// k.SetSequenceFee(ctx, ret.Sequence, charge_coin)
// }
return nil, nil
}

func GetPriority(jsonString string) *string {
var data map[string]interface{}
if err := json.Unmarshal([]byte(jsonString), &data); err != nil {
return nil
}

priority, ok := data["priority"].(string)
if !ok {
return nil
}

return &priority
}

func findPriority(priorities []*types.TxPriorityFee, priority string) *types.TxPriorityFee {
for _, p := range priorities {
if p.Priority == priority {
return p
}
}
return nil
}

0 comments on commit 2b9a785

Please sign in to comment.