Skip to content

[numpy] add torch.argwhere#64257

Closed
kshitij12345 wants to merge 26 commits intopytorch:masterfrom
kshitij12345:develop/numpy/argwhere
Closed

[numpy] add torch.argwhere#64257
kshitij12345 wants to merge 26 commits intopytorch:masterfrom
kshitij12345:develop/numpy/argwhere

Conversation

@kshitij12345
Copy link
Copy Markdown
Collaborator

@kshitij12345 kshitij12345 commented Aug 31, 2021

Adds torch.argwhere as an alias to torch.nonzero

Currently, torch.nonzero is actually provides equivalent functionality to np.argwhere.

From NumPy docs,

np.argwhere(a) is almost the same as np.transpose(np.nonzero(a)), but produces a result of the correct shape for a 0D array.

@facebook-github-bot facebook-github-bot added oncall: jit Add this issue/PR to JIT oncall triage queue cla signed labels Aug 31, 2021
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Aug 31, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit f4924e2 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@codecov
Copy link
Copy Markdown

codecov bot commented Sep 5, 2021

Codecov Report

Merging #64257 (10e7ee9) into master (feefc94) will increase coverage by 0.00%.
The diff coverage is 100.00%.

❗ Current head 10e7ee9 differs from pull request most recent head f4924e2. Consider uploading reports for the commit f4924e2 to get more accurate results

@@           Coverage Diff           @@
##           master   #64257   +/-   ##
=======================================
  Coverage   66.37%   66.38%           
=======================================
  Files         739      739           
  Lines       94299    94310   +11     
=======================================
+ Hits        62595    62608   +13     
+ Misses      31704    31702    -2     

@kshitij12345 kshitij12345 marked this pull request as ready for review September 6, 2021 11:08
@kshitij12345 kshitij12345 requested a review from ezyang as a code owner September 6, 2021 11:08
@kshitij12345 kshitij12345 removed the request for review from ezyang September 6, 2021 11:08
@kshitij12345
Copy link
Copy Markdown
Collaborator Author

Gentle ping @mruberry @saketh-are

Copy link
Copy Markdown
Contributor

@saketh-are saketh-are left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this @kshitij12345. I left a minor comment regarding the documentation of argwhere's output format.

r"""
argwhere(input) -> LongTensor

Return the indices of array elements that are non-zero, grouped by element.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the "grouped by element" language from numpy's documentation to be pretty informal and hard to understand. Looks like I'm also not the first one: https://stackoverflow.com/questions/52400354/how-to-understand-the-np-argwhere-function/52400485.

It seems that "grouped by element" is stated in numpy.argwhere's documentation in reference to numpy.nonzero's behavior (https://numpy.org/doc/stable/reference/generated/numpy.nonzero.html), which groups by dimension instead.

For our purposes it seems clearer to reuse/refer to the language from torch.nonzero's documentation of its as_tuple=False case: https://pytorch.org/docs/stable/generated/torch.nonzero.html.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Thanks!

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@saketh-are should be ready for another review. Thanks :)

Copy link
Copy Markdown
Contributor

@saketh-are saketh-are left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @kshitij12345, thanks for making those changes! I wanted to confirm one last thing with you before we move on to landing this. Given that this PR relies on torch.nonzeros's correctness, would you be able to review the completeness of torch.nonzero's forward tests and make a small comment here about them? I also left a few minor comments inline.

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good @kshitij12345, I just made a few small comments inline for your review. I'll let @saketh-are merge this when it's ready.

Looking forward, we should consider deprecating nonzero in favor of argwhere. nonzero is popular so this will likely count as a "difficult deprecation," and we should prepare an RFC including it and other difficult deprecations (like max and min) soon to gauge the community's interest in NumPy consistency vs. disruptions.

@pytorch-probot
Copy link
Copy Markdown

pytorch-probot bot commented Oct 6, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/kshitij12345/pytorch/blob/f4924e2b90cbe92f9609948728fe8a8fb374b177/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-dynamic ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-xenial-py3-clang5-mobile-code-analysis ciflow/all, ciflow/linux, ciflow/mobile 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@saketh-are has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@saketh-are gentle ping :)

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@saketh-are has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@saketh-are merged this pull request in 462f333.

@saketh-are
Copy link
Copy Markdown
Contributor

Reopening due to f29e522

@saketh-are saketh-are reopened this Oct 22, 2021
@kshitij12345
Copy link
Copy Markdown
Collaborator Author

kshitij12345 commented Oct 23, 2021

Failures look to be about cov operator 🤔

2021-10-21T22:26:27.4373280Z ======================================================================
2021-10-21T22:26:27.4374506Z FAIL [0.006s]: test_get_torch_func_signature_exhaustive_cov_cpu_float32 (__main__.TestOperatorSignaturesCPU)
2021-10-21T22:26:27.4376387Z ----------------------------------------------------------------------
2021-10-21T22:26:27.4377230Z Traceback (most recent call last):
2021-10-21T22:26:27.4378131Z   File "test_fx.py", line 3254, in test_get_torch_func_signature_exhaustive
2021-10-21T22:26:27.4379128Z     op(*bound_args.args, **bound_args.kwargs)
2021-10-21T22:26:27.4380675Z   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_methods_invocations.py", line 666, in __call__
2021-10-21T22:26:27.4381865Z     return self.op(*args, **kwargs)
2021-10-21T22:26:27.4382921Z RuntimeError: cov(): weights sum to zero, can't be normalized
2021-10-21T22:26:27.4383544Z 
2021-10-21T22:26:27.4384334Z During handling of the above exception, another exception occurred:
2021-10-21T22:26:27.4385023Z 
2021-10-21T22:26:27.4385591Z Traceback (most recent call last):
2021-10-21T22:26:27.4387054Z   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 371, in instantiated_test
2021-10-21T22:26:27.4388278Z     result = test(self, **param_kwargs)
2021-10-21T22:26:27.4389686Z   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 737, in test_wrapper
2021-10-21T22:26:27.4390838Z     return test(*args, **kwargs)
2021-10-21T22:26:27.4392182Z   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 884, in only_fn
2021-10-21T22:26:27.4393277Z     return fn(slf, *args, **kwargs)
2021-10-21T22:26:27.4394363Z   File "test_fx.py", line 3262, in test_get_torch_func_signature_exhaustive
2021-10-21T22:26:27.4395523Z     assert op.name in known_no_schema or "nn.functional" in op.name or "_masked." in op.name
2021-10-21T22:26:27.4396484Z AssertionError
2021-10-21T22:26:27.4396898Z 
2021-10-21T22:26:27.4397755Z ---------------------------------------------------------------------

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

The error was fixed in #67039

@saketh-are PTAL :)

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@saketh-are has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants