-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
Problem
The main blocker to using torch.distributions in the JIT is the broadcast_all() function that is used in every distribution's __init__() method. Currently broadcast_all() calls torch._C._infer_size() under the hood, which results in a JIT error
RuntimeError: expected int at position 0, but got: Tensor
An expensive workaround is to replace the non-jittable torch._C._infer_size() invocation with a torch.sum(*values).size() which is jittable, roughly
- broadcast_shape = torch.Size()
- for i in tensor_idxs:
- broadcast_shape = torch._C._infer_size(values[i].shape, broadcast_shape)
+ broadcast_shape = sum(values).shape # expensive workaroundProposed solution
@soumith and @neerajprad suggested implementing a C++ version of broadcast_all(), which would have the added benefit of speeding up non-jitted code.
Metadata
Metadata
Assignees
Labels
No labels