Skip to content

Commit ad93357

Browse files
masnesralpytorchmergebot
authored andcommitted
[fx graph cache] FxGraphPickler: Remove hack to stabilize device string hashes (#138681)
Summary: With the fast pickling mode, we don't need the custom hack for replacing device strings in tensors. This was previously needed because, e.g., two strings "cuda" will pickle differently if they are the same object vs. not. Test Plan: The new test fails with fast mode commented out, but succeeds when enabled: `python test/inductor/test_codecache.py -k test_stable_strings` Pull Request resolved: #138681 Approved by: https://github.com/oulgen
1 parent 3b0f393 commit ad93357

File tree

3 files changed

+26
-35
lines changed

3 files changed

+26
-35
lines changed

test/inductor/test_codecache.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,22 @@ def uuid(self) -> Optional[Union[bytes, str]]:
835835
FxGraphCachePickler.dumps(details3),
836836
)
837837

838+
def test_stable_strings(self):
839+
"""
840+
Test that objects containing identical strings pickle the same
841+
even if they are not the same id.
842+
"""
843+
s1 = "string"
844+
s2 = "strin"
845+
s2 += "g"
846+
847+
self.assertNotEqual(id(s1), id(s2))
848+
849+
self.assertEqual(
850+
FxGraphCachePickler.dumps([s1, s1]),
851+
FxGraphCachePickler.dumps([s1, s2]),
852+
)
853+
838854
def test_get_hash_for_files(self):
839855
"""
840856
Test the get_hash_for_files helper.

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,8 @@ def _reduce_tensor(tensor):
251251
"""
252252
Reduce the tensor to a stable key for caching.
253253
"""
254-
return (
255-
_ident,
256-
(
257-
extract_tensor_metadata_for_cache_key(
258-
FxGraphCachePickler._device_map, tensor
259-
),
260-
),
261-
)
254+
metadata = extract_tensor_metadata_for_cache_key(tensor)
255+
return (_ident, (metadata,))
262256

263257

264258
class AOTAutogradCachePickler(FxGraphCachePickler):

torch/_inductor/codecache.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,7 @@ def _ident(x: T) -> T:
506506
return x
507507

508508

509-
def extract_tensor_metadata_for_cache_key(
510-
device_map: Dict[torch.device, torch.device], t: Tensor
511-
) -> TensorMetadata:
509+
def extract_tensor_metadata_for_cache_key(t: Tensor) -> TensorMetadata:
512510
"""
513511
Extracts the tensor metadata and removes fields of the TensorMetadata
514512
that are not needed for caching
@@ -517,32 +515,19 @@ def extract_tensor_metadata_for_cache_key(
517515
if not hasattr(t, "_is_inductor_static"):
518516
meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None)
519517

520-
# The pickle implementation avoids serializing the same object more than once.
521-
# That behavior means the byte stream we create to hash will vary if, for example,
522-
# we see two tensor objects with the same device, but the torch.device object is
523-
# actually the same object vs. merely equivalent. We want to produce the same hash
524-
# value in either situation, so we memoize the device objects and always reference
525-
# the same object for a given device. It's possible other metadata fields deserve
526-
# the same treatment, but so far we've only observed this issue with the device.
527-
if meta.device not in device_map:
528-
device_map[meta.device] = meta.device
529-
meta = dataclasses.replace(meta, device=device_map[meta.device])
530-
531518
return meta
532519

533520

534-
def _reduce_fake_tensor(
535-
device_map: Dict[torch.device, torch.device], t: Tensor
536-
) -> Tuple[Callable[[T], T], Tuple[TensorMetadata]]:
521+
def _reduce_fake_tensor(t: Tensor) -> Tuple[Callable[[T], T], Tuple[TensorMetadata]]:
537522
"""
538523
See FxGraphCachePickler. Custom reducer to pickle FakeTensors.
539524
"""
540-
metadata = extract_tensor_metadata_for_cache_key(device_map, t)
525+
metadata = extract_tensor_metadata_for_cache_key(t)
541526
return (_ident, (metadata,))
542527

543528

544529
def _reduce_tensor(
545-
device_map: Dict[torch.device, torch.device], t: Tensor
530+
t: Tensor,
546531
) -> Tuple[Callable[[T], T], Tuple[TensorMetadataAndValues]]:
547532
"""
548533
See FxGraphCachePickler. Custom reducer to pickle Tensors.
@@ -570,7 +555,7 @@ def _reduce_tensor(
570555
f"FX graph cache handling of a large constant took {elapsed:.1}s. Please file an issue."
571556
)
572557

573-
metadata = extract_tensor_metadata_for_cache_key(device_map, t)
558+
metadata = extract_tensor_metadata_for_cache_key(t)
574559
return (_ident, (TensorMetadataAndValues(metadata, values),))
575560

576561

@@ -600,13 +585,9 @@ class FxGraphCachePickler(pickle.Pickler):
600585
data that allow us to compute a stable, but safe hash.
601586
"""
602587

603-
# See extract_tensor_metadata_for_cache_key. Whenever we extract metadata during
604-
# pickling, we make sure devices always reference the same torch.device object.
605-
_device_map: Dict[torch.device, torch.device] = {}
606-
607588
dispatch_table = copyreg.dispatch_table.copy()
608-
dispatch_table[FakeTensor] = functools.partial(_reduce_fake_tensor, _device_map)
609-
dispatch_table[torch.Tensor] = functools.partial(_reduce_tensor, _device_map)
589+
dispatch_table[FakeTensor] = _reduce_fake_tensor
590+
dispatch_table[torch.Tensor] = _reduce_tensor
610591
dispatch_table[torch.SymInt] = _reduce_symint
611592
dispatch_table[
612593
torch.fx.experimental._backward_state.BackwardState
@@ -648,7 +629,7 @@ def debug_lines(cls, inp: FxGraphHashDetails) -> List[str]:
648629

649630
def get_str(obj: Any) -> str:
650631
if isinstance(obj, torch.Tensor):
651-
return str(extract_tensor_metadata_for_cache_key(cls._device_map, obj))
632+
return str(extract_tensor_metadata_for_cache_key(obj))
652633
elif isinstance(obj, bytes):
653634
return "<bytes>"
654635
elif type(obj) in cls.dispatch_table:

0 commit comments

Comments
 (0)