Skip to content

Commit 5a146f9

Browse files
committed
fix: use wildcard DNS records
Signed-off-by: Binbin Li <[email protected]>
1 parent 4732249 commit 5a146f9

File tree

5 files changed

+53
-55
lines changed

5 files changed

+53
-55
lines changed

pkg/common/oras/authprovider/azure/azureidentity.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"encoding/json"
2121
"fmt"
2222
"os"
23-
"regexp"
2423
"time"
2524

2625
re "github.com/ratify-project/ratify/errors"
@@ -68,13 +67,13 @@ type MIAuthProvider struct {
6867
authClientFactory AuthClientFactory
6968
registryHostGetter RegistryHostGetter
7069
getManagedIdentityToken ManagedIdentityTokenGetter
71-
hostPredicates []*regexp.Regexp
70+
endpoints []string
7271
}
7372

7473
type azureManagedIdentityAuthProviderConf struct {
7574
Name string `json:"name"`
7675
ClientID string `json:"clientID"`
77-
HostScope []string `json:"hostScope,omitempty"`
76+
Endpoints []string `json:"endpoints,omitempty"`
7877
}
7978

8079
const (
@@ -113,9 +112,12 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider
113112
return nil, err
114113
}
115114

116-
hostPredicates, err := parseHostScopeToPredicates(conf.HostScope)
117-
if err != nil {
118-
return nil, re.ErrorCodeConfigInvalid.WithError(err)
115+
if len(conf.Endpoints) == 0 {
116+
conf.Endpoints = []string{defaultACREndpoint}
117+
} else {
118+
if err := validateHostScope(conf.Endpoints); err != nil {
119+
return nil, re.ErrorCodeConfigInvalid.WithError(err)
120+
}
119121
}
120122

121123
// retrieve an AAD Access token
@@ -130,7 +132,7 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider
130132
tenantID: tenant,
131133
authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation
132134
getManagedIdentityToken: &defaultManagedIdentityTokenGetterImpl{}, // Concrete implementation
133-
hostPredicates: hostPredicates,
135+
endpoints: conf.Endpoints,
134136
}, nil
135137
}
136138

@@ -165,7 +167,7 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider
165167
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
166168
}
167169

168-
if err := validateHost(artifactHostName, d.hostPredicates); err != nil {
170+
if err := validateHost(artifactHostName, d.endpoints); err != nil {
169171
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithError(err)
170172
}
171173

pkg/common/oras/authprovider/azure/azureidentity_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) {
172172
authClientFactory: mockAuthClientFactory,
173173
registryHostGetter: mockRegistryHostGetter,
174174
getManagedIdentityToken: mockManagedIdentityTokenGetter,
175-
hostPredicates: hostPredicates,
175+
endpoints: hostPredicates,
176176
}
177177

178178
// Call Provide method
@@ -210,7 +210,7 @@ func TestMIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) {
210210
authClientFactory: mockAuthClientFactory,
211211
registryHostGetter: mockRegistryHostGetter,
212212
getManagedIdentityToken: mockManagedIdentityTokenGetter,
213-
hostPredicates: hostPredicates,
213+
endpoints: hostPredicates,
214214
}
215215

216216
// Call Provide method

pkg/common/oras/authprovider/azure/azureworkloadidentity.go

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"encoding/json"
2121
"fmt"
2222
"os"
23-
"regexp"
2423
"strings"
2524
"time"
2625

@@ -75,13 +74,13 @@ type WIAuthProvider struct {
7574
registryHostGetter RegistryHostGetter
7675
getAADAccessToken AADAccessTokenGetter
7776
reportMetrics MetricsReporter
78-
hostPredicates []*regexp.Regexp
77+
endpoints []string
7978
}
8079

8180
type azureWIAuthProviderConf struct {
8281
Name string `json:"name"`
8382
ClientID string `json:"clientID,omitempty"`
84-
HostScope []string `json:"hostScope,omitempty"`
83+
Endpoints []string `json:"endpoints,omitempty"`
8584
}
8685

8786
const (
@@ -118,9 +117,12 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider
118117
}
119118
}
120119

121-
hostPredicates, err := parseHostScopeToPredicates(conf.HostScope)
122-
if err != nil {
123-
return nil, re.ErrorCodeConfigInvalid.WithError(err)
120+
if len(conf.Endpoints) == 0 {
121+
conf.Endpoints = []string{defaultACREndpoint}
122+
} else {
123+
if err := validateHostScope(conf.Endpoints); err != nil {
124+
return nil, re.ErrorCodeConfigInvalid.WithError(err)
125+
}
124126
}
125127

126128
// retrieve an AAD Access token
@@ -137,7 +139,7 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider
137139
registryHostGetter: &defaultRegistryHostGetterImpl{}, // Concrete implementation
138140
getAADAccessToken: &defaultAADAccessTokenGetterImpl{}, // Concrete implementation
139141
reportMetrics: &defaultMetricsReporterImpl{},
140-
hostPredicates: hostPredicates,
142+
endpoints: conf.Endpoints,
141143
}, nil
142144
}
143145

@@ -168,7 +170,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider
168170
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
169171
}
170172

171-
if err := validateHost(artifactHostName, d.hostPredicates); err != nil {
173+
if err := validateHost(artifactHostName, d.endpoints); err != nil {
172174
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithError(err)
173175
}
174176

@@ -220,30 +222,6 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider
220222
return authConfig, nil
221223
}
222224

223-
func parseHostScopeToPredicates(hostScope []string) ([]*regexp.Regexp, error) {
224-
if err := validateHostScope(hostScope); err != nil {
225-
return nil, err
226-
}
227-
228-
var predicates []*regexp.Regexp
229-
if len(hostScope) == 0 {
230-
re, err := regexp.Compile("^" + defaultHostScope + "$")
231-
if err != nil {
232-
return nil, fmt.Errorf("failed to compile default host scope regex: %w", err)
233-
}
234-
predicates = append(predicates, re)
235-
} else {
236-
for _, scope := range hostScope {
237-
re, err := regexp.Compile("^" + strings.ReplaceAll(scope, "*", ".*") + "$")
238-
if err != nil {
239-
return nil, fmt.Errorf("failed to compile host scope regex: %w", err)
240-
}
241-
predicates = append(predicates, re)
242-
}
243-
}
244-
return predicates, nil
245-
}
246-
247225
// validateHostScope checks if the host scope is valid for auth provider.
248226
// A valid host is either a fully qualified domain name or a wildcard domain
249227
// name folloiwing RFC 1034.
@@ -256,20 +234,38 @@ func parseHostScopeToPredicates(hostScope []string) ([]*regexp.Regexp, error) {
256234
// - example.*
257235
// - *example.com
258236
func validateHostScope(hostScope []string) error {
259-
pattern := regexp.MustCompile(`^(\*\.)?([^*]+\.)*[^*.]+$`)
260237
for _, scope := range hostScope {
261-
if !pattern.MatchString(scope) {
262-
return fmt.Errorf("invalid host scope %s", scope)
238+
switch strings.Count(scope, "*") {
239+
case 0:
240+
continue
241+
case 1:
242+
if !strings.HasPrefix(scope, "*.") {
243+
return fmt.Errorf("invalid wildcard domain name: %s, it must start with '*.'", scope)
244+
}
245+
if len(scope) < 3 {
246+
return fmt.Errorf("invalid wildcard domain name: %s, it must have at least one character after '*.'", scope)
247+
}
248+
default:
249+
return fmt.Errorf("invalid wildcard domain name: %s, it must have at most one wildcard character", scope)
263250
}
264251
}
265252
return nil
266253
}
267254

268255
// validateHost checks if the host is in the scope of the store auth provider.
269-
func validateHost(host string, predicates []*regexp.Regexp) error {
270-
for _, scope := range predicates {
271-
if scope.MatchString(host) {
272-
return nil
256+
func validateHost(host string, endpoints []string) error {
257+
for _, endpoint := range endpoints {
258+
switch strings.Count(endpoint, "*") {
259+
case 0:
260+
if host == endpoint {
261+
return nil
262+
}
263+
case 1:
264+
if strings.HasSuffix(host, strings.TrimPrefix(endpoint, "*")) {
265+
return nil
266+
}
267+
default:
268+
continue
273269
}
274270
}
275271
return fmt.Errorf("the artifact host %s is not in the scope of the store auth provider", host)

pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func TestWIAuthProvider_Provide_Success(t *testing.T) {
8787
registryHostGetter: mockRegistryHostGetter,
8888
getAADAccessToken: mockAADAccessTokenGetter,
8989
reportMetrics: mockMetricsReporter,
90-
hostPredicates: hostPredicates,
90+
endpoints: hostPredicates,
9191
}
9292

9393
// Call Provide method
@@ -137,7 +137,7 @@ func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) {
137137
registryHostGetter: mockRegistryHostGetter,
138138
getAADAccessToken: mockAADAccessTokenGetter,
139139
reportMetrics: mockMetricsReporter,
140-
hostPredicates: hostPredicates,
140+
endpoints: hostPredicates,
141141
}
142142

143143
// Call Provide method
@@ -178,7 +178,7 @@ func TestWIAuthProvider_Provide_AADTokenFailure(t *testing.T) {
178178
registryHostGetter: mockRegistryHostGetter,
179179
getAADAccessToken: mockAADAccessTokenGetter,
180180
reportMetrics: mockMetricsReporter,
181-
hostPredicates: hostPredicates,
181+
endpoints: hostPredicates,
182182
}
183183

184184
// Call Provide method
@@ -261,7 +261,7 @@ func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) {
261261
registryHostGetter: mockRegistryHostGetter,
262262
getAADAccessToken: mockAADAccessTokenGetter,
263263
reportMetrics: mockMetricsReporter,
264-
hostPredicates: hostPredicates,
264+
endpoints: hostPredicates,
265265
}
266266

267267
// Call Provide method
@@ -302,7 +302,7 @@ func TestWIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) {
302302
registryHostGetter: mockRegistryHostGetter,
303303
getAADAccessToken: mockAADAccessTokenGetter,
304304
reportMetrics: mockMetricsReporter,
305-
hostPredicates: hostPredicates,
305+
endpoints: hostPredicates,
306306
}
307307

308308
// Call Provide method

pkg/common/oras/authprovider/azure/const.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ const (
2525
dockerTokenLoginUsernameGUID = "00000000-0000-0000-0000-000000000000"
2626
AADResource = "https://containerregistry.azure.net/.default"
2727
defaultACRExpiryDuration time.Duration = 3 * time.Hour
28-
defaultHostScope = ".*.azurecr.io"
28+
defaultACREndpoint = ".*.azurecr.io"
2929
)
3030

3131
var logOpt = logger.Option{

0 commit comments

Comments
 (0)