Skip to content

Commit b21a1fa

Browse files
authored
fix(internal): support internaloption.WithDefaultUniverseDomain (#2373)
1 parent ddb3a12 commit b21a1fa

File tree

4 files changed

+86
-43
lines changed

4 files changed

+86
-43
lines changed

internal/cba.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func getClientCertificateSourceAndEndpoint(settings *DialSettings) (cert.Source,
8080
// if settings.DefaultEndpointTemplate == "" {
8181
// return nil, "", errors.New("internaloption.WithDefaultEndpointTemplate is required if option.WithUniverseDomain is not googleapis.com")
8282
// }
83-
endpoint = strings.Replace(settings.DefaultEndpointTemplate, universeDomainPlaceholder, settings.GetUniverseDomain(), 1)
83+
endpoint = resolvedDefaultEndpoint(settings)
8484
}
8585
return clientCertSource, endpoint, nil
8686
}
@@ -164,27 +164,41 @@ func isClientCertificateEnabled() bool {
164164
// WithDefaultEndpoint("https://foo.com/bar/baz") will return "https://myhost:8080/bar/baz"
165165
func getEndpoint(settings *DialSettings, clientCertSource cert.Source) (string, error) {
166166
if settings.Endpoint == "" {
167-
mtlsMode := getMTLSMode()
168-
if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
167+
if isMTLS(clientCertSource) {
169168
if !settings.IsUniverseDomainGDU() {
170169
return "", errUniverseNotSupportedMTLS
171170
}
172171
return settings.DefaultMTLSEndpoint, nil
173172
}
174-
return settings.DefaultEndpoint, nil
173+
return resolvedDefaultEndpoint(settings), nil
175174
}
176175
if strings.Contains(settings.Endpoint, "://") {
177176
// User passed in a full URL path, use it verbatim.
178177
return settings.Endpoint, nil
179178
}
180-
if settings.DefaultEndpoint == "" {
179+
if resolvedDefaultEndpoint(settings) == "" {
181180
// If DefaultEndpoint is not configured, use the user provided endpoint verbatim.
182181
// This allows a naked "host[:port]" URL to be used with GRPC Direct Path.
183182
return settings.Endpoint, nil
184183
}
185184

186185
// Assume user-provided endpoint is host[:port], merge it with the default endpoint.
187-
return mergeEndpoints(settings.DefaultEndpoint, settings.Endpoint)
186+
return mergeEndpoints(resolvedDefaultEndpoint(settings), settings.Endpoint)
187+
}
188+
189+
func isMTLS(clientCertSource cert.Source) bool {
190+
mtlsMode := getMTLSMode()
191+
return mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto)
192+
}
193+
194+
// resolvedDefaultEndpoint returns the DefaultEndpointTemplate merged with the
195+
// Universe Domain if the DefaultEndpointTemplate is set, otherwise returns the
196+
// deprecated DefaultEndpoint value.
197+
func resolvedDefaultEndpoint(settings *DialSettings) string {
198+
if settings.DefaultEndpointTemplate == "" {
199+
return settings.DefaultEndpoint
200+
}
201+
return strings.Replace(settings.DefaultEndpointTemplate, universeDomainPlaceholder, settings.GetUniverseDomain(), 1)
188202
}
189203

190204
func getMTLSMode() string {

internal/cba_test.go

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,25 @@ var dummyClientCertSource = func(info *tls.CertificateRequestInfo) (*tls.Certifi
2525

2626
func TestGetEndpoint(t *testing.T) {
2727
testCases := []struct {
28-
UserEndpoint string
29-
DefaultEndpoint string
30-
Want string
31-
WantErr bool
28+
UserEndpoint string
29+
DefaultEndpoint string
30+
DefaultEndpointTemplate string
31+
Want string
32+
WantErr bool
3233
}{
3334
{
34-
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
35-
Want: "https://foo.googleapis.com/bar/baz",
35+
DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
36+
Want: "https://foo.googleapis.com/bar/baz",
3637
},
3738
{
38-
UserEndpoint: "myhost:3999",
39-
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
40-
Want: "https://myhost:3999/bar/baz",
39+
UserEndpoint: "myhost:3999",
40+
DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
41+
Want: "https://myhost:3999/bar/baz",
4142
},
4243
{
43-
UserEndpoint: "https://host/path/to/bar",
44-
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
45-
Want: "https://host/path/to/bar",
44+
UserEndpoint: "https://host/path/to/bar",
45+
DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
46+
Want: "https://host/path/to/bar",
4647
},
4748
{
4849
UserEndpoint: "host:123",
@@ -63,8 +64,10 @@ func TestGetEndpoint(t *testing.T) {
6364

6465
for _, tc := range testCases {
6566
got, err := getEndpoint(&DialSettings{
66-
Endpoint: tc.UserEndpoint,
67-
DefaultEndpoint: tc.DefaultEndpoint,
67+
Endpoint: tc.UserEndpoint,
68+
DefaultEndpoint: tc.DefaultEndpoint,
69+
DefaultEndpointTemplate: tc.DefaultEndpointTemplate,
70+
DefaultUniverseDomain: "googleapis.com",
6871
}, nil)
6972
if tc.WantErr && err == nil {
7073
t.Errorf("want err, got nil err")
@@ -75,7 +78,7 @@ func TestGetEndpoint(t *testing.T) {
7578
continue
7679
}
7780
if tc.Want != got {
78-
t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpoint, got, tc.Want)
81+
t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpointTemplate, got, tc.Want)
7982
}
8083
}
8184
}
@@ -118,9 +121,10 @@ func TestGetEndpointWithClientCertSource(t *testing.T) {
118121

119122
for _, tc := range testCases {
120123
got, err := getEndpoint(&DialSettings{
121-
Endpoint: tc.UserEndpoint,
122-
DefaultEndpoint: tc.DefaultEndpoint,
123-
DefaultMTLSEndpoint: tc.DefaultMTLSEndpoint,
124+
Endpoint: tc.UserEndpoint,
125+
DefaultEndpoint: tc.DefaultEndpoint,
126+
DefaultMTLSEndpoint: tc.DefaultMTLSEndpoint,
127+
DefaultUniverseDomain: "googleapis.com",
124128
}, dummyClientCertSource)
125129
if tc.WantErr && err == nil {
126130
t.Errorf("want err, got nil err")
@@ -174,18 +178,20 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
174178
{
175179
"no client cert, S2A address not empty, override endpoint",
176180
&DialSettings{
177-
DefaultMTLSEndpoint: testMTLSEndpoint,
178-
DefaultEndpoint: testRegularEndpoint,
179-
Endpoint: testOverrideEndpoint,
181+
DefaultMTLSEndpoint: testMTLSEndpoint,
182+
DefaultEndpointTemplate: testEndpointTemplate,
183+
Endpoint: testOverrideEndpoint,
184+
DefaultUniverseDomain: "googleapis.com",
180185
},
181186
validConfigResp,
182187
testOverrideEndpoint,
183188
},
184189
{
185190
"no client cert, S2A address not empty, DefaultMTLSEndpoint not set",
186191
&DialSettings{
187-
DefaultMTLSEndpoint: "",
188-
DefaultEndpoint: testRegularEndpoint,
192+
DefaultMTLSEndpoint: "",
193+
DefaultEndpointTemplate: testEndpointTemplate,
194+
DefaultUniverseDomain: "googleapis.com",
189195
},
190196
validConfigResp,
191197
testRegularEndpoint,
@@ -336,6 +342,7 @@ func TestGetHTTPTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
336342
DefaultEndpoint: testRegularEndpoint,
337343
DefaultEndpointTemplate: testEndpointTemplate,
338344
DefaultMTLSEndpoint: testMTLSEndpoint,
345+
DefaultUniverseDomain: "googleapis.com",
339346
},
340347
wantEndpoint: testRegularEndpoint,
341348
},
@@ -346,6 +353,7 @@ func TestGetHTTPTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
346353
DefaultEndpointTemplate: testEndpointTemplate,
347354
DefaultMTLSEndpoint: testMTLSEndpoint,
348355
ClientCertSource: dummyClientCertSource,
356+
DefaultUniverseDomain: "googleapis.com",
349357
},
350358
wantEndpoint: testMTLSEndpoint,
351359
},
@@ -356,6 +364,7 @@ func TestGetHTTPTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
356364
DefaultEndpointTemplate: testEndpointTemplate,
357365
DefaultMTLSEndpoint: testMTLSEndpoint,
358366
UniverseDomain: testUniverseDomain,
367+
DefaultUniverseDomain: "googleapis.com",
359368
},
360369
wantEndpoint: testUniverseDomainEndpoint,
361370
},
@@ -367,6 +376,7 @@ func TestGetHTTPTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
367376
DefaultMTLSEndpoint: testMTLSEndpoint,
368377
UniverseDomain: testUniverseDomain,
369378
ClientCertSource: dummyClientCertSource,
379+
DefaultUniverseDomain: "googleapis.com",
370380
},
371381
wantEndpoint: testUniverseDomainEndpoint,
372382
wantErr: errUniverseNotSupportedMTLS,
@@ -405,6 +415,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
405415
DefaultEndpoint: testRegularEndpoint,
406416
DefaultEndpointTemplate: testEndpointTemplate,
407417
DefaultMTLSEndpoint: testMTLSEndpoint,
418+
DefaultUniverseDomain: "googleapis.com",
408419
},
409420
wantEndpoint: testRegularEndpoint,
410421
},
@@ -415,6 +426,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
415426
DefaultEndpointTemplate: testEndpointTemplate,
416427
DefaultMTLSEndpoint: testMTLSEndpoint,
417428
Endpoint: testOverrideEndpoint,
429+
DefaultUniverseDomain: "googleapis.com",
418430
},
419431
wantEndpoint: testOverrideEndpoint,
420432
},
@@ -425,6 +437,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
425437
DefaultEndpointTemplate: testEndpointTemplate,
426438
DefaultMTLSEndpoint: testMTLSEndpoint,
427439
ClientCertSource: dummyClientCertSource,
440+
DefaultUniverseDomain: "googleapis.com",
428441
},
429442
wantEndpoint: testMTLSEndpoint,
430443
},
@@ -436,6 +449,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
436449
DefaultMTLSEndpoint: testMTLSEndpoint,
437450
ClientCertSource: dummyClientCertSource,
438451
Endpoint: testOverrideEndpoint,
452+
DefaultUniverseDomain: "googleapis.com",
439453
},
440454
wantEndpoint: testOverrideEndpoint,
441455
},
@@ -446,6 +460,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
446460
DefaultEndpointTemplate: testEndpointTemplate,
447461
DefaultMTLSEndpoint: testMTLSEndpoint,
448462
UniverseDomain: testUniverseDomain,
463+
DefaultUniverseDomain: "googleapis.com",
449464
},
450465
wantEndpoint: testUniverseDomainEndpoint,
451466
},
@@ -457,6 +472,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
457472
DefaultMTLSEndpoint: testMTLSEndpoint,
458473
UniverseDomain: testUniverseDomain,
459474
Endpoint: testOverrideEndpoint,
475+
DefaultUniverseDomain: "googleapis.com",
460476
},
461477
wantEndpoint: testOverrideEndpoint,
462478
},
@@ -468,6 +484,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
468484
DefaultMTLSEndpoint: testMTLSEndpoint,
469485
UniverseDomain: testUniverseDomain,
470486
ClientCertSource: dummyClientCertSource,
487+
DefaultUniverseDomain: "googleapis.com",
471488
},
472489
wantErr: errUniverseNotSupportedMTLS,
473490
},
@@ -480,6 +497,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
480497
UniverseDomain: testUniverseDomain,
481498
ClientCertSource: dummyClientCertSource,
482499
Endpoint: testOverrideEndpoint,
500+
DefaultUniverseDomain: "googleapis.com",
483501
},
484502
wantEndpoint: testOverrideEndpoint,
485503
},

