Skip to content

Conversation

@tugsbayasgalan
Copy link
Contributor

@tugsbayasgalan tugsbayasgalan commented Oct 15, 2024

Stack from ghstack (oldest at bottom):

When we insert cojstants into unlifted graph, we need to detach them if they require grad BUT when we detach we need to preserve the original aliasing information.

Differential Revision: D64406859

When we insert cojstants into unlifted graph, we need to detach them if they require grad.

Differential Revision: [D64406859](https://our.internmc.facebook.com/intern/diff/D64406859/)

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137997

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a9a5da7 with merge base f173623 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406859

tugsbayasgalan added a commit that referenced this pull request Oct 15, 2024
When we insert cojstants into unlifted graph, we need to detach them if they require grad.

Differential Revision: [D64406859](https://our.internmc.facebook.com/intern/diff/D64406859/)

ghstack-source-id: 248115756
Pull Request resolved: #137997
Copy link
Contributor

@avikchaudhuri avikchaudhuri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test please

…ort"

When we insert cojstants into unlifted graph, we need to detach them if they require grad.

Differential Revision: [D64406859](https://our.internmc.facebook.com/intern/diff/D64406859/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406859

tugsbayasgalan added a commit that referenced this pull request Oct 15, 2024
Pull Request resolved: #137997

When we insert cojstants into unlifted graph, we need to detach them if they require grad.

Differential Revision: [D64406859](https://our.internmc.facebook.com/intern/diff/D64406859/)
ghstack-source-id: 248221115
…ort"

When we insert cojstants into unlifted graph, we need to detach them if they require grad.

Differential Revision: [D64406859](https://our.internmc.facebook.com/intern/diff/D64406859/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406859

Copy link
Contributor

@avikchaudhuri avikchaudhuri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebase and check CI before landing...

if isinstance(value, torch.Tensor):
if value.requires_grad:
warnings.warn(
f"A model attribute `{const_name}` requires gradient. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no ending period

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg

if value in original_tensor_to_detached_tensor:
value = original_tensor_to_detached_tensor[value]
else:
detached_value = value.detach()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this code should be refactored into a util?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is bit hard to make it a util because this is only relevant in this specific location.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 16, 2024
…ort"


When we insert cojstants into unlifted graph, we need to detach them if they require grad BUT when we detach we need to preserve the original aliasing information.

Differential Revision: [D64406859](https://our.internmc.facebook.com/intern/diff/D64406859/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406859

…ort"


When we insert cojstants into unlifted graph, we need to detach them if they require grad BUT when we detach we need to preserve the original aliasing information.

Differential Revision: [D64406859](https://our.internmc.facebook.com/intern/diff/D64406859/)

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Oct 16, 2024
Pull Request resolved: #137997

When we insert cojstants into unlifted graph, we need to detach them if they require grad.
ghstack-source-id: b218468

Differential Revision: [D64406859](https://our.internmc.facebook.com/intern/diff/D64406859/)
@tugsbayasgalan
Copy link
Contributor Author

@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/tugsbayasgalan/256/head branch November 17, 2024 02:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants