Skip to content

Commit

Permalink
fix desired connection pooling state usage (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
posriniv authored Nov 28, 2024
1 parent 2c288e9 commit 98f64a6
Showing 1 changed file with 69 additions and 11 deletions.
80 changes: 69 additions & 11 deletions managed/resource_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ func (r resourceClusterType) GetSchema(_ context.Context) (tfsdk.Schema, diag.Di
Computed: true,
},
}
// remove once feature flag is enabled
// Remove once feature flag is enabled
// TODO: Think of a more scalable solution
if !fflags.IsFeatureFlagEnabled(fflags.CONNECTION_POOLING) {
delete(attributes, "desired_connection_pooling_state")
Expand Down Expand Up @@ -796,7 +796,9 @@ func getPlan(ctx context.Context, plan tfsdk.Plan, cluster *Cluster) diag.Diagno
diags.Append(plan.GetAttribute(ctx, path.Root("restore_backup_id"), &cluster.RestoreBackupID)...)
diags.Append(plan.GetAttribute(ctx, path.Root("database_track"), &cluster.DatabaseTrack)...)
diags.Append(plan.GetAttribute(ctx, path.Root("desired_state"), &cluster.DesiredState)...)
diags.Append(plan.GetAttribute(ctx, path.Root("desired_connection_pooling_state"), &cluster.DesiredConnectionPoolingState)...)
if fflags.IsFeatureFlagEnabled(fflags.CONNECTION_POOLING) {
diags.Append(plan.GetAttribute(ctx, path.Root("desired_connection_pooling_state"), &cluster.DesiredConnectionPoolingState)...)
}
diags.Append(plan.GetAttribute(ctx, path.Root("node_config"), &cluster.NodeConfig)...)
diags.Append(plan.GetAttribute(ctx, path.Root("credentials"), &cluster.Credentials)...)
diags.Append(plan.GetAttribute(ctx, path.Root("backup_schedules"), &cluster.BackupSchedules)...)
Expand All @@ -811,7 +813,9 @@ func getIDsFromState(ctx context.Context, state tfsdk.State, cluster *Cluster) {
state.GetAttribute(ctx, path.Root("project_id"), &cluster.ProjectID)
state.GetAttribute(ctx, path.Root("cluster_id"), &cluster.ClusterID)
state.GetAttribute(ctx, path.Root("desired_state"), &cluster.DesiredState)
state.GetAttribute(ctx, path.Root("desired_connection_pooling_state"), &cluster.DesiredConnectionPoolingState)
if fflags.IsFeatureFlagEnabled(fflags.CONNECTION_POOLING) {
state.GetAttribute(ctx, path.Root("desired_connection_pooling_state"), &cluster.DesiredConnectionPoolingState)
}
state.GetAttribute(ctx, path.Root("cluster_allow_list_ids"), &cluster.ClusterAllowListIDs)
state.GetAttribute(ctx, path.Root("cluster_region_info"), &cluster.ClusterRegionInfo)
state.GetAttribute(ctx, path.Root("backup_schedules"), &cluster.BackupSchedules)
Expand All @@ -835,6 +839,65 @@ func validateCredentials(credentials Credentials) bool {

}

// This function is needed to fix deserialization into TF state when connection pooling is removed from schema
func setClusterState(ctx context.Context, state *tfsdk.State, cluster *Cluster) diag.Diagnostics {
// Create temporary struct without DesiredConnectionPoolingState field
if fflags.IsFeatureFlagEnabled(fflags.CONNECTION_POOLING) {
return state.Set(ctx, cluster)
}
tempState := struct {
AccountID types.String `tfsdk:"account_id"`
ProjectID types.String `tfsdk:"project_id"`
ClusterID types.String `tfsdk:"cluster_id"`
ClusterName types.String `tfsdk:"cluster_name"`
CloudType types.String `tfsdk:"cloud_type"`
ClusterType types.String `tfsdk:"cluster_type"`
FaultTolerance types.String `tfsdk:"fault_tolerance"`
NumFaultsToTolerate types.Int64 `tfsdk:"num_faults_to_tolerate"`
ClusterRegionInfo []RegionInfo `tfsdk:"cluster_region_info"`
DatabaseTrack types.String `tfsdk:"database_track"`
DesiredState types.String `tfsdk:"desired_state"`
ClusterTier types.String `tfsdk:"cluster_tier"`
ClusterAllowListIDs []types.String `tfsdk:"cluster_allow_list_ids"`
RestoreBackupID types.String `tfsdk:"restore_backup_id"`
NodeConfig NodeConfig `tfsdk:"node_config"`
Credentials Credentials `tfsdk:"credentials"`
ClusterInfo ClusterInfo `tfsdk:"cluster_info"`
ClusterVersion types.String `tfsdk:"cluster_version"`
BackupSchedules []BackupScheduleInfo `tfsdk:"backup_schedules"`
ClusterEndpoints types.Map `tfsdk:"cluster_endpoints"`
ClusterEndpointsV2 []ClusterEndpoint `tfsdk:"endpoints"`
ClusterCertificate types.String `tfsdk:"cluster_certificate"`
CMKSpec *CMKSpec `tfsdk:"cmk_spec"`
}{
AccountID: cluster.AccountID,
ProjectID: cluster.ProjectID,
ClusterID: cluster.ClusterID,
ClusterName: cluster.ClusterName,
CloudType: cluster.CloudType,
ClusterType: cluster.ClusterType,
FaultTolerance: cluster.FaultTolerance,
NumFaultsToTolerate: cluster.NumFaultsToTolerate,
ClusterRegionInfo: cluster.ClusterRegionInfo,
DatabaseTrack: cluster.DatabaseTrack,
DesiredState: cluster.DesiredState,
ClusterTier: cluster.ClusterTier,
ClusterAllowListIDs: cluster.ClusterAllowListIDs,
RestoreBackupID: cluster.RestoreBackupID,
NodeConfig: cluster.NodeConfig,
Credentials: cluster.Credentials,
ClusterInfo: cluster.ClusterInfo,
ClusterVersion: cluster.ClusterVersion,
BackupSchedules: cluster.BackupSchedules,
ClusterEndpoints: cluster.ClusterEndpoints,
ClusterEndpointsV2: cluster.ClusterEndpointsV2,
ClusterCertificate: cluster.ClusterCertificate,
CMKSpec: cluster.CMKSpec,
}

return state.Set(ctx, &tempState)
}

func validateOnlyOneCMKSpec(plan *Cluster) error {
count := 0

Expand Down Expand Up @@ -1279,8 +1342,7 @@ func (r resourceCluster) Create(ctx context.Context, req tfsdk.CreateResourceReq
if fflags.IsFeatureFlagEnabled(fflags.CONNECTION_POOLING) && (strings.EqualFold(plan.DesiredConnectionPoolingState.Value, "Enabled") || strings.EqualFold(plan.DesiredConnectionPoolingState.Value, "Disabled")) {
cluster.DesiredConnectionPoolingState.Value = plan.DesiredConnectionPoolingState.Value
}

diags := resp.State.Set(ctx, &cluster)
diags := setClusterState(ctx, &resp.State, &cluster)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return
Expand Down Expand Up @@ -1567,8 +1629,7 @@ func (r resourceCluster) Read(ctx context.Context, req tfsdk.ReadResourceRequest
if !state.RestoreBackupID.Null {
req.State.GetAttribute(ctx, path.Root("restore_backup_id"), &cluster.RestoreBackupID)
}

diags := resp.State.Set(ctx, &cluster)
diags := setClusterState(ctx, &resp.State, &cluster)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return
Expand Down Expand Up @@ -1966,8 +2027,6 @@ func (r resourceCluster) Update(ctx context.Context, req tfsdk.UpdateResourceReq
}
}

tflog.Info(ctx, fmt.Sprintf("Existing Desired Connection Pooling State in State is %v", state.DesiredConnectionPoolingState.Value))

// Disable Connection Pooling if the desired state is set to 'Disabled' and it is enabled currently
if fflags.IsFeatureFlagEnabled(fflags.CONNECTION_POOLING) && !state.DesiredConnectionPoolingState.Unknown && strings.EqualFold(state.DesiredConnectionPoolingState.Value, "Enabled") && (plan.DesiredConnectionPoolingState.Unknown || strings.EqualFold(plan.DesiredConnectionPoolingState.Value, "Disabled")) {
// Disable Connection Pooling
Expand Down Expand Up @@ -2275,8 +2334,7 @@ func (r resourceCluster) Update(ctx context.Context, req tfsdk.UpdateResourceReq
if fflags.IsFeatureFlagEnabled(fflags.CONNECTION_POOLING) && (strings.EqualFold(plan.DesiredConnectionPoolingState.Value, "Enabled") || strings.EqualFold(plan.DesiredConnectionPoolingState.Value, "Disabled")) {
cluster.DesiredConnectionPoolingState.Value = plan.DesiredConnectionPoolingState.Value
}

diags := resp.State.Set(ctx, &cluster)
diags := setClusterState(ctx, &resp.State, &cluster)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return
Expand Down

0 comments on commit 98f64a6

Please sign in to comment.