Skip to content

Conversation

@vishwakftw
Copy link
Contributor

@vishwakftw vishwakftw commented Dec 15, 2018

Changelog:

  • Implements triu and tril for batches of 2D tensors.
  • Remove TH/THC binding for tril
  • Fix CUDA implementation
  • Update docstrings for tril and triu.
  • Remove mask-based triu and tril in cholesky forward and backward.
  • Remove batched tril in torch.distributions.utils

Test plan:

  • Add tests for tril and triu for CPU and CUDA.

Fixes #15016, fixes #15226 and closes #14071

Acknowledgements:

  • Thanks to @t-vi whose implementation I used as a reference.

@vishwakftw
Copy link
Contributor Author

I've tried to debug this error, but to no avail. Posting it below:

Dec 15 07:42:30 /var/lib/jenkins/workspace/aten/src/ATen/cuda/CUDAApplyUtils.cuh(331): error: no instance of overloaded function "at::native::BatchTensorTriOp<T, upper>::operator() [with T=uint8_t, upper=false]" matches the argument list
Dec 15 07:42:30             argument types are: (uint8_t, uint8_t)
Dec 15 07:42:30             object type is: const at::native::BatchTensorTriOp<uint8_t, false>
Dec 15 07:42:30           detected during:
Dec 15 07:42:30             instantiation of "void at::cuda::<unnamed>::ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offset>::apply(at::cuda::detail::TensorInfo<scalar1, IndexType> &, at::cuda::detail::TensorInfo<scalar2, IndexType> &, const Op &, int, IndexType, Offset, Offset) [with Op=at::native::BatchTensorTriOp<uint8_t, false>, scalar1=uint8_t, scalar2=uint8_t, IndexType=unsigned int, ADims=1, BDims=1, Offset=const unsigned int]" 
Dec 15 07:42:30 (310): here
Dec 15 07:42:30             instantiation of "void at::cuda::<unnamed>::ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, remaining_steps, Offsets...>::apply(at::cuda::detail::TensorInfo<scalar1, IndexType> &, at::cuda::detail::TensorInfo<scalar2, IndexType> &, const Op &, int, IndexType, Offsets..., Offsets...) [with Op=at::native::BatchTensorTriOp<uint8_t, false>, scalar1=uint8_t, scalar2=uint8_t, IndexType=unsigned int, ADims=1, BDims=1, remaining_steps=1, Offsets=<>]" 
Dec 15 07:42:30 (369): here
Dec 15 07:42:30             instantiation of "void at::cuda::<unnamed>::kernelPointwiseApply2<Op,scalar1,scalar2,IndexType,ADims,BDims,step>(at::cuda::detail::TensorInfo<scalar1, IndexType>, at::cuda::detail::TensorInfo<scalar2, IndexType>, IndexType, Op) [with Op=at::native::BatchTensorTriOp<uint8_t, false>, scalar1=uint8_t, scalar2=uint8_t, IndexType=unsigned int, ADims=1, BDims=1, step=1]" 
Dec 15 07:42:30 (888): here
Dec 15 07:42:30             instantiation of "__nv_bool at::cuda::CUDA_tensor_apply2<scalar1,scalar2,step,Op>(at::Tensor, at::Tensor, Op, at::cuda::TensorArgType, at::cuda::TensorArgType) [with scalar1=uint8_t, scalar2=uint8_t, step=1, Op=at::native::BatchTensorTriOp<uint8_t, false>]" 
Dec 15 07:42:30 (940): here
Dec 15 07:42:30             instantiation of "__nv_bool at::cuda::CUDA_tensor_apply2<scalar1,scalar2,Op>(at::Tensor, at::Tensor, Op, at::cuda::TensorArgType, at::cuda::TensorArgType) [with scalar1=uint8_t, scalar2=uint8_t, Op=at::native::BatchTensorTriOp<uint8_t, false>]" 
Dec 15 07:42:30 /var/lib/jenkins/workspace/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu(430): here
Dec 15 07:42:30             instantiation of "void at::native::apply_triu_tril<scalar_t,inplace,upper>(at::Tensor &, const at::Tensor &, int64_t) [with scalar_t=uint8_t, inplace=true, upper=false]" 
Dec 15 07:42:30 /var/lib/jenkins/workspace/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu(437): here

@zou3519
Copy link
Contributor

zou3519 commented Dec 17, 2018

@vishwakftw do you get the same error message on a local build?

@vishwakftw
Copy link
Contributor Author

Yes, I do.

@vishwakftw
Copy link
Contributor Author

For additional context about the design - the design for the CUDA implementation is similar to THC implementation, where a custom op (TensorTriOp to be specific: defined in THC/THCTensorMathPairwise.cu) is instantiated and passed to THC_pointwiseApplyN.

@vishwakftw vishwakftw changed the title [WIP] Batched upper triangular, lower triangular [ready for review] Batched upper triangular, lower triangular Dec 19, 2018
@vishwakftw
Copy link
Contributor Author

