@@ -708,6 +708,7 @@ class UnaryCase(Case):
708708 cond : UnaryCheck
709709 check_result : UnaryResultCheck
710710 raw_case : Optional [str ] = field (default = None )
711+ is_complex : bool = field (default = False )
711712
712713
713714r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
@@ -901,6 +902,7 @@ def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
901902 result_expr = result_expr ,
902903 check_result = check_result ,
903904 raw_case = case_str ,
905+ is_complex = True ,
904906 )
905907 cases .append (case )
906908 except ParseError as e :
@@ -1487,28 +1489,21 @@ def test_unary(func_name, func, case):
14871489 # single example test case without using hypothesis.
14881490 filterwarnings ('ignore' , category = NonInteractiveExampleWarning )
14891491
1490- # Determine if this is a complex case by checking the strategy
1491- # Try to generate an example to see if it's complex
1492- try :
1493- in_value = case .cond_from_dtype (xp . float64 ).example ()
1494- except Exception :
1495- # If float64 fails, try complex128
1496- try :
1497- in_value = case . cond_from_dtype ( xp . complex128 ). example ()
1498- except Exception :
1499- # Fallback to float64
1500- in_value = case . cond_from_dtype ( xp . float64 ). example ( )
1492+ # Use the is_complex flag to determine the appropriate dtype
1493+ if case . is_complex :
1494+ dtype = xp . complex128
1495+ in_value = case .cond_from_dtype (dtype ).example ()
1496+ else :
1497+ dtype = xp . float64
1498+ in_value = case . cond_from_dtype ( dtype ). example ()
1499+
1500+ # Create array and compute result based on dtype
1501+ x = xp . asarray ( in_value , dtype = dtype )
1502+ out = func ( x )
15011503
1502- # Determine appropriate dtype based on input value type
1503- if isinstance (in_value , complex ):
1504- dtype = xp .complex128
1505- x = xp .asarray (in_value , dtype = dtype )
1506- out = func (x )
1504+ if case .is_complex :
15071505 out_value = complex (out )
15081506 else :
1509- dtype = xp .float64
1510- x = xp .asarray (in_value , dtype = dtype )
1511- out = func (x )
15121507 out_value = float (out )
15131508
15141509 assert case .check_result (in_value , out_value ), (
0 commit comments