mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 17:45:09 +08:00
[DTensor] Fix adding Partial(sum) with Scalar
This commit is contained in:
@ -7,7 +7,7 @@ import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.testing._internal.common_methods_invocations as common_ops
|
||||
from torch.distributed.tensor import DTensor, init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard
|
||||
from torch.overrides import resolve_name
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
@ -649,6 +649,17 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
else:
|
||||
print(f"xfail('{opinfo.name}'),")
|
||||
|
||||
def test_partial_sum_add_scalar(self):
|
||||
self.mesh = init_device_mesh(DEVICE_TYPE, (self.world_size,))
|
||||
|
||||
tensor = torch.arange(10, dtype=torch.float, device=DEVICE_TYPE)
|
||||
dt = distribute_tensor(tensor, self.mesh, [Shard(0)])
|
||||
|
||||
def func(t):
|
||||
return t.sum() + 1
|
||||
|
||||
self.assertEqual(func(dt).full_tensor(), func(tensor))
|
||||
|
||||
def test_one_hot(self):
|
||||
ops = [op for op in op_db if op.name == "nn.functional.one_hot"]
|
||||
assert len(ops) == 1
|
||||
|
||||
@ -49,6 +49,22 @@ def is_same_size_handler(
|
||||
return lhs.shape == rhs.shape
|
||||
|
||||
|
||||
def add_handler_predicate(
|
||||
op_call: torch._ops.OpOverload, args: tuple[object, ...], kwargs: dict[str, object]
|
||||
) -> bool:
|
||||
return not isinstance(args[1], torch.Tensor)
|
||||
|
||||
|
||||
def add_scalar_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
assert add_handler_predicate(op_call, args, kwargs)
|
||||
new_args = (args[0], torch.tensor(args[1]))
|
||||
return op_call(*new_args, **kwargs)
|
||||
|
||||
|
||||
def found_inf_reduce_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
@ -125,6 +141,10 @@ class OpDispatcher:
|
||||
aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
|
||||
}
|
||||
|
||||
self._custom_op_handlers_with_condition = {
|
||||
aten.add.Tensor: (add_handler_predicate, add_scalar_handler),
|
||||
}
|
||||
|
||||
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
|
||||
# as implicitly replicated or we throw error to user.
|
||||
# NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
|
||||
@ -152,6 +172,11 @@ class OpDispatcher:
|
||||
if op_call in self._custom_op_handlers:
|
||||
return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
|
||||
|
||||
if op_call in self._custom_op_handlers_with_condition:
|
||||
condition, handler = self._custom_op_handlers_with_condition[op_call]
|
||||
if condition(op_call, args, kwargs):
|
||||
return handler(op_call, args, kwargs)
|
||||
|
||||
# extract local tensor and sharding infos to a OpInfo
|
||||
op_info = self.unwrap_to_op_info(op_call, args, kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user