Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions pkg/common/oras/authprovider/azure/azureidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@
authClientFactory AuthClientFactory
registryHostGetter RegistryHostGetter
getManagedIdentityToken ManagedIdentityTokenGetter
endpoints []string
}

type azureManagedIdentityAuthProviderConf struct {
Name string `json:"name"`
ClientID string `json:"clientID"`
Name string `json:"name"`
ClientID string `json:"clientID"`
Endpoints []string `json:"endpoints,omitempty"`
}

const (
Expand Down Expand Up @@ -109,6 +111,15 @@
if err != nil {
return nil, err
}

if len(conf.Endpoints) == 0 {
conf.Endpoints = []string{defaultACREndpoint}
} else {
if err := validateEndpoints(conf.Endpoints); err != nil {
return nil, re.ErrorCodeConfigInvalid.WithError(err)
}

Check warning on line 120 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L115-L120

Added lines #L115 - L120 were not covered by tests
}

// retrieve an AAD Access token
token, err := getManagedIdentityToken(context.Background(), client, azidentity.NewManagedIdentityCredential)
if err != nil {
Expand All @@ -121,6 +132,7 @@
tenantID: tenant,
authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation
getManagedIdentityToken: &defaultManagedIdentityTokenGetterImpl{}, // Concrete implementation
endpoints: conf.Endpoints,

Check warning on line 135 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L135

Added line #L135 was not covered by tests
}, nil
}

Expand Down Expand Up @@ -155,6 +167,10 @@
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
}

if err := validateHost(artifactHostName, d.endpoints); err != nil {
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithError(err)
}

Check warning on line 172 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L171-L172

Added lines #L171 - L172 were not covered by tests

// need to refresh AAD token if it's expired
if time.Now().Add(time.Minute * 5).After(d.identityToken.ExpiresOn) {
newToken, err := d.getManagedIdentityToken.GetManagedIdentityToken(ctx, d.clientID)
Expand Down
2 changes: 2 additions & 0 deletions pkg/common/oras/authprovider/azure/azureidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) {
authClientFactory: mockAuthClientFactory,
registryHostGetter: mockRegistryHostGetter,
getManagedIdentityToken: mockManagedIdentityTokenGetter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down Expand Up @@ -200,6 +201,7 @@ func TestMIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) {
authClientFactory: mockAuthClientFactory,
registryHostGetter: mockRegistryHostGetter,
getManagedIdentityToken: mockManagedIdentityTokenGetter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down
71 changes: 69 additions & 2 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"time"

azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
Expand Down Expand Up @@ -72,11 +74,13 @@
registryHostGetter RegistryHostGetter
getAADAccessToken AADAccessTokenGetter
reportMetrics MetricsReporter
endpoints []string
}

type azureWIAuthProviderConf struct {
Name string `json:"name"`
ClientID string `json:"clientID,omitempty"`
Name string `json:"name"`
ClientID string `json:"clientID,omitempty"`
Endpoints []string `json:"endpoints,omitempty"`
}

const (
Expand Down Expand Up @@ -113,6 +117,14 @@
}
}

if len(conf.Endpoints) == 0 {
conf.Endpoints = []string{defaultACREndpoint}
} else {
if err := validateEndpoints(conf.Endpoints); err != nil {
return nil, re.ErrorCodeConfigInvalid.WithError(err)
}

Check warning on line 125 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L123-L125

Added lines #L123 - L125 were not covered by tests
}

// retrieve an AAD Access token
token, err := defaultGetAADAccessToken(context.Background(), tenant, clientID, AADResource)
if err != nil {
Expand All @@ -127,6 +139,7 @@
registryHostGetter: &defaultRegistryHostGetterImpl{}, // Concrete implementation
getAADAccessToken: &defaultAADAccessTokenGetterImpl{}, // Concrete implementation
reportMetrics: &defaultMetricsReporterImpl{},
endpoints: conf.Endpoints,

Check warning on line 142 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L142

Added line #L142 was not covered by tests
}, nil
}

Expand Down Expand Up @@ -157,6 +170,10 @@
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
}

if err := validateHost(artifactHostName, d.endpoints); err != nil {
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithError(err)
}

Check warning on line 175 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L174-L175

Added lines #L174 - L175 were not covered by tests

// need to refresh AAD token if it's expired
if time.Now().Add(time.Minute * 5).After(d.aadToken.ExpiresOn) {
newToken, err := d.getAADAccessToken.GetAADAccessToken(ctx, d.tenantID, d.clientID, AADResource)
Expand Down Expand Up @@ -205,6 +222,56 @@
return authConfig, nil
}

// validateEndpoints checks if the endpoints are valid for auth provider.
// A valid endpoint is either a fully qualified domain name or a wildcard domain
// name folloiwing RFC 1034.
// Valid examples:
// - *.example.com
// - example.com
//
// Invalid examples:
// - *
// - example.*
// - *example.com.*
func validateEndpoints(endpoints []string) error {
for _, endpoint := range endpoints {
switch strings.Count(endpoint, "*") {
case 0:
continue
case 1:
if !strings.HasPrefix(endpoint, "*.") {
return fmt.Errorf("invalid wildcard domain name: %s, it must start with '*.'", endpoint)
}
if len(endpoint) < 3 {
return fmt.Errorf("invalid wildcard domain name: %s, it must have at least one character after '*.'", endpoint)
}
default:
return fmt.Errorf("invalid wildcard domain name: %s, it must have at most one wildcard character", endpoint)

Check warning on line 249 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L236-L249

Added lines #L236 - L249 were not covered by tests
}
}
return nil

Check warning on line 252 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L252

Added line #L252 was not covered by tests
}

// validateHost checks if the host is matching endpoints supported by the auth
// provider.
func validateHost(host string, endpoints []string) error {
for _, endpoint := range endpoints {
switch strings.Count(endpoint, "*") {
case 0:
if host == endpoint {
return nil
}
case 1:
if strings.HasSuffix(host, strings.TrimPrefix(endpoint, "*")) {
return nil
}
default:
continue

Check warning on line 269 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L264-L269

Added lines #L264 - L269 were not covered by tests
}
}
return fmt.Errorf("the artifact host %s is not in the scope of the store auth provider", host)

Check warning on line 272 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L272

Added line #L272 was not covered by tests
}

// Compare addExpiry with default ACR refresh token expiry
func getACRExpiryIfEarlier(aadExpiry time.Time) time.Time {
// set default refresh token expiry to default ACR expiry - 5 minutes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func TestWIAuthProvider_Provide_Success(t *testing.T) {
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down Expand Up @@ -126,6 +127,7 @@ func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) {
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down Expand Up @@ -161,6 +163,7 @@ func TestWIAuthProvider_Provide_AADTokenFailure(t *testing.T) {
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down Expand Up @@ -238,6 +241,7 @@ func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) {
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down Expand Up @@ -273,6 +277,7 @@ func TestWIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) {
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
endpoints: []string{"example.azurecr.io"},
}

// Call Provide method
Expand Down
1 change: 1 addition & 0 deletions pkg/common/oras/authprovider/azure/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
dockerTokenLoginUsernameGUID = "00000000-0000-0000-0000-000000000000"
AADResource = "https://containerregistry.azure.net/.default"
defaultACRExpiryDuration time.Duration = 3 * time.Hour
defaultACREndpoint = ".*.azurecr.io"
)

var logOpt = logger.Option{
Expand Down
Loading