@@ -597,133 +597,28 @@ def assert_fn(output: torch.Tensor):
597597 loss .backward ()
598598
599599 @skip_if_lt_x_gpu (1 )
600- def test_dataclass_input_output (self ):
601- from unittest .mock import patch
602-
603- from torch .distributed ._composable_state import _get_module_state
604-
600+ def test_dataclass_input (self ):
605601 @dataclasses .dataclass
606602 class Input :
607603 x : torch .Tensor
608- y : torch .Tensor
609-
610- @dataclasses .dataclass
611- class Output :
612- x : torch .Tensor
613- y : torch .Tensor
614-
615- @dataclasses .dataclass
616- class Scale :
617- factor : torch .Tensor
618604
619605 class Model (nn .Module ):
620606 def __init__ (self , * args , ** kwargs ) -> None :
621607 super ().__init__ (* args , ** kwargs )
622608 self ._layer = nn .Linear (10 , 10 )
623609
624- def forward (self , input : Input , * , scale : Scale | None = None ):
625- x = self ._layer (input .x )
626- y = self ._layer (input .y )
627- if scale is not None :
628- x = x * scale .factor
629- y = y * scale .factor
630- return Output (x = x , y = y )
631-
632- class TensorModel (nn .Module ):
633- def __init__ (self , * args , ** kwargs ) -> None :
634- super ().__init__ (* args , ** kwargs )
635- self ._layer = nn .Linear (10 , 10 )
610+ def forward (self , input : Input ):
611+ return self ._layer (input .x )
636612
637- def forward (self , x : torch .Tensor , * , scale : torch .Tensor | None = None ):
638- out = self ._layer (x )
639- if scale is not None :
640- out = out * scale
641- return out
642-
643- # Test with different MixedPrecisionPolicy configurations
644- mp_policies = [
645- MixedPrecisionPolicy (
646- param_dtype = torch .bfloat16 ,
647- reduce_dtype = torch .bfloat16 ,
648- ),
649- MixedPrecisionPolicy (
650- param_dtype = torch .bfloat16 ,
651- reduce_dtype = torch .float32 ,
652- ),
653- ]
654-
655- for mp_policy in mp_policies :
656- # Test with normal torch.Tensor as arg
657- tensor_model = TensorModel ()
658- fully_shard (tensor_model , mp_policy = mp_policy )
659- fsdp_state = _get_module_state (tensor_model )
660- x = torch .randn (10 , 10 , device = device_type , requires_grad = True )
661- with patch .object (
662- fsdp_state , "_pre_backward" , wraps = fsdp_state ._pre_backward
663- ) as mock_pre_backward :
664- loss = tensor_model (x ).sum ()
665- loss .backward ()
666- mock_pre_backward .assert_called ()
667-
668- # Test with normal torch.Tensor as both arg and kwarg
669- tensor_model .zero_grad ()
670- x = torch .randn (10 , 10 , device = device_type , requires_grad = True )
671- scale = torch .randn (10 , 10 , device = device_type , requires_grad = True )
672- with patch .object (
673- fsdp_state , "_pre_backward" , wraps = fsdp_state ._pre_backward
674- ) as mock_pre_backward :
675- loss = tensor_model (x , scale = scale ).sum ()
676- loss .backward ()
677- mock_pre_backward .assert_called ()
678-
679- # Test with dataclass as positional arg only
680- model = nn .Sequential (* [Model (), Model ()])
681- inp = Input (
682- x = torch .randn (10 , 10 , device = device_type , requires_grad = True ),
683- y = torch .randn (10 , 10 , device = device_type , requires_grad = True ),
684- )
613+ mp_policy = MixedPrecisionPolicy (
614+ torch .bfloat16 , torch .bfloat16 , torch .bfloat16 , True
615+ )
616+ model = Model ()
617+ inp = Input (torch .randn (2 , 10 ).to (device_type ))
685618
686- for layer in model :
687- fully_shard (layer , mp_policy = mp_policy , reshard_after_forward = True )
688- fully_shard (model , mp_policy = mp_policy , reshard_after_forward = True )
689-
690- # Patch _pre_backward on all FSDP states
691- layer0_state = _get_module_state (model [0 ])
692- layer1_state = _get_module_state (model [1 ])
693- root_state = _get_module_state (model )
694- with (
695- patch .object (
696- layer0_state , "_pre_backward" , wraps = layer0_state ._pre_backward
697- ) as layer0_mock ,
698- patch .object (
699- layer1_state , "_pre_backward" , wraps = layer1_state ._pre_backward
700- ) as layer1_mock ,
701- patch .object (
702- root_state , "_pre_backward" , wraps = root_state ._pre_backward
703- ) as root_mock ,
704- ):
705- output = model (inp )
706- loss = output .x .sum () + output .y .sum ()
707- loss .backward ()
708- layer0_mock .assert_called ()
709- layer1_mock .assert_called ()
710- root_mock .assert_called ()
711-
712- # Test with dataclass as both positional arg and kwarg
713- inp = Input (
714- x = torch .randn (10 , 10 , device = device_type , requires_grad = True ),
715- y = torch .randn (10 , 10 , device = device_type , requires_grad = True ),
716- )
717- scale = Scale (
718- factor = torch .randn (10 , 10 , device = device_type , requires_grad = True )
719- )
720- with patch .object (
721- layer0_state , "_pre_backward" , wraps = layer0_state ._pre_backward
722- ) as layer0_mock :
723- output = model [0 ](inp , scale = scale )
724- loss = output .x .sum () + output .y .sum ()
725- loss .backward ()
726- layer0_mock .assert_called ()
619+ fully_shard (model , mp_policy = mp_policy )
620+ loss = model (inp ).sum ()
621+ loss .backward ()
727622
728623
729624if __name__ == "__main__" :
0 commit comments