Skip to content

Commit d98b1c4

Browse files
kazhoufacebook-github-bot
authored andcommitted
[pruner] add cuda tests for pruner (#61993)
Summary: Pull Request resolved: #61993 Repeating `test_pruner` unit tests for Linear and Conv2d models with device = 'cuda' to confirm pruner will work on GPU - set device to cuda - move model to device - assert that module.weight.device is cuda ghstack-source-id: 134554382 Test Plan: `buck test mode/dev-nosan //caffe2/test:ao -- TestBasePruner` https://pxl.cl/1Md9c Reviewed By: jerryzh168 Differential Revision: D29829293 fbshipit-source-id: 1f7250e45695d0ad634d0bb7582a34fd1324e765
1 parent b39b28c commit d98b1c4

File tree

1 file changed

+110
-152
lines changed

1 file changed

+110
-152
lines changed

test/ao/sparsity/test_pruner.py

Lines changed: 110 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
1313

14+
DEVICES = {"cpu", "cuda" if torch.cuda.is_available() else "cpu"}
15+
16+
1417
class Linear(nn.Module):
1518
def __init__(self):
1619
super().__init__()
@@ -118,13 +121,45 @@ def update_mask(self, layer, **kwargs):
118121

119122

120123
class TestBasePruner(TestCase):
121-
def test_constructor(self):
122-
# Cannot instantiate the base
124+
def _check_pruner_prepared(self, model, pruner, device):
125+
for g in pruner.module_groups:
126+
module = g['module']
127+
assert module.weight.device == device
128+
# Check mask exists
129+
assert hasattr(module, 'mask')
130+
# Check parametrization exists and is correct
131+
assert parametrize.is_parametrized(module)
132+
assert hasattr(module, "parametrizations")
133+
# Assume that this is the 1st/only parametrization
134+
assert type(module.parametrizations.weight[0]) == PruningParametrization
135+
136+
def _check_pruner_converted(self, model, pruner, device):
137+
for g in pruner.module_groups:
138+
module = g['module']
139+
assert module.weight.device == device
140+
assert not hasattr(module, "parametrizations")
141+
assert not hasattr(module, 'mask')
142+
143+
def _check_pruner_valid_before_step(self, model, pruner, device):
144+
for g in pruner.module_groups:
145+
module = g['module']
146+
assert module.weight.device == device
147+
assert module.parametrizations.weight[0].pruned_outputs == set()
148+
149+
def _check_pruner_valid_after_step(self, model, pruner, pruned_set, device):
150+
for g in pruner.module_groups:
151+
module = g['module']
152+
assert module.weight.device == device
153+
assert module.parametrizations.weight[0].pruned_outputs == pruned_set
154+
155+
def _test_constructor_on_device(self, model, device):
123156
self.assertRaisesRegex(TypeError, 'with abstract methods update_mask',
124157
BasePruner)
125-
# Can instantiate the model with no configs
126-
model = Linear()
158+
model = model.to(device)
127159
pruner = SimplePruner(model, None, None)
160+
for g in pruner.module_groups:
161+
module = g['module']
162+
assert module.weight.device == device
128163
assert len(pruner.module_groups) == 2
129164
pruner.step()
130165
# Can instantiate the model with configs
@@ -134,184 +169,107 @@ def test_constructor(self):
134169
assert 'test' in pruner.module_groups[0]
135170
assert pruner.module_groups[0]['test'] == 3
136171

137-
def test_prepare_linear(self):
172+
def test_constructor(self):
138173
model = Linear()
174+
for device in DEVICES:
175+
self._test_constructor_on_device(model, torch.device(device))
176+
177+
def _test_prepare_linear_on_device(self, model, device):
178+
model = model.to(device)
139179
x = torch.ones(128, 16)
140180
pruner = SimplePruner(model, None, None)
141181
pruner.prepare()
142-
for g in pruner.module_groups:
143-
module = g['module']
144-
# Check mask exists
145-
assert hasattr(module, 'mask')
146-
# Check parametrization exists and is correct
147-
assert parametrize.is_parametrized(module)
148-
assert hasattr(module, "parametrizations")
149-
# Assume that this is the 1st/only parametrization
150-
assert type(module.parametrizations.weight[0]) == PruningParametrization
182+
self._check_pruner_prepared(model, pruner, device)
151183
assert model(x).shape == (128, 16)
152184

153-
def test_prepare_conv2d(self):
154-
model = Conv2d()
185+
def test_prepare_linear(self):
186+
models = [Linear(), LinearB()] # without and with bias
187+
for device in DEVICES:
188+
for model in models:
189+
self._test_prepare_linear_on_device(model, torch.device(device))
190+
191+
def _test_prepare_conv2d_on_device(self, model, device):
192+
model = model.to(device)
155193
x = torch.ones((1, 1, 28, 28))
156194
pruner = SimplePruner(model, None, None)
157195
pruner.prepare()
158-
for g in pruner.module_groups:
159-
module = g['module']
160-
# Check mask exists
161-
assert hasattr(module, 'mask')
162-
# Check parametrization exists and is correct
163-
assert parametrize.is_parametrized(module)
164-
assert hasattr(module, "parametrizations")
165-
# Assume that this is the 1st/only parametrization
166-
assert type(module.parametrizations.weight[0]) == PruningParametrization
196+
self._check_pruner_prepared(model, pruner, device)
167197
assert model(x).shape == (1, 64, 24, 24)
168198

169-
def test_prepare_linear_bias(self):
170-
model = LinearB()
171-
x = torch.ones(128, 16)
172-
pruner = SimplePruner(model, None, None)
173-
pruner.prepare()
174-
for g in pruner.module_groups:
175-
module = g['module']
176-
# Check mask exists
177-
assert hasattr(module, 'mask')
178-
# Check parametrization exists and is correct
179-
assert parametrize.is_parametrized(module)
180-
assert hasattr(module, "parametrizations")
181-
# Assume that this is the 1st/only parametrization
182-
assert type(module.parametrizations.weight[0]) == PruningParametrization
183-
assert model(x).shape == (128, 16)
199+
def test_prepare_conv2d(self):
200+
model = Conv2d()
201+
for device in DEVICES:
202+
self._test_prepare_conv2d_on_device(model, torch.device(device))
184203

185-
def test_convert_linear(self):
186-
model = Linear()
204+
def _test_convert_linear_on_device(self, model, device):
205+
model = model.to(device)
187206
x = torch.ones(128, 16)
188207
pruner = SimplePruner(model, None, None)
189208
pruner.prepare()
190209
pruner.convert()
191-
for g in pruner.module_groups:
192-
module = g['module']
193-
assert not hasattr(module, "parametrizations")
194-
assert not hasattr(module, 'mask')
210+
self._check_pruner_converted(model, pruner, device)
195211
assert model(x).shape == (128, 16)
196212

197-
def test_convert_linear_bias(self):
198-
model = LinearB()
199-
x = torch.ones(128, 16)
200-
pruner = SimplePruner(model, None, None)
201-
pruner.prepare()
202-
pruner.convert()
203-
for g in pruner.module_groups:
204-
module = g['module']
205-
assert not hasattr(module, "parametrizations")
206-
assert not hasattr(module, 'mask')
207-
assert model(x).shape == (128, 16)
213+
def test_convert_linear(self):
214+
models = [Linear(), LinearB()] # without and with bias
215+
for device in DEVICES:
216+
for model in models:
217+
self._test_convert_linear_on_device(model, torch.device(device))
208218

209-
def test_convert_conv2d(self):
210-
model = Conv2d()
219+
def _test_convert_conv2d_on_device(self, model, device):
220+
model = model.to(device)
211221
x = torch.ones((1, 1, 28, 28))
212222
pruner = SimplePruner(model, None, None)
213223
pruner.prepare()
214224
pruner.convert()
215-
for g in pruner.module_groups:
216-
module = g['module']
217-
assert not hasattr(module, "parametrizations")
218-
assert not hasattr(module, 'mask')
225+
self._check_pruner_converted(model, pruner, device)
219226
assert model(x).shape == (1, 64, 24, 24)
220227

221-
def test_step_linear(self):
222-
model = Linear()
223-
x = torch.ones(16, 16)
224-
pruner = SimplePruner(model, None, None)
225-
pruner.prepare()
226-
pruner.enable_mask_update = True
227-
for g in pruner.module_groups:
228-
# Before step
229-
module = g['module']
230-
assert module.parametrizations.weight[0].pruned_outputs == set()
231-
pruner.step()
232-
for g in pruner.module_groups:
233-
# After step
234-
module = g['module']
235-
assert module.parametrizations.weight[0].pruned_outputs == set({1})
236-
assert not (False in (model(x)[:, 1] == 0))
237-
238-
model = MultipleLinear()
239-
x = torch.ones(7, 7)
240-
pruner = MultiplePruner(model, None, None)
241-
pruner.prepare()
242-
pruner.enable_mask_update = True
243-
for g in pruner.module_groups:
244-
# Before step
245-
module = g['module']
246-
assert module.parametrizations.weight[0].pruned_outputs == set()
247-
pruner.step()
248-
for g in pruner.module_groups:
249-
# After step
250-
module = g['module']
251-
assert module.parametrizations.weight[0].pruned_outputs == set({1, 2})
252-
assert not (False in (model(x)[:, 1] == 0))
253-
assert not (False in (model(x)[:, 2] == 0))
254-
255-
def test_step_conv2d(self):
228+
def test_convert_conv2d(self):
256229
model = Conv2d()
230+
for device in DEVICES:
231+
self._test_convert_conv2d_on_device(model, torch.device(device))
232+
233+
def _test_step_linear_on_device(self, model, is_basic, device):
234+
model = model.to(device)
235+
if is_basic:
236+
x = torch.ones(16, 16)
237+
pruner = SimplePruner(model, None, None)
238+
pruner.prepare()
239+
pruner.enable_mask_update = True
240+
self._check_pruner_valid_before_step(model, pruner, device)
241+
pruner.step()
242+
self._check_pruner_valid_after_step(model, pruner, {1}, device)
243+
else:
244+
x = torch.ones(7, 7)
245+
pruner = MultiplePruner(model, None, None)
246+
pruner.prepare()
247+
pruner.enable_mask_update = True
248+
self._check_pruner_valid_before_step(model, pruner, device)
249+
pruner.step()
250+
self._check_pruner_valid_after_step(model, pruner, {1, 2}, device)
251+
252+
def test_step_linear(self):
253+
basic_models = [Linear(), LinearB()]
254+
complex_models = [MultipleLinear(), MultipleLinearB(), MultipleLinearMixed()]
255+
for device in DEVICES:
256+
for model in basic_models:
257+
self._test_step_linear_on_device(model, True, torch.device(device))
258+
for model in complex_models:
259+
self._test_step_linear_on_device(model, False, torch.device(device))
260+
261+
def _test_step_conv2d_on_device(self, model, device):
262+
model = model.to(device)
257263
x = torch.ones((1, 1, 28, 28))
258264
pruner = SimplePruner(model, None, None)
259265
pruner.prepare()
260266
pruner.enable_mask_update = True
261-
for g in pruner.module_groups:
262-
# Before step
263-
module = g['module']
264-
assert module.parametrizations.weight[0].pruned_outputs == set()
267+
self._check_pruner_valid_before_step(model, pruner, device)
265268
pruner.step()
266-
for g in pruner.module_groups:
267-
# After step
268-
module = g['module']
269-
assert module.parametrizations.weight[0].pruned_outputs == set({1})
270-
assert not (False in (model(x)[:, 1, :, :] == 0))
269+
self._check_pruner_valid_after_step(model, pruner, {1}, device)
271270
assert model(x).shape == (1, 64, 24, 24)
272271

273-
def test_step_linear_bias(self):
274-
model = LinearB()
275-
x = torch.ones(16, 16)
276-
pruner = SimplePruner(model, None, None)
277-
pruner.prepare()
278-
pruner.enable_mask_update = True
279-
for g in pruner.module_groups:
280-
# Before step
281-
module = g['module']
282-
assert module.parametrizations.weight[0].pruned_outputs == set()
283-
pruner.step()
284-
for g in pruner.module_groups:
285-
# After step
286-
module = g['module']
287-
assert module.parametrizations.weight[0].pruned_outputs == set({1})
288-
289-
model = MultipleLinearB()
290-
x = torch.ones(7, 7)
291-
pruner = MultiplePruner(model, None, None)
292-
pruner.prepare()
293-
pruner.enable_mask_update = True
294-
for g in pruner.module_groups:
295-
# Before step
296-
module = g['module']
297-
assert module.parametrizations.weight[0].pruned_outputs == set()
298-
pruner.step()
299-
for g in pruner.module_groups:
300-
# After step
301-
module = g['module']
302-
assert module.parametrizations.weight[0].pruned_outputs == set({1, 2})
303-
304-
model = MultipleLinearMixed()
305-
x = torch.ones(7, 7)
306-
pruner = MultiplePruner(model, None, None)
307-
pruner.prepare()
308-
pruner.enable_mask_update = True
309-
for g in pruner.module_groups:
310-
# Before step
311-
module = g['module']
312-
assert module.parametrizations.weight[0].pruned_outputs == set()
313-
pruner.step()
314-
for g in pruner.module_groups:
315-
# After step
316-
module = g['module']
317-
assert module.parametrizations.weight[0].pruned_outputs == set({1, 2})
272+
def test_step_conv2d(self):
273+
model = Conv2d()
274+
for device in DEVICES:
275+
self._test_step_conv2d_on_device(model, torch.device(device))

0 commit comments

Comments
 (0)