Skip to content

Commit

Permalink
Added recursive query parameter and fixed some linting errors (#2368)
Browse files Browse the repository at this point in the history
  • Loading branch information
gapra-msft authored Sep 13, 2023
1 parent 9500654 commit a5fe78e
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 232 deletions.
4 changes: 0 additions & 4 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,6 @@ jobs:
# acquire the mutex before running live tests to avoid conflicts
python ./tool_distributed_mutex.py lock "$(MUTEX_URL)"
name: 'Acquire_the_distributed_mutex'
- template: azurePipelineTemplates/run-ut.yml
parameters:
directory: 'azbfs'
coverage_name: 'azbfs'
- template: azurePipelineTemplates/run-ut.yml
parameters:
directory: 'cmd'
Expand Down
18 changes: 12 additions & 6 deletions cmd/syncProcessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"encoding/json"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake"
Expand Down Expand Up @@ -374,7 +375,8 @@ func (b *remoteResourceDeleter) delete(object StoredObject) error {
var err error
switch b.targetLocation {
case common.ELocation.Blob():
blobURLParts, err := blob.ParseURL(b.rootURL.String())
var blobURLParts blob.URLParts
blobURLParts, err = blob.ParseURL(b.rootURL.String())
if err != nil {
return err
}
Expand All @@ -383,7 +385,8 @@ func (b *remoteResourceDeleter) delete(object StoredObject) error {
blobClient := common.CreateBlobClient(blobURLParts.String(), b.credInfo, nil, b.clientOptions)
_, err = blobClient.Delete(b.ctx, nil)
case common.ELocation.File():
fileURLParts, err := sharefile.ParseURL(b.rootURL.String())
var fileURLParts sharefile.URLParts
fileURLParts, err = sharefile.ParseURL(b.rootURL.String())
if err != nil {
return err
}
Expand Down Expand Up @@ -412,7 +415,8 @@ func (b *remoteResourceDeleter) delete(object StoredObject) error {
}
}
case common.ELocation.BlobFS():
datalakeURLParts, err := azdatalake.ParseURL(b.rootURL.String())
var datalakeURLParts azdatalake.URLParts
datalakeURLParts, err = azdatalake.ParseURL(b.rootURL.String())
if err != nil {
return err
}
Expand Down Expand Up @@ -471,9 +475,11 @@ func (b *remoteResourceDeleter) delete(object StoredObject) error {
}
}
case common.ELocation.BlobFS():
directoryClient := common.CreateDatalakeDirectoryClient(objectURL.String(), b.credInfo, nil, b.clientOptions)
// TODO : Recursive delete
_, err = directoryClient.Delete(ctx, nil)
clientOptions := b.clientOptions
clientOptions.PerCallPolicies = append([]policy.Policy{common.NewRecursivePolicy()}, clientOptions.PerCallPolicies...)
directoryClient := common.CreateDatalakeDirectoryClient(objectURL.String(), b.credInfo, nil, clientOptions)
recursiveContext := common.WithRecursive(ctx, false)
_, err = directoryClient.Delete(recursiveContext, nil)
default:
panic("not implemented, check your code")
}
Expand Down
55 changes: 55 additions & 0 deletions common/blobFSRecursiveDeletePolicy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright © Microsoft <[email protected]>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package common

import (
"context"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"net/http"
"strconv"
)

// CtxRecursiveKey is used as a context key to apply the recursive query parameter.
type CtxRecursiveKey struct{}

// WithRecursive applies the recursive parameter to the request.
func WithRecursive(parent context.Context, recursive bool) context.Context {
return context.WithValue(parent, CtxRecursiveKey{}, recursive)
}

type recursivePolicy struct {
}

// NewRecursivePolicy creates a policy that applies the recursive parameter to the request.
func NewRecursivePolicy() policy.Policy {
return &recursivePolicy{}
}

func (p *recursivePolicy) Do(req *policy.Request) (*http.Response, error) {
if recursive := req.Raw().Context().Value(CtxRecursiveKey{}); recursive != nil {
if req.Raw().URL.Query().Has("recursive") {
query := req.Raw().URL.Query()
query.Set("recursive", strconv.FormatBool(recursive.(bool)))
req.Raw().URL.RawQuery = query.Encode()
}
}
return req.Next()
}
106 changes: 106 additions & 0 deletions common/blobFSRecursiveDeletePolicy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package common

import (
"context"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)

type testRecursive struct{
recursive string
}

func (t testRecursive) Do(req *policy.Request) (*http.Response, error) {
if req.Raw().URL.Query().Has("recursive") {
if req.Raw().URL.Query().Get("recursive") == t.recursive {
return &http.Response{}, nil
}
}
return &http.Response{}, fmt.Errorf("recursive query parameter not found or does not match expected value. expected: %s, actual: %s", t.recursive, req.Raw().URL.Query().Get("recursive"))
}

func TestRecursivePolicyExpectTrue(t *testing.T) {
a := assert.New(t)
ctx := WithRecursive(context.Background(), true)
policies := []policy.Policy{NewRecursivePolicy(), testRecursive{"true"}}
p := runtime.NewPipeline("testmodule", "v0.1.0", runtime.PipelineOptions{}, &policy.ClientOptions{Transport: nil, PerCallPolicies: policies})

endpoints := []string{"https://xxxx.dfs.core.windows.net/container/path?recursive=true",
"https://xxxx.dfs.core.windows.net/container/path?recursive=true&sig=xxxxxx&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&recursive=true&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&snapshot=xxxxx&timeout=xxxx&recursive=true",
"https://xxxx.dfs.core.windows.net/container/path?recursive=false",
"https://xxxx.dfs.core.windows.net/container/path?recursive=false&sig=xxxxxx&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&recursive=false&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&snapshot=xxxxx&timeout=xxxx&recursive=false",}

for _, e := range endpoints {
req, err := runtime.NewRequest(ctx, "HEAD", e)
a.Nil(err)
_, err = p.Do(req)
a.Nil(err)
}
}

func TestRecursivePolicyExpectFalse(t *testing.T) {
a := assert.New(t)
ctx := WithRecursive(context.Background(), false)
policies := []policy.Policy{NewRecursivePolicy(), testRecursive{"false"}}
p := runtime.NewPipeline("testmodule", "v0.1.0", runtime.PipelineOptions{}, &policy.ClientOptions{Transport: nil, PerCallPolicies: policies})

endpoints := []string{"https://xxxx.dfs.core.windows.net/container/path?recursive=true",
"https://xxxx.dfs.core.windows.net/container/path?recursive=true&sig=xxxxxx&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&recursive=true&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&snapshot=xxxxx&timeout=xxxx&recursive=true",
"https://xxxx.dfs.core.windows.net/container/path?recursive=false",
"https://xxxx.dfs.core.windows.net/container/path?recursive=false&sig=xxxxxx&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&recursive=false&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&snapshot=xxxxx&timeout=xxxx&recursive=false",}

for _, e := range endpoints {
req, err := runtime.NewRequest(ctx, "HEAD", e)
a.Nil(err)
_, err = p.Do(req)
a.Nil(err)
}
}

type testEndpoint struct{
endpoint string
}

func (t testEndpoint) Do(req *policy.Request) (*http.Response, error) {
if req.Raw().URL.String() == t.endpoint {
return &http.Response{}, nil
}
return &http.Response{}, fmt.Errorf("recursive query parameter not found or does not match expected value. expected: %s, actual: %s", t.endpoint, req.Raw().URL.String())
}

func TestRecursivePolicyExpectNoChange(t *testing.T) {
a := assert.New(t)

endpoints := []string{"https://xxxx.dfs.core.windows.net/container/path?recursive=true",
"https://xxxx.dfs.core.windows.net/container/path?recursive=true&sig=xxxxxx&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&recursive=true&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&snapshot=xxxxx&timeout=xxxx&recursive=true",
"https://xxxx.dfs.core.windows.net/container/path?recursive=false",
"https://xxxx.dfs.core.windows.net/container/path?recursive=false&sig=xxxxxx&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&recursive=false&snapshot=xxxxx&timeout=xxxx",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&snapshot=xxxxx&timeout=xxxx&recursive=false",
"https://xxxx.dfs.core.windows.net/container/path",
"https://xxxx.dfs.core.windows.net/container/path?sig=xxxxxx&snapshot=xxxxx&timeout=xxxx",}

for _, e := range endpoints {
policies := []policy.Policy{NewRecursivePolicy(), testEndpoint{e}}
p := runtime.NewPipeline("testmodule", "v0.1.0", runtime.PipelineOptions{}, &policy.ClientOptions{Transport: nil, PerCallPolicies: policies})
req, err := runtime.NewRequest(context.Background(), "HEAD", e)
a.Nil(err)
_, err = p.Do(req)
a.Nil(err)
}

}
54 changes: 1 addition & 53 deletions common/credentialFactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,14 @@
package common

import (
gcpUtils "cloud.google.com/go/storage"
"context"
"errors"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
"math"
"sync"
"time"

gcpUtils "cloud.google.com/go/storage"

"github.com/Azure/go-autorest/autorest/adal"
"github.com/minio/minio-go"
"github.com/minio/minio-go/pkg/credentials"
)
Expand All @@ -57,20 +53,6 @@ func (o CredentialOpOptions) callerMessage() string {
return Iff(o.CallerID == "", o.CallerID, o.CallerID+" ")
}

// logInfo logs info, if LogInfo is specified in CredentialOpOptions.
func (o CredentialOpOptions) logInfo(str string) {
if o.LogInfo != nil {
o.LogInfo(o.callerMessage() + str)
}
}

// logError logs error, if LogError is specified in CredentialOpOptions.
func (o CredentialOpOptions) logError(str string) {
if o.LogError != nil {
o.LogError(o.callerMessage() + str)
}
}

// panicError uses built-in panic if no Panic is specified in CredentialOpOptions.
func (o CredentialOpOptions) panicError(err error) {
newErr := fmt.Errorf("%s%v", o.callerMessage(), err)
Expand All @@ -81,14 +63,6 @@ func (o CredentialOpOptions) panicError(err error) {
}
}

func (o CredentialOpOptions) cancel() {
if o.Cancel != nil {
o.Cancel()
} else {
o.panicError(errors.New("cancel the operations"))
}
}

// GetSourceBlobCredential gets the TokenCredential based on the cred info
func GetSourceBlobCredential(credInfo CredentialInfo, options CredentialOpOptions) (azcore.TokenCredential, error) {
if credInfo.CredentialType.IsAzureOAuth() {
Expand All @@ -104,32 +78,6 @@ func GetSourceBlobCredential(credInfo CredentialInfo, options CredentialOpOption
return nil, nil
}

// refreshPolicyHalfOfExpiryWithin is used for calculating next refresh time,
// it checks how long it will be before the token get expired, and use half of the value as
// duration to wait.
func refreshPolicyHalfOfExpiryWithin(token *adal.Token, options CredentialOpOptions) time.Duration {
if token == nil {
// Invalid state, token should not be nil, cancel the operation and stop refresh
options.logError("invalid state, token is nil, cancel will be triggered")
options.cancel()
return time.Duration(math.MaxInt64)
}

waitDuration := token.Expires().Sub(time.Now().UTC()) / 2
// In case of refresh flooding
if waitDuration < time.Second {
waitDuration = time.Second
}

if GlobalTestOAuthInjection.DoTokenRefreshInjection {
waitDuration = GlobalTestOAuthInjection.TokenRefreshDuration
}

options.logInfo(fmt.Sprintf("next token refresh's wait duration: %v", waitDuration))

return waitDuration
}

// CreateS3Credential creates AWS S3 credential according to credential info.
func CreateS3Credential(ctx context.Context, credInfo CredentialInfo, options CredentialOpOptions) (*credentials.Credentials, error) {
glcm := GetLifecycleMgr()
Expand Down
2 changes: 0 additions & 2 deletions common/genericResourceURLParts.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ func (g *GenericResourceURLParts) String() string {
default:
panic(fmt.Sprintf("%s is an invalid location for GenericResourceURLParts", g.location))
}

return ""
}

func (g *GenericResourceURLParts) URL() url.URL {
Expand Down
20 changes: 0 additions & 20 deletions ste/mgr-JobPartMgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,6 @@ func (d *dialRateLimiter) DialContext(ctx context.Context, network, address stri
return d.dialer.DialContext(ctx, network, address)
}

// newAzcopyHTTPClientFactory creates a HTTPClientPolicyFactory object that sends HTTP requests to a Go's default http.Client.
func newAzcopyHTTPClientFactory(pipelineHTTPClient *http.Client) pipeline.Factory {
return pipeline.FactoryFunc(func(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.PolicyFunc {
return func(ctx context.Context, request pipeline.Request) (pipeline.Response, error) {
r, err := pipelineHTTPClient.Do(request.WithContext(ctx))
if err != nil {
err = pipeline.NewError(err, "HTTP request failed")
}
return pipeline.NewHTTPResponse(r), err
}
})
}

func NewClientOptions(retry policy.RetryOptions, telemetry policy.TelemetryOptions, transport policy.Transporter, statsAcc *PipelineNetworkStats, log LogOptions, trailingDot *common.TrailingDotOption, from *common.Location) azcore.ClientOptions {
// Pipeline will look like
// [includeResponsePolicy, newAPIVersionPolicy (ignored), NewTelemetryPolicy, perCall, NewRetryPolicy, perRetry, NewLogPolicy, httpHeaderPolicy, bodyDownloadPolicy]
Expand Down Expand Up @@ -231,15 +218,8 @@ type jobPartMgr struct {
exclusiveDestinationMap *common.ExclusiveStringMap

pipeline pipeline.Pipeline // ordered list of Factory objects and an object implementing the HTTPSender interface
// Currently, this only sees use in ADLSG2->ADLSG2 ACL transfers. TODO: Remove it when we can reliably get/set ACLs on blob.
secondaryPipeline pipeline.Pipeline

sourceProviderPipeline pipeline.Pipeline
// TODO: Ditto
secondarySourceProviderPipeline pipeline.Pipeline

// used defensively to protect double init
atomicPipelinesInitedIndicator uint32

// numberOfTransfersDone_doNotUse represents the number of transfer of JobPartOrder
// which are either completed or failed
Expand Down
8 changes: 8 additions & 0 deletions ste/sender-pageBlobFromURL.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ package ste

import (
"context"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/pageblob"
Expand Down Expand Up @@ -152,6 +154,12 @@ func newPageRangeOptimizer(srcPageBlobClient *pageblob.Client, ctx context.Conte
return &pageRangeOptimizer{srcPageBlobClient: srcPageBlobClient, ctx: ctx}
}

// withNoRetryForBlob returns a context that contains a marker to say we don't want any retries to happen
// Is only implemented for blob pipelines at present
func withNoRetryForBlob(ctx context.Context) context.Context {
return runtime.WithRetryOptions(ctx, policy.RetryOptions{MaxRetries: 1})
}

func (p *pageRangeOptimizer) fetchPages() {
// don't fetch page blob list if optimizations are not desired,
// the lack of page list indicates that there's data everywhere
Expand Down
Loading

0 comments on commit a5fe78e

Please sign in to comment.