Skip to content

Conversation

@dylanbespalko
Copy link
Contributor

@dylanbespalko dylanbespalko commented Oct 9, 2019

In-tree changes to pytorch to support complex numbers are being submitted here.
Out-of-tree support for complex numbers is here: pytorch-cpu-strided-complex extension

Changes so far:

  • Renamed references to variable "I" that may be confused for "I" defined in complex.h. I did this to avoid crazy CI failures messages as complex.h is included by more source files.
    • aten/src/ATen/native/cpu/Loops.h (Renamed I to INDEX)
    • aten/src/ATen/native/cuda/Loops.cuh (Renamed I to INDEX)
    • aten/src/ATen/core/ivalue_inl.h (Renamed I to INDEX)
    • c10/util/Array.h (Renamed I to INDEX)
    • c10/util/C++17.h (Renamed I to INDEX)
    • c10/util/Metaprogramming.h (Renamed I to INDEX)
    • c10/util/SmallVector.h (custom renaming)
  • Added complex support of Linear Algebra Ops.
    • SVD needed to be modified to support mixed data types
    • Example U(std::complex<double)), S(double), V(std::complex)
    • See before and after benchmark below (No observable change in performance).
  • Added complex support of Reduce Ops.
    • var/std computations could have been faster if it was possible to interpret std::complex Tensor as a double Tensor.
  • Added complex derivative support for autograd functionality.
    • derivatives are the same as defined by numpy autograd library for real(), imag(), conj(), angle(). These functions only affect complex numbers.
    • derivative of abs() has not been modified to not interfere with existing code.
    • Autograd defines abs() for complex numbers and fabs() for real numbers. I will look into this further down the road.

PyTorch/Caffe2 Operator Micro-benchmarks Before Changes

Tag : short

Benchmarking PyTorch: svd
Mode: Eager
Name: svd_M512_N512
Input: M: 512, N: 512
Forward Execution Time (us) : 162339.425
Forward Execution Time (us) : 162517.479
Forward Execution Time (us) : 162847.775


PyTorch/Caffe2 Operator Micro-benchmarks After Changes

Tag : short

Benchmarking PyTorch: svd
Mode: Eager
Name: svd_M512_N512
Input: M: 512, N: 512
Forward Execution Time (us) : 162032.117
Forward Execution Time (us) : 161943.484
Forward Execution Time (us) : 162513.786

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks module: internals Related to internal abstractions in c10 and ATen module: operators labels Oct 9, 2019
@pytorchbot pytorchbot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Oct 10, 2019
@dylanbespalko
Copy link
Contributor Author

@ezyang,

I have added complex number support for linear algebra ops and reduce ops. There are a lot of CI failures after pulling changes from master. I am mainly seeing jit failures test_cpp, test_cpp_cuda. Can you verify that those errors are not related to me.

I don't mind waiting on this PR if you need to fix some other problems.

Dylan

@ezyang
Copy link
Contributor

ezyang commented Oct 11, 2019

@pytorchbot rebase this please

@ezyang
Copy link
Contributor

ezyang commented Oct 11, 2019

Windows failure is real:

17:09:37 C:\Program Files (x86)\Windows Kits\10\include\10.0.17763.0\ucrt\complex.h(91): error: more than one instance of overloaded function "norm" has "C" linkage
17:09:37 
17:09:37 C:\Program Files (x86)\Windows Kits\10\include\10.0.17763.0\ucrt\complex.h(116): error: more than one instance of overloaded function "normf" has "C" linkage
17:09:37 
17:09:37 C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/aten/src\ATen/native/cpu/zmath.h(15): error: namespace "std" has no member "complex"
17:09:37 
17:09:37 C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/aten/src\ATen/native/cpu/zmath.h(15): error: expected a ">"
17:09:37 
17:09:37 C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/aten/src\ATen/native/cpu/zmath.h(15): error: expected a ";"
17:09:37 
17:09:37 C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/aten/src\ATen/native/cpu/zmath.h(20): error: namespace "std" has no member "complex"
17:09:37 
17:09:37 C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/aten/src\ATen/native/cpu/zmath.h(20): error: expected a ">"
17:09:37 
17:09:37 C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/aten/src\ATen/native/cpu/zmath.h(20): error: expected a ";"
17:09:37 
17:09:37 C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/aten/src\ATen/native/cpu/zmath.h(30): error: namespace "std" has no member "complex"
17:09:37 
17:09:37 C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/aten/src\ATen/native/cpu/zmath.h(30): error: type name is not allowed
17:09:37 
17:09:37 C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/aten/src\ATen/native/cpu/zmath.h(30): error: function template "VALUE_TYPE at::native::<unnamed>::zabs<SCALAR_TYPE,VALUE_TYPE>(SCALAR_TYPE)" is not an entity that can be explicitly specialized
17:09:37 
17:09:37 C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/aten/src\ATen/native/cpu/zmath.h(30): error: expected a ";"
17:09:37 
17:09:37 12 errors detected in the compilation of "C:/Windows/TEMP/tmpxft_00000fb8_00000000-10_ReduceOpsKernel.cpp1.ii".
17:09:37 -- Removing C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/caffe2/CMakeFiles/torch.dir/__/aten/src/ATen/native/cuda/./torch_generated_ReduceOpsKernel.cu.obj
17:09:37 "C:/Program Files/CMake/bin/cmake.exe" -E remove C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/caffe2/CMakeFiles/torch.dir/__/aten/src/ATen/native/cuda/./torch_generated_ReduceOpsKernel.cu.obj
17:09:37 CMake Error at torch_generated_ReduceOpsKernel.cu.obj.Release.cmake:281 (message):
17:09:37   Error generating file
17:09:37   C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/caffe2/CMakeFiles/torch.dir/__/aten/src/ATen/native/cuda/./torch_generated_ReduceOpsKernel.cu.obj
17:09:37 

auto m = self.size(-2);
auto n = self.size(-1);
auto k = std::min(m, n);
int64_t lrwork = jobz == 'N' ? 5*k : k*std::max(5*k+7,2*std::max(m,n)+2*k+1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did this formula come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I meant to take this settings from the LAPACK docs, but I think I copied the wrong setting when I was searching other sites for the reasoning behind this setting. The code has been updated with the setting for LAPACK 3.8.

The rwork setting in LAPACK is the minimum setting. Increasing this number by 2x, 10x, and 100x did not impact performance in an measurable way.

scalar_t wkopt;
lapackGeqrf<scalar_t>(m, n, self_data, m, tau_data, &wkopt, lwork, &info);
lwork = static_cast<int>(wkopt);
lwork = static_cast<int>(zabs<scalar_t, value_t>(wkopt));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it a while and couldn't figure it out. Why is the zabs call here necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wkopt cannot be cast directly to the lwork data type without converting to a real number. wkopt is a positive real number that is stored in a std::complex variable thus zabs() and real_impl() should both do the job, but real_impl() is more performant.

I have updated the template utility functions in zmath.h to support:

  • complex return type: when performing tensor operations that maintain a fixed dtype.
  • real return type: when performing low-level C++ operations.
    All of these functions are no-ops for other data types.

@dylanbespalko
Copy link
Contributor Author

@ezyang,

I have fixed the windows CI failure by name mangling a couple variable names called norm to norm_ that conflicted with some C references to a function called norm that doesn't belong to any namespace. All the CI looks good now.

} \
}()

#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan can you remind me what the rules for adding dispatch macros here?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@dylanbespalko
Copy link
Contributor Author

@ezyang,

Are you able to merge this PR? Thanks.

Comment on lines 1113 to 1125
auto mn = std::min(m, n);
auto mx = std::max(m, n);
int64_t lrwork; // These settings are based on LAPACK 3.8.
if (jobz == 'N'){
lrwork = 5*mn;
}else if (mx > 10*mn){
lrwork = 5*mn*mn + 5*mn;
} else {
lrwork = std::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn);
}
Tensor rwork = at::empty({std::max(int64_t(1), lrwork)}, at::kInt);
auto rwork_data = rwork.data_ptr<int>();
Tensor iwork = at::empty({8*mn}, at::kInt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please include spaces before and after *?
Couple more comments:

  1. For LAPACK 3.6, lrwork needs to be at least 7 * mn. Could we do that so that users with older LAPACK versions are not at a disadvantage?
  2. Also, in the case mx > 10 * mn: how did you obtain 10?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I have made the changes and update the comment to say it supports LAPACK 3.6.
  2. I had to make the assumption that mx >> mn implies mx > 10 * mn. I don't support torch.rand() with complex numbers so I don't really have a good way to test this right now.

Once legacy::cpu::_th_random_ is ported from TH I should be able to better test this. Also, rwork only impacts complex number svd calls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one more concern: if rwork is required for complex valued matrices / tensors, why allocate it for the general case?

Can't we special case as follows?

Tensor rwork;
int* rwork_data = std::nullptr;
if (at::isComplex(self)) {
  rwork = at::empty(.., at::kInt);
  rwork_data = rwork.data_ptr<int>();
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I will update the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I've made the changes. Have a good weekend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @dylanbespalko, you have a good one too.

@ezyang
Copy link
Contributor

ezyang commented Oct 17, 2019

Are you able to merge this PR? Thanks.

Sorry, it got stuck in land limbo. Do you want to handle Vishwak's comments?

@dylanbespalko
Copy link
Contributor Author

Are you able to merge this PR? Thanks.

Sorry, it got stuck in land limbo. Do you want to handle Vishwak's comments?

@ezyang,

Vishwak's concerns are valid, however these values are only impacting the complex number SVD functions because rwork is not used for real numbers. If you can merge my changes, I can definitely make these changes in the next PR in a few days. That would help me out a lot.

Dylan.

@vishwakftw
Copy link
Contributor

@dylanbespalko thanks for your work, would it be possible for you to address the formatting changes? Also, this needs a rebase.

@dylanbespalko
Copy link
Contributor Author

@vishwakftw,

Yes, can address these issues. I guess there are merge conflicts anyways. I'm busy with other stuff today, but I can have this fixed tomorrow.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@dylanbespalko
Copy link
Contributor Author

@ezyang

Can you merge these changes?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@dylanbespalko
Copy link
Contributor Author

@ezyang,

I pulled changes from master and re-ran the CI. I don't think these failures are from my code. The previous CI was passing all except: pr/pytorch-win-ws2016-cuda9-cudnn7-py3. That machine is now timing out after 2+ hours.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 24, 2019
Summary:
In-tree changes to pytorch to support complex numbers are being submitted here.
Out-of-tree support for complex numbers is here: [pytorch-cpu-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex)

Changes so far:

- [x]  Renamed references to variable "I" that may be confused for "I" defined in complex.h.  I did this to avoid crazy CI failures messages as complex.h is included by more source files.
     - aten/src/ATen/native/cpu/Loops.h (Renamed I to INDEX)
     - aten/src/ATen/native/cuda/Loops.cuh (Renamed I to INDEX)
     - aten/src/ATen/core/ivalue_inl.h (Renamed I to INDEX)
     - c10/util/Array.h (Renamed I to INDEX)
     - c10/util/C++17.h (Renamed I to INDEX)
    - c10/util/Metaprogramming.h (Renamed I to INDEX)
    - c10/util/SmallVector.h (custom renaming)
- [x]  Added complex support of Linear Algebra Ops.
     - SVD needed to be modified to support mixed data types
     - Example U(std::complex<double)), S(double), V(std::complex<double>)
     - See before and after benchmark below (No observable change in performance).
- [x]  Added complex support of Reduce Ops.
     - var/std computations could have been faster if it was possible to interpret std::complex<double> Tensor as a double Tensor.
- [x]  Added complex derivative support for autograd functionality.
     - derivatives are the same as defined by numpy autograd library for real(), imag(), conj(), angle(). These functions only affect complex numbers.
     - derivative of abs() has not been modified to not interfere with existing code.
     - Autograd defines abs() for complex numbers and fabs() for real numbers. I will look into this further down the road.

 ----------------------------------------
 PyTorch/Caffe2 Operator Micro-benchmarks Before Changes
----------------------------------------
Tag : short

Benchmarking PyTorch: svd
Mode: Eager
Name: svd_M512_N512
Input: M: 512, N: 512
Forward Execution Time (us) : 162339.425
Forward Execution Time (us) : 162517.479
Forward Execution Time (us) : 162847.775

----------------------------------------
PyTorch/Caffe2 Operator Micro-benchmarks After Changes
----------------------------------------
Tag : short

Benchmarking PyTorch: svd
Mode: Eager
Name: svd_M512_N512
Input: M: 512, N: 512
Forward Execution Time (us) : 162032.117
Forward Execution Time (us) : 161943.484
Forward Execution Time (us) : 162513.786
Pull Request resolved: pytorch/pytorch#27653

Differential Revision: D17907886

Pulled By: ezyang

fbshipit-source-id: a88b6d0427591ec1fba09e97c880f535c5d0e513
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in f8b758b.

thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this pull request Feb 4, 2020
…7653)

