[DTensor] Computed DTensorSpec hash lazily (#114322)

This is a forward fix for https://github.com/pytorch/pytorch/issues/113781.

We lazily compute the hash so that we do not try to compute the hash on `SymInt`s (for the stride) during Dynamo tracing.

Tested via:
```
python test/distributed/_tensor/test_dtensor_compile.py -k test_2d_fsdp_tp_ac_compile
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114322
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141, #113915, #114140
This commit is contained in:
Andrew Gu
2023-11-21 16:11:32 -08:00
committed by PyTorch MergeBot
parent c5ddfa79b3
commit e7326ec295

View File

@ -388,7 +388,7 @@ class DTensorSpec:
def __post_init__(self):
if not isinstance(self.placements, tuple):
self.placements = tuple(self.placements)
self._hash = self._hash_impl()
self._hash: Optional[int] = None
def __setattr__(self, attr: str, value: Any):
super().__setattr__(attr, value)
@ -397,7 +397,7 @@ class DTensorSpec:
if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"):
self._hash = self._hash_impl()
def _hash_impl(self):
def _hash_impl(self) -> int:
# hashing and equality check for DTensorSpec are used to cache the sharding
# propagation results. We only need to consider the mesh, placements, shape
# dtype and stride.
@ -416,9 +416,12 @@ class DTensorSpec:
return hash((self.mesh, self.placements))
def __hash__(self) -> int:
# We eagerly cache the spec to avoid recomputing the hash upon each
# We lazily cache the spec to avoid recomputing the hash upon each
# use, where we make sure to update the hash when the `tensor_meta`
# changes by overriding `__setattr__`.
# changes by overriding `__setattr__`. This must be lazy so that Dynamo
# does not try to hash non-singleton `SymInt`s for the stride.
if self._hash is None:
self._hash = self._hash_impl()
return self._hash
def __eq__(self, __o: object) -> bool: