Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 4cbc17a

Browse files
committed
Add tests for generic image datasets and models
1 parent 149835a commit 4cbc17a

File tree

5 files changed

+850
-82
lines changed

5 files changed

+850
-82
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
3+
"""
4+
Functions for creating temporary LMDBs
5+
Used in test_views
6+
"""
7+
8+
import os
9+
import sys
10+
import time
11+
import argparse
12+
from collections import defaultdict
13+
from cStringIO import StringIO
14+
15+
import numpy as np
16+
import PIL.Image
17+
import lmdb
18+
19+
try:
20+
import caffe_pb2
21+
except ImportError:
22+
# See issue #32
23+
from caffe.proto import caffe_pb2
24+
25+
26+
IMAGE_SIZE = 10
27+
TRAIN_IMAGE_COUNT = 100
28+
VAL_IMAGE_COUNT = 20
29+
30+
31+
def create_lmdbs(folder, image_width=None, image_height=None, image_count=None):
32+
"""
33+
Creates LMDBs for generic inference
34+
Returns the filename for a test image
35+
36+
Creates these files in "folder":
37+
train_images/
38+
train_labels/
39+
val_images/
40+
val_labels/
41+
mean.binaryproto
42+
test.png
43+
"""
44+
if image_width is None:
45+
image_width = IMAGE_SIZE
46+
if image_height is None:
47+
image_height = IMAGE_SIZE
48+
49+
if image_count is None:
50+
train_image_count = TRAIN_IMAGE_COUNT
51+
else:
52+
train_image_count = image_count
53+
val_image_count = VAL_IMAGE_COUNT
54+
55+
# Used to calculate the gradients later
56+
yy, xx = np.mgrid[:image_height, :image_width].astype('float')
57+
58+
for phase, image_count in [
59+
('train', train_image_count),
60+
('val', val_image_count)]:
61+
image_db = lmdb.open(os.path.join(folder, '%s_images' % phase),
62+
map_size=1024**4, # 1TB
63+
map_async=True,
64+
max_dbs=0)
65+
label_db = lmdb.open(os.path.join(folder, '%s_labels' % phase),
66+
map_size=1024**4, # 1TB
67+
map_async=True,
68+
max_dbs=0)
69+
70+
write_batch_size = 10
71+
72+
image_txn = image_db.begin(write=True)
73+
label_txn = label_db.begin(write=True)
74+
75+
image_sum = np.zeros((image_height, image_width), 'float64')
76+
77+
for i in xrange(image_count):
78+
xslope, yslope = np.random.random_sample(2) - 0.5
79+
a = xslope * 255 / image_width
80+
b = yslope * 255 / image_height
81+
image = a * (xx - image_width/2) + b * (yy - image_height/2) + 127.5
82+
83+
image_sum += image
84+
image = image.astype('uint8')
85+
86+
pil_img = PIL.Image.fromarray(image)
87+
#pil_img.save(os.path.join(folder, '%s_%d.png' % (phase, i)))
88+
89+
# create image Datum
90+
image_datum = caffe_pb2.Datum()
91+
image_datum.height = image.shape[0]
92+
image_datum.width = image.shape[1]
93+
image_datum.channels = 1
94+
s = StringIO()
95+
pil_img.save(s, format='PNG')
96+
image_datum.data = s.getvalue()
97+
image_datum.encoded = True
98+
image_txn.put(str(i), image_datum.SerializeToString())
99+
100+
# create label Datum
101+
label_datum = caffe_pb2.Datum()
102+
label_datum.channels, label_datum.height, label_datum.width = 1, 1, 2
103+
label_datum.float_data.extend(np.array([xslope, yslope]).flat)
104+
label_txn.put(str(i), label_datum.SerializeToString())
105+
106+
if ((i+1)%write_batch_size) == 0:
107+
image_txn.commit()
108+
image_txn = image_db.begin(write=True)
109+
label_txn.commit()
110+
label_txn = label_db.begin(write=True)
111+
112+
# close databases
113+
image_db.close()
114+
label_db.close()
115+
116+
# save mean
117+
mean_image = (image_sum / image_count).astype('uint8')
118+
_save_mean(mean_image, os.path.join(folder, '%s_mean.png' % phase))
119+
_save_mean(mean_image, os.path.join(folder, '%s_mean.binaryproto' % phase))
120+
121+
# create test image
122+
# The network should be able to easily produce two numbers >1
123+
xslope, yslope = 0.5, 0.5
124+
a = xslope * 255 / image_width
125+
b = yslope * 255 / image_height
126+
test_image = a * (xx - image_width/2) + b * (yy - image_height/2) + 127.5
127+
test_image = test_image.astype('uint8')
128+
pil_img = PIL.Image.fromarray(test_image)
129+
test_image_filename = os.path.join(folder, 'test.png')
130+
pil_img.save(test_image_filename)
131+
132+
return test_image_filename
133+
134+
def _save_mean(mean, filename):
135+
"""
136+
Saves mean to file
137+
138+
Arguments:
139+
mean -- the mean as an np.ndarray
140+
filename -- the location to save the image
141+
"""
142+
if filename.endswith('.binaryproto'):
143+
blob = caffe_pb2.BlobProto()
144+
blob.num = 1
145+
blob.channels = 1
146+
blob.height, blob.width = mean.shape
147+
blob.data.extend(mean.astype(float).flat)
148+
with open(filename, 'w') as outfile:
149+
outfile.write(blob.SerializeToString())
150+
151+
elif filename.endswith(('.jpg', '.jpeg', '.png')):
152+
image = PIL.Image.fromarray(mean)
153+
image.save(filename)
154+
else:
155+
raise ValueError('unrecognized file extension')
156+
157+
158+
if __name__ == '__main__':
159+
parser = argparse.ArgumentParser(description='Create-LMDB tool - DIGITS')
160+
161+
### Positional arguments
162+
163+
parser.add_argument('folder',
164+
help='Where to save the images'
165+
)
166+
167+
### Optional arguments
168+
169+
parser.add_argument('-x', '--image_width',
170+
type=int,
171+
help='Width of the images')
172+
parser.add_argument('-y', '--image_height',
173+
type=int,
174+
help='Height of the images')
175+
parser.add_argument('-c', '--image_count',
176+
type=int,
177+
help='How many images')
178+
179+
args = vars(parser.parse_args())
180+
181+
if os.path.exists(args['folder']):
182+
print 'ERROR: Folder already exists'
183+
sys.exit(1)
184+
else:
185+
os.makedirs(args['folder'])
186+
187+
print 'Creating images at "%s" ...' % args['folder']
188+
189+
start_time = time.time()
190+
191+
create_lmdbs(args['folder'],
192+
image_width=args['image_width'],
193+
image_height=args['image_height'],
194+
image_count=args['image_count'],
195+
)
196+
197+
print 'Done after %s seconds' % (time.time() - start_time,)
198+

0 commit comments

Comments
 (0)