mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor] Registers sharding rule for rms_norm (#159692)
Reduces collective calls in the forward pass from 2 to 1 In #158716 I added the sharding rule for the backward pass but didn't add the forward pass as it didn't get dispatched. After #159324 this should get properly dispatched hence I am adding it now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159692 Approved by: https://github.com/tianyu-l
This commit is contained in:
committed by
PyTorch MergeBot
parent
5a9c4cfce4
commit
b4596895b9
@ -271,14 +271,22 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
norm_shape_idx_list = list(range(x.ndim))
|
norm_shape_idx_list = list(range(x.ndim))
|
||||||
shard_dims = [-1, 0, 1, 2]
|
shard_dims = [-1, 0, 1, 2]
|
||||||
elementwise_affine_list = [False, True]
|
elementwise_affine_list = [False, True]
|
||||||
|
|
||||||
|
# Test RMSNorm as well if CUDA
|
||||||
|
norm_types = [torch.nn.LayerNorm]
|
||||||
|
if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"):
|
||||||
|
norm_types.append(torch.nn.RMSNorm)
|
||||||
|
|
||||||
test_config_list = list(
|
test_config_list = list(
|
||||||
itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list)
|
itertools.product(
|
||||||
|
norm_types, shard_dims, norm_shape_idx_list, elementwise_affine_list
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# normalized shape is a torch.Size object
|
# normalized shape is a torch.Size object
|
||||||
for shard_dim, norm_idx, elementwise_affine in test_config_list:
|
for norm_type, shard_dim, norm_idx, elementwise_affine in test_config_list:
|
||||||
normalized_shape = x.shape[norm_idx:]
|
normalized_shape = x.shape[norm_idx:]
|
||||||
layer_norm = torch.nn.LayerNorm(
|
layer_norm = norm_type(
|
||||||
normalized_shape,
|
normalized_shape,
|
||||||
elementwise_affine=elementwise_affine,
|
elementwise_affine=elementwise_affine,
|
||||||
device=self.device_type,
|
device=self.device_type,
|
||||||
@ -287,6 +295,7 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
|
|
||||||
def _replicate_fn(name, module, device_mesh):
|
def _replicate_fn(name, module, device_mesh):
|
||||||
for name, param in module.named_parameters():
|
for name, param in module.named_parameters():
|
||||||
|
# RMSNorm only has weight, LayerNorm has both weight and bias
|
||||||
if name in ["weight", "bias"]:
|
if name in ["weight", "bias"]:
|
||||||
param_dist = torch.nn.Parameter(
|
param_dist = torch.nn.Parameter(
|
||||||
distribute_tensor(param, device_mesh, [Replicate()])
|
distribute_tensor(param, device_mesh, [Replicate()])
|
||||||
@ -307,7 +316,7 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
self.assertLessEqual(
|
self.assertLessEqual(
|
||||||
comm_mode.get_total_counts(),
|
comm_mode.get_total_counts(),
|
||||||
1, # TODO: This should be 0!
|
1, # TODO: This should be 0!
|
||||||
f"comm count={comm_mode.get_total_counts()}, "
|
f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, "
|
||||||
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
|
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -329,12 +338,20 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
norm_shape_idx_list = list(range(3))
|
norm_shape_idx_list = list(range(3))
|
||||||
shard_dims = [0, 1, 2]
|
shard_dims = [0, 1, 2]
|
||||||
elementwise_affine_list = [False, True]
|
elementwise_affine_list = [False, True]
|
||||||
|
|
||||||
|
# Test both LayerNorm and RMSNorm (if CUDA)
|
||||||
|
norm_types = [torch.nn.LayerNorm]
|
||||||
|
if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"):
|
||||||
|
norm_types.append(torch.nn.RMSNorm)
|
||||||
|
|
||||||
test_config_list = list(
|
test_config_list = list(
|
||||||
itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list)
|
itertools.product(
|
||||||
|
norm_types, shard_dims, norm_shape_idx_list, elementwise_affine_list
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# normalized shape is a torch.Size object
|
# normalized shape is a torch.Size object
|
||||||
for shard_dim, norm_idx, elementwise_affine in test_config_list:
|
for norm_type, shard_dim, norm_idx, elementwise_affine in test_config_list:
|
||||||
x = torch.rand(
|
x = torch.rand(
|
||||||
batch,
|
batch,
|
||||||
sentence_length,
|
sentence_length,
|
||||||
@ -343,7 +360,7 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
)
|
)
|
||||||
normalized_shape = x.shape[norm_idx:]
|
normalized_shape = x.shape[norm_idx:]
|
||||||
layer_norm = torch.nn.LayerNorm(
|
layer_norm = norm_type(
|
||||||
normalized_shape,
|
normalized_shape,
|
||||||
elementwise_affine=elementwise_affine,
|
elementwise_affine=elementwise_affine,
|
||||||
device=self.device_type,
|
device=self.device_type,
|
||||||
@ -364,9 +381,11 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
layer_norm_local.weight, layer_norm_dist.weight.full_tensor()
|
layer_norm_local.weight, layer_norm_dist.weight.full_tensor()
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
# RMSNorm doesn't have bias
|
||||||
layer_norm_local.bias, layer_norm_dist.bias.full_tensor()
|
if hasattr(layer_norm_local, "bias"):
|
||||||
)
|
self.assertEqual(
|
||||||
|
layer_norm_local.bias, layer_norm_dist.bias.full_tensor()
|
||||||
|
)
|
||||||
|
|
||||||
x_local = x.detach().clone().requires_grad_(True)
|
x_local = x.detach().clone().requires_grad_(True)
|
||||||
x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
|
x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
|
||||||
@ -384,7 +403,7 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
sum(comm_mode.comm_module_counts["Global"]["forward"].values()),
|
sum(comm_mode.comm_module_counts["Global"]["forward"].values()),
|
||||||
expected_fwd_comm,
|
expected_fwd_comm,
|
||||||
f"comm count={comm_mode.get_total_counts()}, "
|
f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, "
|
||||||
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
|
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -398,7 +417,7 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
sum(comm_mode.comm_module_counts["Global"]["backward"].values()),
|
sum(comm_mode.comm_module_counts["Global"]["backward"].values()),
|
||||||
expected_bwd_comm,
|
expected_bwd_comm,
|
||||||
f"comm count={comm_mode.get_total_counts()}, "
|
f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, "
|
||||||
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
|
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -412,18 +431,22 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
is_tensor_partial(layer_norm_dist.weight.grad._spec),
|
is_tensor_partial(layer_norm_dist.weight.grad._spec),
|
||||||
needs_reduction,
|
needs_reduction,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
# RMSNorm doesn't have bias
|
||||||
is_tensor_partial(layer_norm_dist.bias.grad._spec),
|
if hasattr(layer_norm_dist, "bias"):
|
||||||
needs_reduction,
|
self.assertEqual(
|
||||||
)
|
is_tensor_partial(layer_norm_dist.bias.grad._spec),
|
||||||
|
needs_reduction,
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
layer_norm_local.weight.grad,
|
layer_norm_local.weight.grad,
|
||||||
layer_norm_dist.weight.grad.full_tensor(),
|
layer_norm_dist.weight.grad.full_tensor(),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
# RMSNorm doesn't have bias
|
||||||
layer_norm_local.bias.grad,
|
if hasattr(layer_norm_local, "bias"):
|
||||||
layer_norm_dist.bias.grad.full_tensor(),
|
self.assertEqual(
|
||||||
)
|
layer_norm_local.bias.grad,
|
||||||
|
layer_norm_dist.bias.grad.full_tensor(),
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(x_local.grad, x_dist.grad.full_tensor())
|
self.assertEqual(x_local.grad, x_dist.grad.full_tensor())
|
||||||
|
|
||||||
@ -432,8 +455,14 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
device_mesh = self.build_device_mesh()
|
device_mesh = self.build_device_mesh()
|
||||||
batch, seq_len, embedding_dim, vocab_size = 8, 8, 10, 32
|
batch, seq_len, embedding_dim, vocab_size = 8, 8, 10, 32
|
||||||
|
|
||||||
|
# Test both LayerNorm and RMSNorm (if CUDA)
|
||||||
|
norm_types = [torch.nn.LayerNorm]
|
||||||
|
if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"):
|
||||||
|
norm_types.append(torch.nn.RMSNorm)
|
||||||
|
|
||||||
# build our subtest configurations and filter out invalid ones
|
# build our subtest configurations and filter out invalid ones
|
||||||
class SubTest(NamedTuple):
|
class SubTest(NamedTuple):
|
||||||
|
norm_type: type
|
||||||
multidim_norm: bool
|
multidim_norm: bool
|
||||||
elementwise_affine: bool
|
elementwise_affine: bool
|
||||||
emb_req_grad: bool
|
emb_req_grad: bool
|
||||||
@ -443,19 +472,24 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
subtest_fails = {}
|
subtest_fails = {}
|
||||||
valid_filter = ( # noqa: E731
|
valid_filter = ( # noqa: E731
|
||||||
lambda cfg: (
|
lambda cfg: (
|
||||||
not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:])
|
not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[3:])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
subtest_cfgs = list(
|
subtest_cfgs = list(
|
||||||
filter(
|
filter(
|
||||||
valid_filter,
|
valid_filter,
|
||||||
[SubTest(*cfg) for cfg in itertools.product(*(((False, True),) * 5))],
|
[
|
||||||
|
SubTest(norm_type, *cfg)
|
||||||
|
for norm_type in norm_types
|
||||||
|
for cfg in itertools.product(*(((False, True),) * 5))
|
||||||
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for subtest_cfg in subtest_cfgs:
|
for subtest_cfg in subtest_cfgs:
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
|
norm_type,
|
||||||
multidim_norm,
|
multidim_norm,
|
||||||
elementwise_affine,
|
elementwise_affine,
|
||||||
emb_req_grad,
|
emb_req_grad,
|
||||||
@ -473,7 +507,7 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
self.preln_embeddings = torch.nn.Embedding(
|
self.preln_embeddings = torch.nn.Embedding(
|
||||||
vocab_size, embedding_dim
|
vocab_size, embedding_dim
|
||||||
)
|
)
|
||||||
self.layer_norm = torch.nn.LayerNorm(
|
self.layer_norm = norm_type(
|
||||||
normalized_shape, elementwise_affine=elementwise_affine
|
normalized_shape, elementwise_affine=elementwise_affine
|
||||||
)
|
)
|
||||||
self.postln_linear = torch.nn.Linear(
|
self.postln_linear = torch.nn.Linear(
|
||||||
@ -572,104 +606,6 @@ class DistMathOpsTest(DTensorTestBase):
|
|||||||
f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
|
f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_comms
|
|
||||||
def test_rms_norm_bwd(self):
|
|
||||||
device_mesh = self.build_device_mesh()
|
|
||||||
|
|
||||||
# NLP example from pytorch docs
|
|
||||||
batch, sentence_length, embedding_dim = 20, 5, 10
|
|
||||||
norm_shape_idx_list = list(range(3))
|
|
||||||
shard_dims = [0] # non-first dimensional sharding is not supported
|
|
||||||
elementwise_affine_list = [False, True]
|
|
||||||
test_config_list = list(
|
|
||||||
itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list)
|
|
||||||
)
|
|
||||||
|
|
||||||
# normalized shape is a torch.Size object
|
|
||||||
for shard_dim, norm_idx, elementwise_affine in test_config_list:
|
|
||||||
x = torch.rand(
|
|
||||||
batch,
|
|
||||||
sentence_length,
|
|
||||||
embedding_dim,
|
|
||||||
device=self.device_type,
|
|
||||||
requires_grad=True,
|
|
||||||
)
|
|
||||||
normalized_shape = x.shape[norm_idx:]
|
|
||||||
rms_norm = torch.nn.RMSNorm(
|
|
||||||
normalized_shape,
|
|
||||||
elementwise_affine=elementwise_affine,
|
|
||||||
device=self.device_type,
|
|
||||||
)
|
|
||||||
rms_norm_local = copy.deepcopy(rms_norm).to(self.device_type)
|
|
||||||
|
|
||||||
def _replicate_fn(name, module, device_mesh):
|
|
||||||
for name, param in module.named_parameters():
|
|
||||||
if name == "weight":
|
|
||||||
param_dist = torch.nn.Parameter(
|
|
||||||
distribute_tensor(param, device_mesh, [Replicate()])
|
|
||||||
)
|
|
||||||
module.register_parameter(name, param_dist)
|
|
||||||
|
|
||||||
rms_norm_dist = distribute_module(rms_norm, device_mesh, _replicate_fn)
|
|
||||||
|
|
||||||
if elementwise_affine:
|
|
||||||
self.assertEqual(
|
|
||||||
rms_norm_local.weight, rms_norm_dist.weight.full_tensor()
|
|
||||||
)
|
|
||||||
|
|
||||||
x_local = x.detach().clone().requires_grad_(True)
|
|
||||||
x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
|
|
||||||
self.assertEqual(x_local, x_dist.full_tensor())
|
|
||||||
|
|
||||||
y_local = rms_norm_local(x_local)
|
|
||||||
# make sure that backward rms norm does not introduce extra collectives
|
|
||||||
comm_mode = CommDebugMode()
|
|
||||||
with comm_mode:
|
|
||||||
y_dist = rms_norm_dist(x_dist)
|
|
||||||
y_dist.sum().backward()
|
|
||||||
|
|
||||||
# TODO: forward pass is sharding strategy is generated from composite, hence 1 more collective than layer_norm
|
|
||||||
# see: https://github.com/pytorch/pytorch/pull/158716#issuecomment-3096012679
|
|
||||||
expected_fwd_comm = 0 if shard_dim < norm_idx else 2
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
sum(comm_mode.comm_module_counts["Global"]["forward"].values()),
|
|
||||||
expected_fwd_comm,
|
|
||||||
f"comm count={comm_mode.get_total_counts()}, "
|
|
||||||
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(y_local, y_dist.full_tensor())
|
|
||||||
|
|
||||||
# backward step
|
|
||||||
y_local.sum().backward()
|
|
||||||
|
|
||||||
expected_bwd_comm = 0 if shard_dim < norm_idx else 1
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
sum(comm_mode.comm_module_counts["Global"]["backward"].values()),
|
|
||||||
expected_bwd_comm,
|
|
||||||
f"comm count={comm_mode.get_total_counts()}, "
|
|
||||||
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if elementwise_affine:
|
|
||||||
# if input is sharded on any outer dimension, the gradient of weight
|
|
||||||
# should be Partial
|
|
||||||
dim_map = x_dist._spec.dim_map
|
|
||||||
outer_dims = range(norm_idx)
|
|
||||||
needs_reduction = any(dim_map[d] >= 0 for d in outer_dims)
|
|
||||||
self.assertEqual(
|
|
||||||
is_tensor_partial(rms_norm_dist.weight.grad._spec),
|
|
||||||
needs_reduction,
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
rms_norm_local.weight.grad,
|
|
||||||
rms_norm_dist.weight.grad.full_tensor(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(x_local.grad, x_dist.grad.full_tensor())
|
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_topk(self):
|
def test_topk(self):
|
||||||
device_mesh = self.build_device_mesh()
|
device_mesh = self.build_device_mesh()
|
||||||
|
@ -818,27 +818,38 @@ def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy:
|
|||||||
return grad_in_strategy
|
return grad_in_strategy
|
||||||
|
|
||||||
|
|
||||||
@register_op_strategy(
|
def _common_norm_forward_strategy(
|
||||||
[aten.native_layer_norm.default],
|
op_schema: OpSchema,
|
||||||
schema_info=RuntimeSchemaInfo(1),
|
rms_norm: bool = False,
|
||||||
)
|
) -> OpStrategy:
|
||||||
def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
"""Common forward strategy logic for layer_norm and rms_norm."""
|
||||||
mesh = op_schema.get_mesh_from_args()
|
mesh = op_schema.get_mesh_from_args()
|
||||||
|
|
||||||
# args must be: input, normalized_shape, weight, bias, eps
|
if not rms_norm:
|
||||||
# for None weight and bias, their corresponding objects will
|
# layer_norm args: input, normalized_shape, weight, bias, eps
|
||||||
# be None as well. layer_norm_strategy returns one OpStrategy
|
# for None weight and bias, their corresponding objects will
|
||||||
# for the triple return values (out, mean, rstd).
|
# be None as well. layer_norm_strategy returns one OpStrategy
|
||||||
assert len(op_schema.args_schema) == 5
|
# for the triple return values (out, mean, rstd).
|
||||||
(
|
assert len(op_schema.args_schema) == 5
|
||||||
input_strategy,
|
(
|
||||||
normalized_shape,
|
input_strategy,
|
||||||
weight_strategy,
|
normalized_shape,
|
||||||
bias_strategy,
|
weight_strategy,
|
||||||
_,
|
bias_strategy,
|
||||||
) = op_schema.args_schema
|
_,
|
||||||
|
) = op_schema.args_schema
|
||||||
|
else:
|
||||||
|
# rms_norm args: input, normalized_shape, weight, eps
|
||||||
|
assert len(op_schema.args_schema) == 4
|
||||||
|
(
|
||||||
|
input_strategy,
|
||||||
|
normalized_shape,
|
||||||
|
weight_strategy,
|
||||||
|
_,
|
||||||
|
) = op_schema.args_schema
|
||||||
|
bias_strategy = None
|
||||||
|
|
||||||
# the current layer norm implementation requires that all
|
# the current norm implementation requires that all
|
||||||
# input DTensor's sharding must be in form of OpStrategy
|
# input DTensor's sharding must be in form of OpStrategy
|
||||||
assert isinstance(input_strategy, OpStrategy)
|
assert isinstance(input_strategy, OpStrategy)
|
||||||
assert isinstance(normalized_shape, (int, Sequence, torch.Size))
|
assert isinstance(normalized_shape, (int, Sequence, torch.Size))
|
||||||
@ -847,7 +858,7 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
|||||||
input_ndim = input_strategy.ndim
|
input_ndim = input_strategy.ndim
|
||||||
axis = input_ndim - len(normalized_size)
|
axis = input_ndim - len(normalized_size)
|
||||||
|
|
||||||
# we use OpStrategy because the output (out, mean, rstd)
|
# we use OpStrategy because the output values (out, mean, rstd)
|
||||||
# should have the same placements
|
# should have the same placements
|
||||||
output_strategy = OpStrategy([])
|
output_strategy = OpStrategy([])
|
||||||
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
|
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
|
||||||
@ -915,6 +926,22 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
|||||||
return output_strategy
|
return output_strategy
|
||||||
|
|
||||||
|
|
||||||
|
@register_op_strategy(
|
||||||
|
[aten.native_layer_norm.default],
|
||||||
|
schema_info=RuntimeSchemaInfo(1),
|
||||||
|
)
|
||||||
|
def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||||
|
return _common_norm_forward_strategy(op_schema)
|
||||||
|
|
||||||
|
|
||||||
|
@register_op_strategy(
|
||||||
|
[aten._fused_rms_norm.default],
|
||||||
|
schema_info=RuntimeSchemaInfo(1),
|
||||||
|
)
|
||||||
|
def fused_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||||
|
return _common_norm_forward_strategy(op_schema, rms_norm=True)
|
||||||
|
|
||||||
|
|
||||||
def _common_norm_backward_strategy(
|
def _common_norm_backward_strategy(
|
||||||
op_schema: OpSchema,
|
op_schema: OpSchema,
|
||||||
rms_norm: bool = False,
|
rms_norm: bool = False,
|
||||||
|
Reference in New Issue
Block a user