Skip to content

Commit b8f2f47

Browse files
committed
fix: handle wildcards with expiration properly in the schema compiler
1 parent a0dacf6 commit b8f2f47

File tree

3 files changed

+84
-37
lines changed

3 files changed

+84
-37
lines changed

pkg/schema/definition_test.go

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1263,7 +1263,9 @@ func TestTypeSystemAccessors(t *testing.T) {
12631263
`use expiration
12641264
12651265
definition user {}
1266-
definition group {}
1266+
definition group {
1267+
relation member: user
1268+
}
12671269
12681270
caveat somecaveat(somecondition int) {
12691271
somecondition == 42
@@ -1274,6 +1276,12 @@ func TestTypeSystemAccessors(t *testing.T) {
12741276
relation caveated: user with somecaveat | group
12751277
relation expired: user with expiration | group
12761278
relation mixed: user | group with somecaveat and expiration
1279+
relation viewer: user | user:* with expiration
1280+
relation caveated_viewer: user | user:* with somecaveat
1281+
relation membered: group#member with somecaveat and expiration
1282+
relation membered_plain: group#member
1283+
relation caveated_membered: group#member with somecaveat | group#member
1284+
relation expired_membered: group#member with expiration | group#member
12771285
}`,
12781286
map[string]tsTester{
12791287
"resource": func(t *testing.T, vts *ValidatedDefinition) {
@@ -1298,6 +1306,36 @@ func TestTypeSystemAccessors(t *testing.T) {
12981306
require.True(t, traits.AllowsCaveats)
12991307
require.True(t, traits.AllowsExpiration)
13001308

1309+
traits, err = vts.PossibleTraitsForAnySubject("viewer")
1310+
require.NoError(t, err)
1311+
require.False(t, traits.AllowsCaveats)
1312+
require.True(t, traits.AllowsExpiration)
1313+
1314+
traits, err = vts.PossibleTraitsForAnySubject("caveated_viewer")
1315+
require.NoError(t, err)
1316+
require.True(t, traits.AllowsCaveats)
1317+
require.False(t, traits.AllowsExpiration)
1318+
1319+
traits, err = vts.PossibleTraitsForAnySubject("membered")
1320+
require.NoError(t, err)
1321+
require.True(t, traits.AllowsCaveats)
1322+
require.True(t, traits.AllowsExpiration)
1323+
1324+
traits, err = vts.PossibleTraitsForAnySubject("membered_plain")
1325+
require.NoError(t, err)
1326+
require.False(t, traits.AllowsCaveats)
1327+
require.False(t, traits.AllowsExpiration)
1328+
1329+
traits, err = vts.PossibleTraitsForAnySubject("caveated_membered")
1330+
require.NoError(t, err)
1331+
require.True(t, traits.AllowsCaveats)
1332+
require.False(t, traits.AllowsExpiration)
1333+
1334+
traits, err = vts.PossibleTraitsForAnySubject("expired_membered")
1335+
require.NoError(t, err)
1336+
require.False(t, traits.AllowsCaveats)
1337+
require.True(t, traits.AllowsExpiration)
1338+
13011339
_, err = vts.PossibleTraitsForAnySubject("unknown")
13021340
require.Error(t, err)
13031341
})

pkg/schemadsl/compiler/compiler_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,25 @@ func TestCompile(t *testing.T) {
10081008
),
10091009
},
10101010
},
1011+
{
1012+
"wildcard relation with expiration trait",
1013+
withTenantPrefix,
1014+
`use expiration
1015+
1016+
definition simple {
1017+
relation viewer: user:* with expiration
1018+
}`,
1019+
"",
1020+
[]SchemaDefinition{
1021+
namespace.Namespace("sometenant/simple",
1022+
namespace.MustRelation("viewer", nil,
1023+
namespace.WithExpiration(
1024+
namespace.AllowedPublicNamespace("sometenant/user"),
1025+
),
1026+
),
1027+
),
1028+
},
1029+
},
10111030
{
10121031
"duplicate use pragmas",
10131032
withTenantPrefix,

pkg/schemadsl/compiler/translator.go

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -651,76 +651,66 @@ func translateSpecificTypeReference(tctx *translationContext, typeRefNode *dslNo
651651
return nil, typeRefNode.Errorf("%w", err)
652652
}
653653

654-
if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateWildcard) {
655-
ref := &core.AllowedRelation{
656-
Namespace: nspath,
657-
RelationOrWildcard: &core.AllowedRelation_PublicWildcard_{
658-
PublicWildcard: &core.AllowedRelation_PublicWildcard{},
659-
},
660-
}
661-
662-
err = addWithCaveats(tctx, typeRefNode, ref)
663-
if err != nil {
664-
return nil, typeRefNode.Errorf("invalid caveat: %w", err)
665-
}
666-
667-
if !tctx.skipValidate {
668-
if err := ref.Validate(); err != nil {
669-
return nil, typeRefNode.Errorf("invalid type relation: %w", err)
670-
}
671-
}
672-
673-
ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper)
674-
return ref, nil
675-
}
676-
677654
relationName := Ellipsis
678655
if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateRelation) {
679656
relationName, err = typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateRelation)
680657
if err != nil {
681658
return nil, typeRefNode.Errorf("invalid type relation: %w", err)
682659
}
683660
}
684-
685661
ref := &core.AllowedRelation{
686662
Namespace: nspath,
687663
RelationOrWildcard: &core.AllowedRelation_Relation{
688664
Relation: relationName,
689665
},
690666
}
691667

