Skip to content

Commit a21e945

Browse files
committed
fix: only query access tokens by hashed signature
1 parent 0b56f53 commit a21e945

File tree

4 files changed

+85
-69
lines changed

4 files changed

+85
-69
lines changed

persistence/sql/persister_nid_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,24 @@ import (
4040
type PersisterTestSuite struct {
4141
suite.Suite
4242
registries map[string]driver.Registry
43-
clean func(*testing.T)
4443
t1 context.Context
4544
t2 context.Context
4645
t1NID uuid.UUID
4746
t2NID uuid.UUID
4847
}
4948

50-
var _ PersisterTestSuite = PersisterTestSuite{}
49+
var _ interface {
50+
suite.SetupAllSuite
51+
suite.TearDownTestSuite
52+
} = (*PersisterTestSuite)(nil)
5153

5254
func (s *PersisterTestSuite) SetupSuite() {
5355
s.registries = map[string]driver.Registry{
5456
"memory": internal.NewRegistrySQLFromURL(s.T(), dbal.NewSQLiteTestDatabase(s.T()), true, &contextx.Default{}),
5557
}
5658

5759
if !testing.Short() {
58-
s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], s.clean = internal.ConnectDatabases(s.T(), true, &contextx.Default{})
60+
s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = internal.ConnectDatabases(s.T(), true, &contextx.Default{})
5961
}
6062

6163
s.t1NID, s.t2NID = uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4())
@@ -558,11 +560,11 @@ func (s *PersisterTestSuite) DeleteAccessTokenSession() {
558560
require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t2, sig))
559561

560562
actual := persistencesql.OAuth2RequestSQL{Table: "access"}
561-
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
563+
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
562564
require.Equal(t, s.t1NID, actual.NID)
563565

564566
require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t1, sig))
565-
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
567+
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
566568
})
567569
}
568570
}

persistence/sql/persister_oauth2.go

Lines changed: 77 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func (r OAuth2RequestSQL) TableName() string {
6767
return "hydra_oauth2_" + string(r.Table)
6868
}
6969

70-
func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature string, r fosite.Requester, table tableName) (*OAuth2RequestSQL, error) {
70+
func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, r fosite.Requester, table tableName) (*OAuth2RequestSQL, error) {
7171
subject := ""
7272
if r.GetSession() == nil {
7373
p.l.Debugf("Got an empty session in sqlSchemaFromRequest")
@@ -101,7 +101,7 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin
101101
return &OAuth2RequestSQL{
102102
Request: r.GetID(),
103103
ConsentChallenge: challenge,
104-
ID: p.hashSignature(ctx, rawSignature, table),
104+
ID: signature,
105105
RequestedAt: r.GetRequestedAt(),
106106
Client: r.GetClient().GetID(),
107107
Scopes: strings.Join(r.GetRequestedScopes(), "|"),
@@ -160,20 +160,6 @@ func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session
160160
}, nil
161161
}
162162

163-
// SignatureHash hashes the signature to prevent errors where the signature is
164-
// longer than 128 characters (and thus doesn't fit into the pk).
165-
func SignatureHash(signature string) string {
166-
return fmt.Sprintf("%x", sha512.Sum384([]byte(signature)))
167-
}
168-
169-
// hashSignature prevents errors where the signature is longer than 128 characters (and thus doesn't fit into the pk).
170-
func (p *Persister) hashSignature(_ context.Context, signature string, table tableName) string {
171-
if table == sqlTableAccess {
172-
return SignatureHash(signature)
173-
}
174-
return signature
175-
}
176-
177163
func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) {
178164
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid")
179165
defer otelx.End(span, &err)
@@ -228,7 +214,7 @@ func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.Bl
228214
return sqlcon.HandleError(p.CreateWithNetwork(ctx, jti))
229215
}
230216

