Compare commits

..

2 Commits

Author SHA1 Message Date
43f24e9876 add partition meta 2025-11-12 16:00:43 -08:00
4407acbb20 Add test for unbacked symint expression
Add backend node meta to invoke subgraph
2025-11-12 16:00:43 -08:00
46 changed files with 931 additions and 1800 deletions

View File

@ -1681,13 +1681,14 @@ class GraphModule(torch.nn.Module):
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None
return (getitem,)
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[4, 4]"):
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (y,)
return (y, y)
""",
)
@ -1797,9 +1798,9 @@ class GraphModule(torch.nn.Module):
out: "f32[4, 4]" = l_x_.sin()
sin_1: "f32[4, 4]" = torch.sin(o)
cos: "f32[4, 4]" = torch.cos(sin_1)
sin_2: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (cos, sin_2, matmul, o, out, sin_1)
child: "f32[4, 4]" = torch.cos(sin_1)
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (child, child_1, matmul, o, out, sin_1)
""",
)

View File

@ -1,11 +1,9 @@
# Owner(s): ["module: dynamo"]
import copy
import functools
import inspect
import os
import pickle
import unittest
from contextlib import contextmanager
from unittest.mock import patch
@ -15,16 +13,13 @@ import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
from torch._dynamo.exc import PackageError, Unsupported
from torch._dynamo.package import DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.fx._graph_pickler import GraphPickler
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TEST_CUDA,
)
from torch.testing._internal.common_utils import instantiate_parametrized_tests
MY_LAMBDA = lambda x: x + 1 # noqa: E731
@ -604,92 +599,6 @@ from user code:
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
fn,
(make_inputs(), {}),
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_module(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
mod = SimpleLinearModule()
def make_inputs():
return (torch.randn(4, 3),)
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
mod,
[ModelInput(make_inputs(), {}, [])],
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
def get_grads(m: torch.nn.Module):
return {name: p.grad for name, p in m.named_parameters()}
original_mod = copy.deepcopy(mod)
test_inputs = make_inputs()
expected = mod(*test_inputs)
expected.sum().backward()
expected_grads = get_grads(mod)
actual = compiled_mod(*test_inputs)
self.assertEqual(expected, actual)
serialized = compiled_mod.serialize()
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
actual = compiled_fn(*test_inputs)
actual.sum().backward()
self.assertEqual(get_grads(original_mod), expected_grads)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_torch_compile(self):
with torch.device("cuda"):
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch.compile(
fn, fullgraph=True, options={"use_aoti": True}
).aot_compile((make_inputs(), {}))
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -222,13 +222,13 @@ class GraphModule(torch.nn.Module):
matmul: "f32[3, 3]" = l_x_ @ l_y_
sin: "f32[3, 3]" = matmul.sin(); matmul = None
cos: "f32[3, 3]" = sin.cos(); sin = None
child: "f32[3, 3]" = sin.cos(); sin = None
add: "f32[3, 3]" = l_x_ + l_y_
sub: "f32[3, 3]" = l_x_ - l_y_
child_1: "f32[3, 3]" = l_x_ + l_y_
child_2: "f32[3, 3]" = l_x_ - l_y_
matmul_1: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (cos, add, sub, matmul_1)
child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (child, child_1, child_2, child_3)
""", # noqa: B950
)
self.assertExpectedInline(

View File

@ -249,7 +249,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
# when testing with dynamic shape, symbols are lifted as input
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
def test_return_captured_vars(self):
freevar1 = torch.randn(3)
@ -267,7 +267,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
# be the input.
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
def test_return_captured_var_used_multiple_times(self):
freevar = torch.randn(3)
@ -282,7 +282,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
x = torch.randn(3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 2)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
def test_capture_untracked_global(self):
def f(x):
@ -762,15 +762,15 @@ class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
getitem: "f32[s77]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
child: "f32[s77]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
class wrap_body_0(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
sin: "f32[s77]" = l_x_.sin(); l_x_ = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (sin, sin_1)
child: "f32[s77]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
""",
)
else:
@ -801,15 +801,15 @@ class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
getitem: "f32[3]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
child: "f32[3]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
sin: "f32[3]" = l_x_.sin(); l_x_ = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (sin, sin_1)
child: "f32[3]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
""",
)
@ -922,16 +922,16 @@ class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
getitem: "f32[3]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
child: "f32[3]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
sin: "f32[3]" = l_x_.sin(); l_x_ = None
add: "f32[3]" = sin + size; sin = size = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (add, sin_1)
child: "f32[3]" = sin + size; sin = size = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
""",
)
@ -2458,10 +2458,10 @@ class GraphModule(torch.nn.Module):
class wrap_body_0(torch.nn.Module):
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
add: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
add_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
return (add, add_1)
child_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
return (child, child_1)
""",
)
@ -2655,9 +2655,9 @@ class GraphModule(torch.nn.Module):
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[2, 3]"):
sin: "f32[2, 3]" = l_x_.sin()
cos: "f32[2, 3]" = l_x_.cos(); l_x_ = None
return (sin, cos)
child: "f32[2, 3]" = l_x_.sin()
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
return (child, child_1)
""",
)
@ -2687,13 +2687,13 @@ class GraphModule(torch.nn.Module):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
value: "f32[3]" = wrap[0]; wrap = None
return (value,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]"):
neg: "f32[3]" = -l_x_; l_x_ = None
return (neg,)
child: "f32[3]" = -l_x_; l_x_ = None
return (child,)
""",
)
@ -3318,17 +3318,17 @@ class GraphModule(torch.nn.Module):
hints_wrapper_body_1 = self.hints_wrapper_body_1
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
return (getitem,)
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
return (res,)
class hints_wrapper_body_1(torch.nn.Module):
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
hints_wrapper_body_0 = self.hints_wrapper_body_0
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
x_1: "f32[2, 4]" = torch.abs(getitem); getitem = None
return (x_1,)
x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None
return (x_2,)
class hints_wrapper_body_0(torch.nn.Module):
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):

View File

@ -10,6 +10,10 @@ import torch.utils.checkpoint
from torch._dynamo.backends.common import aot_autograd
from torch._functorch._aot_autograd.autograd_cache import BundledCompiledForward
from torch._guards import detect_fake_mode
from torch._higher_order_ops.invoke_subgraph import (
NestedCompileBackend,
NestedCompileRegionOptions,
)
from torch._inductor.output_code import RegionalOutputCode
from torch._inductor.test_case import run_tests
from torch._inductor.utils import run_fw_bw_and_get_code
@ -468,6 +472,86 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
# flex in forward and flex_backward in backward
self.assertEqual(len(codes), 2)
@parametrize("serialize", [True, False])
def test_invoke_subgraph_regional_compile(self, serialize):
call_test_partitioner_ct = 0
original_default_partitioner = torch._functorch.partitioners.default_partition
def test_partitioner(
*args, **kwargs
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
nonlocal call_test_partitioner_ct
call_test_partitioner_ct += 1
return original_default_partitioner(*args, **kwargs)
# pyrefly: ignore [not-iterable]
if serialize:
# Callable cannot be serialized
torch._functorch.partitioners.default_partition = test_partitioner
partitioner = "default_partition"
else:
partitioner = test_partitioner
backend = NestedCompileRegionOptions(
backend=NestedCompileBackend.INDUCTOR,
inductor_configs={
"max_autotune": True,
"triton.cudagraphs": False,
},
partitioner=partitioner,
)
@torch.compiler.nested_compile_region(backend_options=backend)
def gn_with_backend(x):
return torch.sin(x)
@torch.compiler.nested_compile_region
def gn_without_backend(x):
return torch.cos(x)
def fn(x):
return gn_with_backend(x) + gn_without_backend(x)
backend = aot_eager_regional_inductor(serialize=serialize)
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
import torch._inductor.config as inductor_config
# Hook to verify options
original_compile = torch._inductor.standalone_compile
captured_options = []
def verify_options(*args, **kwargs):
options = kwargs.get("options", {})
captured_options.append(options)
# Verify config is set as expected from explicit options
assert inductor_config.max_autotune, "max_autotune should be True"
assert not inductor_config.triton.cudagraphs, (
"triton.cudagraphs should be False"
)
return original_compile(*args, **kwargs)
torch._inductor.standalone_compile = verify_options
try:
x = torch.randn(8, 8, requires_grad=True)
# opt_fn(x)
res, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
self.assertEqual(len(codes), 2)
self.assertTrue("repeated_subgraph0" in codes[0])
self.assertTrue("repeated_subgraph1" not in codes[0])
self.assertTrue("repeated_subgraph0" in codes[1])
self.assertTrue("repeated_subgraph1" not in codes[1])
self.assertEqual(call_test_partitioner_ct, 1)
true_res = fn(x)
self.assertEqual(res, true_res)
finally:
torch._inductor.standalone_compile = original_compile
torch._functorch.partitioners.default_partition = (
original_default_partitioner
)
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
class TestRegionalOutputCode(torch._inductor.test_case.TestCase):

View File

@ -21,6 +21,10 @@ from torch._dynamo.testing import (
InductorAndRecordGraphs,
normalize_gm,
)
from torch._higher_order_ops.invoke_subgraph import (
NestedCompileBackend,
NestedCompileRegionOptions,
)
from torch._higher_order_ops.schema import find_hop_schema
from torch._inductor import config as inductor_config
from torch._inductor.pattern_matcher import (
@ -899,14 +903,14 @@ class GraphModule(torch.nn.Module):
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
mul: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
mul_1: "f32[8]" = mul * 2; mul = None
return (mul_1,)
child: "f32[8]" = mul * 2; mul = None
return (child,)
class subgraph_1(torch.nn.Module):
def forward(self, a: "f32[8]", l_y_: "f32[8]"):
mul: "f32[8]" = torch.mul(a, l_y_); a = l_y_ = None
mul_1: "f32[8]" = mul * 3; mul = None
return (mul_1,)
child: "f32[8]" = mul * 3; mul = None
return (child,)
""",
)
@ -983,20 +987,20 @@ class GraphModule(torch.nn.Module):
subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
subgraph_1 = self.subgraph_0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = None
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', x, l_y_); subgraph_1 = x = None
x_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
subgraph_2 = self.subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', getitem_1, l_y_); subgraph_2 = getitem_1 = None
getitem_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', x_1, l_y_); subgraph_2 = x_1 = None
x_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
subgraph_3 = self.subgraph_0
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', getitem_2, l_y_); subgraph_3 = getitem_2 = None
getitem_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', x_2, l_y_); subgraph_3 = x_2 = None
x_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
subgraph_4 = self.subgraph_0
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', getitem_3, l_y_); subgraph_4 = getitem_3 = l_y_ = None
getitem_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
return (getitem_4,)
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', x_3, l_y_); subgraph_4 = x_3 = l_y_ = None
x_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
return (x_4,)
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
@ -1495,9 +1499,9 @@ class GraphModule(torch.nn.Module):
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]"):
mul: "f32[8, 8]" = l_x_ * 2
mul_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
return (mul, mul_1)
child: "f32[8, 8]" = l_x_ * 2
child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
return (child, child_1)
""",
)
@ -1556,6 +1560,101 @@ class GraphModule(torch.nn.Module):
res = opt_fn(x)
self.assertEqual(ref, res)
def test_unbacked_expr(self):
@nested_compile_region
def gn(x):
return x + 1
def fn(c):
d = torch.concat([c, c], dim=0)
d = gn(d)
return d
c = torch.randn((64, 32))
torch._dynamo.decorators.mark_unbacked(c, 0)
ref = fn(c)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
res = opt_fn(c)
self.assertEqual(ref, res)
def test_grad_accumulation(self):
mod1 = torch.nn.Linear(8, 8)
mod2 = torch.nn.Linear(8, 8)
mod3 = torch.nn.Linear(8, 8)
@nested_compile_region
def gn(x):
return mod1(x) - mod2(x)
def fn(c):
d = gn(c) - mod3(c)
return d * 2
c = torch.randn((8, 8), requires_grad=True)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
res = opt_fn(c)
res.sum().backward()
# fw_add_nodes = backend.fw_graphs[0].graph.find_nodes(op="call_function", target = torch.ops.aten.add.Tensor)
# The gradient addition node for mod3 is not in the subgraph.
bw_add_nodes = backend.bw_graphs[0].graph.find_nodes(
op="call_function", target=torch.ops.aten.add.Tensor
)
self.assertEqual(len(bw_add_nodes), 1)
subgraph_node = backend.bw_graphs[0].graph.find_nodes(op="get_attr")[0]
subgraph_name = subgraph_node.target
# The gradient addition node between mod1 and mode2 will be in the subgraph
bw_add_nodes = getattr(backend.bw_graphs[0], subgraph_name).graph.find_nodes(
op="call_function", target=torch.ops.aten.add.Tensor
)
self.assertEqual(len(bw_add_nodes), 1)
def test_backend_parameter(self):
backend = NestedCompileRegionOptions(NestedCompileBackend.INDUCTOR)
# Test that backend parameter is properly set in node.meta
@nested_compile_region(backend_options=backend)
def gn_with_backend(x):
return torch.sin(x)
@nested_compile_region
def gn_without_backend(x):
return torch.cos(x)
def fn(x):
return gn_with_backend(x) + gn_without_backend(x)
backend = EagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(8, 8, requires_grad=False)
opt_fn(x)
# Check that we captured the graph
self.assertEqual(len(backend.graphs), 1)
graph = backend.graphs[0]
# Find invoke_subgraph nodes and check their backend metadata
invoke_subgraph_nodes = [
node
for node in graph.graph.nodes
if node.op == "call_function"
and node.target == torch.ops.higher_order.invoke_subgraph
]
# We should have 2 invoke_subgraph calls
self.assertEqual(len(invoke_subgraph_nodes), 2)
# First invoke_subgraph (gn_with_backend) should have backend
self.assertIn("custom", invoke_subgraph_nodes[0].meta)
# Second invoke_subgraph (gn_without_backend) should have custom=None or no custom
backend_value = invoke_subgraph_nodes[1].meta.get("custom", None)
self.assertIsNone(backend_value)
def test_complex(self):
# Observed in Wan2.1
@nested_compile_region
@ -2504,107 +2603,6 @@ class GraphModule(torch.nn.Module):
self.assertEqual(f(x, other), f_compile(x, other))
self.assertTrue(called)
def test_udf_output(self):
class Foo:
def __init__(self, a, b):
self.a = a
self.b = b
@nested_compile_region
def gn(x, y):
a = torch.sin(x)
b = torch.cos(y)
return Foo(a, b)
def fn(x, y):
foo1 = gn(x, y)
foo2 = gn(foo1.a, y)
return foo1.b + foo2.a # + foo2.b
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(8, 8, requires_grad=True)
y = torch.randn(8, 8, requires_grad=True)
x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
ref = fn(x, y)
res = opt_fn(x_clone, y_clone)
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
if not TEST_WITH_CROSSREF:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[8, 8]", L_y_: "f32[8, 8]"):
l_x_ = L_x_
l_y_ = L_y_
subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
getitem: "f32[8, 8]" = invoke_subgraph[0]
getitem_1: "f32[8, 8]" = invoke_subgraph[1]; invoke_subgraph = None
subgraph_1 = self.subgraph_0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = l_y_ = None
getitem_2: "f32[8, 8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
add: "f32[8, 8]" = getitem_1 + getitem_2; getitem_1 = getitem_2 = None
return (add,)
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]", l_y_: "f32[8, 8]"):
a: "f32[8, 8]" = torch.sin(l_x_); l_x_ = None
b: "f32[8, 8]" = torch.cos(l_y_); l_y_ = None
return (a, b)
""",
)
# High piority - grads are wrong
@unittest.expectedFailure
def test_grad_accuracy_check(self):
class Foo:
def __init__(self, a, b):
self.a = a
self.b = b
@nested_compile_region
def gn(x):
a = torch.sin(x)
b = torch.cos(x)
return (a, b)
def fn(x):
foo1 = gn(x)
foo2 = gn(foo1[0])
return foo1[1] + foo2[0] + foo2[1]
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(8, 8, requires_grad=True)
x_clone = x.detach().clone().requires_grad_(True)
x.grad = None
x_clone.grad = None
ref = fn(x)
res = opt_fn(x_clone)
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
@skipIfTorchDynamo("Not a torch._dynamo test")
@parameterized_class(

View File

@ -286,31 +286,47 @@ class GraphModule(torch.nn.Module):
l_self_modules_wo_parameters_weight_ = L_self_modules_wo_parameters_weight_
l_self_modules_w1_parameters_weight_ = L_self_modules_w1_parameters_weight_
l_self_modules_w2_parameters_weight_ = L_self_modules_w2_parameters_weight_
q: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wq_parameters_weight_, None); l_self_modules_wq_parameters_weight_ = None
k: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wk_parameters_weight_, None); l_self_modules_wk_parameters_weight_ = None
v: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wv_parameters_weight_, None); l_self_modules_wv_parameters_weight_ = None
unflatten: "f32[8, 16, 16, 6]" = q.unflatten(-1, (16, -1)); q = None
q_1: "f32[8, 16, 16, 6]" = unflatten.permute(0, 2, 1, 3); unflatten = None
unflatten_1: "f32[8, 16, 16, 6]" = k.unflatten(-1, (16, -1)); k = None
k_1: "f32[8, 16, 16, 6]" = unflatten_1.permute(0, 2, 1, 3); unflatten_1 = None
unflatten_2: "f32[8, 16, 16, 6]" = v.unflatten(-1, (16, -1)); v = None
v_1: "f32[8, 16, 16, 6]" = unflatten_2.permute(0, 2, 1, 3); unflatten_2 = None
subgraph_0 = self.subgraph_0
local_map_hop = torch.ops.higher_order.local_map_hop(subgraph_0, q_1, k_1, v_1); subgraph_0 = q_1 = k_1 = v_1 = None
getitem: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
permute_3: "f32[8, 16, 16, 6]" = getitem.permute(0, 2, 1, 3); getitem = None
o: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
o_1: "f32[8, 16, 96]" = torch._C._nn.linear(o, l_self_modules_wo_parameters_weight_, None); o = l_self_modules_wo_parameters_weight_ = None
o0: "f32[8, 16, 96]" = o_1 + l_x_; o_1 = l_x_ = None
o_2: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
o_3: "f32[8, 16, 384]" = torch.nn.functional.relu(o_2); o_2 = None
o_4: "f32[8, 16, 96]" = torch._C._nn.linear(o_3, l_self_modules_w2_parameters_weight_, None); o_3 = l_self_modules_w2_parameters_weight_ = None
o_5: "f32[8, 16, 96]" = o0 + o_4; o0 = o_4 = None
return (o_5,)
o: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
permute_3: "f32[8, 16, 16, 6]" = o.permute(0, 2, 1, 3); o = None
o_1: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
o_2: "f32[8, 16, 96]" = torch._C._nn.linear(o_1, l_self_modules_wo_parameters_weight_, None); o_1 = l_self_modules_wo_parameters_weight_ = None
o0: "f32[8, 16, 96]" = o_2 + l_x_; o_2 = l_x_ = None
o_3: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
o_4: "f32[8, 16, 384]" = torch.nn.functional.relu(o_3); o_3 = None
o_5: "f32[8, 16, 96]" = torch._C._nn.linear(o_4, l_self_modules_w2_parameters_weight_, None); o_4 = l_self_modules_w2_parameters_weight_ = None
o_6: "f32[8, 16, 96]" = o0 + o_5; o0 = o_5 = None
return (o_6,)
class subgraph_0(torch.nn.Module):
def forward(self, q_1: "f32[1, 2, 4, 6]", k_1: "f32[1, 2, 16, 6]", v_1: "f32[1, 2, 16, 6]"):
out: "f32[1, 2, 4, 6]" = torch._C._nn.scaled_dot_product_attention(query = q_1, key = k_1, value = v_1, is_causal = False); q_1 = k_1 = v_1 = None
return (out,)""",
return (out,)
""",
ignore_empty_lines=True,
)

