Skip to content

[JIT] Support torch.distributions.utils.broadcast_all() #10041

@fritzo

Description

@fritzo

Problem

The main blocker to using torch.distributions in the JIT is the broadcast_all() function that is used in every distribution's __init__() method. Currently broadcast_all() calls torch._C._infer_size() under the hood, which results in a JIT error

RuntimeError: expected int at position 0, but got: Tensor

An expensive workaround is to replace the non-jittable torch._C._infer_size() invocation with a torch.sum(*values).size() which is jittable, roughly

- broadcast_shape = torch.Size()
- for i in tensor_idxs:
-     broadcast_shape = torch._C._infer_size(values[i].shape, broadcast_shape)
+ broadcast_shape = sum(values).shape    # expensive workaround

Proposed solution

@soumith and @neerajprad suggested implementing a C++ version of broadcast_all(), which would have the added benefit of speeding up non-jitted code.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions