Files
pytorch/test/lazy/test_ts_opinfo.py
Edward Z. Yang de8d81275a Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
2025-10-11 01:03:55 +00:00

377 lines
12 KiB
Python

# Owner(s): ["oncall: jit"]
import functools
import itertools
import os
from collections.abc import Sequence
from pathlib import Path
from unittest import skip
import yaml
import torch
import torch._lazy
import torch._lazy.config
import torch._lazy.ir_cache
import torch._lazy.metrics
import torch._lazy.ts_backend
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.jit_utils import JitTestCase
torch._lazy.ts_backend.init()
def get_test_device():
return "cuda" if "LTC_TS_CUDA" in os.environ else "cpu"
def remove_suffixes(l):
return [x.split(".")[0] for x in l]
def init_lists():
path_to_script = Path(os.path.abspath(os.path.dirname(__file__)))
TS_NATIVE_FUNCTIONS_PATH = (
path_to_script.parent.parent / "aten/src/ATen/native/ts_native_functions.yaml"
)
with open(TS_NATIVE_FUNCTIONS_PATH) as f:
yaml_ts = yaml.load(f, yaml.SafeLoader)
LAZY_OPS_LIST = set(
remove_suffixes(
itertools.chain(
yaml_ts["full_codegen"], yaml_ts["supported"], yaml_ts["autograd"]
)
)
)
HAS_SYMINT_SUFFIX = yaml_ts["symint"]
FALLBACK_LIST = {"clamp"}
SKIP_RUNTIME_ERROR_LIST = {
"index_select", # Empty output_sizes is not supported
"clone", # is clone decomposed?
# General ASAN Failure due to related to generating bool values.
# https://github.com/pytorch/pytorch/issues/74519
# https://github.com/pytorch/pytorch/issues/63034
"nonzero", # ASAN failure (paste: P501906539)
"all", # ASAN failure
"any", # ASAN failure
"logdet", # ASAN failure
}
SKIP_INCORRECT_RESULTS_LIST = {
"squeeze", # Value out of range
"t", # Value out of range
"transpose", # Value out of range
"bernoulli", # incorrect results
"pow", # incorrect results
"addcdiv", # incorrect results (on CI not locally?)
}
# The following ops all show up directly in ts_native_functions.yaml,
# but run functionalized versions of the composite kernels in core.
# This means that we don't expect the ops to show directly in the LTC metrics.
FUNCTIONAL_DECOMPOSE_LIST = {
"diag_embed",
"block_diag",
"new_empty_strided",
"narrow_copy",
"pixel_shuffle",
"pixel_unshuffle",
"select_backward",
"_trilinear",
"linalg_inv_ex",
"linalg_pinv.atol_rtol_tensor",
"logsumexp",
"svd",
}
# For some ops, we don't support all variants. Here we use formatted_name
# to uniquely identify the variant.
SKIP_VARIANT_LIST = {"norm_nuc", "min_reduction_with_dim"}
return (
LAZY_OPS_LIST,
FALLBACK_LIST,
SKIP_RUNTIME_ERROR_LIST,
SKIP_INCORRECT_RESULTS_LIST,
FUNCTIONAL_DECOMPOSE_LIST,
HAS_SYMINT_SUFFIX,
SKIP_VARIANT_LIST,
)
(
LAZY_OPS_LIST,
FALLBACK_LIST,
SKIP_RUNTIME_ERROR_LIST,
SKIP_INCORRECT_RESULTS_LIST,
FUNCTIONAL_DECOMPOSE_LIST,
HAS_SYMINT_SUFFIX,
SKIP_VARIANT_LIST,
) = init_lists()
torch.manual_seed(42)
def clone_move(t):
dev = "lazy"
copy_t = t.detach().clone().requires_grad_(True).to(device=dev)
return copy_t
class TestLazyTensor(JitTestCase):
@skip("Disable until autograd supports symints")
def testConvolutionBackward(self):
test_device = get_test_device()
inp = torch.rand(1, 3, 128, 128, device=test_device, requires_grad=True)
inp_copy = clone_move(inp)
grad = torch.rand(1, 32, 121, 121, device=test_device) # no requires_grad
grad_copy = clone_move(grad)
weight = torch.rand(32, 3, 8, 8, device=test_device, requires_grad=True)
weight_copy = clone_move(weight)
bias = torch.rand(32, device=test_device, requires_grad=True)
bias_copy = clone_move(bias)
# run eager
conv_out = torch.nn.functional.conv2d(inp, weight, bias)
(inp_grad, weight_grad, bias_grad) = torch.autograd.grad(
[conv_out], [inp, weight, bias], [grad]
)
# run lazy
conv_copy_out = torch.nn.functional.conv2d(inp_copy, weight_copy, bias_copy)
(inp_copy_grad, weight_copy_grad, bias_copy_grad) = torch.autograd.grad(
[conv_copy_out], [inp_copy, weight_copy, bias_copy], [grad_copy]
)
# check numerics
torch.testing.assert_close(bias_copy_grad.cpu(), bias_grad.cpu())
torch.testing.assert_close(weight_copy_grad.cpu(), weight_grad.cpu())
torch.testing.assert_close(inp_copy_grad.cpu(), inp_grad.cpu())
def test_view_mark_step_preserved(self):
test_device = get_test_device()
inp = torch.rand(4, device=test_device)
inp_lazy = clone_move(inp)
def foo(x, *, mark_step):
y = x.view(2, 2)
y.add_(1)
z = x + x # noqa: F841
if mark_step:
torch._lazy.mark_step()
# y and x should contiue to be aliased after the mark_step call.
y.add_(1)
return x
out_ref = foo(inp, mark_step=False)
out = foo(inp_lazy, mark_step=True)
# out will have some pending mutations, which will be synced by the .cpu() call.
torch.testing.assert_close(out_ref.cpu(), out.cpu())
def test_tensor_ctr(self):
test_device = get_test_device()
inp = torch.tensor([[1, 2, 3, 4, 5]], device=test_device)
inp_lazy = torch.tensor([[1, 2, 3, 4, 5]], device="lazy")
def foo(x):
# Calling a view op to ensure that functionalization wrapping occurs.
return x.view(-1)
out_ref = foo(inp)
out = foo(inp_lazy)
torch.testing.assert_close(out_ref.cpu(), out.cpu())
class TestLazyOpInfo(TestCase):
@ops(
[
op
for op in op_db
if op.name in LAZY_OPS_LIST
and op.name not in SKIP_RUNTIME_ERROR_LIST
and op.name not in FUNCTIONAL_DECOMPOSE_LIST
and op.formatted_name not in SKIP_VARIANT_LIST
],
allowed_dtypes=(torch.float,),
)
def test_dispatched_to_lazy(self, device, dtype, op):
def get_name(op): # noqa: F841
l = [op.name]
if op.variant_test_name != "":
l.append(op.variant_test_name)
return ".".join(l)
global HAS_SYMINT_SUFFIX, FALLBACK_LIST
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
sample = next(iter(samples))
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
torch._lazy.mark_step()
torch._lazy.wait_device_ops()
torch._lazy.metrics.reset()
op(*args, **kwargs)
torch._lazy.mark_step()
torch._lazy.wait_device_ops()
prefix = "aten" if op.name in FALLBACK_LIST else "lazy"
symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else ""
metrics = remove_suffixes(torch._lazy.metrics.counter_names())
cands = [f"{prefix}::{op.name}{symint_suffix}"]
# check aliases
for alias in op.aliases:
cands.append(f"{prefix}::{alias.name}{symint_suffix}")
self.assertTrue(
any(c in metrics for c in cands), f"none of {cands} not found in {metrics}"
)
@ops(
[
op
for op in op_db
if op.name in LAZY_OPS_LIST
and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST
],
allowed_dtypes=(torch.float,),
) # noqa: B950
def test_correctness(self, device, dtype, op):
test_device = get_test_device()
def clone_to_device(input, dev):
if isinstance(input, torch.Tensor):
return input.detach().clone().to(device=dev)
if isinstance(input, Sequence) and not isinstance(input, str):
return tuple(map(functools.partial(clone_to_device, dev=dev), input))
return input
def assert_allclose_rec(t):
a, b = t
self.assertEqual(type(a), type(b))
if isinstance(a, torch.Tensor):
self.assertTrue(
torch.allclose(clone_to_device(a, test_device), b, atol=1e-4)
)
if isinstance(a, Sequence):
map(assert_allclose_rec, zip(a, b))
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
for sample in samples:
# Need to run mark step so that all random ops are computed in the right order
torch._lazy.mark_step()
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
copy_args = clone_to_device(args, test_device)
r_exp = op(*copy_args, **kwargs)
r_actual = op(*args, **kwargs)
torch._lazy.mark_step()
assert_allclose_rec((r_actual, r_exp))
@ops(
[
op
for op in op_db
if op.name in LAZY_OPS_LIST
and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST
],
allowed_dtypes=(torch.float,),
) # noqa: B950
def test_correctness_with_reusing_ir(self, device, dtype, op):
torch._lazy.config.set_reuse_ir(True)
test_device = get_test_device()
def clone_to_device(input, dev):
if isinstance(input, torch.Tensor):
return input.detach().clone().to(device=dev)
if isinstance(input, Sequence) and not isinstance(input, str):
return tuple(map(functools.partial(clone_to_device, dev=dev), input))
return input
def assert_allclose_rec(t):
a, b = t
self.assertEqual(type(a), type(b))
if isinstance(a, torch.Tensor):
self.assertTrue(
torch.allclose(clone_to_device(a, test_device), b, atol=1e-4)
)
if isinstance(a, Sequence):
map(assert_allclose_rec, zip(a, b))
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
for sample in samples:
# Need to run mark step so that all random ops are computed in the right order
torch._lazy.mark_step()
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
copy_args = clone_to_device(args, test_device)
r_exp = op(*copy_args, **kwargs)
r_actual = op(*args, **kwargs)
torch._lazy.mark_step()
assert_allclose_rec((r_actual, r_exp))
torch._lazy.ir_cache.reset()
torch._lazy.config.set_reuse_ir(False)
# TODO: after we move to master, add Lazy as a new Device here:
# https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_device_type.py#L532
instantiate_device_type_tests(TestLazyOpInfo, globals(), only_for="cpu")
class TestLazyDynamicOps(TestCase):
@classmethod
def setUpClass(cls) -> None:
# Setup the dynamic shape mode
cls.old_ssa_mode = torch._C._lazy._get_symbolic_shape_mode()
torch._C._lazy._set_symbolic_shape_mode(True)
return super().setUpClass()
@classmethod
def tearDownClass(cls) -> None:
torch._C._lazy._set_symbolic_shape_mode(cls.old_ssa_mode)
return super().tearDownClass()
def test_nonzero_dynamic(self):
# Test that nonzero gives upper bounds sizes when symbolic shape mode is enabled
test_device = get_test_device()
x1 = torch.tensor(
[[0, 1.0, 2.0], [3.0, 0, 0]], device=test_device, requires_grad=True
)
x1_lazy = clone_move(x1)
x2_lazy = torch.nonzero(x1_lazy)
# FIXME: Add bindings to get upper bounds
# self.assertEqual(tuple(x2_lazy.size()), (6, 2))
# We should still be able to instantiate it and get the actual result
x2_eager = x2_lazy.cpu()
self.assertEqual(tuple(x2_eager.size()), (3, 2))
def test_adaptiveavgpool3d_dynamic(self):
# Test that adaptive_avg_pool3d gives correct shapes with lazy backend
img_cpu = torch.zeros([2, 3, 4, 5, 6], device="cpu")
out_cpu = torch.nn.AdaptiveAvgPool3d(2).to(device="cpu")(img_cpu)
test_device = get_test_device()
img_lazy = torch.zeros([2, 3, 4, 5, 6], device=test_device)
out_lazy = torch.nn.AdaptiveAvgPool3d(2).to(test_device)(img_lazy)
self.assertEqual(out_cpu.shape, out_lazy.shape)
if __name__ == "__main__":
run_tests()