Skip to content

Commit

Permalink
replication: fix problem with setReplication
Browse files Browse the repository at this point in the history
When src is nil, we create a new instance
of csiReplication.ReplicationSource but
this newly created instance does not
propagate back to the caller since Go
passes pointers by value, This commit
fixes the problem.

Signed-off-by: Madhu Rajanna <[email protected]>
  • Loading branch information
Madhu-1 authored and mergify[bot] committed Jul 24, 2024
1 parent 24866a8 commit 463254d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
22 changes: 11 additions & 11 deletions internal/sidecar/service/volumereplication.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (rs *ReplicationServer) EnableVolumeReplication(
Parameters: req.GetParameters(),
Secrets: data,
}
err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource())
err = setReplicationSource(&repReq.ReplicationSource, req.GetReplicationSource())
if err != nil {
klog.Errorf("Failed to set replication source: %v", err)
return nil, status.Error(codes.Internal, err.Error())
Expand Down Expand Up @@ -102,7 +102,7 @@ func (rs *ReplicationServer) DisableVolumeReplication(
Parameters: req.GetParameters(),
Secrets: data,
}
err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource())
err = setReplicationSource(&repReq.ReplicationSource, req.GetReplicationSource())
if err != nil {
klog.Errorf("Failed to set replication source: %v", err)
return nil, status.Error(codes.Internal, err.Error())
Expand Down Expand Up @@ -135,7 +135,7 @@ func (rs *ReplicationServer) PromoteVolume(
Force: req.GetForce(),
Secrets: data,
}
err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource())
err = setReplicationSource(&repReq.ReplicationSource, req.GetReplicationSource())
if err != nil {
klog.Errorf("Failed to set replication source: %v", err)
return nil, status.Error(codes.Internal, err.Error())
Expand Down Expand Up @@ -168,7 +168,7 @@ func (rs *ReplicationServer) DemoteVolume(
Force: req.GetForce(),
Secrets: data,
}
err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource())
err = setReplicationSource(&repReq.ReplicationSource, req.GetReplicationSource())
if err != nil {
klog.Errorf("Failed to set replication source: %v", err)
return nil, status.Error(codes.Internal, err.Error())
Expand Down Expand Up @@ -201,7 +201,7 @@ func (rs *ReplicationServer) ResyncVolume(
Force: req.GetForce(),
Secrets: data,
}
err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource())
err = setReplicationSource(&repReq.ReplicationSource, req.GetReplicationSource())
if err != nil {
klog.Errorf("Failed to set replication source: %v", err)
return nil, status.Error(codes.Internal, err.Error())
Expand Down Expand Up @@ -234,7 +234,7 @@ func (rs *ReplicationServer) GetVolumeReplicationInfo(
ReplicationId: req.GetReplicationId(),
Secrets: data,
}
err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource())
err = setReplicationSource(&repReq.ReplicationSource, req.GetReplicationSource())
if err != nil {
klog.Errorf("Failed to set replication source: %v", err)
return nil, status.Error(codes.Internal, err.Error())
Expand All @@ -259,9 +259,9 @@ func (rs *ReplicationServer) GetVolumeReplicationInfo(
}

// setReplicationSource sets the replication source for the given ReplicationSource.
func setReplicationSource(src *csiReplication.ReplicationSource, req *proto.ReplicationSource) error {
if src == nil {
src = &csiReplication.ReplicationSource{}
func setReplicationSource(src **csiReplication.ReplicationSource, req *proto.ReplicationSource) error {
if *src == nil {
*src = &csiReplication.ReplicationSource{}
}

switch {
Expand All @@ -270,12 +270,12 @@ func setReplicationSource(src *csiReplication.ReplicationSource, req *proto.Repl
case req.GetVolume() == nil && req.GetVolumeGroup() == nil:
return errors.New("either volume or volume group is required")
case req.GetVolume() != nil:
src.Type = &csiReplication.ReplicationSource_Volume{Volume: &csiReplication.ReplicationSource_VolumeSource{
(*src).Type = &csiReplication.ReplicationSource_Volume{Volume: &csiReplication.ReplicationSource_VolumeSource{
VolumeId: req.GetVolume().GetVolumeId(),
}}
return nil
case req.GetVolumeGroup() != nil:
src.Type = &csiReplication.ReplicationSource_Volumegroup{Volumegroup: &csiReplication.ReplicationSource_VolumeGroupSource{
(*src).Type = &csiReplication.ReplicationSource_Volumegroup{Volumegroup: &csiReplication.ReplicationSource_VolumeGroupSource{
VolumeGroupId: req.GetVolumeGroup().GetVolumeGroupId(),
}}
return nil
Expand Down
26 changes: 20 additions & 6 deletions internal/sidecar/service/volumereplication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,27 @@ func Test_setReplicationSource(t *testing.T) {
wantErr: true,
},
{
name: "set replication source when request is not set",
name: "set replication source when request is nil",
args: args{
src: &csiReplication.ReplicationSource{},
req: &proto.ReplicationSource{},
src: nil,
req: nil,
},
wantErr: true,
},
{
name: "set replication source is nil but request is not nil",
args: args{
src: nil,
req: &proto.ReplicationSource{
Type: &proto.ReplicationSource_Volume{
Volume: &proto.ReplicationSource_VolumeSource{
VolumeId: volID,
},
},
},
},
wantErr: false,
},
{
name: "set replication source when volume is set",
args: args{
Expand Down Expand Up @@ -81,16 +95,16 @@ func Test_setReplicationSource(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := setReplicationSource(tt.args.src, tt.args.req); (err != nil) != tt.wantErr {
if err := setReplicationSource(&tt.args.src, tt.args.req); (err != nil) != tt.wantErr {
t.Errorf("setReplicationSource() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.args.req.GetVolume() != nil {
if tt.args.req.GetVolume().GetVolumeId() != volID {
if tt.args.src.GetVolume().GetVolumeId() != volID {
t.Errorf("setReplicationSource() got = %v volumeID, expected = %v volumeID", tt.args.req.GetVolume().GetVolumeId(), volID)
}
}
if tt.args.req.GetVolumeGroup() != nil {
if tt.args.req.GetVolumeGroup().GetVolumeGroupId() != groupID {
if tt.args.src.GetVolumegroup().GetVolumeGroupId() != groupID {
t.Errorf("setReplicationSource() got = %v groupID, expected = %v volumeID", tt.args.req.GetVolumeGroup().GetVolumeGroupId(), groupID)
}
}
Expand Down

0 comments on commit 463254d

Please sign in to comment.