Skip to content

Commit

Permalink
Add a flag --registry-mirror
Browse files Browse the repository at this point in the history
With registry mirror, e.g. --registry-mirror=mcr.microsoft.com:xxx.azurecr.io
credential provider will return credentials of xxx.azurecr.io for
mcr.microsoft.com. In this way, an image with URL prefix mcr.microsoft.com,
but actually in xxx.azurecr.io, can be successfully pulled.

Signed-off-by: Zhecheng Li <[email protected]>
  • Loading branch information
lzhecheng authored and k8s-infra-cherrypick-robot committed Oct 16, 2024
1 parent 1a58862 commit 70732ba
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 25 deletions.
9 changes: 8 additions & 1 deletion cmd/acr-credential-provider/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ import (

func main() {
rand.Seed(time.Now().UnixNano())

var RegistryMirrorStr string

command := &cobra.Command{
Use: "acr-credential-provider configFile",
Short: "Acr credential provider for Kubelet",
Expand All @@ -44,7 +47,7 @@ func main() {
os.Exit(1)
}

acrProvider, err := credentialprovider.NewAcrProviderFromConfig(args[0])
acrProvider, err := credentialprovider.NewAcrProviderFromConfig(args[0], RegistryMirrorStr)
if err != nil {
klog.Errorf("Failed to initialize ACR provider: %v", err)
os.Exit(1)
Expand All @@ -60,6 +63,10 @@ func main() {
logs.InitLogs()
defer logs.FlushLogs()

// Flags
command.Flags().StringVarP(&RegistryMirrorStr, "registry-mirror", "r", "",
"Mirror a source registry host to a target registry host, and image pull credential will be requested to the target registry host when the image is from source registry host")

if err := command.Execute(); err != nil {
os.Exit(1)
}
Expand Down
2 changes: 2 additions & 0 deletions examples/out-of-tree/credential-provider-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ providers:
- "*.azurecr.cn"
- "*.azurecr.de"
- "*.azurecr.us"
- "mcr.microsoft.com"
args:
- /etc/kubernetes/azure.json
- --registry-mirror=mcr.microsoft.com:xxx.azurecr.io
78 changes: 58 additions & 20 deletions pkg/credentialprovider/azure_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ type CredentialProvider interface {

// acrProvider implements the credential provider interface for Azure Container Registry.
type acrProvider struct {
config *providerconfig.AzureAuthConfig
environment *azclient.Environment
credential azcore.TokenCredential
config *providerconfig.AzureAuthConfig
environment *azclient.Environment
credential azcore.TokenCredential
registryMirror map[string]string // Registry mirror relation: source registry -> target registry
}

func NewAcrProvider(config *providerconfig.AzureAuthConfig, environment *azclient.Environment, credential azcore.TokenCredential) CredentialProvider {
Expand All @@ -73,7 +74,7 @@ func NewAcrProvider(config *providerconfig.AzureAuthConfig, environment *azclien
}

// NewAcrProvider creates a new instance of the ACR provider.
func NewAcrProviderFromConfig(configFile string) (CredentialProvider, error) {
func NewAcrProviderFromConfig(configFile string, registryMirrorStr string) (CredentialProvider, error) {
if len(configFile) == 0 {
return nil, errors.New("no azure credential file is provided")
}
Expand Down Expand Up @@ -118,15 +119,16 @@ func NewAcrProviderFromConfig(configFile string) (CredentialProvider, error) {
}

return &acrProvider{
config: config,
credential: managedIdentityCredential,
environment: &envConfig,
config: config,
credential: managedIdentityCredential,
environment: &envConfig,
registryMirror: parseRegistryMirror(registryMirrorStr),
}, nil
}

func (a *acrProvider) GetCredentials(ctx context.Context, image string, _ []string) (*v1.CredentialProviderResponse, error) {
loginServer := a.parseACRLoginServerFromImage(image)
if loginServer == "" {
targetloginServer, sourceloginServer := a.parseACRLoginServerFromImage(image)
if targetloginServer == "" {
klog.V(2).Infof("image(%s) is not from ACR, return empty authentication", image)
return &v1.CredentialProviderResponse{
CacheKeyType: v1.RegistryPluginCacheKeyType,
Expand All @@ -148,16 +150,20 @@ func (a *acrProvider) GetCredentials(ctx context.Context, image string, _ []stri
}

if a.config.UseManagedIdentityExtension {
username, password, err := a.getFromACR(ctx, loginServer)
username, password, err := a.getFromACR(ctx, targetloginServer)
if err != nil {
klog.Errorf("error getting credentials from ACR for %s: %s", loginServer, err)
klog.Errorf("error getting credentials from ACR for %s: %s", targetloginServer, err)
return nil, err
}

response.Auth[loginServer] = v1.AuthConfig{
authConfig := v1.AuthConfig{
Username: username,
Password: password,
}
response.Auth[targetloginServer] = authConfig
if sourceloginServer != "" {
response.Auth[sourceloginServer] = authConfig
}
} else {
// Add our entry for each of the supported container registry URLs
for _, url := range containerRegistryUrls {
Expand Down Expand Up @@ -227,27 +233,59 @@ func (a *acrProvider) getFromACR(ctx context.Context, loginServer string) (strin
return dockerTokenLoginUsernameGUID, registryRefreshToken, nil
}

// parseACRLoginServerFromImage takes image as parameter and returns login server of it.
// Parameter `image` is expected in following format: foo.azurecr.io/bar/imageName:version
// parseACRLoginServerFromImage inputs an image URL and outputs login servers of target registry and source registry if --registry-mirror is set.
// Input is expected in following format: foo.azurecr.io/bar/imageName:version
// If the provided image is not an acr image, this function will return an empty string.
func (a *acrProvider) parseACRLoginServerFromImage(image string) string {
match := acrRE.FindAllString(image, -1)
func (a *acrProvider) parseACRLoginServerFromImage(image string) (string, string) {
targetImage, sourceRegistry := a.processImageWithRegistryMirror(image)

match := acrRE.FindAllString(targetImage, -1)
if len(match) == 1 {
return match[0]
targetRegistry := match[0]
return targetRegistry, sourceRegistry
}

// handle the custom cloud case
if a != nil && a.environment != nil {
cloudAcrSuffix := a.environment.ContainerRegistryDNSSuffix
cloudAcrSuffixLength := len(cloudAcrSuffix)
if cloudAcrSuffixLength > 0 {
customAcrSuffixIndex := strings.Index(image, cloudAcrSuffix)
customAcrSuffixIndex := strings.Index(targetImage, cloudAcrSuffix)
if customAcrSuffixIndex != -1 {
endIndex := customAcrSuffixIndex + cloudAcrSuffixLength
return image[0:endIndex]
return targetImage[0:endIndex], sourceRegistry
}
}
}

return ""
return "", ""
}

// With acrProvider registry mirror, e.g. {"mcr.microsoft.com": "abc.azurecr.io"}
// processImageWithRegistryMirror input format: "mcr.microsoft.com/bar/image:version"
// output format: "abc.azurecr.io/bar/image:version", "mcr.microsoft.com"
func (a *acrProvider) processImageWithRegistryMirror(image string) (string, string) {
for sourceRegistry, targetRegistry := range a.registryMirror {
if strings.HasPrefix(image, sourceRegistry) {
return strings.Replace(image, sourceRegistry, targetRegistry, 1), sourceRegistry
}
}
return image, ""
}

// parseRegistryMirror input format: "--registry-mirror=aaa:bbb,ccc:ddd"
// output format: map[string]string{"aaa": "bbb", "ccc": "ddd"}
func parseRegistryMirror(registryMirrorStr string) map[string]string {
registryMirror := map[string]string{}

registryMirrorStr = strings.ReplaceAll(registryMirrorStr, " ", "")
for _, mapping := range strings.Split(registryMirrorStr, ",") {
parts := strings.Split(mapping, ":")
if len(parts) != 2 {
klog.Errorf("Invalid registry mirror format: %s", mapping)
continue
}
registryMirror[parts[0]] = parts[1]
}
return registryMirror
}
96 changes: 92 additions & 4 deletions pkg/credentialprovider/azure_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -137,7 +138,7 @@ func TestGetCredentialsConfig(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error when closing temp file: %v", err)
}
provider, err := NewAcrProviderFromConfig(configFile.Name())
provider, err := NewAcrProviderFromConfig(configFile.Name(), "")
if err != nil && !test.expectError {
t.Fatalf("Unexpected error when creating new acr provider: %v", err)
}
Expand All @@ -163,6 +164,53 @@ func TestGetCredentialsConfig(t *testing.T) {
}
}

func TestProcessImageWithMirrorMapping(t *testing.T) {
configStr := `
{
"aadClientId": "foo",
"aadClientSecret": "bar"
}`

configFile, err := os.CreateTemp(".", "config.json")
assert.Nilf(t, err, "Unexpected error when creating temp file")
defer os.Remove(configFile.Name())
_, err = configFile.WriteString(configStr)
assert.Nilf(t, err, "Unexpected error when writing to temp file")
assert.Nilf(t, configFile.Close(), "Unexpected error when closing temp file")

provider, err := NewAcrProviderFromConfig(configFile.Name(), "mcr.microsoft.com:abc.azurecr.io")
assert.Nilf(t, err, "Unexpected error when creating new acr provider")
acrProvider := provider.(*acrProvider)

testcases := []struct {
description string
image string
expectedLoginServer string
expectedLoginServerMirror string
}{
{
description: "image in registry mirror map",
image: "mcr.microsoft.com/bar/image:version",
expectedLoginServer: "abc.azurecr.io",
expectedLoginServerMirror: "mcr.microsoft.com",
},
{
description: "image not in registry mirror map",
image: "foo.azurecr.io/bar/image:version",
expectedLoginServer: "foo.azurecr.io",
expectedLoginServerMirror: "",
},
}

for _, test := range testcases {
t.Run(test.description, func(t *testing.T) {
targetloginServer, sourceloginServer := acrProvider.parseACRLoginServerFromImage(test.image)
assert.Equal(t, targetloginServer, test.expectedLoginServer)
assert.Equal(t, sourceloginServer, test.expectedLoginServerMirror)
})
}
}

func TestParseACRLoginServerFromImage(t *testing.T) {

providerInterface := NewAcrProvider(&config.AzureAuthConfig{
Expand Down Expand Up @@ -215,8 +263,48 @@ func TestParseACRLoginServerFromImage(t *testing.T) {
},
}
for _, test := range tests {
if loginServer := provider.parseACRLoginServerFromImage(test.image); loginServer != test.expected {
t.Errorf("function parseACRLoginServerFromImage returns \"%s\" for image %s, expected \"%s\"", loginServer, test.image, test.expected)
}
t.Run(test.image, func(t *testing.T) {
targetloginServer, _ := provider.parseACRLoginServerFromImage(test.image)
assert.Equal(t, targetloginServer, test.expected)
})
}
}

func TestProcessMirrorMapping(t *testing.T) {
testcases := []struct {
description string
mirrorMappingStr string
expected map[string]string
}{
{
"multiple",
"aaa:bbb,ccc:ddd",
map[string]string{
"aaa": "bbb",
"ccc": "ddd",
},
},
{
"multiple with some spaces",
"aaa: bbb, ccc:ddd",
map[string]string{
"aaa": "bbb",
"ccc": "ddd",
},
},
{
"single",
"aaa:bbb",
map[string]string{
"aaa": "bbb",
},
},
}

for _, tc := range testcases {
t.Run(tc.description, func(t *testing.T) {
result := parseRegistryMirror(tc.mirrorMappingStr)
assert.True(t, reflect.DeepEqual(result, tc.expected))
})
}
}

0 comments on commit 70732ba

Please sign in to comment.