Skip to content

Conversation

@kwen2501
Copy link
Collaborator

@kwen2501 kwen2501 commented Jan 25, 2025

Stack from ghstack (oldest at bottom):

This PR implements a small UI improvement over #133603.

It prepares a NCCL memory allocator in torch cpp and then pybind's it out, so that user can directly use it.

UI:

pool = torch.cuda.MemPool(backend.mem_allocator)
with torch.cuda.use_mem_pool(pool):
    tensor = torch.arange(1024 * 1024 * 2, device=device)

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145675

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 37 Pending

As of commit d310c91 with merge base d79c6f4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Jan 25, 2025
kwen2501 added a commit that referenced this pull request Jan 25, 2025
ghstack-source-id: 321d023
Pull Request resolved: #145675
pool = torch.cuda.MemPool(allocator.allocator())

# Use NCCL memory allocator
allocator = c10d.nccl_mem_allocator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make this generic? (for *ccl vendors)

Copy link
Collaborator Author

@kwen2501 kwen2501 Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we should. Maybe later :) (More framework design needed)

@wconstab
Copy link
Contributor

the API looks nice. I have 2 high level questions about the direction

  1. do we have any input from any other vendors who want to support similar functionality? It would be nice if the allocator creation and backend-registration were members of 'backend' object accessed through processgroup API so that any backend can easily override the behaviors cc @c-p-i-o @xw285cornell @sujoysaraswati

  2. it seems possible to either use NVLS automatically (via compile/inductor) or manually by calling this API. Is this API also useful for inductor binaries? (I don't see why not..) cc @yifuwang

Copy link
Collaborator

@syed-ahmed syed-ahmed left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

This PR implements a small UI improvement over #133603.

It prepares a NCCL memory allocator in torch cpp and then pybind's it out, so that user can directly use it.

UI:
```
pool = torch.cuda.MemPool(dist.nccl_mem_allocator)
with torch.cuda.use_mem_pool(pool):
    tensor = torch.arange(1024 * 1024 * 2, device=device)
```

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Jan 28, 2025
ghstack-source-id: 76788cc
Pull Request resolved: #145675
Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! This seems good to me. It might be enough just as it is. If we eventually want to make a smoother API for getting an allocator, i think we could build that strictly in the python layer later on and use these apis under the hood.

@kwen2501
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 28, 2025
@kwen2501 kwen2501 added the module: nccl Problems related to nccl support label Jan 28, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 28, 2025 22:17 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 28, 2025 22:17 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 28, 2025 22:17 Inactive
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 29, 2025 19:45 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 29, 2025 19:45 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 29, 2025 19:45 Inactive
@ZainRizvi
Copy link
Contributor

@pytorchbot revert -c ghfirst -m "Sorry but this still fails internally. See D68866823 for details"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@kwen2501 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Jan 30, 2025
This reverts commit 18a7a04.

Reverted #145675 on behalf of https://github.com/ZainRizvi due to Sorry but this still fails internally. See D68866823 for details ([comment](#145675 (comment)))
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 30, 2025 16:34 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 30, 2025 16:34 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 30, 2025 16:34 Inactive
This PR implements a small UI improvement over #133603.

It prepares a NCCL memory allocator in torch cpp and then pybind's it out, so that user can directly use it.

UI:
```
pool = torch.cuda.MemPool(backend.mem_allocator)
with torch.cuda.use_mem_pool(pool):
    tensor = torch.arange(1024 * 1024 * 2, device=device)
```

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Jan 30, 2025
ghstack-source-id: 9ca714d
Pull Request resolved: #145675
@kwen2501
Copy link
Collaborator Author

@ZainRizvi Arg, sorry I forgot to add an #else, too stupid.
Fixed now and trying to reland again. Sorry.

@kwen2501
Copy link
Collaborator Author

@pytorchbot merge -f "Adding #else to fix internal compilation error"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 30, 2025 18:42 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 30, 2025 18:42 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 30, 2025 18:42 Inactive
@github-actions github-actions bot deleted the gh/kwen2501/119/head branch March 2, 2025 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants