Skip to content

Commit

Permalink
fix: trim blob SAS URL read from mounted secret file
Browse files Browse the repository at this point in the history
Signed-off-by: Qingchuan Hao <[email protected]>
  • Loading branch information
mainred committed May 29, 2024
1 parent 98751d6 commit 9464e6f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 11 deletions.
31 changes: 23 additions & 8 deletions pkg/capture/outputlocation/blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (bu *BlobUpload) Name() string {
}

func (bu *BlobUpload) Enabled() bool {
_, err := readBlobURL()
_, err := readBlobSASURL()
if err != nil {
bu.l.Debug("Output location is not enabled", zap.String("location", bu.Name()))
return false
Expand All @@ -44,13 +44,13 @@ func (bu *BlobUpload) Enabled() bool {

func (bu *BlobUpload) Output(srcFilePath string) error {
bu.l.Info("Upload capture file to blob.", zap.String("location", bu.Name()))
blobURL, err := readBlobURL()
blobURL, err := readBlobSASURL()
if err != nil {
bu.l.Error("Failed to read blob url", zap.Error(err))
return err
}

if err = validateBlobURL(blobURL); err != nil {
if err = validateBlobSASURL(blobURL); err != nil {
bu.l.Error("Failed to validate blob url", zap.Error(err))
return err
}
Expand Down Expand Up @@ -85,7 +85,18 @@ func (bu *BlobUpload) Output(srcFilePath string) error {
return nil
}

func readBlobURL() (string, error) {
func trimBlobSASURL(blobSASURL string) string {
// Blob SAS URL from the secret created from a file can have a newline and is surrounded by double quotes,
// so we need to trim \" and \n and trimming spaces is for unexpected spaces in the URL by customers.
// For example:
// "\"https://$storage-account-url/$container-name?$blob-sas-token\"\n"
trimedSecret := strings.Trim(blobSASURL, "\"\n")
trimedSecret = strings.TrimSpace(trimedSecret)

return trimedSecret
}

func readBlobSASURL() (string, error) {
secretPath := filepath.Join(captureConstants.CaptureOutputLocationBlobUploadSecretPath, captureConstants.CaptureOutputLocationBlobUploadSecretKey)
if runtime.GOOS == "windows" {
containerSandboxMountPoint := os.Getenv(captureConstants.ContainerSandboxMountPointEnvKey)
Expand All @@ -95,19 +106,23 @@ func readBlobURL() (string, error) {
secretPath = filepath.Join(containerSandboxMountPoint, captureConstants.CaptureOutputLocationBlobUploadSecretPath, captureConstants.CaptureOutputLocationBlobUploadSecretKey)
}
secretBytes, err := os.ReadFile(secretPath)
return string(secretBytes), err
if err != nil {
return "", fmt.Errorf("failed to read file %s: %w", secretPath, err)
}
secretStr := string(secretBytes)
return trimBlobSASURL(secretStr), nil
}

func validateBlobURL(blobURL string) error {
u, err := url.Parse(blobURL)
func validateBlobSASURL(blobSASURL string) error {
u, err := url.Parse(blobSASURL)
if err != nil {
return err
}

// Split the path into storage account container and blob
path := strings.TrimPrefix(u.Path, "/")
if path == "" {
return fmt.Errorf("invalid blob URL")
return fmt.Errorf("invalid blob SAS URL") //nolint:goerr113

Check failure on line 125 in pkg/capture/outputlocation/blob.go

View workflow job for this annotation

GitHub Actions / Lint (linux, amd64)

whyNoLint: include an explanation for nolint directive (gocritic)

Check failure on line 125 in pkg/capture/outputlocation/blob.go

View workflow job for this annotation

GitHub Actions / Lint (linux, arm64)

whyNoLint: include an explanation for nolint directive (gocritic)

Check failure on line 125 in pkg/capture/outputlocation/blob.go

View workflow job for this annotation

GitHub Actions / Lint (windows, amd64)

whyNoLint: include an explanation for nolint directive (gocritic)

Check failure on line 125 in pkg/capture/outputlocation/blob.go

View workflow job for this annotation

GitHub Actions / Lint (windows, arm64)

whyNoLint: include an explanation for nolint directive (gocritic)
}

return nil
Expand Down
44 changes: 41 additions & 3 deletions pkg/capture/outputlocation/blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,53 @@ import (
"testing"
)

func TestValidateBlobURL(t *testing.T) {
func TestTrimBlobSASURL(t *testing.T) {
tests := []struct {
name string
inputURL string
expectedTrimmedURL string
}{
{
name: "valid input URL with sas token that have a newline and is surrounded by double quotes",
inputURL: "\"https://retina.blob.core.windows.net/container/blob?sas-token\"\n",
expectedTrimmedURL: "https://retina.blob.core.windows.net/container/blob?sas-token",
},
{
name: "valid input URL with sas token that have a newline and is surrounded by double quotes and extra spaces",
inputURL: "\"https://retina.blob.core.windows.net/container/blob?sas-token \"\n",
expectedTrimmedURL: "https://retina.blob.core.windows.net/container/blob?sas-token",
},
{
name: "valid input URL with sas token that has extra spaces",
inputURL: "https://retina.blob.core.windows.net/container/blob?sas-token ",
expectedTrimmedURL: "https://retina.blob.core.windows.net/container/blob?sas-token",
},
{
name: "valid input URL with sas token",
inputURL: "\"https://retina.blob.core.windows.net/container/blob?sas-token\"\n",
expectedTrimmedURL: "https://retina.blob.core.windows.net/container/blob?sas-token",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actualTrimmedBlobSASURL := trimBlobSASURL(tt.inputURL)
if actualTrimmedBlobSASURL != tt.expectedTrimmedURL {
t.Errorf("Expected trimmed Blob SAS URL %s, but got %s", tt.expectedTrimmedURL, actualTrimmedBlobSASURL)
}
})
}
}

func TestValidateBlobSASURL(t *testing.T) {
tests := []struct {
name string
inputURL string
expectedError error
}{
{
name: "valid input URL with sas token",
inputURL: "https://retina.blob.core.windows.net/container/blob?sp=r&st=2023-02-17T19:13:30Z&se=2023-02-18T03:13:30Z&spr=https&sv=2021-06-08&sr=c&sig=NtSxlRK5Vs4kVs1dIOfr%2FMdLKBVTA4t3uJ0gqLZ9exk%3D",
inputURL: "https://retina.blob.core.windows.net/container/blob?sas-token",
expectedError: nil,
},
{
Expand All @@ -44,7 +82,7 @@ func TestValidateBlobURL(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateBlobURL(tt.inputURL)
err := validateBlobSASURL(tt.inputURL)

if err != nil && tt.expectedError == nil {
t.Errorf("Unexpected error: %v", err)
Expand Down

0 comments on commit 9464e6f

Please sign in to comment.