[HOO] add hints_wrapper to support passing context hints (#132860)

Fixes #126393

The implementation code is based on feedback here (https://github.com/pytorch/pytorch/pull/121639#issuecomment-2223948842).

Hints are passed as kwargs of hints_wrapper op. It also supports nested hints.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132860
Approved by: https://github.com/ydwu4, https://github.com/zou3519
This commit is contained in:
Wuxun Zhang
2024-08-26 18:21:20 +00:00
committed by PyTorch MergeBot
parent 1ccc8f0200
commit 1d231ff8ba
6 changed files with 435 additions and 1 deletions

View File

@ -23,6 +23,7 @@ from torch._dynamo.testing import (
normalize_gm,
)
from torch._dynamo.utils import counters, ifdynstaticdefault
from torch._higher_order_ops.hints_wrap import hints_wrapper
from torch._higher_order_ops.wrap import wrap
from torch.testing._internal.common_utils import (
munge_exc,
@ -2502,6 +2503,139 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
):
torch.compile(fn, backend="eager")(pred, pytree_in)
def test_hints_wrapper(self):
def ref_fn(x, y):
x = x + y
x = torch.relu(x)
x = x + y
return torch.abs(x)
def fn_with_hints(x, y):
x = x + y
def inner_body_fn(x, y):
x = torch.relu(x)
x = x + y
return x
def outer_body_fn(x, y):
x = hints_wrapper(inner_body_fn, (x, y), {}, hints={"inner_body": True})
x = torch.abs(x)
return x
res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"outer_body": True})
return res
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
x = torch.randn(2, 4)
y = torch.ones(4)
eager_res = fn_with_hints(x, y)
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
ref_res = ref_fn(x, y)
self.assertEqual(eager_res, ref_res)
self.assertEqual(compiled_res, ref_res)
self.assertEqual(len(cnt.graphs), 1)
# Dynamic shapes produce a slightly different graph.
if check_dynamic_shape_capture():
return
graph = backend.graphs[0]
self.assertExpectedInline(
normalize_gm(graph.print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[2, 4]", L_y_: "f32[4]"):
l_x_ = L_x_
l_y_ = L_y_
x: "f32[2, 4]" = l_x_ + l_y_; l_x_ = None
hints_wrapper_body_1 = self.hints_wrapper_body_1
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
return (res,)
class hints_wrapper_body_1(torch.nn.Module):
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
hints_wrapper_body_0 = self.hints_wrapper_body_0
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None
return (x_2,)
class hints_wrapper_body_0(torch.nn.Module):
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
x_1: "f32[2, 4]" = torch.relu(x); x = None
x_2: "f32[2, 4]" = x_1 + l_y_; x_1 = l_y_ = None
return (x_2,)
""",
)
def test_hints_wrapper_no_hints(self):
def fn_with_hints(x, y):
def outer_body_fn(x, y):
x = torch.add(x, y)
return x
res = hints_wrapper(outer_body_fn, (x, y), {})
return res
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
x = torch.randn(2, 4)
y = torch.ones(4)
msg = "hints_wrapper - key hints not provided"
with self.assertRaisesRegex(RuntimeError, msg):
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
def test_hints_wrapper_incorrect_type(self):
def fn_with_hints(x, y):
def outer_body_fn(x, y):
x = torch.add(x, y)
return x
res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"test": (True,)})
return res
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
x = torch.randn(2, 4)
y = torch.ones(4)
msg = r"hints must be a dict containing int, float, bool or str value,"
with self.assertRaisesRegex(RuntimeError, msg):
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
def test_hints_wrapper_pytree_inputs(self):
def fn_with_hints(x, y):
def outer_body_fn(x):
res = torch.add(x[0], x[1]["test"])
return res
res = hints_wrapper(
outer_body_fn, ((x, {"test": y}),), {}, hints={"test": True}
)
return res
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
x = torch.randn(2, 4)
y = torch.ones(4)
msg = r"args must be a tuple of tensors, ints, floats, or bools,"
with self.assertRaisesRegex(RuntimeError, msg):
fn_with_hints(x, y)
class HigherOrderOpVmapGuardTests(LoggingTestCase):
@make_logging_test(recompiles=True)

View File

@ -20,6 +20,7 @@ from functorch.experimental.control_flow import cond, map
from torch import Tensor
from torch._decomp import get_decompositions
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import normalize_gm
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch._export.utils import (
get_buffer,
@ -28,6 +29,7 @@ from torch._export.utils import (
is_param,
register_dataclass_as_pytree_node,
)
from torch._higher_order_ops.hints_wrap import hints_wrapper
from torch._inductor.compile_fx import split_const_gm
from torch._subclasses import FakeTensorMode
from torch.export import Dim, export, unflatten
@ -7028,6 +7030,69 @@ def forward(self, x, y):
],
)
@testing.expectedFailureNonStrict
@testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked?
@testing.expectedFailureSerDer # T195866111
def test_hints_wrapper(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
x = x + y
def inner_body_fn(x, y):
x = torch.relu(x)
x = x + y
return x
def outer_body_fn(x, y):
x = hints_wrapper(
inner_body_fn, (x, y), {}, hints={"inner_body": True}
)
x = torch.abs(x)
return x
res = hints_wrapper(
outer_body_fn, (x, y), {}, hints={"outer_body": True}
)
return res
x = torch.randn(2, 4)
y = torch.ones(4)
ep = export(M(), (x, y))
export_res = ep.module()(x, y)
ref_res = M()(x, y)
self.assertEqual(export_res, ref_res)
self.assertExpectedInline(
normalize_gm(ep.graph_module.print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[2, 4]", y: "f32[4]"):
add: "f32[2, 4]" = torch.ops.aten.add.Tensor(x, y); x = None
hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (add, y), {}, hints = {'outer_body': True}); hints_wrapper_body_graph_0 = add = y = None
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
return (getitem,)
class hints_wrapper_body_graph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (arg0_1, arg1_1), {}, hints = {'inner_body': True}); hints_wrapper_body_graph_0 = arg0_1 = arg1_1 = None
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
abs_1: "f32[2, 4]" = torch.ops.aten.abs.default(getitem); getitem = None
return (abs_1,)
class hints_wrapper_body_graph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
relu: "f32[2, 4]" = torch.ops.aten.relu.default(arg0_1); arg0_1 = None
add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None
return (add,)
""",
)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):

View File

@ -24,7 +24,12 @@ from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils import _pytree as pytree
from .. import variables
from ..exc import UncapturedHigherOrderOpError, unimplemented, Unsupported
from ..exc import (
IncorrectUsage,
UncapturedHigherOrderOpError,
unimplemented,
Unsupported,
)
from ..source import AttrSource
from ..utils import proxy_args_kwargs
from .dicts import ConstDictVariable
@ -578,6 +583,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
return OutDtypeHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "wrap":
return WrapHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "hints_wrapper":
return HintsWrapperHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "flex_attention":
return FlexAttentionHigherOrderVariable(value, source, **kwargs)
elif value.__name__ in (
@ -1431,6 +1438,80 @@ class WrapWithSetGradEnabledHigherOrderVariable(TorchHigherOrderOperatorVariable
)
class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable):
@raise_hard_error_if_graph_break(
reason="Hints_wrapper doesn't work unless it is captured completely with torch.compile."
)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
_check_supported_callable_arg(tx, args[0], "body_fn")
# inputs
if len(args) != 3:
unimplemented(
f"Expected 3 arguments but got {len(args)}.\n"
f"Usage: hints_wrapper(body_fn, args, kwargs, hints).\n"
f"kwargs required to be provided explicitly."
)
if not isinstance(args[1], (ListVariable, TupleVariable)):
unimplemented(
f"Expected a tuple but got {args[1].python_type()}",
)
operands = args[1].unpack_var_sequence(tx)
if not isinstance(args[2], ConstDictVariable):
unimplemented(
f"Expected a dict but got {args[2].python_type()}",
)
if "hints" not in kwargs:
raise IncorrectUsage("hints_wrapper - key hints not provided")
(
(body_r, treespec),
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
tx,
args[0], # function
operands,
args[2].as_python_constant(),
"hints_wrapper",
source_target=self.value,
should_flatten_outputs=True,
)
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
body_name = add_subgraph(
tx,
"hints_wrapper_body",
body_gmod,
)
body_node = make_attr(tx, body_name)
# Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`,
# all the arguments are lifted.
lifted_args = tuple(arg for arg in body_lifted_freevars.keys())
p_args = (body_node, lifted_args, {})
p_kwargs = {}
# add hints into p_kwargs
p_kwargs["hints"] = kwargs["hints"].as_python_constant()
flat_example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
return _call_function_and_unflatten_output(
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
)
class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self,

View File

@ -3,6 +3,7 @@ from torch._higher_order_ops.flex_attention import (
flex_attention,
flex_attention_backward,
)
from torch._higher_order_ops.hints_wrap import hints_wrapper
from torch._higher_order_ops.while_loop import while_loop
@ -11,4 +12,5 @@ __all__ = [
"while_loop",
"flex_attention",
"flex_attention_backward",
"hints_wrapper",
]

View File

@ -0,0 +1,151 @@
# mypy: allow-untyped-defs
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
autograd_not_implemented,
reenter_make_fx,
unique_graph_id,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
# used for wrapping a function/op with context hints
class HintsWrapper(HigherOrderOperator):
def __init__(self):
super().__init__("hints_wrapper")
def __call__(self, body_fn, args, kwargs, hints):
r"""
Call implementation of hints_wrapper
Args:
body_fn (Callable): A callable function that is within the scope
that is being traced.
args (Tuple of torch.Tensor/int/float/bool): A tuple of inputs to
body_fn.
kwargs (dict): Keyword argument to the body_fn.
hints (dict): A dict of context hints which could be passed to
backend compiler.
"""
if not isinstance(args, tuple):
raise RuntimeError(f"args must be a tuple, got {type(args)}")
if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args):
raise RuntimeError(
"args must be a tuple of tensors, ints, floats, or bools, got "
f"{args}"
)
if not isinstance(kwargs, dict):
raise RuntimeError(f"kwargs must be a dict, got {type(kwargs)}")
if len(kwargs) > 0:
raise RuntimeError(
f"kwargs except for hints are not supported, got {kwargs}"
)
if not isinstance(hints, dict):
raise RuntimeError(f"hints must be a dict, got {type(hints)}")
for k, v in hints.items():
if not isinstance(k, str):
raise RuntimeError(f"hints key must be a str, got {k}.")
if not isinstance(v, (int, float, bool, str)):
raise RuntimeError(
"hints must be a dict containing int, float, bool or str "
f"value, got value {v} for key {k}."
)
return super().__call__(body_fn, args, kwargs, hints)
hints_wrapper = HintsWrapper()
@hints_wrapper.py_impl(DispatchKey.CompositeExplicitAutograd)
def hints_wrapper_dense(body_fn, args, kwargs, hints):
return body_fn(*args, **kwargs)
hints_wrapper.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(hints_wrapper, deferred_error=True)
)
@hints_wrapper.py_impl(FakeTensorMode)
def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints):
flat_args = pytree.tree_leaves(args)
with mode:
return body_func(*flat_args, **kwargs)
@hints_wrapper.py_functionalize_impl
def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints):
unwrapped_args = ctx.unwrap_tensors(args)
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
unwrapped_hints = ctx.unwrap_tensors(hints)
with ctx.redispatch_to_next():
functional_body_fn = ctx.functionalize(body_fn)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
if _has_potential_branch_input_mutation(
functional_body_fn, unwrapped_args, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"body_fn of hints_wrapper might be modifying the input!"
)
if _has_potential_branch_input_alias(
functional_body_fn, unwrapped_args, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"body_fn of hints_wrapper might be aliasing the input!"
)
outputs = hints_wrapper(
functional_body_fn,
unwrapped_args,
unwrapped_kwargs,
unwrapped_hints,
)
return ctx.wrap_tensors(outputs)
def trace_hints_wrapper(proxy_mode, hints_wrapper, body_fn, args, kwargs, hints):
flat_args = tuple(pytree.tree_leaves(args))
body_graph = reenter_make_fx(body_fn)(*flat_args, **kwargs)
_, body_graph_name = unique_graph_id(proxy_mode, prefix="hints_wrapper_body_graph")
proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
new_args: tuple = (body_graph, flat_args, {})
# merge hints into kwargs
new_kwargs = {}
new_kwargs["hints"] = hints
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_args)
proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_kwargs)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", hints_wrapper, proxy_args, proxy_kwargs, name="hints_wrapper"
)
out = body_fn(*flat_args, **kwargs)
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@hints_wrapper.py_impl(ProxyTorchDispatchMode)
def inner(proxy_mode, body_fn, args, kwargs, hints):
if proxy_mode.enable_tracing:
return trace_hints_wrapper(
proxy_mode, hints_wrapper, body_fn, args, kwargs, hints
)
else:
return hints_wrapper(body_fn, args, kwargs, hints)

View File

@ -61,6 +61,7 @@ hop_that_doesnt_have_opinfo_test_allowlist = [
"call_torchbind",
"triton_kernel_wrapper_mutation",
"triton_kernel_wrapper_functional",
"hints_wrapper",
]
torch.library.define(