Skip to content

torch.assert operator / async assertion for cuda tensors #36853

@ppwwyyxx

Description

@ppwwyyxx

🚀 Feature

Launch a non-blocking "assert" op for cpu/cuda tensors.

Motivation

It's common to use assertion to check pre-conditions or unexpected inputs in code, as a form of defensive programming. For example, sometimes we want to assert that all elements of a tensor is all positive, or finite.

However, doing assert in python has the following problems in pytorch:

  • when dealing with cuda tensors, using python's assert (t>0).all() or assert torch.isfinite(t).all() will wait for results of the cuda kernel and most if not all its preceding kernel launches, thus potentially cause significantly slow down.
    It would be good to have a way to execute an assert as an async call.
  • it does not work nicely with tracing or FX, because it is considered as control logic

Pitch

x = some_tensor()
(x>0).all().assert("I got a bad data x")
y = torch.log(x)
torch.Assert(y.all(), "message")

(assert is a reserved keyword. Tensorflow uses tf.Assert. we can also use torch.assert_ since the op can be used with in-place style: y.all().assert_())

cc @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions