mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 22:25:03 +08:00
Compare commits
2 Commits
ciflow/tru
...
sy_invoke_
| Author | SHA1 | Date | |
|---|---|---|---|
| 43f24e9876 | |||
| 4407acbb20 |
@ -1681,13 +1681,14 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
|
||||
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None
|
||||
return (getitem,)
|
||||
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
|
||||
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[4, 4]"):
|
||||
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (y,)
|
||||
return (y, y)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1797,9 +1798,9 @@ class GraphModule(torch.nn.Module):
|
||||
out: "f32[4, 4]" = l_x_.sin()
|
||||
|
||||
sin_1: "f32[4, 4]" = torch.sin(o)
|
||||
cos: "f32[4, 4]" = torch.cos(sin_1)
|
||||
sin_2: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (cos, sin_2, matmul, o, out, sin_1)
|
||||
child: "f32[4, 4]" = torch.cos(sin_1)
|
||||
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (child, child_1, matmul, o, out, sin_1)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -15,16 +13,13 @@ import torch._inductor.config
|
||||
import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils.cpp_extension
|
||||
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
|
||||
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
|
||||
from torch._dynamo.exc import PackageError, Unsupported
|
||||
from torch._dynamo.package import DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
TEST_CUDA,
|
||||
)
|
||||
from torch.testing._internal.common_utils import instantiate_parametrized_tests
|
||||
|
||||
|
||||
MY_LAMBDA = lambda x: x + 1 # noqa: E731
|
||||
@ -604,92 +599,6 @@ from user code:
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti(self):
|
||||
with torch.device("cuda"):
|
||||
from torch._dynamo.hooks import Hooks
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
|
||||
fn,
|
||||
(make_inputs(), {}),
|
||||
Hooks(),
|
||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
||||
)
|
||||
|
||||
test_inputs = make_inputs()
|
||||
expected = fn(*test_inputs)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti_module(self):
|
||||
with torch.device("cuda"):
|
||||
from torch._dynamo.hooks import Hooks
|
||||
|
||||
mod = SimpleLinearModule()
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(4, 3),)
|
||||
|
||||
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
|
||||
mod,
|
||||
[ModelInput(make_inputs(), {}, [])],
|
||||
Hooks(),
|
||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
||||
)
|
||||
|
||||
def get_grads(m: torch.nn.Module):
|
||||
return {name: p.grad for name, p in m.named_parameters()}
|
||||
|
||||
original_mod = copy.deepcopy(mod)
|
||||
test_inputs = make_inputs()
|
||||
expected = mod(*test_inputs)
|
||||
expected.sum().backward()
|
||||
expected_grads = get_grads(mod)
|
||||
|
||||
actual = compiled_mod(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
serialized = compiled_mod.serialize()
|
||||
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
actual.sum().backward()
|
||||
self.assertEqual(get_grads(original_mod), expected_grads)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti_torch_compile(self):
|
||||
with torch.device("cuda"):
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch.compile(
|
||||
fn, fullgraph=True, options={"use_aoti": True}
|
||||
).aot_compile((make_inputs(), {}))
|
||||
test_inputs = make_inputs()
|
||||
expected = fn(*test_inputs)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -222,13 +222,13 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
matmul: "f32[3, 3]" = l_x_ @ l_y_
|
||||
sin: "f32[3, 3]" = matmul.sin(); matmul = None
|
||||
cos: "f32[3, 3]" = sin.cos(); sin = None
|
||||
child: "f32[3, 3]" = sin.cos(); sin = None
|
||||
|
||||
add: "f32[3, 3]" = l_x_ + l_y_
|
||||
sub: "f32[3, 3]" = l_x_ - l_y_
|
||||
child_1: "f32[3, 3]" = l_x_ + l_y_
|
||||
child_2: "f32[3, 3]" = l_x_ - l_y_
|
||||
|
||||
matmul_1: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
||||
return (cos, add, sub, matmul_1)
|
||||
child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
||||
return (child, child_1, child_2, child_3)
|
||||
""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
|
||||
@ -249,7 +249,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
# when testing with dynamic shape, symbols are lifted as input
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
|
||||
|
||||
def test_return_captured_vars(self):
|
||||
freevar1 = torch.randn(3)
|
||||
@ -267,7 +267,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
# be the input.
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
|
||||
|
||||
def test_return_captured_var_used_multiple_times(self):
|
||||
freevar = torch.randn(3)
|
||||
@ -282,7 +282,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
x = torch.randn(3)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 2)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
|
||||
|
||||
def test_capture_untracked_global(self):
|
||||
def f(x):
|
||||
@ -762,15 +762,15 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
|
||||
getitem: "f32[s77]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
child: "f32[s77]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
sin: "f32[s77]" = l_x_.sin(); l_x_ = None
|
||||
sin_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (sin, sin_1)
|
||||
child: "f32[s77]" = l_x_.sin(); l_x_ = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
""",
|
||||
)
|
||||
else:
|
||||
@ -801,15 +801,15 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
|
||||
getitem: "f32[3]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
child: "f32[3]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
sin: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
sin_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (sin, sin_1)
|
||||
child: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -922,16 +922,16 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
|
||||
getitem: "f32[3]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
child: "f32[3]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
sin: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
add: "f32[3]" = sin + size; sin = size = None
|
||||
sin_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (add, sin_1)
|
||||
child: "f32[3]" = sin + size; sin = size = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2458,10 +2458,10 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
|
||||
add: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
|
||||
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
|
||||
|
||||
add_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
|
||||
return (add, add_1)
|
||||
child_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
|
||||
return (child, child_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2655,9 +2655,9 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[2, 3]"):
|
||||
sin: "f32[2, 3]" = l_x_.sin()
|
||||
cos: "f32[2, 3]" = l_x_.cos(); l_x_ = None
|
||||
return (sin, cos)
|
||||
child: "f32[2, 3]" = l_x_.sin()
|
||||
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
|
||||
return (child, child_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2687,13 +2687,13 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
value: "f32[3]" = wrap[0]; wrap = None
|
||||
return (value,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]"):
|
||||
neg: "f32[3]" = -l_x_; l_x_ = None
|
||||
return (neg,)
|
||||
child: "f32[3]" = -l_x_; l_x_ = None
|
||||
return (child,)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -3318,17 +3318,17 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
hints_wrapper_body_1 = self.hints_wrapper_body_1
|
||||
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
|
||||
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
return (getitem,)
|
||||
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
return (res,)
|
||||
|
||||
class hints_wrapper_body_1(torch.nn.Module):
|
||||
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
||||
hints_wrapper_body_0 = self.hints_wrapper_body_0
|
||||
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
|
||||
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
|
||||
x_1: "f32[2, 4]" = torch.abs(getitem); getitem = None
|
||||
return (x_1,)
|
||||
x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None
|
||||
return (x_2,)
|
||||
|
||||
class hints_wrapper_body_0(torch.nn.Module):
|
||||
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
||||
|
||||
@ -10,6 +10,10 @@ import torch.utils.checkpoint
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._functorch._aot_autograd.autograd_cache import BundledCompiledForward
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._higher_order_ops.invoke_subgraph import (
|
||||
NestedCompileBackend,
|
||||
NestedCompileRegionOptions,
|
||||
)
|
||||
from torch._inductor.output_code import RegionalOutputCode
|
||||
from torch._inductor.test_case import run_tests
|
||||
from torch._inductor.utils import run_fw_bw_and_get_code
|
||||
@ -468,6 +472,86 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
# flex in forward and flex_backward in backward
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@parametrize("serialize", [True, False])
|
||||
def test_invoke_subgraph_regional_compile(self, serialize):
|
||||
call_test_partitioner_ct = 0
|
||||
original_default_partitioner = torch._functorch.partitioners.default_partition
|
||||
|
||||
def test_partitioner(
|
||||
*args, **kwargs
|
||||
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
|
||||
nonlocal call_test_partitioner_ct
|
||||
call_test_partitioner_ct += 1
|
||||
return original_default_partitioner(*args, **kwargs)
|
||||
|
||||
# pyrefly: ignore [not-iterable]
|
||||
if serialize:
|
||||
# Callable cannot be serialized
|
||||
torch._functorch.partitioners.default_partition = test_partitioner
|
||||
partitioner = "default_partition"
|
||||
else:
|
||||
partitioner = test_partitioner
|
||||
backend = NestedCompileRegionOptions(
|
||||
backend=NestedCompileBackend.INDUCTOR,
|
||||
inductor_configs={
|
||||
"max_autotune": True,
|
||||
"triton.cudagraphs": False,
|
||||
},
|
||||
partitioner=partitioner,
|
||||
)
|
||||
|
||||
@torch.compiler.nested_compile_region(backend_options=backend)
|
||||
def gn_with_backend(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn_without_backend(x):
|
||||
return torch.cos(x)
|
||||
|
||||
def fn(x):
|
||||
return gn_with_backend(x) + gn_without_backend(x)
|
||||
|
||||
backend = aot_eager_regional_inductor(serialize=serialize)
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
|
||||
import torch._inductor.config as inductor_config
|
||||
|
||||
# Hook to verify options
|
||||
original_compile = torch._inductor.standalone_compile
|
||||
captured_options = []
|
||||
|
||||
def verify_options(*args, **kwargs):
|
||||
options = kwargs.get("options", {})
|
||||
captured_options.append(options)
|
||||
|
||||
# Verify config is set as expected from explicit options
|
||||
assert inductor_config.max_autotune, "max_autotune should be True"
|
||||
assert not inductor_config.triton.cudagraphs, (
|
||||
"triton.cudagraphs should be False"
|
||||
)
|
||||
|
||||
return original_compile(*args, **kwargs)
|
||||
|
||||
torch._inductor.standalone_compile = verify_options
|
||||
|
||||
try:
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
# opt_fn(x)
|
||||
res, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
self.assertEqual(len(codes), 2)
|
||||
self.assertTrue("repeated_subgraph0" in codes[0])
|
||||
self.assertTrue("repeated_subgraph1" not in codes[0])
|
||||
self.assertTrue("repeated_subgraph0" in codes[1])
|
||||
self.assertTrue("repeated_subgraph1" not in codes[1])
|
||||
self.assertEqual(call_test_partitioner_ct, 1)
|
||||
true_res = fn(x)
|
||||
self.assertEqual(res, true_res)
|
||||
finally:
|
||||
torch._inductor.standalone_compile = original_compile
|
||||
torch._functorch.partitioners.default_partition = (
|
||||
original_default_partitioner
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
|
||||
class TestRegionalOutputCode(torch._inductor.test_case.TestCase):
|
||||
|
||||
@ -21,6 +21,10 @@ from torch._dynamo.testing import (
|
||||
InductorAndRecordGraphs,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch._higher_order_ops.invoke_subgraph import (
|
||||
NestedCompileBackend,
|
||||
NestedCompileRegionOptions,
|
||||
)
|
||||
from torch._higher_order_ops.schema import find_hop_schema
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.pattern_matcher import (
|
||||
@ -899,14 +903,14 @@ class GraphModule(torch.nn.Module):
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
|
||||
mul: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
|
||||
mul_1: "f32[8]" = mul * 2; mul = None
|
||||
return (mul_1,)
|
||||
child: "f32[8]" = mul * 2; mul = None
|
||||
return (child,)
|
||||
|
||||
class subgraph_1(torch.nn.Module):
|
||||
def forward(self, a: "f32[8]", l_y_: "f32[8]"):
|
||||
mul: "f32[8]" = torch.mul(a, l_y_); a = l_y_ = None
|
||||
mul_1: "f32[8]" = mul * 3; mul = None
|
||||
return (mul_1,)
|
||||
child: "f32[8]" = mul * 3; mul = None
|
||||
return (child,)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -983,20 +987,20 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
subgraph_0 = self.subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
|
||||
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
subgraph_1 = self.subgraph_0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = None
|
||||
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', x, l_y_); subgraph_1 = x = None
|
||||
x_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
subgraph_2 = self.subgraph_0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', getitem_1, l_y_); subgraph_2 = getitem_1 = None
|
||||
getitem_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', x_1, l_y_); subgraph_2 = x_1 = None
|
||||
x_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
subgraph_3 = self.subgraph_0
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', getitem_2, l_y_); subgraph_3 = getitem_2 = None
|
||||
getitem_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', x_2, l_y_); subgraph_3 = x_2 = None
|
||||
x_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
|
||||
subgraph_4 = self.subgraph_0
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', getitem_3, l_y_); subgraph_4 = getitem_3 = l_y_ = None
|
||||
getitem_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
return (getitem_4,)
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', x_3, l_y_); subgraph_4 = x_3 = l_y_ = None
|
||||
x_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
return (x_4,)
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
|
||||
@ -1495,9 +1499,9 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8, 8]"):
|
||||
mul: "f32[8, 8]" = l_x_ * 2
|
||||
mul_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
|
||||
return (mul, mul_1)
|
||||
child: "f32[8, 8]" = l_x_ * 2
|
||||
child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
|
||||
return (child, child_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1556,6 +1560,101 @@ class GraphModule(torch.nn.Module):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_unbacked_expr(self):
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return x + 1
|
||||
|
||||
def fn(c):
|
||||
d = torch.concat([c, c], dim=0)
|
||||
d = gn(d)
|
||||
return d
|
||||
|
||||
c = torch.randn((64, 32))
|
||||
torch._dynamo.decorators.mark_unbacked(c, 0)
|
||||
|
||||
ref = fn(c)
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
res = opt_fn(c)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_grad_accumulation(self):
|
||||
mod1 = torch.nn.Linear(8, 8)
|
||||
mod2 = torch.nn.Linear(8, 8)
|
||||
mod3 = torch.nn.Linear(8, 8)
|
||||
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return mod1(x) - mod2(x)
|
||||
|
||||
def fn(c):
|
||||
d = gn(c) - mod3(c)
|
||||
return d * 2
|
||||
|
||||
c = torch.randn((8, 8), requires_grad=True)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
res = opt_fn(c)
|
||||
res.sum().backward()
|
||||
|
||||
# fw_add_nodes = backend.fw_graphs[0].graph.find_nodes(op="call_function", target = torch.ops.aten.add.Tensor)
|
||||
# The gradient addition node for mod3 is not in the subgraph.
|
||||
bw_add_nodes = backend.bw_graphs[0].graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.add.Tensor
|
||||
)
|
||||
self.assertEqual(len(bw_add_nodes), 1)
|
||||
subgraph_node = backend.bw_graphs[0].graph.find_nodes(op="get_attr")[0]
|
||||
subgraph_name = subgraph_node.target
|
||||
# The gradient addition node between mod1 and mode2 will be in the subgraph
|
||||
bw_add_nodes = getattr(backend.bw_graphs[0], subgraph_name).graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.add.Tensor
|
||||
)
|
||||
self.assertEqual(len(bw_add_nodes), 1)
|
||||
|
||||
def test_backend_parameter(self):
|
||||
backend = NestedCompileRegionOptions(NestedCompileBackend.INDUCTOR)
|
||||
|
||||
# Test that backend parameter is properly set in node.meta
|
||||
@nested_compile_region(backend_options=backend)
|
||||
def gn_with_backend(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@nested_compile_region
|
||||
def gn_without_backend(x):
|
||||
return torch.cos(x)
|
||||
|
||||
def fn(x):
|
||||
return gn_with_backend(x) + gn_without_backend(x)
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
|
||||
x = torch.randn(8, 8, requires_grad=False)
|
||||
opt_fn(x)
|
||||
|
||||
# Check that we captured the graph
|
||||
self.assertEqual(len(backend.graphs), 1)
|
||||
graph = backend.graphs[0]
|
||||
|
||||
# Find invoke_subgraph nodes and check their backend metadata
|
||||
invoke_subgraph_nodes = [
|
||||
node
|
||||
for node in graph.graph.nodes
|
||||
if node.op == "call_function"
|
||||
and node.target == torch.ops.higher_order.invoke_subgraph
|
||||
]
|
||||
|
||||
# We should have 2 invoke_subgraph calls
|
||||
self.assertEqual(len(invoke_subgraph_nodes), 2)
|
||||
|
||||
# First invoke_subgraph (gn_with_backend) should have backend
|
||||
self.assertIn("custom", invoke_subgraph_nodes[0].meta)
|
||||
|
||||
# Second invoke_subgraph (gn_without_backend) should have custom=None or no custom
|
||||
backend_value = invoke_subgraph_nodes[1].meta.get("custom", None)
|
||||
self.assertIsNone(backend_value)
|
||||
|
||||
def test_complex(self):
|
||||
# Observed in Wan2.1
|
||||
@nested_compile_region
|
||||
@ -2504,107 +2603,6 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(f(x, other), f_compile(x, other))
|
||||
self.assertTrue(called)
|
||||
|
||||
def test_udf_output(self):
|
||||
class Foo:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(y)
|
||||
return Foo(a, b)
|
||||
|
||||
def fn(x, y):
|
||||
foo1 = gn(x, y)
|
||||
foo2 = gn(foo1.a, y)
|
||||
return foo1.b + foo2.a # + foo2.b
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
y = torch.randn(8, 8, requires_grad=True)
|
||||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
y_clone = y.detach().clone().requires_grad_(True)
|
||||
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x_clone, y_clone)
|
||||
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[8, 8]", L_y_: "f32[8, 8]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
subgraph_0 = self.subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph[0]
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph[1]; invoke_subgraph = None
|
||||
subgraph_1 = self.subgraph_0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = l_y_ = None
|
||||
getitem_2: "f32[8, 8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
add: "f32[8, 8]" = getitem_1 + getitem_2; getitem_1 = getitem_2 = None
|
||||
return (add,)
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8, 8]", l_y_: "f32[8, 8]"):
|
||||
a: "f32[8, 8]" = torch.sin(l_x_); l_x_ = None
|
||||
|
||||
b: "f32[8, 8]" = torch.cos(l_y_); l_y_ = None
|
||||
return (a, b)
|
||||
""",
|
||||
)
|
||||
|
||||
# High piority - grads are wrong
|
||||
@unittest.expectedFailure
|
||||
def test_grad_accuracy_check(self):
|
||||
class Foo:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(x)
|
||||
return (a, b)
|
||||
|
||||
def fn(x):
|
||||
foo1 = gn(x)
|
||||
foo2 = gn(foo1[0])
|
||||
return foo1[1] + foo2[0] + foo2[1]
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
x.grad = None
|
||||
x_clone.grad = None
|
||||
|
||||
ref = fn(x)
|
||||
res = opt_fn(x_clone)
|
||||
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a torch._dynamo test")
|
||||
@parameterized_class(
|
||||
|
||||
@ -286,31 +286,47 @@ class GraphModule(torch.nn.Module):
|
||||
l_self_modules_wo_parameters_weight_ = L_self_modules_wo_parameters_weight_
|
||||
l_self_modules_w1_parameters_weight_ = L_self_modules_w1_parameters_weight_
|
||||
l_self_modules_w2_parameters_weight_ = L_self_modules_w2_parameters_weight_
|
||||
|
||||
q: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wq_parameters_weight_, None); l_self_modules_wq_parameters_weight_ = None
|
||||
|
||||
k: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wk_parameters_weight_, None); l_self_modules_wk_parameters_weight_ = None
|
||||
|
||||
v: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wv_parameters_weight_, None); l_self_modules_wv_parameters_weight_ = None
|
||||
|
||||
unflatten: "f32[8, 16, 16, 6]" = q.unflatten(-1, (16, -1)); q = None
|
||||
q_1: "f32[8, 16, 16, 6]" = unflatten.permute(0, 2, 1, 3); unflatten = None
|
||||
|
||||
unflatten_1: "f32[8, 16, 16, 6]" = k.unflatten(-1, (16, -1)); k = None
|
||||
k_1: "f32[8, 16, 16, 6]" = unflatten_1.permute(0, 2, 1, 3); unflatten_1 = None
|
||||
|
||||
unflatten_2: "f32[8, 16, 16, 6]" = v.unflatten(-1, (16, -1)); v = None
|
||||
v_1: "f32[8, 16, 16, 6]" = unflatten_2.permute(0, 2, 1, 3); unflatten_2 = None
|
||||
|
||||
subgraph_0 = self.subgraph_0
|
||||
local_map_hop = torch.ops.higher_order.local_map_hop(subgraph_0, q_1, k_1, v_1); subgraph_0 = q_1 = k_1 = v_1 = None
|
||||
getitem: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
|
||||
permute_3: "f32[8, 16, 16, 6]" = getitem.permute(0, 2, 1, 3); getitem = None
|
||||
o: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
|
||||
o_1: "f32[8, 16, 96]" = torch._C._nn.linear(o, l_self_modules_wo_parameters_weight_, None); o = l_self_modules_wo_parameters_weight_ = None
|
||||
o0: "f32[8, 16, 96]" = o_1 + l_x_; o_1 = l_x_ = None
|
||||
o_2: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
|
||||
o_3: "f32[8, 16, 384]" = torch.nn.functional.relu(o_2); o_2 = None
|
||||
o_4: "f32[8, 16, 96]" = torch._C._nn.linear(o_3, l_self_modules_w2_parameters_weight_, None); o_3 = l_self_modules_w2_parameters_weight_ = None
|
||||
o_5: "f32[8, 16, 96]" = o0 + o_4; o0 = o_4 = None
|
||||
return (o_5,)
|
||||
o: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
|
||||
|
||||
permute_3: "f32[8, 16, 16, 6]" = o.permute(0, 2, 1, 3); o = None
|
||||
o_1: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
|
||||
|
||||
o_2: "f32[8, 16, 96]" = torch._C._nn.linear(o_1, l_self_modules_wo_parameters_weight_, None); o_1 = l_self_modules_wo_parameters_weight_ = None
|
||||
|
||||
o0: "f32[8, 16, 96]" = o_2 + l_x_; o_2 = l_x_ = None
|
||||
|
||||
o_3: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
|
||||
|
||||
o_4: "f32[8, 16, 384]" = torch.nn.functional.relu(o_3); o_3 = None
|
||||
|
||||
o_5: "f32[8, 16, 96]" = torch._C._nn.linear(o_4, l_self_modules_w2_parameters_weight_, None); o_4 = l_self_modules_w2_parameters_weight_ = None
|
||||
|
||||
o_6: "f32[8, 16, 96]" = o0 + o_5; o0 = o_5 = None
|
||||
return (o_6,)
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, q_1: "f32[1, 2, 4, 6]", k_1: "f32[1, 2, 16, 6]", v_1: "f32[1, 2, 16, 6]"):
|
||||
out: "f32[1, 2, 4, 6]" = torch._C._nn.scaled_dot_product_attention(query = q_1, key = k_1, value = v_1, is_causal = False); q_1 = k_1 = v_1 = None
|
||||
return (out,)""",
|
||||
return (out,)
|
||||
""",
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
|
||||
@ -18,16 +18,15 @@ from functorch.compile import (
|
||||
nop,
|
||||
)
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_get_effect,
|
||||
_register_effectful_op,
|
||||
with_effects,
|
||||
)
|
||||
from torch._higher_order_ops.effects import with_effects
|
||||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import SM70OrLater, SM80OrLater
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
SM70OrLater,
|
||||
SM80OrLater,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
@ -301,6 +300,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
@unittest.skipIf(IS_WINDOWS, "triton")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "triton")
|
||||
@unittest.skipIf(not SM80OrLater, "triton")
|
||||
@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
|
||||
@unittest.skipIf(not TEST_CUDA, "triton")
|
||||
@skipIfNoDynamoSupport
|
||||
def test_register_effectful_custom_op(self):
|
||||
@ -308,23 +308,41 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
torch.library.define(
|
||||
"mylib::record_scalar_tensor",
|
||||
"(Tensor x, str prefix) -> ()",
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
# global variable to store the recorded tensor and prefix.
|
||||
recorded_dict = {}
|
||||
|
||||
# Pytorch custom op implementation
|
||||
@torch.library.custom_op("mylib::record_scalar_tensor", mutates_args=())
|
||||
def record_scalar_tensor(x: torch.Tensor, prefix: str) -> None:
|
||||
# Pytorch custorm op implementation
|
||||
@torch.library.impl(
|
||||
"mylib::record_scalar_tensor",
|
||||
"CompositeExplicitAutograd",
|
||||
lib=lib,
|
||||
)
|
||||
def record_scalar_tensor(x, prefix):
|
||||
recorded_dict[prefix] = x.clone()
|
||||
return
|
||||
|
||||
# Meta function of the custom op
|
||||
@record_scalar_tensor.register_fake
|
||||
@torch.library.register_fake(
|
||||
"mylib::record_scalar_tensor",
|
||||
lib=lib,
|
||||
)
|
||||
def record_scalar_tensor_meta(x, prefix):
|
||||
return
|
||||
|
||||
record_scalar_tensor.register_effect(_EffectType.ORDERED)
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
|
||||
self.assertEqual(_get_effect(record_scalar_tensor), _EffectType.ORDERED)
|
||||
_register_effectful_op(
|
||||
torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
|
||||
)
|
||||
|
||||
my_config = {}
|
||||
my_config["MockModule"] = "mean"
|
||||
@ -451,13 +469,14 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
|
||||
torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)
|
||||
|
||||
torch.library._register_effectful_op(
|
||||
torch.ops._mylib.zoo.default, _EffectType.ORDERED
|
||||
)
|
||||
torch.library._register_effectful_op(
|
||||
torch.ops._mylib.zoo2.default, _EffectType.ORDERED
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
|
||||
_register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)
|
||||
|
||||
def fn(x, y):
|
||||
return torch.ops._mylib.zoo(x) + y
|
||||
|
||||
@ -668,13 +687,13 @@ def forward(self, arg0_1, arg1_1):
|
||||
|
||||
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)
|
||||
|
||||
handle = _register_effectful_op(
|
||||
torch.ops._mylib.foo.default, _EffectType.ORDERED
|
||||
)
|
||||
self.assertEqual(
|
||||
_get_effect(torch.ops._mylib.foo.default), _EffectType.ORDERED
|
||||
from torch._higher_order_ops.effects import (
|
||||
_deregister_effectful_op,
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
|
||||
try:
|
||||
|
||||
def fn(x, y):
|
||||
@ -760,13 +779,17 @@ def forward(self, tangents_1, tangents_2, tangents_token):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
finally:
|
||||
handle.destroy()
|
||||
|
||||
self.assertEqual(_get_effect(torch.ops._mylib.foo.default), None)
|
||||
_deregister_effectful_op(torch.ops._mylib.foo.default)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
def test_regular_effectful_op_only_in_backward(self):
|
||||
handle = _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
from torch._higher_order_ops.effects import (
|
||||
_deregister_effectful_op,
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
try:
|
||||
|
||||
def fn(x):
|
||||
@ -829,11 +852,17 @@ def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
|
||||
return (mul, mul_1, getitem_2)""",
|
||||
)
|
||||
finally:
|
||||
handle.destroy()
|
||||
_deregister_effectful_op(torch.ops.aten.cos.default)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
def test_regular_effectful_op_in_forward_and_backward(self):
|
||||
handle = _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
from torch._higher_order_ops.effects import (
|
||||
_deregister_effectful_op,
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
try:
|
||||
|
||||
def fn(x):
|
||||
@ -868,7 +897,7 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
|
||||
return (mul_1, getitem_2)""",
|
||||
)
|
||||
finally:
|
||||
handle.destroy()
|
||||
_deregister_effectful_op(torch.ops.aten.cos.default)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -136,59 +136,12 @@ class TestStandaloneInductor(TestCase):
|
||||
mod_opt = inductor.compile(mod, inp)
|
||||
self.assertEqual(mod(*inp), mod_opt(*inp))
|
||||
|
||||
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_COMPILE": "1"})
|
||||
def test_inductor_generate_debug_compile(self):
|
||||
cpp_code = """
|
||||
int main(){
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
_, source_path = write(
|
||||
cpp_code,
|
||||
"cpp",
|
||||
)
|
||||
build_option = CppOptions()
|
||||
cpp_builder = CppBuilder(
|
||||
name="test_compile",
|
||||
sources=source_path,
|
||||
output_dir=os.path.dirname(source_path),
|
||||
BuildOption=build_option,
|
||||
)
|
||||
cpp_builder.build()
|
||||
binary_path = cpp_builder.get_target_file_path()
|
||||
|
||||
"""
|
||||
When we turn on generate debug compile.
|
||||
On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG.
|
||||
On Linux, it should create some debug sections in binary file.
|
||||
"""
|
||||
|
||||
def check_linux_debug_section(module_path: str):
|
||||
check_cmd = shlex.split(f"readelf -S {module_path}")
|
||||
output = safe_command_output(check_cmd)
|
||||
has_debug_sym = ".debug_info" in output
|
||||
self.assertEqual(has_debug_sym, True)
|
||||
|
||||
def check_windows_pdb_exist(module_path: str):
|
||||
file_name_no_ext = os.path.splitext(module_path)[0]
|
||||
file_name_pdb = f"{file_name_no_ext}.pdb"
|
||||
has_pdb_file = os.path.exists(file_name_pdb)
|
||||
self.assertEqual(has_pdb_file, True)
|
||||
|
||||
if _IS_WINDOWS:
|
||||
check_windows_pdb_exist(binary_path)
|
||||
elif _IS_MACOS:
|
||||
pass # MacOS not sure that if it should be works.
|
||||
else:
|
||||
check_linux_debug_section(binary_path)
|
||||
|
||||
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_SYMBOL": "1"})
|
||||
def test_inductor_generate_debug_symbol(self):
|
||||
cpp_code = """
|
||||
int main(){
|
||||
return 0;
|
||||
}
|
||||
int main(){
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
_, source_path = write(
|
||||
|
||||
@ -4902,21 +4902,6 @@ class CommonTemplate:
|
||||
(torch.randn(2, 4, 6, 6),),
|
||||
)
|
||||
|
||||
@skip_if_gpu_halide
|
||||
@xfail_if_mps
|
||||
@config.patch(combo_kernels=True)
|
||||
def test_combo_kernel_cpu(self):
|
||||
def fn(x):
|
||||
return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d(
|
||||
x + 1, (2, 5)
|
||||
)
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
(torch.randn(2, 4, 16, 16),),
|
||||
check_lowp=False,
|
||||
)
|
||||
|
||||
@xfail_if_mps # Non-divisible input sizes are not implemented on MPS device
|
||||
def test_adaptive_avg_pool2d2(self):
|
||||
# Big kernel size, use fallback
|
||||
|
||||
@ -90,7 +90,7 @@ class TestOpaqueObject(TestCase):
|
||||
# This is not accurate since the queue could have tensors that are
|
||||
# not rank 1
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.new_dynamic_size()
|
||||
u0 = ctx.create_unbacked_symint()
|
||||
return torch.empty(u0)
|
||||
|
||||
self.lib._register_fake("queue_pop", pop_impl_fake)
|
||||
@ -107,7 +107,8 @@ class TestOpaqueObject(TestCase):
|
||||
@size_impl.register_fake
|
||||
def size_impl_fake(q: torch._C.ScriptObject) -> int:
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.new_dynamic_size()
|
||||
u0 = ctx.create_unbacked_symint()
|
||||
torch._check_is_size(u0)
|
||||
return u0
|
||||
|
||||
super().setUp()
|
||||
|
||||
@ -1,22 +1,12 @@
|
||||
# Owner(s): ["module: custom-operators"]
|
||||
|
||||
import random
|
||||
from contextlib import ExitStack
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs
|
||||
from torch._functorch.aot_autograd import (
|
||||
aot_compile_joint_with_descriptors,
|
||||
aot_export_joint_with_descriptors,
|
||||
aot_export_module,
|
||||
)
|
||||
from torch._library.effects import EffectType
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import register_opaque_type
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -51,21 +41,11 @@ class OpaqueQueue:
|
||||
|
||||
class RNGState:
|
||||
def __init__(self, seed):
|
||||
self.seed = seed
|
||||
self.rng = random.Random(self.seed)
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self, start):
|
||||
self.counter = torch.tensor(start)
|
||||
|
||||
def increment_counter(self):
|
||||
self.counter += 1
|
||||
self.rng = random.Random(seed)
|
||||
|
||||
|
||||
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
|
||||
register_opaque_type(RNGState, "_TestOpaqueObject_RNGState")
|
||||
register_opaque_type(Counter, "_TestOpaqueObject_Counter")
|
||||
|
||||
|
||||
class TestOpaqueObject(TestCase):
|
||||
@ -145,20 +125,6 @@ class TestOpaqueObject(TestCase):
|
||||
def noisy_inject_fake(x: torch.Tensor, obj: RNGState) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_TestOpaqueObject::increment_counter",
|
||||
mutates_args=["prev"],
|
||||
)
|
||||
def increment_counter_impl(c: Counter, prev: torch.Tensor) -> torch.Tensor:
|
||||
assert isinstance(c, Counter)
|
||||
prev.copy_(c.counter)
|
||||
c.increment_counter()
|
||||
return c.counter
|
||||
|
||||
@increment_counter_impl.register_fake
|
||||
def increment_counter_fake(c: Counter, prev: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(prev)
|
||||
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
@ -267,235 +233,6 @@ def forward(self, arg0_1, arg1_1):
|
||||
):
|
||||
make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3))
|
||||
|
||||
def test_aot_export(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, rng_state, x):
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x * x
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x + x
|
||||
return (x,)
|
||||
|
||||
mod = Model()
|
||||
rng = RNGState(0)
|
||||
x = torch.ones(2, 3)
|
||||
|
||||
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
|
||||
fake_rng = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, rng)
|
||||
fake_x = fake_mode.from_tensor(x)
|
||||
gm = aot_export_module(mod, (fake_rng, fake_x), trace_joint=False)[0]
|
||||
|
||||
# By default we don't register ops containing PyObjs as being effectful
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg1_1, arg0_1); arg1_1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
|
||||
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg0_1); mul = arg0_1 = None
|
||||
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
torch.library._register_effectful_op(
|
||||
"_TestOpaqueObject::noisy_inject", EffectType.ORDERED
|
||||
)
|
||||
try:
|
||||
gm = aot_export_module(mod, (rng, fake_x), trace_joint=False)[0]
|
||||
# inputs: token, rng, x
|
||||
# return: token, res
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TestOpaqueObject.noisy_inject.default, arg2_1, arg1_1); arg0_1 = arg2_1 = None
|
||||
getitem = with_effects[0]
|
||||
getitem_1 = with_effects[1]; with_effects = None
|
||||
mul = torch.ops.aten.mul.Tensor(getitem_1, getitem_1); getitem_1 = None
|
||||
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TestOpaqueObject.noisy_inject.default, mul, arg1_1); getitem = mul = arg1_1 = None
|
||||
getitem_2 = with_effects_1[0]
|
||||
getitem_3 = with_effects_1[1]; with_effects_1 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem_3, getitem_3); getitem_3 = None
|
||||
return (getitem_2, add)""", # noqa: B950
|
||||
)
|
||||
finally:
|
||||
torch.library._register_effectful_op(
|
||||
"_TestOpaqueObject::noisy_inject", None
|
||||
)
|
||||
|
||||
def test_compile(self):
|
||||
def foo(rng_state, x):
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x * x
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x + x
|
||||
return x
|
||||
|
||||
rng = RNGState(0)
|
||||
x = torch.ones(2, 3)
|
||||
|
||||
res = torch.compile(foo, fullgraph=True, backend="inductor")(rng, x)
|
||||
self.assertFalse(torch.allclose(res, x * x + x))
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
torch.compile(foo, fullgraph=True, backend=backend)(rng, x)
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, L_x_ : torch.Tensor, L_rng_state_ : __main___RNGState):
|
||||
l_x_ = L_x_
|
||||
l_rng_state_ = L_rng_state_
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(l_x_, l_rng_state_); l_x_ = None
|
||||
x_1 = x * x; x = None
|
||||
x_2 = torch.ops._TestOpaqueObject.noisy_inject(x_1, l_rng_state_); x_1 = l_rng_state_ = None
|
||||
x_3 = x_2 + x_2; x_2 = None
|
||||
return (x_3,)""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
backend.fw_graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg0_1, arg1_1); arg0_1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
|
||||
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg1_1); mul = arg1_1 = None
|
||||
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_compile_intermediate(self):
|
||||
counter = Counter(0)
|
||||
|
||||
def foo(x, y):
|
||||
z = torch.ops._TestOpaqueObject.increment_counter(counter, y)
|
||||
x = x * z
|
||||
z = torch.ops._TestOpaqueObject.increment_counter(counter, y)
|
||||
x = x + z
|
||||
return x, counter
|
||||
|
||||
inp = (torch.tensor(1), torch.tensor(0))
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_f = torch.compile(foo, fullgraph=True, backend=backend)
|
||||
res = opt_f(*inp)
|
||||
self.assertEqual(res[0], torch.tensor(3))
|
||||
self.assertEqual(res[1].counter, torch.tensor(2))
|
||||
|
||||
res = opt_f(*inp)
|
||||
self.assertEqual(res[0], torch.tensor(7))
|
||||
self.assertEqual(res[1].counter, torch.tensor(4))
|
||||
|
||||
# counter is automatically lifted as an input
|
||||
# Even though we returned counter in the eager code, it does not get
|
||||
# returned in the graph because dynamo does not detect that the object
|
||||
# is mutated.
|
||||
self.assertExpectedInline(
|
||||
backend.fw_graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops._TestOpaqueObject.increment_counter.default, c = arg1_1, _prev_base_index = 0, _all_bases = [arg0_1])
|
||||
getitem = auto_functionalized_v2[0]
|
||||
getitem_1 = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
||||
mul = torch.ops.aten.mul.Tensor(arg2_1, getitem); arg2_1 = getitem = None
|
||||
auto_functionalized_v2_1 = torch.ops.higher_order.auto_functionalized_v2(torch.ops._TestOpaqueObject.increment_counter.default, c = arg1_1, _prev_base_index = 0, _all_bases = [getitem_1]); arg1_1 = getitem_1 = None
|
||||
getitem_2 = auto_functionalized_v2_1[0]
|
||||
getitem_3 = auto_functionalized_v2_1[1]; auto_functionalized_v2_1 = None
|
||||
add = torch.ops.aten.add.Tensor(mul, getitem_2); mul = getitem_2 = None
|
||||
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_3); arg0_1 = getitem_3 = copy_ = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_compile_attribute(self):
|
||||
counter = Counter(0)
|
||||
|
||||
def foo(counter, x):
|
||||
x = x * x
|
||||
counter.increment_counter()
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Attempted to access attributes/methods on an OpaqueObject"
|
||||
):
|
||||
torch.compile(foo)(counter, torch.ones(2, 3))
|
||||
|
||||
def bar(counter, x):
|
||||
x = x * x
|
||||
x += counter.counter
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Attempted to access attributes/methods on an OpaqueObject"
|
||||
):
|
||||
torch.compile(bar)(counter, torch.ones(2, 3))
|
||||
|
||||
def test_export_joint(self):
|
||||
class Moo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x * y
|
||||
|
||||
register_opaque_type(Moo, "_TestOpaqueObject_Moo")
|
||||
|
||||
torch.library.define(
|
||||
"_TestOpaqueObject::module_mul",
|
||||
"(_TestOpaqueObject_Moo a, Tensor b, SymInt c) -> Tensor",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=self.lib,
|
||||
)
|
||||
|
||||
@torch.library.impl(
|
||||
"_TestOpaqueObject::module_mul", "CompositeExplicitAutograd", lib=self.lib
|
||||
)
|
||||
def module_mul_impl(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor:
|
||||
assert isinstance(m, Moo)
|
||||
return m(a, b)
|
||||
|
||||
@torch.library.register_fake("_TestOpaqueObject::module_mul", lib=self.lib)
|
||||
def module_mul_fake(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor:
|
||||
return torch.empty_like(a)
|
||||
|
||||
def module_mul_setup_context(ctx, inputs, output):
|
||||
m, a, b = inputs
|
||||
ctx.b = b
|
||||
|
||||
def module_mul_backward(ctx, grad) -> torch.Tensor:
|
||||
return None, grad * ctx.b, None
|
||||
|
||||
torch.library.register_autograd(
|
||||
"_TestOpaqueObject::module_mul",
|
||||
module_mul_backward,
|
||||
setup_context=module_mul_setup_context,
|
||||
lib=self.lib,
|
||||
)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.moo = Moo()
|
||||
|
||||
def forward(self, x, y):
|
||||
b = y.item()
|
||||
return torch.ops._TestOpaqueObject.module_mul(self.moo, x, b)
|
||||
|
||||
inp = (torch.randn(3, requires_grad=True), torch.tensor(4))
|
||||
with ExitStack() as stack:
|
||||
with FakeTensorMode(shape_env=ShapeEnv()):
|
||||
joint = aot_export_joint_with_descriptors(stack, M(), inp)
|
||||
self.assertExpectedInline(
|
||||
joint.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, primals, tangents):
|
||||
primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
|
||||
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(primals_2); primals_2 = None
|
||||
_opaque_obj0 = self._opaque_obj0
|
||||
module_mul = torch.ops._TestOpaqueObject.module_mul.default(_opaque_obj0, primals_1, _local_scalar_dense); _opaque_obj0 = primals_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, _local_scalar_dense); tangents_1 = _local_scalar_dense = None
|
||||
return pytree.tree_unflatten([module_mul, mul_1, None], self._out_spec)""", # noqa: B950
|
||||
)
|
||||
compiled_fn = aot_compile_joint_with_descriptors(joint)
|
||||
|
||||
self.assertEqual(compiled_fn(*inp), M()(*inp))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestOpaqueObject)
|
||||
|
||||
|
||||
@ -796,27 +796,6 @@ def forward(self, x_1):
|
||||
|
||||
self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
|
||||
def test_T244632748(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + (x.shape[0] * 2)
|
||||
|
||||
mod = TestModule()
|
||||
sample = torch.randn((5, 5)).to("cuda")
|
||||
dim0 = torch.export.Dim.DYNAMIC(max=100)
|
||||
dynamic_shapes = {"x": (dim0, torch.export.Dim.STATIC)}
|
||||
ep = torch.export.export(mod, (sample,), dynamic_shapes=dynamic_shapes)
|
||||
gm = ep.module()
|
||||
symint = list(gm.graph.nodes)[3].meta["val"]
|
||||
list(gm.graph.nodes)[3].replace_all_uses_with(symint)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
inductor_fx = torch._inductor.aot_compile(
|
||||
gm, (sample,), options={"fx_wrapper": True, "compile_threads": 1}
|
||||
)
|
||||
|
||||
|
||||
class TestGenericProxyTensorReal(TestGenericProxyTensor):
|
||||
tracing_mode = "real"
|
||||
|
||||
|
||||
@ -2439,35 +2439,6 @@ class _TorchCompileInductorWrapper:
|
||||
reset_cudagraph_trees()
|
||||
|
||||
|
||||
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
|
||||
compiler_name = "aotinductor"
|
||||
|
||||
def __init__(self, mode, options, dynamic):
|
||||
super().__init__(mode, options, dynamic)
|
||||
self.apply_options({"cpp_wrapper": True})
|
||||
self.apply_options({"aot_inductor.package": True})
|
||||
|
||||
def __call__(self, model_, inputs_):
|
||||
from contextlib import nullcontext
|
||||
from unittest import mock
|
||||
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
fake_mode = detect_fake_mode(inputs_)
|
||||
ctx = (
|
||||
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
|
||||
if fake_mode
|
||||
else nullcontext()
|
||||
)
|
||||
with (
|
||||
V.set_aot_compilation(True),
|
||||
ctx,
|
||||
torch._inductor.config.patch("enable_autograd_for_aot", True),
|
||||
):
|
||||
return super().__call__(model_, inputs_)
|
||||
|
||||
|
||||
class _TorchCompileWrapper:
|
||||
def __init__(self, backend, mode, options, dynamic):
|
||||
from torch._dynamo.backends.registry import lookup_backend
|
||||
@ -2701,10 +2672,8 @@ def compile(
|
||||
backend = bisect_backend
|
||||
|
||||
guard_filter_fn = None
|
||||
use_aoti = False
|
||||
if options and isinstance(options, dict):
|
||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||
use_aoti = options.pop("use_aoti", False)
|
||||
|
||||
if torch.compiler.is_exporting():
|
||||
warnings.warn(
|
||||
@ -2731,10 +2700,7 @@ def compile(
|
||||
return export_wrapped_fn
|
||||
|
||||
if backend == "inductor":
|
||||
if use_aoti:
|
||||
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
|
||||
|
||||
|
||||
@ -53,7 +53,6 @@ class CompileArtifacts:
|
||||
argdefs: Optional[tuple[Any, ...]]
|
||||
source_info: "SourceInfo"
|
||||
device_type: str
|
||||
backend_name: str
|
||||
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
||||
|
||||
def check_compatibility(self) -> None:
|
||||
@ -274,7 +273,6 @@ def aot_compile_fullgraph(
|
||||
argdefs=fn.__defaults__,
|
||||
source_info=source_info,
|
||||
device_type=device_type,
|
||||
backend_name=getattr(backend, "compiler_name", "unknown"),
|
||||
)
|
||||
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
|
||||
|
||||
|
||||
@ -3657,15 +3657,5 @@
|
||||
"Explanation": "Encountered triton kernel unsupported feature: {msg}",
|
||||
"Hints": []
|
||||
}
|
||||
],
|
||||
"GB0362": [
|
||||
{
|
||||
"Gb_type": "Attempted to access attributes/methods on an OpaqueObject",
|
||||
"Context": "value={self.value}, attr={name}",
|
||||
"Explanation": "Attribute/method access of OpaqueObjects is not supported.",
|
||||
"Hints": [
|
||||
"Use custom operators instead of direct attribute/method access."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -56,7 +56,6 @@ from torch._guards import (
|
||||
tracing,
|
||||
TracingContext,
|
||||
)
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch.export.dynamic_shapes import _ConstraintTarget
|
||||
@ -2606,8 +2605,6 @@ class OutputGraph(OutputGraphCommon):
|
||||
fake_attr_val,
|
||||
)
|
||||
continue
|
||||
if is_opaque_type(type(node.meta["grapharg"].example)):
|
||||
continue
|
||||
fake = (
|
||||
arg.fake_tensor if arg.fake_tensor is not None else arg.example
|
||||
)
|
||||
|
||||
@ -58,7 +58,6 @@ from torch._dynamo.utils import (
|
||||
from torch._guards import TracingContext
|
||||
from torch._higher_order_ops.flat_apply import flat_apply
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
||||
from torch._subclasses.meta_utils import is_sparse_any, safe_grad
|
||||
@ -1453,32 +1452,27 @@ class VariableBuilder:
|
||||
source=self.source,
|
||||
)
|
||||
|
||||
if is_opaque_type(type(value)):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
|
||||
elif not hasattr(value, "__obj_flatten__"):
|
||||
# This exists to allow a smoother transition.
|
||||
# The implications are:
|
||||
# The script objects won't be tracked as proxies.
|
||||
# Methods on these objects won't show up in the graph.
|
||||
# The original script object might be mutated.
|
||||
# This exists to allow a smoother transition.
|
||||
# The implications are:
|
||||
# The script objects won't be tracked as proxies.
|
||||
# Methods on these objects won't show up in the graph.
|
||||
# The original script object might be mutated.
|
||||
if not hasattr(value, "__obj_flatten__"):
|
||||
return self.wrap_user_defined(value)
|
||||
else:
|
||||
# Install the guards on the fully qualified name of the script object
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(
|
||||
self.tx, ScriptObjectQualifiedNameSource(self.source)
|
||||
)(
|
||||
value._type().qualified_name() # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# Install the guards on the fully qualified name of the script object
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
|
||||
value._type().qualified_name() # type: ignore[attr-defined]
|
||||
)
|
||||
# Install the guards on the content of the script object by setting the source
|
||||
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
|
||||
value.__obj_flatten__()
|
||||
)
|
||||
)
|
||||
# Install the guards on the content of the script object by setting the source
|
||||
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
|
||||
value.__obj_flatten__()
|
||||
)
|
||||
)
|
||||
|
||||
fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
self.tx.output.fake_mode, value
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -25,7 +25,6 @@ from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
from torch._guards import Source
|
||||
from torch._library.opaque_object import is_opaque_type, OpaqueTypeStr
|
||||
from torch.fx.proxy import Proxy
|
||||
|
||||
from .. import graph_break_hints
|
||||
@ -62,7 +61,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||
|
||||
@classmethod
|
||||
def is_matching_cls(cls, user_cls: type) -> bool:
|
||||
return issubclass(user_cls, torch.ScriptObject) or is_opaque_type(user_cls)
|
||||
return issubclass(user_cls, torch.ScriptObject)
|
||||
|
||||
@staticmethod
|
||||
def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable":
|
||||
@ -81,16 +80,6 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||
"Dynamo cannot safely trace script object due to graph break."
|
||||
)
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if getattr(self.value, "script_class_name", "") == OpaqueTypeStr:
|
||||
unimplemented(
|
||||
gb_type="Attempted to access attributes/methods on an OpaqueObject",
|
||||
context=f"value={self.value}, attr={name}",
|
||||
explanation="Attribute/method access of OpaqueObjects is not supported.",
|
||||
hints=[
|
||||
"Use custom operators instead of direct attribute/method access.",
|
||||
],
|
||||
)
|
||||
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
|
||||
from ..source import AttrSource
|
||||
|
||||
@ -24,7 +24,6 @@ from torch._export.passes.lift_constants_pass import ConstantAttrMap
|
||||
from torch._export.utils import _fakify_params_buffers
|
||||
from torch._guards import Source
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.export import Constraint
|
||||
from torch.export.dynamic_shapes import (
|
||||
@ -947,9 +946,7 @@ def _fakify_script_objects(
|
||||
|
||||
try:
|
||||
for obj, fqns in constant_attrs.items():
|
||||
if torch._library.fake_class_registry._is_script_object(
|
||||
obj
|
||||
) or is_opaque_type(obj):
|
||||
if torch._library.fake_class_registry._is_script_object(obj):
|
||||
fake_script_obj = _maybe_fakify_obj(obj)
|
||||
for fqn in fqns:
|
||||
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
|
||||
|
||||
@ -511,7 +511,6 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
|
||||
).post_compile(
|
||||
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
|
||||
)
|
||||
compiled_fw_func._boxed_call = True
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
|
||||
if needs_autograd:
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Any, Optional
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
@ -47,7 +46,7 @@ def process_inputs(
|
||||
hint=x,
|
||||
source=source,
|
||||
)
|
||||
if isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)):
|
||||
if isinstance(x, torch.ScriptObject):
|
||||
return torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
fake_mode, x
|
||||
)
|
||||
|
||||
@ -779,13 +779,49 @@ def run_joint_graph_passes_on_hops(
|
||||
# TODO: invoke_subgraph should track which of its inputs static indices
|
||||
# so it can propagate them to the partitioner (and use in cudagraphs)
|
||||
static_lifetime_input_indices: list[int] = []
|
||||
|
||||
partition_fn: Callable[
|
||||
..., tuple[torch.fx.GraphModule, torch.fx.GraphModule]
|
||||
] = aot_config.partition_fn
|
||||
|
||||
used_hop_custom_partition = False
|
||||
# Use hop specific partitioner_fn
|
||||
if (
|
||||
fw_hop_node.target == torch._higher_order_ops.invoke_subgraph
|
||||
and "custom" in fw_hop_node.meta
|
||||
and "partitioner" in fw_hop_node.meta["custom"]
|
||||
):
|
||||
hop_partition_fn = fw_hop_node.meta["custom"]["partitioner"]
|
||||
if callable(hop_partition_fn):
|
||||
partition_fn = hop_partition_fn # pyrefly: ignore [bad-assignment]
|
||||
used_hop_custom_partition = True
|
||||
else:
|
||||
assert isinstance(hop_partition_fn, str)
|
||||
match hop_partition_fn:
|
||||
case "default_partition":
|
||||
partition_fn = torch._functorch.partitioners.default_partition
|
||||
case "min_cut_rematerialization_partition":
|
||||
partition_fn = torch._functorch.partitioners.min_cut_rematerialization_partition
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unknown HOP partitioner config: {hop_partition_fn}"
|
||||
)
|
||||
|
||||
# Step 2) and 3) - Run joint graph passes and partitioner
|
||||
new_fw_hop_gm, new_bw_hop_gm = aot_config.partition_fn(
|
||||
joint_hop_gm,
|
||||
[],
|
||||
num_fwd_outputs=num_fw_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
try:
|
||||
new_fw_hop_gm, new_bw_hop_gm = partition_fn(
|
||||
joint_hop_gm,
|
||||
[],
|
||||
num_fwd_outputs=num_fw_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
except Exception as e:
|
||||
if used_hop_custom_partition:
|
||||
raise RuntimeError(
|
||||
f"Error in custom partition function for invoke_subgraph node {fw_hop_node.name}: {e}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
# Save the new forward and backward graph modules
|
||||
new_hop_graphs[identifier].new_fw_hop_gm = new_fw_hop_gm
|
||||
|
||||
@ -534,7 +534,6 @@ def create_aot_state(
|
||||
stack.enter_context(autograd_fallback_mode("error"))
|
||||
|
||||
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
|
||||
# Tracing may mutate the states the fake script object,
|
||||
# so we need to duplicate the fake script objects so that subsequent tracing
|
||||
@ -542,7 +541,7 @@ def create_aot_state(
|
||||
def _dup_fake_script_obj(fake_flat_args):
|
||||
return [
|
||||
maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj)
|
||||
if isinstance(arg, FakeScriptObject) or is_opaque_type(type(arg))
|
||||
if isinstance(arg, FakeScriptObject)
|
||||
else arg
|
||||
for arg in fake_flat_args
|
||||
]
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._library.custom_ops import CustomOpDef
|
||||
from torch._library.effects import EffectType
|
||||
from torch._library.utils import RegistrationHandle
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
@ -17,50 +17,39 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
)
|
||||
|
||||
|
||||
_op_identifier = Union[
|
||||
str,
|
||||
"torch._ops.OpOverload",
|
||||
"torch._library.custom_ops.CustomOpDef",
|
||||
"torch._ops.HigherOrderOperator",
|
||||
]
|
||||
OpType = Union["torch._ops.HigherOrderOperator", "torch._ops.OpOverload"]
|
||||
|
||||
_EffectType = EffectType
|
||||
class _EffectType(Enum):
|
||||
ORDERED = "Ordered"
|
||||
|
||||
|
||||
def _get_op_qualname(op: _op_identifier) -> str:
|
||||
"""Convert an op identifier to a qualified string key."""
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
return op._name
|
||||
elif isinstance(op, torch._ops.HigherOrderOperator):
|
||||
return f"{op.namespace}::{op.name()}"
|
||||
elif isinstance(op, CustomOpDef):
|
||||
return op._qualname
|
||||
elif isinstance(op, str):
|
||||
return op
|
||||
|
||||
raise ValueError(f"Invalid operator input {op}")
|
||||
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
|
||||
|
||||
|
||||
def _register_effectful_op(
|
||||
op: _op_identifier, effect: Optional[EffectType]
|
||||
) -> RegistrationHandle:
|
||||
qualname = _get_op_qualname(op)
|
||||
entry = torch._library.simple_registry.singleton.find(qualname)
|
||||
handle = entry.effect.register(effect)
|
||||
return handle
|
||||
SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType](
|
||||
[
|
||||
(torch.ops.aten._print.default, _EffectType.ORDERED),
|
||||
(torch.ops.aten._async_error.default, _EffectType.ORDERED),
|
||||
(call_torchbind, _EffectType.ORDERED),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_effect(op: _op_identifier) -> Optional[_EffectType]:
|
||||
qualname = _get_op_qualname(op)
|
||||
entry = torch._library.simple_registry.singleton.find(qualname)
|
||||
return entry.effect.effect
|
||||
def _register_effectful_op(op: OpType, effect: _EffectType):
|
||||
assert isinstance(
|
||||
op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
|
||||
) and not has_aliasing(op)
|
||||
if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
|
||||
raise RuntimeError(
|
||||
f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
|
||||
f"trying to register a different effect type {effect}."
|
||||
)
|
||||
SIDE_EFFECTS[op] = effect
|
||||
|
||||
|
||||
_register_effectful_op("aten::_print", _EffectType.ORDERED)
|
||||
_register_effectful_op("aten::_async_error", _EffectType.ORDERED)
|
||||
_register_effectful_op("profiler::_record_function_exit._RecordFunction", None)
|
||||
_register_effectful_op(call_torchbind, _EffectType.ORDERED)
|
||||
def _deregister_effectful_op(op: OpType):
|
||||
if op not in SIDE_EFFECTS:
|
||||
raise RuntimeError(f"Op {op} is not registered as effectful")
|
||||
|
||||
del SIDE_EFFECTS[op]
|
||||
|
||||
|
||||
class WithEffects(HigherOrderOperator):
|
||||
@ -89,7 +78,7 @@ class WithEffects(HigherOrderOperator):
|
||||
) -> tuple[Any, ...]:
|
||||
assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
|
||||
assert not has_aliasing(op), "Ops with aliasing is not supported"
|
||||
assert has_effects(op)
|
||||
assert has_effects(op, args, kwargs)
|
||||
assert isinstance(kwargs, dict)
|
||||
return super().__call__(token, op, *args, **kwargs)
|
||||
|
||||
@ -100,7 +89,7 @@ with_effects = WithEffects()
|
||||
def has_aliasing(op: OpType):
|
||||
# NOT FOR PUBLIC USE
|
||||
if isinstance(op, torch._ops.HigherOrderOperator):
|
||||
return not _get_effect(op)
|
||||
return op not in SIDE_EFFECTS
|
||||
|
||||
for arg in op._schema.arguments:
|
||||
if arg.alias_info is not None:
|
||||
@ -111,7 +100,7 @@ def has_aliasing(op: OpType):
|
||||
return False
|
||||
|
||||
|
||||
def has_effects(op) -> bool:
|
||||
def has_effects(op, args, kwargs) -> bool:
|
||||
# Skip over the profiler's RecordFunction as they should not show up in the graph
|
||||
_skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction}
|
||||
if op in _skip_ops:
|
||||
@ -120,10 +109,31 @@ def has_effects(op) -> bool:
|
||||
return (
|
||||
isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
|
||||
and not has_aliasing(op)
|
||||
and _get_effect(op) is not None
|
||||
and get_effect_key(op, args, kwargs) is not None
|
||||
)
|
||||
|
||||
|
||||
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
|
||||
if op in SIDE_EFFECTS:
|
||||
return SIDE_EFFECTS[op]
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, (torch.ScriptObject, FakeScriptObject)):
|
||||
# Add it to the table so that next time we see the same op we don't
|
||||
# have to parse through the args again
|
||||
SIDE_EFFECTS[op] = _EffectType.ORDERED
|
||||
return _EffectType.ORDERED
|
||||
|
||||
for arg in kwargs.values():
|
||||
if isinstance(arg, (torch.ScriptObject, FakeScriptObject)):
|
||||
# Add it to the table so that next time we see the same op we don't
|
||||
# have to parse through the args again
|
||||
SIDE_EFFECTS[op] = _EffectType.ORDERED
|
||||
return _EffectType.ORDERED
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def new_token_tensor() -> torch.Tensor:
|
||||
return torch.tensor([])
|
||||
|
||||
@ -228,7 +238,7 @@ def handle_effects(
|
||||
# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
|
||||
# this will create an empty tensor during proxy mode tracing if the token
|
||||
# doesn't exist. But the tokens should always exist during proxy mode tracing.
|
||||
key = _get_effect(op)
|
||||
key = get_effect_key(op, args, kwargs)
|
||||
assert key is not None
|
||||
if key not in tokens:
|
||||
assert allow_token_discovery, (
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import contextlib
|
||||
import enum
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -36,10 +38,6 @@ from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
invoke_subgraph_counter = 0
|
||||
|
||||
|
||||
@ -53,6 +51,31 @@ class OutputMetadata:
|
||||
indexes_with_no_grad: set[int] = field(default_factory=set)
|
||||
|
||||
|
||||
class NestedCompileBackend(enum.Enum):
|
||||
INDUCTOR = "inductor"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NestedCompileRegionOptions:
|
||||
# If default, does nothing, inherient the torch.compile backend
|
||||
# If "inductor", will add {"compile_with_inductor": {"inductor_configs":config}} to HOP node meta "custom"
|
||||
# If "custom" already has "compile_with_inductor", this config will override
|
||||
backend: NestedCompileBackend = NestedCompileBackend.DEFAULT
|
||||
|
||||
# If backend == "inductor", the configs
|
||||
inductor_configs: Optional[dict[str, Any]] = None
|
||||
|
||||
# If not None, add "partitioner" to HOP node meta.
|
||||
# If Callable, directly assign the callable, but the callable cannot be pickled
|
||||
# If str, the options are "default_partition" and "min_cut_rematerialization_partition".
|
||||
# The HOP joint graph will be partitioned using the corresponding functions in
|
||||
# torch/_functorch/partitioners.py
|
||||
partitioner: Optional[Callable | str] = None
|
||||
|
||||
# TODO: add decomposition function
|
||||
|
||||
|
||||
class InvokeSubgraphHOP(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
# Invoke subgraph does not have any state, it is just a wrapper over a
|
||||
@ -153,7 +176,9 @@ def invoke_subgraph_placeholder(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def mark_compile_region(fn=None):
|
||||
def mark_compile_region(
|
||||
fn=None, backend_options: Optional[NestedCompileRegionOptions] = None
|
||||
):
|
||||
"""
|
||||
This wrapper instructs torch.compile to compile the wrapped region once and
|
||||
reuse the compiled artifact, instead of the usual way of aggressively
|
||||
@ -161,6 +186,10 @@ def mark_compile_region(fn=None):
|
||||
|
||||
Under the hood, it tells TorchDynamo to use InvokeSubgraph HOP for the
|
||||
region. For PyTorch eager, this is a no-op.
|
||||
|
||||
Args:
|
||||
fn: The function to wrap
|
||||
backend: Optional backend to use for compiling the subgraph
|
||||
"""
|
||||
|
||||
def wrap(func):
|
||||
@ -172,6 +201,7 @@ def mark_compile_region(fn=None):
|
||||
return invoke_subgraph_placeholder(inner_func, *args, **kwargs)
|
||||
|
||||
inner.__marked_compile_region_fn__ = func # type: ignore[attr-defined]
|
||||
func.__marked_compile_region_backend__ = backend_options # type: ignore[attr-defined]
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@ -2122,10 +2122,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||
output.writeline(f"{name} = {val}")
|
||||
|
||||
def add_torchbind_input(name, value):
|
||||
if value is None:
|
||||
output.writeline(f"{name} = None")
|
||||
return
|
||||
|
||||
import pickle
|
||||
|
||||
assert isinstance(value, torch.ScriptObject)
|
||||
|
||||
@ -91,7 +91,6 @@ from torch._inductor.utils import (
|
||||
tensor_is_aligned,
|
||||
)
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._logging import trace_structured
|
||||
from torch._utils_internal import compile_time_strobelight_meta
|
||||
from torch.fx import GraphModule
|
||||
@ -1640,9 +1639,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
# pyrefly: ignore [unbound-name]
|
||||
(str, list, torch.fx.GraphModule),
|
||||
), type(compiled_fn)
|
||||
return CompiledAOTI(
|
||||
filename=compiled_fn, device_type=graph.device_type
|
||||
)
|
||||
return CompiledAOTI(compiled_fn)
|
||||
|
||||
# TODO: Hoist this above V.aot_compilation
|
||||
# pyrefly: ignore [unbound-name]
|
||||
@ -2715,7 +2712,7 @@ def _compile_fx_main(
|
||||
or torch._guards.TracingContext(fake_mode)
|
||||
)
|
||||
|
||||
if V.aot_compilation and not config.enable_autograd_for_aot:
|
||||
if V.aot_compilation:
|
||||
from .utils import is_valid_aoti_model_name
|
||||
|
||||
is_valid_aoti_model_name()
|
||||
@ -2748,9 +2745,7 @@ def _compile_fx_main(
|
||||
node.meta["val"] = fake_mode.from_tensor(
|
||||
target, static_shapes=True
|
||||
)
|
||||
elif isinstance(target, torch.ScriptObject) or is_opaque_type(
|
||||
type(target)
|
||||
):
|
||||
elif isinstance(target, torch.ScriptObject):
|
||||
node.meta["val"] = (
|
||||
torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
fake_mode, target
|
||||
|
||||
@ -1193,8 +1193,6 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
|
||||
|
||||
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
|
||||
|
||||
enable_autograd_for_aot: bool = False
|
||||
|
||||
|
||||
def get_worker_log_path() -> Optional[str]:
|
||||
log_loc = None
|
||||
|
||||
@ -883,12 +883,11 @@ def _get_optimization_cflags(
|
||||
|
||||
should_use_optimized_flags = not (
|
||||
config.aot_inductor.debug_compile
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1"
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1"
|
||||
)
|
||||
should_add_debug_symbol_flags = (
|
||||
config.aot_inductor.debug_compile
|
||||
or config.aot_inductor.debug_symbols
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1"
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1"
|
||||
)
|
||||
if should_use_optimized_flags:
|
||||
|
||||
@ -9242,9 +9242,12 @@ class EffectfulKernel(FallbackKernel):
|
||||
unbacked_bindings=unbacked_bindings,
|
||||
)
|
||||
|
||||
from torch._higher_order_ops.effects import _get_effect
|
||||
from torch._higher_order_ops.effects import get_effect_key
|
||||
|
||||
effect_type = _get_effect(kernel)
|
||||
uncovered_args = [
|
||||
a.value if isinstance(a, TorchBindObject) else a for a in tensor_args
|
||||
]
|
||||
effect_type = get_effect_key(kernel, (*nontensor_args, *uncovered_args), kwargs)
|
||||
assert effect_type is not None
|
||||
self.effect_type = effect_type
|
||||
self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
|
||||
@ -9295,10 +9298,6 @@ class TorchBindObject(NonTensorObj):
|
||||
def get_buf_bytes(self) -> int:
|
||||
# Returns the sum of all tensors in the flattened object
|
||||
real_script_obj = self.get_real_obj()
|
||||
|
||||
if real_script_obj is None:
|
||||
return 0
|
||||
|
||||
assert hasattr(real_script_obj, "__obj_flatten__")
|
||||
flat_dict = dict(real_script_obj.__obj_flatten__())
|
||||
flat_elems = pytree.tree_flatten(flat_dict)[0]
|
||||
|
||||
@ -26,7 +26,6 @@ import torch.utils._pytree as pytree
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._higher_order_ops.associative_scan import associative_scan_op
|
||||
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.utils import get_layout_constraint_tag
|
||||
from torch._prims_common import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated]
|
||||
canonicalize_dim,
|
||||
@ -2705,8 +2704,6 @@ def require_channels_last(_, *args, **kwargs):
|
||||
|
||||
|
||||
def constrain_to_fake_tensor(arg, fake_arg):
|
||||
if isinstance(fake_arg, FakeScriptObject):
|
||||
return arg
|
||||
if isinstance(arg, ir.IRNode):
|
||||
meta_stride_expr = [
|
||||
s.node.expr if isinstance(s, torch.SymInt) else s for s in fake_arg.stride()
|
||||
@ -7456,9 +7453,9 @@ def _sink_tokens(tokens):
|
||||
def with_effects(token, op, *args, **kwargs):
|
||||
result = ir.EffectfulKernel.create(op, *args, **kwargs)
|
||||
|
||||
from torch._higher_order_ops.effects import _get_effect
|
||||
from torch._higher_order_ops.effects import get_effect_key
|
||||
|
||||
effect_type = _get_effect(op)
|
||||
effect_type = get_effect_key(op, args, kwargs)
|
||||
assert effect_type is not None
|
||||
effectful_kernel = V.graph.effectful_ops[effect_type]
|
||||
|
||||
|
||||
@ -773,83 +773,9 @@ class CompiledAOTI(OutputCode):
|
||||
"""
|
||||
|
||||
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
|
||||
device_type: str
|
||||
current_callable: Optional[Callable[..., Any]] = None
|
||||
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not config.aot_inductor.link_libtorch:
|
||||
return
|
||||
|
||||
if (
|
||||
torch._inductor.cpp_builder._IS_MACOS
|
||||
or torch._inductor.cpp_builder._IS_WINDOWS
|
||||
):
|
||||
return
|
||||
|
||||
if config.aot_inductor.cross_target_platform == "windows":
|
||||
return
|
||||
|
||||
if config.aot_inductor.package_cpp_only:
|
||||
return
|
||||
|
||||
if isinstance(self.filename, list):
|
||||
current_callable = next(
|
||||
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
|
||||
)
|
||||
else:
|
||||
current_callable = self.filename
|
||||
|
||||
if isinstance(current_callable, torch.fx.GraphModule):
|
||||
self.current_callable = current_callable
|
||||
return
|
||||
|
||||
if self.device_type.startswith("cuda"):
|
||||
current_callable = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
|
||||
current_callable,
|
||||
1,
|
||||
self.device_type,
|
||||
"",
|
||||
True,
|
||||
).run # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
elif self.device_type == "cpu":
|
||||
current_callable = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
|
||||
current_callable, 1
|
||||
).run # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
else:
|
||||
raise RuntimeError(f"unsupported device type {self.device_type}")
|
||||
self.current_callable = current_callable
|
||||
self._boxed_call = True
|
||||
for file in self._cached_files:
|
||||
if not os.path.exists(file):
|
||||
with open(file, "wb") as f:
|
||||
f.write(self._cached_files[file])
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
if self.current_callable is None:
|
||||
raise RuntimeError("AOTInductor compiled so is not loaded")
|
||||
return self.current_callable(inputs)
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
self.current_callable = None
|
||||
self._cached_files = {}
|
||||
filenames: list[str] = []
|
||||
if isinstance(self.filename, list):
|
||||
filenames = self.filename # type: ignore[assignment]
|
||||
elif isinstance(self.filename, str):
|
||||
filenames = [self.filename]
|
||||
for name in filenames:
|
||||
with open(name, "rb") as f:
|
||||
self._cached_files[name] = f.read()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["current_callable"] = None
|
||||
return state
|
||||
raise NotImplementedError("NYI")
|
||||
|
||||
def post_compile(
|
||||
self,
|
||||
@ -857,8 +783,10 @@ class CompiledAOTI(OutputCode):
|
||||
constants: CompiledFxGraphConstants,
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> None:
|
||||
if self.current_callable is None:
|
||||
self.__post_init__()
|
||||
pass
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
pass
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
pass
|
||||
|
||||
@ -6107,15 +6107,14 @@ class Scheduler:
|
||||
If config.benchmark_fusion is False, always return True.
|
||||
Otherwise, return True if fusion can brings speedup.
|
||||
"""
|
||||
if not config.benchmark_combo_kernel:
|
||||
return True
|
||||
|
||||
subkernel_nodes = nodes
|
||||
device = subkernel_nodes[0].get_device()
|
||||
|
||||
# don't support benchmark fusion for CPU C++ backend right now.
|
||||
if device is None or (device.type == "cpu" and config.cpu_backend != "triton"):
|
||||
return False
|
||||
|
||||
if not config.benchmark_combo_kernel:
|
||||
return True
|
||||
|
||||
from triton.compiler.errors import CompilationError
|
||||
|
||||
@ -13,7 +13,6 @@ from torch.types import _dtype
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
from . import autograd, utils
|
||||
from .effects import EffectType
|
||||
|
||||
|
||||
device_types_t = Optional[Union[str, Sequence[str]]]
|
||||
@ -472,9 +471,6 @@ class CustomOpDef:
|
||||
self._abstract_fn = fn
|
||||
return fn
|
||||
|
||||
def register_effect(self, effect: Optional[EffectType]) -> None:
|
||||
self._lib._register_effectful_op(self._qualname, effect)
|
||||
|
||||
def register_torch_dispatch(
|
||||
self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
|
||||
) -> Callable:
|
||||
|
||||
@ -1,68 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class EffectType(Enum):
|
||||
ORDERED = "Ordered"
|
||||
|
||||
|
||||
from torch._library.utils import RegistrationHandle
|
||||
|
||||
|
||||
class EffectHolder:
|
||||
"""A holder where one can register an effect impl to."""
|
||||
|
||||
def __init__(self, qualname: str):
|
||||
self.qualname: str = qualname
|
||||
self._set_default_effect()
|
||||
|
||||
def _set_default_effect(self) -> None:
|
||||
self._effect: Optional[EffectType] = None
|
||||
|
||||
# If the op contains a ScriptObject input, we want to mark it as having effects
|
||||
namespace, opname = torch._library.utils.parse_namespace(self.qualname)
|
||||
split = opname.split(".")
|
||||
if len(split) > 1:
|
||||
assert len(split) == 2, (
|
||||
f"Tried to split {opname} based on '.' but found more than 1 '.'"
|
||||
)
|
||||
opname, overload = split
|
||||
else:
|
||||
overload = ""
|
||||
|
||||
if namespace == "higher_order":
|
||||
return
|
||||
|
||||
opname = f"{namespace}::{opname}"
|
||||
if torch._C._get_operation_overload(opname, overload) is not None:
|
||||
# Since we call this when destroying the library, sometimes the
|
||||
# schema will be gone already at that time.
|
||||
schema = torch._C._get_schema(opname, overload)
|
||||
for arg in schema.arguments:
|
||||
if isinstance(arg.type, torch.ClassType):
|
||||
self._effect = EffectType.ORDERED
|
||||
return
|
||||
|
||||
@property
|
||||
def effect(self) -> Optional[EffectType]:
|
||||
return self._effect
|
||||
|
||||
@effect.setter
|
||||
def effect(self, _):
|
||||
raise RuntimeError("Unable to directly set kernel.")
|
||||
|
||||
def register(self, effect: Optional[EffectType]) -> RegistrationHandle:
|
||||
"""Register an effect
|
||||
|
||||
Returns a RegistrationHandle that one can use to de-register this
|
||||
effect.
|
||||
"""
|
||||
self._effect = effect
|
||||
|
||||
def deregister_effect():
|
||||
self._set_default_effect()
|
||||
|
||||
handle = RegistrationHandle(deregister_effect)
|
||||
return handle
|
||||
@ -1,7 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from .effects import EffectHolder
|
||||
from .fake_impl import FakeImplHolder
|
||||
from .utils import RegistrationHandle
|
||||
|
||||
@ -52,8 +51,6 @@ class SimpleOperatorEntry:
|
||||
GenericTorchDispatchRuleHolder(qualname)
|
||||
)
|
||||
|
||||
self.effect: EffectHolder = EffectHolder(qualname)
|
||||
|
||||
# For compatibility reasons. We can delete this soon.
|
||||
@property
|
||||
def abstract_impl(self) -> FakeImplHolder:
|
||||
|
||||
@ -1023,7 +1023,6 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
|
||||
DispatchKey.BackendSelect,
|
||||
DispatchKey.PythonTLSSnapshot,
|
||||
DispatchKey.PythonDispatcher,
|
||||
DispatchKey.Functionalize,
|
||||
]
|
||||
|
||||
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
|
||||
@ -1047,23 +1046,17 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
|
||||
def _register_as_effectful_op_temporarily(self):
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_get_effect,
|
||||
_register_effectful_op,
|
||||
SIDE_EFFECTS,
|
||||
)
|
||||
|
||||
try:
|
||||
# We don't want to register the effect if there already exists a
|
||||
# registration, especially if the registration is None (explicitly
|
||||
# no effect)
|
||||
register_tmp_effect = _get_effect(self) is None
|
||||
handle = None
|
||||
if register_tmp_effect:
|
||||
handle = _register_effectful_op(self, _EffectType.ORDERED)
|
||||
if self not in SIDE_EFFECTS:
|
||||
_register_effectful_op(self, _EffectType.ORDERED)
|
||||
yield
|
||||
finally:
|
||||
if register_tmp_effect:
|
||||
assert handle is not None
|
||||
handle.destroy()
|
||||
if self in SIDE_EFFECTS:
|
||||
del SIDE_EFFECTS[self]
|
||||
|
||||
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _functionalization_reapply_views_tls as _reapply_views
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch, TorchBindOpOverload
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.utils._python_dispatch import (
|
||||
_detect_infra_mode,
|
||||
@ -471,7 +471,7 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
|
||||
from torch._higher_order_ops.effects import handle_effects, has_effects
|
||||
|
||||
if has_effects(func):
|
||||
if has_effects(func, args, kwargs):
|
||||
assert not torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
func.name(), torch._C.DispatchKey.Functionalize
|
||||
)
|
||||
@ -504,81 +504,65 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
- FunctionalTensor._extra_dispatch_keys
|
||||
)
|
||||
|
||||
if isinstance(func, TorchBindOpOverload):
|
||||
# When the function is a TorchBindOpOverload, meaning some of the
|
||||
# inputs are FakeScriptObjects, we need to skip c++ dispatcher and
|
||||
# dispatch in python because C++ dispatcher will check the schema
|
||||
# and cannot recognize FakeScriptObject.
|
||||
ctx = PythonFunctionalizeAPI()
|
||||
fully_unwrapped_args = ctx.unwrap_tensors(args)
|
||||
fully_unwrapped_kwargs = ctx.unwrap_tensors(
|
||||
kwargs # pyrefly: ignore[bad-argument-type]
|
||||
)
|
||||
outs_unwrapped = func(
|
||||
*fully_unwrapped_args,
|
||||
**fully_unwrapped_kwargs,
|
||||
)
|
||||
outs_wrapped = ctx.wrap_tensors(outs_unwrapped)
|
||||
else:
|
||||
# All we want to do here is reuse the existing C++ functionalization logic.
|
||||
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
try:
|
||||
# By default for python functionalization (for AOTAutograd), we reapply views.
|
||||
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
|
||||
# All we want to do here is reuse the existing C++ functionalization logic.
|
||||
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
try:
|
||||
# By default for python functionalization (for AOTAutograd), we reapply views.
|
||||
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
|
||||
|
||||
# Sometimes these functions cannot be directly dispatched to functionalize key
|
||||
# because args are sometimes not functional tensors for some reason?
|
||||
if func in FunctionalTensor.metadata_fns:
|
||||
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
else:
|
||||
# Note: [Functionalization View Replay Annotation]
|
||||
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
|
||||
# at the first time they are next used.
|
||||
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
|
||||
# to have the same annotation that the user specified on the original views. But view replay in
|
||||
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
|
||||
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
|
||||
# for second_op!
|
||||
#
|
||||
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
|
||||
# is globally set when we lazily perform view replay.
|
||||
# The globally set metadata will be used to populate the fx node created for the replayed operation.
|
||||
if m := torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
):
|
||||
for a in pytree.tree_leaves([args, kwargs]):
|
||||
if not isinstance(a, FunctionalTensor):
|
||||
continue
|
||||
curr_node = m.tracer.tensor_tracker[
|
||||
torch._from_functional_tensor(a.elem)
|
||||
].proxy.node
|
||||
with fx_traceback.set_current_replay_node(curr_node):
|
||||
torch._sync(a)
|
||||
# Sometimes these functions cannot be directly dispatched to functionalize key
|
||||
# because args are sometimes not functional tensors for some reason?
|
||||
if func in FunctionalTensor.metadata_fns:
|
||||
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
else:
|
||||
# Note: [Functionalization View Replay Annotation]
|
||||
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
|
||||
# at the first time they are next used.
|
||||
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
|
||||
# to have the same annotation that the user specified on the original views. But view replay in
|
||||
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
|
||||
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
|
||||
# for second_op!
|
||||
#
|
||||
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
|
||||
# is globally set when we lazily perform view replay.
|
||||
# The globally set metadata will be used to populate the fx node created for the replayed operation.
|
||||
if m := torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
):
|
||||
for a in pytree.tree_leaves([args, kwargs]):
|
||||
if not isinstance(a, FunctionalTensor):
|
||||
continue
|
||||
curr_node = m.tracer.tensor_tracker[
|
||||
torch._from_functional_tensor(a.elem)
|
||||
].proxy.node
|
||||
with fx_traceback.set_current_replay_node(curr_node):
|
||||
torch._sync(a)
|
||||
|
||||
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
|
||||
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
|
||||
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
|
||||
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
|
||||
# back to PreDispatch later
|
||||
outs_unwrapped = func._op_dk(
|
||||
torch._C.DispatchKey.Functionalize,
|
||||
*args_unwrapped,
|
||||
**kwargs_unwrapped,
|
||||
)
|
||||
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
|
||||
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
|
||||
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
|
||||
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
|
||||
# back to PreDispatch later
|
||||
outs_unwrapped = func._op_dk(
|
||||
torch._C.DispatchKey.Functionalize,
|
||||
*args_unwrapped,
|
||||
**kwargs_unwrapped,
|
||||
)
|
||||
|
||||
if self.export:
|
||||
if func is torch.ops.aten.dropout.default:
|
||||
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
|
||||
if self.export:
|
||||
if func is torch.ops.aten.dropout.default:
|
||||
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
|
||||
|
||||
is_included = torch._C._dispatch_tls_is_dispatch_key_included(
|
||||
torch._C.DispatchKey.Functionalize
|
||||
|
||||
@ -1,14 +1,21 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.invoke_subgraph import NestedCompileRegionOptions
|
||||
|
||||
from . import config
|
||||
|
||||
|
||||
try:
|
||||
from typing import LiteralString
|
||||
except ImportError:
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._cache import CacheInfo
|
||||
|
||||
@ -617,7 +624,9 @@ def skip_guard_on_globals_unsafe(guard_entries):
|
||||
return [not entry.is_global for entry in guard_entries]
|
||||
|
||||
|
||||
def nested_compile_region(fn=None):
|
||||
def nested_compile_region(
|
||||
fn=None, backend_options: Optional[NestedCompileRegionOptions] = None
|
||||
):
|
||||
"""
|
||||
Tells **``torch.compile``** that the marked set of operations forms a nested
|
||||
compile region (which is often repeated in the full model) whose code can be
|
||||
@ -626,8 +635,8 @@ def nested_compile_region(fn=None):
|
||||
|
||||
During **``torch.compile``** tracing, the compiler applies *hierarchical
|
||||
compilation* with ``nested_compile_region``: it emits optimized code for the
|
||||
marked region the first time it is encountered and re-emits (or “stamps
|
||||
out”) the previously compiled code on every subsequent invocation. This can
|
||||
marked region the first time it is encountered and re-emits (or "stamps
|
||||
out") the previously compiled code on every subsequent invocation. This can
|
||||
substantially reduce overall compile time for deeply-stacked,
|
||||
structurally-identical components such as the transformer layers of a
|
||||
large-language-model (LLM).
|
||||
@ -641,13 +650,17 @@ def nested_compile_region(fn=None):
|
||||
to reuse, it will transparently re-compile the region. Using it is
|
||||
therefore *safe*: correctness is always preserved, and you pay the extra
|
||||
compilation cost only when required.
|
||||
|
||||
Args:
|
||||
fn: The function to wrap
|
||||
backend: Optional backend to use for compiling the subgraph.
|
||||
"""
|
||||
|
||||
from torch._higher_order_ops.invoke_subgraph import (
|
||||
mark_compile_region as _mark_compile_region,
|
||||
)
|
||||
|
||||
return _mark_compile_region(fn)
|
||||
return _mark_compile_region(fn, backend_options=backend_options)
|
||||
|
||||
|
||||
def load_compiled_function(file: io.IOBase) -> Callable[..., Any]:
|
||||
|
||||
@ -66,12 +66,6 @@ void initAOTIRunnerBindings(PyObject* module) {
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&>())
|
||||
.def(py::init<
|
||||
const std::string&,
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const bool>())
|
||||
.def(
|
||||
"run",
|
||||
&AOTIModelContainerRunnerCuda::run,
|
||||
|
||||
@ -18,7 +18,6 @@ import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import ScriptObject # type: ignore[attr-defined]
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from ._lazy_graph_module import _make_graph_module
|
||||
@ -422,10 +421,8 @@ class Tracer(TracerBase):
|
||||
# a get_attr to retrieve that tensor. Otherwise, we'll store away the
|
||||
# tensor value into a special attribute on the Module s.t. we can
|
||||
# retrieve it with a get_attr.
|
||||
if isinstance(a, _constant_attribute_types) or is_opaque_type(type(a)):
|
||||
qualname: Optional[str] = self.tensor_attrs.get(
|
||||
a
|
||||
) # pyrefly: ignore[no-matching-overload]
|
||||
if isinstance(a, _constant_attribute_types):
|
||||
qualname: Optional[str] = self.tensor_attrs.get(a)
|
||||
|
||||
# Tensor was not found in the Module hierarchy, stow it away in a
|
||||
# special attribute and set the qualname to refer to that
|
||||
@ -436,17 +433,13 @@ class Tracer(TracerBase):
|
||||
base_name = "_torchbind_obj"
|
||||
elif isinstance(a, pytree.TreeSpec):
|
||||
base_name = "_tree_spec_constant"
|
||||
elif is_opaque_type(type(a)):
|
||||
base_name = "_opaque_obj"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"cannot create constant arg for {a} of type {type(a)}."
|
||||
)
|
||||
qualname = self.get_fresh_qualname(base_name)
|
||||
assert isinstance(qualname, str)
|
||||
self.tensor_attrs[a] = ( # pyrefly: ignore[unsupported-operation]
|
||||
qualname
|
||||
)
|
||||
self.tensor_attrs[a] = qualname
|
||||
setattr(self.root, qualname, a)
|
||||
|
||||
return self.create_node("get_attr", qualname, (), {})
|
||||
|
||||
@ -84,7 +84,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx._symbolic_trace import PHBase
|
||||
from torch.types import BoolLikeType, FloatLikeType, IntLikeType
|
||||
from torch.types import IntLikeType
|
||||
|
||||
__all__ = [
|
||||
"PythonKeyTracer",
|
||||
@ -458,7 +458,7 @@ def _sympy_handlers() -> dict[type[sympy.Expr], Callable[..., Any]]:
|
||||
|
||||
def _build_proxy_for_sym_expr(
|
||||
tracer: _ProxyTracer, expr: sympy.Expr, out: PySymType | None = None
|
||||
) -> IntLikeType | FloatLikeType | BoolLikeType | None:
|
||||
) -> PySymType | None:
|
||||
"""
|
||||
Decompose `expr` and look for the pieces as inputs. If `out` is provided
|
||||
then that will be the resulting SymNode (and `out.expr` must be the same as
|
||||
@ -532,13 +532,6 @@ def _build_proxy_for_sym_expr(
|
||||
assert not out
|
||||
return value.value
|
||||
|
||||
if isinstance(expr, (int, float, bool)):
|
||||
return expr
|
||||
if expr.is_Integer:
|
||||
return int(expr)
|
||||
if expr.is_Float:
|
||||
return float(expr)
|
||||
|
||||
args = []
|
||||
for arg in expr.args:
|
||||
if (arg_value := _build_proxy_for_sym_expr(tracer, arg)) is None:
|
||||
|
||||
@ -19,7 +19,6 @@ from torch._library.custom_ops import (
|
||||
CustomOpDef,
|
||||
device_types_t,
|
||||
)
|
||||
from torch._library.effects import EffectType
|
||||
from torch._library.infer_schema import infer_schema # noqa: F401
|
||||
from torch._library.triton import triton_op, wrap_triton
|
||||
from torch._ops import OpOverload
|
||||
@ -399,22 +398,6 @@ class Library:
|
||||
|
||||
self.m.fallback(dispatch_key, fn, with_keyset)
|
||||
|
||||
def _register_effectful_op(self, op_name: str, effect: Optional[EffectType]):
|
||||
"""
|
||||
Registers an effect to an operator. This is used to register an op that
|
||||
has side effects that is not capturable by the schema.
|
||||
|
||||
Args:
|
||||
op_name: operator name (along with the overload) or OpOverload object.
|
||||
effect: The effect of the op.
|
||||
"""
|
||||
from torch._higher_order_ops.effects import (
|
||||
_register_effectful_op as hoo_register_effect,
|
||||
)
|
||||
|
||||
handle = hoo_register_effect(op_name, effect)
|
||||
self._registration_handles.append(handle)
|
||||
|
||||
def _destroy(self):
|
||||
if self.m is not None:
|
||||
self.m.reset()
|
||||
@ -1082,44 +1065,6 @@ def register_fake(
|
||||
return register(func)
|
||||
|
||||
|
||||
def _register_effectful_op(
|
||||
op: _op_identifier,
|
||||
effect: Optional[EffectType],
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
To specify that an operator has side-effects, we must register an effect
|
||||
type for the operator. This will prevent graph passes in torch.compile from
|
||||
reordering operations with the same effect type.
|
||||
|
||||
Args:
|
||||
op_name: Operator name (along with the overload) or OpOverload object.
|
||||
effect: Effect type to register. None means the operator is not effectful.
|
||||
"""
|
||||
if not isinstance(
|
||||
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||
):
|
||||
raise ValueError(
|
||||
f"register_effectful_op({op}): got unexpected type for op: {type(op)}"
|
||||
)
|
||||
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
op = op._name
|
||||
opdef = _maybe_get_opdef(op)
|
||||
if opdef is not None:
|
||||
opdef.register_effect(effect)
|
||||
assert isinstance(op, str)
|
||||
|
||||
namespace, _ = torch._library.utils.parse_namespace(op)
|
||||
if lib is None:
|
||||
use_lib = Library(namespace, "FRAGMENT")
|
||||
_keep_alive.append(use_lib)
|
||||
else:
|
||||
use_lib = lib
|
||||
use_lib._register_effectful_op(op, effect)
|
||||
|
||||
|
||||
def register_autograd(
|
||||
op: _op_identifier,
|
||||
backward: Callable,
|
||||
|
||||
@ -37,7 +37,7 @@ import functools
|
||||
import traceback
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional, TYPE_CHECKING # noqa: F401
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
|
||||
Reference in New Issue
Block a user