Skip to content

Commit 7293ac8

Browse files
authored
Merge pull request #7 from ev-br/copilot/fix-complex-special-case-parsing
Fix complex π expression parsing in special case tests
2 parents 80615aa + c0fcd55 commit 7293ac8

File tree

1 file changed

+45
-10
lines changed

1 file changed

+45
-10
lines changed

array_api_tests/test_special_cases.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -502,10 +502,13 @@ def parse_complex_value(value_str: str) -> complex:
502502
(nan+nanj)
503503
>>> parse_complex_value('0 + NaN j')
504504
nanj
505+
>>> parse_complex_value('+0 + πj/2')
506+
1.5707963267948966j
507+
>>> parse_complex_value('+infinity + 3πj/4')
508+
(inf+2.356194490192345j)
505509
506-
Handles both "0j" and "0 j" formats with optional spaces.
510+
Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M"
507511
"""
508-
# Handle the format like "+0 + 0j" or "NaN + NaN j"
509512
m = r_complex_value.match(value_str)
510513
if m is None:
511514
raise ParseError(value_str)
@@ -517,7 +520,16 @@ def parse_complex_value(value_str: str) -> complex:
517520

518521
# Parse imaginary part with its sign
519522
imag_sign = m.group(3)
520-
imag_val_str = m.group(4)
523+
# Group 4 is πj form (e.g., "πj/2"), group 5 is plain form (e.g., "NaN")
524+
if m.group(4): # πj form
525+
imag_val_str_raw = m.group(4)
526+
# Remove 'j' to get coefficient: "πj/2" -> "π/2"
527+
imag_val_str = imag_val_str_raw.replace('j', '')
528+
else: # plain form
529+
imag_val_str_raw = m.group(5)
530+
# Strip trailing 'j' if present: "0j" -> "0"
531+
imag_val_str = imag_val_str_raw[:-1] if imag_val_str_raw.endswith('j') else imag_val_str_raw
532+
521533
imag_val = parse_value(imag_sign + imag_val_str)
522534

523535
return complex(real_val, imag_val)
@@ -583,16 +595,29 @@ def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]:
583595
return complex_cond, expr, complex_from_dtype
584596

585597

598+
def _check_component_with_tolerance(actual: float, expected: float, allow_any_sign: bool) -> bool:
599+
"""
600+
Helper to check if actual matches expected, with optional sign flexibility and tolerance.
601+
"""
602+
if allow_any_sign and not math.isnan(expected):
603+
return abs(actual) == abs(expected) or math.isclose(abs(actual), abs(expected), abs_tol=0.01)
604+
elif not math.isnan(expected):
605+
check_fn = make_strict_eq(expected) if expected == 0 or math.isinf(expected) else make_rough_eq(expected)
606+
return check_fn(actual)
607+
else:
608+
return math.isnan(actual)
609+
610+
586611
def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], str]:
587612
"""
588613
Parses a complex result string to return a checker and expression.
589614
590615
Handles cases like:
591616
- "``+0 + 0j``" - exact complex value
592-
- "``0 + NaN j`` (sign of the real component is unspecified)" - allow any sign for real
593-
- "``NaN + NaN j``" - both parts NaN
617+
- "``0 + NaN j`` (sign of the real component is unspecified)"
618+
- "``+0 + πj/2``" - with π expressions (uses approximate equality)
594619
"""
595-
# Check for unspecified sign note
620+
# Check for unspecified sign notes
596621
unspecified_real_sign = "sign of the real component is unspecified" in result_str
597622
unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str
598623

@@ -601,13 +626,22 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st
601626
m = re.search(r"``([^`]+)``", result_str)
602627
if m:
603628
value_str = m.group(1)
629+
# Check if the value contains π expressions (for approximate comparison)
630+
has_pi = 'π' in value_str
631+
604632
try:
605633
expected = parse_complex_value(value_str)
606634
except ParseError:
607635
raise ParseError(result_str)
608636

609-
# Create checker based on whether signs are unspecified
610-
if unspecified_real_sign and not math.isnan(expected.real):
637+
# Create checker based on whether signs are unspecified and whether π is involved
638+
if has_pi:
639+
# Use approximate equality for both real and imaginary parts if they involve π
640+
def check_result(z: complex) -> bool:
641+
real_match = _check_component_with_tolerance(z.real, expected.real, unspecified_real_sign)
642+
imag_match = _check_component_with_tolerance(z.imag, expected.imag, unspecified_imag_sign)
643+
return real_match and imag_match
644+
elif unspecified_real_sign and not math.isnan(expected.real):
611645
# Allow any sign for real part
612646
def check_result(z: complex) -> bool:
613647
imag_check = make_strict_eq(expected.imag)
@@ -693,9 +727,10 @@ class UnaryCase(Case):
693727
r"For complex floating-point operands, let ``a = real\(x_i\)``, ``b = imag\(x_i\)``"
694728
)
695729
r_complex_case = re.compile(r"If ``a`` is (.+) and ``b`` is (.+), the result is (.+)")
696-
# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j"
730+
# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j", "πj/2", "3πj/4"
731+
# Two formats: 1) πj/N expressions where j is part of the coefficient, 2) plain values followed by j
697732
r_complex_value = re.compile(
698-
r"([+-]?)([^\s]+)\s*([+-])\s*([^\s]+)\s*j"
733+
r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d+)?)|([^\s]+))\s*j?"
699734
)
700735

701736

0 commit comments

Comments
 (0)