internal/creds.go

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"time"
1717

1818
"golang.org/x/oauth2"
19+
"google.golang.org/api/internal/cert"
1920
"google.golang.org/api/internal/impersonate"
2021

2122
"golang.org/x/oauth2/google"
@@ -90,11 +91,11 @@ func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*g
9091

9192
// Determine configurations for the OAuth2 transport, which is separate from the API transport.
9293
// The OAuth2 transport and endpoint will be configured for mTLS if applicable.
93-
clientCertSource, oauth2Endpoint, err := getClientCertificateSourceAndEndpoint(oauth2DialSettings(ds))
94+
clientCertSource, err := getClientCertificateSource(ds)
9495
if err != nil {
9596
return nil, err
9697
}
97-
params.TokenURL = oauth2Endpoint
98+
params.TokenURL = oAuth2Endpoint(clientCertSource)
9899
if clientCertSource != nil {
99100
tlsConfig := &tls.Config{
100101
GetClientCertificate: clientCertSource,
@@ -124,6 +125,13 @@ func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*g
124125
return cred, err
125126
}
126127

128+
func oAuth2Endpoint(clientCertSource cert.Source) string {
129+
if isMTLS(clientCertSource) {
130+
return google.MTLSTokenURL
131+
}
132+
return google.Endpoint.TokenURL
133+
}
134+
127135
func isSelfSignedJWTFlow(data []byte, ds *DialSettings) (bool, error) {
128136
// For non-GDU universe domains, token exchange is impossible and services
129137
// must support self-signed JWTs with scopes.
@@ -196,15 +204,6 @@ func impersonateCredentials(ctx context.Context, creds *google.Credentials, ds *
196204
}, nil
197205
}
198206

199-
// oauth2DialSettings returns the settings to be used by the OAuth2 transport, which is separate from the API transport.
200-
func oauth2DialSettings(ds *DialSettings) *DialSettings {
201-
var ods DialSettings
202-
ods.DefaultEndpoint = google.Endpoint.TokenURL
203-
ods.DefaultMTLSEndpoint = google.MTLSTokenURL
204-
ods.ClientCertSource = ds.ClientCertSource
205-
return &ods
206-
}
207-
208207
// customHTTPClient constructs an HTTPClient using the provided tlsConfig, to support mTLS.
209208
func customHTTPClient(tlsConfig *tls.Config) *http.Client {
210209
trans := baseTransport()

internal/settings.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,27 @@ func (ds *DialSettings) Validate() error {
163163
return nil
164164
}
165165

166-
// UniverseDomain returns the default service domain for a given Cloud universe.
166+
// GetDefaultUniverseDomain returns the default service domain for a given Cloud
167+
// universe, as configured with internaloption.WithDefaultUniverseDomain.
167168
// The default value is "googleapis.com".
169+
func (ds *DialSettings) GetDefaultUniverseDomain() string {
170+
if ds.DefaultUniverseDomain == "" {
171+
return universeDomainDefault
172+
}
173+
return ds.DefaultUniverseDomain
174+
}
175+
176+
// GetUniverseDomain returns the default service domain for a given Cloud
177+
// universe, as configured with option.WithUniverseDomain.
178+
// The default value is the value of GetDefaultUniverseDomain, as configured
179+
// with internaloption.WithDefaultUniverseDomain.
168180
func (ds *DialSettings) GetUniverseDomain() string {
169181
if ds.UniverseDomain == "" {
170-
return universeDomainDefault
182+
return ds.GetDefaultUniverseDomain()
171183
}
172184
return ds.UniverseDomain
173185
}
174186

175187
func (ds *DialSettings) IsUniverseDomainGDU() bool {
176-
return ds.GetUniverseDomain() == universeDomainDefault
188+
return ds.GetUniverseDomain() == ds.GetDefaultUniverseDomain()
177189
}

0 commit comments

Comments
 (0)