Skip to content

Commit d15dc6b

Browse files
committed
Extend test by a bit, fix docstring
1 parent d454379 commit d15dc6b

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

test/test_torch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7490,9 +7490,12 @@ def test_sobolengine_fast_forward_scrambled(self):
74907490
self.test_sobolengine_fast_forward(scramble=True)
74917491

74927492
def test_sobolengine_default_dtype(self):
7493+
engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456)
7494+
# Check that default dtype is correctly handled
7495+
self.assertEqual(engine.draw(n=5).dtype, torch.float32)
74937496
with set_default_dtype(torch.float64):
74947497
engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456)
7495-
# Check that default dtype is correctly handled
7498+
# Check that default dtype is correctly handled (when set to float64)
74967499
self.assertEqual(engine.draw(n=5).dtype, torch.float64)
74977500
# Check that explicitly passed dtype is adhered to
74987501
self.assertEqual(engine.draw(n=5, dtype=torch.float32).dtype, torch.float32)

torch/quasirandom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def draw_base2(self, m: int, out: Optional[torch.Tensor] = None,
119119
out (Tensor, optional): The output tensor
120120
dtype (:class:`torch.dtype`, optional): the desired data type of the
121121
returned tensor.
122-
Default: ``torch.float32``
122+
Default: ``None``
123123
"""
124124
n = 2 ** m
125125
total_n = self.num_generated + n

0 commit comments

Comments
 (0)