Skip to content

Commit c0597b1

Browse files
authored
Merge pull request BVLC#5477 from lukeyeager/bvlc/test-draw-net
[pycaffe] Fix draw_net() and add a test
2 parents 5ad9e53 + 179dafd commit c0597b1

File tree

4 files changed

+39
-3
lines changed

4 files changed

+39
-3
lines changed

python/caffe/draw.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ def get_layer_label(layer, rankdir):
104104
pooling_types_dict[layer.pooling_param.pool],
105105
layer.type,
106106
separator,
107-
layer.pooling_param.kernel_size[0] if len(layer.pooling_param.kernel_size._values) else 1,
107+
layer.pooling_param.kernel_size,
108108
separator,
109-
layer.pooling_param.stride[0] if len(layer.pooling_param.stride._values) else 1,
109+
layer.pooling_param.stride,
110110
separator,
111-
layer.pooling_param.pad[0] if len(layer.pooling_param.pad._values) else 0)
111+
layer.pooling_param.pad)
112112
else:
113113
node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type)
114114
return node_label

python/caffe/test/test_draw.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
import unittest
3+
4+
from google import protobuf
5+
6+
import caffe.draw
7+
from caffe.proto import caffe_pb2
8+
9+
def getFilenames():
10+
"""Yields files in the source tree which are Net prototxts."""
11+
result = []
12+
13+
root_dir = os.path.abspath(os.path.join(
14+
os.path.dirname(__file__), '..', '..', '..'))
15+
assert os.path.exists(root_dir)
16+
17+
for dirname in ('models', 'examples'):
18+
dirname = os.path.join(root_dir, dirname)
19+
assert os.path.exists(dirname)
20+
for cwd, _, filenames in os.walk(dirname):
21+
for filename in filenames:
22+
filename = os.path.join(cwd, filename)
23+
if filename.endswith('.prototxt') and 'solver' not in filename:
24+
yield os.path.join(dirname, filename)
25+
26+
27+
class TestDraw(unittest.TestCase):
28+
def test_draw_net(self):
29+
for filename in getFilenames():
30+
net = caffe_pb2.NetParameter()
31+
with open(filename) as infile:
32+
protobuf.text_format.Merge(infile.read(), net)
33+
caffe.draw.draw_net(net, 'LR')

scripts/travis/install-deps.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ source $BASEDIR/defaults.sh
88
apt-get -y update
99
apt-get install -y --no-install-recommends \
1010
build-essential \
11+
graphviz \
1112
libboost-filesystem-dev \
1213
libboost-python-dev \
1314
libboost-system-dev \
@@ -31,6 +32,7 @@ if ! $WITH_PYTHON3 ; then
3132
python-dev \
3233
python-numpy \
3334
python-protobuf \
35+
python-pydot \
3436
python-skimage
3537
else
3638
# Python3

scripts/travis/install-python-deps.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ if ! $WITH_PYTHON3 ; then
1111
else
1212
# Python3
1313
pip install --pre protobuf==3.0.0b3
14+
pip install pydot
1415
fi

0 commit comments

Comments
 (0)