Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner
return azureVisitor(cfg, refreshableVisitor(inner, opts...)), nil
}
management := azureReuseTokenSource(t, ts, opts...)
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, opts...)), nil
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, false, opts...)), nil
}

func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
Expand Down
2 changes: 1 addition & 1 deletion config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
opts := cacheOptions(cfg)
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, aadEndpoint, env.AzureApplicationID), opts...)
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, aadEndpoint, managementEndpoint), opts...)
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, opts...))
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, false, opts...))
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
}
2 changes: 1 addition & 1 deletion config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (creden
opts := cacheOptions(cfg)
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureApplicationID), opts...)
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureServiceManagementEndpoint()), opts...)
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, opts...))
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, false, opts...))
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
}

Expand Down
2 changes: 1 addition & 1 deletion config/auth_gcp_google_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (credenti
return nil, fmt.Errorf("could not obtain OAuth2 token from JSON: %w", err)
}
logger.Infof(ctx, "Using Google Credentials")
visitor := serviceToServiceVisitor(inner, creds.TokenSource, "X-Databricks-GCP-SA-Access-Token", cacheOptions(cfg)...)
visitor := serviceToServiceVisitor(inner, creds.TokenSource, "X-Databricks-GCP-SA-Access-Token", true, cacheOptions(cfg)...)
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
}

Expand Down
16 changes: 7 additions & 9 deletions config/auth_gcp_google_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (c
return nil, err
}
opts := cacheOptions(cfg)
if cfg.ConfigType() == WorkspaceConfig {
logger.Infof(ctx, "Using Google Default Application Credentials for Workspace")
visitor := refreshableVisitor(inner, opts...)
return credentials.CredentialsProviderFn(visitor), nil
}
// source for generateAccessToken
// Always attempt to create SA token source for the secondary header.
// If it fails, fall back to refreshableVisitor with a warning.
platform, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{
TargetPrincipal: cfg.GoogleServiceAccount,
Scopes: []string{
Expand All @@ -43,10 +39,12 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (c
},
}, c.opts...)
if err != nil {
return nil, err
logger.Warnf(ctx, "Failed to create GCP SA access token source: %v. Proceeding without SA token.", err)
visitor := refreshableVisitor(inner, opts...)
return credentials.CredentialsProviderFn(visitor), nil
}
logger.Infof(ctx, "Using Google Default Application Credentials for Accounts API")
visitor := serviceToServiceVisitor(inner, platform, "X-Databricks-GCP-SA-Access-Token", opts...)
logger.Infof(ctx, "Using Google Default Application Credentials")
visitor := serviceToServiceVisitor(inner, platform, "X-Databricks-GCP-SA-Access-Token", true, opts...)
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
}

Expand Down
13 changes: 10 additions & 3 deletions config/oauth_visitors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/databricks/databricks-sdk-go/config/experimental/auth"
"github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
)

Expand All @@ -19,9 +20,11 @@ func cacheOptions(cfg *Config) []auth.Option {
}

// serviceToServiceVisitor returns a visitor that sets the Authorization header
// to the token from the auth token source and the provided secondary header to
// the token from the secondary token source.
func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHeader string, opts ...auth.Option) func(r *http.Request) error {
// to the token from the primary token source and the provided secondary header
// to the token from the secondary token source. If secondaryOptional is true,
// a failure to get the secondary token logs a warning and skips the header
// instead of returning an error.
func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHeader string, secondaryOptional bool, opts ...auth.Option) func(r *http.Request) error {
refreshableAuth := auth.NewCachedTokenSource(authconv.AuthTokenSource(primary), opts...)
refreshableSecondary := auth.NewCachedTokenSource(authconv.AuthTokenSource(secondary), opts...)
return func(r *http.Request) error {
Expand All @@ -33,6 +36,10 @@ func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHea

cloud, err := refreshableSecondary.Token(context.Background())
if err != nil {
if secondaryOptional {
logger.Warnf(r.Context(), "Failed to get secondary token for %s header: %v. Skipping.", secondaryHeader, err)
return nil
}
return fmt.Errorf("cloud token: %w", err)
}
r.Header.Set(secondaryHeader, cloud.AccessToken)
Expand Down
62 changes: 62 additions & 0 deletions config/oauth_visitors_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package config

import (
"fmt"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

Expand All @@ -31,3 +34,62 @@ func TestAzureReuseTokenSource(t *testing.T) {
assert.NoError(t, err)
assert.False(t, token.Valid())
}

type staticTokenSource struct {
token *oauth2.Token
err error
}

func (s *staticTokenSource) Token() (*oauth2.Token, error) {
return s.token, s.err
}

func TestServiceToServiceVisitorWithFallback_BothSucceed(t *testing.T) {
primary := &staticTokenSource{token: &oauth2.Token{AccessToken: "primary-token"}}
secondary := &staticTokenSource{token: &oauth2.Token{AccessToken: "secondary-token"}}
visitor := serviceToServiceVisitor(primary, secondary, "X-Secondary", true)

req, err := http.NewRequest("GET", "https://example.com", nil)
require.NoError(t, err)
err = visitor(req)
require.NoError(t, err)
assert.Equal(t, "Bearer primary-token", req.Header.Get("Authorization"))
assert.Equal(t, "secondary-token", req.Header.Get("X-Secondary"))
}

func TestServiceToServiceVisitorWithFallback_SecondaryFails_SkipsHeader(t *testing.T) {
primary := &staticTokenSource{token: &oauth2.Token{AccessToken: "primary-token"}}
secondary := &staticTokenSource{err: fmt.Errorf("secondary failed")}
visitor := serviceToServiceVisitor(primary, secondary, "X-Secondary", true)

req, err := http.NewRequest("GET", "https://example.com", nil)
require.NoError(t, err)
err = visitor(req)
require.NoError(t, err)
assert.Equal(t, "Bearer primary-token", req.Header.Get("Authorization"))
assert.Empty(t, req.Header.Get("X-Secondary"))
}

func TestServiceToServiceVisitor_SecondaryFails_NotOptional_ReturnsError(t *testing.T) {
primary := &staticTokenSource{token: &oauth2.Token{AccessToken: "primary-token"}}
secondary := &staticTokenSource{err: fmt.Errorf("secondary failed")}
visitor := serviceToServiceVisitor(primary, secondary, "X-Secondary", false)

req, err := http.NewRequest("GET", "https://example.com", nil)
require.NoError(t, err)
err = visitor(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "cloud token")
}

func TestServiceToServiceVisitorWithFallback_PrimaryFails_ReturnsError(t *testing.T) {
primary := &staticTokenSource{err: fmt.Errorf("primary failed")}
secondary := &staticTokenSource{token: &oauth2.Token{AccessToken: "secondary-token"}}
visitor := serviceToServiceVisitor(primary, secondary, "X-Secondary", true)

req, err := http.NewRequest("GET", "https://example.com", nil)
require.NoError(t, err)
err = visitor(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "inner token")
}
Loading