View File

@ -18,16 +18,15 @@ from functorch.compile import (
nop,
)
from torch._functorch.aot_autograd import aot_export_module
from torch._higher_order_ops.effects import (
_EffectType,
_get_effect,
_register_effectful_op,
with_effects,
)
from torch._higher_order_ops.effects import with_effects
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM70OrLater, SM80OrLater
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
SM70OrLater,
SM80OrLater,
)
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
from torch.testing._internal.common_utils import (
IS_WINDOWS,
@ -301,6 +300,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
@unittest.skipIf(IS_WINDOWS, "triton")
@unittest.skipIf(TEST_WITH_ROCM, "triton")
@unittest.skipIf(not SM80OrLater, "triton")
@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
@unittest.skipIf(not TEST_CUDA, "triton")
@skipIfNoDynamoSupport
def test_register_effectful_custom_op(self):
@ -308,23 +308,41 @@ def forward(self, arg0_1, arg1_1, arg2_1):
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch.library.define(
"mylib::record_scalar_tensor",
"(Tensor x, str prefix) -> ()",
lib=lib,
)
# global variable to store the recorded tensor and prefix.
recorded_dict = {}
# Pytorch custom op implementation
@torch.library.custom_op("mylib::record_scalar_tensor", mutates_args=())
def record_scalar_tensor(x: torch.Tensor, prefix: str) -> None:
# Pytorch custorm op implementation
@torch.library.impl(
"mylib::record_scalar_tensor",
"CompositeExplicitAutograd",
lib=lib,
)
def record_scalar_tensor(x, prefix):
recorded_dict[prefix] = x.clone()
return
# Meta function of the custom op
@record_scalar_tensor.register_fake
@torch.library.register_fake(
"mylib::record_scalar_tensor",
lib=lib,
)
def record_scalar_tensor_meta(x, prefix):
return
record_scalar_tensor.register_effect(_EffectType.ORDERED)
from torch._higher_order_ops.effects import (
_EffectType,
_register_effectful_op,
)
self.assertEqual(_get_effect(record_scalar_tensor), _EffectType.ORDERED)
_register_effectful_op(
torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
)
my_config = {}
my_config["MockModule"] = "mean"
@ -451,13 +469,14 @@ def forward(self, arg0_1, arg1_1, arg2_1):
torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)
torch.library._register_effectful_op(
torch.ops._mylib.zoo.default, _EffectType.ORDERED
)
torch.library._register_effectful_op(
torch.ops._mylib.zoo2.default, _EffectType.ORDERED
from torch._higher_order_ops.effects import (
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
_register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)
def fn(x, y):
return torch.ops._mylib.zoo(x) + y
@ -668,13 +687,13 @@ def forward(self, arg0_1, arg1_1):
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)
handle = _register_effectful_op(
torch.ops._mylib.foo.default, _EffectType.ORDERED
)
self.assertEqual(
_get_effect(torch.ops._mylib.foo.default), _EffectType.ORDERED
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
try:
def fn(x, y):
@ -760,13 +779,17 @@ def forward(self, tangents_1, tangents_2, tangents_token):
else:
raise NotImplementedError
finally:
handle.destroy()
self.assertEqual(_get_effect(torch.ops._mylib.foo.default), None)
_deregister_effectful_op(torch.ops._mylib.foo.default)
@skipIfNoDynamoSupport
def test_regular_effectful_op_only_in_backward(self):
handle = _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
try:
def fn(x):
@ -829,11 +852,17 @@ def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
return (mul, mul_1, getitem_2)""",
)
finally:
handle.destroy()
_deregister_effectful_op(torch.ops.aten.cos.default)
@skipIfNoDynamoSupport
def test_regular_effectful_op_in_forward_and_backward(self):
handle = _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
try:
def fn(x):
@ -868,7 +897,7 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
return (mul_1, getitem_2)""",
)
finally:
handle.destroy()
_deregister_effectful_op(torch.ops.aten.cos.default)
if __name__ == "__main__":