668+
if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateWildcard) {
669+
ref.RelationOrWildcard = &core.AllowedRelation_PublicWildcard_{
670+
PublicWildcard: &core.AllowedRelation_PublicWildcard{},
671+
}
672+
}
673+
692674
// Add the caveat(s), if any.
693675
err = addWithCaveats(tctx, typeRefNode, ref)
694676
if err != nil {
695677
return nil, typeRefNode.Errorf("invalid caveat: %w", err)
696678
}
697679

698680
// Add the expiration trait, if any.
681+
err = addWithExpiration(tctx, typeRefNode, ref)
682+
if err != nil {
683+
return nil, typeRefNode.Errorf("invalid expiration: %w", err)
684+
}
685+
686+
if !tctx.skipValidate {
687+
if err := ref.Validate(); err != nil {
688+
return nil, typeRefNode.Errorf("invalid type relation: %w", err)
689+
}
690+
}
691+
692+
ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper)
693+
return ref, nil
694+
}
695+
696+
func addWithExpiration(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error {
699697
if traitNode, err := typeRefNode.Lookup(dslshape.NodeSpecificReferencePredicateTrait); err == nil {
700698
traitName, err := traitNode.GetString(dslshape.NodeTraitPredicateTrait)
701699
if err != nil {
702-
return nil, typeRefNode.Errorf("invalid trait: %w", err)
700+
return err
703701
}
704702

705703
if traitName != "expiration" {
706-
return nil, typeRefNode.Errorf("invalid trait: %s", traitName)
704+
return fmt.Errorf("invalid trait: %s", traitName)
707705
}
708706

709707
if !slices.Contains(tctx.allowedFlags, "expiration") {
710-
return nil, typeRefNode.Errorf("expiration trait is not allowed")
708+
return fmt.Errorf("expiration trait is not allowed")
711709
}
712710

713711
ref.RequiredExpiration = &core.ExpirationTrait{}
714712
}
715-
716-
if !tctx.skipValidate {
717-
if err := ref.Validate(); err != nil {
718-
return nil, typeRefNode.Errorf("invalid type relation: %w", err)
719-
}
720-
}
721-
722-
ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper)
723-
return ref, nil
713+
return nil
724714
}
725715

726716
func addWithCaveats(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error {

0 commit comments

Comments
 (0)