|
f8d379d29e
|
[DTensor] Introduce DebugMode (#162665)
Introduce a lightweight TorchDispatchMode for understanding the magic behind DTensor.
- Tracks redistribution, see `redistribute_input(input_idx, from_placement, to_placement)`
- Optionally tracks torch-level functions, via `__torch_function__`
- Optionally tracks FakeTensor operations, which was needed for propagating tensor meta as a step of sharding propagation
- Optionally tracks real tensor operations, including functional c10d op, and regular ops
- Calls are shown in the hierarchical structure!
- shorthand representation
- dt: DTesnor, ft: FakeTensor, t: Tensor
- DM(2, 2) == DeviceMesh(shape = [2, 2])
- [R, P, S(0)] == Placement[Replicate, Partial, Shard(0)]
- f32[8,8] == float32 with shape[8, 8]
```
debug_mode = DTensorDebugMode(record_faketensor=False, record_realtensor=True)
with debug_mode:
torch.mm(x_dtensor, y_dtensor)
print(debug_mode.debug_string())
```
produces:
```
torch.mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
aten::mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
redistribute_input(1, [S(0)], [R])
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
_c10d_functional::wait_tensor(t: f32[8, 32])
aten::mm(t: f32[1, 8], t: f32[8, 32])
```
Another example, for torch.einsum
```
torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8][P, R], dt: f32[8, 4, 4][R, P])
aten::unsqueeze(dt: f32[16, 6, 8][P, R], 3)
aten::unsqueeze(t: f32[16, 6, 8], 3)
aten::unsqueeze(dt: f32[16, 6, 8, 1][P, R], 4)
aten::unsqueeze(t: f32[16, 6, 8, 1], 4)
aten::permute(dt: f32[16, 6, 8, 1, 1][P, R], [0, 1, 3, 4, 2])
aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2])
aten::unsqueeze(dt: f32[8, 4, 4][R, P], 3)
aten::unsqueeze(t: f32[8, 4, 4], 3)
aten::unsqueeze(dt: f32[8, 4, 4, 1][R, P], 4)
aten::unsqueeze(t: f32[8, 4, 4, 1], 4)
aten::permute(dt: f32[8, 4, 4, 1, 1][R, P], [3, 4, 1, 2, 0])
aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0])
aten::permute(dt: f32[16, 6, 1, 1, 8][P, R], [0, 1, 4, 2, 3])
aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3])
aten::view(dt: f32[16, 6, 8, 1, 1][P, R], [1, 96, 8])
aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8])
aten::permute(dt: f32[1, 1, 4, 4, 8][R, P], [4, 2, 3, 0, 1])
aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1])
aten::view(dt: f32[8, 4, 4, 1, 1][R, P], [1, 8, 16])
aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
aten::bmm(dt: f32[1, 96, 8][P, R], dt: f32[1, 8, 16][R, P])
redistribute_input(0, [P, R], [S(2), S(2)])
aten::chunk(t: f32[1, 96, 8], 4, 2)
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 2)
aten::clone(t: f32[1, 96, 1])
redistribute_input(1, [R, P], [S(1), S(1)])
aten::chunk(t: f32[1, 8, 16], 4, 1)
aten::clone(t: f32[1, 2, 16])
aten::chunk(t: f32[1, 2, 16], 2, 1)
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
aten::view(dt: f32[1, 96, 16][P, P], [16, 6, 1, 4, 4])
aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])
aten::permute(dt: f32[16, 6, 1, 4, 4][P, P], [0, 1, 3, 4, 2])
aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2])
aten::view(dt: f32[16, 6, 4, 4, 1][P, P], [16, 6, 4, 4])
aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162665
Approved by: https://github.com/ezyang
|
2025-09-16 07:30:05 +00:00 |
|