mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Can't actually deploy it because of: https://github.com/pytorch/pytorch/issues/161456 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161458 Approved by: https://github.com/ydwu4
502 lines
16 KiB
Python
502 lines
16 KiB
Python
# mypy: ignore-errors
|
|
|
|
import functools
|
|
import unittest
|
|
|
|
import torch
|
|
from functorch.experimental.control_flow import map
|
|
from torch.nn.attention.flex_attention import _create_empty_block_mask, flex_attention
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_device_type import onlyCUDA
|
|
from torch.testing._internal.common_dtype import all_types_and, custom_types
|
|
from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput
|
|
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
|
|
from torch._higher_order_ops import InvokeQuant, invoke_quant_packed
|
|
|
|
|
|
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(
|
|
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
|
|
)
|
|
yield SampleInput(
|
|
[make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
|
|
args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)),
|
|
)
|
|
|
|
|
|
def inner_f(x, y0, y1):
|
|
return [x[0].cos().add_(1.0) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
|
|
|
|
|
|
def simple_map(xs, y0, y1):
|
|
def f(x, y0, y1):
|
|
return inner_f(x, y0, y1)
|
|
|
|
return map(f, xs, y0, y1)
|
|
|
|
|
|
def nested_map(xs, y0, y1):
|
|
def f1(xx, y0, y1):
|
|
def f2(x, y0, y1):
|
|
return inner_f(x, y0, y1)
|
|
|
|
return map(f2, xx, y0, y1)
|
|
|
|
return map(f1, xs, y0, y1)
|
|
|
|
|
|
def triple_nested_map(xs, y0, y1):
|
|
def f0(xs, y0, y1):
|
|
def f1(xx, y0, y1):
|
|
def f2(x, y0, y1):
|
|
return inner_f(x, y0, y1)
|
|
|
|
return map(f2, xx, y0, y1)
|
|
|
|
return map(f1, xs, y0, y1)
|
|
|
|
return map(f0, xs, y0, y1)
|
|
|
|
|
|
# PLEASE DON'T ADD ANYTHING NEW TO THIS LIST,
|
|
# and do add an OpInfo for your HOP.
|
|
# The OpInfo lets us do automated testing for the HOP to check that
|
|
# your HOP will work correctly with PyTorch!
|
|
#
|
|
# Your new HOP may fail some automated testing. That's OK. If you don't
|
|
# care about certain features (like torch.export), it's fine to xfail those
|
|
# failing tests. It is less fine to xfail a more critical check (like checking
|
|
# if torch.compile works with your HOP, or if your HOP has a docstring).
|
|
# If you don't know if a test is fine to xfail, please ask.
|
|
#
|
|
# There are legitimate reasons why something cannot be added to this list
|
|
# (e.g. it uses executorch which is not in PyTorch). If that's the case then
|
|
# please leave a comment.
|
|
FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
|
|
"custom_function_call",
|
|
"autograd_function_apply",
|
|
"run_and_save_rng_state",
|
|
"run_with_rng_state",
|
|
"graphsafe_run_with_rng_state",
|
|
"out_dtype",
|
|
"trace_wrapped",
|
|
'tag_activation_checkpoint',
|
|
'executorch_call_delegate',
|
|
'wrap',
|
|
'wrap_with_set_grad_enabled',
|
|
'auto_functionalized_v2',
|
|
'associative_scan',
|
|
'flat_apply', # is WIP, doesn't pass any of the tests yet
|
|
'wrap_with_autocast',
|
|
'wrap_activation_checkpoint',
|
|
'run_const_graph',
|
|
'auto_functionalized',
|
|
"map", # T183144629
|
|
"map_impl",
|
|
"with_effects",
|
|
"strict_mode",
|
|
"_export_tracepoint",
|
|
"call_torchbind",
|
|
"triton_kernel_wrapper_mutation",
|
|
"triton_kernel_wrapper_functional",
|
|
"hints_wrapper",
|
|
"dynamo_bypassing_wrapper", # TODO(soulitzer)
|
|
"foreach_map",
|
|
"aoti_call_delegate",
|
|
]
|
|
|
|
torch.library.define(
|
|
"testlib::mutating_custom_op",
|
|
"(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
)
|
|
|
|
|
|
@torch.library.impl("testlib::mutating_custom_op", "cpu")
|
|
def foo_impl_cpu(x, z):
|
|
x.add_(5)
|
|
z.add_(5)
|
|
return x, z, x + z
|
|
|
|
|
|
@torch.library.impl("testlib::mutating_custom_op", "cuda")
|
|
def foo_impl_cuda(x, z):
|
|
x.add_(5)
|
|
z.add_(5)
|
|
return x, z, x + z
|
|
|
|
|
|
@torch.library.register_fake("testlib::mutating_custom_op")
|
|
def foo_impl_abstract(x, z):
|
|
return x, z, x + z
|
|
|
|
|
|
def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(
|
|
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
|
|
)
|
|
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
|
|
|
|
|
|
def simple_cond(x):
|
|
return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x])
|
|
|
|
|
|
def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(
|
|
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
|
|
)
|
|
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
|
|
|
|
|
|
@mark_compile_region
|
|
def fn_for_invoke_subgraph(x):
|
|
return torch.sin(x)
|
|
|
|
def simple_invoke_subgraph(x):
|
|
return fn_for_invoke_subgraph(x)
|
|
|
|
|
|
def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(
|
|
make_tensor, device=device, dtype=dtype, requires_grad=False
|
|
)
|
|
yield SampleInput(
|
|
make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)
|
|
)
|
|
|
|
|
|
def simple_auto_functionalize(x, z):
|
|
return torch.ops.testlib.mutating_custom_op(x, z)
|
|
|
|
|
|
def sample_inputs_flex_attention(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(
|
|
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
|
|
)
|
|
|
|
def score_mod(score, b, h, m, n):
|
|
return score + h
|
|
|
|
q, k, v = (make_arg(2, 2, 128, 8, low=0.1, high=2) for _ in range(3))
|
|
block_mask = _create_empty_block_mask(q, k)
|
|
yield SampleInput(q, k, v, score_mod, block_mask)
|
|
|
|
|
|
def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(
|
|
make_tensor, device=device, dtype=dtype, requires_grad=False
|
|
)
|
|
yield SampleInput(
|
|
torch.tensor(3),
|
|
make_arg(2, 3, 4, low=0.1, high=2),
|
|
)
|
|
|
|
|
|
def simple_while_loop(iter_t, x):
|
|
def cond_fn(iter_t, x):
|
|
return iter_t > 0
|
|
|
|
def body_fn(iter_t, x):
|
|
return iter_t - 1, x.cos()
|
|
|
|
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x))
|
|
|
|
def simple_while_loop_stack_output(iter_t, x):
|
|
def cond_fn(iter_t, x):
|
|
return iter_t > 0
|
|
|
|
def body_fn(iter_t, x):
|
|
return iter_t - 1, x.cos()
|
|
|
|
return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple())
|
|
|
|
|
|
def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs):
|
|
# TODO: once HOPs support DTensor inputs, we should also test DTensors
|
|
make_arg = functools.partial(
|
|
make_tensor, device=device, dtype=dtype, requires_grad=False
|
|
)
|
|
yield SampleInput(
|
|
make_arg(2, 3, 4, low=0.1, high=2),
|
|
make_arg(2, 3, 4, low=0.1, high=2),
|
|
)
|
|
|
|
|
|
def simple_local_map_hop(inp1, inp2):
|
|
def body_gm(inp1, inp2):
|
|
return inp1.cos() + inp2.sin()
|
|
gm = torch.fx.symbolic_trace(body_gm)
|
|
|
|
assert torch.distributed.is_available()
|
|
from torch.distributed.tensor.placement_types import Replicate
|
|
gm.meta["local_map_kwargs"] = {
|
|
"in_placements": (Replicate(), Replicate(), Replicate()),
|
|
"out_placements": ((Replicate(), Replicate(), Replicate()),)
|
|
}
|
|
|
|
# TODO: Dynamo would rewrite this op differently
|
|
return torch._higher_order_ops.local_map_hop(gm, inp1, inp2)
|
|
|
|
def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(
|
|
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
|
|
)
|
|
yield SampleInput(
|
|
make_arg(2, 2, low=0.1, high=2),
|
|
make_arg(2, 2, 2, low=0.1, high=2),
|
|
)
|
|
|
|
|
|
def simple_scan(init, xs):
|
|
|
|
def combine_fn(carry, x):
|
|
result = carry @ x + x
|
|
return result, carry.clone()
|
|
|
|
return torch._higher_order_ops.scan(combine_fn, init, xs)
|
|
|
|
|
|
quant_tracer = InvokeQuant()
|
|
|
|
|
|
def simple_invoke_quant(x):
|
|
def fn(x, y):
|
|
return (torch.sin(x) * y,)
|
|
|
|
return quant_tracer(fn, x, x)[0] * 2.
|
|
|
|
|
|
def simple_invoke_quant_packed(x):
|
|
def fn(x):
|
|
return (torch.sin(x),)
|
|
|
|
return invoke_quant_packed(fn, x)[0] * 2.
|
|
|
|
|
|
|
|
hop_db = [
|
|
OpInfo(
|
|
name="scan",
|
|
variant_test_name="simple",
|
|
op=simple_scan,
|
|
sample_inputs_func=sample_inputs_scan,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
supports_autograd=False,
|
|
# "torch.compile with aot_autograd does not currently support double backward."
|
|
supports_gradgrad=False,
|
|
),
|
|
OpInfo(
|
|
name="invoke_subgraph",
|
|
variant_test_name="simple",
|
|
op=simple_invoke_subgraph,
|
|
sample_inputs_func=sample_inputs_invoke_subgraph,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
supports_autograd=True,
|
|
# "torch.compile with aot_autograd does not currently support double backward."
|
|
supports_gradgrad=False,
|
|
),
|
|
OpInfo(
|
|
name="map",
|
|
variant_test_name="simple",
|
|
op=simple_map,
|
|
sample_inputs_func=sample_inputs_map,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
),
|
|
OpInfo(
|
|
name="map",
|
|
variant_test_name="nested",
|
|
op=nested_map,
|
|
sample_inputs_func=sample_inputs_map,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
),
|
|
OpInfo(
|
|
name="map",
|
|
variant_test_name="triple_nested",
|
|
op=triple_nested_map,
|
|
sample_inputs_func=sample_inputs_map,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
),
|
|
OpInfo(
|
|
name="cond",
|
|
variant_test_name="simple",
|
|
op=simple_cond,
|
|
sample_inputs_func=sample_inputs_cond,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
supports_autograd=True,
|
|
# "torch.compile with aot_autograd does not currently support double backward."
|
|
supports_gradgrad=False,
|
|
),
|
|
OpInfo(
|
|
name="invoke_quant",
|
|
variant_test_name="simple",
|
|
op=simple_invoke_quant,
|
|
sample_inputs_func=sample_inputs_invoke_subgraph,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
supports_autograd=True,
|
|
# "torch.compile with aot_autograd does not currently support double backward."
|
|
skips=(
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
|
|
),
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
|
|
),
|
|
# "torch.compile with aot_autograd does not currently support double backward."
|
|
supports_gradgrad=False,
|
|
),
|
|
OpInfo(
|
|
name="invoke_quant_packed",
|
|
variant_test_name="simple",
|
|
op=simple_invoke_quant_packed,
|
|
sample_inputs_func=sample_inputs_invoke_subgraph,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
supports_autograd=True,
|
|
# "torch.compile with aot_autograd does not currently support double backward."
|
|
supports_gradgrad=False,
|
|
),
|
|
OpInfo(
|
|
name="while_loop",
|
|
variant_test_name="simple",
|
|
op=simple_while_loop,
|
|
sample_inputs_func=sample_inputs_while_loop,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
supports_autograd=False,
|
|
),
|
|
OpInfo(
|
|
name="while_loop_stack_output",
|
|
variant_test_name="simple",
|
|
op=simple_while_loop_stack_output,
|
|
sample_inputs_func=sample_inputs_while_loop,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
supports_autograd=False,
|
|
),
|
|
OpInfo(
|
|
name="auto_functionalize",
|
|
variant_test_name="simple",
|
|
op=simple_auto_functionalize,
|
|
sample_inputs_func=sample_inputs_auto_functionalize,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
supports_autograd=False,
|
|
),
|
|
OpInfo(
|
|
name="flex_attention",
|
|
variant_test_name="simple",
|
|
op=flex_attention,
|
|
sample_inputs_func=sample_inputs_flex_attention,
|
|
dtypes=custom_types(torch.float16, torch.float32),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
skips=(
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
|
|
),
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
|
|
),
|
|
decorators=[onlyCUDA],
|
|
),
|
|
OpInfo(
|
|
name="flex_attention_backward",
|
|
variant_test_name="simple",
|
|
op=flex_attention,
|
|
sample_inputs_func=sample_inputs_flex_attention,
|
|
dtypes=custom_types(torch.float16, torch.float32),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
skips=(
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
|
|
),
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
|
|
),
|
|
decorators=[onlyCUDA],
|
|
),
|
|
OpInfo(
|
|
name="local_map_hop",
|
|
variant_test_name="simple",
|
|
op=simple_local_map_hop,
|
|
sample_inputs_func=sample_inputs_local_map_hop,
|
|
dtypes=custom_types(torch.float16, torch.float32),
|
|
supports_out=False,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
check_batched_forward_grad=False,
|
|
check_inplace_batched_forward_grad=False,
|
|
skips=(
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
|
|
),
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
|
|
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
|
|
),
|
|
decorators=[onlyCUDA, unittest.skipIf(not torch.distributed.is_available(), "requires distributed build")],
|
|
),
|
|
]
|