Skip to content

Commit e48216b

Browse files
Copilotev-br
andcommitted
Refactor: use is_complex flag instead of nested try-except in test_unary
Co-authored-by: ev-br <[email protected]>
1 parent 7293ac8 commit e48216b

File tree

1 file changed

+14
-19
lines changed

1 file changed

+14
-19
lines changed

array_api_tests/test_special_cases.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,7 @@ class UnaryCase(Case):
708708
cond: UnaryCheck
709709
check_result: UnaryResultCheck
710710
raw_case: Optional[str] = field(default=None)
711+
is_complex: bool = field(default=False)
711712

712713

713714
r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)")
@@ -901,6 +902,7 @@ def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
901902
result_expr=result_expr,
902903
check_result=check_result,
903904
raw_case=case_str,
905+
is_complex=True,
904906
)
905907
cases.append(case)
906908
except ParseError as e:
@@ -1487,28 +1489,21 @@ def test_unary(func_name, func, case):
14871489
# single example test case without using hypothesis.
14881490
filterwarnings('ignore', category=NonInteractiveExampleWarning)
14891491

1490-
# Determine if this is a complex case by checking the strategy
1491-
# Try to generate an example to see if it's complex
1492-
try:
1493-
in_value = case.cond_from_dtype(xp.float64).example()
1494-
except Exception:
1495-
# If float64 fails, try complex128
1496-
try:
1497-
in_value = case.cond_from_dtype(xp.complex128).example()
1498-
except Exception:
1499-
# Fallback to float64
1500-
in_value = case.cond_from_dtype(xp.float64).example()
1492+
# Use the is_complex flag to determine the appropriate dtype
1493+
if case.is_complex:
1494+
dtype = xp.complex128
1495+
in_value = case.cond_from_dtype(dtype).example()
1496+
else:
1497+
dtype = xp.float64
1498+
in_value = case.cond_from_dtype(dtype).example()
1499+
1500+
# Create array and compute result based on dtype
1501+
x = xp.asarray(in_value, dtype=dtype)
1502+
out = func(x)
15011503

1502-
# Determine appropriate dtype based on input value type
1503-
if isinstance(in_value, complex):
1504-
dtype = xp.complex128
1505-
x = xp.asarray(in_value, dtype=dtype)
1506-
out = func(x)
1504+
if case.is_complex:
15071505
out_value = complex(out)
15081506
else:
1509-
dtype = xp.float64
1510-
x = xp.asarray(in_value, dtype=dtype)
1511-
out = func(x)
15121507
out_value = float(out)
15131508

15141509
assert case.check_result(in_value, out_value), (

0 commit comments

Comments
 (0)