[DTensor] Support matmul in inference_mode (#142197)

Fixes #142190 .

The solution is to add a `decompose_handler` for `aten.matmul`, similar to how we handle `aten.linear`.
With the decomposition, `aten.matmul` becomes `aten.mm` which has sharding strategy registered with DTensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142197
Approved by: https://github.com/XilunWu, https://github.com/wz337
This commit is contained in:
Ke Wen
2024-12-05 17:22:56 -08:00
committed by PyTorch MergeBot
parent 02c509669a
commit 8bdcdae733
2 changed files with 20 additions and 0 deletions

View File

@ -120,6 +120,25 @@ class DistMatrixOpsTest(DTensorTestBase):
for spec in shard_specs_comb:
test_placement_comb([spec[0]], [spec[1]])
@with_comms
def test_matmul(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
dim = 128
x = torch.randn(8, dim)
A = torch.randn(dim, dim)
y = torch.matmul(x, A)
# Prepare DTensors
dx = distribute_tensor(x, device_mesh, [Replicate()])
dA = distribute_tensor(A, device_mesh, [Shard(0)])
# Use `inference_mode` to test DTensor's capability of decomposing
# `matmul` op
with torch.inference_mode():
dy = torch.matmul(dx, dA)
self.assertEqual(y, dy.full_tensor())
@with_comms
def test_t(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

View File

@ -137,6 +137,7 @@ class OpDispatcher:
}
self._custom_op_handlers = {
aten.linear.default: decompose_handler,
aten.matmul.default: decompose_handler,
aten.is_same_size.default: is_same_size_handler,
aten.convolution.default: convolution_handler,
aten.convolution_backward.default: convolution_backward_handler,