Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f9f7416

Browse files
apeforestanirudh2290
authored andcommitted
[MXNET-623] Fixing an integer overflow bug in large NDArray (#11742)
* Fix integer overflow when the array size is too large * Update issue templates * Update issue templates * Remove files added by mistake * Fix compilation error after type index_t changed to int64_t * Explicity specify type in std::max template to avoid platform dependent compilation error * Add nightly test for large array * Update submodule mshadow * Fix compilation warning * Fix compilation warning * Change index variable type to size_t * Fix integer overflow when the array size is too large * Update issue templates * Remove files added by mistake * Fix compilation error after type index_t changed to int64_t * Explicity specify type in std::max template to avoid platform dependent compilation error * Add nightly test for large array * [MXNET-531] NeuralStyle Example for Scala (#11621) * add initial neuralstyle and test coverage * Add two more test and README * kill comments * patch on memory leaks fix * fix formatting issues * remove redundant files * disable the Gan example for now * add ignore method * add new download scheme to match the changes * Update submodule mshadow * Fix compilation warning * Fix compilation warning * Change index variable type to size_t * Change temp_size type from size_t to index_t * Fix lint error * Fix compilation error in GPU * Fix compilation error on GPU * Fix compilation error in cpp-package * Fix unit test in GPU * Change correct type for nnvmGraph * update mshadow submodule to local repo to verify * update mshadow submodule * change some data type to size_t * change unit test style * fix lint * fix compilation error in Windows * fix compilation error in Windows * use forked submodule to verify * temporarily update submodule to verify the fix * update mshadow submodule to use remote * add test to nightly test script
1 parent e93af41 commit f9f7416

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+153
-115
lines changed

3rdparty/mshadow

src/c_api/c_api_function.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ std::vector<nnvm::NodeEntry> Gradient(
5555
g->inputs = out_grads;
5656

5757
std::vector<nnvm::NodeEntry> ret;
58-
for (index_t i = 0; i < g->num_outputs(); ++i) {
58+
for (uint32_t i = 0; i < g->num_outputs(); ++i) {
5959
ret.emplace_back(nnvm::NodeEntry{g, i, 0});
6060
}
6161

src/executor/graph_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1308,7 +1308,7 @@ void GraphExecutor::ExecuteMonCallback(size_t nid) {
13081308
}
13091309
}
13101310
CHECK_EQ(opnode.exec->out_array.size(), output_names.size());
1311-
for (index_t i = 0; i < opnode.exec->out_array.size(); ++i) {
1311+
for (size_t i = 0; i < opnode.exec->out_array.size(); ++i) {
13121312
NDArray *cpy = new NDArray(opnode.exec->out_array[i]);
13131313
std::string name = inode.source->attrs.name + "_" + output_names[i];
13141314
this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));

src/io/image_iter_common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ImageLabelMap {
4242
* \param label_width predefined label_width
4343
*/
4444
explicit ImageLabelMap(const char *path_imglist,
45-
mshadow::index_t label_width,
45+
index_t label_width,
4646
bool silent) {
4747
this->label_width = label_width;
4848
image_index_.clear();
@@ -58,7 +58,7 @@ class ImageLabelMap {
5858
// skip space
5959
while (isspace(*p) && p != end) ++p;
6060
image_index_.push_back(static_cast<size_t>(atol(p)));
61-
for (size_t i = 0; i < label_width; ++i) {
61+
for (index_t i = 0; i < label_width; ++i) {
6262
// skip till space
6363
while (!isspace(*p) && p != end) ++p;
6464
// skip space
@@ -171,7 +171,7 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
171171
// Batch parameters
172172
struct BatchParam : public dmlc::Parameter<BatchParam> {
173173
/*! \brief label width */
174-
index_t batch_size;
174+
uint32_t batch_size;
175175
/*! \brief use round roubin to handle overflow batch */
176176
bool round_batch;
177177
// declare parameters

src/io/iter_image_recordio_2.cc

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class ImageRecordIOParser2 {
7575
cv::Mat TJimdecode(cv::Mat buf, int color);
7676
#endif
7777
#endif
78-
inline unsigned ParseChunk(DType* data_dptr, real_t* label_dptr, const unsigned current_size,
78+
inline size_t ParseChunk(DType* data_dptr, real_t* label_dptr, const size_t current_size,
7979
dmlc::InputSplit::Blob * chunk);
8080
inline void CreateMeanImg(void);
8181

@@ -104,10 +104,10 @@ class ImageRecordIOParser2 {
104104
/*! \brief temp space */
105105
mshadow::TensorContainer<cpu, 3> img_;
106106
/*! \brief internal instance order */
107-
std::vector<std::pair<unsigned, unsigned> > inst_order_;
108-
unsigned inst_index_;
107+
std::vector<std::pair<size_t, size_t> > inst_order_;
108+
size_t inst_index_;
109109
/*! \brief internal counter tracking number of already parsed entries */
110-
unsigned n_parsed_;
110+
size_t n_parsed_;
111111
/*! \brief overflow marker */
112112
bool overflow;
113113
/*! \brief unit size */
@@ -200,7 +200,7 @@ inline void ImageRecordIOParser2<DType>::Init(
200200
"larger chunk size";
201201
}
202202
// 1.1 ratio is for a bit more shuffle parts to avoid boundary issue
203-
unsigned num_shuffle_parts =
203+
size_t num_shuffle_parts =
204204
std::ceil(source_->GetTotalSize() * 1.1 /
205205
(param_.num_parts * (param_.shuffle_chunk_size << 20UL)));
206206

@@ -262,7 +262,7 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
262262
}
263263
CHECK(source_ != nullptr);
264264
dmlc::InputSplit::Blob chunk;
265-
unsigned current_size = 0;
265+
size_t current_size = 0;
266266
out->index.resize(batch_param_.batch_size);
267267

268268
// InitBatch
@@ -295,7 +295,7 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
295295

296296
while (current_size < batch_param_.batch_size) {
297297
// int n_to_copy;
298-
unsigned n_to_out = 0;
298+
size_t n_to_out = 0;
299299
if (n_parsed_ == 0) {
300300
if (source_->NextBatch(&chunk, batch_param_.batch_size)) {
301301
inst_order_.clear();
@@ -328,15 +328,16 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
328328
n_to_out = 0;
329329
}
330330
} else {
331-
int n_to_copy = std::min(n_parsed_, batch_param_.batch_size - current_size);
331+
size_t n_to_copy = std::min(n_parsed_,
332+
static_cast<size_t>(batch_param_.batch_size) - current_size);
332333
n_parsed_ -= n_to_copy;
333334
// Copy
334335
#pragma omp parallel for num_threads(param_.preprocess_threads)
335-
for (int i = 0; i < n_to_copy; ++i) {
336+
for (int i = 0; i < static_cast<int>(n_to_copy); ++i) {
336337
omp_exc_.Run([&] {
337-
std::pair<unsigned, unsigned> place = inst_order_[inst_index_ + i];
338+
std::pair<size_t, size_t> place = inst_order_[inst_index_ + i];
338339
const DataInst& batch = temp_[place.first][place.second];
339-
for (unsigned j = 0; j < batch.data.size(); ++j) {
340+
for (size_t j = 0; j < batch.data.size(); ++j) {
340341
CHECK_EQ(unit_size_[j], batch.data[j].Size());
341342
MSHADOW_TYPE_SWITCH(out->data[j].data().type_flag_, dtype, {
342343
mshadow::Copy(
@@ -482,18 +483,18 @@ cv::Mat ImageRecordIOParser2<DType>::TJimdecode(cv::Mat image, int color) {
482483

483484
// Returns the number of images that are put into output
484485
template<typename DType>
485-
inline unsigned ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t* label_dptr,
486-
const unsigned current_size, dmlc::InputSplit::Blob * chunk) {
486+
inline size_t ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t* label_dptr,
487+
const size_t current_size, dmlc::InputSplit::Blob * chunk) {
487488
temp_.resize(param_.preprocess_threads);
488489
#if MXNET_USE_OPENCV
489490
// save opencv out
490491
dmlc::RecordIOChunkReader reader(*chunk, 0, 1);
491-
unsigned gl_idx = current_size;
492+
size_t gl_idx = current_size;
492493
#pragma omp parallel num_threads(param_.preprocess_threads)
493494
{
494495
omp_exc_.Run([&] {
495496
CHECK(omp_get_num_threads() == param_.preprocess_threads);
496-
unsigned int tid = omp_get_thread_num();
497+
int tid = omp_get_thread_num();
497498
// dmlc::RecordIOChunkReader reader(*chunk, tid, param_.preprocess_threads);
498499
ImageRecordIO rec;
499500
dmlc::InputSplit::Blob blob;
@@ -502,7 +503,7 @@ inline unsigned ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t
502503
out_tmp.Clear();
503504
while (true) {
504505
bool reader_has_data;
505-
unsigned idx;
506+
size_t idx;
506507
#pragma omp critical
507508
{
508509
reader_has_data = reader.NextRecord(&blob);
@@ -567,7 +568,7 @@ inline unsigned ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t
567568
data = mshadow::Tensor<cpu, 3, DType>(data_dptr + idx*unit_size_[0],
568569
mshadow::Shape3(n_channels, res.rows, res.cols));
569570
} else {
570-
out_tmp.Push(static_cast<unsigned>(rec.image_index()),
571+
out_tmp.Push(static_cast<size_t>(rec.image_index()),
571572
mshadow::Shape3(n_channels, res.rows, res.cols),
572573
mshadow::Shape1(param_.label_width));
573574
data = out_tmp.data().Back();
@@ -612,7 +613,7 @@ inline unsigned ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t
612613
});
613614
}
614615
omp_exc_.Rethrow();
615-
return (std::min(batch_param_.batch_size, gl_idx) - current_size);
616+
return (std::min(static_cast<size_t>(batch_param_.batch_size), gl_idx) - current_size);
616617
#else
617618
LOG(FATAL) << "Opencv is needed for image decoding and augmenting.";
618619
return 0;
@@ -633,8 +634,8 @@ inline void ImageRecordIOParser2<DType>::CreateMeanImg(void) {
633634
inst_order_.clear();
634635
// Parse chunk w/o putting anything in out
635636
ParseChunk(nullptr, nullptr, batch_param_.batch_size, &chunk);
636-
for (unsigned i = 0; i < inst_order_.size(); ++i) {
637-
std::pair<unsigned, unsigned> place = inst_order_[i];
637+
for (size_t i = 0; i < inst_order_.size(); ++i) {
638+
std::pair<size_t, size_t> place = inst_order_[i];
638639
mshadow::Tensor<cpu, 3> outimg =
639640
temp_[place.first][place.second].data[0].template get<cpu, 3, real_t>();
640641
if (imcnt == 0) {

src/ndarray/ndarray.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2105,10 +2105,10 @@ void Imdecode(NDArray *ret, NDArray mean, size_t index,
21052105
if (mean.is_none()) {
21062106
MSHADOW_TYPE_SWITCH(buff.dtype(), DType, {
21072107
mshadow::Tensor<cpu, 4, DType> tensor = buff.data().get<cpu, 4, DType>();
2108-
for (index_t i = 0; i < y1-y0; i++) {
2108+
for (size_t i = 0; i < y1-y0; i++) {
21092109
uchar* im_data = res.ptr<uchar>(y0+i) + res.channels()*x0;
2110-
for (index_t j = 0; j < x1-x0; j++) {
2111-
for (index_t k = 0; k < n_channels; k++) {
2110+
for (size_t j = 0; j < x1-x0; j++) {
2111+
for (size_t k = 0; k < n_channels; k++) {
21122112
tensor[0][k][i][j] = DType(im_data[k]); // NOLINT(*)
21132113
}
21142114
im_data += res.channels();
@@ -2125,10 +2125,10 @@ void Imdecode(NDArray *ret, NDArray mean, size_t index,
21252125
MSHADOW_TYPE_SWITCH(buff.dtype(), DType, {
21262126
mshadow::Tensor<cpu, 4, DType> tensor = buff.data().get<cpu, 4, DType>();
21272127
mshadow::Tensor<cpu, 3, DType> tmean = mean.data().get<cpu, 3, DType>();
2128-
for (index_t i = 0; i < y1-y0; i++) {
2128+
for (size_t i = 0; i < y1-y0; i++) {
21292129
uchar* im_data = res.ptr<uchar>(y0+i) + res.channels()*x0;
2130-
for (index_t j = 0; j < x1-x0; j++) {
2131-
for (index_t k = 0; k < n_channels; k++) {
2130+
for (size_t j = 0; j < x1-x0; j++) {
2131+
for (size_t k = 0; k < n_channels; k++) {
21322132
tensor[0][k][i][j] = DType(im_data[k]) - tmean[k][i][j]; // NOLINT(*)
21332133
}
21342134
im_data += res.channels();

src/ndarray/ndarray_function.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ void ElementwiseSumRspImpl(mshadow::Stream<cpu>* s,
9292
auto out_value_cur_row = out_values[irow];
9393
const auto offset = row_idx_ptr - nd_indices_start;
9494
auto nd_value_cur_row = nd_values[offset];
95-
for (size_t j = 0; j < nd_value_cur_row.shape_[0]; ++j) {
95+
for (index_t j = 0; j < nd_value_cur_row.shape_[0]; ++j) {
9696
out_value_cur_row[j] += nd_value_cur_row[j];
9797
}
9898
++irow;

src/operator/batch_norm_v1-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,14 +286,14 @@ class BatchNormV1Prop : public OperatorProperty {
286286
// For other input types, these parameters have the same type as input
287287
// NOTE: This requirement is from cuDNN (v. 4 and 5)
288288
int dtype_param = (dtype == kFloat16) ? kFloat32 : dtype;
289-
for (index_t i = 1; i < in_type->size(); ++i) {
289+
for (size_t i = 1; i < in_type->size(); ++i) {
290290
if ((*in_type)[i] == -1) {
291291
(*in_type)[i] = dtype_param;
292292
} else {
293293
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, ListArguments()[i]);
294294
}
295295
}
296-
for (index_t i = 0; i < aux_type->size(); ++i) {
296+
for (size_t i = 0; i < aux_type->size(); ++i) {
297297
if ((*aux_type)[i] != -1) {
298298
UNIFORM_TYPE_CHECK((*aux_type)[i], dtype_param, ListArguments()[i]);
299299
}

src/operator/bilinear_sampler.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ __global__ void BilinearSamplerForwardKernel(const int i_c, const int i_h,
5151
int h = (index / o_w) % o_h;
5252
int c = (index / o_w / o_h) % o_c;
5353
int n = index / o_w / o_h / o_c;
54-
index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
55-
index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
54+
int out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
55+
int grid_index = n * o_h * o_w * 2 + h * o_w + w;
5656
DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
5757
DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
5858
int top_left_y = static_cast<int>(floor(y_real));
@@ -96,16 +96,16 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
9696
int n = index / o_w / o_h;
9797
DType top_left_y_gw = 0.0;
9898
DType top_left_x_gw = 0.0;
99-
index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
99+
int grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
100100
DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
101101
DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
102102

103103
int top_left_y = static_cast<int>(floor(y_real));
104104
int top_left_x = static_cast<int>(floor(x_real));
105105
DType top_left_y_w = 1.0 - (y_real - top_left_y);
106106
DType top_left_x_w = 1.0 - (x_real - top_left_x);
107-
for (index_t c = 0; c < o_c; ++c) {
108-
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
107+
for (int c = 0; c < o_c; ++c) {
108+
int grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
109109
int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
110110
// calc 4 vertex value in input data
111111
DType top_left_v = 0;

src/operator/channel_op_common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ inline void concatenate_helper(const std::vector<mshadow::Tensor<xpu, dim, DType
4444
mshadow::Tensor<xpu, dim, DType> out = *output;
4545
size_t size = input.size();
4646
index_t begin = 0;
47-
for (index_t i = 0; i < size; ++i) {
47+
for (size_t i = 0; i < size; ++i) {
4848
index_t end = begin + input[i].size(cdim);
4949
Assign(slice<cdim>(out, begin, end), req, input[i]);
5050
begin = end;
@@ -79,7 +79,7 @@ void split_helper(const mshadow::Tensor<xpu, dim, DType> &input,
7979
std::vector<mshadow::Tensor<xpu, dim, DType> > out = *output;
8080
size_t size = out.size();
8181
index_t begin = 0;
82-
for (index_t i = 0; i < size; ++i) {
82+
for (size_t i = 0; i < size; ++i) {
8383
index_t end = begin + out[i].size(cdim);
8484
Assign(out[i], req[i], slice<cdim>(input, begin, end));
8585
begin = end;

0 commit comments

Comments
 (0)