[DTensor] Registers sharding rule for rms_norm (#159692)

Reduces collective calls in the forward pass from 2 to 1

In #158716 I added the sharding rule for the backward pass but didn't add the forward pass as it didn't get dispatched. After #159324 this should get properly dispatched hence I am adding it now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159692
Approved by: https://github.com/tianyu-l
This commit is contained in:
AaronWang04
2025-08-12 21:05:24 +00:00
committed by PyTorch MergeBot
parent 5a9c4cfce4
commit b4596895b9
2 changed files with 103 additions and 140 deletions

View File

@ -271,14 +271,22 @@ class DistMathOpsTest(DTensorTestBase):
norm_shape_idx_list = list(range(x.ndim))
shard_dims = [-1, 0, 1, 2]
elementwise_affine_list = [False, True]
# Test RMSNorm as well if CUDA
norm_types = [torch.nn.LayerNorm]
if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"):
norm_types.append(torch.nn.RMSNorm)
test_config_list = list(
itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list)
itertools.product(
norm_types, shard_dims, norm_shape_idx_list, elementwise_affine_list
)
)
# normalized shape is a torch.Size object
for shard_dim, norm_idx, elementwise_affine in test_config_list:
for norm_type, shard_dim, norm_idx, elementwise_affine in test_config_list:
normalized_shape = x.shape[norm_idx:]
layer_norm = torch.nn.LayerNorm(
layer_norm = norm_type(
normalized_shape,
elementwise_affine=elementwise_affine,
device=self.device_type,
@ -287,6 +295,7 @@ class DistMathOpsTest(DTensorTestBase):
def _replicate_fn(name, module, device_mesh):
for name, param in module.named_parameters():
# RMSNorm only has weight, LayerNorm has both weight and bias
if name in ["weight", "bias"]:
param_dist = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [Replicate()])
@ -307,7 +316,7 @@ class DistMathOpsTest(DTensorTestBase):
self.assertLessEqual(
comm_mode.get_total_counts(),
1, # TODO: This should be 0!
f"comm count={comm_mode.get_total_counts()}, "
f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, "
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
)
@ -329,12 +338,20 @@ class DistMathOpsTest(DTensorTestBase):
norm_shape_idx_list = list(range(3))
shard_dims = [0, 1, 2]
elementwise_affine_list = [False, True]
# Test both LayerNorm and RMSNorm (if CUDA)
norm_types = [torch.nn.LayerNorm]
if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"):
norm_types.append(torch.nn.RMSNorm)
test_config_list = list(
itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list)
itertools.product(
norm_types, shard_dims, norm_shape_idx_list, elementwise_affine_list
)
)
# normalized shape is a torch.Size object
for shard_dim, norm_idx, elementwise_affine in test_config_list:
for norm_type, shard_dim, norm_idx, elementwise_affine in test_config_list:
x = torch.rand(
batch,
sentence_length,
@ -343,7 +360,7 @@ class DistMathOpsTest(DTensorTestBase):
requires_grad=True,
)
normalized_shape = x.shape[norm_idx:]
layer_norm = torch.nn.LayerNorm(
layer_norm = norm_type(
normalized_shape,
elementwise_affine=elementwise_affine,
device=self.device_type,
@ -364,9 +381,11 @@ class DistMathOpsTest(DTensorTestBase):
self.assertEqual(
layer_norm_local.weight, layer_norm_dist.weight.full_tensor()
)
self.assertEqual(
layer_norm_local.bias, layer_norm_dist.bias.full_tensor()
)
# RMSNorm doesn't have bias
if hasattr(layer_norm_local, "bias"):
self.assertEqual(
layer_norm_local.bias, layer_norm_dist.bias.full_tensor()
)
x_local = x.detach().clone().requires_grad_(True)
x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
@ -384,7 +403,7 @@ class DistMathOpsTest(DTensorTestBase):
self.assertEqual(
sum(comm_mode.comm_module_counts["Global"]["forward"].values()),
expected_fwd_comm,
f"comm count={comm_mode.get_total_counts()}, "
f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, "
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
)
@ -398,7 +417,7 @@ class DistMathOpsTest(DTensorTestBase):
self.assertEqual(
sum(comm_mode.comm_module_counts["Global"]["backward"].values()),
expected_bwd_comm,
f"comm count={comm_mode.get_total_counts()}, "
f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, "
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
)
@ -412,18 +431,22 @@ class DistMathOpsTest(DTensorTestBase):
is_tensor_partial(layer_norm_dist.weight.grad._spec),
needs_reduction,
)
self.assertEqual(
is_tensor_partial(layer_norm_dist.bias.grad._spec),
needs_reduction,
)
# RMSNorm doesn't have bias
if hasattr(layer_norm_dist, "bias"):
self.assertEqual(
is_tensor_partial(layer_norm_dist.bias.grad._spec),
needs_reduction,
)
self.assertEqual(
layer_norm_local.weight.grad,
layer_norm_dist.weight.grad.full_tensor(),
)
self.assertEqual(
layer_norm_local.bias.grad,
layer_norm_dist.bias.grad.full_tensor(),
)
# RMSNorm doesn't have bias
if hasattr(layer_norm_local, "bias"):
self.assertEqual(
layer_norm_local.bias.grad,
layer_norm_dist.bias.grad.full_tensor(),
)
self.assertEqual(x_local.grad, x_dist.grad.full_tensor())
@ -432,8 +455,14 @@ class DistMathOpsTest(DTensorTestBase):
device_mesh = self.build_device_mesh()
batch, seq_len, embedding_dim, vocab_size = 8, 8, 10, 32
# Test both LayerNorm and RMSNorm (if CUDA)
norm_types = [torch.nn.LayerNorm]
if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"):
norm_types.append(torch.nn.RMSNorm)
# build our subtest configurations and filter out invalid ones
class SubTest(NamedTuple):
norm_type: type
multidim_norm: bool
elementwise_affine: bool
emb_req_grad: bool
@ -443,19 +472,24 @@ class DistMathOpsTest(DTensorTestBase):
subtest_fails = {}
valid_filter = ( # noqa: E731
lambda cfg: (
not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:])
not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[3:])
)
)
subtest_cfgs = list(
filter(
valid_filter,
[SubTest(*cfg) for cfg in itertools.product(*(((False, True),) * 5))],
[
SubTest(norm_type, *cfg)
for norm_type in norm_types
for cfg in itertools.product(*(((False, True),) * 5))
],
)
)
for subtest_cfg in subtest_cfgs:
try:
(
norm_type,
multidim_norm,
elementwise_affine,
emb_req_grad,
@ -473,7 +507,7 @@ class DistMathOpsTest(DTensorTestBase):
self.preln_embeddings = torch.nn.Embedding(
vocab_size, embedding_dim
)
self.layer_norm = torch.nn.LayerNorm(
self.layer_norm = norm_type(
normalized_shape, elementwise_affine=elementwise_affine
)
self.postln_linear = torch.nn.Linear(
@ -572,104 +606,6 @@ class DistMathOpsTest(DTensorTestBase):
f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
)
@with_comms
def test_rms_norm_bwd(self):
device_mesh = self.build_device_mesh()
# NLP example from pytorch docs
batch, sentence_length, embedding_dim = 20, 5, 10
norm_shape_idx_list = list(range(3))
shard_dims = [0] # non-first dimensional sharding is not supported
elementwise_affine_list = [False, True]
test_config_list = list(
itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list)
)
# normalized shape is a torch.Size object
for shard_dim, norm_idx, elementwise_affine in test_config_list:
x = torch.rand(
batch,
sentence_length,
embedding_dim,
device=self.device_type,
requires_grad=True,
)
normalized_shape = x.shape[norm_idx:]
rms_norm = torch.nn.RMSNorm(
normalized_shape,
elementwise_affine=elementwise_affine,
device=self.device_type,
)
rms_norm_local = copy.deepcopy(rms_norm).to(self.device_type)
def _replicate_fn(name, module, device_mesh):
for name, param in module.named_parameters():
if name == "weight":
param_dist = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [Replicate()])
)
module.register_parameter(name, param_dist)
rms_norm_dist = distribute_module(rms_norm, device_mesh, _replicate_fn)
if elementwise_affine:
self.assertEqual(
rms_norm_local.weight, rms_norm_dist.weight.full_tensor()
)
x_local = x.detach().clone().requires_grad_(True)
x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
self.assertEqual(x_local, x_dist.full_tensor())
y_local = rms_norm_local(x_local)
# make sure that backward rms norm does not introduce extra collectives
comm_mode = CommDebugMode()
with comm_mode:
y_dist = rms_norm_dist(x_dist)
y_dist.sum().backward()
# TODO: forward pass is sharding strategy is generated from composite, hence 1 more collective than layer_norm
# see: https://github.com/pytorch/pytorch/pull/158716#issuecomment-3096012679
expected_fwd_comm = 0 if shard_dim < norm_idx else 2
self.assertEqual(
sum(comm_mode.comm_module_counts["Global"]["forward"].values()),
expected_fwd_comm,
f"comm count={comm_mode.get_total_counts()}, "
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
)
self.assertEqual(y_local, y_dist.full_tensor())
# backward step
y_local.sum().backward()
expected_bwd_comm = 0 if shard_dim < norm_idx else 1
self.assertEqual(
sum(comm_mode.comm_module_counts["Global"]["backward"].values()),
expected_bwd_comm,
f"comm count={comm_mode.get_total_counts()}, "
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
)
if elementwise_affine:
# if input is sharded on any outer dimension, the gradient of weight
# should be Partial
dim_map = x_dist._spec.dim_map
outer_dims = range(norm_idx)
needs_reduction = any(dim_map[d] >= 0 for d in outer_dims)
self.assertEqual(
is_tensor_partial(rms_norm_dist.weight.grad._spec),
needs_reduction,
)
self.assertEqual(
rms_norm_local.weight.grad,
rms_norm_dist.weight.grad.full_tensor(),
)
self.assertEqual(x_local.grad, x_dist.grad.full_tensor())
@with_comms
def test_topk(self):
device_mesh = self.build_device_mesh()

View File

@ -818,27 +818,38 @@ def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy:
return grad_in_strategy
@register_op_strategy(
[aten.native_layer_norm.default],
schema_info=RuntimeSchemaInfo(1),
)
def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy:
def _common_norm_forward_strategy(
op_schema: OpSchema,
rms_norm: bool = False,
) -> OpStrategy:
"""Common forward strategy logic for layer_norm and rms_norm."""
mesh = op_schema.get_mesh_from_args()
# args must be: input, normalized_shape, weight, bias, eps
# for None weight and bias, their corresponding objects will
# be None as well. layer_norm_strategy returns one OpStrategy
# for the triple return values (out, mean, rstd).
assert len(op_schema.args_schema) == 5
(
input_strategy,
normalized_shape,
weight_strategy,
bias_strategy,
_,
) = op_schema.args_schema
if not rms_norm:
# layer_norm args: input, normalized_shape, weight, bias, eps
# for None weight and bias, their corresponding objects will
# be None as well. layer_norm_strategy returns one OpStrategy
# for the triple return values (out, mean, rstd).
assert len(op_schema.args_schema) == 5
(
input_strategy,
normalized_shape,
weight_strategy,
bias_strategy,
_,
) = op_schema.args_schema
else:
# rms_norm args: input, normalized_shape, weight, eps
assert len(op_schema.args_schema) == 4
(
input_strategy,
normalized_shape,
weight_strategy,
_,
) = op_schema.args_schema
bias_strategy = None
# the current layer norm implementation requires that all
# the current norm implementation requires that all
# input DTensor's sharding must be in form of OpStrategy
assert isinstance(input_strategy, OpStrategy)
assert isinstance(normalized_shape, (int, Sequence, torch.Size))
@ -847,7 +858,7 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy:
input_ndim = input_strategy.ndim
axis = input_ndim - len(normalized_size)
# we use OpStrategy because the output (out, mean, rstd)
# we use OpStrategy because the output values (out, mean, rstd)
# should have the same placements
output_strategy = OpStrategy([])
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
@ -915,6 +926,22 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy:
return output_strategy
@register_op_strategy(
[aten.native_layer_norm.default],
schema_info=RuntimeSchemaInfo(1),
)
def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy:
return _common_norm_forward_strategy(op_schema)
@register_op_strategy(
[aten._fused_rms_norm.default],
schema_info=RuntimeSchemaInfo(1),
)
def fused_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
return _common_norm_forward_strategy(op_schema, rms_norm=True)
def _common_norm_backward_strategy(
op_schema: OpSchema,
rms_norm: bool = False,