231-
func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) (err error) {
217+
func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) error {
232218
req, err := p.sqlSchemaFromRequest(ctx, signature, requester, table)
233219
if err != nil {
234220
return err
@@ -242,28 +228,21 @@ func (p *Persister) createSession(ctx context.Context, signature string, request
242228
return nil
243229
}
244230

245-
func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) {
246-
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findSessionBySignature")
247-
defer otelx.End(span, &err)
248-
231+
func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (fosite.Requester, error) {
249232
r := OAuth2RequestSQL{Table: table}
250-
251-
// We look for the signature as well as the hash of the signature here.
252-
// This is because we now always store the hash of the signature in the database,
253-
// regardless of the type of the signature. In previous versions, we only stored
254-
// the hash of the signature for JWT tokens.
255-
//
256-
// This code will be removed in a future version.
257-
err = p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r)
233+
err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r)
258234
if errors.Is(err, sql.ErrNoRows) {
259235
return nil, errorsx.WithStack(fosite.ErrNotFound)
260-
} else if err != nil {
236+
}
237+
if err != nil {
261238
return nil, sqlcon.HandleError(err)
262-
} else if !r.Active {
239+
}
240+
if !r.Active {
263241
fr, err := r.toRequest(ctx, session, p)
264242
if err != nil {
265243
return nil, err
266-
} else if table == sqlTableCode {
244+
}
245+
if table == sqlTableCode {
267246
return fr, errorsx.WithStack(fosite.ErrInvalidatedAuthorizeCode)
268247
}
269248
return fr, errorsx.WithStack(fosite.ErrInactiveToken)
@@ -272,46 +251,35 @@ func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature str
272251
return r.toRequest(ctx, session, p)
273252
}
274253

275-
func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) (err error) {
276-
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionBySignature")
277-
defer otelx.End(span, &err)
278-
279-
signature = p.hashSignature(ctx, signature, table)
280-
281-
// We look for the signature as well as the hash of the signature here.
282-
// This is because we now always store the hash of the signature in the database,
283-
// regardless of the type of the signature. In previous versions, we only stored
284-
// the hash of the signature for JWT tokens.
285-
//
286-
// This code will be removed in a future version.
287-
err = sqlcon.HandleError(
254+
func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) error {
255+
err := sqlcon.HandleError(
288256
p.QueryWithNetwork(ctx).
289-
Where("signature IN (?, ?)", signature, SignatureHash(signature)).
257+
Where("signature = ?", signature).
290258
Delete(&OAuth2RequestSQL{Table: table}))
291-
292259
if errors.Is(err, sqlcon.ErrNoRows) {
293260
return errorsx.WithStack(fosite.ErrNotFound)
294-
} else if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
261+
}
262+
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
295263
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
296-
} else if err != nil {
297-
return err
298264
}
299-
return nil
265+
return err
300266
}
301267

302268
func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, table tableName) (err error) {
303269
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionByRequestID")
304270
defer otelx.End(span, &err)
305271

306-
/* #nosec G201 table is static */
307-
if err := p.QueryWithNetwork(ctx).
272+
err = p.QueryWithNetwork(ctx).
308273
Where("request_id=?", id).
309-
Delete(&OAuth2RequestSQL{Table: table}); errors.Is(err, sql.ErrNoRows) {
274+
Delete(&OAuth2RequestSQL{Table: table})
275+
if errors.Is(err, sql.ErrNoRows) {
310276
return errorsx.WithStack(fosite.ErrNotFound)
311-
} else if err := sqlcon.HandleError(err); err != nil {
277+
}
278+
if err := sqlcon.HandleError(err); err != nil {
312279
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
313280
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
314-
} else if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock?
281+
}
282+
if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock?
315283
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
316284
}
317285
return err
@@ -356,14 +324,20 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur
356324
return sqlcon.HandleError(
357325
p.Connection(ctx).
358326
RawQuery(
359-
fmt.Sprintf("UPDATE %s SET active=false WHERE signature=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()),
327+
fmt.Sprintf("UPDATE %s SET active = false WHERE signature = ? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()),
360328
signature,
361329
p.NetworkID(ctx),
362330
).
363331
Exec(),
364332
)
365333
}
366334

335+
// SignatureHash hashes the signature to prevent errors where the signature is
336+
// longer than 128 characters (and thus doesn't fit into the pk).
337+
func SignatureHash(signature string) string {
338+
return fmt.Sprintf("%x", sha512.Sum384([]byte(signature)))
339+
}
340+
367341
func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) {
368342
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateAccessTokenSession")
369343
defer otelx.End(span, &err)
@@ -372,19 +346,62 @@ func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature stri
372346
append(toEventOptions(requester), events.WithGrantType(requester.GetRequestForm().Get("grant_type")))...,
373347
)
374348

