mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] add test for local_map decorator (#127752)
**Summary** This PR is a follow-up of #126924 to address reviewer's comments: 1) add a test case to show the use of `local_map` as a function decorator. 2) simplify the logic of handling different data types of `out_placements`. 3) correct variable naming in test cases to match math formulas. **Test** see #126924 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127752 Approved by: https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
8de0d7690c
commit
0159ebb654
@ -194,13 +194,17 @@ def local_map(
|
||||
flat_out, out_spec = pytree.tree_flatten(out)
|
||||
|
||||
flat_dist_out = []
|
||||
for idx, out in enumerate(flat_out):
|
||||
spec = (
|
||||
out_placements[idx]
|
||||
if isinstance(out_placements, tuple)
|
||||
else out_placements
|
||||
)
|
||||
|
||||
out_placements_tuple = (
|
||||
out_placements
|
||||
if isinstance(out_placements, tuple)
|
||||
else (out_placements,)
|
||||
)
|
||||
assert len(flat_out) == len(out_placements_tuple), (
|
||||
"local_map requires one PlacementType be provided for each output value,"
|
||||
f" received {len(out_placements_tuple)} out_placements but"
|
||||
f" {len(flat_out)} is expected!"
|
||||
)
|
||||
for out, spec in zip(flat_out, out_placements_tuple):
|
||||
if isinstance(out, torch.Tensor):
|
||||
assert not isinstance(
|
||||
out, DTensor
|
||||
|
||||
Reference in New Issue
Block a user