1111
1212logging .basicConfig (format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' , level = logging .INFO )
1313
14+ DEVICES = {"cpu" , "cuda" if torch .cuda .is_available () else "cpu" }
15+
16+
1417class Linear (nn .Module ):
1518 def __init__ (self ):
1619 super ().__init__ ()
@@ -118,13 +121,45 @@ def update_mask(self, layer, **kwargs):
118121
119122
120123class TestBasePruner (TestCase ):
121- def test_constructor (self ):
122- # Cannot instantiate the base
124+ def _check_pruner_prepared (self , model , pruner , device ):
125+ for g in pruner .module_groups :
126+ module = g ['module' ]
127+ assert module .weight .device == device
128+ # Check mask exists
129+ assert hasattr (module , 'mask' )
130+ # Check parametrization exists and is correct
131+ assert parametrize .is_parametrized (module )
132+ assert hasattr (module , "parametrizations" )
133+ # Assume that this is the 1st/only parametrization
134+ assert type (module .parametrizations .weight [0 ]) == PruningParametrization
135+
136+ def _check_pruner_converted (self , model , pruner , device ):
137+ for g in pruner .module_groups :
138+ module = g ['module' ]
139+ assert module .weight .device == device
140+ assert not hasattr (module , "parametrizations" )
141+ assert not hasattr (module , 'mask' )
142+
143+ def _check_pruner_valid_before_step (self , model , pruner , device ):
144+ for g in pruner .module_groups :
145+ module = g ['module' ]
146+ assert module .weight .device == device
147+ assert module .parametrizations .weight [0 ].pruned_outputs == set ()
148+
149+ def _check_pruner_valid_after_step (self , model , pruner , pruned_set , device ):
150+ for g in pruner .module_groups :
151+ module = g ['module' ]
152+ assert module .weight .device == device
153+ assert module .parametrizations .weight [0 ].pruned_outputs == pruned_set
154+
155+ def _test_constructor_on_device (self , model , device ):
123156 self .assertRaisesRegex (TypeError , 'with abstract methods update_mask' ,
124157 BasePruner )
125- # Can instantiate the model with no configs
126- model = Linear ()
158+ model = model .to (device )
127159 pruner = SimplePruner (model , None , None )
160+ for g in pruner .module_groups :
161+ module = g ['module' ]
162+ assert module .weight .device == device
128163 assert len (pruner .module_groups ) == 2
129164 pruner .step ()
130165 # Can instantiate the model with configs
@@ -134,184 +169,107 @@ def test_constructor(self):
134169 assert 'test' in pruner .module_groups [0 ]
135170 assert pruner .module_groups [0 ]['test' ] == 3
136171
137- def test_prepare_linear (self ):
172+ def test_constructor (self ):
138173 model = Linear ()
174+ for device in DEVICES :
175+ self ._test_constructor_on_device (model , torch .device (device ))
176+
177+ def _test_prepare_linear_on_device (self , model , device ):
178+ model = model .to (device )
139179 x = torch .ones (128 , 16 )
140180 pruner = SimplePruner (model , None , None )
141181 pruner .prepare ()
142- for g in pruner .module_groups :
143- module = g ['module' ]
144- # Check mask exists
145- assert hasattr (module , 'mask' )
146- # Check parametrization exists and is correct
147- assert parametrize .is_parametrized (module )
148- assert hasattr (module , "parametrizations" )
149- # Assume that this is the 1st/only parametrization
150- assert type (module .parametrizations .weight [0 ]) == PruningParametrization
182+ self ._check_pruner_prepared (model , pruner , device )
151183 assert model (x ).shape == (128 , 16 )
152184
153- def test_prepare_conv2d (self ):
154- model = Conv2d ()
185+ def test_prepare_linear (self ):
186+ models = [Linear (), LinearB ()] # without and with bias
187+ for device in DEVICES :
188+ for model in models :
189+ self ._test_prepare_linear_on_device (model , torch .device (device ))
190+
191+ def _test_prepare_conv2d_on_device (self , model , device ):
192+ model = model .to (device )
155193 x = torch .ones ((1 , 1 , 28 , 28 ))
156194 pruner = SimplePruner (model , None , None )
157195 pruner .prepare ()
158- for g in pruner .module_groups :
159- module = g ['module' ]
160- # Check mask exists
161- assert hasattr (module , 'mask' )
162- # Check parametrization exists and is correct
163- assert parametrize .is_parametrized (module )
164- assert hasattr (module , "parametrizations" )
165- # Assume that this is the 1st/only parametrization
166- assert type (module .parametrizations .weight [0 ]) == PruningParametrization
196+ self ._check_pruner_prepared (model , pruner , device )
167197 assert model (x ).shape == (1 , 64 , 24 , 24 )
168198
169- def test_prepare_linear_bias (self ):
170- model = LinearB ()
171- x = torch .ones (128 , 16 )
172- pruner = SimplePruner (model , None , None )
173- pruner .prepare ()
174- for g in pruner .module_groups :
175- module = g ['module' ]
176- # Check mask exists
177- assert hasattr (module , 'mask' )
178- # Check parametrization exists and is correct
179- assert parametrize .is_parametrized (module )
180- assert hasattr (module , "parametrizations" )
181- # Assume that this is the 1st/only parametrization
182- assert type (module .parametrizations .weight [0 ]) == PruningParametrization
183- assert model (x ).shape == (128 , 16 )
199+ def test_prepare_conv2d (self ):
200+ model = Conv2d ()
201+ for device in DEVICES :
202+ self ._test_prepare_conv2d_on_device (model , torch .device (device ))
184203
185- def test_convert_linear (self ):
186- model = Linear ( )
204+ def _test_convert_linear_on_device (self , model , device ):
205+ model = model . to ( device )
187206 x = torch .ones (128 , 16 )
188207 pruner = SimplePruner (model , None , None )
189208 pruner .prepare ()
190209 pruner .convert ()
191- for g in pruner .module_groups :
192- module = g ['module' ]
193- assert not hasattr (module , "parametrizations" )
194- assert not hasattr (module , 'mask' )
210+ self ._check_pruner_converted (model , pruner , device )
195211 assert model (x ).shape == (128 , 16 )
196212
197- def test_convert_linear_bias (self ):
198- model = LinearB ()
199- x = torch .ones (128 , 16 )
200- pruner = SimplePruner (model , None , None )
201- pruner .prepare ()
202- pruner .convert ()
203- for g in pruner .module_groups :
204- module = g ['module' ]
205- assert not hasattr (module , "parametrizations" )
206- assert not hasattr (module , 'mask' )
207- assert model (x ).shape == (128 , 16 )
213+ def test_convert_linear (self ):
214+ models = [Linear (), LinearB ()] # without and with bias
215+ for device in DEVICES :
216+ for model in models :
217+ self ._test_convert_linear_on_device (model , torch .device (device ))
208218
209- def test_convert_conv2d (self ):
210- model = Conv2d ( )
219+ def _test_convert_conv2d_on_device (self , model , device ):
220+ model = model . to ( device )
211221 x = torch .ones ((1 , 1 , 28 , 28 ))
212222 pruner = SimplePruner (model , None , None )
213223 pruner .prepare ()
214224 pruner .convert ()
215- for g in pruner .module_groups :
216- module = g ['module' ]
217- assert not hasattr (module , "parametrizations" )
218- assert not hasattr (module , 'mask' )
225+ self ._check_pruner_converted (model , pruner , device )
219226 assert model (x ).shape == (1 , 64 , 24 , 24 )
220227
221- def test_step_linear (self ):
222- model = Linear ()
223- x = torch .ones (16 , 16 )
224- pruner = SimplePruner (model , None , None )
225- pruner .prepare ()
226- pruner .enable_mask_update = True
227- for g in pruner .module_groups :
228- # Before step
229- module = g ['module' ]
230- assert module .parametrizations .weight [0 ].pruned_outputs == set ()
231- pruner .step ()
232- for g in pruner .module_groups :
233- # After step
234- module = g ['module' ]
235- assert module .parametrizations .weight [0 ].pruned_outputs == set ({1 })
236- assert not (False in (model (x )[:, 1 ] == 0 ))
237-
238- model = MultipleLinear ()
239- x = torch .ones (7 , 7 )
240- pruner = MultiplePruner (model , None , None )
241- pruner .prepare ()
242- pruner .enable_mask_update = True
243- for g in pruner .module_groups :
244- # Before step
245- module = g ['module' ]
246- assert module .parametrizations .weight [0 ].pruned_outputs == set ()
247- pruner .step ()
248- for g in pruner .module_groups :
249- # After step
250- module = g ['module' ]
251- assert module .parametrizations .weight [0 ].pruned_outputs == set ({1 , 2 })
252- assert not (False in (model (x )[:, 1 ] == 0 ))
253- assert not (False in (model (x )[:, 2 ] == 0 ))
254-
255- def test_step_conv2d (self ):
228+ def test_convert_conv2d (self ):
256229 model = Conv2d ()
230+ for device in DEVICES :
231+ self ._test_convert_conv2d_on_device (model , torch .device (device ))
232+
233+ def _test_step_linear_on_device (self , model , is_basic , device ):
234+ model = model .to (device )
235+ if is_basic :
236+ x = torch .ones (16 , 16 )
237+ pruner = SimplePruner (model , None , None )
238+ pruner .prepare ()
239+ pruner .enable_mask_update = True
240+ self ._check_pruner_valid_before_step (model , pruner , device )
241+ pruner .step ()
242+ self ._check_pruner_valid_after_step (model , pruner , {1 }, device )
243+ else :
244+ x = torch .ones (7 , 7 )
245+ pruner = MultiplePruner (model , None , None )
246+ pruner .prepare ()
247+ pruner .enable_mask_update = True
248+ self ._check_pruner_valid_before_step (model , pruner , device )
249+ pruner .step ()
250+ self ._check_pruner_valid_after_step (model , pruner , {1 , 2 }, device )
251+
252+ def test_step_linear (self ):
253+ basic_models = [Linear (), LinearB ()]
254+ complex_models = [MultipleLinear (), MultipleLinearB (), MultipleLinearMixed ()]
255+ for device in DEVICES :
256+ for model in basic_models :
257+ self ._test_step_linear_on_device (model , True , torch .device (device ))
258+ for model in complex_models :
259+ self ._test_step_linear_on_device (model , False , torch .device (device ))
260+
261+ def _test_step_conv2d_on_device (self , model , device ):
262+ model = model .to (device )
257263 x = torch .ones ((1 , 1 , 28 , 28 ))
258264 pruner = SimplePruner (model , None , None )
259265 pruner .prepare ()
260266 pruner .enable_mask_update = True
261- for g in pruner .module_groups :
262- # Before step
263- module = g ['module' ]
264- assert module .parametrizations .weight [0 ].pruned_outputs == set ()
267+ self ._check_pruner_valid_before_step (model , pruner , device )
265268 pruner .step ()
266- for g in pruner .module_groups :
267- # After step
268- module = g ['module' ]
269- assert module .parametrizations .weight [0 ].pruned_outputs == set ({1 })
270- assert not (False in (model (x )[:, 1 , :, :] == 0 ))
269+ self ._check_pruner_valid_after_step (model , pruner , {1 }, device )
271270 assert model (x ).shape == (1 , 64 , 24 , 24 )
272271
273- def test_step_linear_bias (self ):
274- model = LinearB ()
275- x = torch .ones (16 , 16 )
276- pruner = SimplePruner (model , None , None )
277- pruner .prepare ()
278- pruner .enable_mask_update = True
279- for g in pruner .module_groups :
280- # Before step
281- module = g ['module' ]
282- assert module .parametrizations .weight [0 ].pruned_outputs == set ()
283- pruner .step ()
284- for g in pruner .module_groups :
285- # After step
286- module = g ['module' ]
287- assert module .parametrizations .weight [0 ].pruned_outputs == set ({1 })
288-
289- model = MultipleLinearB ()
290- x = torch .ones (7 , 7 )
291- pruner = MultiplePruner (model , None , None )
292- pruner .prepare ()
293- pruner .enable_mask_update = True
294- for g in pruner .module_groups :
295- # Before step
296- module = g ['module' ]
297- assert module .parametrizations .weight [0 ].pruned_outputs == set ()
298- pruner .step ()
299- for g in pruner .module_groups :
300- # After step
301- module = g ['module' ]
302- assert module .parametrizations .weight [0 ].pruned_outputs == set ({1 , 2 })
303-
304- model = MultipleLinearMixed ()
305- x = torch .ones (7 , 7 )
306- pruner = MultiplePruner (model , None , None )
307- pruner .prepare ()
308- pruner .enable_mask_update = True
309- for g in pruner .module_groups :
310- # Before step
311- module = g ['module' ]
312- assert module .parametrizations .weight [0 ].pruned_outputs == set ()
313- pruner .step ()
314- for g in pruner .module_groups :
315- # After step
316- module = g ['module' ]
317- assert module .parametrizations .weight [0 ].pruned_outputs == set ({1 , 2 })
272+ def test_step_conv2d (self ):
273+ model = Conv2d ()
274+ for device in DEVICES :
275+ self ._test_step_conv2d_on_device (model , torch .device (device ))
0 commit comments