Skip to content

Commit 8cc1f38

Browse files
authored
Merge pull request #4639 from Zac-HD/fix-recursive-strategy
Fix and deprecate unusual uses of `st.recursive()`
2 parents 4a322b4 + 6541d3c commit 8cc1f38

7 files changed

Lines changed: 95 additions & 36 deletions

File tree

hypothesis-python/RELEASE.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
RELEASE_TYPE: patch
2+
3+
This patch fixes a bug where |st.recursive| would fail in cases where the
4+
``extend=`` function does not reference it's argument - which was assumed
5+
by the recent ``min_leaves=`` feature, because the strategy can't actually
6+
recurse otherwise. (:issue:`4638`)
7+
8+
Now, the historical behavior is working-but-deprecated, or an error if you
9+
explicitly pass ``min_leaves=``.

hypothesis-python/src/hypothesis/internal/conjecture/data.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -646,21 +646,21 @@ def __init__(
646646
self.overdraw = 0
647647
self._random = random
648648

649-
self.length = 0
650-
self.index = 0
651-
self.output = ""
652-
self.status = Status.VALID
653-
self.frozen = False
654-
self.testcounter = threadlocal.global_test_counter
649+
self.length: int = 0
650+
self.index: int = 0
651+
self.output: str = ""
652+
self.status: Status = Status.VALID
653+
self.frozen: bool = False
654+
self.testcounter: int = threadlocal.global_test_counter
655655
threadlocal.global_test_counter += 1
656656
self.start_time = time.perf_counter()
657657
self.gc_start_time = gc_cumulative_time()
658658
self.events: dict[str, str | int | float] = {}
659659
self.interesting_origin: InterestingOrigin | None = None
660660
self.draw_times: dict[str, float] = {}
661661
self._stateful_run_times: dict[str, float] = defaultdict(float)
662-
self.max_depth = 0
663-
self.has_discards = False
662+
self.max_depth: int = 0
663+
self.has_discards: bool = False
664664

665665
self.provider: PrimitiveProvider = (
666666
provider(self, **provider_kw) if isinstance(provider, type) else provider
@@ -683,9 +683,8 @@ def __init__(
683683
# examples for reporting purposes.
684684
self.__spans: Spans | None = None
685685

686-
# We want the top level span to have depth 0, so we start
687-
# at -1.
688-
self.depth = -1
686+
# We want the top level span to have depth 0, so we start at -1.
687+
self.depth: int = -1
689688
self.__span_record = SpanRecord()
690689

691690
# Slice indices for discrete reportable parts that which-parts-matter can

hypothesis-python/src/hypothesis/internal/conjecture/engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -911,10 +911,14 @@ def debug_data(self, data: ConjectureData | ConjectureResult) -> None:
911911
status = repr(data.status)
912912
if data.status == Status.INTERESTING:
913913
status = f"{status} ({data.interesting_origin!r})"
914+
elif data.status == Status.INVALID and isinstance(data, ConjectureData):
915+
assert isinstance(data, ConjectureData) # mypy is silly
916+
status = f"{status} ({data.events.get('invalid because', '?')})"
914917

918+
newline_tab = "\n\t"
915919
self.debug(
916-
f"{len(data.choices)} choices {data.choices} -> {status}"
917-
f"{', ' + data.output if data.output else ''}"
920+
f"{len(data.choices)} choices -> {status}\n\t{data.choices}"
921+
f"{newline_tab + data.output if data.output else ''}"
918922
)
919923

920924
def observe_for_provider(self) -> AbstractContextManager:

hypothesis-python/src/hypothesis/strategies/_internal/core.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,18 +1857,16 @@ def recursive(
18571857
base: SearchStrategy[Ex],
18581858
extend: Callable[[SearchStrategy[Any]], SearchStrategy[T]],
18591859
*,
1860-
min_leaves: int = 1,
1860+
min_leaves: int | None = None,
18611861
max_leaves: int = 100,
18621862
) -> SearchStrategy[T | Ex]:
18631863
"""base: A strategy to start from.
18641864
18651865
extend: A function which takes a strategy and returns a new strategy.
18661866
1867-
min_leaves: The minimum number of elements to be drawn from base on a given
1868-
run.
1867+
min_leaves: The minimum number of elements to be drawn from base on a given run.
18691868
1870-
max_leaves: The maximum number of elements to be drawn from base on a given
1871-
run.
1869+
max_leaves: The maximum number of elements to be drawn from base on a given run.
18721870
18731871
This returns a strategy ``S`` such that ``S = extend(base | S)``. That is,
18741872
values may be drawn from base, or from any strategy reachable by mixing
@@ -1882,9 +1880,7 @@ def recursive(
18821880
Examples from this strategy shrink by trying to reduce the amount of
18831881
recursion and by shrinking according to the shrinking behaviour of base
18841882
and the result of extend.
1885-
18861883
"""
1887-
18881884
return RecursiveStrategy(base, extend, min_leaves, max_leaves)
18891885

18901886

hypothesis-python/src/hypothesis/strategies/_internal/recursive.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
import threading
1212
import warnings
13+
from collections.abc import Callable
1314
from contextlib import contextmanager
1415

1516
from hypothesis.errors import HypothesisWarning, InvalidArgument
1617
from hypothesis.internal.reflection import (
1718
get_pretty_function_description,
19+
is_first_param_referenced_in_function,
1820
is_identity_function,
1921
)
2022
from hypothesis.internal.validation import check_type
@@ -23,6 +25,7 @@
2325
SearchStrategy,
2426
check_strategy,
2527
)
28+
from hypothesis.utils.deprecation import note_deprecation
2629

2730

2831
class LimitReached(BaseException):
@@ -76,27 +79,25 @@ def capped(self, max_templates):
7679

7780

7881
class RecursiveStrategy(SearchStrategy):
79-
def __init__(self, base, extend, min_leaves, max_leaves):
82+
def __init__(
83+
self,
84+
base: SearchStrategy,
85+
extend: Callable[[SearchStrategy], SearchStrategy],
86+
min_leaves: int | None,
87+
max_leaves: int,
88+
):
8089
super().__init__()
8190
self.min_leaves = min_leaves
8291
self.max_leaves = max_leaves
8392
self.base = base
8493
self.limited_base = LimitedStrategy(base)
8594
self.extend = extend
8695

87-
if is_identity_function(extend):
88-
warnings.warn(
89-
"extend=lambda x: x is a no-op; you probably want to use a "
90-
"different extend function, or just use the base strategy directly.",
91-
HypothesisWarning,
92-
stacklevel=5,
93-
)
94-
9596
strategies = [self.limited_base, self.extend(self.limited_base)]
9697
while 2 ** (len(strategies) - 1) <= max_leaves:
9798
strategies.append(extend(OneOfStrategy(tuple(strategies))))
9899
# If min_leaves > 1, we can never draw from base directly
99-
if min_leaves > 1:
100+
if min_leaves is not None and min_leaves > 1:
100101
strategies = strategies[1:]
101102
self.strategy = OneOfStrategy(strategies)
102103

@@ -115,17 +116,42 @@ def do_validate(self) -> None:
115116
check_strategy(extended, f"extend({self.limited_base!r})")
116117
self.limited_base.validate()
117118
extended.validate()
118-
check_type(int, self.min_leaves, "min_leaves")
119+
120+
if is_identity_function(self.extend):
121+
warnings.warn(
122+
"extend=lambda x: x is a no-op; you probably want to use a "
123+
"different extend function, or just use the base strategy directly.",
124+
HypothesisWarning,
125+
stacklevel=5,
126+
)
127+
128+
if not is_first_param_referenced_in_function(self.extend):
129+
msg = (
130+
f"extend={get_pretty_function_description(self.extend)} doesn't use "
131+
"it's argument, and thus can't actually recurse!"
132+
)
133+
if self.min_leaves is None:
134+
note_deprecation(
135+
msg,
136+
since="RELEASEDAY",
137+
has_codemod=False,
138+
stacklevel=1,
139+
)
140+
else:
141+
raise InvalidArgument(msg)
142+
143+
if self.min_leaves is not None:
144+
check_type(int, self.min_leaves, "min_leaves")
119145
check_type(int, self.max_leaves, "max_leaves")
120-
if self.min_leaves <= 0:
146+
if self.min_leaves is not None and self.min_leaves <= 0:
121147
raise InvalidArgument(
122148
f"min_leaves={self.min_leaves!r} must be greater than zero"
123149
)
124150
if self.max_leaves <= 0:
125151
raise InvalidArgument(
126152
f"max_leaves={self.max_leaves!r} must be greater than zero"
127153
)
128-
if self.min_leaves > self.max_leaves:
154+
if (self.min_leaves or 1) > self.max_leaves:
129155
raise InvalidArgument(
130156
f"min_leaves={self.min_leaves!r} must be less than or equal to "
131157
f"max_leaves={self.max_leaves!r}"
@@ -138,7 +164,7 @@ def do_draw(self, data):
138164
with self.limited_base.capped(self.max_leaves):
139165
result = data.draw(self.strategy)
140166
leaves_drawn = self.max_leaves - self.limited_base.marker
141-
if leaves_drawn < self.min_leaves:
167+
if self.min_leaves and leaves_drawn < self.min_leaves:
142168
data.events[
143169
f"Draw for {self!r} had fewer than "
144170
f"min_leaves={self.min_leaves} and had to be retried"

hypothesis-python/tests/conjecture/test_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def f(data):
596596
runner.run()
597597

598598
out, _ = capsys.readouterr()
599-
assert re.match(r"\d+ choices \(.*\) -> ", out)
599+
assert re.match(r"\d+ choices -> ", out)
600600
assert "INTERESTING" in out
601601

602602

hypothesis-python/tests/cover/test_recursive.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515

1616
from tests.common.debug import (
1717
assert_all_examples,
18+
assert_no_examples,
1819
check_can_generate_examples,
1920
find_any,
2021
minimal,
2122
)
23+
from tests.common.utils import checks_deprecated_behaviour
2224

2325

2426
@given(st.recursive(st.booleans(), st.lists, max_leaves=10))
@@ -80,6 +82,28 @@ def test_issue_1502_regression(s):
8082
pass
8183

8284

85+
def test_recursive_can_generate_varied_structures():
86+
values = st.recursive(st.none(), st.lists)
87+
88+
find_any(values, lambda x: x is None)
89+
find_any(values, lambda x: isinstance(x, list))
90+
find_any(
91+
values, lambda x: isinstance(x, list) and any(isinstance(y, list) for y in x)
92+
)
93+
94+
95+
@checks_deprecated_behaviour
96+
def test_recursive_can_generate_varied_structures_without_using_leaves():
97+
values = st.recursive(st.none(), lambda _: st.lists(st.none()))
98+
99+
find_any(values, lambda x: x is None)
100+
find_any(values, lambda x: isinstance(x, list))
101+
# The bad `extend` function means we can't actually recurse!
102+
assert_no_examples(
103+
values, lambda x: isinstance(x, list) and any(isinstance(y, list) for y in x)
104+
)
105+
106+
83107
@pytest.mark.parametrize(
84108
"s",
85109
[
@@ -92,6 +116,7 @@ def test_issue_1502_regression(s):
92116
st.recursive(st.none(), st.lists, min_leaves=0),
93117
st.recursive(st.none(), st.lists, min_leaves=1.0),
94118
st.recursive(st.none(), st.lists, min_leaves=10, max_leaves=5),
119+
st.recursive(st.none(), lambda _: st.lists(st.none()), min_leaves=1),
95120
],
96121
)
97122
def test_invalid_args(s):
@@ -129,4 +154,4 @@ def test_can_set_exact_leaf_count(tree):
129154

130155
def test_identity_extend_warns():
131156
with pytest.warns(HypothesisWarning, match="extend=lambda x: x is a no-op"):
132-
st.recursive(st.none(), lambda x: x)
157+
st.recursive(st.none(), lambda x: x).validate()

0 commit comments

Comments
 (0)