mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor] Support matmul in inference_mode (#142197)
Fixes #142190 . The solution is to add a `decompose_handler` for `aten.matmul`, similar to how we handle `aten.linear`. With the decomposition, `aten.matmul` becomes `aten.mm` which has sharding strategy registered with DTensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142197 Approved by: https://github.com/XilunWu, https://github.com/wz337
This commit is contained in:
@ -120,6 +120,25 @@ class DistMatrixOpsTest(DTensorTestBase):
|
||||
for spec in shard_specs_comb:
|
||||
test_placement_comb([spec[0]], [spec[1]])
|
||||
|
||||
@with_comms
|
||||
def test_matmul(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
dim = 128
|
||||
x = torch.randn(8, dim)
|
||||
A = torch.randn(dim, dim)
|
||||
y = torch.matmul(x, A)
|
||||
|
||||
# Prepare DTensors
|
||||
dx = distribute_tensor(x, device_mesh, [Replicate()])
|
||||
dA = distribute_tensor(A, device_mesh, [Shard(0)])
|
||||
|
||||
# Use `inference_mode` to test DTensor's capability of decomposing
|
||||
# `matmul` op
|
||||
with torch.inference_mode():
|
||||
dy = torch.matmul(dx, dA)
|
||||
|
||||
self.assertEqual(y, dy.full_tensor())
|
||||
|
||||
@with_comms
|
||||
def test_t(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
@ -137,6 +137,7 @@ class OpDispatcher:
|
||||
}
|
||||
self._custom_op_handlers = {
|
||||
aten.linear.default: decompose_handler,
|
||||
aten.matmul.default: decompose_handler,
|
||||
aten.is_same_size.default: is_same_size_handler,
|
||||
aten.convolution.default: convolution_handler,
|
||||
aten.convolution_backward.default: convolution_backward_handler,
|
||||
|
Reference in New Issue
Block a user