Skip to content

Fix GPSampler crash when torch default device is CUDA#6418

Merged
kAIto47802 merged 1 commit intooptuna:masterfrom
VedantMadane:fix-gpsampler-cuda-device
Jan 23, 2026
Merged

Fix GPSampler crash when torch default device is CUDA#6418
kAIto47802 merged 1 commit intooptuna:masterfrom
VedantMadane:fix-gpsampler-cuda-device

Conversation

@VedantMadane
Copy link
Contributor

Summary

Fix GPSampler crash when torch.set_default_device("cuda") is set globally.

When users set torch.set_default_device("cuda"), GPSampler would crash because torch tensors created internally would be placed on CUDA while numpy arrays remain on CPU, causing conversion errors like:

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Changes

  • Wrapped the sample_relative computation in a torch.device("cpu") context manager to force all tensor operations to use CPU
  • Extracted the main computation into _sample_relative_impl for cleaner separation
  • Added regression test that verifies GPSampler works with CUDA default device (skipped if CUDA not available)

Test plan

  • Added regression test test_gpsampler_with_cuda_default_device
  • Existing tests should pass

Fixes #6113

cc @kAIto47802 (who suggested this approach in the issue comments)

@not522
Copy link
Member

not522 commented Jan 19, 2026

@y0z @kAIto47802 Could you review this PR?

Copy link
Member

@y0z y0z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@y0z y0z removed their assignment Jan 23, 2026
Copy link
Collaborator

@kAIto47802 kAIto47802 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR. LGTM!

@kAIto47802 kAIto47802 merged commit 599e87a into optuna:master Jan 23, 2026
14 checks passed
@y0z y0z added this to the v4.8.0 milestone Feb 10, 2026
@not522 not522 added the bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself. label Feb 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPSampler crashes when torch default device is cuda

4 participants