Skip to content

torch.utils._pytree -> stable #65761

@zou3519

Description

@zou3519

🚀 Feature

PyTorch should use a performant and stable version of _pytree.

Motivation

torch.utils._pytree (heavily inspired by JAX pytrees) is something I cooked up in a non-performant way when prototyping the initial non-composable version of torch.vmap. Since then, it's found many legitimate use cases in PyTorch internals:

The last use case means we want to expose torch.utils._pytree to users! (or provide a stable mechanism that users can use).

There is only really one main blocker to exposing torch.utils._pytree: performance.

  • tree_map is slow, it can probably be faster.
  • Some of the eager compilation work is showing that torch.utils._pytree is slow, especially in the case where the input is a flat list.

Pitch

torch.utils._pytree is heavily inspired by JAX's pytrees. JAX's pytrees are implemented in C++ (with bindings to Python) for performance reasons. We should either:

  • delete our pytree and take a dependency on JAX (or figure out if they'd be interested in splitting off pytree into its own library)
  • run our own implementation of pytree and improve its performance.

cc @ezyang @gchanan @kadeng @msaroufim @XuehaiPan @zou3519 @albanD @Chillee

Metadata

Metadata

Assignees

Labels

better-engineeringRelatively self-contained tasks for better engineering contributorsfeatureA request for a proper, new feature.high prioritymodule: pytreeneeds designWe want to add this feature but we need to figure out how firsttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions