[DTensor] Fix adding Partial(sum) with Scalar

This commit is contained in:
Sherlock Huang
2025-09-18 13:38:21 -07:00
parent 607489f3d0
commit b937dc720f
2 changed files with 37 additions and 1 deletions

View File

@ -7,7 +7,7 @@ import warnings
import torch
import torch.distributed as dist
import torch.testing._internal.common_methods_invocations as common_ops
from torch.distributed.tensor import DTensor, init_device_mesh
from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard
from torch.overrides import resolve_name
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
@ -649,6 +649,17 @@ class TestDTensorOps(DTensorOpTestBase):
else:
print(f"xfail('{opinfo.name}'),")
def test_partial_sum_add_scalar(self):
self.mesh = init_device_mesh(DEVICE_TYPE, (self.world_size,))
tensor = torch.arange(10, dtype=torch.float, device=DEVICE_TYPE)
dt = distribute_tensor(tensor, self.mesh, [Shard(0)])
def func(t):
return t.sum() + 1
self.assertEqual(func(dt).full_tensor(), func(tensor))
def test_one_hot(self):
ops = [op for op in op_db if op.name == "nn.functional.one_hot"]
assert len(ops) == 1

View File

@ -49,6 +49,22 @@ def is_same_size_handler(
return lhs.shape == rhs.shape
def add_handler_predicate(
op_call: torch._ops.OpOverload, args: tuple[object, ...], kwargs: dict[str, object]
) -> bool:
return not isinstance(args[1], torch.Tensor)
def add_scalar_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
assert add_handler_predicate(op_call, args, kwargs)
new_args = (args[0], torch.tensor(args[1]))
return op_call(*new_args, **kwargs)
def found_inf_reduce_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
@ -125,6 +141,10 @@ class OpDispatcher:
aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
}
self._custom_op_handlers_with_condition = {
aten.add.Tensor: (add_handler_predicate, add_scalar_handler),
}
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
# as implicitly replicated or we throw error to user.
# NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
@ -152,6 +172,11 @@ class OpDispatcher:
if op_call in self._custom_op_handlers:
return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
if op_call in self._custom_op_handlers_with_condition:
condition, handler = self._custom_op_handlers_with_condition[op_call]
if condition(op_call, args, kwargs):
return handler(op_call, args, kwargs)
# extract local tensor and sharding infos to a OpInfo
op_info = self.unwrap_to_op_info(op_call, args, kwargs)