Skip to content

Conversation

@Rakshit-gen
Copy link
Contributor

Fixes #7733

When using lr_scaling_method='sqrt' with dynamic batching, the scale_lr function was failing with TypeError because torch.sqrt expects a Tensor but receives a Python float from batch_size/base_batch_size division.

Changed torch.sqrt to math.sqrt which correctly handles Python floats.

This fixes the issue where training would fail with: TypeError: sqrt(): argument 'input' (position 1) must be Tensor, not float

@Rakshit-gen
Copy link
Contributor Author

@tjruwase can we review this pr?

@sfc-gh-truwase
Copy link
Collaborator

@Rakshit-gen this PR is already approved to merge. The current blockers are

  1. Formatting error: Follow this guide to fix.
  2. Uncompleted DCO requirement. See image below and here for more details.
image

@Rakshit-gen Rakshit-gen force-pushed the fix-scale-lr-sqrt-bug branch from 54e90fa to da6ec60 Compare December 19, 2025 14:33
Rakshit-gen and others added 4 commits December 19, 2025 20:25
…or sqrt method

When using lr_scaling_method='sqrt' with dynamic batching, the scale_lr
function was failing with TypeError because torch.sqrt expects a Tensor
but receives a Python float from batch_size/base_batch_size division.

Changed torch.sqrt to math.sqrt which correctly handles Python floats.

This fixes the issue where training would fail with:
TypeError: sqrt(): argument 'input' (position 1) must be Tensor, not float

Signed-off-by: Rakshit-gen <[email protected]>
Signed-off-by: Rakshit-gen <[email protected]>
…deepspeedai#7727)

_**What this PR does**_

- This PR fixes an occasional deadlock / hang when using DeepSpeed Async
I/O (AIO) for NVMe swap-in/swap-out
- The hang happens inside aio_handle.wait() where training can stall
forever.

_**Reproduction**_

[ds_config.json](https://github.com/user-attachments/files/24179010/ds_config.json)

[finetune_zero3.py](https://github.com/user-attachments/files/24179011/finetune_zero3.py)

Steps
1. Replace {NVME_PATH} in ds_config.json with a valid NVMe mount path on
your cluster.
2. Build/install DeepSpeed with AIO enabled: `DS_BUILD_AIO=1 pip install
--no-build-isolation .`
3. Run:  `CUDA_VISIBLE_DEVICES=0 deepspeed finetune_zero3.py`

_**Fix:**_
Release the Python GIL while aio_handle.wait() is blocking by adding a
pybind11 call guard (py::gil_scoped_release) to the wait() binding.

_**Why this is needed (root cause)**_
Two threads are involved:

- Python main thread: calls aio_handle.wait() and blocks until all async
I/O operations complete.
- AIO worker thread(s): perform the actual file I/O in the background.

In some cases, after an I/O operation completes, the worker thread
triggers cleanup of PyTorch tensors (e.g., decref / refcount updates for
Python-backed objects). That cleanup path may require acquiring the
Python GIL.

**Before this PR:**

- The Python main thread enters aio_handle.wait() while still holding
the GIL.
- wait() blocks, waiting for the worker thread(s) to finish.
- A worker thread completes an I/O op and reaches a cleanup path that
attempts to acquire the GIL.
- The worker thread cannot acquire the GIL because it is held by the
Python thread blocked in wait().
- Result: the Python thread is waiting for the worker, and the worker is
waiting for the GIL → deadlock.

Signed-off-by: Rakshit-gen <[email protected]>
Signed-off-by: Rakshit-gen <[email protected]>
@Rakshit-gen Rakshit-gen force-pushed the fix-scale-lr-sqrt-bug branch from 0681c5c to 994da47 Compare December 19, 2025 14:55
@sfc-gh-truwase sfc-gh-truwase enabled auto-merge (squash) December 19, 2025 15:05
@sfc-gh-truwase sfc-gh-truwase merged commit d7d4eeb into deepspeedai:master Dec 19, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] scale_lr fails for lr_scaling_method="sqrt" due to torch.sqrt on Python float

3 participants