Files
pytorch/test/higher_order_ops/test_with_effects.py
FFFrog 6fc0ad22f0 Using the latest torch.library.register_fake API instead of torch.library.impl_abstract (#158839)
As the title stated.

`torch.library.impl_abstract` have beed deprecated in PyTorch2.4, so change to use the new API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158839
Approved by: https://github.com/jingsh, https://github.com/zou3519
ghstack dependencies: #158838
2025-07-25 02:37:30 +00:00

909 lines
35 KiB
Python

# Owner(s): ["module: functorch"]
# ruff: noqa: F841
# flake8: noqa: B950
import unittest
from collections import deque
from functools import partial
from typing import TYPE_CHECKING
import torch
import torch._dynamo
import torch._functorch
import torch._inductor
import torch._inductor.decomposition
from functorch.compile import (
aot_function,
default_decompositions,
min_cut_rematerialization_partition,
nop,
)
from torch._functorch.aot_autograd import aot_export_module
from torch._higher_order_ops.effects import with_effects
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
SM70OrLater,
SM80OrLater,
)
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TEST_CUDA,
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
from torch.testing._internal.two_tensor import TwoTensor
def extract_graph(fx_g, _, graph_cell):
graph_cell[0] = fx_g
return fx_g
def get_fw_bw_graph(
f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False
):
fw_graph_cell = [None]
bw_graph_cell = [None]
requires_grad = False
def fn_req_grad(t):
nonlocal requires_grad
requires_grad = requires_grad or t.requires_grad
return t
torch.utils._pytree.tree_map_only(torch.Tensor, fn_req_grad, inps)
out = aot_function(
f,
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
bw_compiler=(
partial(extract_graph, graph_cell=bw_graph_cell) if requires_grad else nop
),
partition_fn=partitioner,
decompositions=default_decompositions,
dynamic=dynamic,
)(*inps)
if requires_grad:
out.sum().backward()
return (fw_graph_cell[0], bw_graph_cell[0])
def make_inputs_non_leaves(inps):
return torch.utils._pytree.tree_map_only(torch.Tensor, lambda t: t.add(1), inps)
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
class TestWithEffects(TestCase):
def setUp(self):
init_torchbind_implementations()
def test_print(self):
class M(torch.nn.Module):
def forward(self, x):
torch.ops.aten._print("moo")
res = x + x
torch.ops.aten._print("moo")
return (res,)
inputs = (torch.randn(3),)
# Without functionalization, print should just appear in the graph directly
gm = make_fx(M())(*inputs)
FileCheck().check_count("torch.ops.aten._print.default", 2, exactly=True).run(
gm.code
)
# With functionalization, it should appear wrapped with with_effects()
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1, arg1_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
getitem = with_effects[0]; with_effects = None
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add)""",
)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
with torch._functorch.config.patch(unlift_effect_tokens=True):
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg1_1):
_make_token_default = torch.ops.prims._make_token.default()
with_effects = torch.ops.higher_order.with_effects(_make_token_default, torch.ops.aten._print.default, 'moo'); _make_token_default = None
getitem = with_effects[0]; with_effects = None
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
_sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_2]); getitem_2 = _sink_tokens_default = None
return [add]""", # noqa: B950
)
def test_torchbind_custom_op(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return (x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x),)
with enable_torchbind_tracing():
gm, gs = aot_export_module(M(), (torch.ones(2, 3),), trace_joint=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1, arg1_1):
_torchbind_obj0 = self._torchbind_obj0
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo.default, _torchbind_obj0, arg1_1); arg0_1 = _torchbind_obj0 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add = torch.ops.aten.add.Tensor(arg1_1, getitem_1); arg1_1 = getitem_1 = None
return (getitem, add)""", # noqa: B950
)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
def test_print_with_buffer_mutations(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.buf = torch.nn.Buffer(torch.ones(3))
def forward(self, x):
torch.ops.aten._print("moo")
res = x + x
self.buf.add_(res)
res = self.buf + x
torch.ops.aten._print("moo")
return (res,)
inputs = (torch.randn(3),)
# With functionalization, it should appear wrapped with with_effects()
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
getitem = with_effects[0]; with_effects = None
add = torch.ops.aten.add.Tensor(arg2_1, arg2_1)
add_1 = torch.ops.aten.add.Tensor(arg1_1, add); arg1_1 = add = None
add_2 = torch.ops.aten.add.Tensor(add_1, arg2_1); arg2_1 = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add_1, add_2)""",
)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
self.assertEqual(len(gs.buffers_to_mutate), 1)
def test_print_with_input_mutations(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
torch.ops.aten._print("moo")
res = x + x
x.add_(res)
res = x + x
torch.ops.aten._print("moo")
return (res,)
inputs = (torch.randn(3),)
# With functionalization, it should appear wrapped with with_effects()
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
self.assertEqual(len(gs.user_inputs_to_mutate), 1)
def test_alias_op(self):
def f(token, x):
token, out = with_effects(token, torch.ops.aten.absolute_.default, x)
return token, out
with self.assertRaisesRegex(
AssertionError, r"Ops with aliasing is not supported"
):
make_fx(f)(torch.tensor([]), torch.tensor(4))
def test_compile_aot_eager(self):
def f(x):
torch.ops.aten._print("moo")
res = x + x
torch.ops.aten._print("moo")
return res
inputs = (torch.randn(2, 3),)
res = torch.compile(f, backend="aot_eager")(*inputs)
self.assertTrue(torch.allclose(res, f(*inputs)))
@unittest.skipIf(IS_WINDOWS, "triton")
@unittest.skipIf(not SM70OrLater, "triton")
def test_compile_inductor(self):
def f(x):
torch.ops.aten._print("moo")
res = x + x
torch.ops.aten._print("moo")
return res
inputs = (torch.randn(2, 3),)
res = torch.compile(f, backend="inductor")(*inputs)
self.assertTrue(torch.allclose(res, f(*inputs)))
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@skipIfNoDynamoSupport
def test_compile_inductor_external_op_return_none(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::inplace_add",
"(Tensor input, Tensor(a!) output) -> ()",
lib=lib,
)
def inplace_add(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.device == output.device
output.add_(input)
lib.impl("inplace_add", inplace_add, "CompositeExplicitAutograd")
def f(x):
out = torch.empty(3)
out = torch.zeros_like(out)
torch.ops.mylib.inplace_add(x, out)
return out
inputs = (torch.randn(3),)
res = torch.compile(f, backend="inductor")(*inputs)
self.assertTrue(torch.allclose(res, f(*inputs)))
def test_compile_aot_eager_requires_grad(self):
def f(x):
torch.ops.aten._print("moo")
res = x + x
torch.ops.aten._print("moo")
return res
inputs = (torch.randn(2, 3, requires_grad=True),)
res = torch.compile(f, backend="aot_eager")(*inputs)
self.assertTrue(torch.allclose(res, f(*inputs)))
res.sum().backward()
@unittest.skipIf(IS_WINDOWS, "triton")
@unittest.skipIf(TEST_WITH_ROCM, "triton")
@unittest.skipIf(not SM80OrLater, "triton")
@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
@unittest.skipIf(not TEST_CUDA, "triton")
@skipIfNoDynamoSupport
def test_register_effectful_custom_op(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch.library.define(
"mylib::record_scalar_tensor",
"(Tensor x, str prefix) -> ()",
lib=lib,
)
# global variable to store the recorded tensor and prefix.
recorded_dict = {}
# Pytorch custorm op implementation
@torch.library.impl(
"mylib::record_scalar_tensor",
"CompositeExplicitAutograd",
lib=lib,
)
def record_scalar_tensor(x, prefix):
recorded_dict[prefix] = x.clone()
return
# Meta function of the custom op
@torch.library.register_fake(
"mylib::record_scalar_tensor",
lib=lib,
)
def record_scalar_tensor_meta(x, prefix):
return
from torch._higher_order_ops.effects import (
_EffectType,
_register_effectful_op,
)
_register_effectful_op(
torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
)
my_config = {}
my_config["MockModule"] = "mean"
my_config["MockModule.linear"] = "mean"
my_config["MockModule.relu"] = "mean"
class MyLinear(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.nn.Parameter(
torch.randn(out_features, in_features), requires_grad=True
)
self.bias = torch.nn.Parameter(
torch.randn(out_features), requires_grad=True
)
def forward(self, x):
return torch.nn.functional.linear(x, self.weight, self.bias)
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = MyLinear(10, 10)
self.register_buffer(
"buf0", torch.randn(10, 10, requires_grad=True)
)
def forward(self, x):
return torch.nn.functional.relu(self.linear(x) + self.buf0)
def forward_hook(
module: torch.nn.Module,
inputs: torch.Tensor,
output: torch.Tensor,
prefix: str,
aggregate_method: str,
) -> torch.Tensor:
if aggregate_method == "mean":
torch.ops.mylib.record_scalar_tensor(output.mean(), prefix)
elif aggregate_method == "max":
torch.ops.mylib.record_scalar_tensor(output.max(), prefix)
else:
# demo purpose, using "min"
torch.ops.mylib.record_scalar_tensor(output.sum(), prefix)
return output
def add_hooks(module, config):
handles: list[RemovableHandle] = []
q = deque([(module.__class__.__name__, module)])
while q:
name, m = q.pop()
children = [(name + "." + n, y) for (n, y) in m.named_children()]
q.extend(children)
aggregate_method = config.get(name, "mean")
prefix = name + ":" + aggregate_method
handle = m.register_forward_hook(
partial(
forward_hook,
prefix=prefix,
aggregate_method=aggregate_method,
)
)
if handle:
handles.append(handle)
return handles
x = torch.randn(10, 10, device="cuda")
mod = MockModule().to("cuda")
add_hooks(mod, my_config)
opt_mod = torch.compile(backend="inductor")(mod)
y = opt_mod(x)
self.assertTrue(torch.allclose(y, mod(x)))
# Ensure it works well with backward
y.sum().backward()
# Ensure the grad is existing
self.assertTrue(isinstance(opt_mod.linear.weight.grad, torch.Tensor))
self.assertEqual(len(recorded_dict), 2)
self.assertTrue("MockModule.linear:mean" in recorded_dict)
self.assertTrue("MockModule:mean" in recorded_dict)
@skipIfNoDynamoSupport
def test_effectful_custom_op_with_subclasses(self):
with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
lib.define("zoo(Tensor x) -> Tensor")
lib.define("zoo2(Tensor x) -> Tensor")
d = {"fw": 0, "bw": 0}
def reset_counter():
d["fw"] = 0
d["bw"] = 0
def assert_counter(fw, bw):
self.assertEqual(d["fw"], fw)
self.assertEqual(d["bw"], bw)
def foo_impl(a):
d["fw"] = d["fw"] + 1
return 2 * a.clone()
def foo_meta(a):
return a.clone()
def foo2_impl(x):
d["bw"] = d["bw"] + 1
return x.clone()
def foo2_meta(a):
return a.clone()
for backend in ["CPU", "CUDA"]:
lib.impl("zoo", foo_impl, backend)
lib.impl("zoo2", foo2_impl, backend)
lib.impl("zoo", foo_meta, "Meta")
lib.impl("zoo2", foo2_meta, "Meta")
def foo_bwd(ctx, grad):
torch.ops._mylib.zoo2(grad)
return grad.clone()
torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)
from torch._higher_order_ops.effects import (
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
_register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)
def fn(x, y):
return torch.ops._mylib.zoo(x) + y
def ins_sc():
return (
TwoTensor(
torch.tensor([1.0, 2.0, 3.0]), torch.tensor([1.0, 2.0, 3.0])
),
torch.tensor([4.0, 5.0, 6.0]),
)
def ins_dense():
return torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])
for i, (ins_fn, expected_fw_count) in enumerate(
zip([ins_sc, ins_dense], [2, 1])
):
reset_counter()
ref_out = fn(*ins_fn())
assert_counter(expected_fw_count, 0)
compiled_fn = torch.compile(fn, backend="aot_eager")
out = compiled_fn(*ins_fn())
reset_counter()
out = compiled_fn(*ins_fn())
assert_counter(expected_fw_count, 0)
self.assertEqual(ref_out, out)
def ins_dense_req_grad():
return (
torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
)
def ins_sc_req_grad():
return (
TwoTensor(
torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
),
TwoTensor(
torch.tensor([7.0, 8.0, 9.0], requires_grad=True),
torch.tensor([10.0, 11.0, 12.0], requires_grad=True),
),
)
for i, (
ins_fn_req_grad,
(
expected_fw_count,
expected_fw_count_after_bw,
expected_bw_count_after_bw,
),
) in enumerate(
zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)])
):
ref_ins = ins_fn_req_grad()
reset_counter()
ref_out = fn(*ref_ins)
assert_counter(expected_fw_count, 0)
ref_out.sum().backward()
assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
compiled_fn = torch.compile(fn, fullgraph=True)
ins = ins_fn_req_grad()
out = compiled_fn(*ins)
reset_counter()
out = compiled_fn(*ins)
assert_counter(expected_fw_count, 0)
self.assertEqual(ref_out, out)
out.sum().backward()
assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
self.assertEqual(ref_ins[1].grad, ins[1].grad)
self.assertEqual(ref_ins[0].grad, ins[0].grad)
fw_graph, bw_graph = get_fw_bw_graph(fn, ins_sc_req_grad())
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5):
with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.zoo.default, primals_2); primals_1 = primals_2 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.zoo.default, primals_3); getitem = primals_3 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
add = torch.ops.aten.add.Tensor(getitem_1, primals_4); getitem_1 = primals_4 = None
add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_5); getitem_3 = primals_5 = None
return (getitem_2, add, add_1)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, tangents_1, tangents_2, tangents_token):
with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.zoo2.default, tangents_1); tangents_token = None
getitem_4 = with_effects_2[0]; with_effects_2 = None
with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.zoo2.default, tangents_2); getitem_4 = None
getitem_6 = with_effects_3[0]; with_effects_3 = None
clone = torch.ops.aten.clone.default(tangents_1)
clone_1 = torch.ops.aten.clone.default(tangents_2)
return (clone, clone_1, tangents_1, tangents_2, getitem_6)""",
)
def test_effects_and_input_mutation_return(self):
def fn(a, b):
torch.ops.aten._print("effect")
return torch.sin(a, out=b)
inp = [torch.randn(3, 3), torch.ones(3, 3)]
ref_out = fn(*inp)
out = torch.compile(fn, fullgraph=True)(*inp)
self.assertEqual(ref_out, out)
fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect'); arg0_1 = None
getitem = with_effects[0]; with_effects = None
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
return (getitem, sin, sin)""",
)
def test_effects_and_input_output_view_simple(self):
def fn(a):
return a.view(-1)
inp = [torch.ones(2, 2, requires_grad=False).add(1)]
ref_out = fn(*inp)
out = torch.compile(fn, fullgraph=True)(*inp)
self.assertEqual(ref_out, out)
inp = [torch.ones(2, 2, requires_grad=True).add(1)]
ref_out = fn(*inp)
out = torch.compile(fn, fullgraph=True)(*inp)
self.assertEqual(ref_out, out)
fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, arg0_1):
view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None
return (view,)""",
)
def test_effects_and_aliased_outputs(self):
def fn(a):
b = a.mul(2)
torch.ops.aten._print("effect")
c = b.view(-1)
return b, c
f_compiled = aot_function(fn, nop)
for req_grad in [True, False]:
inp = torch.ones(3, requires_grad=req_grad)
out_ref = fn(inp)
out_test = f_compiled(inp)
self.assertEqual(out_ref[0], out_test[0])
self.assertEqual(out_ref[1], out_test[1])
# Try mutating one of the outputs, which is aliased.
out_ref[0].mul_(3)
out_test[0].mul_(3)
# Assert that the aliasing relationship was preserved
self.assertEqual(out_ref[0], out_test[0])
self.assertEqual(out_ref[1], out_test[1])
def test_effects_and_input_mutation_is_output(self):
def fn(a):
a.mul_(2)
torch.ops.aten._print("effect")
return a
inp = make_inputs_non_leaves([torch.ones(3, 3, requires_grad=True)])
ref_out = fn(*inp)
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp)
self.assertEqual(ref_out, out)
inp = [torch.ones(3, 3, requires_grad=False)]
ref_out = fn(*inp)
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp)
self.assertEqual(ref_out, out)
fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
mul = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect'); arg0_1 = None
getitem = with_effects[0]; with_effects = None
return (getitem, mul, mul)""",
)
@skipIfTorchDynamo()
def test_effectful_op_in_backward(self):
with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
lib.define("foo(Tensor x) -> Tensor")
def foo_impl(a):
return a.clone()
def foo_bwd(ctx, grad):
return torch.ops._mylib.foo(grad)
for backend in ["CPU", "CUDA", "Meta"]:
lib.impl("foo", foo_impl, backend)
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
try:
def fn(x, y):
return torch.ops._mylib.foo(x) + y
def ins_dense_req_grad():
return (
torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
)
def ins_sc_req_grad():
return (
TwoTensor(
torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
),
torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
)
for i, ins_fn in enumerate([ins_dense_req_grad, ins_sc_req_grad]):
ref_ins = ins_fn()
ref_out = fn(*ref_ins)
ref_out.sum().backward()
compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True)
ins = ins_fn()
out = compiled_fn(*ins)
self.assertEqual(ref_out, out)
out.sum().backward()
self.assertEqual(ref_ins[1].grad, ins[1].grad)
self.assertEqual(ref_ins[0].grad, ins[0].grad)
fw_graph, bw_graph = get_fw_bw_graph(fn, ins)
if i == 0:
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2, primals_3):
with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2); primals_1 = primals_2 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add = torch.ops.aten.add.Tensor(getitem_1, primals_3); getitem_1 = primals_3 = None
return (getitem, add)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, tangents_1, tangents_token):
with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1); tangents_token = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
return (getitem_3, tangents_1, getitem_2)""",
)
elif i == 1:
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2, primals_3, primals_4):
with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2); primals_1 = primals_2 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.foo.default, primals_3); getitem = primals_3 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
add = torch.ops.aten.add.Tensor(getitem_1, primals_4); getitem_1 = None
add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_4); getitem_3 = primals_4 = None
return (getitem_2, add, add_1)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, tangents_1, tangents_2, tangents_token):
with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1); tangents_token = None
getitem_4 = with_effects_2[0]
getitem_5 = with_effects_2[1]; with_effects_2 = None
with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.foo.default, tangents_2); getitem_4 = None
getitem_6 = with_effects_3[0]
getitem_7 = with_effects_3[1]; with_effects_3 = None
return (getitem_5, getitem_7, tangents_1, tangents_2, getitem_6)""",
)
else:
raise NotImplementedError
finally:
_deregister_effectful_op(torch.ops._mylib.foo.default)
@skipIfNoDynamoSupport
def test_regular_effectful_op_only_in_backward(self):
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
try:
def fn(x):
return x.sin()
def inps_fn():
return (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn())
fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn())
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1):
sin = torch.ops.aten.sin.default(primals_1)
return (sin, primals_1)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, primals_1, tangents_1, tangents_token):
with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1); tangents_token = primals_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1); tangents_1 = getitem_1 = None
return (mul, getitem)""",
)
def inps_fn_sc():
return (
TwoTensor(
torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
),
)
torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn_sc())
fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn_sc())
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2):
sin = torch.ops.aten.sin.default(primals_1)
sin_1 = torch.ops.aten.sin.default(primals_2)
return (sin, sin_1, primals_1, primals_2)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1); tangents_token = primals_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten.cos.default, primals_2); getitem = primals_2 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1); tangents_1 = getitem_1 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_2, getitem_3); tangents_2 = getitem_3 = None
return (mul, mul_1, getitem_2)""",
)
finally:
_deregister_effectful_op(torch.ops.aten.cos.default)
@skipIfNoDynamoSupport
def test_regular_effectful_op_in_forward_and_backward(self):
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
try:
def fn(x):
x = x.cos()
return x.sin()
inps = (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
torch.compile(fn, backend="inductor", fullgraph=True)(*inps)
fw_graph, bw_graph = get_fw_bw_graph(fn, inps)
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2):
with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.aten.cos.default, primals_2); primals_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
sin = torch.ops.aten.sin.default(getitem_1)
return (getitem, sin, primals_2, getitem_1)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, getitem_1); tangents_token = getitem_1 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_3); tangents_1 = getitem_3 = None
sin_1 = torch.ops.aten.sin.default(primals_2); primals_2 = None
neg = torch.ops.aten.neg.default(sin_1); sin_1 = None
mul_1 = torch.ops.aten.mul.Tensor(mul, neg); mul = neg = None
return (mul_1, getitem_2)""",
)
finally:
_deregister_effectful_op(torch.ops.aten.cos.default)
if __name__ == "__main__":
run_tests()