Skip to content

Commit df0e8a7

Browse files
committed
Registering operators in their files.
1 parent da80ce1 commit df0e8a7

File tree

4 files changed

+22
-7
lines changed

4 files changed

+22
-7
lines changed

torchvision/csrc/cpu/deform_conv2d_kernel.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
// modified from
6767
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
6868

69+
#include <torch/script.h>
70+
6971
#include "deform_conv2d_kernel.h"
7072

7173
namespace {
@@ -1137,3 +1139,8 @@ deform_conv2d_backward_cpu(
11371139
return std::make_tuple(
11381140
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
11391141
}
1142+
1143+
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
1144+
m.impl("deform_conv2d", deform_conv2d_forward_cpu);
1145+
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu);
1146+
}

torchvision/csrc/cuda/deform_conv2d_kernel.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
#include <ATen/cuda/CUDAContext.h>
7171
#include <c10/cuda/CUDAGuard.h>
7272
#include <THC/THCAtomics.cuh>
73+
#include <torch/script.h>
7374

7475
#include "cuda_helpers.h"
7576
#include "deform_conv2d_kernel.h"
@@ -1188,3 +1189,8 @@ deform_conv2d_backward_cuda(
11881189
return std::make_tuple(
11891190
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
11901191
}
1192+
1193+
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
1194+
m.impl("deform_conv2d", deform_conv2d_forward_cuda);
1195+
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda);
1196+
}

torchvision/csrc/deform_conv2d.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ at::Tensor deform_conv2d_autocast(
7474
use_mask)
7575
.to(input.scalar_type());
7676
}
77+
78+
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
79+
m.impl("deform_conv2d", deform_conv2d_autocast);
80+
}
7781
#endif
7882

7983
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -361,3 +365,8 @@ deform_conv2d_backward_autograd(
361365

362366
return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
363367
}
368+
369+
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
370+
m.impl("deform_conv2d", deform_conv2d_autograd);
371+
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
372+
}

torchvision/csrc/vision.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ TORCH_LIBRARY(torchvision, m) {
6262
}
6363

6464
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
65-
m.impl("deform_conv2d", deform_conv2d_forward_cpu);
66-
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu);
6765
m.impl("nms", nms_cpu);
6866
m.impl("ps_roi_align", PSROIAlign_forward_cpu);
6967
m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu);
@@ -78,8 +76,6 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
7876
// TODO: Place this in a hypothetical separate torchvision_cuda library
7977
#if defined(WITH_CUDA) || defined(WITH_HIP)
8078
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
81-
m.impl("deform_conv2d", deform_conv2d_forward_cuda);
82-
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda);
8379
m.impl("nms", nms_cuda);
8480
m.impl("ps_roi_align", PSROIAlign_forward_cuda);
8581
m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda);
@@ -95,7 +91,6 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
9591
// Autocast only needs to wrap forward pass ops.
9692
#if defined(WITH_CUDA) || defined(WITH_HIP)
9793
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
98-
m.impl("deform_conv2d", deform_conv2d_autocast);
9994
m.impl("nms", nms_autocast);
10095
m.impl("ps_roi_align", PSROIAlign_autocast);
10196
m.impl("ps_roi_pool", PSROIPool_autocast);
@@ -105,8 +100,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
105100
#endif
106101

107102
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
108-
m.impl("deform_conv2d", deform_conv2d_autograd);
109-
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
110103
m.impl("ps_roi_align", PSROIAlign_autograd);
111104
m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd);
112105
m.impl("ps_roi_pool", PSROIPool_autograd);

0 commit comments

Comments
 (0)