Skip to content

Commit bcf933b

Browse files
committed
Adapt the data interface from CXXNET authored by @tqchen & @antinucleon
1 parent 76dfaaf commit bcf933b

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

include/caffe/data/data.hpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright 2014 BVLC and contributors.
2+
/*
3+
* Adapted from cxxnet
4+
*/
5+
#ifndef CAFFE_UTIL_DATA_H_
6+
#define CAFFE_UTIL_DATA_H_
7+
8+
#include <vector>
9+
#include "mshadow/tensor.h"
10+
#include "caffe/proto/caffe.pb.h"
11+
12+
#include "caffe/blob.hpp"
13+
14+
namespace caffe {
15+
using std::vector;
16+
17+
template<typename Dtype>
18+
class DataIterator {
19+
public:
20+
virtual ~DataIterator() {}
21+
virtual void Init() = 0;
22+
virtual void BeforeFirst() = 0;
23+
virtual bool Next() = 0;
24+
virtual const Dtype& Value() const = 0;
25+
};
26+
27+
template<typename Dtype>
28+
class DataInstance {
29+
public:
30+
float label;
31+
uint32_t index;
32+
Blob<Dtype> data;
33+
};
34+
35+
TensorShape
36+
37+
template<typename Dtype>
38+
class DataBatch {
39+
public:
40+
DataBatch(): labels(), indices(), batch_size(0) {
41+
}
42+
43+
inline void AllocSpace(mshadow::Shape<4> shape, const size_t batch_size) {
44+
data = Blob::NewBlob(shape);
45+
labele.resize(batch_size);
46+
indices.resize(batch_size);
47+
this->batch_size = batch_size;
48+
}
49+
50+
inline void FreeSpace() {
51+
}
52+
53+
inline void CopyFrom(const DataBatch& src) {
54+
CHECK_EQ(batch_size, src.batch_size);
55+
labels = src.labels;
56+
indices = src.indices;
57+
data = src.data;
58+
}
59+
60+
public:
61+
vector<float> labels;
62+
vector<uint32_t> indices;
63+
size_t batch_size;
64+
Blob<Dtype> data;
65+
};
66+
67+
template <typename Dtype>
68+
DataIterator<DataBatch>* GetDataIterator(const DataIteratorParameter& param);
69+
70+
} // namespace caffe
71+
72+
#endif // CAFFE_UTIL_DATA_H_

src/caffe/data/data.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2014 BVLC and contributors.
2+
3+
#include <string>
4+
5+
#include "caffe/data/data.hpp"
6+
#include "caffe/proto/caffe.pb.h"
7+
8+
namespace caffe {
9+
using std::string;
10+
11+
template <typename Dtype>
12+
DataIterator<DataBatch>* GetDataIterator(const DataIteratorParameter& param) {
13+
const string& name = param.name();
14+
const DataIteratorParameter_DataIteratorType& type = param.type();
15+
switch (type) {
16+
case DataIteratorParameter_DataIteratorType_HDF5:
17+
return new HDF5DataIterator<Dtype>(param);
18+
case DataIteratorParameter_DataIteratorType_IMAGE:
19+
return new ImageDataIterator<Dtype>(param);
20+
case DataIteratorParameter_DataIteratorType_LEVELDB:
21+
return new LeveldbDataIterator<Dtype>(param);
22+
case DataIteratorParameter_DataIteratorType_MEMORY:
23+
return new MemoryDataIterator<Dtype>(param);
24+
case DataIteratorParameter_DataIteratorType_WINDOW:
25+
return new WindowDataIterator<Dtype>(param);
26+
default:
27+
LOG(FATAL) << "DataIterator " << name << " has unknown type " << type;
28+
}
29+
return (DataIterator<Dtype>*)(NULL);
30+
}
31+
32+
template
33+
DataIterator<float>* GetDataIterator(const DataIteratorParameter& param);
34+
template
35+
DataIterator<double>* GetDataIterator(const DataIteratorParameter& param);
36+
37+
38+
} // namespace caffe

0 commit comments

Comments
 (0)