Compare commits

...

8 Commits

Author SHA1 Message Date
7adcd945e0 Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. This is accomplished when checking if we should add a Partial placement. If one of the arguments is a scalar and the op is aten.add_Tensor, we should redistribute the Partial to replicate. 

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-17 14:24:46 -08:00
0e0fa940ba Update base for Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. This is accomplished when checking if we should add a Partial placement. If one of the arguments is a scalar and the op is aten.add_Tensor, we should redistribute the Partial to replicate. 

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-17 14:24:46 -08:00
09a0f58faf Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. We accomplish this by adding a function that checks if we support Partial placements when attempting to add them to the output_spec. Regarding the specific checks for adding a scalar to Partial, I check if aten.add.Tensor is the op, and if so, is there a scalar argument in arg_schemas. If this is the case, I return False to force replication. Currently, we don't need to do this for aten.mul.Tensor as it works for all the reduction ops for each Partial placement. However, in the future, new Partial placements could be added where the reduce op requires redistribution. In this case, I replicate, but I warn the user I'm replicating.

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar
2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_unverified_custom_partial




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-14 13:49:08 -08:00
aa9b4f836e Update base for Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. We accomplish this by adding a function that checks if we support Partial placements when attempting to add them to the output_spec. Regarding the specific checks for adding a scalar to Partial, I check if aten.add.Tensor is the op, and if so, is there a scalar argument in arg_schemas. If this is the case, I return False to force replication. Currently, we don't need to do this for aten.mul.Tensor as it works for all the reduction ops for each Partial placement. However, in the future, new Partial placements could be added where the reduce op requires redistribution. In this case, I replicate, but I warn the user I'm replicating.

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar
2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_unverified_custom_partial




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-14 13:49:08 -08:00
2387cfb3bf Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. We accomplish this by adding a function that checks if we support Partial placements when attempting to add them to the output_spec. Regarding the specific checks for adding a scalar to Partial, I check if aten.add.Tensor is the op, and if so, is there a scalar argument in arg_schemas. If this is the case, I return False to force replication. Currently, we don't need to do this for aten.mul.Tensor as it works for all the reduction ops for each Partial placement. However, in the future, new Partial placements could be added where the reduce op requires redistribution. In this case, I replicate, but I warn the user I'm replicating.

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar
2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_unverified_custom_partial




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-14 13:31:24 -08:00
b92b703b6c Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. We accomplish this by adding a function that checks if we support Partial placements when attempting to add them to the output_spec. Regarding the specific checks for adding a scalar to Partial, I check if aten.add.Tensor is the op, and if so, is there a scalar argument in arg_schemas. If this is the case, I return False to force replication. Currently, we don't need to do this for aten.mul.Tensor as it works for all the reduction ops for each Partial placement. However, in the future, new Partial placements could be added where the reduce op requires redistribution. In this case, I replicate, but I warn the user I'm replicating.

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar
2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_unverified_custom_partial




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-14 13:25:44 -08:00
48b8b9539d Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. We accomplish this by adding a function that checks if we support Partial placements when attempting to add them to the output_spec. Regarding the specific checks for adding a scalar to Partial, I check if aten.add.Tensor is the op, and if so, is there a scalar argument in arg_schemas. If this is the case, I return False to force replication. Currently, we don't need to do this for aten.mul.Tensor as it works for all the reduction ops for each Partial placement. However, in the future, new Partial placements could be added where the reduce op requires redistribution. In this case, I replicate, but I warn the user I'm replicating.

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar
2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_unverified_custom_partial




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-14 10:55:55 -08:00
47e80120ac [DTensor][Partial] fixing adding scalar to Partial
[ghstack-poisoned]
2025-11-13 22:47:15 -08:00
2 changed files with 69 additions and 4 deletions

View File

@ -350,6 +350,53 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
):
partial_dt.clamp_(max=10)
def test_add_partial_scalar(self):
mesh = self.build_device_mesh()
rank = self.rank
local_tensor = torch.tensor([rank])
dt = DTensor.from_local(
local_tensor, device_mesh=mesh, placements=[Partial("sum")]
)
res = dt + 1
self.assertEqual(res, 7)
self.assertTrue(res._spec.placements[0].is_replicate())
local_tensor = torch.tensor([1.0, 1.0, 7.0, 7.0])
dt = distribute_tensor(local_tensor, mesh, [Shard(0)])
norm = dt.norm()
norm = norm + 1
self.assertEqual(norm, 11)
self.assertTrue(norm._spec.placements[0].is_replicate())
def test_mult_div_scalar(self):
aten = torch.ops.aten
mesh = self.build_device_mesh()
rank = self.rank
local_tensor = torch.tensor([rank])
dt = DTensor.from_local(
local_tensor, device_mesh=mesh, placements=[Partial("sum")]
)
res = aten.mul.Scalar(dt, 2)
self.assertEqual(res, 2 * rank)
self.assertTrue(res._spec.placements[0].is_partial())
res = res.redistribute(dt.device_mesh, placements=[Replicate()])
self.assertEqual(res, 12)
res = aten.div.Scalar(dt, 2)
self.assertEqual(res, rank / 2)
self.assertTrue(res._spec.placements[0].is_partial())
res = res.redistribute(dt.device_mesh, placements=[Replicate()])
self.assertEqual(res, 3)
if __name__ == "__main__":
run_tests()

View File

@ -25,6 +25,7 @@ from torch.distributed.tensor.placement_types import (
Replicate,
Shard,
)
from torch.types import _Number
from torch.utils._typing_utils import not_none
@ -465,6 +466,7 @@ def pointwise_strategy(op_schema: OpSchema, linearity: int = -1) -> OpStrategy:
f"no strategy to follow for {op_schema}!"
)
return common_pointwise_strategy(
op_schema.op,
op_schema.args_schema,
followed_strategy,
followed_strategy_index,
@ -489,6 +491,7 @@ def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType:
def common_pointwise_strategy(
op,
args_schema: Sequence[object],
followed_strategy: OpStrategy,
followed_strategy_index: int,
@ -530,11 +533,25 @@ def common_pointwise_strategy(
new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim
out_placements.append(Shard(new_shard_dim))
elif isinstance(placement, Partial):
# list of ops that support linearity with partial placements
safe_partial_ops = [
aten.div.Scalar,
aten.div_.Scalar,
aten.mul.Scalar,
aten.mul_.Scalar,
aten.mul.Tensor,
aten.mul_.Tensor,
]
# note that only partial-sum and partial-avg are supported for linearity
partial_supports_linearity = placement.is_partial(
"sum"
) or placement.is_partial("avg")
if linearity > 0 and partial_supports_linearity:
partial_supports_linearity = (
placement.is_partial("sum") or placement.is_partial("avg")
) and (
op in safe_partial_ops
or not any(isinstance(arg, _Number) for arg in args_schema)
)
if linearity >= 0 and partial_supports_linearity:
# propagate the partial placement
out_placements.append(placement)
else:
@ -748,6 +765,7 @@ def list_pointwise_strategy(
for arg_strategy in args_strategies
]
pointwise_strategy: OpStrategy = common_pointwise_strategy(
op_schema.op,
args_schema,
child_strtgy,
linearity,