mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[HOO] add hints_wrapper to support passing context hints (#132860)
Fixes #126393 The implementation code is based on feedback here (https://github.com/pytorch/pytorch/pull/121639#issuecomment-2223948842). Hints are passed as kwargs of hints_wrapper op. It also supports nested hints. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132860 Approved by: https://github.com/ydwu4, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ccc8f0200
commit
1d231ff8ba
@ -23,6 +23,7 @@ from torch._dynamo.testing import (
|
||||
normalize_gm,
|
||||
)
|
||||
from torch._dynamo.utils import counters, ifdynstaticdefault
|
||||
from torch._higher_order_ops.hints_wrap import hints_wrapper
|
||||
from torch._higher_order_ops.wrap import wrap
|
||||
from torch.testing._internal.common_utils import (
|
||||
munge_exc,
|
||||
@ -2502,6 +2503,139 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
|
||||
):
|
||||
torch.compile(fn, backend="eager")(pred, pytree_in)
|
||||
|
||||
def test_hints_wrapper(self):
|
||||
def ref_fn(x, y):
|
||||
x = x + y
|
||||
x = torch.relu(x)
|
||||
x = x + y
|
||||
return torch.abs(x)
|
||||
|
||||
def fn_with_hints(x, y):
|
||||
x = x + y
|
||||
|
||||
def inner_body_fn(x, y):
|
||||
x = torch.relu(x)
|
||||
x = x + y
|
||||
return x
|
||||
|
||||
def outer_body_fn(x, y):
|
||||
x = hints_wrapper(inner_body_fn, (x, y), {}, hints={"inner_body": True})
|
||||
x = torch.abs(x)
|
||||
return x
|
||||
|
||||
res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"outer_body": True})
|
||||
return res
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
cnt = CompileCounterWithBackend(backend)
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
y = torch.ones(4)
|
||||
|
||||
eager_res = fn_with_hints(x, y)
|
||||
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
|
||||
ref_res = ref_fn(x, y)
|
||||
self.assertEqual(eager_res, ref_res)
|
||||
self.assertEqual(compiled_res, ref_res)
|
||||
self.assertEqual(len(cnt.graphs), 1)
|
||||
|
||||
# Dynamic shapes produce a slightly different graph.
|
||||
if check_dynamic_shape_capture():
|
||||
return
|
||||
|
||||
graph = backend.graphs[0]
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(graph.print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[2, 4]", L_y_: "f32[4]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
x: "f32[2, 4]" = l_x_ + l_y_; l_x_ = None
|
||||
|
||||
hints_wrapper_body_1 = self.hints_wrapper_body_1
|
||||
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
|
||||
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
return (res,)
|
||||
|
||||
class hints_wrapper_body_1(torch.nn.Module):
|
||||
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
||||
hints_wrapper_body_0 = self.hints_wrapper_body_0
|
||||
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
|
||||
x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
|
||||
x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None
|
||||
return (x_2,)
|
||||
|
||||
class hints_wrapper_body_0(torch.nn.Module):
|
||||
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
||||
x_1: "f32[2, 4]" = torch.relu(x); x = None
|
||||
|
||||
x_2: "f32[2, 4]" = x_1 + l_y_; x_1 = l_y_ = None
|
||||
return (x_2,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_hints_wrapper_no_hints(self):
|
||||
def fn_with_hints(x, y):
|
||||
def outer_body_fn(x, y):
|
||||
x = torch.add(x, y)
|
||||
return x
|
||||
|
||||
res = hints_wrapper(outer_body_fn, (x, y), {})
|
||||
return res
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
cnt = CompileCounterWithBackend(backend)
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
y = torch.ones(4)
|
||||
|
||||
msg = "hints_wrapper - key hints not provided"
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
|
||||
|
||||
def test_hints_wrapper_incorrect_type(self):
|
||||
def fn_with_hints(x, y):
|
||||
def outer_body_fn(x, y):
|
||||
x = torch.add(x, y)
|
||||
return x
|
||||
|
||||
res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"test": (True,)})
|
||||
return res
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
cnt = CompileCounterWithBackend(backend)
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
y = torch.ones(4)
|
||||
|
||||
msg = r"hints must be a dict containing int, float, bool or str value,"
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
|
||||
|
||||
def test_hints_wrapper_pytree_inputs(self):
|
||||
def fn_with_hints(x, y):
|
||||
def outer_body_fn(x):
|
||||
res = torch.add(x[0], x[1]["test"])
|
||||
return res
|
||||
|
||||
res = hints_wrapper(
|
||||
outer_body_fn, ((x, {"test": y}),), {}, hints={"test": True}
|
||||
)
|
||||
return res
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
cnt = CompileCounterWithBackend(backend)
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
y = torch.ones(4)
|
||||
|
||||
msg = r"args must be a tuple of tensors, ints, floats, or bools,"
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
fn_with_hints(x, y)
|
||||
|
||||
|
||||
class HigherOrderOpVmapGuardTests(LoggingTestCase):
|
||||
@make_logging_test(recompiles=True)
|
||||
|
@ -20,6 +20,7 @@ from functorch.experimental.control_flow import cond, map
|
||||
from torch import Tensor
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||
from torch._export.utils import (
|
||||
get_buffer,
|
||||
@ -28,6 +29,7 @@ from torch._export.utils import (
|
||||
is_param,
|
||||
register_dataclass_as_pytree_node,
|
||||
)
|
||||
from torch._higher_order_ops.hints_wrap import hints_wrapper
|
||||
from torch._inductor.compile_fx import split_const_gm
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.export import Dim, export, unflatten
|
||||
@ -7028,6 +7030,69 @@ def forward(self, x, y):
|
||||
],
|
||||
)
|
||||
|
||||
@testing.expectedFailureNonStrict
|
||||
@testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked?
|
||||
@testing.expectedFailureSerDer # T195866111
|
||||
def test_hints_wrapper(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
x = x + y
|
||||
|
||||
def inner_body_fn(x, y):
|
||||
x = torch.relu(x)
|
||||
x = x + y
|
||||
return x
|
||||
|
||||
def outer_body_fn(x, y):
|
||||
x = hints_wrapper(
|
||||
inner_body_fn, (x, y), {}, hints={"inner_body": True}
|
||||
)
|
||||
x = torch.abs(x)
|
||||
return x
|
||||
|
||||
res = hints_wrapper(
|
||||
outer_body_fn, (x, y), {}, hints={"outer_body": True}
|
||||
)
|
||||
return res
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
y = torch.ones(4)
|
||||
|
||||
ep = export(M(), (x, y))
|
||||
export_res = ep.module()(x, y)
|
||||
ref_res = M()(x, y)
|
||||
self.assertEqual(export_res, ref_res)
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(ep.graph_module.print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, x: "f32[2, 4]", y: "f32[4]"):
|
||||
add: "f32[2, 4]" = torch.ops.aten.add.Tensor(x, y); x = None
|
||||
|
||||
hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
|
||||
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (add, y), {}, hints = {'outer_body': True}); hints_wrapper_body_graph_0 = add = y = None
|
||||
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
return (getitem,)
|
||||
|
||||
class hints_wrapper_body_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
|
||||
hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
|
||||
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (arg0_1, arg1_1), {}, hints = {'inner_body': True}); hints_wrapper_body_graph_0 = arg0_1 = arg1_1 = None
|
||||
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
abs_1: "f32[2, 4]" = torch.ops.aten.abs.default(getitem); getitem = None
|
||||
return (abs_1,)
|
||||
|
||||
class hints_wrapper_body_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
|
||||
relu: "f32[2, 4]" = torch.ops.aten.relu.default(arg0_1); arg0_1 = None
|
||||
add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestOneOffModelExportResult(TestCase):
|
||||
|
@ -24,7 +24,12 @@ from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
from .. import variables
|
||||
from ..exc import UncapturedHigherOrderOpError, unimplemented, Unsupported
|
||||
from ..exc import (
|
||||
IncorrectUsage,
|
||||
UncapturedHigherOrderOpError,
|
||||
unimplemented,
|
||||
Unsupported,
|
||||
)
|
||||
from ..source import AttrSource
|
||||
from ..utils import proxy_args_kwargs
|
||||
from .dicts import ConstDictVariable
|
||||
@ -578,6 +583,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
|
||||
return OutDtypeHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "wrap":
|
||||
return WrapHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "hints_wrapper":
|
||||
return HintsWrapperHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "flex_attention":
|
||||
return FlexAttentionHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ in (
|
||||
@ -1431,6 +1438,80 @@ class WrapWithSetGradEnabledHigherOrderVariable(TorchHigherOrderOperatorVariable
|
||||
)
|
||||
|
||||
|
||||
class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
@raise_hard_error_if_graph_break(
|
||||
reason="Hints_wrapper doesn't work unless it is captured completely with torch.compile."
|
||||
)
|
||||
def call_function(
|
||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||
) -> "VariableTracker":
|
||||
_check_supported_callable_arg(tx, args[0], "body_fn")
|
||||
|
||||
# inputs
|
||||
if len(args) != 3:
|
||||
unimplemented(
|
||||
f"Expected 3 arguments but got {len(args)}.\n"
|
||||
f"Usage: hints_wrapper(body_fn, args, kwargs, hints).\n"
|
||||
f"kwargs required to be provided explicitly."
|
||||
)
|
||||
|
||||
if not isinstance(args[1], (ListVariable, TupleVariable)):
|
||||
unimplemented(
|
||||
f"Expected a tuple but got {args[1].python_type()}",
|
||||
)
|
||||
operands = args[1].unpack_var_sequence(tx)
|
||||
|
||||
if not isinstance(args[2], ConstDictVariable):
|
||||
unimplemented(
|
||||
f"Expected a dict but got {args[2].python_type()}",
|
||||
)
|
||||
|
||||
if "hints" not in kwargs:
|
||||
raise IncorrectUsage("hints_wrapper - key hints not provided")
|
||||
|
||||
(
|
||||
(body_r, treespec),
|
||||
body_graph,
|
||||
body_lifted_freevars,
|
||||
) = speculate_subgraph(
|
||||
tx,
|
||||
args[0], # function
|
||||
operands,
|
||||
args[2].as_python_constant(),
|
||||
"hints_wrapper",
|
||||
source_target=self.value,
|
||||
should_flatten_outputs=True,
|
||||
)
|
||||
|
||||
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
|
||||
body_name = add_subgraph(
|
||||
tx,
|
||||
"hints_wrapper_body",
|
||||
body_gmod,
|
||||
)
|
||||
|
||||
body_node = make_attr(tx, body_name)
|
||||
|
||||
# Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`,
|
||||
# all the arguments are lifted.
|
||||
lifted_args = tuple(arg for arg in body_lifted_freevars.keys())
|
||||
p_args = (body_node, lifted_args, {})
|
||||
|
||||
p_kwargs = {}
|
||||
# add hints into p_kwargs
|
||||
p_kwargs["hints"] = kwargs["hints"].as_python_constant()
|
||||
|
||||
flat_example_value = pytree.tree_map_only(
|
||||
torch.fx.Proxy,
|
||||
lambda a: a.node.meta["example_value"],
|
||||
body_r.as_proxy(),
|
||||
)
|
||||
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
|
||||
)
|
||||
|
||||
|
||||
class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def call_function(
|
||||
self,
|
||||
|
@ -3,6 +3,7 @@ from torch._higher_order_ops.flex_attention import (
|
||||
flex_attention,
|
||||
flex_attention_backward,
|
||||
)
|
||||
from torch._higher_order_ops.hints_wrap import hints_wrapper
|
||||
from torch._higher_order_ops.while_loop import while_loop
|
||||
|
||||
|
||||
@ -11,4 +12,5 @@ __all__ = [
|
||||
"while_loop",
|
||||
"flex_attention",
|
||||
"flex_attention_backward",
|
||||
"hints_wrapper",
|
||||
]
|
||||
|
151
torch/_higher_order_ops/hints_wrap.py
Normal file
151
torch/_higher_order_ops/hints_wrap.py
Normal file
@ -0,0 +1,151 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
autograd_not_implemented,
|
||||
reenter_make_fx,
|
||||
unique_graph_id,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
||||
|
||||
|
||||
# used for wrapping a function/op with context hints
|
||||
class HintsWrapper(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("hints_wrapper")
|
||||
|
||||
def __call__(self, body_fn, args, kwargs, hints):
|
||||
r"""
|
||||
Call implementation of hints_wrapper
|
||||
|
||||
Args:
|
||||
body_fn (Callable): A callable function that is within the scope
|
||||
that is being traced.
|
||||
|
||||
args (Tuple of torch.Tensor/int/float/bool): A tuple of inputs to
|
||||
body_fn.
|
||||
|
||||
kwargs (dict): Keyword argument to the body_fn.
|
||||
|
||||
hints (dict): A dict of context hints which could be passed to
|
||||
backend compiler.
|
||||
"""
|
||||
if not isinstance(args, tuple):
|
||||
raise RuntimeError(f"args must be a tuple, got {type(args)}")
|
||||
|
||||
if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args):
|
||||
raise RuntimeError(
|
||||
"args must be a tuple of tensors, ints, floats, or bools, got "
|
||||
f"{args}"
|
||||
)
|
||||
|
||||
if not isinstance(kwargs, dict):
|
||||
raise RuntimeError(f"kwargs must be a dict, got {type(kwargs)}")
|
||||
|
||||
if len(kwargs) > 0:
|
||||
raise RuntimeError(
|
||||
f"kwargs except for hints are not supported, got {kwargs}"
|
||||
)
|
||||
|
||||
if not isinstance(hints, dict):
|
||||
raise RuntimeError(f"hints must be a dict, got {type(hints)}")
|
||||
|
||||
for k, v in hints.items():
|
||||
if not isinstance(k, str):
|
||||
raise RuntimeError(f"hints key must be a str, got {k}.")
|
||||
|
||||
if not isinstance(v, (int, float, bool, str)):
|
||||
raise RuntimeError(
|
||||
"hints must be a dict containing int, float, bool or str "
|
||||
f"value, got value {v} for key {k}."
|
||||
)
|
||||
|
||||
return super().__call__(body_fn, args, kwargs, hints)
|
||||
|
||||
|
||||
hints_wrapper = HintsWrapper()
|
||||
|
||||
|
||||
@hints_wrapper.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def hints_wrapper_dense(body_fn, args, kwargs, hints):
|
||||
return body_fn(*args, **kwargs)
|
||||
|
||||
|
||||
hints_wrapper.py_impl(DispatchKey.Autograd)(
|
||||
autograd_not_implemented(hints_wrapper, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@hints_wrapper.py_impl(FakeTensorMode)
|
||||
def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints):
|
||||
flat_args = pytree.tree_leaves(args)
|
||||
with mode:
|
||||
return body_func(*flat_args, **kwargs)
|
||||
|
||||
|
||||
@hints_wrapper.py_functionalize_impl
|
||||
def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints):
|
||||
unwrapped_args = ctx.unwrap_tensors(args)
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||
unwrapped_hints = ctx.unwrap_tensors(hints)
|
||||
with ctx.redispatch_to_next():
|
||||
functional_body_fn = ctx.functionalize(body_fn)
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
if _has_potential_branch_input_mutation(
|
||||
functional_body_fn, unwrapped_args, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"body_fn of hints_wrapper might be modifying the input!"
|
||||
)
|
||||
if _has_potential_branch_input_alias(
|
||||
functional_body_fn, unwrapped_args, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"body_fn of hints_wrapper might be aliasing the input!"
|
||||
)
|
||||
outputs = hints_wrapper(
|
||||
functional_body_fn,
|
||||
unwrapped_args,
|
||||
unwrapped_kwargs,
|
||||
unwrapped_hints,
|
||||
)
|
||||
return ctx.wrap_tensors(outputs)
|
||||
|
||||
|
||||
def trace_hints_wrapper(proxy_mode, hints_wrapper, body_fn, args, kwargs, hints):
|
||||
flat_args = tuple(pytree.tree_leaves(args))
|
||||
body_graph = reenter_make_fx(body_fn)(*flat_args, **kwargs)
|
||||
|
||||
_, body_graph_name = unique_graph_id(proxy_mode, prefix="hints_wrapper_body_graph")
|
||||
proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
|
||||
|
||||
new_args: tuple = (body_graph, flat_args, {})
|
||||
# merge hints into kwargs
|
||||
new_kwargs = {}
|
||||
new_kwargs["hints"] = hints
|
||||
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_args)
|
||||
proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_kwargs)
|
||||
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", hints_wrapper, proxy_args, proxy_kwargs, name="hints_wrapper"
|
||||
)
|
||||
|
||||
out = body_fn(*flat_args, **kwargs)
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
|
||||
|
||||
@hints_wrapper.py_impl(ProxyTorchDispatchMode)
|
||||
def inner(proxy_mode, body_fn, args, kwargs, hints):
|
||||
if proxy_mode.enable_tracing:
|
||||
return trace_hints_wrapper(
|
||||
proxy_mode, hints_wrapper, body_fn, args, kwargs, hints
|
||||
)
|
||||
else:
|
||||
return hints_wrapper(body_fn, args, kwargs, hints)
|
@ -61,6 +61,7 @@ hop_that_doesnt_have_opinfo_test_allowlist = [
|
||||
"call_torchbind",
|
||||
"triton_kernel_wrapper_mutation",
|
||||
"triton_kernel_wrapper_functional",
|
||||
"hints_wrapper",
|
||||
]
|
||||
|
||||
torch.library.define(
|
||||
|
Reference in New Issue
Block a user