-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
Consumer: Optimizer._init_group
The main driving motivation for this API is Optimizer._init_group. It is a function that rearranges the input parameters into lists, which are then multiplexed into downstream foreach operators based on their dtype and device.
Note that in a typical training loop, apart from the first initialization, this capture is a noop. The inputs are already mapped to the desired operators. If something about the inputs change (one of them suddenly not has_grad, dtype changes, order of the inputs, presence/absence of the input), we will actually want to recompile the graph.
Requirements
In particular, we need to guard on _init_group and not on any downstream compiled artifact, even though _init_group itself will not be captured by the graph (we will retrace the original code object when recompiling). This is to ensure that _init_group is reexecuted as long as any of the guarded inputs change.
In particular, we are:
- guarding a function execution that is not captured by the graph
- ask the user to specify the guarded inputs
- since we cannot inspect the contents of the marked function, we require that the user upholds the contract (no other mutations/inputs/side-effects apart from guarded).
API
@init_values_once(
guards=["my_guarded_param"]
)
def _my_func(..., my_guarded_param=..., ...):
...@init_values_once(
guards=["self.state", "self.param_groups", "my_guarded_param"]
)
def _init_group(self, ..., my_guarded_param=..., ...):
...Implementation Details
- We can guard on the arbitrary inputs like self.state and self.param_groups using
pytree.flatten. (actually, we may also need to guard onDICT_KEYS. I am unsure if other container types need similar guards, but I suppose in this case that we need to duck-type inputs by their dtype and tensor metas) - We can use a similar method to
DisableContextbut with extra guard metadata? What's the advantage of this compared with just installing an attribute on the function itself? - We use
inspect.signatureto map guard parameter names to args/kwargs.
Questions
- We cant guard outputs of function, for the purpose of this API. However we can (and we should) guard on them as inputs to downstream graphs as it is natural to do... (downstream graphs depending on them recompile if they change, and may not recompile if they do not).
- However, since we cannot map input dependencies to outputs, we have to guard the entire function output on the inputs, which may be stricter than required in some cases.
- The number of installed guards is quite high. For 1000 tensors, we need to check 1000 guards. Can we make this more efficient somehow? If we expect the guards to pass (in the happy case), then a hashing method could somehow be good.
- regardless, we have to read and check all of the tensor metas at runtime. If we can instead use something like nested tensor maybe such problems (and other perf problems to do with high cardinality, both in dynamo and inductor alike) might go away.
Outdated:
Details
- Do we need to ensure that downstream is guarded on return values of the
@init_values_oncefunction? It's unclear if these will derive from guarding on the specified inputs, or we need to guard them explicitly.- I believe that we don't need to, but it is better if we do (less likely to screw up). Any of mutated values, return values, and inputs which determine the return values can be guarded upon, to equal effect. Mutated values have a reflexive relationship with to guards, but output guard dependence on input is asymmetric.
Additional context
See discussion here: #110709 (comment)
cc:
@ezyang @voznesenskym @jansel from dynamo / core compiler
@mlazos @janeyx99 from optim
cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng