@@ -73,6 +73,20 @@ def __init__(
7373 norm_layer = norm_layer , activation_layer = activation_layer )
7474
7575
76+ def _make_stem (
77+ stem_width : int ,
78+ norm_layer : Callable [..., nn .Module ],
79+ activation : Callable [..., nn .Module ],
80+ stem_type : Callable [..., nn .Module ] = SimpleStemIN ,
81+ ) -> nn .Module :
82+ return stem_type (
83+ 3 , # width_in
84+ stem_width ,
85+ norm_layer ,
86+ activation ,
87+ )
88+
89+
7690class VanillaBlock (nn .Sequential ):
7791 """Vanilla block: [3x3 conv, BN, Relu] x2."""
7892
@@ -201,9 +215,6 @@ def __init__(
201215 )
202216 self .activation = activation_layer (inplace = True )
203217
204- # The projection and transform happen in parallel,
205- # and activation is not counted with respect to depth
206-
207218 def forward (self , x : Tensor ) -> Tensor :
208219 if self .proj_block :
209220 x = self .bn (self .proj (x )) + self .f (x )
@@ -288,6 +299,7 @@ def __init__(
288299 bottleneck_multiplier : float = 1.0 ,
289300 use_se : bool = True ,
290301 se_ratio : float = 0.25 ,
302+ ** kwargs : Any ,
291303 ) -> None :
292304 if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0 :
293305 raise ValueError ("Invalid RegNet settings" )
@@ -377,83 +389,79 @@ def _adjust_widths_groups_compatibilty(
377389 return stage_widths , group_widths_min
378390
379391
380- class RegNet (nn .Module ):
381- def __init__ (
382- self ,
383- block_params : BlockParams ,
384- num_classes : int = 1000 ,
385- stem_width : int = 32 ,
386- stem_type : Optional [Callable [..., nn .Module ]] = None ,
387- block_type : Optional [Callable [..., nn .Module ]] = None ,
388- norm_layer : Optional [Callable [..., nn .Module ]] = None ,
389- activation : Optional [Callable [..., nn .Module ]] = None ,
390- ) -> None :
391- super ().__init__ ()
392-
393- if stem_type is None :
394- stem_type = SimpleStemIN
395- if norm_layer is None :
396- norm_layer = nn .BatchNorm2d
397- if block_type is None :
398- block_type = ResBottleneckBlock
399- if activation is None :
400- activation = nn .ReLU
401-
402- # Ad hoc stem
403- self .stem = stem_type (
404- 3 , # width_in
405- stem_width ,
406- norm_layer ,
407- activation ,
392+ def _make_blocks (
393+ stem_width : int ,
394+ params : BlockParams ,
395+ norm_layer : Callable [..., nn .Module ],
396+ activation : Callable [..., nn .Module ],
397+ block_type : Callable [..., nn .Module ] = ResBottleneckBlock ,
398+ ) -> Tuple [nn .Sequential , int ]:
399+ current_width = stem_width
400+
401+ blocks = []
402+ for i , (
403+ width_out ,
404+ stride ,
405+ depth ,
406+ group_width ,
407+ bottleneck_multiplier ,
408+ ) in enumerate (params .get_expanded_params ()):
409+ blocks .append (
410+ (
411+ f"block{ i + 1 } " ,
412+ AnyStage (
413+ current_width ,
414+ width_out ,
415+ stride ,
416+ depth ,
417+ block_type ,
418+ norm_layer ,
419+ activation ,
420+ group_width ,
421+ bottleneck_multiplier ,
422+ params .se_ratio ,
423+ stage_index = i + 1 ,
424+ ),
425+ )
408426 )
409427
410- current_width = stem_width
428+ current_width = width_out
429+ return (nn .Sequential (OrderedDict (blocks )), current_width )
411430
412- blocks = []
413- for i , (
414- width_out ,
415- stride ,
416- depth ,
417- group_width ,
418- bottleneck_multiplier ,
419- ) in enumerate (block_params .get_expanded_params ()):
420- blocks .append (
421- (
422- f"block{ i + 1 } " ,
423- AnyStage (
424- current_width ,
425- width_out ,
426- stride ,
427- depth ,
428- block_type ,
429- norm_layer ,
430- activation ,
431- group_width ,
432- bottleneck_multiplier ,
433- block_params .se_ratio ,
434- stage_index = i + 1 ,
435- ),
436- )
437- )
438431
439- current_width = width_out
432+ class Classifier (nn .Module ):
433+ def __init__ (self , in_channels : int , num_classes : int = 1000 ) -> None :
434+ super ().__init__ ()
435+ self .avgpool = nn .AdaptiveAvgPool2d ((1 , 1 ))
436+ self .fc = nn .Linear (in_features = in_channels , out_features = num_classes )
440437
441- self .trunk_output = nn .Sequential (OrderedDict (blocks ))
438+ def forward (self , x : Tensor ) -> Tensor :
439+ x = self .avgpool (x )
440+ x = x .flatten (start_dim = 1 )
441+ x = self .fc (x )
442+ return x
442443
443- self .avgpool = nn .AdaptiveAvgPool2d ((1 , 1 ))
444- self .fc = nn .Linear (in_features = current_width , out_features = num_classes )
444+
445+ class RegNet (nn .Module ):
446+ def __init__ (
447+ self ,
448+ stem : nn .Module ,
449+ blocks : nn .Module ,
450+ classifier : nn .Module ,
451+ ** kwargs : Any ,
452+ ) -> None :
453+ super ().__init__ ()
454+ self .stem = stem
455+ self .blocks = blocks
456+ self .classifier = classifier
445457
446458 # Init weights and good to go
447459 self .reset_parameters ()
448460
449461 def forward (self , x : Tensor ) -> Tensor :
450462 x = self .stem (x )
451- x = self .trunk_output (x )
452-
453- x = self .avgpool (x )
454- x = x .flatten (start_dim = 1 )
455- x = self .fc (x )
456-
463+ x = self .blocks (x )
464+ x = self .classifier (x )
457465 return x
458466
459467 def reset_parameters (self ) -> None :
@@ -472,7 +480,15 @@ def reset_parameters(self) -> None:
472480
473481
474482def _regnet (arch : str , block_params : BlockParams , pretrained : bool , progress : bool , ** kwargs : Any ) -> RegNet :
475- model = RegNet (block_params , norm_layer = partial (nn .BatchNorm2d , eps = 1e-05 , momentum = 0.1 ), ** kwargs )
483+ norm_layer = kwargs ["norm_layer" ] if "norm_layer" in kwargs else partial (nn .BatchNorm2d , eps = 1e-05 , momentum = 0.1 )
484+ activation = kwargs ["activation" ] if "activation" in kwargs else nn .ReLU
485+ num_classes = kwargs ["num_classes" ] if "num_classes" in kwargs else 1000
486+
487+ stem_width = 32
488+ stem = _make_stem (stem_width , norm_layer = norm_layer , activation = activation )
489+ blocks , out_channels = _make_blocks (stem_width , params = block_params , norm_layer = norm_layer , activation = activation )
490+ classifier = Classifier (out_channels , num_classes )
491+ model = RegNet (stem , blocks , classifier )
476492 if pretrained :
477493 if arch not in model_urls :
478494 raise ValueError (f"No checkpoint is available for model type { arch } " )
0 commit comments