diff --git a/internal/grpc/services/usershareprovider/usershareprovider.go b/internal/grpc/services/usershareprovider/usershareprovider.go index c52caad51b..2c1cae656c 100644 --- a/internal/grpc/services/usershareprovider/usershareprovider.go +++ b/internal/grpc/services/usershareprovider/usershareprovider.go @@ -506,24 +506,6 @@ func (s *service) GetReceivedShare(ctx context.Context, req *collaboration.GetRe } func (s *service) UpdateReceivedShare(ctx context.Context, req *collaboration.UpdateReceivedShareRequest) (*collaboration.UpdateReceivedShareResponse, error) { - if req.GetShare() == nil { - return &collaboration.UpdateReceivedShareResponse{ - Status: status.NewInvalid(ctx, "updating requires a received share object"), - }, nil - } - - if req.GetShare().GetShare() == nil { - return &collaboration.UpdateReceivedShareResponse{ - Status: status.NewInvalid(ctx, "share missing"), - }, nil - } - - if req.GetShare().GetShare().GetId() == nil { - return &collaboration.UpdateReceivedShareResponse{ - Status: status.NewInvalid(ctx, "share id missing"), - }, nil - } - if req.GetShare().GetShare().GetId().GetOpaqueId() == "" { return &collaboration.UpdateReceivedShareResponse{ Status: status.NewInvalid(ctx, "share id empty"), @@ -531,74 +513,23 @@ func (s *service) UpdateReceivedShare(ctx context.Context, req *collaboration.Up } isStateTransitionShareAccepted := slices.Contains(req.GetUpdateMask().GetPaths(), _fieldMaskPathState) && req.GetShare().GetState() == collaboration.ShareState_SHARE_STATE_ACCEPTED - if isStateTransitionShareAccepted { + isMountPointSet := slices.Contains(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint) && req.GetShare().GetMountPoint().GetPath() != "" + // we calculate a valid mountpoint only if the share should be accepted and the mount point is not set explicitly + if isStateTransitionShareAccepted && !isMountPointSet { gatewayClient, err := s.gatewaySelector.Next() if err != nil { return nil, err } - receivedShare, err := gatewayClient.GetReceivedShare(ctx, &collaboration.GetReceivedShareRequest{ - Ref: &collaboration.ShareReference{ - Spec: &collaboration.ShareReference_Id{ - Id: req.GetShare().GetShare().GetId(), - }, - }, - }) - switch { - case err != nil: - fallthrough - case receivedShare.GetStatus().GetCode() != rpc.Code_CODE_OK: - return &collaboration.UpdateReceivedShareResponse{ - Status: receivedShare.GetStatus(), - }, err - } - - resourceStat, err := gatewayClient.Stat(ctx, &provider.StatRequest{ - Ref: &provider.Reference{ - ResourceId: receivedShare.GetShare().GetShare().GetResourceId(), - }, - }) + s, err := setReceivedShareMountPoint(ctx, gatewayClient, req) switch { case err != nil: fallthrough - case resourceStat.GetStatus().GetCode() != rpc.Code_CODE_OK: + case s.GetCode() != rpc.Code_CODE_OK: return &collaboration.UpdateReceivedShareResponse{ - Status: resourceStat.GetStatus(), + Status: s, }, err } - - // handle mount point related updates - { - // find a suitable mount point - var requestedMountpoint string - switch { - case slices.Contains(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint) && req.GetShare().GetMountPoint().GetPath() != "": - requestedMountpoint = req.GetShare().GetMountPoint().GetPath() - case receivedShare.GetShare().GetMountPoint().GetPath() != "": - requestedMountpoint = receivedShare.GetShare().GetMountPoint().GetPath() - default: - requestedMountpoint = resourceStat.GetInfo().GetName() - } - - // check if the requested mount point is available and if not, find a suitable one - availableMountpoint, err := GetAvailableMountpoint(ctx, gatewayClient, - resourceStat.GetInfo().GetId(), - requestedMountpoint, - ) - if err != nil { - return &collaboration.UpdateReceivedShareResponse{ - Status: status.NewInternal(ctx, err.Error()), - }, nil - } - - if !slices.Contains(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint) { - req.GetUpdateMask().Paths = append(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint) - } - - req.GetShare().MountPoint = &provider.Reference{ - Path: availableMountpoint, - } - } } var uid userpb.UserId @@ -627,7 +558,6 @@ func GetAvailableMountpoint(ctx context.Context, gwc gateway.GatewayAPIClient, i return "", err } - // we need to sort the received shares by mount point in order to make things easier to evaluate. base := filepath.Clean(name) mount := base existingMountpoint := "" @@ -670,11 +600,7 @@ func GetAvailableMountpoint(ctx context.Context, gwc gateway.GatewayAPIClient, i for i := 1; i <= len(mountedShares)+1; i++ { ext := filepath.Ext(base) name := strings.TrimSuffix(base, ext) - // be smart about .tar.(gz|bz) files - if strings.HasSuffix(name, ".tar") { - name = strings.TrimSuffix(name, ".tar") - ext = ".tar" + ext - } + mount = name + " (" + strconv.Itoa(i) + ")" + ext if !slices.Contains(mountedShares, mount) { return mount, nil @@ -684,3 +610,57 @@ func GetAvailableMountpoint(ctx context.Context, gwc gateway.GatewayAPIClient, i return mount, nil } + +func setReceivedShareMountPoint(ctx context.Context, gwc gateway.GatewayAPIClient, req *collaboration.UpdateReceivedShareRequest) (*rpc.Status, error) { + receivedShare, err := gwc.GetReceivedShare(ctx, &collaboration.GetReceivedShareRequest{ + Ref: &collaboration.ShareReference{ + Spec: &collaboration.ShareReference_Id{ + Id: req.GetShare().GetShare().GetId(), + }, + }, + }) + switch { + case err != nil: + fallthrough + case receivedShare.GetStatus().GetCode() != rpc.Code_CODE_OK: + return receivedShare.GetStatus(), err + } + + if receivedShare.GetShare().GetMountPoint().GetPath() != "" { + return status.NewOK(ctx), nil + } + + resourceStat, err := gwc.Stat(ctx, &provider.StatRequest{ + Ref: &provider.Reference{ + ResourceId: receivedShare.GetShare().GetShare().GetResourceId(), + }, + }) + switch { + case err != nil: + fallthrough + case resourceStat.GetStatus().GetCode() != rpc.Code_CODE_OK: + return resourceStat.GetStatus(), err + } + + // handle mount point related updates + { + // check if the requested mount point is available and if not, find a suitable one + availableMountpoint, err := GetAvailableMountpoint(ctx, gwc, + resourceStat.GetInfo().GetId(), + resourceStat.GetInfo().GetName(), + ) + if err != nil { + return status.NewInternal(ctx, err.Error()), nil + } + + if !slices.Contains(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint) { + req.GetUpdateMask().Paths = append(req.GetUpdateMask().GetPaths(), _fieldMaskPathMountPoint) + } + + req.GetShare().MountPoint = &provider.Reference{ + Path: availableMountpoint, + } + } + + return status.NewOK(ctx), nil +} diff --git a/internal/grpc/services/usershareprovider/usershareprovider_test.go b/internal/grpc/services/usershareprovider/usershareprovider_test.go index ba4ccf8210..9adcbef778 100644 --- a/internal/grpc/services/usershareprovider/usershareprovider_test.go +++ b/internal/grpc/services/usershareprovider/usershareprovider_test.go @@ -141,30 +141,6 @@ var _ = Describe("user share provider service", func() { Expect(res.GetStatus().GetCode()).To(Equal(expectedStatus.GetCode())) Expect(res.GetStatus().GetMessage()).To(ContainSubstring(expectedStatus.GetMessage())) }, - Entry( - "no received share", - &collaborationpb.UpdateReceivedShareRequest{}, - status.NewInvalid(ctx, "updating requires"), - nil, - ), - Entry( - "no share", - &collaborationpb.UpdateReceivedShareRequest{ - Share: &collaborationpb.ReceivedShare{}, - }, - status.NewInvalid(ctx, "share missing"), - nil, - ), - Entry( - "no share id", - &collaborationpb.UpdateReceivedShareRequest{ - Share: &collaborationpb.ReceivedShare{ - Share: &collaborationpb.Share{}, - }, - }, - status.NewInvalid(ctx, "share id missing"), - nil, - ), Entry( "no share opaque id", &collaborationpb.UpdateReceivedShareRequest{