-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[RFC] Deferred compute in imperative interface to unify imperative and symbolic interface #16376
Description
A new deferred computation (DC) argument to the imperative MXNet APIs is
proposed. If enabled, memory allocation and computation is deferred as long as
possible. Users can export the computational graph recorded during deferred
computation, which enables hybridization support.
Arrays for which DC is enabled are called lazy. Other arrays are called
normal. Inplace operations on lazy arrays are unsupported.
Storage allocation and computation for lazy arrays is deferred until their
results are required by conversion to numpy or use as input to an operator
creating a normal array. Accessing attributes such as shape can also trigger
computation if the attribute can't be inferred.
Update: The proposed implementation in #17530 differs slightly from the API previously described in this RFC. Thus I deleted the API docs in this RFC. Please refer to the PR. For example, a global state is used to enable / disable deferred compute, instead of introducing a new invocation API MXImperativeDeferredInvokeEx.
FAQ
How about Autograd, NDArray.autograd_entry_ and AGInfo?
Autograd inside deferred computation (DC) mode can be supported.
Relation of Autograd and DC: While autograd’s RecordOp provides a similar
recording functionality to the deferred computation, the autograd graph is not
the same as a computational graph: NDArray::Detach() serves to detach a node
from the autograd graph by deleting NDArray.entry_, though the NodeEntry is
still required for reconstructing the computational history of how this NDArray
came to be.
Are reqs like kInPlace supported?
No. For now only kWriteTo is supported in DC mode.
The plan is to replace inplace operations with kWriteTo operations, writing to
a new (lazy) array. The framework should be smart enough to decide when to reuse
memory and when not. It shouldn’t be required for users to specify that they
want an inplace operation.
How is context attribute handled, specifically context changes?
Cross-device copy must be represented as operator (CrossDeviceCopyOp) which
requires special handling in the graph executor.
How is incomplete shape information handled?
shape property triggers computation if shape is accessed and can't be inferred completely.
Users can access static_shape if they wan't to avoid triggering computation.
Python (Gluon)
Based on DC, hybridization in Gluon is simplified:
Instead of implementing def hybrid_forward(self, F, x, ...) in HybridBlock,
users can opt to implement def forward(self, x, ...) in HybridBlock.
Hybridization based on DC works by the HybridBlock performing the following
steps (if it is not called by a parent block being hybridized)
- keeping a reference to the input arrays and a reference to the parameter
arrays to pass them toMXNDArrayGetDeferredComputeSymbol; - enabling deferred compute mode
- running
forward - exporting to symbol and create CachedOp; Run CachedOp
A (internal) global context variable tracks if hybridization is ongoing. If set
to False and a Block is called that is to be hybridized, the global context
variable is set to True and the Block goes through all 4 steps outlined above;
finally the context variable is set back to False after the export to Symbol
step is finished.
Usage example
class Net(nn.HybridBlock):
def forward(self, x, ...):
...Hybridizing gluon.Blocks?
DC could be used to support hybridzing Block if all logic can be traced. A
separate effort may add logic to detect these cases and add hybridization
support based on DC. For now we rely on user to signify hybridization support by
subclassing HybridBlock.
Parameter Shape Inference
For HybridBlock making use of DC for hybridization, we request users to
implement HybridBlock.infer_shape to infer the parameters shape given the
inputs.
Currently, if HybridBlock.infer_shape is not implemented, backward shape
inference is used to infer the shape of parameters. However backward shape
inference is not supported in all cases (cf #14253,
#14983 (comment))
and relying on it for parameter shape inference is brittle. Thus for consistency
and simplicity we require infer_shape method implementation when using
hybridization based on DC.