375-
return p.createSession(ctx, signature, requester, sqlTableAccess)
349+
return p.createSession(ctx, SignatureHash(signature), requester, sqlTableAccess)
376350
}
377351

378352
func (p *Persister) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
379353
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAccessTokenSession")
380354
defer otelx.End(span, &err)
381-
return p.findSessionBySignature(ctx, signature, session, sqlTableAccess)
355+
356+
r := OAuth2RequestSQL{Table: sqlTableAccess}
357+
err = p.QueryWithNetwork(ctx).Where("signature = ?", SignatureHash(signature)).First(&r)
358+
if errors.Is(err, sql.ErrNoRows) {
359+
// Backwards compatibility: we previously did not always hash the
360+
// signature before inserting. In case there are still very old (but
361+
// valid) access tokens in the database, this should get them.
362+
err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r)
363+
if errors.Is(err, sql.ErrNoRows) {
364+
return nil, errorsx.WithStack(fosite.ErrNotFound)
365+
}
366+
}
367+
if err != nil {
368+
return nil, sqlcon.HandleError(err)
369+
}
370+
if !r.Active {
371+
fr, err := r.toRequest(ctx, session, p)
372+
if err != nil {
373+
return nil, err
374+
}
375+
return fr, errorsx.WithStack(fosite.ErrInactiveToken)
376+
}
377+
378+
return r.toRequest(ctx, session, p)
382379
}
383380

384381
func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) {
385382
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokenSession")
386383
defer otelx.End(span, &err)
387-
return p.deleteSessionBySignature(ctx, signature, sqlTableAccess)
384+
385+
err = sqlcon.HandleError(
386+
p.QueryWithNetwork(ctx).
387+
Where("signature = ?", SignatureHash(signature)).
388+
Delete(&OAuth2RequestSQL{Table: sqlTableAccess}))
389+
if errors.Is(err, sqlcon.ErrNoRows) {
390+
// Backwards compatibility: we previously did not always hash the
391+
// signature before inserting. In case there are still very old (but
392+
// valid) access tokens in the database, this should get them.
393+
err = sqlcon.HandleError(
394+
p.QueryWithNetwork(ctx).
395+
Where("signature = ?", signature).
396+
Delete(&OAuth2RequestSQL{Table: sqlTableAccess}))
397+
if errors.Is(err, sqlcon.ErrNoRows) {
398+
return errorsx.WithStack(fosite.ErrNotFound)
399+
}
400+
}
401+
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
402+
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
403+
}
404+
return err
388405
}
389406

390407
func toEventOptions(requester fosite.Requester) []trace.EventOption {

x/audit_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ func TestLogAudit(t *testing.T) {
4343
l.Logger.Out = buf
4444
LogAudit(r, tc.message, l)
4545

46-
t.Logf("%s", buf.String())
47-
4846
assert.Contains(t, buf.String(), "audience=audit")
4947
for _, expectContain := range tc.expectContains {
5048
assert.Contains(t, buf.String(), expectContain)

x/clean_sql.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
)
1111

1212
func DeleteHydraRows(t *testing.T, c *pop.Connection) {
13-
t.Logf("Deleting hydra rows in database: %s", c.Dialect.Name())
1413
for _, tb := range []string{
1514
"hydra_oauth2_access",
1615
"hydra_oauth2_refresh",
@@ -57,7 +56,7 @@ func CleanSQLPop(t *testing.T, c *pop.Connection) {
5756
"schema_migration",
5857
} {
5958
if err := c.RawQuery("DROP TABLE IF EXISTS " + tb).Exec(); err != nil {
60-
t.Logf(`Unable to clean up table "%s": %s`, tb, err)
59+
t.Fatalf(`Unable to clean up table "%s": %s`, tb, err)
6160
}
6261
}
6362
t.Logf("Successfully cleaned up database: %s", c.Dialect.Name())

0 commit comments

Comments
 (0)