Skip to content

Commit a589a03

Browse files
authored
feat: Add recursive iterators (#2621)
1 parent b59dc8c commit a589a03

21 files changed

+971
-61
lines changed

pkg/query/alias.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,11 @@ func (a *Alias) Explain() Explain {
138138
SubExplain: []Explain{a.subIt.Explain()},
139139
}
140140
}
141+
142+
func (a *Alias) Subiterators() []Iterator {
143+
return []Iterator{a.subIt}
144+
}
145+
146+
func (a *Alias) ReplaceSubiterators(newSubs []Iterator) (Iterator, error) {
147+
return &Alias{relation: a.relation, subIt: newSubs[0]}, nil
148+
}

pkg/query/arrow.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,11 @@ func (a *Arrow) Explain() Explain {
138138
SubExplain: []Explain{a.left.Explain(), a.right.Explain()},
139139
}
140140
}
141+
142+
func (a *Arrow) Subiterators() []Iterator {
143+
return []Iterator{a.left, a.right}
144+
}
145+
146+
func (a *Arrow) ReplaceSubiterators(newSubs []Iterator) (Iterator, error) {
147+
return &Arrow{left: newSubs[0], right: newSubs[1]}, nil
148+
}

pkg/query/build_tree.go

Lines changed: 93 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,35 @@
11
package query
22

33
import (
4-
"errors"
54
"fmt"
65

76
core "github.com/authzed/spicedb/pkg/proto/core/v1"
87
"github.com/authzed/spicedb/pkg/schema/v2"
8+
"github.com/authzed/spicedb/pkg/spiceerrors"
99
"github.com/authzed/spicedb/pkg/tuple"
1010
)
1111

12+
type recursiveSentinelInfo struct {
13+
sentinel *RecursiveSentinel
14+
definitionName string
15+
relationName string
16+
}
17+
1218
type iteratorBuilder struct {
13-
schema *schema.Schema
14-
seen map[string]bool
15-
collectedCaveats []*core.ContextualizedCaveat // Collect caveats to combine with AND logic
19+
schema *schema.Schema
20+
building map[string]bool // Track what's currently being built (call stack)
21+
collectedCaveats []*core.ContextualizedCaveat // Collect caveats to combine with AND logic
22+
recursiveSentinels []*recursiveSentinelInfo // Track recursion points for wrapping in RecursiveIterator
1623
}
1724

