Compare commits

...

1 Commits

Author SHA1 Message Date
fba29cea7b [DTensor] Assert DTensorSpec has valid placements (#158133)
Summary:
This helped identify buggy sharding rules during debugging, why not
check it in.

Test Plan:
contbuild & OSS CI

Rollback Plan:

Differential Revision: D78929245
2025-07-24 15:25:02 -07:00
2 changed files with 5 additions and 1 deletions

View File

@ -343,7 +343,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
x = torch.randn(64, 32, requires_grad=True)
spec = DTensorSpec(
mesh,
(Replicate(), Shard(0)),
(Replicate(),),
tensor_meta=TensorMeta(
shape=torch.Size([128, 32]), stride=(32, 1), dtype=x.dtype
),

View File

@ -32,6 +32,10 @@ class DTensorSpec:
def __post_init__(self) -> None:
if not isinstance(self.placements, tuple):
self.placements = tuple(self.placements)
if not len(self.placements) == self.mesh.ndim:
raise ValueError(
f"DTensorSpec requires one placement per mesh dim (mesh.ndim={self.mesh.ndim}), got {self.placements=}"
)
self._hash: Optional[int] = None
def __setattr__(self, attr: str, value: Any) -> None: