Files
pytorch/torch/testing/_internal/hop_db.py
2025-09-17 09:32:38 +00:00

502 lines
16 KiB
Python

# mypy: ignore-errors
import functools
import unittest
import torch
from functorch.experimental.control_flow import map
from torch.nn.attention.flex_attention import _create_empty_block_mask, flex_attention
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import onlyCUDA
from torch.testing._internal.common_dtype import all_types_and, custom_types
from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
from torch._higher_order_ops import InvokeQuant, invoke_quant_packed
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
yield SampleInput(
[make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)),
)
def inner_f(x, y0, y1):
return [x[0].cos().add_(1.0) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
def simple_map(xs, y0, y1):
def f(x, y0, y1):
return inner_f(x, y0, y1)
return map(f, xs, y0, y1)
def nested_map(xs, y0, y1):
def f1(xx, y0, y1):
def f2(x, y0, y1):
return inner_f(x, y0, y1)
return map(f2, xx, y0, y1)
return map(f1, xs, y0, y1)
def triple_nested_map(xs, y0, y1):
def f0(xs, y0, y1):
def f1(xx, y0, y1):
def f2(x, y0, y1):
return inner_f(x, y0, y1)
return map(f2, xx, y0, y1)
return map(f1, xs, y0, y1)
return map(f0, xs, y0, y1)
# PLEASE DON'T ADD ANYTHING NEW TO THIS LIST,
# and do add an OpInfo for your HOP.
# The OpInfo lets us do automated testing for the HOP to check that
# your HOP will work correctly with PyTorch!
#
# Your new HOP may fail some automated testing. That's OK. If you don't
# care about certain features (like torch.export), it's fine to xfail those
# failing tests. It is less fine to xfail a more critical check (like checking
# if torch.compile works with your HOP, or if your HOP has a docstring).
# If you don't know if a test is fine to xfail, please ask.
#
# There are legitimate reasons why something cannot be added to this list
# (e.g. it uses executorch which is not in PyTorch). If that's the case then
# please leave a comment.
FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
"custom_function_call",
"autograd_function_apply",
"run_and_save_rng_state",
"run_with_rng_state",
"graphsafe_run_with_rng_state",
"out_dtype",
"trace_wrapped",
'tag_activation_checkpoint',
'executorch_call_delegate',
'wrap',
'wrap_with_set_grad_enabled',
'auto_functionalized_v2',
'associative_scan',
'flat_apply', # is WIP, doesn't pass any of the tests yet
'wrap_with_autocast',
'wrap_activation_checkpoint',
'run_const_graph',
'auto_functionalized',
"map", # T183144629
"map_impl",
"with_effects",
"strict_mode",
"_export_tracepoint",
"call_torchbind",
"triton_kernel_wrapper_mutation",
"triton_kernel_wrapper_functional",
"hints_wrapper",
"dynamo_bypassing_wrapper", # TODO(soulitzer)
"foreach_map",
"aoti_call_delegate",
]
torch.library.define(
"testlib::mutating_custom_op",
"(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)
@torch.library.impl("testlib::mutating_custom_op", "cpu")
def foo_impl_cpu(x, z):
x.add_(5)
z.add_(5)
return x, z, x + z
@torch.library.impl("testlib::mutating_custom_op", "cuda")
def foo_impl_cuda(x, z):
x.add_(5)
z.add_(5)
return x, z, x + z
@torch.library.register_fake("testlib::mutating_custom_op")
def foo_impl_abstract(x, z):
return x, z, x + z
def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
def simple_cond(x):
return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x])
def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
@mark_compile_region
def fn_for_invoke_subgraph(x):
return torch.sin(x)
def simple_invoke_subgraph(x):
return fn_for_invoke_subgraph(x)
def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=False
)
yield SampleInput(
make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)
)
def simple_auto_functionalize(x, z):
return torch.ops.testlib.mutating_custom_op(x, z)
def sample_inputs_flex_attention(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
def score_mod(score, b, h, m, n):
return score + h
q, k, v = (make_arg(2, 2, 128, 8, low=0.1, high=2) for _ in range(3))
block_mask = _create_empty_block_mask(q, k)
yield SampleInput(q, k, v, score_mod, block_mask)
def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=False
)
yield SampleInput(
torch.tensor(3),
make_arg(2, 3, 4, low=0.1, high=2),
)
def simple_while_loop(iter_t, x):
def cond_fn(iter_t, x):
return iter_t > 0
def body_fn(iter_t, x):
return iter_t - 1, x.cos()
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x))
def simple_while_loop_stack_output(iter_t, x):
def cond_fn(iter_t, x):
return iter_t > 0
def body_fn(iter_t, x):
return iter_t - 1, x.cos()
return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple())
def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs):
# TODO: once HOPs support DTensor inputs, we should also test DTensors
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=False
)
yield SampleInput(
make_arg(2, 3, 4, low=0.1, high=2),
make_arg(2, 3, 4, low=0.1, high=2),
)
def simple_local_map_hop(inp1, inp2):
def body_gm(inp1, inp2):
return inp1.cos() + inp2.sin()
gm = torch.fx.symbolic_trace(body_gm)
assert torch.distributed.is_available()
from torch.distributed.tensor.placement_types import Replicate
gm.meta["local_map_kwargs"] = {
"in_placements": (Replicate(), Replicate(), Replicate()),
"out_placements": ((Replicate(), Replicate(), Replicate()),)
}
# TODO: Dynamo would rewrite this op differently
return torch._higher_order_ops.local_map_hop(gm, inp1, inp2)
def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
yield SampleInput(
make_arg(2, 2, low=0.1, high=2),
make_arg(2, 2, 2, low=0.1, high=2),
)
def simple_scan(init, xs):
def combine_fn(carry, x):
result = carry @ x + x
return result, carry.clone()
return torch._higher_order_ops.scan(combine_fn, init, xs)
quant_tracer = InvokeQuant()
def simple_invoke_quant(x):
def fn(x, y):
return (torch.sin(x) * y,)
return quant_tracer(fn, x, x)[0] * 2.
def simple_invoke_quant_packed(x):
def fn(x):
return (torch.sin(x),)
return invoke_quant_packed(fn, x)[0] * 2.
hop_db = [
OpInfo(
name="scan",
variant_test_name="simple",
op=simple_scan,
sample_inputs_func=sample_inputs_scan,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=False,
# "torch.compile with aot_autograd does not currently support double backward."
supports_gradgrad=False,
),
OpInfo(
name="invoke_subgraph",
variant_test_name="simple",
op=simple_invoke_subgraph,
sample_inputs_func=sample_inputs_invoke_subgraph,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=True,
# "torch.compile with aot_autograd does not currently support double backward."
supports_gradgrad=False,
),
OpInfo(
name="map",
variant_test_name="simple",
op=simple_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
),
OpInfo(
name="map",
variant_test_name="nested",
op=nested_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
),
OpInfo(
name="map",
variant_test_name="triple_nested",
op=triple_nested_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
),
OpInfo(
name="cond",
variant_test_name="simple",
op=simple_cond,
sample_inputs_func=sample_inputs_cond,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=True,
# "torch.compile with aot_autograd does not currently support double backward."
supports_gradgrad=False,
),
OpInfo(
name="invoke_quant",
variant_test_name="simple",
op=simple_invoke_quant,
sample_inputs_func=sample_inputs_invoke_subgraph,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=True,
# "torch.compile with aot_autograd does not currently support double backward."
skips=(
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
DecorateInfo(
unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
),
# "torch.compile with aot_autograd does not currently support double backward."
supports_gradgrad=False,
),
OpInfo(
name="invoke_quant_packed",
variant_test_name="simple",
op=simple_invoke_quant_packed,
sample_inputs_func=sample_inputs_invoke_subgraph,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=True,
# "torch.compile with aot_autograd does not currently support double backward."
supports_gradgrad=False,
),
OpInfo(
name="while_loop",
variant_test_name="simple",
op=simple_while_loop,
sample_inputs_func=sample_inputs_while_loop,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=False,
),
OpInfo(
name="while_loop_stack_output",
variant_test_name="simple",
op=simple_while_loop_stack_output,
sample_inputs_func=sample_inputs_while_loop,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=False,
),
OpInfo(
name="auto_functionalize",
variant_test_name="simple",
op=simple_auto_functionalize,
sample_inputs_func=sample_inputs_auto_functionalize,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=False,
),
OpInfo(
name="flex_attention",
variant_test_name="simple",
op=flex_attention,
sample_inputs_func=sample_inputs_flex_attention,
dtypes=custom_types(torch.float16, torch.float32),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
skips=(
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
DecorateInfo(
unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
),
decorators=[onlyCUDA],
),
OpInfo(
name="flex_attention_backward",
variant_test_name="simple",
op=flex_attention,
sample_inputs_func=sample_inputs_flex_attention,
dtypes=custom_types(torch.float16, torch.float32),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
skips=(
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
DecorateInfo(
unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
),
decorators=[onlyCUDA],
),
OpInfo(
name="local_map_hop",
variant_test_name="simple",
op=simple_local_map_hop,
sample_inputs_func=sample_inputs_local_map_hop,
dtypes=custom_types(torch.float16, torch.float32),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
skips=(
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
DecorateInfo(
unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
),
decorators=[onlyCUDA, unittest.skipIf(not torch.distributed.is_available(), "requires distributed build")],
),
]