Skip to content

Commit

Permalink
fix: remove unnecessary share validations and cleanup mountpoint name…
Browse files Browse the repository at this point in the history
… func
  • Loading branch information
fschade committed Jun 10, 2024
1 parent ff87878 commit 4c01ace
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 105 deletions.
142 changes: 61 additions & 81 deletions internal/grpc/services/usershareprovider/usershareprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,99 +506,30 @@ func (s *service) GetReceivedShare(ctx context.Context, req *collaboration.GetRe
}

func (s *service) UpdateReceivedShare(ctx context.Context, req *collaboration.UpdateReceivedShareRequest) (*collaboration.UpdateReceivedShareResponse, error) {
if req.GetShare() == nil {
return &collaboration.UpdateReceivedShareResponse{
Status: status.NewInvalid(ctx, "updating requires a received share object"),
}, nil
}

if req.GetShare().GetShare() == nil {
return &collaboration.UpdateReceivedShareResponse{
Status: status.NewInvalid(ctx, "share missing"),
}, nil
}

if req.GetShare().GetShare().GetId() == nil {
return &collaboration.UpdateReceivedShareResponse{
Status: status.NewInvalid(ctx, "share id missing"),
}, nil
}

if req.GetShare().GetShare().GetId().GetOpaqueId() == "" {
return &collaboration.UpdateReceivedShareResponse{
Status: status.NewInvalid(ctx, "share id empty"),
}, nil
}

isStateTransitionShareAccepted := slices.Contains(req.GetUpdateMask().GetPaths(), _fieldMaskPathState) && req.GetShare().GetState() == collaboration.ShareState_SHARE_STATE_ACCEPTED
if isStateTransitionShareAccepted {
isMountPointSet := slices.Contains(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint) && req.GetShare().GetMountPoint().GetPath() != ""
// we calculate a valid mountpoint only if the share should be accepted and the mount point is not set explicitly
if isStateTransitionShareAccepted && !isMountPointSet {
gatewayClient, err := s.gatewaySelector.Next()
if err != nil {
return nil, err
}

receivedShare, err := gatewayClient.GetReceivedShare(ctx, &collaboration.GetReceivedShareRequest{
Ref: &collaboration.ShareReference{
Spec: &collaboration.ShareReference_Id{
Id: req.GetShare().GetShare().GetId(),
},
},
})
switch {
case err != nil:
fallthrough
case receivedShare.GetStatus().GetCode() != rpc.Code_CODE_OK:
return &collaboration.UpdateReceivedShareResponse{
Status: receivedShare.GetStatus(),
}, err
}

resourceStat, err := gatewayClient.Stat(ctx, &provider.StatRequest{
Ref: &provider.Reference{
ResourceId: receivedShare.GetShare().GetShare().GetResourceId(),
},
})
s, err := setReceivedShareMountPoint(ctx, gatewayClient, req)
switch {
case err != nil:
fallthrough
case resourceStat.GetStatus().GetCode() != rpc.Code_CODE_OK:
case s.GetCode() != rpc.Code_CODE_OK:
return &collaboration.UpdateReceivedShareResponse{
Status: resourceStat.GetStatus(),
Status: s,
}, err
}

// handle mount point related updates
{
// find a suitable mount point
var requestedMountpoint string
switch {
case slices.Contains(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint) && req.GetShare().GetMountPoint().GetPath() != "":
requestedMountpoint = req.GetShare().GetMountPoint().GetPath()
case receivedShare.GetShare().GetMountPoint().GetPath() != "":
requestedMountpoint = receivedShare.GetShare().GetMountPoint().GetPath()
default:
requestedMountpoint = resourceStat.GetInfo().GetName()
}

// check if the requested mount point is available and if not, find a suitable one
availableMountpoint, err := GetAvailableMountpoint(ctx, gatewayClient,
resourceStat.GetInfo().GetId(),
requestedMountpoint,
)
if err != nil {
return &collaboration.UpdateReceivedShareResponse{
Status: status.NewInternal(ctx, err.Error()),
}, nil
}

if !slices.Contains(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint) {
req.GetUpdateMask().Paths = append(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint)
}

req.GetShare().MountPoint = &provider.Reference{
Path: availableMountpoint,
}
}
}

var uid userpb.UserId
Expand Down Expand Up @@ -627,7 +558,6 @@ func GetAvailableMountpoint(ctx context.Context, gwc gateway.GatewayAPIClient, i
return "", err
}

// we need to sort the received shares by mount point in order to make things easier to evaluate.
base := filepath.Clean(name)
mount := base
existingMountpoint := ""
Expand Down Expand Up @@ -670,11 +600,7 @@ func GetAvailableMountpoint(ctx context.Context, gwc gateway.GatewayAPIClient, i
for i := 1; i <= len(mountedShares)+1; i++ {
ext := filepath.Ext(base)
name := strings.TrimSuffix(base, ext)
// be smart about .tar.(gz|bz) files
if strings.HasSuffix(name, ".tar") {
name = strings.TrimSuffix(name, ".tar")
ext = ".tar" + ext
}

mount = name + " (" + strconv.Itoa(i) + ")" + ext
if !slices.Contains(mountedShares, mount) {
return mount, nil
Expand All @@ -684,3 +610,57 @@ func GetAvailableMountpoint(ctx context.Context, gwc gateway.GatewayAPIClient, i

return mount, nil
}

func setReceivedShareMountPoint(ctx context.Context, gwc gateway.GatewayAPIClient, req *collaboration.UpdateReceivedShareRequest) (*rpc.Status, error) {
receivedShare, err := gwc.GetReceivedShare(ctx, &collaboration.GetReceivedShareRequest{
Ref: &collaboration.ShareReference{
Spec: &collaboration.ShareReference_Id{
Id: req.GetShare().GetShare().GetId(),
},
},
})
switch {
case err != nil:
fallthrough
case receivedShare.GetStatus().GetCode() != rpc.Code_CODE_OK:
return receivedShare.GetStatus(), err
}

if receivedShare.GetShare().GetMountPoint().GetPath() != "" {
return status.NewOK(ctx), nil
}

resourceStat, err := gwc.Stat(ctx, &provider.StatRequest{
Ref: &provider.Reference{
ResourceId: receivedShare.GetShare().GetShare().GetResourceId(),
},
})
switch {
case err != nil:
fallthrough
case resourceStat.GetStatus().GetCode() != rpc.Code_CODE_OK:
return resourceStat.GetStatus(), err
}

// handle mount point related updates
{
// check if the requested mount point is available and if not, find a suitable one
availableMountpoint, err := GetAvailableMountpoint(ctx, gwc,
resourceStat.GetInfo().GetId(),
resourceStat.GetInfo().GetName(),
)
if err != nil {
return status.NewInternal(ctx, err.Error()), nil
}

if !slices.Contains(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint) {
req.GetUpdateMask().Paths = append(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint)
}

req.GetShare().MountPoint = &provider.Reference{
Path: availableMountpoint,
}
}

return status.NewOK(ctx), nil
}
24 changes: 0 additions & 24 deletions internal/grpc/services/usershareprovider/usershareprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,30 +141,6 @@ var _ = Describe("user share provider service", func() {
Expect(res.GetStatus().GetCode()).To(Equal(expectedStatus.GetCode()))
Expect(res.GetStatus().GetMessage()).To(ContainSubstring(expectedStatus.GetMessage()))
},
Entry(
"no received share",
&collaborationpb.UpdateReceivedShareRequest{},
status.NewInvalid(ctx, "updating requires"),
nil,
),
Entry(
"no share",
&collaborationpb.UpdateReceivedShareRequest{
Share: &collaborationpb.ReceivedShare{},
},
status.NewInvalid(ctx, "share missing"),
nil,
),
Entry(
"no share id",
&collaborationpb.UpdateReceivedShareRequest{
Share: &collaborationpb.ReceivedShare{
Share: &collaborationpb.Share{},
},
},
status.NewInvalid(ctx, "share id missing"),
nil,
),
Entry(
"no share opaque id",
&collaborationpb.UpdateReceivedShareRequest{
Expand Down

0 comments on commit 4c01ace

Please sign in to comment.