-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[pytree] add another simplified pytree module torch.pytree
#148180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/XuehaiPan/249/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148180
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit cddc222 with merge base 3d7a8b7 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: 0bf5096
Pull Request resolved: #148180
| ] | ||
|
|
||
|
|
||
| def unflatten(treespec: PyTreeSpec, leaves: Iterable[_Any]) -> PyTree: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a TypeVar to dispatch iterable typing to the stub. No reason to erase the typing info here in case we improve typing in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTree is an alias of typing.Any which is not a generic type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that might change in the futre though. Good to not erase typing if unnecessary and make it flexible to future refactors
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change is localized and well documented + motivated, I'm happy with it
Thanks @XuehaiPan
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: adff8bf
Pull Request resolved: pytorch#148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: 7590f64
Pull Request resolved: pytorch#148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: 736abee
Pull Request resolved: pytorch#148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: 6347364
Pull Request resolved: pytorch#148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: cec3d0f
Pull Request resolved: pytorch#148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: d4ebb29
Pull Request resolved: pytorch#148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: 7edb772
Pull Request resolved: pytorch#148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: 61b699f
Pull Request resolved: #148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: 354c1b6
Pull Request resolved: #148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: 3257605
Pull Request resolved: #148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: dcde1bc
Pull Request resolved: #148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: 8a607ad
Pull Request resolved: #148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: e06cbf8
Pull Request resolved: #148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: fc77164
Pull Request resolved: #148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: fc77164
Pull Request resolved: #148180
Differences between `torch.pytree` and `torch.utils.pytree`:
1. APIs in `torch.utils.pytree` have a `tree_` prefix:
```python
leaves, treespec = torch.utils.pytree.tree_flatten(tree)
new_tree = torch.utils.pytree.tree_map(func, tree)
leaevs, treespec = torch.pytree.flatten(tree)
new_tree = torch.pytree.map(func, tree)
```
2. The argument order of `unflatten` is reversed for better `functools.partial` support:
```python
tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
tree = torch.pytree.unflatten(treespec, leaves)
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
```
This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.
ghstack-source-id: 6b2df26
Pull Request resolved: #148180
Stack from ghstack (oldest at bottom):
torch.pytree#148180torch.utils.pytree#137400tree_*functions accept both Python and C++PyTreeSpec#152624Differences between
torch.pytreeandtorch.utils.pytree:APIs in
torch.utils.pytreehave atree_prefix:This is similar to the JAX pytree API:
jax.tree_util.tree_*vs.jax.tree.*.The argument order of
unflattenis reversed for betterfunctools.partialsupport:This is also aligned with the JAX pytree API:
jax.tree.unflatten(treedef, leaves).Because we are adding a completely new module, there are no BC issues.
cc @zou3519