Skip to content

Commit ffe4943

Browse files
apexskieraeneasr
andauthored
fix: support allowed_cors_origins with client_secret_post (#3457)
Closes #3456 Co-authored-by: hackerman <[email protected]>
1 parent 97ac03a commit ffe4943

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

x/oauth2cors/cors.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,18 @@ func Middleware(
8181
return true
8282
}
8383

84-
username, _, ok := r.BasicAuth()
85-
if !ok || username == "" {
84+
var clientID string
85+
86+
// if the client uses client_secret_post auth it will provide its client ID in form data
87+
clientID = r.PostFormValue("client_id")
88+
89+
// if the client uses client_secret_basic auth the client ID will be the username component
90+
if clientID == "" {
91+
clientID, _, _ = r.BasicAuth()
92+
}
93+
94+
// otherwise, this may be a bearer auth request, in which case we can introspect the token
95+
if clientID == "" {
8696
token := fosite.AccessTokenFromRequest(r)
8797
if token == "" {
8898
return false
@@ -94,10 +104,10 @@ func Middleware(
94104
return false
95105
}
96106

97-
username = ar.GetClient().GetID()
107+
clientID = ar.GetClient().GetID()
98108
}
99109

100-
cl, err := reg.ClientManager().GetConcreteClient(ctx, username)
110+
cl, err := reg.ClientManager().GetConcreteClient(ctx, clientID)
101111
if err != nil {
102112
return false
103113
}

x/oauth2cors/cors_test.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
package oauth2cors_test
55

66
import (
7+
"bytes"
78
"context"
89
"fmt"
10+
"io"
911
"net/http"
1012
"net/http/httptest"
13+
"net/url"
1114
"testing"
1215
"time"
1316

@@ -37,6 +40,7 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
3740
header http.Header
3841
expectHeader http.Header
3942
method string
43+
body io.Reader
4044
}{
4145
{
4246
d: "should ignore when disabled",
@@ -55,6 +59,36 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
5559
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo", "bar"))}},
5660
expectHeader: http.Header{"Vary": {"Origin"}},
5761
},
62+
{
63+
d: "should reject when post auth client exists but origin not allowed",
64+
prep: func(t *testing.T, r driver.Registry) {
65+
r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true)
66+
r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
67+
68+
// Ignore unique violations
69+
_ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-2", Secret: "bar", AllowedCORSOrigins: []string{"http://not-foobar.com"}})
70+
},
71+
code: http.StatusNotImplemented,
72+
header: http.Header{"Origin": {"http://foobar.com"}, "Content-Type": {"application/x-www-form-urlencoded"}},
73+
expectHeader: http.Header{"Vary": {"Origin"}},
74+
method: http.MethodPost,
75+
body: bytes.NewBufferString(url.Values{"client_id": []string{"foo-2"}}.Encode()),
76+
},
77+
{
78+
d: "should accept when post auth client exists and origin allowed",
79+
prep: func(t *testing.T, r driver.Registry) {
80+
r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true)
81+
r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
82+
83+
// Ignore unique violations
84+
_ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}})
85+
},
86+
code: http.StatusNotImplemented,
87+
header: http.Header{"Origin": {"http://foobar.com"}, "Content-Type": {"application/x-www-form-urlencoded"}},
88+
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
89+
method: http.MethodPost,
90+
body: bytes.NewBufferString(url.Values{"client_id": {"foo-3"}}.Encode()),
91+
},
5892
{
5993
d: "should reject when basic auth client exists but origin not allowed",
6094
prep: func(t *testing.T, r driver.Registry) {
@@ -237,7 +271,7 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
237271
if tc.method != "" {
238272
method = tc.method
239273
}
240-
req, err := http.NewRequest(method, "http://foobar.com/", nil)
274+
req, err := http.NewRequest(method, "http://foobar.com/", tc.body)
241275
require.NoError(t, err)
242276
for k := range tc.header {
243277
req.Header.Set(k, tc.header.Get(k))

0 commit comments

Comments
 (0)