vishwakftw commented Dec 21, 2018

@zou3519 is this good to go?

Also, should I remove the TH/THC implementations as well?

…timize CPU implementation

- The thrust implementation seemed to be incredibly slow
- The CPU implementation was bottlenecked by a clone() op
- Add test cases for non-square matrices, and an addition based test
@zou3519
Copy link
Contributor

zou3519 commented Dec 26, 2018

@vishwakftw I'll take a look later today

@vishwakftw
Copy link
Contributor Author

Thank you, much obliged.

@zou3519
Copy link
Contributor

zou3519 commented Dec 26, 2018

No, thank you for the contribution :)

@zou3519
Copy link
Contributor

zou3519 commented Dec 26, 2018

Distributions tests seem to be failing

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU code looks fine, still reading the CUDA kernel

- Add edge cases
- Remove redundant tests
- Fix for non-contiguous case
- Pop batch_tril in torch.distributions.utils
- Remove tril in TH only (cannot remove triu due to requirement in THTensorLapack, cannot remove either tril or triu due to dependencies in THCTensorMathMagma.cu)
@vishwakftw
Copy link
Contributor Author

CUDA tests are failing with this error: RuntimeError: cuda runtime error (9) : invalid configuration argument at /var/lib/jenkins/workspace/aten/src/THC/THCTensorMathCompareT.cuh:69

@vishwakftw
Copy link
Contributor Author

@ngimel @zou3519 I have made changes as recommended. Could you please take a look?

@mrshenli mrshenli self-requested a review January 8, 2019 17:53
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did not check the cpu part, I requested small changes, but this is generally good now.

- Make the contiguous check weaker
- Grid dimension computation simplification
- Make (batch) contiguous checks more conservative
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking care of this. Great work!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jan 10, 2019
Summary:
Changelog:

- Implements `triu` and `tril` for batches of 2D tensors.
- Remove TH/THC binding for `tril`
- Fix CUDA implementation
- Update docstrings for tril and triu.
- Remove mask-based `triu` and `tril` in cholesky forward and backward.
- Remove batched tril in torch.distributions.utils
Pull Request resolved: pytorch/pytorch#15257

Differential Revision: D13613888

Pulled By: mrshenli

fbshipit-source-id: 0949a05b9b8e974c1acfaf02a6284848ec5cc1c4
@zou3519 zou3519 added this to the 1.0.1 milestone Jan 10, 2019
@vishwakftw vishwakftw deleted the batched-tril-triu branch January 18, 2019 04:02
vishwakftw added a commit to vishwakftw/pytorch that referenced this pull request Jan 18, 2019
Summary:
Changelog:

- Implements `triu` and `tril` for batches of 2D tensors.
- Remove TH/THC binding for `tril`
- Fix CUDA implementation
- Update docstrings for tril and triu.
- Remove mask-based `triu` and `tril` in cholesky forward and backward.
- Remove batched tril in torch.distributions.utils
Pull Request resolved: pytorch#15257

Differential Revision: D13613888

Pulled By: mrshenli

fbshipit-source-id: 0949a05b9b8e974c1acfaf02a6284848ec5cc1c4
@soumith soumith added the cherry-picked This PR was cherry-picked onto a release branch from master label Jan 18, 2019
soumith pushed a commit that referenced this pull request Jan 18, 2019
Summary:
Changelog:

- Implements `triu` and `tril` for batches of 2D tensors.
- Remove TH/THC binding for `tril`
- Fix CUDA implementation
- Update docstrings for tril and triu.
- Remove mask-based `triu` and `tril` in cholesky forward and backward.
- Remove batched tril in torch.distributions.utils
Pull Request resolved: #15257

Differential Revision: D13613888

Pulled By: mrshenli

fbshipit-source-id: 0949a05b9b8e974c1acfaf02a6284848ec5cc1c4
soumith pushed a commit that referenced this pull request Jan 29, 2019
Summary:
Changelog:

- Implements `triu` and `tril` for batches of 2D tensors.
- Remove TH/THC binding for `tril`
- Fix CUDA implementation
- Update docstrings for tril and triu.
- Remove mask-based `triu` and `tril` in cholesky forward and backward.
- Remove batched tril in torch.distributions.utils
Pull Request resolved: #15257

Differential Revision: D13613888

Pulled By: mrshenli

fbshipit-source-id: 0949a05b9b8e974c1acfaf02a6284848ec5cc1c4
@yongheng1991
Copy link

Hi,
Thanks very much for your contribution.
Have you ever considered the problem when the batch size is larger than 65536? I got an error when I did that. I think it might because of that you used "magma_int_t" to define the 'batch_size' in

magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");

And I think the "magma_int_t" uses 32-bit integer by default.
Best

@vishwakftw
Copy link
Contributor Author

Hi Yongheng, thanks for the message. Are you facing issues in triu / tril or in triangular_solve?

