Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 5b99b25

Browse files
DickJC123piiswrong
authored andcommitted
1 d conv with cudnn (#9184)
* 1D conv/deconv handling by cudnn, with tests. * Fix python3 test issue. * Fix lint issues. * Fixed CI and doc.
1 parent 4aff838 commit 5b99b25

File tree

9 files changed

+529
-476
lines changed

9 files changed

+529
-476
lines changed

src/operator/nn/convolution-inl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
6767
bool cudnn_off;
6868
dmlc::optional<int> layout;
6969
DMLC_DECLARE_PARAMETER(ConvolutionParam) {
70-
DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (h, w) or (d, h, w)");
70+
DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (w,), (h, w) or (d, h, w)");
7171
DMLC_DECLARE_FIELD(stride).set_default(TShape())
72-
.describe("Convolution stride: (h, w) or (d, h, w). Defaults to 1 for each dimension.");
72+
.describe("Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.");
7373
DMLC_DECLARE_FIELD(dilate).set_default(TShape())
74-
.describe("Convolution dilate: (h, w) or (d, h, w). Defaults to 1 for each dimension.");
74+
.describe("Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.");
7575
DMLC_DECLARE_FIELD(pad).set_default(TShape())
76-
.describe("Zero pad for convolution: (h, w) or (d, h, w). Defaults to no padding.");
76+
.describe("Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding.");
7777
DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000)
7878
.describe("Convolution filter(channel) number");
7979
DMLC_DECLARE_FIELD(num_group).set_default(1)

src/operator/nn/convolution.cu

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,6 @@ Operator* CreateOp<gpu>(ConvolutionParam param, int dtype,
4141
std::vector<TShape> *out_shape,
4242
Context ctx) {
4343
Operator *op = NULL;
44-
// If 1D convolution, use MXNet implementation
45-
if (param.kernel.ndim() == 1) {
46-
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
47-
op = new ConvolutionOp<gpu, DType>(param);
48-
})
49-
return op;
50-
}
5144

5245
// depth wise conv
5346
if (param.num_filter == param.num_group &&

src/operator/nn/cudnn/cudnn_convolution-inl.h

Lines changed: 122 additions & 151 deletions
Large diffs are not rendered by default.

src/operator/nn/cudnn/cudnn_deconvolution-inl.h

Lines changed: 119 additions & 142 deletions
Large diffs are not rendered by default.

src/operator/nn/deconvolution-inl.h

Lines changed: 97 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -63,28 +63,28 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
6363
bool cudnn_off;
6464
dmlc::optional<int> layout;
6565
DMLC_DECLARE_PARAMETER(DeconvolutionParam) {
66-
DMLC_DECLARE_FIELD(kernel).describe("Deconvolution kernel size: (h, w) or (d, h, w). "
66+
DMLC_DECLARE_FIELD(kernel).describe("Deconvolution kernel size: (w,), (h, w) or (d, h, w). "
6767
"This is same as the kernel size used for the corresponding convolution");
6868
DMLC_DECLARE_FIELD(stride).set_default(TShape())
69-
.describe("The stride used for the corresponding convolution: (h, w) or (d, h, w). "
69+
.describe("The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w). "
7070
"Defaults to 1 for each dimension.");
7171
DMLC_DECLARE_FIELD(dilate).set_default(TShape())
72-
.describe("Dilation factor for each dimension of the input: (h, w) or (d, h, w). "
72+
.describe("Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w). "
7373
"Defaults to 1 for each dimension.");
7474
DMLC_DECLARE_FIELD(pad).set_default(TShape())
7575
.describe("The amount of implicit zero padding added during convolution for each "
7676
"dimension of the input: "
77-
"(h, w) or (d, h, w). "
77+
"(w,), (h, w) or (d, h, w). "
7878
"``(kernel-1)/2`` is usually a good choice. "
7979
"If `target_shape` is set, "
8080
"`pad` will be ignored and a padding that will generate the target shape "
8181
"will be used. Defaults to no padding.");
8282
DMLC_DECLARE_FIELD(adj).set_default(TShape())
83-
.describe("Adjustment for output shape: (h, w) or (d, h, w). "
83+
.describe("Adjustment for output shape: (w,), (h, w) or (d, h, w). "
8484
"If `target_shape` is set, "
8585
"`adj` will be ignored and computed accordingly.");
8686
DMLC_DECLARE_FIELD(target_shape).set_default(TShape())
87-
.describe("Shape of the output tensor: (h, w) or (d, h, w).");
87+
.describe("Shape of the output tensor: (w,), (h, w) or (d, h, w).");
8888
DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000)
8989
.describe("Number of output filters.");
9090
DMLC_DECLARE_FIELD(num_group).set_default(1)
@@ -211,27 +211,38 @@ class DeconvolutionOp : public Operator {
211211
using namespace mshadow;
212212
using namespace mshadow::expr;
213213

214-
if (param_.kernel.ndim() != 2) {
215-
LOG(FATAL) << "If not using CUDNN only 2D-Deconvolution is supported";
214+
if (param_.kernel.ndim() > 2) {
215+
LOG(FATAL) << "If not using CUDNN, only 1D or 2D Deconvolution is supported";
216216
}
217217

218218
CHECK_EQ(req[deconv::kOut], kWriteTo);
219219
size_t expected = param_.no_bias ? 2 : 3;
220220
CHECK_EQ(in_data.size(), expected);
221221
CHECK_EQ(out_data.size(), 1U);
222222
Stream<xpu> *s = ctx.get_stream<xpu>();
223-
Tensor<xpu, 4, DType> data = in_data[deconv::kData].get<xpu, 4, DType>(s);
224-
Tensor<xpu, 4, DType> out = out_data[deconv::kOut].get<xpu, 4, DType>(s);
225-
223+
auto in_data_shape = in_data[deconv::kData].shape_;
224+
Tensor<xpu, 4, DType> data = TBlobTo4DTensor(in_data[deconv::kData], s);
225+
Tensor<xpu, 4, DType> out = TBlobTo4DTensor(out_data[deconv::kOut], s);
226226
index_t o_pad[2], o_adj[2];
227-
TShape dshape = {static_cast<nnvm::dim_t>(data.size(2)),
228-
static_cast<nnvm::dim_t>(data.size(3))};
229-
param_.InferPad(dshape, o_pad, o_adj);
227+
if (param_.kernel.ndim() == 2) {
228+
param_.InferPad(TShape({in_data_shape[2], in_data_shape[3]}), o_pad, o_adj);
229+
} else {
230+
index_t o_pad_1D[1], o_adj_1D[1];
231+
param_.InferPad({in_data_shape[2]}, o_pad_1D, o_adj_1D);
232+
o_pad[0] = 0;
233+
o_pad[1] = o_pad_1D[0];
234+
o_adj[0] = 0;
235+
o_adj[1] = o_adj_1D[0];
236+
}
237+
auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]});
238+
auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]});
239+
auto kernel = param_.kernel.ndim() == 2 ? param_.kernel : TShape({1, param_.kernel[0]});
240+
auto kernel_size = kernel.Size();
230241

231242
Shape<3> wmat_shape =
232243
Shape3(param_.num_group,
233244
data.shape_[1] / param_.num_group,
234-
param_.num_filter / param_.num_group * param_.kernel[0] * param_.kernel[1]);
245+
param_.num_filter / param_.num_group * kernel_size);
235246
Tensor<xpu, 3, DType> wmat =
236247
in_data[deconv::kWeight].get_with_shape<xpu, 3, DType>(wmat_shape, s);
237248
#if defined(__CUDACC__)
@@ -256,21 +267,21 @@ class DeconvolutionOp : public Operator {
256267
temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_);
257268
if (o_pad[0] == 0 && o_pad[1] == 0) {
258269
temp_col = unpack_patch2col(out.Slice(i, i + step),
259-
param_.kernel[0],
260-
param_.kernel[1],
261-
param_.stride[0],
262-
param_.stride[1],
263-
param_.dilate[0],
264-
param_.dilate[1]);
270+
kernel[0],
271+
kernel[1],
272+
stride[0],
273+
stride[1],
274+
dilate[0],
275+
dilate[1]);
265276
} else {
266277
temp_col = unpack_patch2col(pad(out.Slice(i, i + step),
267278
o_pad[0], o_pad[1]),
268-
param_.kernel[0],
269-
param_.kernel[1],
270-
param_.stride[0],
271-
param_.stride[1],
272-
param_.dilate[0],
273-
param_.dilate[1]);
279+
kernel[0],
280+
kernel[1],
281+
stride[0],
282+
stride[1],
283+
dilate[0],
284+
dilate[1]);
274285
}
275286
const index_t gstride = temp_col.size(0) / param_.num_group;
276287
for (uint32_t gid = 0; gid < param_.num_group; ++gid) {
@@ -283,24 +294,24 @@ class DeconvolutionOp : public Operator {
283294
if (o_pad[0] == 0 && o_pad[1] == 0) {
284295
out.Slice(i, i + step) = pack_col2patch(temp_col,
285296
out.Slice(i, i + step).shape_,
286-
param_.kernel[0],
287-
param_.kernel[1],
288-
param_.stride[0],
289-
param_.stride[1],
290-
param_.dilate[0],
291-
param_.dilate[1]);
297+
kernel[0],
298+
kernel[1],
299+
stride[0],
300+
stride[1],
301+
dilate[0],
302+
dilate[1]);
292303
} else {
293304
Shape<4> pshape = out.Slice(i, i + step).shape_;
294305
pshape[2] += 2 * o_pad[0];
295306
pshape[3] += 2 * o_pad[1];
296307
out.Slice(i, i + step) = crop(pack_col2patch(temp_col,
297308
pshape,
298-
param_.kernel[0],
299-
param_.kernel[1],
300-
param_.stride[0],
301-
param_.stride[1],
302-
param_.dilate[0],
303-
param_.dilate[1]),
309+
kernel[0],
310+
kernel[1],
311+
stride[0],
312+
stride[1],
313+
dilate[0],
314+
dilate[1]),
304315
out[i][0].shape_);
305316
}
306317
}
@@ -328,13 +339,31 @@ class DeconvolutionOp : public Operator {
328339
CHECK_EQ(in_data[deconv::kWeight].CheckContiguous(), true);
329340
// get data
330341
Stream<xpu> *s = ctx.get_stream<xpu>();
331-
Tensor<xpu, 4, DType> data = in_data[deconv::kData].get<xpu, 4, DType>(s);
332-
Tensor<xpu, 4, DType> grad = out_grad[deconv::kOut].get<xpu, 4, DType>(s);
333-
Tensor<xpu, 4, DType> gdata = in_grad[deconv::kData].get<xpu, 4, DType>(s);
342+
auto in_data_shape = in_data[deconv::kData].shape_;
343+
Tensor<xpu, 4, DType> data = TBlobTo4DTensor(in_data[deconv::kData], s);
344+
Tensor<xpu, 4, DType> grad = TBlobTo4DTensor(out_grad[deconv::kOut], s);
345+
Tensor<xpu, 4, DType> gdata = TBlobTo4DTensor(in_grad[deconv::kData], s);
346+
347+
index_t o_pad[2], o_adj[2];
348+
if (param_.kernel.ndim() == 2) {
349+
param_.InferPad(TShape({in_data_shape[2], in_data_shape[3]}), o_pad, o_adj);
350+
} else {
351+
index_t o_pad_1D[1], o_adj_1D[1];
352+
param_.InferPad({in_data_shape[2]}, o_pad_1D, o_adj_1D);
353+
o_pad[0] = 0;
354+
o_pad[1] = o_pad_1D[0];
355+
o_adj[0] = 0;
356+
o_adj[1] = o_adj_1D[0];
357+
}
358+
auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]});
359+
auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]});
360+
auto kernel = param_.kernel.ndim() == 2 ? param_.kernel : TShape({1, param_.kernel[0]});
361+
auto kernel_size = kernel.Size();
362+
334363
Shape<3> wmat_shape =
335364
Shape3(param_.num_group,
336365
data.shape_[1] / param_.num_group,
337-
param_.num_filter / param_.num_group * param_.kernel[0] * param_.kernel[1]);
366+
param_.num_filter / param_.num_group * kernel_size);
338367
Tensor<xpu, 3, DType> wmat =
339368
in_data[deconv::kWeight].get_with_shape<xpu, 3, DType>(wmat_shape, s);
340369
Tensor<xpu, 3, DType> gwmat =
@@ -343,10 +372,6 @@ class DeconvolutionOp : public Operator {
343372
CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle)
344373
<< "Must init CuBLAS handle in stream";
345374
#endif
346-
index_t o_pad[2], o_adj[2];
347-
TShape dshape = {static_cast<nnvm::dim_t>(data.size(2)),
348-
static_cast<nnvm::dim_t>(data.size(3))};
349-
param_.InferPad(dshape, o_pad, o_adj);
350375