Summary:
In-tree changes to pytorch to support complex numbers are being submitted here.
Out-of-tree support for complex numbers is here: [pytorch-cpu-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex)

Changes so far:

- [x]  Renamed references to variable "I" that may be confused for "I" defined in complex.h.  I did this to avoid crazy CI failures messages as complex.h is included by more source files.
     - aten/src/ATen/native/cpu/Loops.h (Renamed I to INDEX)
     - aten/src/ATen/native/cuda/Loops.cuh (Renamed I to INDEX)
     - aten/src/ATen/core/ivalue_inl.h (Renamed I to INDEX)
     - c10/util/Array.h (Renamed I to INDEX)
     - c10/util/C++17.h (Renamed I to INDEX)
    - c10/util/Metaprogramming.h (Renamed I to INDEX)
    - c10/util/SmallVector.h (custom renaming)
- [x]  Added complex support of Linear Algebra Ops.
     - SVD needed to be modified to support mixed data types
     - Example U(std::complex<double)), S(double), V(std::complex<double>)
     - See before and after benchmark below (No observable change in performance).
- [x]  Added complex support of Reduce Ops.
     - var/std computations could have been faster if it was possible to interpret std::complex<double> Tensor as a double Tensor.
- [x]  Added complex derivative support for autograd functionality.
     - derivatives are the same as defined by numpy autograd library for real(), imag(), conj(), angle(). These functions only affect complex numbers.
     - derivative of abs() has not been modified to not interfere with existing code.
     - Autograd defines abs() for complex numbers and fabs() for real numbers. I will look into this further down the road.

 ----------------------------------------
 PyTorch/Caffe2 Operator Micro-benchmarks Before Changes
----------------------------------------
Tag : short

Benchmarking PyTorch: svd
Mode: Eager
Name: svd_M512_N512
Input: M: 512, N: 512
Forward Execution Time (us) : 162339.425
Forward Execution Time (us) : 162517.479
Forward Execution Time (us) : 162847.775

----------------------------------------
PyTorch/Caffe2 Operator Micro-benchmarks After Changes
----------------------------------------
Tag : short

Benchmarking PyTorch: svd
Mode: Eager
Name: svd_M512_N512
Input: M: 512, N: 512
Forward Execution Time (us) : 162032.117
Forward Execution Time (us) : 161943.484
Forward Execution Time (us) : 162513.786
Pull Request resolved: pytorch#27653

Differential Revision: D17907886

Pulled By: ezyang

fbshipit-source-id: a88b6d0427591ec1fba09e97c880f535c5d0e513
auto mn = std::min(m, n);
Tensor iwork = at::empty({8*mn}, at::kInt);
auto iwork_data = iwork.data_ptr<int>();
Tensor rwork;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dylanbespalko why is this required? did you have tests to verify the svd behavior for complex?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am working to add svd for complex on cuda if we need to add something similar on cuda as well since magma follows lapack

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks module: internals Related to internal abstractions in c10 and ATen

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants