You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update on "[ONNX] Enable _jit_pass_onnx_fold_if only when dynamic_axes is None (#50582)"
Fixing pytorch/vision#3251 (PR #49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )
The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.
This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.
Differential Revision: [D26050886](https://our.internmc.facebook.com/intern/diff/D26050886)
[ghstack-poisoned]
Copy file name to clipboardExpand all lines: aten/src/ATen/native/BatchLinearAlgebra.cpp
+58-6Lines changed: 58 additions & 6 deletions
Original file line number
Diff line number
Diff line change
@@ -82,6 +82,22 @@ extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, floa
82
82
// geev
83
83
extern"C"voiddgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);
84
84
extern"C"voidsgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
85
+
extern"C"voidcgeev_(char *jobvl, char *jobvr, int *n,
86
+
std::complex<float> *a, int *lda,
87
+
std::complex<float> *w,
88
+
std::complex<float> *vl, int *ldvl,
89
+
std::complex<float> *vr, int *ldvr,
90
+
std::complex<float> *work, int *lwork,
91
+
float *rwork,
92
+
int *info);
93
+
extern"C"voidzgeev_(char *jobvl, char *jobvr, int *n,
94
+
std::complex<double> *a, int *lda,
95
+
std::complex<double> *w,
96
+
std::complex<double> *vl, int *ldvl,
97
+
std::complex<double> *vr, int *ldvr,
98
+
std::complex<double> *work, int *lwork,
99
+
double *rwork,
100
+
int *info);
85
101
86
102
// gesdd
87
103
extern"C"voidzgesdd_(char *jobz, int *m, int *n, std::complex<double> *a, int *lda,
@@ -307,14 +323,44 @@ template<> void lapackSyevd<float>(char jobz, char uplo, int n, float *a, int ld
template<> void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int lda, double *wr, double *wi, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, int *info) {
326
+
template<> void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int lda, double *w, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, double *rwork, int *info) {
327
+
// lapack [sd]geev wants to separate output arrays: wr and wi for the real
template<> void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int lda, float *wr, float *wi, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, int *info) {
335
+
template<> void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int lda, float *w, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, float *rwork, int *info) {
336
+
// lapack [sd]geev wants to separate output arrays: wr and wi for the real
template<> void lapackEig<c10::complex<double>, double>(char jobvl, char jobvr, int n, c10::complex<double> *a, int lda, c10::complex<double> *w, c10::complex<double> *vl, int ldvl, c10::complex<double> *vr, int ldvr, c10::complex<double> *work, int lwork, double *rwork, int *info) {
template<> void lapackEig<c10::complex<float>, float>(char jobvl, char jobvr, int n, c10::complex<float> *a, int lda, c10::complex<float> *w, c10::complex<float> *vl, int ldvl, c10::complex<float> *vr, int ldvr, c10::complex<float> *work, int lwork, float *rwork, int *info) {
template<> void lapackSvd<c10::complex<double>, double>(char jobz, int m, int n, c10::complex<double> *a, int lda,
319
365
double *s, c10::complex<double> *u, int ldu, c10::complex<double> *vt, int ldvt, c10::complex<double> *work, int lwork, double *rwork, int *iwork, int *info) {
Copy file name to clipboardExpand all lines: aten/src/ATen/native/BatchLinearAlgebra.h
+2-2Lines changed: 2 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -14,8 +14,8 @@ namespace at { namespace native {
14
14
// Define per-batch functions to be used in the implementation of batched
15
15
// linear algebra operations
16
16
17
-
template<classscalar_t>
18
-
voidlapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info);
17
+
template<classscalar_t, classvalue_t=scalar_t>
18
+
voidlapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
19
19
20
20
template<classscalar_t>
21
21
voidlapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
// If not Hermitian use singular value decomposition, else use eigenvalue decomposition
149
149
if (!hermitian) {
150
-
// until https://github.com/pytorch/pytorch/issues/45821 is resolved
151
-
// svd() returns conjugated V for complex-valued input
152
-
Tensor U, S, V_conj;
150
+
Tensor U, S, V;
153
151
// TODO: replace input.svd with linalg_svd
154
-
std::tie(U, S, V_conj) = input.svd();
152
+
// using linalg_svd breaks pytorch/xla, see https://github.com/pytorch/xla/issues/2755
153
+
std::tie(U, S, V) = input.svd();
155
154
Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order
// TODO: replace V_conj.conj() -> V once https://github.com/pytorch/pytorch/issues/45821 is resolved
0 commit comments