Skip to content

Commit 72f5cb7

Browse files
committed
Add min_leaves parameter to st.recursive()
Add a new min_leaves parameter to recursive() that uses rejection sampling to ensure generated recursive structures have at least the specified number of leaf nodes. This is useful for users who want to generate larger recursive structures early in testing. Closes #4205
1 parent eaecbba commit 72f5cb7

4 files changed

Lines changed: 87 additions & 15 deletions

File tree

hypothesis-python/RELEASE.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
RELEASE_TYPE: minor
2+
3+
This release adds a ``min_leaves`` argument to :func:`~hypothesis.strategies.recursive`,
4+
which ensures that generated recursive structures have at least the specified number
5+
of leaf nodes (:issue:`4205`).

hypothesis-python/src/hypothesis/strategies/_internal/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1833,12 +1833,16 @@ def recursive(
18331833
base: SearchStrategy[Ex],
18341834
extend: Callable[[SearchStrategy[Any]], SearchStrategy[T]],
18351835
*,
1836+
min_leaves: int = 1,
18361837
max_leaves: int = 100,
18371838
) -> SearchStrategy[T | Ex]:
18381839
"""base: A strategy to start from.
18391840
18401841
extend: A function which takes a strategy and returns a new strategy.
18411842
1843+
min_leaves: The minimum number of elements to be drawn from base on a given
1844+
run.
1845+
18421846
max_leaves: The maximum number of elements to be drawn from base on a given
18431847
run.
18441848
@@ -1857,7 +1861,7 @@ def recursive(
18571861
18581862
"""
18591863

1860-
return RecursiveStrategy(base, extend, max_leaves)
1864+
return RecursiveStrategy(base, extend, min_leaves, max_leaves)
18611865

18621866

18631867
class PermutationStrategy(SearchStrategy):

hypothesis-python/src/hypothesis/strategies/_internal/recursive.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

1111
import threading
12+
import warnings
1213
from contextlib import contextmanager
1314

14-
from hypothesis.errors import InvalidArgument
15-
from hypothesis.internal.reflection import get_pretty_function_description
15+
from hypothesis.errors import HypothesisWarning, InvalidArgument
16+
from hypothesis.internal.reflection import (
17+
get_pretty_function_description,
18+
is_identity_function,
19+
)
1620
from hypothesis.internal.validation import check_type
1721
from hypothesis.strategies._internal.strategies import (
1822
OneOfStrategy,
@@ -72,24 +76,36 @@ def capped(self, max_templates):
7276

7377

7478
class RecursiveStrategy(SearchStrategy):
75-
def __init__(self, base, extend, max_leaves):
79+
def __init__(self, base, extend, min_leaves, max_leaves):
7680
super().__init__()
81+
self.min_leaves = min_leaves
7782
self.max_leaves = max_leaves
7883
self.base = base
7984
self.limited_base = LimitedStrategy(base)
8085
self.extend = extend
8186

87+
if is_identity_function(extend):
88+
warnings.warn(
89+
"extend=lambda x: x is a no-op; you probably want to use a "
90+
"different extend function, or just use the base strategy directly.",
91+
HypothesisWarning,
92+
stacklevel=5,
93+
)
94+
8295
strategies = [self.limited_base, self.extend(self.limited_base)]
8396
while 2 ** (len(strategies) - 1) <= max_leaves:
8497
strategies.append(extend(OneOfStrategy(tuple(strategies))))
98+
# If min_leaves > 1, we can never draw from base directly
99+
if min_leaves > 1:
100+
strategies = strategies[1:]
85101
self.strategy = OneOfStrategy(strategies)
86102

87103
def __repr__(self) -> str:
88104
if not hasattr(self, "_cached_repr"):
89-
self._cached_repr = "recursive(%r, %s, max_leaves=%d)" % (
90-
self.base,
91-
get_pretty_function_description(self.extend),
92-
self.max_leaves,
105+
self._cached_repr = (
106+
f"recursive({self.base!r}, "
107+
f"{get_pretty_function_description(self.extend)}, "
108+
f"min_leaves={self.min_leaves}, max_leaves={self.max_leaves})"
93109
)
94110
return self._cached_repr
95111

@@ -99,20 +115,41 @@ def do_validate(self) -> None:
99115
check_strategy(extended, f"extend({self.limited_base!r})")
100116
self.limited_base.validate()
101117
extended.validate()
118+
check_type(int, self.min_leaves, "min_leaves")
102119
check_type(int, self.max_leaves, "max_leaves")
120+
if self.min_leaves <= 0:
121+
raise InvalidArgument(
122+
f"min_leaves={self.min_leaves!r} must be greater than zero"
123+
)
103124
if self.max_leaves <= 0:
104125
raise InvalidArgument(
105126
f"max_leaves={self.max_leaves!r} must be greater than zero"
106127
)
128+
if self.min_leaves > self.max_leaves:
129+
raise InvalidArgument(
130+
f"min_leaves={self.min_leaves!r} must be less than or equal to "
131+
f"max_leaves={self.max_leaves!r}"
132+
)
107133

108134
def do_draw(self, data):
109-
count = 0
135+
min_leaves_retries = 0
110136
while True:
111137
try:
112138
with self.limited_base.capped(self.max_leaves):
113-
return data.draw(self.strategy)
139+
result = data.draw(self.strategy)
140+
leaves_drawn = self.max_leaves - self.limited_base.marker
141+
if leaves_drawn < self.min_leaves:
142+
data.events[
143+
f"Draw for {self!r} had fewer than "
144+
f"min_leaves={self.min_leaves} and had to be retried"
145+
] = ""
146+
min_leaves_retries += 1
147+
if min_leaves_retries < 5:
148+
continue
149+
data.mark_invalid(f"min_leaves={self.min_leaves} unsatisfied")
150+
return result
114151
except LimitReached:
115-
if count == 0:
116-
msg = f"Draw for {self!r} exceeded max_leaves and had to be retried"
117-
data.events[msg] = ""
118-
count += 1
152+
data.events[
153+
f"Draw for {self!r} exceeded "
154+
f"max_leaves={self.max_leaves} and had to be retried"
155+
] = ""

hypothesis-python/tests/cover/test_recursive.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytest
1212

1313
from hypothesis import given, strategies as st
14-
from hypothesis.errors import InvalidArgument
14+
from hypothesis.errors import HypothesisWarning, InvalidArgument
1515

1616
from tests.common.debug import check_can_generate_examples, find_any, minimal
1717

@@ -83,8 +83,34 @@ def test_issue_1502_regression(s):
8383
st.recursive(st.none(), st.lists, max_leaves=-1),
8484
st.recursive(st.none(), st.lists, max_leaves=0),
8585
st.recursive(st.none(), st.lists, max_leaves=1.0),
86+
st.recursive(st.none(), st.lists, min_leaves=-1),
87+
st.recursive(st.none(), st.lists, min_leaves=0),
88+
st.recursive(st.none(), st.lists, min_leaves=1.0),
89+
st.recursive(st.none(), st.lists, min_leaves=10, max_leaves=5),
8690
],
8791
)
8892
def test_invalid_args(s):
8993
with pytest.raises(InvalidArgument):
9094
check_can_generate_examples(s)
95+
96+
97+
def count_leaves(tree):
98+
"""Count the number of leaf nodes (non-tuple values) in a tree."""
99+
if isinstance(tree, tuple):
100+
return sum(count_leaves(child) for child in tree)
101+
return 1
102+
103+
104+
@given(st.recursive(st.none(), lambda x: st.tuples(x, x), min_leaves=3, max_leaves=10))
105+
def test_respects_min_leaves(tree):
106+
assert count_leaves(tree) >= 3
107+
108+
109+
@given(st.recursive(st.none(), lambda x: st.tuples(x, x), min_leaves=5, max_leaves=5))
110+
def test_can_set_exact_leaf_count(tree):
111+
assert count_leaves(tree) == 5
112+
113+
114+
def test_identity_extend_warns():
115+
with pytest.warns(HypothesisWarning, match="extend=lambda x: x is a no-op"):
116+
st.recursive(st.none(), lambda x: x)

0 commit comments

Comments
 (0)