1825
// BuildIteratorFromSchema takes a schema and walks the schema tree for a given definition namespace and a relationship or
1926
// permission therein. From this, it generates an iterator tree, rooted on that relationship.
2027
func BuildIteratorFromSchema(fullSchema *schema.Schema, definitionName string, relationName string) (Iterator, error) {
2128
builder := &iteratorBuilder{
22-
schema: fullSchema,
23-
seen: make(map[string]bool),
24-
collectedCaveats: make([]*core.ContextualizedCaveat, 0),
29+
schema: fullSchema,
30+
building: make(map[string]bool),
31+
collectedCaveats: make([]*core.ContextualizedCaveat, 0),
32+
recursiveSentinels: make([]*recursiveSentinelInfo, 0),
2533
}
2634
iterator, err := builder.buildIteratorFromSchemaInternal(definitionName, relationName, true)
2735
if err != nil {
@@ -33,27 +41,80 @@ func BuildIteratorFromSchema(fullSchema *schema.Schema, definitionName string, r
3341
for _, caveat := range builder.collectedCaveats {
3442
result = NewCaveatIterator(result, caveat)
3543
}
44+
45+
// Note: RecursiveIterator wrapping happens at the recursion point,
46+
// not at the top level. So we shouldn't have any sentinels left here.
47+
if len(builder.recursiveSentinels) > 0 {
48+
// This would be an error - sentinels should have been wrapped already
49+
return nil, spiceerrors.MustBugf("unwrapped sentinels remaining: %d", len(builder.recursiveSentinels))
50+
}
51+
3652
return result, nil
3753
}
3854

3955
func (b *iteratorBuilder) buildIteratorFromSchemaInternal(definitionName string, relationName string, withSubRelations bool) (Iterator, error) {
40-
id := fmt.Sprintf("%s#%s:%v", definitionName, relationName, withSubRelations)
41-
if b.seen[id] {
42-
return nil, errors.New("recursive schema iterators are as yet unsupported")
56+
id := fmt.Sprintf("%s#%s", definitionName, relationName)
57+
58+
// Check if we're currently building this (true recursion)
59+
// Check both with the same flag and opposite flag, since recursion can cross the boundary
60+
if b.building[id] {
61+
// Recursion detected - create sentinel and remember where
62+
sentinel := NewRecursiveSentinel(definitionName, relationName, withSubRelations)
63+
// Track this sentinel with its location info
64+
sentinelInfo := &recursiveSentinelInfo{
65+
sentinel: sentinel,
66+
definitionName: definitionName,
67+
relationName: relationName,
68+
}
69+
b.recursiveSentinels = append(b.recursiveSentinels, sentinelInfo)
70+
return sentinel, nil
4371
}
44-
b.seen[id] = true
72+
73+
// Mark as currently building
74+
b.building[id] = true
75+
// Track the position in the sentinels list before building
76+
sentinelsLenBefore := len(b.recursiveSentinels)
4577

4678
def, ok := b.schema.Definitions()[definitionName]
4779
if !ok {
80+
// Remove before returning error
81+
delete(b.building, id)
4882
return nil, fmt.Errorf("BuildIteratorFromSchema: couldn't find a schema definition named `%s`", definitionName)
4983
}
84+
85+
var result Iterator
86+
var err error
5087
if p, ok := def.Permissions()[relationName]; ok {
51-
return b.buildIteratorFromPermission(p)
88+
result, err = b.buildIteratorFromPermission(p)
89+
} else if r, ok := def.Relations()[relationName]; ok {
90+
result, err = b.buildIteratorFromRelation(r, withSubRelations)
91+
} else {
92+
err = fmt.Errorf("BuildIteratorFromSchema: couldn't find a relation or permission named `%s` in definition `%s`", relationName, definitionName)
93+
}
94+
95+
// Remove from building after we're done (allows reuse in other branches)
96+
delete(b.building, id)
97+
98+
if err != nil {
99+
return nil, err
52100
}
53-
if r, ok := def.Relations()[relationName]; ok {
54-
return b.buildIteratorFromRelation(r, withSubRelations)
101+
102+
// Check if any NEW sentinels were added while building this
103+
// If so, this subtree contains recursion and should be wrapped
104+
sentinelsAdded := b.recursiveSentinels[sentinelsLenBefore:]
105+
if len(sentinelsAdded) > 0 {
106+
// Extract just the sentinel objects
107+
sentinels := make([]*RecursiveSentinel, len(sentinelsAdded))
108+
for i, info := range sentinelsAdded {
109+
sentinels[i] = info.sentinel
110+
}
111+
// Wrap this subtree in RecursiveIterator
112+
result = NewRecursiveIterator(result)
113+
// Remove these sentinels from the list since we've wrapped them
114+
b.recursiveSentinels = b.recursiveSentinels[:sentinelsLenBefore]
55115
}
56-
return nil, fmt.Errorf("BuildIteratorFromSchema: couldn't find a relation or permission named `%s` in definition `%s`", relationName, definitionName)
116+
117+
return result, nil
57118
}
58119

59120
func (b *iteratorBuilder) buildIteratorFromRelation(r *schema.Relation, withSubRelations bool) (Iterator, error) {
@@ -177,12 +238,22 @@ func (b *iteratorBuilder) buildBaseRelationIterator(br *schema.BaseRelation, wit
177238
return base, nil
178239
}
179240

180-
// For relation references in schema definitions (like group#member in "relation member: user | group#member"),
181-
// we always need to resolve what the referenced relation means, even if withSubRelations=false.
182-
// The withSubRelations flag controls whether we build arrows for nested traversal, but relation
183-
// references in the schema definition itself must always be resolved.
184-
// However, we still need to prevent infinite recursion.
185-
if !withSubRelations {
241+
// Check if we need to expand subrelations
242+
// We always need to expand if withSubRelations=true (normal case)
243+
// OR if the subrelation might be recursive (same type as something we're building)
244+
needsExpansion := withSubRelations
245+
246+
if !needsExpansion {
247+
// Check if this might be a recursive subrelation
248+
// by seeing if the subrelation type matches any definition we're currently building
249+
subrelID := fmt.Sprintf("%s#%s", br.Type(), br.Subrelation())
250+
if b.building[subrelID] {
251+
// This is recursive! We need to expand to detect it
252+
needsExpansion = true
253+
}
254+
}
255+
256+
if !needsExpansion {
186257
return base, nil
187258
}
188259

@@ -191,8 +262,7 @@ func (b *iteratorBuilder) buildBaseRelationIterator(br *schema.BaseRelation, wit
191262
return nil, err
192263
}
193264

194-
// We must check the effective arrow of a subrelation if we have one and subrelations are enabled
195-
// (subrelations are disabled in cases of actual arrows)
265+
// We must check the effective arrow of a subrelation if we have one
196266
union := NewUnion()
197267
union.addSubIterator(base)
198268

pkg/query/build_tree_test.go

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -139,25 +139,41 @@ func TestBuildTreeRecursion(t *testing.T) {
139139

140140
require := require.New(t)
141141

142-
// Create a simple schema with potential recursion using group membership
143-
userDef := testfixtures.UserNS.CloneVT()
144-
142+
// Create a proper recursive group hierarchy schema:
143+
// definition group {
144+
// relation parent: group
145+
// permission member = parent->member
146+
// }
147+
// This creates recursion: computing member arrows through parent groups,
148+
// which recursively compute their own member permission
145149
groupDef := namespace.Namespace("group",
150+
namespace.MustRelation("parent", nil,
151+
namespace.AllowedRelation("group", "..."),
152+
),
146153
namespace.MustRelation("member",
147154
namespace.Union(
148-
namespace.ComputedUserset("member"),
155+
namespace.TupleToUserset("parent", "member"),
149156
),
150157
),
151158
)
152159

153-
objectDefs := []*corev1.NamespaceDefinition{userDef, groupDef}
160+
objectDefs := []*corev1.NamespaceDefinition{groupDef}
154161
dsSchema, err := schema.BuildSchemaFromDefinitions(objectDefs, nil)
155162
require.NoError(err)
156163

157-
// This should detect recursion and return an error
158-
_, err = BuildIteratorFromSchema(dsSchema, "group", "member")
159-
require.Error(err)
160-
require.Contains(err.Error(), "recursive schema iterators are as yet unsupported")
164+
// This should detect recursion and create a RecursiveIterator
165+
// The arrow operation parent->member creates recursion: group->parent->member->parent->member...
166+
it, err := BuildIteratorFromSchema(dsSchema, "group", "member")
167+
require.NoError(err)
168+
require.NotNil(it)
169+
170+
// Verify it's wrapped in a RecursiveIterator
171+
_, isRecursive := it.(*RecursiveIterator)
172+
require.True(isRecursive, "Expected RecursiveIterator for recursive arrow operation")
173+
174+
// Verify the explain output
175+
explain := it.Explain()
176+
require.Equal("RecursiveIterator", explain.Name)
161177
}
162178

163179
func TestBuildTreeArrowOperation(t *testing.T) {
@@ -572,20 +588,16 @@ func TestBuildTreeSubrelationHandling(t *testing.T) {
572588

573589
t.Run("Base Relation with Ellipsis Subrelation", func(t *testing.T) {
574590
t.Parallel()
575-
// Test base relation with ellipsis - should return just the base relation
591+
// Test that base relations with ellipsis (group:...) work correctly with arrows
576592
groupDef := namespace.Namespace("group",
577-
namespace.MustRelation("member",
578-
namespace.Union(
579-
namespace.ComputedUserset("member"),
580-
),
581-
),
593+
namespace.MustRelation("member", nil, namespace.AllowedRelation("user", "...")),
582594
)
583595

584596
docDef := namespace.Namespace("document",
585-
namespace.MustRelation("parent", nil, namespace.AllowedRelation("document", "...")),
597+
namespace.MustRelation("parent", nil, namespace.AllowedRelation("group", "...")), // Ellipsis on group
586598
namespace.MustRelation("viewer",
587599
namespace.Union(
588-
namespace.TupleToUserset("parent", "viewer"),
600+
namespace.TupleToUserset("parent", "member"), // Arrow to group's member relation
589601
),
590602
),
591603
)
@@ -594,9 +606,21 @@ func TestBuildTreeSubrelationHandling(t *testing.T) {
594606
dsSchema, err := schema.BuildSchemaFromDefinitions(objectDefs, nil)
595607
require.NoError(err)
596608

597-
_, err = BuildIteratorFromSchema(dsSchema, "document", "viewer")
598-
require.Error(err)
599-
require.Contains(err.Error(), "recursive schema iterators are as yet unsupported", "Self-referential schema should be detected")
609+
// Should create an alias wrapping union with arrow
610+
it, err := BuildIteratorFromSchema(dsSchema, "document", "viewer")
611+
require.NoError(err)
612+
require.NotNil(it)
613+
614+
// Verify structure has arrow operation
615+
explain := it.Explain()
616+
explainStr := explain.String()
617+
require.Contains(explainStr, "Arrow", "Expected arrow operation for tuple-to-userset")
618+
619+
// Test execution doesn't crash
620+
relSeq, err := ctx.Check(it, []Object{NewObject("document", "test_doc")}, NewObject("user", "alice").WithEllipses())
621+
require.NoError(err)
622+
_, err = CollectAll(relSeq)
623+
require.NoError(err)
600624
})
601625

602626
t.Run("Base Relation with Specific Subrelation", func(t *testing.T) {
@@ -653,9 +677,14 @@ func TestBuildTreeSubrelationHandling(t *testing.T) {
653677
dsSchema, err := schema.BuildSchemaFromDefinitions(objectDefs, nil)
654678
require.NoError(err)
655679

656-
_, err = BuildIteratorFromSchema(dsSchema, "document", "viewer")
657-
require.Error(err)
658-
require.Contains(err.Error(), "recursive schema iterators are as yet unsupported", "Self-referential schema should be detected")
680+
// Should create RecursiveIterator for arrow recursion
681+
it, err := BuildIteratorFromSchema(dsSchema, "document", "viewer")
682+
require.NoError(err)
683+
require.NotNil(it)
684+
685+
// Should be wrapped in RecursiveIterator
686+
_, isRecursive := it.(*RecursiveIterator)
687+
require.True(isRecursive, "Expected RecursiveIterator for arrow recursion")
659688
})
660689

661690
t.Run("Base Relation with Missing Subrelation Definition", func(t *testing.T) {

pkg/query/caveat.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,14 @@ func (c *CaveatIterator) Explain() Explain {
194194
}
195195
}
196196

197+
func (c *CaveatIterator) Subiterators() []Iterator {
198+
return []Iterator{c.subiterator}
199+
}
200+
201+
func (c *CaveatIterator) ReplaceSubiterators(newSubs []Iterator) (Iterator, error) {
202+
return &CaveatIterator{subiterator: newSubs[0], caveat: c.caveat}, nil
203+
}
204+
197205
// buildExplainInfo creates detailed explanation information for the caveat iterator
198206
func (c *CaveatIterator) buildExplainInfo() string {
199207
if c.caveat == nil {

pkg/query/context.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,13 @@ func (t *TraceLogger) DumpTrace() string {
123123
// Context is the concrete type that contains the overall handles, and uses the executor as a strategy for continuing execution.
124124
type Context struct {
125125
context.Context
126-
Executor Executor
127-
Datastore datastore.ReadOnlyDatastore
128-
Revision datastore.Revision
129-
CaveatContext map[string]any
130-
CaveatRunner *caveats.CaveatRunner
131-
TraceLogger *TraceLogger // For debugging iterator execution
126+
Executor Executor
127+
Datastore datastore.ReadOnlyDatastore
128+
Revision datastore.Revision
129+
CaveatContext map[string]any
130+
CaveatRunner *caveats.CaveatRunner
131+
TraceLogger *TraceLogger // For debugging iterator execution
132+
MaxRecursionDepth int // Maximum depth for recursive iterators (0 = use default of 10)
132133
}
133134

134135
func (ctx *Context) TraceStep(it Iterator, step string, data ...any) {

pkg/query/datastore.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,11 @@ func (r *RelationIterator) Explain() Explain {
251251
r.base.Caveat() != "", r.base.Expiration()),
252252
}
253253
}
254+
255+
func (r *RelationIterator) Subiterators() []Iterator {
256+
return nil
257+
}
258+
259+
func (r *RelationIterator) ReplaceSubiterators(newSubs []Iterator) (Iterator, error) {
260+
return nil, spiceerrors.MustBugf("Trying to replace a leaf RelationIterator's subiterators")
261+
}

pkg/query/exclusion.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,11 @@ func (e *Exclusion) Explain() Explain {
158158
},
159159
}
160160
}
161+
162+
func (e *Exclusion) Subiterators() []Iterator {
163+
return []Iterator{e.mainSet, e.excluded}
164+
}
165+
166+
func (e *Exclusion) ReplaceSubiterators(newSubs []Iterator) (Iterator, error) {
167+
return &Exclusion{mainSet: newSubs[0], excluded: newSubs[1]}, nil
168+
}

pkg/query/fixed.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package query
22

3-
import "fmt"
3+
import (
4+
"fmt"
5+
6+
"github.com/authzed/spicedb/pkg/spiceerrors"
7+
)
48

59
// FixedIterator represents a fixed set of pre-computed paths.
610
// This is often useful for testing, but can also be used in rare situations
@@ -98,3 +102,11 @@ func (f *FixedIterator) Clone() Iterator {
98102
paths: clonedPaths,
99103
}
100104
}
105+
106+
func (f *FixedIterator) Subiterators() []Iterator {
107+
return nil
108+
}
109+
110+
func (f *FixedIterator) ReplaceSubiterators(newSubs []Iterator) (Iterator, error) {
111+
return nil, spiceerrors.MustBugf("Trying to replace a leaf FixedIterator's subiterators")
112+
}

0 commit comments

Comments
 (0)