[dtensor] add test for local_map decorator (#127752)

**Summary**
This PR is a follow-up of #126924 to address reviewer's comments:
1) add a test case to show the use of `local_map` as a function decorator.
2) simplify the logic of handling different data types of `out_placements`.
3) correct variable naming in test cases to match math formulas.

**Test**
see #126924

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127752
Approved by: https://github.com/wanchaol
This commit is contained in:
Xilun Wu
2024-08-26 16:19:38 -07:00
committed by PyTorch MergeBot
parent 8de0d7690c
commit 0159ebb654
2 changed files with 75 additions and 56 deletions

View File

@ -194,13 +194,17 @@ def local_map(
flat_out, out_spec = pytree.tree_flatten(out)
flat_dist_out = []
for idx, out in enumerate(flat_out):
spec = (
out_placements[idx]
if isinstance(out_placements, tuple)
else out_placements
)
out_placements_tuple = (
out_placements
if isinstance(out_placements, tuple)
else (out_placements,)
)
assert len(flat_out) == len(out_placements_tuple), (
"local_map requires one PlacementType be provided for each output value,"
f" received {len(out_placements_tuple)} out_placements but"
f" {len(flat_out)} is expected!"
)
for out, spec in zip(flat_out, out_placements_tuple):
if isinstance(out, torch.Tensor):
assert not isinstance(
out, DTensor