@@ -506,9 +506,7 @@ def _ident(x: T) -> T:
506506 return x
507507
508508
509- def extract_tensor_metadata_for_cache_key (
510- device_map : Dict [torch .device , torch .device ], t : Tensor
511- ) -> TensorMetadata :
509+ def extract_tensor_metadata_for_cache_key (t : Tensor ) -> TensorMetadata :
512510 """
513511 Extracts the tensor metadata and removes fields of the TensorMetadata
514512 that are not needed for caching
@@ -517,32 +515,19 @@ def extract_tensor_metadata_for_cache_key(
517515 if not hasattr (t , "_is_inductor_static" ):
518516 meta = dataclasses .replace (meta , storage_offset = 0 , storage_bytes = None )
519517
520- # The pickle implementation avoids serializing the same object more than once.
521- # That behavior means the byte stream we create to hash will vary if, for example,
522- # we see two tensor objects with the same device, but the torch.device object is
523- # actually the same object vs. merely equivalent. We want to produce the same hash
524- # value in either situation, so we memoize the device objects and always reference
525- # the same object for a given device. It's possible other metadata fields deserve
526- # the same treatment, but so far we've only observed this issue with the device.
527- if meta .device not in device_map :
528- device_map [meta .device ] = meta .device
529- meta = dataclasses .replace (meta , device = device_map [meta .device ])
530-
531518 return meta
532519
533520
534- def _reduce_fake_tensor (
535- device_map : Dict [torch .device , torch .device ], t : Tensor
536- ) -> Tuple [Callable [[T ], T ], Tuple [TensorMetadata ]]:
521+ def _reduce_fake_tensor (t : Tensor ) -> Tuple [Callable [[T ], T ], Tuple [TensorMetadata ]]:
537522 """
538523 See FxGraphCachePickler. Custom reducer to pickle FakeTensors.
539524 """
540- metadata = extract_tensor_metadata_for_cache_key (device_map , t )
525+ metadata = extract_tensor_metadata_for_cache_key (t )
541526 return (_ident , (metadata ,))
542527
543528
544529def _reduce_tensor (
545- device_map : Dict [ torch . device , torch . device ], t : Tensor
530+ t : Tensor ,
546531) -> Tuple [Callable [[T ], T ], Tuple [TensorMetadataAndValues ]]:
547532 """
548533 See FxGraphCachePickler. Custom reducer to pickle Tensors.
@@ -570,7 +555,7 @@ def _reduce_tensor(
570555 f"FX graph cache handling of a large constant took { elapsed :.1} s. Please file an issue."
571556 )
572557
573- metadata = extract_tensor_metadata_for_cache_key (device_map , t )
558+ metadata = extract_tensor_metadata_for_cache_key (t )
574559 return (_ident , (TensorMetadataAndValues (metadata , values ),))
575560
576561
@@ -600,13 +585,9 @@ class FxGraphCachePickler(pickle.Pickler):
600585 data that allow us to compute a stable, but safe hash.
601586 """
602587
603- # See extract_tensor_metadata_for_cache_key. Whenever we extract metadata during
604- # pickling, we make sure devices always reference the same torch.device object.
605- _device_map : Dict [torch .device , torch .device ] = {}
606-
607588 dispatch_table = copyreg .dispatch_table .copy ()
608- dispatch_table [FakeTensor ] = functools . partial ( _reduce_fake_tensor , _device_map )
609- dispatch_table [torch .Tensor ] = functools . partial ( _reduce_tensor , _device_map )
589+ dispatch_table [FakeTensor ] = _reduce_fake_tensor
590+ dispatch_table [torch .Tensor ] = _reduce_tensor
610591 dispatch_table [torch .SymInt ] = _reduce_symint
611592 dispatch_table [
612593 torch .fx .experimental ._backward_state .BackwardState
@@ -648,7 +629,7 @@ def debug_lines(cls, inp: FxGraphHashDetails) -> List[str]:
648629
649630 def get_str (obj : Any ) -> str :
650631 if isinstance (obj , torch .Tensor ):
651- return str (extract_tensor_metadata_for_cache_key (cls . _device_map , obj ))
632+ return str (extract_tensor_metadata_for_cache_key (obj ))
652633 elif isinstance (obj , bytes ):
653634 return "<bytes>"
654635 elif type (obj ) in cls .dispatch_table :
0 commit comments