Skip to content

Commit

Permalink
refactor: simplify LaunchConsumer logic (#2260)
Browse files Browse the repository at this point in the history
* simplify LaunchConsumer

* fix tests
  • Loading branch information
mpoke authored Sep 12, 2024
1 parent 74257ca commit ce6c271
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 282 deletions.
27 changes: 12 additions & 15 deletions testutil/keeper/expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,18 @@ import (
func GetMocksForCreateConsumerClient(ctx sdk.Context, mocks *MockedKeepers,
expectedChainID string, expectedLatestHeight clienttypes.Height,
) []*gomock.Call {
// append MakeConsumerGenesis and CreateClient expectations
expectations := GetMocksForMakeConsumerGenesis(ctx, mocks, time.Hour)
createClientExp := mocks.MockClientKeeper.EXPECT().CreateClient(
gomock.Any(),
// Allows us to expect a match by field. These are the only two client state values
// that are dependent on parameters passed to CreateConsumerClient.
extra.StructMatcher().Field(
"ChainId", expectedChainID).Field(
"LatestHeight", expectedLatestHeight,
),
gomock.Any(),
).Return("clientID", nil).Times(1)
expectations = append(expectations, createClientExp)

return expectations
return []*gomock.Call{
mocks.MockClientKeeper.EXPECT().CreateClient(
gomock.Any(),
// Allows us to expect a match by field. These are the only two client state values
// that are dependent on parameters passed to CreateConsumerClient.
extra.StructMatcher().Field(
"ChainId", expectedChainID).Field(
"LatestHeight", expectedLatestHeight,
),
gomock.Any(),
).Return("clientID", nil).Times(1),
}
}

// GetMocksForMakeConsumerGenesis returns mock expectations needed to call MakeConsumerGenesis().
Expand Down
4 changes: 1 addition & 3 deletions testutil/keeper/unit_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ func SetupForDeleteConsumerChain(t *testing.T, ctx sdk.Context,
) {
t.Helper()

SetupMocksForLastBondedValidatorsExpectation(mocks.MockStakingKeeper, 1, []stakingtypes.Validator{}, 1)

expectations := GetMocksForCreateConsumerClient(ctx, &mocks,
"chainID", clienttypes.NewHeight(4, 5))
expectations = append(expectations, GetMocksForSetConsumerChain(ctx, &mocks, "chainID")...)
Expand All @@ -241,7 +239,7 @@ func SetupForDeleteConsumerChain(t *testing.T, ctx sdk.Context,
// set the chain to initialized so that we can create a consumer client
providerKeeper.SetConsumerPhase(ctx, consumerId, providertypes.CONSUMER_PHASE_INITIALIZED)

err = providerKeeper.CreateConsumerClient(ctx, consumerId)
err = providerKeeper.CreateConsumerClient(ctx, consumerId, []byte{})
require.NoError(t, err)
// set the mapping consumer ID <> client ID for the consumer chain
providerKeeper.SetConsumerClientId(ctx, consumerId, "clientID")
Expand Down
201 changes: 93 additions & 108 deletions x/ccv/provider/keeper/consumer_lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"

abci "github.com/cometbft/cometbft/abci/types"
tmtypes "github.com/cometbft/cometbft/types"

"github.com/cosmos/interchain-security/v6/x/ccv/provider/types"
Expand Down Expand Up @@ -62,6 +63,9 @@ func (k Keeper) InitializeConsumer(ctx sdk.Context, consumerId string) (time.Tim

// BeginBlockLaunchConsumers launches initialized consumers chains for which the spawn time has passed
func (k Keeper) BeginBlockLaunchConsumers(ctx sdk.Context) error {
bondedValidators := []stakingtypes.Validator{}
activeValidators := []stakingtypes.Validator{}

consumerIds, err := k.ConsumeIdsFromTimeQueue(
ctx,
types.SpawnTimeToConsumerIdsKeyPrefix(),
Expand All @@ -73,9 +77,22 @@ func (k Keeper) BeginBlockLaunchConsumers(ctx sdk.Context) error {
if err != nil {
return errorsmod.Wrapf(ccv.ErrInvalidConsumerState, "getting consumers ready to laumch: %s", err.Error())
}
if len(consumerIds) > 0 {
// get the bonded validators from the staking module
bondedValidators, err = k.GetLastBondedValidators(ctx)
if err != nil {
return fmt.Errorf("getting last bonded validators: %w", err)
}
// get the provider active validators
activeValidators, err = k.GetLastProviderConsensusActiveValidators(ctx)
if err != nil {
return fmt.Errorf("getting last provider active validators: %w", err)
}
}

for _, consumerId := range consumerIds {
cachedCtx, writeFn := ctx.CacheContext()
err = k.LaunchConsumer(cachedCtx, consumerId)
err = k.LaunchConsumer(cachedCtx, bondedValidators, activeValidators, consumerId)
if err != nil {
ctx.Logger().Error("could not launch chain",
"consumerId", consumerId,
Expand Down Expand Up @@ -162,29 +179,64 @@ func (k Keeper) ConsumeIdsFromTimeQueue(

// LaunchConsumer launches the chain with the provided consumer id by creating the consumer client and the respective
// consumer genesis file
func (k Keeper) LaunchConsumer(ctx sdk.Context, consumerId string) error {
err := k.CreateConsumerClient(ctx, consumerId)
//
// TODO add unit test for LaunchConsumer
func (k Keeper) LaunchConsumer(
ctx sdk.Context,
bondedValidators []stakingtypes.Validator,
activeValidators []stakingtypes.Validator,
consumerId string,
) error {
// compute consumer initial validator set
initialValUpdates, err := k.ComputeConsumerNextValSet(ctx, bondedValidators, activeValidators, consumerId, []types.ConsensusValidator{})
if err != nil {
return err
return fmt.Errorf("computing consumer next validator set, consumerId(%s): %w", consumerId, err)
}
if len(initialValUpdates) == 0 {
return fmt.Errorf("cannot launch consumer with no validator opted in, consumerId(%s)", consumerId)
}

consumerGenesis, found := k.GetConsumerGenesis(ctx, consumerId)
if !found {
return errorsmod.Wrapf(types.ErrNoConsumerGenesis, "consumer genesis could not be found for consumer id: %s", consumerId)
// create consumer genesis
genesisState, err := k.MakeConsumerGenesis(ctx, consumerId, initialValUpdates)
if err != nil {
return fmt.Errorf("creating consumer genesis state, consumerId(%s): %w", consumerId, err)
}
err = k.SetConsumerGenesis(ctx, consumerId, genesisState)
if err != nil {
return fmt.Errorf("setting consumer genesis state, consumerId(%s): %w", consumerId, err)
}

if len(consumerGenesis.Provider.InitialValSet) == 0 {
return errorsmod.Wrapf(types.ErrInvalidConsumerGenesis, "consumer genesis initial validator set is empty - no validators opted in consumer id: %s", consumerId)
// compute the hash of the consumer initial validator updates
updatesAsValSet, err := tmtypes.PB2TM.ValidatorUpdates(initialValUpdates)
if err != nil {
return fmt.Errorf("unable to create initial validator set from initial validator updates: %w", err)
}
valsetHash := tmtypes.NewValidatorSet(updatesAsValSet).Hash()

// create the consumer client and the genesis
err = k.CreateConsumerClient(ctx, consumerId, valsetHash)
if err != nil {
return fmt.Errorf("crating consumer client, consumerId(%s): %w", consumerId, err)
}

k.SetConsumerPhase(ctx, consumerId, types.CONSUMER_PHASE_LAUNCHED)

k.Logger(ctx).Info("consumer successfully launched",
"consumerId", consumerId,
"valset size", len(initialValUpdates),
"valsetHash", string(valsetHash),
)

return nil
}

// CreateConsumerClient will create the CCV client for the given consumer chain. The CCV channel must be built
// on top of the CCV client to ensure connection with the right consumer chain.
func (k Keeper) CreateConsumerClient(ctx sdk.Context, consumerId string) error {
func (k Keeper) CreateConsumerClient(
ctx sdk.Context,
consumerId string,
valsetHash []byte,
) error {
initializationRecord, err := k.GetConsumerInitializationParameters(ctx, consumerId)
if err != nil {
return err
Expand Down Expand Up @@ -219,20 +271,11 @@ func (k Keeper) CreateConsumerClient(ctx sdk.Context, consumerId string) error {
clientState.TrustingPeriod = trustPeriod
clientState.UnbondingPeriod = consumerUnbondingPeriod

consumerGen, validatorSetHash, err := k.MakeConsumerGenesis(ctx, consumerId)
if err != nil {
return err
}
err = k.SetConsumerGenesis(ctx, consumerId, consumerGen)
if err != nil {
return err
}

// Create consensus state
consensusState := ibctmtypes.NewConsensusState(
ctx.BlockTime(),
commitmenttypes.NewMerkleRoot([]byte(ibctmtypes.SentinelRoot)),
validatorSetHash, // use the hash of the updated initial valset
valsetHash,
)

clientID, err := k.clientKeeper.CreateClient(ctx, clientState, consensusState)
Expand All @@ -241,7 +284,7 @@ func (k Keeper) CreateConsumerClient(ctx sdk.Context, consumerId string) error {
}
k.SetConsumerClientId(ctx, consumerId, clientID)

k.Logger(ctx).Info("consumer chain launched (client created)",
k.Logger(ctx).Info("consumer client created",
"consumer id", consumerId,
"client id", clientID,
)
Expand All @@ -256,6 +299,7 @@ func (k Keeper) CreateConsumerClient(ctx sdk.Context, consumerId string) error {
sdk.NewAttribute(types.AttributeInitialHeight, initializationRecord.InitialHeight.String()),
sdk.NewAttribute(types.AttributeTrustingPeriod, clientState.TrustingPeriod.String()),
sdk.NewAttribute(types.AttributeUnbondingPeriod, clientState.UnbondingPeriod.String()),
sdk.NewAttribute(types.AttributeValsetHash, string(valsetHash)),
),
)

Expand All @@ -267,21 +311,36 @@ func (k Keeper) CreateConsumerClient(ctx sdk.Context, consumerId string) error {
func (k Keeper) MakeConsumerGenesis(
ctx sdk.Context,
consumerId string,
) (gen ccv.ConsumerGenesisState, nextValidatorsHash []byte, err error) {
initialValidatorUpdates []abci.ValidatorUpdate,
) (gen ccv.ConsumerGenesisState, err error) {
initializationRecord, err := k.GetConsumerInitializationParameters(ctx, consumerId)
if err != nil {
return gen, nil, errorsmod.Wrapf(ccv.ErrInvalidConsumerState,
"cannot retrieve initialization parameters: %s", err.Error())
}
powerShapingParameters, err := k.GetConsumerPowerShapingParameters(ctx, consumerId)
if err != nil {
return gen, nil, errorsmod.Wrapf(ccv.ErrInvalidConsumerState,
"cannot retrieve power shaping parameters: %s", err.Error())
return gen, errorsmod.Wrapf(ccv.ErrInvalidConsumerState,
"getting initialization parameters, consumerId(%s): %s", consumerId, err.Error())
}
// note that providerFeePoolAddrStr is sent to the consumer during the IBC Channel handshake;
// see HandshakeMetadata in OnChanOpenTry on the provider-side, and OnChanOpenAck on the consumer-side
consumerGenesisParams := ccv.NewParams(
true,
initializationRecord.BlocksPerDistributionTransmission,
initializationRecord.DistributionTransmissionChannel,
"", // providerFeePoolAddrStr,
initializationRecord.CcvTimeoutPeriod,
initializationRecord.TransferTimeoutPeriod,
initializationRecord.ConsumerRedistributionFraction,
initializationRecord.HistoricalEntries,
initializationRecord.UnbondingPeriod,
[]string{},
[]string{},
ccv.DefaultRetryDelayPeriod,
)

// create provider client state and consensus state for the consumer to be able
// to create a provider client

providerUnbondingPeriod, err := k.stakingKeeper.UnbondingTime(ctx)
if err != nil {
return gen, nil, errorsmod.Wrapf(types.ErrNoUnbondingTime, "unbonding time not found: %s", err)
return gen, errorsmod.Wrapf(types.ErrNoUnbondingTime, "unbonding time not found: %s", err)
}
height := clienttypes.GetSelfHeight(ctx)

Expand All @@ -293,97 +352,23 @@ func (k Keeper) MakeConsumerGenesis(
clientState.LatestHeight = height
trustPeriod, err := ccv.CalculateTrustPeriod(providerUnbondingPeriod, k.GetTrustingPeriodFraction(ctx))
if err != nil {
return gen, nil, errorsmod.Wrapf(sdkerrors.ErrInvalidHeight, "error %s calculating trusting_period for: %s", err, height)
return gen, errorsmod.Wrapf(sdkerrors.ErrInvalidHeight, "error %s calculating trusting_period for: %s", err, height)
}
clientState.TrustingPeriod = trustPeriod
clientState.UnbondingPeriod = providerUnbondingPeriod

consState, err := k.clientKeeper.GetSelfConsensusState(ctx, height)
if err != nil {
return gen, nil, errorsmod.Wrapf(clienttypes.ErrConsensusStateNotFound, "error %s getting self consensus state for: %s", err, height)
}

// get the bonded validators from the staking module
bondedValidators, err := k.GetLastBondedValidators(ctx)
if err != nil {
return gen, nil, errorsmod.Wrapf(stakingtypes.ErrNoValidatorFound, "error getting last bonded validators: %s", err)
}

minPower := int64(0)
if powerShapingParameters.Top_N > 0 {
// get the consensus active validators
// we do not want to base the power calculation for the top N
// on inactive validators, too, since the top N will be a percentage of the active set power
// otherwise, it could be that inactive validators are forced to validate
activeValidators, err := k.GetLastProviderConsensusActiveValidators(ctx)
if err != nil {
return gen, nil, errorsmod.Wrapf(stakingtypes.ErrNoValidatorFound, "error getting last active bonded validators: %s", err)
}

// in a Top-N chain, we automatically opt in all validators that belong to the top N
minPower, err = k.ComputeMinPowerInTopN(ctx, activeValidators, powerShapingParameters.Top_N)
if err != nil {
return gen, nil, err
}
// log the minimum power in top N
k.Logger(ctx).Info("minimum power in top N at consumer genesis",
"consumerId", consumerId,
"minPower", minPower,
)

// set the minimal power of validators in the top N in the store
k.SetMinimumPowerInTopN(ctx, consumerId, minPower)

err = k.OptInTopNValidators(ctx, consumerId, activeValidators, minPower)
if err != nil {
return gen, nil, fmt.Errorf("unable to opt in topN validators in MakeConsumerGenesis, consumerId(%s): %w", consumerId, err)
}
}

// need to use the bondedValidators, not activeValidators, here since the chain might be opt-in and allow inactive vals
nextValidators, err := k.ComputeNextValidators(ctx, consumerId, bondedValidators, powerShapingParameters, minPower)
if err != nil {
return gen, nil, fmt.Errorf("unable to compute the next validators in MakeConsumerGenesis, consumerId(%s): %w", consumerId, err)
}
err = k.SetConsumerValSet(ctx, consumerId, nextValidators)
if err != nil {
return gen, nil, fmt.Errorf("unable to set consumer validator set in MakeConsumerGenesis, consumerId(%s): %w", consumerId, err)
}

// get the initial updates with the latest set consumer public keys
initialUpdatesWithConsumerKeys := DiffValidators([]types.ConsensusValidator{}, nextValidators)

// Get a hash of the consumer validator set from the update with applied consumer assigned keys
updatesAsValSet, err := tmtypes.PB2TM.ValidatorUpdates(initialUpdatesWithConsumerKeys)
if err != nil {
return gen, nil, fmt.Errorf("unable to create validator set from updates computed from key assignment in MakeConsumerGenesis: %s", err)
return gen, errorsmod.Wrapf(clienttypes.ErrConsensusStateNotFound, "error %s getting self consensus state for: %s", err, height)
}
hash := tmtypes.NewValidatorSet(updatesAsValSet).Hash()

// note that providerFeePoolAddrStr is sent to the consumer during the IBC Channel handshake;
// see HandshakeMetadata in OnChanOpenTry on the provider-side, and OnChanOpenAck on the consumer-side
consumerGenesisParams := ccv.NewParams(
true,
initializationRecord.BlocksPerDistributionTransmission,
initializationRecord.DistributionTransmissionChannel,
"", // providerFeePoolAddrStr,
initializationRecord.CcvTimeoutPeriod,
initializationRecord.TransferTimeoutPeriod,
initializationRecord.ConsumerRedistributionFraction,
initializationRecord.HistoricalEntries,
initializationRecord.UnbondingPeriod,
[]string{},
[]string{},
ccv.DefaultRetryDelayPeriod,
)

gen = *ccv.NewInitialConsumerGenesisState(
clientState,
consState.(*ibctmtypes.ConsensusState),
initialUpdatesWithConsumerKeys,
initialValidatorUpdates,
consumerGenesisParams,
)
return gen, hash, nil
return gen, nil
}

// StopAndPrepareForConsumerRemoval sets the phase of the chain to stopped and prepares to get the state of the
Expand Down
Loading

0 comments on commit ce6c271

Please sign in to comment.