@@ -638,120 +638,85 @@ def set_classy_state(self, state):
638638 # Set up pytorch module in train vs eval mode, update optimizer.
639639 self ._set_model_train_mode ()
640640
641- def eval_step (self , use_gpu , local_variables = None ):
642- if local_variables is None :
643- local_variables = {}
644-
641+ def eval_step (self , use_gpu ):
645642 self .last_batch = None
646643
647644 # Process next sample
648645 sample = next (self .get_data_iterator ())
649- local_variables ["sample" ] = sample
650646
651- assert (
652- isinstance (local_variables ["sample" ], dict )
653- and "input" in local_variables ["sample" ]
654- and "target" in local_variables ["sample" ]
655- ), (
647+ assert isinstance (sample , dict ) and "input" in sample and "target" in sample , (
656648 f"Returned sample [{ sample } ] is not a map with 'input' and"
657649 + "'target' keys"
658650 )
659651
660652 # Copy sample to GPU
661- local_variables [ " target" ] = local_variables [ " sample" ] ["target" ]
653+ target = sample ["target" ]
662654 if use_gpu :
663- for key , value in local_variables ["sample" ].items ():
664- local_variables ["sample" ][key ] = recursive_copy_to_gpu (
665- value , non_blocking = True
666- )
655+ for key , value in sample .items ():
656+ sample [key ] = recursive_copy_to_gpu (value , non_blocking = True )
667657
668658 with torch .no_grad ():
669- local_variables [ " output" ] = self .model (local_variables [ " sample" ] ["input" ])
659+ output = self .model (sample ["input" ])
670660
671- local_variables ["local_loss" ] = self .compute_loss (
672- local_variables ["output" ], local_variables ["sample" ]
673- )
661+ local_loss = self .compute_loss (output , sample )
674662
675- local_variables [ " loss" ] = local_variables [ " local_loss" ] .detach ().clone ()
676- local_variables [ " loss" ] = all_reduce_mean (local_variables [ " loss" ] )
663+ loss = local_loss .detach ().clone ()
664+ loss = all_reduce_mean (loss )
677665
678- self .losses .append (
679- local_variables ["loss" ].data .cpu ().item ()
680- * local_variables ["target" ].size (0 )
681- )
666+ self .losses .append (loss .data .cpu ().item () * target .size (0 ))
682667
683- self .update_meters (local_variables [ " output" ], local_variables [ " sample" ] )
668+ self .update_meters (output , sample )
684669
685670 # Move some data to the task so hooks get a chance to access it
686671 self .last_batch = LastBatchInfo (
687- loss = local_variables ["loss" ],
688- output = local_variables ["output" ],
689- target = local_variables ["target" ],
690- sample = local_variables ["sample" ],
672+ loss = loss , output = output , target = target , sample = sample
691673 )
692674
693- def train_step (self , use_gpu , local_variables = None ):
675+ def train_step (self , use_gpu ):
694676 """Train step to be executed in train loop
695677
696678 Args:
697679 use_gpu: if true, execute training on GPU
698- local_variables: Dict containing intermediate values
699- in train_step for access by hooks
700680 """
701681
702- if local_variables is None :
703- local_variables = {}
704-
705682 self .last_batch = None
706683
707684 # Process next sample
708685 sample = next (self .get_data_iterator ())
709- local_variables ["sample" ] = sample
710686
711- assert (
712- isinstance (local_variables ["sample" ], dict )
713- and "input" in local_variables ["sample" ]
714- and "target" in local_variables ["sample" ]
715- ), (
687+ assert isinstance (sample , dict ) and "input" in sample and "target" in sample , (
716688 f"Returned sample [{ sample } ] is not a map with 'input' and"
717689 + "'target' keys"
718690 )
719691
720692 # Copy sample to GPU
721- local_variables [ " target" ] = local_variables [ " sample" ] ["target" ]
693+ target = sample ["target" ]
722694 if use_gpu :
723- for key , value in local_variables ["sample" ].items ():
724- local_variables ["sample" ][key ] = recursive_copy_to_gpu (
725- value , non_blocking = True
726- )
695+ for key , value in sample .items ():
696+ sample [key ] = recursive_copy_to_gpu (value , non_blocking = True )
727697
728698 with torch .enable_grad ():
729699 # Forward pass
730- local_variables [ " output" ] = self .model (local_variables [ " sample" ] ["input" ])
700+ output = self .model (sample ["input" ])
731701
732- local_variables ["local_loss" ] = self .compute_loss (
733- local_variables ["output" ], local_variables ["sample" ]
734- )
702+ local_loss = self .compute_loss (output , sample )
735703
736- local_variables [ " loss" ] = local_variables [ " local_loss" ] .detach ().clone ()
737- local_variables [ " loss" ] = all_reduce_mean (local_variables [ " loss" ] )
704+ loss = local_loss .detach ().clone ()
705+ loss = all_reduce_mean (loss )
738706
739- self .losses .append (
740- local_variables ["loss" ].data .cpu ().item ()
741- * local_variables ["target" ].size (0 )
742- )
707+ self .losses .append (loss .data .cpu ().item () * target .size (0 ))
743708
744- self .update_meters (local_variables [ " output" ], local_variables [ " sample" ] )
709+ self .update_meters (output , sample )
745710
746711 # Run backwards pass / update optimizer
747712 if self .amp_opt_level is not None :
748713 self .optimizer .zero_grad ()
749714 with apex .amp .scale_loss (
750- local_variables [ " local_loss" ] , self .optimizer .optimizer
715+ local_loss , self .optimizer .optimizer
751716 ) as scaled_loss :
752717 scaled_loss .backward ()
753718 else :
754- self .optimizer .backward (local_variables [ " local_loss" ] )
719+ self .optimizer .backward (local_loss )
755720
756721 self .optimizer .update_schedule_on_step (self .where )
757722 self .optimizer .step ()
@@ -760,10 +725,7 @@ def train_step(self, use_gpu, local_variables=None):
760725
761726 # Move some data to the task so hooks get a chance to access it
762727 self .last_batch = LastBatchInfo (
763- loss = local_variables ["loss" ],
764- output = local_variables ["output" ],
765- target = local_variables ["target" ],
766- sample = local_variables ["sample" ],
728+ loss = loss , output = output , target = target , sample = sample
767729 )
768730
769731 def compute_loss (self , model_output , sample ):
0 commit comments