351376
const index_t nbatch = data.size(0);
352377
Tensor<xpu, 1, DType> workspace =
@@ -366,20 +391,20 @@ class DeconvolutionOp : public Operator {
366391
temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_);
367392
if (o_pad[0] == 0 && o_pad[1] == 0) {
368393
temp_col = unpack_patch2col(grad.Slice(i, i + step),
369-
param_.kernel[0],
370-
param_.kernel[1],
371-
param_.stride[0],
372-
param_.stride[1],
373-
param_.dilate[0],
374-
param_.dilate[1]);
394+
kernel[0],
395+
kernel[1],
396+
stride[0],
397+
stride[1],
398+
dilate[0],
399+
dilate[1]);
375400
} else {
376401
temp_col = unpack_patch2col(pad(grad.Slice(i, i + step), o_pad[0], o_pad[1]),
377-
param_.kernel[0],
378-
param_.kernel[1],
379-
param_.stride[0],
380-
param_.stride[1],
381-
param_.dilate[0],
382-
param_.dilate[1]);
402+
kernel[0],
403+
kernel[1],
404+
stride[0],
405+
stride[1],
406+
dilate[0],
407+
dilate[1]);
383408
}
384409
const index_t gstride = temp_col.size(0) / param_.num_group;
385410
for (uint32_t gid = 0; gid < param_.num_group; ++gid) {
@@ -422,9 +447,8 @@ class DeconvolutionOp : public Operator {
422447
private:
423448
inline index_t InitTemp(const mshadow::Shape<4> &ishape,
424449
const mshadow::Shape<4> &oshape) {
425-
const int ksize_y = param_.kernel[0];
426-
const int ksize_x = param_.kernel[1];
427-
shape_colunit_ = mshadow::Shape2(ishape[1] * ksize_y * ksize_x,
450+
const int ksize = param_.kernel.Size();
451+
shape_colunit_ = mshadow::Shape2(ishape[1] * ksize,
428452
oshape[2] * oshape[3]);
429453
shape_dstunit_ = mshadow::Shape3(param_.num_group,
430454
oshape[1] / param_.num_group,
@@ -449,6 +473,15 @@ class DeconvolutionOp : public Operator {
449473
return required_size;
450474
}
451475

476+
inline Tensor<xpu, 4, DType> TBlobTo4DTensor(const TBlob &tb, Stream<xpu> *s) {
477+
using namespace mshadow;
478+
if (param_.kernel.ndim() == 2)
479+
return tb.get<xpu, 4, DType>(s);
480+
else
481+
return tb.get_with_shape<xpu, 4, DType>(
482+
Shape4(tb.shape_[0], tb.shape_[1], 1, tb.shape_[2]), s);
483+
}
484+
452485
DeconvolutionParam param_;
453486
mshadow::Shape<2> shape_colunit_;
454487
mshadow::Shape<3> shape_dstunit_;
@@ -505,8 +538,8 @@ class DeconvolutionProp : public OperatorProperty {
505538
std::vector<TShape> *out_shape,
506539
std::vector<TShape> *aux_shape) const override {
507540
#if MXNET_USE_CUDNN == 0
508-
if (param_.kernel.ndim() != 2) {
509-
LOG(FATAL) << "If not using CUDNN only 2D-Deconvolution is supported";
541+
if (param_.kernel.ndim() > 2) {
542+
LOG(FATAL) << "If not using CUDNN, only 1D or 2D Deconvolution is supported";
510543
return false;
511544
}
512545
#endif // CUDNN

src/operator/nn/deconvolution.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ MXNET_REGISTER_OP_PROPERTY(Deconvolution, DeconvolutionProp)
5555
.add_argument("bias", "NDArray-or-Symbol", "Bias added to the result after the deconvolution "
5656
"operation.")
5757
.add_arguments(DeconvolutionParam::__FIELDS__())
58-
.describe("Computes 2D transposed convolution (aka fractionally strided convolution) of the "
58+
.describe("Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of the "
5959
"input tensor. This operation can be seen as the gradient of Convolution operation with "
6060
"respect to its input. Convolution usually reduces the size of the input. Transposed "
6161
"convolution works the other way, going from a smaller input to a larger output while "

src/operator/nn/deconvolution.cu

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,7 @@ Operator* CreateOp<gpu>(DeconvolutionParam param, int dtype,
3838
Context ctx) {
3939
// Logic here parallels that in Convolution.cu
4040
Operator *op = NULL;
41-
// If 1D deconvolution, use MXNet implementation
42-
if (param.kernel.ndim() == 1) {
43-
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
44-
op = new DeconvolutionOp<gpu, DType>(param);
45-
})
46-
return op;
47-
}
41+
4842
#if MXNET_USE_CUDNN == 1
4943
// On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16).
5044
int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype;

0 commit comments

Comments
 (0)