View File

@ -136,59 +136,12 @@ class TestStandaloneInductor(TestCase):
mod_opt = inductor.compile(mod, inp)
self.assertEqual(mod(*inp), mod_opt(*inp))
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_COMPILE": "1"})
def test_inductor_generate_debug_compile(self):
cpp_code = """
int main(){
return 0;
}
"""
_, source_path = write(
cpp_code,
"cpp",
)
build_option = CppOptions()
cpp_builder = CppBuilder(
name="test_compile",
sources=source_path,
output_dir=os.path.dirname(source_path),
BuildOption=build_option,
)
cpp_builder.build()
binary_path = cpp_builder.get_target_file_path()
"""
When we turn on generate debug compile.
On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG.
On Linux, it should create some debug sections in binary file.
"""
def check_linux_debug_section(module_path: str):
check_cmd = shlex.split(f"readelf -S {module_path}")
output = safe_command_output(check_cmd)
has_debug_sym = ".debug_info" in output
self.assertEqual(has_debug_sym, True)
def check_windows_pdb_exist(module_path: str):
file_name_no_ext = os.path.splitext(module_path)[0]
file_name_pdb = f"{file_name_no_ext}.pdb"
has_pdb_file = os.path.exists(file_name_pdb)
self.assertEqual(has_pdb_file, True)
if _IS_WINDOWS:
check_windows_pdb_exist(binary_path)
elif _IS_MACOS:
pass # MacOS not sure that if it should be works.
else:
check_linux_debug_section(binary_path)
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_SYMBOL": "1"})
def test_inductor_generate_debug_symbol(self):
cpp_code = """
int main(){
return 0;
}
int main(){
return 0;
}
"""
_, source_path = write(

View File

@ -4902,21 +4902,6 @@ class CommonTemplate:
(torch.randn(2, 4, 6, 6),),
)
@skip_if_gpu_halide
@xfail_if_mps
@config.patch(combo_kernels=True)
def test_combo_kernel_cpu(self):
def fn(x):
return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d(
x + 1, (2, 5)
)
self.common(
fn,
(torch.randn(2, 4, 16, 16),),
check_lowp=False,
)
@xfail_if_mps # Non-divisible input sizes are not implemented on MPS device
def test_adaptive_avg_pool2d2(self):
# Big kernel size, use fallback

View File

@ -90,7 +90,7 @@ class TestOpaqueObject(TestCase):
# This is not accurate since the queue could have tensors that are
# not rank 1
ctx = torch._custom_op.impl.get_ctx()
u0 = ctx.new_dynamic_size()
u0 = ctx.create_unbacked_symint()
return torch.empty(u0)
self.lib._register_fake("queue_pop", pop_impl_fake)
@ -107,7 +107,8 @@ class TestOpaqueObject(TestCase):
@size_impl.register_fake
def size_impl_fake(q: torch._C.ScriptObject) -> int:
ctx = torch._custom_op.impl.get_ctx()
u0 = ctx.new_dynamic_size()
u0 = ctx.create_unbacked_symint()
torch._check_is_size(u0)
return u0
super().setUp()

View File

@ -1,22 +1,12 @@
# Owner(s): ["module: custom-operators"]
import random
from contextlib import ExitStack
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import AotEagerAndRecordGraphs
from torch._functorch.aot_autograd import (
aot_compile_joint_with_descriptors,
aot_export_joint_with_descriptors,
aot_export_module,
)
from torch._library.effects import EffectType
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import register_opaque_type
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -51,21 +41,11 @@ class OpaqueQueue:
class RNGState:
def __init__(self, seed):
self.seed = seed
self.rng = random.Random(self.seed)
class Counter:
def __init__(self, start):
self.counter = torch.tensor(start)
def increment_counter(self):
self.counter += 1
self.rng = random.Random(seed)
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
register_opaque_type(RNGState, "_TestOpaqueObject_RNGState")
register_opaque_type(Counter, "_TestOpaqueObject_Counter")
class TestOpaqueObject(TestCase):
@ -145,20 +125,6 @@ class TestOpaqueObject(TestCase):
def noisy_inject_fake(x: torch.Tensor, obj: RNGState) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.custom_op(
"_TestOpaqueObject::increment_counter",
mutates_args=["prev"],
)
def increment_counter_impl(c: Counter, prev: torch.Tensor) -> torch.Tensor:
assert isinstance(c, Counter)
prev.copy_(c.counter)
c.increment_counter()
return c.counter
@increment_counter_impl.register_fake
def increment_counter_fake(c: Counter, prev: torch.Tensor) -> torch.Tensor:
return torch.empty_like(prev)
super().setUp()
def tearDown(self):
@ -267,235 +233,6 @@ def forward(self, arg0_1, arg1_1):
):
make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3))
def test_aot_export(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, rng_state, x):
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
x = x * x
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
x = x + x
return (x,)
mod = Model()
rng = RNGState(0)
x = torch.ones(2, 3)
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
fake_rng = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, rng)
fake_x = fake_mode.from_tensor(x)
gm = aot_export_module(mod, (fake_rng, fake_x), trace_joint=False)[0]
# By default we don't register ops containing PyObjs as being effectful
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg1_1, arg0_1); arg1_1 = None
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg0_1); mul = arg0_1 = None
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
return (add,)""", # noqa: B950
)
torch.library._register_effectful_op(
"_TestOpaqueObject::noisy_inject", EffectType.ORDERED
)
try:
gm = aot_export_module(mod, (rng, fake_x), trace_joint=False)[0]
# inputs: token, rng, x
# return: token, res
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TestOpaqueObject.noisy_inject.default, arg2_1, arg1_1); arg0_1 = arg2_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
mul = torch.ops.aten.mul.Tensor(getitem_1, getitem_1); getitem_1 = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TestOpaqueObject.noisy_inject.default, mul, arg1_1); getitem = mul = arg1_1 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
add = torch.ops.aten.add.Tensor(getitem_3, getitem_3); getitem_3 = None
return (getitem_2, add)""", # noqa: B950
)
finally:
torch.library._register_effectful_op(
"_TestOpaqueObject::noisy_inject", None
)
def test_compile(self):
def foo(rng_state, x):
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
x = x * x
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
x = x + x
return x
rng = RNGState(0)
x = torch.ones(2, 3)
res = torch.compile(foo, fullgraph=True, backend="inductor")(rng, x)
self.assertFalse(torch.allclose(res, x * x + x))
backend = AotEagerAndRecordGraphs()
torch.compile(foo, fullgraph=True, backend=backend)(rng, x)
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_x_ : torch.Tensor, L_rng_state_ : __main___RNGState):
l_x_ = L_x_
l_rng_state_ = L_rng_state_
x = torch.ops._TestOpaqueObject.noisy_inject(l_x_, l_rng_state_); l_x_ = None
x_1 = x * x; x = None
x_2 = torch.ops._TestOpaqueObject.noisy_inject(x_1, l_rng_state_); x_1 = l_rng_state_ = None
x_3 = x_2 + x_2; x_2 = None
return (x_3,)""", # noqa: B950
)
self.assertExpectedInline(
backend.fw_graphs[0].code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg0_1, arg1_1); arg0_1 = None
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg1_1); mul = arg1_1 = None
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
return (add,)""", # noqa: B950
)
def test_compile_intermediate(self):
counter = Counter(0)
def foo(x, y):
z = torch.ops._TestOpaqueObject.increment_counter(counter, y)
x = x * z
z = torch.ops._TestOpaqueObject.increment_counter(counter, y)
x = x + z
return x, counter
inp = (torch.tensor(1), torch.tensor(0))
backend = AotEagerAndRecordGraphs()
opt_f = torch.compile(foo, fullgraph=True, backend=backend)
res = opt_f(*inp)
self.assertEqual(res[0], torch.tensor(3))
self.assertEqual(res[1].counter, torch.tensor(2))
res = opt_f(*inp)
self.assertEqual(res[0], torch.tensor(7))
self.assertEqual(res[1].counter, torch.tensor(4))
# counter is automatically lifted as an input
# Even though we returned counter in the eager code, it does not get
# returned in the graph because dynamo does not detect that the object
# is mutated.
self.assertExpectedInline(
backend.fw_graphs[0].code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops._TestOpaqueObject.increment_counter.default, c = arg1_1, _prev_base_index = 0, _all_bases = [arg0_1])
getitem = auto_functionalized_v2[0]
getitem_1 = auto_functionalized_v2[1]; auto_functionalized_v2 = None
mul = torch.ops.aten.mul.Tensor(arg2_1, getitem); arg2_1 = getitem = None
auto_functionalized_v2_1 = torch.ops.higher_order.auto_functionalized_v2(torch.ops._TestOpaqueObject.increment_counter.default, c = arg1_1, _prev_base_index = 0, _all_bases = [getitem_1]); arg1_1 = getitem_1 = None
getitem_2 = auto_functionalized_v2_1[0]
getitem_3 = auto_functionalized_v2_1[1]; auto_functionalized_v2_1 = None
add = torch.ops.aten.add.Tensor(mul, getitem_2); mul = getitem_2 = None
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_3); arg0_1 = getitem_3 = copy_ = None
return (add,)""", # noqa: B950
)
def test_compile_attribute(self):
counter = Counter(0)
def foo(counter, x):
x = x * x
counter.increment_counter()
return x
with self.assertRaisesRegex(
RuntimeError, "Attempted to access attributes/methods on an OpaqueObject"
):
torch.compile(foo)(counter, torch.ones(2, 3))
def bar(counter, x):
x = x * x
x += counter.counter
return x
with self.assertRaisesRegex(
RuntimeError, "Attempted to access attributes/methods on an OpaqueObject"
):
torch.compile(bar)(counter, torch.ones(2, 3))
def test_export_joint(self):
class Moo(torch.nn.Module):
def forward(self, x, y):
return x * y
register_opaque_type(Moo, "_TestOpaqueObject_Moo")
torch.library.define(
"_TestOpaqueObject::module_mul",
"(_TestOpaqueObject_Moo a, Tensor b, SymInt c) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=self.lib,
)
@torch.library.impl(
"_TestOpaqueObject::module_mul", "CompositeExplicitAutograd", lib=self.lib
)
def module_mul_impl(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor:
assert isinstance(m, Moo)
return m(a, b)
@torch.library.register_fake("_TestOpaqueObject::module_mul", lib=self.lib)
def module_mul_fake(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor:
return torch.empty_like(a)
def module_mul_setup_context(ctx, inputs, output):
m, a, b = inputs
ctx.b = b
def module_mul_backward(ctx, grad) -> torch.Tensor:
return None, grad * ctx.b, None
torch.library.register_autograd(
"_TestOpaqueObject::module_mul",
module_mul_backward,
setup_context=module_mul_setup_context,
lib=self.lib,
)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.moo = Moo()
def forward(self, x, y):
b = y.item()
return torch.ops._TestOpaqueObject.module_mul(self.moo, x, b)
inp = (torch.randn(3, requires_grad=True), torch.tensor(4))
with ExitStack() as stack:
with FakeTensorMode(shape_env=ShapeEnv()):
joint = aot_export_joint_with_descriptors(stack, M(), inp)
self.assertExpectedInline(
joint.graph_module.code.strip(),
"""\
def forward(self, primals, tangents):
primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(primals_2); primals_2 = None
_opaque_obj0 = self._opaque_obj0
module_mul = torch.ops._TestOpaqueObject.module_mul.default(_opaque_obj0, primals_1, _local_scalar_dense); _opaque_obj0 = primals_1 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, _local_scalar_dense); tangents_1 = _local_scalar_dense = None
return pytree.tree_unflatten([module_mul, mul_1, None], self._out_spec)""", # noqa: B950
)
compiled_fn = aot_compile_joint_with_descriptors(joint)
self.assertEqual(compiled_fn(*inp), M()(*inp))
instantiate_parametrized_tests(TestOpaqueObject)

View File

@ -796,27 +796,6 @@ def forward(self, x_1):
self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_T244632748(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return x + (x.shape[0] * 2)
mod = TestModule()
sample = torch.randn((5, 5)).to("cuda")
dim0 = torch.export.Dim.DYNAMIC(max=100)
dynamic_shapes = {"x": (dim0, torch.export.Dim.STATIC)}
ep = torch.export.export(mod, (sample,), dynamic_shapes=dynamic_shapes)
gm = ep.module()
symint = list(gm.graph.nodes)[3].meta["val"]
list(gm.graph.nodes)[3].replace_all_uses_with(symint)
gm.graph.eliminate_dead_code()
inductor_fx = torch._inductor.aot_compile(
gm, (sample,), options={"fx_wrapper": True, "compile_threads": 1}
)
class TestGenericProxyTensorReal(TestGenericProxyTensor):
tracing_mode = "real"

View File

@ -2439,35 +2439,6 @@ class _TorchCompileInductorWrapper:
reset_cudagraph_trees()
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
compiler_name = "aotinductor"
def __init__(self, mode, options, dynamic):
super().__init__(mode, options, dynamic)
self.apply_options({"cpp_wrapper": True})
self.apply_options({"aot_inductor.package": True})
def __call__(self, model_, inputs_):
from contextlib import nullcontext
from unittest import mock
from torch._guards import detect_fake_mode
from torch._inductor.virtualized import V
fake_mode = detect_fake_mode(inputs_)
ctx = (
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
if fake_mode
else nullcontext()
)
with (
V.set_aot_compilation(True),
ctx,
torch._inductor.config.patch("enable_autograd_for_aot", True),
):
return super().__call__(model_, inputs_)
class _TorchCompileWrapper:
def __init__(self, backend, mode, options, dynamic):
from torch._dynamo.backends.registry import lookup_backend
@ -2701,10 +2672,8 @@ def compile(
backend = bisect_backend
guard_filter_fn = None
use_aoti = False
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
use_aoti = options.pop("use_aoti", False)
if torch.compiler.is_exporting():
warnings.warn(
@ -2731,10 +2700,7 @@ def compile(
return export_wrapped_fn
if backend == "inductor":
if use_aoti:
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileWrapper(backend, mode, options, dynamic)

View File

@ -53,7 +53,6 @@ class CompileArtifacts:
argdefs: Optional[tuple[Any, ...]]
source_info: "SourceInfo"
device_type: str
backend_name: str
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
def check_compatibility(self) -> None:
@ -274,7 +273,6 @@ def aot_compile_fullgraph(
argdefs=fn.__defaults__,
source_info=source_info,
device_type=device_type,
backend_name=getattr(backend, "compiler_name", "unknown"),
)
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)

View File

@ -3657,15 +3657,5 @@
"Explanation": "Encountered triton kernel unsupported feature: {msg}",
"Hints": []
}
],
"GB0362": [
{
"Gb_type": "Attempted to access attributes/methods on an OpaqueObject",
"Context": "value={self.value}, attr={name}",
"Explanation": "Attribute/method access of OpaqueObjects is not supported.",
"Hints": [
"Use custom operators instead of direct attribute/method access."
]
}
]
}

View File

@ -56,7 +56,6 @@ from torch._guards import (
tracing,
TracingContext,
)
from torch._library.opaque_object import is_opaque_type
from torch._subclasses.fake_tensor import FakeTensor
from torch._utils_internal import signpost_event
from torch.export.dynamic_shapes import _ConstraintTarget
@ -2606,8 +2605,6 @@ class OutputGraph(OutputGraphCommon):
fake_attr_val,
)
continue
if is_opaque_type(type(node.meta["grapharg"].example)):
continue
fake = (
arg.fake_tensor if arg.fake_tensor is not None else arg.example
)

View File

@ -58,7 +58,6 @@ from torch._dynamo.utils import (
from torch._guards import TracingContext
from torch._higher_order_ops.flat_apply import flat_apply
from torch._higher_order_ops.torchbind import call_torchbind
from torch._library.opaque_object import is_opaque_type
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch._subclasses.meta_utils import is_sparse_any, safe_grad
@ -1453,32 +1452,27 @@ class VariableBuilder:
source=self.source,
)
if is_opaque_type(type(value)):
self.install_guards(GuardBuilder.TYPE_MATCH)
elif not hasattr(value, "__obj_flatten__"):
# This exists to allow a smoother transition.
# The implications are:
# The script objects won't be tracked as proxies.
# Methods on these objects won't show up in the graph.
# The original script object might be mutated.
# This exists to allow a smoother transition.
# The implications are:
# The script objects won't be tracked as proxies.
# Methods on these objects won't show up in the graph.
# The original script object might be mutated.
if not hasattr(value, "__obj_flatten__"):
return self.wrap_user_defined(value)
else:
# Install the guards on the fully qualified name of the script object
LazyVariableTracker.realize_all(
VariableBuilder(
self.tx, ScriptObjectQualifiedNameSource(self.source)
)(
value._type().qualified_name() # type: ignore[attr-defined]
)
# Install the guards on the fully qualified name of the script object
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
value._type().qualified_name() # type: ignore[attr-defined]
)
# Install the guards on the content of the script object by setting the source
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
value.__obj_flatten__()
)
)
# Install the guards on the content of the script object by setting the source
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
value.__obj_flatten__()
)
)
fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj(
self.tx.output.fake_mode, value

File diff suppressed because it is too large Load Diff

View File

@ -25,7 +25,6 @@ from typing_extensions import ParamSpec
import torch
from torch._guards import Source
from torch._library.opaque_object import is_opaque_type, OpaqueTypeStr
from torch.fx.proxy import Proxy
from .. import graph_break_hints
@ -62,7 +61,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
@classmethod
def is_matching_cls(cls, user_cls: type) -> bool:
return issubclass(user_cls, torch.ScriptObject) or is_opaque_type(user_cls)
return issubclass(user_cls, torch.ScriptObject)
@staticmethod
def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable":
@ -81,16 +80,6 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
"Dynamo cannot safely trace script object due to graph break."
)
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if getattr(self.value, "script_class_name", "") == OpaqueTypeStr:
unimplemented(
gb_type="Attempted to access attributes/methods on an OpaqueObject",
context=f"value={self.value}, attr={name}",
explanation="Attribute/method access of OpaqueObjects is not supported.",
hints=[
"Use custom operators instead of direct attribute/method access.",
],
)
from torch._higher_order_ops.torchbind import call_torchbind
from ..source import AttrSource

View File

@ -24,7 +24,6 @@ from torch._export.passes.lift_constants_pass import ConstantAttrMap
from torch._export.utils import _fakify_params_buffers
from torch._guards import Source
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import is_opaque_type
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.export import Constraint
from torch.export.dynamic_shapes import (
@ -947,9 +946,7 @@ def _fakify_script_objects(
try:
for obj, fqns in constant_attrs.items():
if torch._library.fake_class_registry._is_script_object(
obj
) or is_opaque_type(obj):
if torch._library.fake_class_registry._is_script_object(obj):
fake_script_obj = _maybe_fakify_obj(obj)
for fqn in fqns:
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)

View File

@ -511,7 +511,6 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
).post_compile(
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
compiled_fw_func._boxed_call = True
disable_amp = torch._C._is_any_autocast_enabled()
if needs_autograd:

View File

@ -8,7 +8,6 @@ from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
from torch._guards import detect_fake_mode
from torch._library.opaque_object import is_opaque_type
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
from torch.fx.experimental.symbolic_shapes import ShapeEnv
@ -47,7 +46,7 @@ def process_inputs(
hint=x,
source=source,
)
if isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)):
if isinstance(x, torch.ScriptObject):
return torch._library.fake_class_registry.maybe_to_fake_obj(
fake_mode, x
)

View File

@ -779,13 +779,49 @@ def run_joint_graph_passes_on_hops(
# TODO: invoke_subgraph should track which of its inputs static indices
# so it can propagate them to the partitioner (and use in cudagraphs)
static_lifetime_input_indices: list[int] = []
partition_fn: Callable[
..., tuple[torch.fx.GraphModule, torch.fx.GraphModule]
] = aot_config.partition_fn
used_hop_custom_partition = False
# Use hop specific partitioner_fn
if (
fw_hop_node.target == torch._higher_order_ops.invoke_subgraph
and "custom" in fw_hop_node.meta
and "partitioner" in fw_hop_node.meta["custom"]
):
hop_partition_fn = fw_hop_node.meta["custom"]["partitioner"]
if callable(hop_partition_fn):
partition_fn = hop_partition_fn # pyrefly: ignore [bad-assignment]
used_hop_custom_partition = True
else:
assert isinstance(hop_partition_fn, str)
match hop_partition_fn:
case "default_partition":
partition_fn = torch._functorch.partitioners.default_partition
case "min_cut_rematerialization_partition":
partition_fn = torch._functorch.partitioners.min_cut_rematerialization_partition
case _:
raise ValueError(
f"Unknown HOP partitioner config: {hop_partition_fn}"
)
# Step 2) and 3) - Run joint graph passes and partitioner
new_fw_hop_gm, new_bw_hop_gm = aot_config.partition_fn(
joint_hop_gm,
[],
num_fwd_outputs=num_fw_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
try:
new_fw_hop_gm, new_bw_hop_gm = partition_fn(
joint_hop_gm,
[],
num_fwd_outputs=num_fw_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
except Exception as e:
if used_hop_custom_partition:
raise RuntimeError(
f"Error in custom partition function for invoke_subgraph node {fw_hop_node.name}: {e}"
) from e
else:
raise
# Save the new forward and backward graph modules
new_hop_graphs[identifier].new_fw_hop_gm = new_fw_hop_gm

View File

@ -534,7 +534,6 @@ def create_aot_state(
stack.enter_context(autograd_fallback_mode("error"))
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
from torch._library.opaque_object import is_opaque_type
# Tracing may mutate the states the fake script object,
# so we need to duplicate the fake script objects so that subsequent tracing
@ -542,7 +541,7 @@ def create_aot_state(
def _dup_fake_script_obj(fake_flat_args):
return [
maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj)
if isinstance(arg, FakeScriptObject) or is_opaque_type(type(arg))
if isinstance(arg, FakeScriptObject)
else arg
for arg in fake_flat_args
]

View File

@ -1,13 +1,13 @@
# mypy: allow-untyped-defs
from enum import Enum
from typing import Any, Optional, Union
from weakref import WeakKeyDictionary
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.torchbind import call_torchbind
from torch._library.custom_ops import CustomOpDef
from torch._library.effects import EffectType
from torch._library.utils import RegistrationHandle
from torch._library.fake_class_registry import FakeScriptObject
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
@ -17,50 +17,39 @@ from torch.fx.experimental.proxy_tensor import (
)
_op_identifier = Union[
str,
"torch._ops.OpOverload",
"torch._library.custom_ops.CustomOpDef",
"torch._ops.HigherOrderOperator",
]
OpType = Union["torch._ops.HigherOrderOperator", "torch._ops.OpOverload"]
_EffectType = EffectType
class _EffectType(Enum):
ORDERED = "Ordered"
def _get_op_qualname(op: _op_identifier) -> str:
"""Convert an op identifier to a qualified string key."""
if isinstance(op, torch._ops.OpOverload):
return op._name
elif isinstance(op, torch._ops.HigherOrderOperator):
return f"{op.namespace}::{op.name()}"
elif isinstance(op, CustomOpDef):
return op._qualname
elif isinstance(op, str):
return op
raise ValueError(f"Invalid operator input {op}")
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
def _register_effectful_op(
op: _op_identifier, effect: Optional[EffectType]
) -> RegistrationHandle:
qualname = _get_op_qualname(op)
entry = torch._library.simple_registry.singleton.find(qualname)
handle = entry.effect.register(effect)
return handle
SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType](
[
(torch.ops.aten._print.default, _EffectType.ORDERED),
(torch.ops.aten._async_error.default, _EffectType.ORDERED),
(call_torchbind, _EffectType.ORDERED),
]
)
def _get_effect(op: _op_identifier) -> Optional[_EffectType]:
qualname = _get_op_qualname(op)
entry = torch._library.simple_registry.singleton.find(qualname)
return entry.effect.effect
def _register_effectful_op(op: OpType, effect: _EffectType):
assert isinstance(
op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
) and not has_aliasing(op)
if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
raise RuntimeError(
f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
f"trying to register a different effect type {effect}."
)
SIDE_EFFECTS[op] = effect
_register_effectful_op("aten::_print", _EffectType.ORDERED)
_register_effectful_op("aten::_async_error", _EffectType.ORDERED)
_register_effectful_op("profiler::_record_function_exit._RecordFunction", None)
_register_effectful_op(call_torchbind, _EffectType.ORDERED)
def _deregister_effectful_op(op: OpType):
if op not in SIDE_EFFECTS:
raise RuntimeError(f"Op {op} is not registered as effectful")
del SIDE_EFFECTS[op]
class WithEffects(HigherOrderOperator):
@ -89,7 +78,7 @@ class WithEffects(HigherOrderOperator):
) -> tuple[Any, ...]:
assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
assert not has_aliasing(op), "Ops with aliasing is not supported"
assert has_effects(op)
assert has_effects(op, args, kwargs)
assert isinstance(kwargs, dict)
return super().__call__(token, op, *args, **kwargs)
@ -100,7 +89,7 @@ with_effects = WithEffects()
def has_aliasing(op: OpType):
# NOT FOR PUBLIC USE
if isinstance(op, torch._ops.HigherOrderOperator):
return not _get_effect(op)
return op not in SIDE_EFFECTS
for arg in op._schema.arguments:
if arg.alias_info is not None:
@ -111,7 +100,7 @@ def has_aliasing(op: OpType):
return False
def has_effects(op) -> bool:
def has_effects(op, args, kwargs) -> bool:
# Skip over the profiler's RecordFunction as they should not show up in the graph
_skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction}
if op in _skip_ops:
@ -120,10 +109,31 @@ def has_effects(op) -> bool:
return (
isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
and not has_aliasing(op)
and _get_effect(op) is not None
and get_effect_key(op, args, kwargs) is not None
)
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
if op in SIDE_EFFECTS:
return SIDE_EFFECTS[op]
for arg in args:
if isinstance(arg, (torch.ScriptObject, FakeScriptObject)):
# Add it to the table so that next time we see the same op we don't
# have to parse through the args again
SIDE_EFFECTS[op] = _EffectType.ORDERED
return _EffectType.ORDERED
for arg in kwargs.values():
if isinstance(arg, (torch.ScriptObject, FakeScriptObject)):
# Add it to the table so that next time we see the same op we don't
# have to parse through the args again
SIDE_EFFECTS[op] = _EffectType.ORDERED
return _EffectType.ORDERED
return None
def new_token_tensor() -> torch.Tensor:
return torch.tensor([])
@ -228,7 +238,7 @@ def handle_effects(
# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
# this will create an empty tensor during proxy mode tracing if the token
# doesn't exist. But the tokens should always exist during proxy mode tracing.
key = _get_effect(op)
key = get_effect_key(op, args, kwargs)
assert key is not None
if key not in tokens:
assert allow_token_discovery, (

View File

@ -1,9 +1,11 @@
# mypy: allow-untyped-defs
import contextlib
import enum
from collections.abc import Callable
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -36,10 +38,6 @@ from torch.fx.graph_module import GraphModule
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
if TYPE_CHECKING:
from collections.abc import Callable
invoke_subgraph_counter = 0
@ -53,6 +51,31 @@ class OutputMetadata:
indexes_with_no_grad: set[int] = field(default_factory=set)
class NestedCompileBackend(enum.Enum):
INDUCTOR = "inductor"
DEFAULT = "default"
@dataclass
class NestedCompileRegionOptions:
# If default, does nothing, inherient the torch.compile backend
# If "inductor", will add {"compile_with_inductor": {"inductor_configs":config}} to HOP node meta "custom"
# If "custom" already has "compile_with_inductor", this config will override
backend: NestedCompileBackend = NestedCompileBackend.DEFAULT
# If backend == "inductor", the configs
inductor_configs: Optional[dict[str, Any]] = None
# If not None, add "partitioner" to HOP node meta.
# If Callable, directly assign the callable, but the callable cannot be pickled
# If str, the options are "default_partition" and "min_cut_rematerialization_partition".
# The HOP joint graph will be partitioned using the corresponding functions in
# torch/_functorch/partitioners.py
partitioner: Optional[Callable | str] = None
# TODO: add decomposition function
class InvokeSubgraphHOP(HigherOrderOperator):
def __init__(self) -> None:
# Invoke subgraph does not have any state, it is just a wrapper over a
@ -153,7 +176,9 @@ def invoke_subgraph_placeholder(func, *args, **kwargs):
return func(*args, **kwargs)
def mark_compile_region(fn=None):
def mark_compile_region(
fn=None, backend_options: Optional[NestedCompileRegionOptions] = None
):
"""
This wrapper instructs torch.compile to compile the wrapped region once and
reuse the compiled artifact, instead of the usual way of aggressively
@ -161,6 +186,10 @@ def mark_compile_region(fn=None):
Under the hood, it tells TorchDynamo to use InvokeSubgraph HOP for the
region. For PyTorch eager, this is a no-op.
Args:
fn: The function to wrap
backend: Optional backend to use for compiling the subgraph
"""
def wrap(func):
@ -172,6 +201,7 @@ def mark_compile_region(fn=None):
return invoke_subgraph_placeholder(inner_func, *args, **kwargs)
inner.__marked_compile_region_fn__ = func # type: ignore[attr-defined]
func.__marked_compile_region_backend__ = backend_options # type: ignore[attr-defined]
return inner

View File

@ -2122,10 +2122,6 @@ class PythonWrapperCodegen(CodeGen):
output.writeline(f"{name} = {val}")
def add_torchbind_input(name, value):
if value is None:
output.writeline(f"{name} = None")
return
import pickle
assert isinstance(value, torch.ScriptObject)

View File

@ -91,7 +91,6 @@ from torch._inductor.utils import (
tensor_is_aligned,
)
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import is_opaque_type
from torch._logging import trace_structured
from torch._utils_internal import compile_time_strobelight_meta
from torch.fx import GraphModule
@ -1640,9 +1639,7 @@ class _InProcessFxCompile(FxCompile):
# pyrefly: ignore [unbound-name]
(str, list, torch.fx.GraphModule),
), type(compiled_fn)
return CompiledAOTI(
filename=compiled_fn, device_type=graph.device_type
)
return CompiledAOTI(compiled_fn)
# TODO: Hoist this above V.aot_compilation
# pyrefly: ignore [unbound-name]
@ -2715,7 +2712,7 @@ def _compile_fx_main(
or torch._guards.TracingContext(fake_mode)
)
if V.aot_compilation and not config.enable_autograd_for_aot:
if V.aot_compilation:
from .utils import is_valid_aoti_model_name
is_valid_aoti_model_name()
@ -2748,9 +2745,7 @@ def _compile_fx_main(
node.meta["val"] = fake_mode.from_tensor(
target, static_shapes=True
)
elif isinstance(target, torch.ScriptObject) or is_opaque_type(
type(target)
):
elif isinstance(target, torch.ScriptObject):
node.meta["val"] = (
torch._library.fake_class_registry.maybe_to_fake_obj(
fake_mode, target

View File

@ -1193,8 +1193,6 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
enable_autograd_for_aot: bool = False
def get_worker_log_path() -> Optional[str]:
log_loc = None

View File

@ -883,12 +883,11 @@ def _get_optimization_cflags(
should_use_optimized_flags = not (
config.aot_inductor.debug_compile
or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1"
or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1"
)
should_add_debug_symbol_flags = (
config.aot_inductor.debug_compile
or config.aot_inductor.debug_symbols
or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1"
or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1"
)
if should_use_optimized_flags:

View File

@ -9242,9 +9242,12 @@ class EffectfulKernel(FallbackKernel):
unbacked_bindings=unbacked_bindings,
)
from torch._higher_order_ops.effects import _get_effect
from torch._higher_order_ops.effects import get_effect_key
effect_type = _get_effect(kernel)
uncovered_args = [
a.value if isinstance(a, TorchBindObject) else a for a in tensor_args
]
effect_type = get_effect_key(kernel, (*nontensor_args, *uncovered_args), kwargs)
assert effect_type is not None
self.effect_type = effect_type
self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
@ -9295,10 +9298,6 @@ class TorchBindObject(NonTensorObj):
def get_buf_bytes(self) -> int:
# Returns the sum of all tensors in the flattened object
real_script_obj = self.get_real_obj()
if real_script_obj is None:
return 0
assert hasattr(real_script_obj, "__obj_flatten__")
flat_dict = dict(real_script_obj.__obj_flatten__())
flat_elems = pytree.tree_flatten(flat_dict)[0]

View File

@ -26,7 +26,6 @@ import torch.utils._pytree as pytree
from torch._dynamo.utils import counters
from torch._higher_order_ops.associative_scan import associative_scan_op
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.utils import get_layout_constraint_tag
from torch._prims_common import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated]
canonicalize_dim,
@ -2705,8 +2704,6 @@ def require_channels_last(_, *args, **kwargs):
def constrain_to_fake_tensor(arg, fake_arg):
if isinstance(fake_arg, FakeScriptObject):
return arg
if isinstance(arg, ir.IRNode):
meta_stride_expr = [
s.node.expr if isinstance(s, torch.SymInt) else s for s in fake_arg.stride()
@ -7456,9 +7453,9 @@ def _sink_tokens(tokens):
def with_effects(token, op, *args, **kwargs):
result = ir.EffectfulKernel.create(op, *args, **kwargs)
from torch._higher_order_ops.effects import _get_effect
from torch._higher_order_ops.effects import get_effect_key
effect_type = _get_effect(op)
effect_type = get_effect_key(op, args, kwargs)
assert effect_type is not None
effectful_kernel = V.graph.effectful_ops[effect_type]

View File

@ -773,83 +773,9 @@ class CompiledAOTI(OutputCode):
"""
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
device_type: str
current_callable: Optional[Callable[..., Any]] = None
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
def __post_init__(self):
if not config.aot_inductor.link_libtorch:
return
if (
torch._inductor.cpp_builder._IS_MACOS
or torch._inductor.cpp_builder._IS_WINDOWS
):
return
if config.aot_inductor.cross_target_platform == "windows":
return
if config.aot_inductor.package_cpp_only:
return
if isinstance(self.filename, list):
current_callable = next(
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
)
else:
current_callable = self.filename
if isinstance(current_callable, torch.fx.GraphModule):
self.current_callable = current_callable
return
if self.device_type.startswith("cuda"):
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
current_callable,
1,
self.device_type,
"",
True,
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
elif self.device_type == "cpu":
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
current_callable, 1
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
else:
raise RuntimeError(f"unsupported device type {self.device_type}")
self.current_callable = current_callable
self._boxed_call = True
for file in self._cached_files:
if not os.path.exists(file):
with open(file, "wb") as f:
f.write(self._cached_files[file])
def __call__(self, inputs: Sequence[Any]) -> Any:
if self.current_callable is None:
raise RuntimeError("AOTInductor compiled so is not loaded")
return self.current_callable(inputs)
def prepare_for_serialization(self) -> None:
self.current_callable = None
self._cached_files = {}
filenames: list[str] = []
if isinstance(self.filename, list):
filenames = self.filename # type: ignore[assignment]
elif isinstance(self.filename, str):
filenames = [self.filename]
for name in filenames:
with open(name, "rb") as f:
self._cached_files[name] = f.read()
def __getstate__(self):
state = self.__dict__.copy()
state["current_callable"] = None
return state
raise NotImplementedError("NYI")
def post_compile(
self,
@ -857,8 +783,10 @@ class CompiledAOTI(OutputCode):
constants: CompiledFxGraphConstants,
graph_kwargs: _CompileFxKwargs,
) -> None:
if self.current_callable is None:
self.__post_init__()
pass
def prepare_for_serialization(self) -> None:
pass
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass

View File

@ -6107,15 +6107,14 @@ class Scheduler:
If config.benchmark_fusion is False, always return True.
Otherwise, return True if fusion can brings speedup.
"""
if not config.benchmark_combo_kernel:
return True
subkernel_nodes = nodes
device = subkernel_nodes[0].get_device()
# don't support benchmark fusion for CPU C++ backend right now.
if device is None or (device.type == "cpu" and config.cpu_backend != "triton"):
return False
if not config.benchmark_combo_kernel:
return True
from triton.compiler.errors import CompilationError

View File

@ -13,7 +13,6 @@ from torch.types import _dtype
from torch.utils._exposed_in import exposed_in
from . import autograd, utils
from .effects import EffectType
device_types_t = Optional[Union[str, Sequence[str]]]
@ -472,9 +471,6 @@ class CustomOpDef:
self._abstract_fn = fn
return fn
def register_effect(self, effect: Optional[EffectType]) -> None:
self._lib._register_effectful_op(self._qualname, effect)
def register_torch_dispatch(
self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
) -> Callable:

View File

@ -1,68 +0,0 @@
from enum import Enum
from typing import Optional
import torch
class EffectType(Enum):
ORDERED = "Ordered"
from torch._library.utils import RegistrationHandle
class EffectHolder:
"""A holder where one can register an effect impl to."""
def __init__(self, qualname: str):
self.qualname: str = qualname
self._set_default_effect()
def _set_default_effect(self) -> None:
self._effect: Optional[EffectType] = None
# If the op contains a ScriptObject input, we want to mark it as having effects
namespace, opname = torch._library.utils.parse_namespace(self.qualname)
split = opname.split(".")
if len(split) > 1:
assert len(split) == 2, (
f"Tried to split {opname} based on '.' but found more than 1 '.'"
)
opname, overload = split
else:
overload = ""
if namespace == "higher_order":
return
opname = f"{namespace}::{opname}"
if torch._C._get_operation_overload(opname, overload) is not None:
# Since we call this when destroying the library, sometimes the
# schema will be gone already at that time.
schema = torch._C._get_schema(opname, overload)
for arg in schema.arguments:
if isinstance(arg.type, torch.ClassType):
self._effect = EffectType.ORDERED
return
@property
def effect(self) -> Optional[EffectType]:
return self._effect
@effect.setter
def effect(self, _):
raise RuntimeError("Unable to directly set kernel.")
def register(self, effect: Optional[EffectType]) -> RegistrationHandle:
"""Register an effect
Returns a RegistrationHandle that one can use to de-register this
effect.
"""
self._effect = effect
def deregister_effect():
self._set_default_effect()
handle = RegistrationHandle(deregister_effect)
return handle

View File

@ -1,7 +1,6 @@
from collections.abc import Callable
from typing import Any, Optional
from .effects import EffectHolder
from .fake_impl import FakeImplHolder
from .utils import RegistrationHandle
@ -52,8 +51,6 @@ class SimpleOperatorEntry:
GenericTorchDispatchRuleHolder(qualname)
)
self.effect: EffectHolder = EffectHolder(qualname)
# For compatibility reasons. We can delete this soon.
@property
def abstract_impl(self) -> FakeImplHolder:

View File

@ -1023,7 +1023,6 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
DispatchKey.BackendSelect,
DispatchKey.PythonTLSSnapshot,
DispatchKey.PythonDispatcher,
DispatchKey.Functionalize,
]
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
@ -1047,23 +1046,17 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
def _register_as_effectful_op_temporarily(self):
from torch._higher_order_ops.effects import (
_EffectType,
_get_effect,
_register_effectful_op,
SIDE_EFFECTS,
)
try:
# We don't want to register the effect if there already exists a
# registration, especially if the registration is None (explicitly
# no effect)
register_tmp_effect = _get_effect(self) is None
handle = None
if register_tmp_effect:
handle = _register_effectful_op(self, _EffectType.ORDERED)
if self not in SIDE_EFFECTS:
_register_effectful_op(self, _EffectType.ORDERED)
yield
finally:
if register_tmp_effect:
assert handle is not None
handle.destroy()
if self in SIDE_EFFECTS:
del SIDE_EFFECTS[self]
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.

View File

@ -11,7 +11,7 @@ import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._C import _functionalization_reapply_views_tls as _reapply_views
from torch._ops import _get_dispatch_mode_pre_dispatch, TorchBindOpOverload
from torch._ops import _get_dispatch_mode_pre_dispatch
from torch._subclasses.meta_utils import is_sparse_any
from torch.utils._python_dispatch import (
_detect_infra_mode,
@ -471,7 +471,7 @@ class FunctionalTensorMode(TorchDispatchMode):
from torch._higher_order_ops.effects import handle_effects, has_effects
if has_effects(func):
if has_effects(func, args, kwargs):
assert not torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), torch._C.DispatchKey.Functionalize
)
@ -504,81 +504,65 @@ class FunctionalTensorMode(TorchDispatchMode):
- FunctionalTensor._extra_dispatch_keys
)
if isinstance(func, TorchBindOpOverload):
# When the function is a TorchBindOpOverload, meaning some of the
# inputs are FakeScriptObjects, we need to skip c++ dispatcher and
# dispatch in python because C++ dispatcher will check the schema
# and cannot recognize FakeScriptObject.
ctx = PythonFunctionalizeAPI()
fully_unwrapped_args = ctx.unwrap_tensors(args)
fully_unwrapped_kwargs = ctx.unwrap_tensors(
kwargs # pyrefly: ignore[bad-argument-type]
)
outs_unwrapped = func(
*fully_unwrapped_args,
**fully_unwrapped_kwargs,
)
outs_wrapped = ctx.wrap_tensors(outs_unwrapped)
else:
# All we want to do here is reuse the existing C++ functionalization logic.
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
try:
# By default for python functionalization (for AOTAutograd), we reapply views.
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
# All we want to do here is reuse the existing C++ functionalization logic.
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
try:
# By default for python functionalization (for AOTAutograd), we reapply views.
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
# Sometimes these functions cannot be directly dispatched to functionalize key
# because args are sometimes not functional tensors for some reason?
if func in FunctionalTensor.metadata_fns:
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
else:
# Note: [Functionalization View Replay Annotation]
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
# at the first time they are next used.
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
# to have the same annotation that the user specified on the original views. But view replay in
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
# for second_op!
#
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
# is globally set when we lazily perform view replay.
# The globally set metadata will be used to populate the fx node created for the replayed operation.
if m := torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
):
for a in pytree.tree_leaves([args, kwargs]):
if not isinstance(a, FunctionalTensor):
continue
curr_node = m.tracer.tensor_tracker[
torch._from_functional_tensor(a.elem)
].proxy.node
with fx_traceback.set_current_replay_node(curr_node):
torch._sync(a)
# Sometimes these functions cannot be directly dispatched to functionalize key
# because args are sometimes not functional tensors for some reason?
if func in FunctionalTensor.metadata_fns:
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
else:
# Note: [Functionalization View Replay Annotation]
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
# at the first time they are next used.
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
# to have the same annotation that the user specified on the original views. But view replay in
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
# for second_op!
#
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
# is globally set when we lazily perform view replay.
# The globally set metadata will be used to populate the fx node created for the replayed operation.
if m := torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
):
for a in pytree.tree_leaves([args, kwargs]):
if not isinstance(a, FunctionalTensor):
continue
curr_node = m.tracer.tensor_tracker[
torch._from_functional_tensor(a.elem)
].proxy.node
with fx_traceback.set_current_replay_node(curr_node):
torch._sync(a)
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
# back to PreDispatch later
outs_unwrapped = func._op_dk(
torch._C.DispatchKey.Functionalize,
*args_unwrapped,
**kwargs_unwrapped,
)
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
# back to PreDispatch later
outs_unwrapped = func._op_dk(
torch._C.DispatchKey.Functionalize,
*args_unwrapped,
**kwargs_unwrapped,
)
if self.export:
if func is torch.ops.aten.dropout.default:
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
finally:
torch._disable_functionalization()
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
if self.export:
if func is torch.ops.aten.dropout.default:
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
finally:
torch._disable_functionalization()
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
is_included = torch._C._dispatch_tls_is_dispatch_key_included(
torch._C.DispatchKey.Functionalize

View File

@ -1,14 +1,21 @@
# mypy: allow-untyped-defs
import io
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
import torch
from torch._higher_order_ops.invoke_subgraph import NestedCompileRegionOptions
from . import config
try:
from typing import LiteralString
except ImportError:
from typing_extensions import LiteralString
if TYPE_CHECKING:
from ._cache import CacheInfo
@ -617,7 +624,9 @@ def skip_guard_on_globals_unsafe(guard_entries):
return [not entry.is_global for entry in guard_entries]
def nested_compile_region(fn=None):
def nested_compile_region(
fn=None, backend_options: Optional[NestedCompileRegionOptions] = None
):
"""
Tells **``torch.compile``** that the marked set of operations forms a nested
compile region (which is often repeated in the full model) whose code can be
@ -626,8 +635,8 @@ def nested_compile_region(fn=None):
During **``torch.compile``** tracing, the compiler applies *hierarchical
compilation* with ``nested_compile_region``: it emits optimized code for the
marked region the first time it is encountered and re-emits (or stamps
out) the previously compiled code on every subsequent invocation. This can
marked region the first time it is encountered and re-emits (or "stamps
out") the previously compiled code on every subsequent invocation. This can
substantially reduce overall compile time for deeply-stacked,
structurally-identical components such as the transformer layers of a
large-language-model (LLM).
@ -641,13 +650,17 @@ def nested_compile_region(fn=None):
to reuse, it will transparently re-compile the region. Using it is
therefore *safe*: correctness is always preserved, and you pay the extra
compilation cost only when required.
Args:
fn: The function to wrap
backend: Optional backend to use for compiling the subgraph.
"""
from torch._higher_order_ops.invoke_subgraph import (
mark_compile_region as _mark_compile_region,
)
return _mark_compile_region(fn)
return _mark_compile_region(fn, backend_options=backend_options)
def load_compiled_function(file: io.IOBase) -> Callable[..., Any]:

View File

@ -66,12 +66,6 @@ void initAOTIRunnerBindings(PyObject* module) {
int,
const std::string&,
const std::string&>())
.def(py::init<
const std::string&,
int,
const std::string&,
const std::string&,
const bool>())
.def(
"run",
&AOTIModelContainerRunnerCuda::run,

View File

@ -18,7 +18,6 @@ import torch
import torch.utils._pytree as pytree
from torch._C import ScriptObject # type: ignore[attr-defined]
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import is_opaque_type
from ._compatibility import compatibility
from ._lazy_graph_module import _make_graph_module
@ -422,10 +421,8 @@ class Tracer(TracerBase):
# a get_attr to retrieve that tensor. Otherwise, we'll store away the
# tensor value into a special attribute on the Module s.t. we can
# retrieve it with a get_attr.
if isinstance(a, _constant_attribute_types) or is_opaque_type(type(a)):
qualname: Optional[str] = self.tensor_attrs.get(
a
) # pyrefly: ignore[no-matching-overload]
if isinstance(a, _constant_attribute_types):
qualname: Optional[str] = self.tensor_attrs.get(a)
# Tensor was not found in the Module hierarchy, stow it away in a
# special attribute and set the qualname to refer to that
@ -436,17 +433,13 @@ class Tracer(TracerBase):
base_name = "_torchbind_obj"
elif isinstance(a, pytree.TreeSpec):
base_name = "_tree_spec_constant"
elif is_opaque_type(type(a)):
base_name = "_opaque_obj"
else:
raise RuntimeError(
f"cannot create constant arg for {a} of type {type(a)}."
)
qualname = self.get_fresh_qualname(base_name)
assert isinstance(qualname, str)
self.tensor_attrs[a] = ( # pyrefly: ignore[unsupported-operation]
qualname
)
self.tensor_attrs[a] = qualname
setattr(self.root, qualname, a)
return self.create_node("get_attr", qualname, (), {})

View File

@ -84,7 +84,7 @@ if TYPE_CHECKING:
from torch._ops import OpOverload
from torch.fx._symbolic_trace import PHBase
from torch.types import BoolLikeType, FloatLikeType, IntLikeType
from torch.types import IntLikeType
__all__ = [
"PythonKeyTracer",
@ -458,7 +458,7 @@ def _sympy_handlers() -> dict[type[sympy.Expr], Callable[..., Any]]:
def _build_proxy_for_sym_expr(
tracer: _ProxyTracer, expr: sympy.Expr, out: PySymType | None = None
) -> IntLikeType | FloatLikeType | BoolLikeType | None:
) -> PySymType | None:
"""
Decompose `expr` and look for the pieces as inputs. If `out` is provided
then that will be the resulting SymNode (and `out.expr` must be the same as
@ -532,13 +532,6 @@ def _build_proxy_for_sym_expr(
assert not out
return value.value
if isinstance(expr, (int, float, bool)):
return expr
if expr.is_Integer:
return int(expr)
if expr.is_Float:
return float(expr)
args = []
for arg in expr.args:
if (arg_value := _build_proxy_for_sym_expr(tracer, arg)) is None:

View File

@ -19,7 +19,6 @@ from torch._library.custom_ops import (
CustomOpDef,
device_types_t,
)
from torch._library.effects import EffectType
from torch._library.infer_schema import infer_schema # noqa: F401
from torch._library.triton import triton_op, wrap_triton
from torch._ops import OpOverload
@ -399,22 +398,6 @@ class Library:
self.m.fallback(dispatch_key, fn, with_keyset)
def _register_effectful_op(self, op_name: str, effect: Optional[EffectType]):
"""
Registers an effect to an operator. This is used to register an op that
has side effects that is not capturable by the schema.
Args:
op_name: operator name (along with the overload) or OpOverload object.
effect: The effect of the op.
"""
from torch._higher_order_ops.effects import (
_register_effectful_op as hoo_register_effect,
)
handle = hoo_register_effect(op_name, effect)
self._registration_handles.append(handle)
def _destroy(self):
if self.m is not None:
self.m.reset()
@ -1082,44 +1065,6 @@ def register_fake(
return register(func)
def _register_effectful_op(
op: _op_identifier,
effect: Optional[EffectType],
*,
lib: Optional[Library] = None,
) -> None:
r"""
To specify that an operator has side-effects, we must register an effect
type for the operator. This will prevent graph passes in torch.compile from
reordering operations with the same effect type.
Args:
op_name: Operator name (along with the overload) or OpOverload object.
effect: Effect type to register. None means the operator is not effectful.
"""
if not isinstance(
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError(
f"register_effectful_op({op}): got unexpected type for op: {type(op)}"
)
if isinstance(op, torch._ops.OpOverload):
op = op._name
opdef = _maybe_get_opdef(op)
if opdef is not None:
opdef.register_effect(effect)
assert isinstance(op, str)
namespace, _ = torch._library.utils.parse_namespace(op)
if lib is None:
use_lib = Library(namespace, "FRAGMENT")
_keep_alive.append(use_lib)
else:
use_lib = lib
use_lib._register_effectful_op(op, effect)
def register_autograd(
op: _op_identifier,
backward: Callable,

View File

@ -37,7 +37,7 @@ import functools
import traceback
import weakref
from collections.abc import Callable
from typing import Any, Optional, TYPE_CHECKING # noqa: F401
from typing import Any, TYPE_CHECKING
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode