Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

while loop fails with hybridization  #18575

@eric-haibin-lin

Description

@eric-haibin-lin
import mxnet as mx

class MyBlock(mx.gluon.HybridBlock):
    def __init__(self):
        super().__init__()

    def hybrid_forward(self, F, free_nds, loop_nds):
        outputs, final_loop_nds = F.contrib.while_loop(
            cond=lambda x: x[0],
            func=lambda s: (s[0], s[0]),
            loop_vars=loop_nds,
            max_iterations=2)
        return outputs

net = MyBlock()
net.initialize()
net.hybridize()

free_nds = [mx.nd.ones((1,)), mx.nd.ones((1,))]
loop_nds = [mx.nd.ones((1,))]

for n in free_nds + loop_nds:
    n.attach_grad()

with mx.autograd.record():
    result = net(free_nds, loop_nds)

print(result)
mx.nd.waitall()

python3.7 test.py

Traceback (most recent call last):
  File "test.py", line 51, in <module>
    result = net(free_nds, loop_nds)
  File "/home/ec2-user/cached_executor/python/mxnet/gluon/block.py", line 1324, in __call__
    return super().__call__(x, *args)
  File "/home/ec2-user/cached_executor/python/mxnet/gluon/block.py", line 705, in __call__
    out = self.forward(*args)
  File "/home/ec2-user/cached_executor/python/mxnet/gluon/block.py", line 1369, in forward
    return self._call_cached_op(x, *args)
  File "/home/ec2-user/cached_executor/python/mxnet/gluon/block.py", line 1090, in _call_cached_op
    out = self._cached_op(*cargs)
  File "mxnet/cython/ndarray.pyx", line 177, in mxnet._cy3.ndarray.CachedOp.__call__
  File "mxnet/cython/./base.pyi", line 41, in mxnet._cy3.ndarray.CALL
mxnet.base.MXNetError: Traceback (most recent call last):
  File "../src/imperative/imperative.cc", line 217
MXNetError: Check failed: AGInfo: :IsNone(*output): Assigning to NDArrays that are already in a computational graph will cause undefined behavior when evaluating gradients. Please call backward first to clear the graph or do this out side of a record section. Also note that you cannot use inplace operations like +=, *=, relu(x, out=x), y[idx]=x, etc inside a record section._cachedop

=======

original version

import mxnet as mx
from mxnet.base import _as_list

class MyBlock(mx.gluon.HybridBlock):
    def __init__(self):
        super().__init__()

    def hybrid_forward(self, F, free_nds, loop_nds):
        n_steps = 5
        max_iterations = 5

        def step(loop, free):
            (s, ), (a, b) = loop, free
            return (s, s)

        cond = lambda loop_vars, _: (loop_vars[0] < 1e35).prod()
        func=lambda *_loop_vars: func(_loop_vars, free_nds)

        outputs, final_loop_nds = F.contrib.while_loop(
            cond=lambda *_loop_vars: cond(_loop_vars, free_nds),
            func=lambda *_loop_vars: step(_loop_vars, free_nds),
            loop_vars=loop_nds,
            max_iterations=max_iterations,
        )

        outputs = _as_list(outputs)
        final_loop_nds = _as_list(final_loop_nds)

        if n_steps == 0:
            outputs = []
        else:
            outputs = [x.slice_axis(axis=0, begin=0, end=n_steps) for x in outputs]
        loop_result_sym = [x * 2 for x in outputs] + [x * 3 for x in final_loop_nds]
        return loop_result_sym

net = MyBlock()
net.initialize()
net.hybridize()

free_var_shapes=[(1, ),(1, )]
loop_var_shapes=[(1, )]

free_nds = [mx.nd.ones(s) for s in free_var_shapes]
loop_nds = [mx.nd.ones(s) for s in loop_var_shapes]

for n in free_nds + loop_nds:
    n.attach_grad()

with mx.autograd.record():
    result = net(free_nds, loop_nds)

print(result)
mx.nd.waitall()

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions