Compare commits

...

1 Commits

Author SHA1 Message Date
128e3fd598 Update
[ghstack-poisoned]
2025-09-15 22:39:48 +00:00

View File

@ -4616,6 +4616,57 @@ def error_inputs_native_layer_norm(opinfo, device, **kwargs):
)
yield ErrorInput(s4, error_regex=err_msg4)
def sample_inputs_native_layer_norm_backward(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# Ordered as input shape, normalized_shape, eps
cases: tuple[tuple[int], tuple[int], float] = ( # type: ignore[assignment]
((1, 2, 3), (1, 2, 3), 0.5),
((2, 2, 3), (2, 3), -0.5),
((1,), (1,), 1e-5),
((1, 2), (2,), 1e-5),
((0, 1), (1,), 1e-5),
)
for input_shape, normalized_shape, eps in cases:
# Create input and compute forward to get mean and rstd
input_tensor = make_arg(input_shape)
weight = make_arg(normalized_shape)
bias = make_arg(normalized_shape)
# Compute forward to get mean and rstd
output, mean, rstd = torch.native_layer_norm(
input_tensor, normalized_shape, weight, bias, eps
)
# Create grad_out tensor
grad_out = make_arg(input_shape)
# Test different combinations of weight and bias
# Full case with weight and bias
yield SampleInput(
grad_out,
args=(input_tensor, normalized_shape, mean, rstd, weight, bias, [True, True, True]),
)
# Case with weight but no bias
yield SampleInput(
grad_out,
args=(input_tensor, normalized_shape, mean, rstd, weight, None, [True, True, False]),
)
# Case with bias but no weight
yield SampleInput(
grad_out,
args=(input_tensor, normalized_shape, mean, rstd, None, bias, [True, False, True]),
)
# Case with neither weight nor bias
yield SampleInput(
grad_out,
args=(input_tensor, normalized_shape, mean, rstd, None, None, [True, False, False]),
)
def error_inputs_rms_norm(opinfo, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
input_shape = (1, 2, 3)
@ -11546,7 +11597,6 @@ def reference_mse_loss(input, target, reduction="mean"):
def reference_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight=None, bias=None, eps=1e-5):
return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0]
def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight, bias, eps):
feature_size = np.prod(normalized_shape)
inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload]
@ -11563,7 +11613,6 @@ def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int],
stat_shape = inp.shape[:axis] + (1,) * len(normalized_shape)
return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape)
def reference_rms_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight=None, eps=None):
if eps is None:
eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps
@ -14916,6 +14965,27 @@ op_db: list[OpInfo] = [
DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-03, rtol=5e-03)}),
"TestDecomp", "test_comprehensive", device_type="cpu"),
)),
OpInfo('native_layer_norm_backward',
aten_name='native_layer_norm_backward',
op=torch.ops.aten.native_layer_norm_backward,
dtypes=floating_types_and(torch.half, torch.bfloat16),
dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
supports_out=False,
assert_jit_shape_analysis=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_native_layer_norm_backward,
skips=(
# IndexError: tuple index out of range
# DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients', 'test_forward_mode_AD'),
# Tests fail when weight=None and bias is defined
# https://github.com/pytorch/pytorch/issues/79705
# DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'),
# JIT test also tries to compute double backward, which fails
# DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
# DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-03, rtol=5e-03)}),
"TestDecomp", "test_comprehensive", device_type="cpu"),
)),
OpInfo('native_batch_norm',
aten_name='native_batch_norm',
dtypes=floating_types_and(torch.float16, torch.bfloat16),