@yongheng1991
Copy link

yongheng1991 commented May 2, 2019

Hi Yongheng, thanks for the message. Are you facing issues in triu / tril or in triangular_solve?

Yes. I am using it in a batch-wise eigenvalue decomposition. There is an error when the batch size is larger than 65535.
Here is a part of code:

        auto tmp_gxu =at::triu(gx.transpose(1, 2), 1);
        gx=gx.triu_().add_(tmp_gxu);  

And this is the cuda error:

RuntimeError: CUDA error: invalid configuration argument (triu_tril_cuda_template at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu:709)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7fb0c0dd7dc5 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: at::Tensor& at::native::triu_tril_cuda_template(at::Tensor&, at::Tensor const&, long, char const*) + 0x2ce (0x7fb0c67bec0e in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libcaffe2_gpu.so)
frame #2: at::native::triu_cuda_out(at::Tensor&, at::Tensor const&, long) + 0xc6 (0x7fb0c67b27a6 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libcaffe2_gpu.so)
frame #3: at::CUDAType::triu_out(at::Tensor&, at::Tensor const&, long) const + 0xd6 (0x7fb0c53fefa6 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libcaffe2_gpu.so)
frame #4: at::native::triu(at::Tensor const&, long) + 0x6a (0x7fb0c163d94a in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #5: at::TypeDefault::triu(at::Tensor const&, long) const + 0x5d (0x7fb0c1a749cd in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #6: torch::autograd::VariableType::triu(at::Tensor const&, long) const + 0x44f (0x7fb0be7a0bbf in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #7: batch_symeig_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, bool) + 0x75c (0x7fb0aa68052c in /home/zhao/anaconda3/lib/python3.6/site-packages/torch_autograd_solver-0.0.0-py3.6-linux-x86_64.egg/torch_autograd_solver_aten.cpython-36m-x86_64-linux-gnu.so)
frame #8: + 0xeea4 (0x7fb0aa685ea4 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch_autograd_solver-0.0.0-py3.6-linux-x86_64.egg/torch_autograd_solver_aten.cpython-36m-x86_64-linux-gnu.so)
frame #9: + 0xf13e (0x7fb0aa68613e in /home/zhao/anaconda3/lib/python3.6/site-packages/torch_autograd_solver-0.0.0-py3.6-linux-x86_64.egg/torch_autograd_solver_aten.cpython-36m-x86_64-linux-gnu.so)
frame #10: + 0x13585 (0x7fb0aa68a585 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch_autograd_solver-0.0.0-py3.6-linux-x86_64.egg/torch_autograd_solver_aten.cpython-36m-x86_64-linux-gnu.so)

frame #18: THPFunction_do_backward(THPFunction*, _object*) + 0xf6 (0x7fb0f005d8a6 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #22: torch::autograd::PyFunction::legacy_apply(std::vector<torch::autograd::Variable, std::allocatortorch::autograd::Variable > const&) + 0xdf (0x7fb0f005dc4f in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #23: torch::autograd::PyFunction::apply(std::vector<torch::autograd::Variable, std::allocatortorch::autograd::Variable >&&) + 0x837 (0x7fb0f005fa57 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #24: + 0x307622 (0x7fb0be40e622 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #25: torch::autograd::Engine::evaluate_function(torch::autograd::FunctionTask&) + 0x385 (0x7fb0be407745 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #26: torch::autograd::Engine::thread_main(torch::autograd::GraphTask*) + 0xc0 (0x7fb0be409740 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #27: torch::autograd::Engine::thread_init(int) + 0x2b0 (0x7fb0be4069e0 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #28: torch::autograd::python::PythonEngine::thread_init(int) + 0x2a (0x7fb0f0059d1a in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #29: + 0xb8678 (0x7fb0bf4d0678 in /home/zhao/anaconda3/lib/python3.6/site-packages/torch/lib/../../../../libstdc++.so.6)
frame #30: + 0x76ba (0x7fb0ff7786ba in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #31: clone + 0x6d (0x7fb0ff4ae41d in /lib/x86_64-linux-gnu/libc.so.6)

@vishwakftw
Copy link
Contributor Author

I think I know the reason for this: the batches for the triu and tril use the y-dimension of the grid, whose limit is 65535. I will try to fix within the next week. For now, you could try mini-batching them in a loop. Sorry about the inconvenience.

@yongheng1991
Copy link

I think I know the reason for this: the batches for the triu and tril use the y-dimension of the grid, whose limit is 65535. I will try to fix within the next week. For now, you could try mini-batching them in a loop. Sorry about the inconvenience.

Thanks very much. Glad to help

@vishwakftw
Copy link
Contributor Author

@yongheng1991 Please take a look at #21067 which adds support for batch sizes > 65535.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-picked This PR was cherry-picked onto a release branch from master open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.tril and torch.triu produce incorrect results with device='cuda' torch.tril does not support 0-sized dims Batched triu, tril

9 participants