mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[DTensor] Computed DTensorSpec hash lazily (#114322)
This is a forward fix for https://github.com/pytorch/pytorch/issues/113781. We lazily compute the hash so that we do not try to compute the hash on `SymInt`s (for the stride) during Dynamo tracing. Tested via: ``` python test/distributed/_tensor/test_dtensor_compile.py -k test_2d_fsdp_tp_ac_compile ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114322 Approved by: https://github.com/wanchaol ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141, #113915, #114140
This commit is contained in:
committed by
PyTorch MergeBot
parent
c5ddfa79b3
commit
e7326ec295
@ -388,7 +388,7 @@ class DTensorSpec:
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.placements, tuple):
|
||||
self.placements = tuple(self.placements)
|
||||
self._hash = self._hash_impl()
|
||||
self._hash: Optional[int] = None
|
||||
|
||||
def __setattr__(self, attr: str, value: Any):
|
||||
super().__setattr__(attr, value)
|
||||
@ -397,7 +397,7 @@ class DTensorSpec:
|
||||
if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"):
|
||||
self._hash = self._hash_impl()
|
||||
|
||||
def _hash_impl(self):
|
||||
def _hash_impl(self) -> int:
|
||||
# hashing and equality check for DTensorSpec are used to cache the sharding
|
||||
# propagation results. We only need to consider the mesh, placements, shape
|
||||
# dtype and stride.
|
||||
@ -416,9 +416,12 @@ class DTensorSpec:
|
||||
return hash((self.mesh, self.placements))
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# We eagerly cache the spec to avoid recomputing the hash upon each
|
||||
# We lazily cache the spec to avoid recomputing the hash upon each
|
||||
# use, where we make sure to update the hash when the `tensor_meta`
|
||||
# changes by overriding `__setattr__`.
|
||||
# changes by overriding `__setattr__`. This must be lazy so that Dynamo
|
||||
# does not try to hash non-singleton `SymInt`s for the stride.
|
||||
if self._hash is None:
|
||||
self._hash = self._hash_impl()
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user