Skip to content

Commit b8f7d50

Browse files
committed
Filter-rewriting for floats
1 parent 9204825 commit b8f7d50

File tree

4 files changed

+168
-25
lines changed

4 files changed

+168
-25
lines changed

hypothesis-python/RELEASE.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
RELEASE_TYPE: patch
2+
3+
This release automatically rewrites some simple filters, such as
4+
``floats().filter(lambda x: x >= 10)`` to the more efficient
5+
``floats(min_value=10)``, based on the AST of the predicate.
6+
7+
We continue to recommend using the efficient form directly wherever
8+
possible, but this should be useful for e.g. :pypi:`pandera` "``Checks``"
9+
where you already have a simple predicate and translating manually
10+
is really annoying. See :issue:`2701` for details.

hypothesis-python/src/hypothesis/internal/filtering.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from typing import Any, Callable, Dict, NamedTuple, Optional, TypeVar
3333

3434
from hypothesis.internal.compat import ceil, floor
35+
from hypothesis.internal.floats import next_down, next_up
3536
from hypothesis.internal.reflection import extract_lambda_source
3637

3738
Ex = TypeVar("Ex")
@@ -274,3 +275,26 @@ def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
274275

275276
kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}}
276277
return ConstructivePredicate(kwargs, predicate)
278+
279+
280+
def get_float_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
281+
kwargs, predicate = get_numeric_predicate_bounds(predicate) # type: ignore
282+
283+
if "min_value" in kwargs:
284+
min_value = kwargs["min_value"]
285+
kwargs["min_value"] = float(kwargs["min_value"])
286+
if min_value < kwargs["min_value"] or (
287+
min_value == kwargs["min_value"] and kwargs.get("exclude_min", False)
288+
):
289+
kwargs["min_value"] = next_up(kwargs["min_value"])
290+
291+
if "max_value" in kwargs:
292+
max_value = kwargs["max_value"]
293+
kwargs["max_value"] = float(kwargs["max_value"])
294+
if max_value > kwargs["max_value"] or (
295+
max_value == kwargs["max_value"] and kwargs.get("exclude_max", False)
296+
):
297+
kwargs["max_value"] = next_down(kwargs["max_value"])
298+
299+
kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}}
300+
return ConstructivePredicate(kwargs, predicate)

hypothesis-python/src/hypothesis/strategies/_internal/numbers.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from hypothesis.errors import InvalidArgument
1919
from hypothesis.internal.conjecture import floats as flt, utils as d
2020
from hypothesis.internal.conjecture.utils import calc_label_from_name
21-
from hypothesis.internal.filtering import get_integer_predicate_bounds
21+
from hypothesis.internal.filtering import (
22+
get_float_predicate_bounds,
23+
get_integer_predicate_bounds,
24+
)
2225
from hypothesis.internal.floats import (
2326
float_of,
2427
int_to_float,
@@ -293,6 +296,40 @@ def do_draw(self, data):
293296
data.stop_example() # (FLOAT_STRATEGY_DO_DRAW_LABEL)
294297
return result
295298

299+
def filter(self, condition):
300+
kwargs, pred = get_float_predicate_bounds(condition)
301+
if not kwargs:
302+
return super().filter(pred)
303+
min_bound = max(kwargs.get("min_value", -math.inf), self.min_value)
304+
max_bound = min(kwargs.get("max_value", math.inf), self.max_value)
305+
306+
# Adjustments for allow_subnormal=False, if any need to be made
307+
if -self.smallest_nonzero_magnitude < min_bound < 0:
308+
min_bound = -0.0
309+
elif 0 < min_bound < self.smallest_nonzero_magnitude:
310+
min_bound = self.smallest_nonzero_magnitude
311+
if -self.smallest_nonzero_magnitude < max_bound < 0:
312+
max_bound = -self.smallest_nonzero_magnitude
313+
elif 0 < max_bound < self.smallest_nonzero_magnitude:
314+
max_bound = 0.0
315+
316+
if min_bound > max_bound:
317+
return nothing()
318+
if (
319+
min_bound > self.min_value
320+
or self.max_value > max_bound
321+
or (self.allow_nan and (-math.inf < min_bound or max_bound < math.inf))
322+
):
323+
self = type(self)(
324+
min_value=min_bound,
325+
max_value=max_bound,
326+
allow_nan=False,
327+
smallest_nonzero_magnitude=self.smallest_nonzero_magnitude,
328+
)
329+
if pred is None:
330+
return self
331+
return super().filter(pred)
332+
296333

297334
@cacheable
298335
@defines_strategy(force_reusable_values=True)
@@ -509,11 +546,11 @@ def floats(
509546
min_value = float("-inf")
510547
if max_value is None:
511548
max_value = float("inf")
512-
assert isinstance(min_value, float)
513-
assert isinstance(max_value, float)
514549
if not allow_infinity:
515550
min_value = max(min_value, next_up(float("-inf")))
516551
max_value = min(max_value, next_down(float("inf")))
552+
assert isinstance(min_value, float)
553+
assert isinstance(max_value, float)
517554
smallest_nonzero_magnitude = (
518555
SMALLEST_SUBNORMAL if allow_subnormal else smallest_normal
519556
)

hypothesis-python/tests/cover/test_filter_rewriting.py

Lines changed: 94 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@
1111
import decimal
1212
import math
1313
import operator
14+
from fractions import Fraction
1415
from functools import partial
16+
from sys import float_info
1517

1618
import pytest
1719

1820
from hypothesis import given, strategies as st
1921
from hypothesis.errors import HypothesisWarning, Unsatisfiable
22+
from hypothesis.internal.floats import next_down, next_up
2023
from hypothesis.internal.reflection import get_pretty_function_description
2124
from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies
22-
from hypothesis.strategies._internal.numbers import IntegersStrategy
25+
from hypothesis.strategies._internal.numbers import FloatStrategy, IntegersStrategy
2326
from hypothesis.strategies._internal.strategies import FilteredStrategy
2427

2528
from tests.common.utils import fails_with
@@ -87,20 +90,81 @@ def test_filter_rewriting(data, strategy, predicate, start, end):
8790

8891

8992
@pytest.mark.parametrize(
90-
"s",
93+
"strategy, predicate, min_value, max_value",
9194
[
92-
st.integers(1, 5).filter(partial(operator.lt, 6)),
93-
st.integers(1, 5).filter(partial(operator.eq, 3.5)),
94-
st.integers(1, 5).filter(partial(operator.eq, "can't compare to strings")),
95-
st.integers(1, 5).filter(partial(operator.ge, 0)),
96-
st.integers(1, 5).filter(partial(operator.lt, math.inf)),
97-
st.integers(1, 5).filter(partial(operator.gt, -math.inf)),
95+
# Floats with integer bounds
96+
(st.floats(1, 5), partial(operator.lt, 3), next_up(3.0), 5), # 3 < x
97+
(st.floats(1, 5), partial(operator.le, 3), 3, 5), # lambda x: 3 <= x
98+
(st.floats(1, 5), partial(operator.eq, 3), 3, 3), # lambda x: 3 == x
99+
(st.floats(1, 5), partial(operator.ge, 3), 1, 3), # lambda x: 3 >= x
100+
(st.floats(1, 5), partial(operator.gt, 3), 1, next_down(3.0)), # 3 > x
101+
# Floats with non-integer bounds
102+
(st.floats(1, 5), partial(operator.lt, 3.5), next_up(3.5), 5),
103+
(st.floats(1, 5), partial(operator.le, 3.5), 3.5, 5),
104+
(st.floats(1, 5), partial(operator.ge, 3.5), 1, 3.5),
105+
(st.floats(1, 5), partial(operator.gt, 3.5), 1, next_down(3.5)),
106+
(st.floats(1, 5), partial(operator.lt, -math.inf), 1, 5),
107+
(st.floats(1, 5), partial(operator.gt, math.inf), 1, 5),
108+
# Floats with only one bound
109+
(st.floats(min_value=1), partial(operator.lt, 3), next_up(3.0), math.inf),
110+
(st.floats(min_value=1), partial(operator.le, 3), 3, math.inf),
111+
(st.floats(max_value=5), partial(operator.ge, 3), -math.inf, 3),
112+
(st.floats(max_value=5), partial(operator.gt, 3), -math.inf, next_down(3.0)),
113+
# Unbounded floats
114+
(st.floats(), partial(operator.lt, 3), next_up(3.0), math.inf),
115+
(st.floats(), partial(operator.le, 3), 3, math.inf),
116+
(st.floats(), partial(operator.eq, 3), 3, 3),
117+
(st.floats(), partial(operator.ge, 3), -math.inf, 3),
118+
(st.floats(), partial(operator.gt, 3), -math.inf, next_down(3.0)),
119+
# Simple lambdas
120+
(st.floats(), lambda x: x < 3, -math.inf, next_down(3.0)),
121+
(st.floats(), lambda x: x <= 3, -math.inf, 3),
122+
(st.floats(), lambda x: x == 3, 3, 3),
123+
(st.floats(), lambda x: x >= 3, 3, math.inf),
124+
(st.floats(), lambda x: x > 3, next_up(3.0), math.inf),
125+
# Simple lambdas, reverse comparison
126+
(st.floats(), lambda x: 3 > x, -math.inf, next_down(3.0)),
127+
(st.floats(), lambda x: 3 >= x, -math.inf, 3),
128+
(st.floats(), lambda x: 3 == x, 3, 3),
129+
(st.floats(), lambda x: 3 <= x, 3, math.inf),
130+
(st.floats(), lambda x: 3 < x, next_up(3.0), math.inf),
131+
# More complicated lambdas
132+
(st.floats(), lambda x: 0 < x < 5, next_up(0.0), next_down(5.0)),
133+
(st.floats(), lambda x: 0 < x >= 1, 1, math.inf),
134+
(st.floats(), lambda x: 1 > x <= 0, -math.inf, 0),
135+
(st.floats(), lambda x: x > 0 and x > 0, next_up(0.0), math.inf),
136+
(st.floats(), lambda x: x < 1 and x < 1, -math.inf, next_down(1.0)),
137+
(st.floats(), lambda x: x > 1 and x > 0, next_up(1.0), math.inf),
138+
(st.floats(), lambda x: x < 1 and x < 2, -math.inf, next_down(1.0)),
98139
],
140+
ids=get_pretty_function_description,
99141
)
100-
@fails_with(Unsatisfiable)
101142
@given(data=st.data())
102-
def test_rewrite_unsatisfiable_filter(data, s):
103-
data.draw(s)
143+
def test_filter_rewriting_floats(data, strategy, predicate, min_value, max_value):
144+
s = strategy.filter(predicate)
145+
assert isinstance(s, LazyStrategy)
146+
assert isinstance(s.wrapped_strategy, FloatStrategy)
147+
assert s.wrapped_strategy.min_value == min_value
148+
assert s.wrapped_strategy.max_value == max_value
149+
value = data.draw(s)
150+
assert predicate(value)
151+
152+
153+
@pytest.mark.parametrize(
154+
"pred",
155+
[
156+
partial(operator.lt, 6),
157+
partial(operator.eq, Fraction(10, 3)),
158+
partial(operator.eq, "can't compare to strings"),
159+
partial(operator.ge, 0),
160+
partial(operator.lt, math.inf),
161+
partial(operator.gt, -math.inf),
162+
],
163+
)
164+
@pytest.mark.parametrize("s", [st.integers(1, 5), st.floats(1, 5)])
165+
@fails_with(Unsatisfiable)
166+
def test_rewrite_unsatisfiable_filter(s, pred):
167+
s.filter(pred).example()
104168

105169

106170
@given(st.integers(0, 2).filter(partial(operator.ne, 1)))
@@ -115,14 +179,8 @@ def test_rewriting_does_not_compare_decimal_snan():
115179
s.example()
116180

117181

118-
@pytest.mark.parametrize(
119-
"strategy, lo, hi",
120-
[
121-
(st.integers(0, 1), -1, 2),
122-
],
123-
ids=repr,
124-
)
125-
def test_applying_noop_filter_returns_self(strategy, lo, hi):
182+
@pytest.mark.parametrize("strategy", [st.integers(0, 1), st.floats(0, 1)], ids=repr)
183+
def test_applying_noop_filter_returns_self(strategy):
126184
s = strategy.wrapped_strategy
127185
s2 = s.filter(partial(operator.le, -1)).filter(partial(operator.ge, 2))
128186
assert s is s2
@@ -135,6 +193,7 @@ def mod2(x):
135193
Y = 2**20
136194

137195

196+
@pytest.mark.parametrize("s", [st.integers(1, 5), st.floats(1, 5)])
138197
@given(
139198
data=st.data(),
140199
predicates=st.permutations(
@@ -149,9 +208,8 @@ def mod2(x):
149208
]
150209
),
151210
)
152-
def test_rewrite_filter_chains_with_some_unhandled(data, predicates):
211+
def test_rewrite_filter_chains_with_some_unhandled(data, predicates, s):
153212
# Set up our strategy
154-
s = st.integers(1, 5)
155213
for p in predicates:
156214
s = s.filter(p)
157215

@@ -163,7 +221,7 @@ def test_rewrite_filter_chains_with_some_unhandled(data, predicates):
163221
# No matter the order of the filters, we get the same resulting structure
164222
unwrapped = s.wrapped_strategy
165223
assert isinstance(unwrapped, FilteredStrategy)
166-
assert isinstance(unwrapped.filtered_strategy, IntegersStrategy)
224+
assert isinstance(unwrapped.filtered_strategy, (IntegersStrategy, FloatStrategy))
167225
for pred in unwrapped.flat_conditions:
168226
assert pred is mod2 or pred.__name__ == "<lambda>"
169227

@@ -246,3 +304,17 @@ def test_bumps_min_size_and_filters_for_content_str_methods(method):
246304
fs = s.filter(method)
247305
assert fs.filtered_strategy.min_size == 1
248306
assert fs.flat_conditions == (method,)
307+
308+
309+
@pytest.mark.parametrize(
310+
"op, attr, value, expected",
311+
[
312+
(operator.lt, "min_value", -float_info.min / 2, 0),
313+
(operator.lt, "min_value", float_info.min / 2, float_info.min),
314+
(operator.gt, "max_value", float_info.min / 2, 0),
315+
(operator.gt, "max_value", -float_info.min / 2, -float_info.min),
316+
],
317+
)
318+
def test_filter_floats_can_skip_subnormals(op, attr, value, expected):
319+
base = st.floats(allow_subnormal=False).filter(partial(op, value))
320+
assert getattr(base.wrapped_strategy, attr) == expected

0 commit comments

Comments
 (0)