Compare commits

...

6 Commits

Author SHA1 Message Date
eqy
3dc9878a70 Update cuDNN version to 9.10.2.21 2025-11-12 18:26:53 -08:00
8919f69362 [Inductor][2/2] Decouple flags for optimization and debug symbols (#167575)
Summary:
What: Decouple flags for compile (unoptimized build) and symbols (optimized build)
Why: Reduce confusion around naming and usage

Test Plan: Unit test & CI

Differential Revision: D86683526

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167575
Approved by: https://github.com/jansel, https://github.com/hl475
2025-11-13 00:59:15 +00:00
19c867873a [opqaue obj] Add attribute support (#167230)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167230
Approved by: https://github.com/zou3519
ghstack dependencies: #163284, #163714, #163936
2025-11-13 00:35:20 +00:00
e3dadb1d36 [opaque obj] torch.compile support (#163936)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163936
Approved by: https://github.com/zou3519
ghstack dependencies: #163284, #163714
2025-11-13 00:35:20 +00:00
c9b09a31e8 [opaque obj] Allow non-effectful scriptobjs (#163714)
Fixes functionalization so that we can run ops using ScriptObjects w/o needing effects. Previously we would run into an error when running functionalization on the TorchBindOpOverloads.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163714
Approved by: https://github.com/zou3519
ghstack dependencies: #163284
2025-11-13 00:35:20 +00:00
35571fe94b [effects] Add register_effectful_op (#163284)
Refactored register_effectful_op to return a handler to match how fake kernels are registered. This makes it easier to deregister effects

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163284
Approved by: https://github.com/zou3519
2025-11-13 00:35:20 +00:00
26 changed files with 696 additions and 219 deletions

View File

@ -129,7 +129,7 @@ function install_129 {
}
function install_128 {
CUDNN_VERSION=9.8.0.87
CUDNN_VERSION=9.10.2.21
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -18,15 +18,16 @@ from functorch.compile import (
nop,
)
from torch._functorch.aot_autograd import aot_export_module
from torch._higher_order_ops.effects import with_effects
from torch._higher_order_ops.effects import (
_EffectType,
_get_effect,
_register_effectful_op,
with_effects,
)
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
SM70OrLater,
SM80OrLater,
)
from torch.testing._internal.common_cuda import SM70OrLater, SM80OrLater
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
from torch.testing._internal.common_utils import (
IS_WINDOWS,
@ -300,7 +301,6 @@ def forward(self, arg0_1, arg1_1, arg2_1):
@unittest.skipIf(IS_WINDOWS, "triton")
@unittest.skipIf(TEST_WITH_ROCM, "triton")
@unittest.skipIf(not SM80OrLater, "triton")
@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
@unittest.skipIf(not TEST_CUDA, "triton")
@skipIfNoDynamoSupport
def test_register_effectful_custom_op(self):
@ -308,41 +308,23 @@ def forward(self, arg0_1, arg1_1, arg2_1):
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch.library.define(
"mylib::record_scalar_tensor",
"(Tensor x, str prefix) -> ()",
lib=lib,
)
# global variable to store the recorded tensor and prefix.
recorded_dict = {}
# Pytorch custorm op implementation
@torch.library.impl(
"mylib::record_scalar_tensor",
"CompositeExplicitAutograd",
lib=lib,
)
def record_scalar_tensor(x, prefix):
# Pytorch custom op implementation
@torch.library.custom_op("mylib::record_scalar_tensor", mutates_args=())
def record_scalar_tensor(x: torch.Tensor, prefix: str) -> None:
recorded_dict[prefix] = x.clone()
return
# Meta function of the custom op
@torch.library.register_fake(
"mylib::record_scalar_tensor",
lib=lib,
)
@record_scalar_tensor.register_fake
def record_scalar_tensor_meta(x, prefix):
return
from torch._higher_order_ops.effects import (
_EffectType,
_register_effectful_op,
)
record_scalar_tensor.register_effect(_EffectType.ORDERED)
_register_effectful_op(
torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
)
self.assertEqual(_get_effect(record_scalar_tensor), _EffectType.ORDERED)
my_config = {}
my_config["MockModule"] = "mean"
@ -469,13 +451,12 @@ def forward(self, arg0_1, arg1_1, arg2_1):
torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)
from torch._higher_order_ops.effects import (
_EffectType,
_register_effectful_op,
torch.library._register_effectful_op(
torch.ops._mylib.zoo.default, _EffectType.ORDERED
)
torch.library._register_effectful_op(
torch.ops._mylib.zoo2.default, _EffectType.ORDERED
)
_register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
_register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)
def fn(x, y):
return torch.ops._mylib.zoo(x) + y
@ -687,13 +668,13 @@ def forward(self, arg0_1, arg1_1):
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
handle = _register_effectful_op(
torch.ops._mylib.foo.default, _EffectType.ORDERED
)
self.assertEqual(
_get_effect(torch.ops._mylib.foo.default), _EffectType.ORDERED
)
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
try:
def fn(x, y):
@ -779,17 +760,13 @@ def forward(self, tangents_1, tangents_2, tangents_token):
else:
raise NotImplementedError
finally:
_deregister_effectful_op(torch.ops._mylib.foo.default)
handle.destroy()
self.assertEqual(_get_effect(torch.ops._mylib.foo.default), None)
@skipIfNoDynamoSupport
def test_regular_effectful_op_only_in_backward(self):
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
handle = _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
try:
def fn(x):
@ -852,17 +829,11 @@ def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
return (mul, mul_1, getitem_2)""",
)
finally:
_deregister_effectful_op(torch.ops.aten.cos.default)
handle.destroy()
@skipIfNoDynamoSupport
def test_regular_effectful_op_in_forward_and_backward(self):
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
handle = _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
try:
def fn(x):
@ -897,7 +868,7 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
return (mul_1, getitem_2)""",
)
finally:
_deregister_effectful_op(torch.ops.aten.cos.default)
handle.destroy()
if __name__ == "__main__":

View File

@ -136,12 +136,59 @@ class TestStandaloneInductor(TestCase):
mod_opt = inductor.compile(mod, inp)
self.assertEqual(mod(*inp), mod_opt(*inp))
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_COMPILE": "1"})
def test_inductor_generate_debug_compile(self):
cpp_code = """
int main(){
return 0;
}
"""
_, source_path = write(
cpp_code,
"cpp",
)
build_option = CppOptions()
cpp_builder = CppBuilder(
name="test_compile",
sources=source_path,
output_dir=os.path.dirname(source_path),
BuildOption=build_option,
)
cpp_builder.build()
binary_path = cpp_builder.get_target_file_path()
"""
When we turn on generate debug compile.
On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG.
On Linux, it should create some debug sections in binary file.
"""
def check_linux_debug_section(module_path: str):
check_cmd = shlex.split(f"readelf -S {module_path}")
output = safe_command_output(check_cmd)
has_debug_sym = ".debug_info" in output
self.assertEqual(has_debug_sym, True)
def check_windows_pdb_exist(module_path: str):
file_name_no_ext = os.path.splitext(module_path)[0]
file_name_pdb = f"{file_name_no_ext}.pdb"
has_pdb_file = os.path.exists(file_name_pdb)
self.assertEqual(has_pdb_file, True)
if _IS_WINDOWS:
check_windows_pdb_exist(binary_path)
elif _IS_MACOS:
pass # MacOS not sure that if it should be works.
else:
check_linux_debug_section(binary_path)
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_SYMBOL": "1"})
def test_inductor_generate_debug_symbol(self):
cpp_code = """
int main(){
return 0;
}
int main(){
return 0;
}
"""
_, source_path = write(

View File

@ -90,7 +90,7 @@ class TestOpaqueObject(TestCase):
# This is not accurate since the queue could have tensors that are
# not rank 1
ctx = torch._custom_op.impl.get_ctx()
u0 = ctx.create_unbacked_symint()
u0 = ctx.new_dynamic_size()
return torch.empty(u0)
self.lib._register_fake("queue_pop", pop_impl_fake)
@ -107,8 +107,7 @@ class TestOpaqueObject(TestCase):
@size_impl.register_fake
def size_impl_fake(q: torch._C.ScriptObject) -> int:
ctx = torch._custom_op.impl.get_ctx()
u0 = ctx.create_unbacked_symint()
torch._check_is_size(u0)
u0 = ctx.new_dynamic_size()
return u0
super().setUp()

View File

@ -1,12 +1,22 @@
# Owner(s): ["module: custom-operators"]
import random
from contextlib import ExitStack
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import AotEagerAndRecordGraphs
from torch._functorch.aot_autograd import (
aot_compile_joint_with_descriptors,
aot_export_joint_with_descriptors,
aot_export_module,
)
from torch._library.effects import EffectType
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import register_opaque_type
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -41,11 +51,21 @@ class OpaqueQueue:
class RNGState:
def __init__(self, seed):
self.rng = random.Random(seed)
self.seed = seed
self.rng = random.Random(self.seed)
class Counter:
def __init__(self, start):
self.counter = torch.tensor(start)
def increment_counter(self):
self.counter += 1
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
register_opaque_type(RNGState, "_TestOpaqueObject_RNGState")
register_opaque_type(Counter, "_TestOpaqueObject_Counter")
class TestOpaqueObject(TestCase):
@ -125,6 +145,20 @@ class TestOpaqueObject(TestCase):
def noisy_inject_fake(x: torch.Tensor, obj: RNGState) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.custom_op(
"_TestOpaqueObject::increment_counter",
mutates_args=["prev"],
)
def increment_counter_impl(c: Counter, prev: torch.Tensor) -> torch.Tensor:
assert isinstance(c, Counter)
prev.copy_(c.counter)
c.increment_counter()
return c.counter
@increment_counter_impl.register_fake
def increment_counter_fake(c: Counter, prev: torch.Tensor) -> torch.Tensor:
return torch.empty_like(prev)
super().setUp()
def tearDown(self):
@ -233,6 +267,235 @@ def forward(self, arg0_1, arg1_1):
):
make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3))
def test_aot_export(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, rng_state, x):
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
x = x * x
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
x = x + x
return (x,)
mod = Model()
rng = RNGState(0)
x = torch.ones(2, 3)
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
fake_rng = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, rng)
fake_x = fake_mode.from_tensor(x)
gm = aot_export_module(mod, (fake_rng, fake_x), trace_joint=False)[0]
# By default we don't register ops containing PyObjs as being effectful
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg1_1, arg0_1); arg1_1 = None
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg0_1); mul = arg0_1 = None
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
return (add,)""", # noqa: B950
)
torch.library._register_effectful_op(
"_TestOpaqueObject::noisy_inject", EffectType.ORDERED
)
try:
gm = aot_export_module(mod, (rng, fake_x), trace_joint=False)[0]
# inputs: token, rng, x
# return: token, res
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TestOpaqueObject.noisy_inject.default, arg2_1, arg1_1); arg0_1 = arg2_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
mul = torch.ops.aten.mul.Tensor(getitem_1, getitem_1); getitem_1 = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TestOpaqueObject.noisy_inject.default, mul, arg1_1); getitem = mul = arg1_1 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
add = torch.ops.aten.add.Tensor(getitem_3, getitem_3); getitem_3 = None
return (getitem_2, add)""", # noqa: B950
)
finally:
torch.library._register_effectful_op(
"_TestOpaqueObject::noisy_inject", None
)
def test_compile(self):
def foo(rng_state, x):
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
x = x * x
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
x = x + x
return x
rng = RNGState(0)
x = torch.ones(2, 3)
res = torch.compile(foo, fullgraph=True, backend="inductor")(rng, x)
self.assertFalse(torch.allclose(res, x * x + x))
backend = AotEagerAndRecordGraphs()
torch.compile(foo, fullgraph=True, backend=backend)(rng, x)
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_x_ : torch.Tensor, L_rng_state_ : __main___RNGState):
l_x_ = L_x_
l_rng_state_ = L_rng_state_
x = torch.ops._TestOpaqueObject.noisy_inject(l_x_, l_rng_state_); l_x_ = None
x_1 = x * x; x = None
x_2 = torch.ops._TestOpaqueObject.noisy_inject(x_1, l_rng_state_); x_1 = l_rng_state_ = None
x_3 = x_2 + x_2; x_2 = None
return (x_3,)""", # noqa: B950
)
self.assertExpectedInline(
backend.fw_graphs[0].code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg0_1, arg1_1); arg0_1 = None
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg1_1); mul = arg1_1 = None
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
return (add,)""", # noqa: B950
)
def test_compile_intermediate(self):
counter = Counter(0)
def foo(x, y):
z = torch.ops._TestOpaqueObject.increment_counter(counter, y)
x = x * z
z = torch.ops._TestOpaqueObject.increment_counter(counter, y)
x = x + z
return x, counter
inp = (torch.tensor(1), torch.tensor(0))
backend = AotEagerAndRecordGraphs()
opt_f = torch.compile(foo, fullgraph=True, backend=backend)
res = opt_f(*inp)
self.assertEqual(res[0], torch.tensor(3))
self.assertEqual(res[1].counter, torch.tensor(2))
res = opt_f(*inp)
self.assertEqual(res[0], torch.tensor(7))
self.assertEqual(res[1].counter, torch.tensor(4))
# counter is automatically lifted as an input
# Even though we returned counter in the eager code, it does not get
# returned in the graph because dynamo does not detect that the object
# is mutated.
self.assertExpectedInline(
backend.fw_graphs[0].code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops._TestOpaqueObject.increment_counter.default, c = arg1_1, _prev_base_index = 0, _all_bases = [arg0_1])
getitem = auto_functionalized_v2[0]
getitem_1 = auto_functionalized_v2[1]; auto_functionalized_v2 = None
mul = torch.ops.aten.mul.Tensor(arg2_1, getitem); arg2_1 = getitem = None
auto_functionalized_v2_1 = torch.ops.higher_order.auto_functionalized_v2(torch.ops._TestOpaqueObject.increment_counter.default, c = arg1_1, _prev_base_index = 0, _all_bases = [getitem_1]); arg1_1 = getitem_1 = None
getitem_2 = auto_functionalized_v2_1[0]
getitem_3 = auto_functionalized_v2_1[1]; auto_functionalized_v2_1 = None
add = torch.ops.aten.add.Tensor(mul, getitem_2); mul = getitem_2 = None
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_3); arg0_1 = getitem_3 = copy_ = None
return (add,)""", # noqa: B950
)
def test_compile_attribute(self):
counter = Counter(0)
def foo(counter, x):
x = x * x
counter.increment_counter()
return x
with self.assertRaisesRegex(
RuntimeError, "Attempted to access attributes/methods on an OpaqueObject"
):
torch.compile(foo)(counter, torch.ones(2, 3))
def bar(counter, x):
x = x * x
x += counter.counter
return x
with self.assertRaisesRegex(
RuntimeError, "Attempted to access attributes/methods on an OpaqueObject"
):
torch.compile(bar)(counter, torch.ones(2, 3))
def test_export_joint(self):
class Moo(torch.nn.Module):
def forward(self, x, y):
return x * y
register_opaque_type(Moo, "_TestOpaqueObject_Moo")
torch.library.define(
"_TestOpaqueObject::module_mul",
"(_TestOpaqueObject_Moo a, Tensor b, SymInt c) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=self.lib,
)
@torch.library.impl(
"_TestOpaqueObject::module_mul", "CompositeExplicitAutograd", lib=self.lib
)
def module_mul_impl(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor:
assert isinstance(m, Moo)
return m(a, b)
@torch.library.register_fake("_TestOpaqueObject::module_mul", lib=self.lib)
def module_mul_fake(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor:
return torch.empty_like(a)
def module_mul_setup_context(ctx, inputs, output):
m, a, b = inputs
ctx.b = b
def module_mul_backward(ctx, grad) -> torch.Tensor:
return None, grad * ctx.b, None
torch.library.register_autograd(
"_TestOpaqueObject::module_mul",
module_mul_backward,
setup_context=module_mul_setup_context,
lib=self.lib,
)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.moo = Moo()
def forward(self, x, y):
b = y.item()
return torch.ops._TestOpaqueObject.module_mul(self.moo, x, b)
inp = (torch.randn(3, requires_grad=True), torch.tensor(4))
with ExitStack() as stack:
with FakeTensorMode(shape_env=ShapeEnv()):
joint = aot_export_joint_with_descriptors(stack, M(), inp)
self.assertExpectedInline(
joint.graph_module.code.strip(),
"""\
def forward(self, primals, tangents):
primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(primals_2); primals_2 = None
_opaque_obj0 = self._opaque_obj0
module_mul = torch.ops._TestOpaqueObject.module_mul.default(_opaque_obj0, primals_1, _local_scalar_dense); _opaque_obj0 = primals_1 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, _local_scalar_dense); tangents_1 = _local_scalar_dense = None
return pytree.tree_unflatten([module_mul, mul_1, None], self._out_spec)""", # noqa: B950
)
compiled_fn = aot_compile_joint_with_descriptors(joint)
self.assertEqual(compiled_fn(*inp), M()(*inp))
instantiate_parametrized_tests(TestOpaqueObject)

View File

@ -3657,5 +3657,15 @@
"Explanation": "Encountered triton kernel unsupported feature: {msg}",
"Hints": []
}
],
"GB0362": [
{
"Gb_type": "Attempted to access attributes/methods on an OpaqueObject",
"Context": "value={self.value}, attr={name}",
"Explanation": "Attribute/method access of OpaqueObjects is not supported.",
"Hints": [
"Use custom operators instead of direct attribute/method access."
]
}
]
}

View File

@ -56,6 +56,7 @@ from torch._guards import (
tracing,
TracingContext,
)
from torch._library.opaque_object import is_opaque_type
from torch._subclasses.fake_tensor import FakeTensor
from torch._utils_internal import signpost_event
from torch.export.dynamic_shapes import _ConstraintTarget
@ -2605,6 +2606,8 @@ class OutputGraph(OutputGraphCommon):
fake_attr_val,
)
continue
if is_opaque_type(type(node.meta["grapharg"].example)):
continue
fake = (
arg.fake_tensor if arg.fake_tensor is not None else arg.example
)

View File

@ -58,6 +58,7 @@ from torch._dynamo.utils import (
from torch._guards import TracingContext
from torch._higher_order_ops.flat_apply import flat_apply
from torch._higher_order_ops.torchbind import call_torchbind
from torch._library.opaque_object import is_opaque_type
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch._subclasses.meta_utils import is_sparse_any, safe_grad
@ -1452,27 +1453,32 @@ class VariableBuilder:
source=self.source,
)
# This exists to allow a smoother transition.
# The implications are:
# The script objects won't be tracked as proxies.
# Methods on these objects won't show up in the graph.
# The original script object might be mutated.
if not hasattr(value, "__obj_flatten__"):
return self.wrap_user_defined(value)
if is_opaque_type(type(value)):
self.install_guards(GuardBuilder.TYPE_MATCH)
# Install the guards on the fully qualified name of the script object
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
value._type().qualified_name() # type: ignore[attr-defined]
elif not hasattr(value, "__obj_flatten__"):
# This exists to allow a smoother transition.
# The implications are:
# The script objects won't be tracked as proxies.
# Methods on these objects won't show up in the graph.
# The original script object might be mutated.
return self.wrap_user_defined(value)
else:
# Install the guards on the fully qualified name of the script object
LazyVariableTracker.realize_all(
VariableBuilder(
self.tx, ScriptObjectQualifiedNameSource(self.source)
)(
value._type().qualified_name() # type: ignore[attr-defined]
)
)
)
# Install the guards on the content of the script object by setting the source
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
value.__obj_flatten__()
# Install the guards on the content of the script object by setting the source
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
value.__obj_flatten__()
)
)
)
fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj(
self.tx.output.fake_mode, value

View File

@ -25,6 +25,7 @@ from typing_extensions import ParamSpec
import torch
from torch._guards import Source
from torch._library.opaque_object import is_opaque_type, OpaqueTypeStr
from torch.fx.proxy import Proxy
from .. import graph_break_hints
@ -61,7 +62,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
@classmethod
def is_matching_cls(cls, user_cls: type) -> bool:
return issubclass(user_cls, torch.ScriptObject)
return issubclass(user_cls, torch.ScriptObject) or is_opaque_type(user_cls)
@staticmethod
def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable":
@ -80,6 +81,16 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
"Dynamo cannot safely trace script object due to graph break."
)
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if getattr(self.value, "script_class_name", "") == OpaqueTypeStr:
unimplemented(
gb_type="Attempted to access attributes/methods on an OpaqueObject",
context=f"value={self.value}, attr={name}",
explanation="Attribute/method access of OpaqueObjects is not supported.",
hints=[
"Use custom operators instead of direct attribute/method access.",
],
)
from torch._higher_order_ops.torchbind import call_torchbind
from ..source import AttrSource

View File

@ -24,6 +24,7 @@ from torch._export.passes.lift_constants_pass import ConstantAttrMap
from torch._export.utils import _fakify_params_buffers
from torch._guards import Source
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import is_opaque_type
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.export import Constraint
from torch.export.dynamic_shapes import (
@ -946,7 +947,9 @@ def _fakify_script_objects(
try:
for obj, fqns in constant_attrs.items():
if torch._library.fake_class_registry._is_script_object(obj):
if torch._library.fake_class_registry._is_script_object(
obj
) or is_opaque_type(obj):
fake_script_obj = _maybe_fakify_obj(obj)
for fqn in fqns:
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)

View File

@ -8,6 +8,7 @@ from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
from torch._guards import detect_fake_mode
from torch._library.opaque_object import is_opaque_type
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
from torch.fx.experimental.symbolic_shapes import ShapeEnv
@ -46,7 +47,7 @@ def process_inputs(
hint=x,
source=source,
)
if isinstance(x, torch.ScriptObject):
if isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)):
return torch._library.fake_class_registry.maybe_to_fake_obj(
fake_mode, x
)

View File

@ -534,6 +534,7 @@ def create_aot_state(
stack.enter_context(autograd_fallback_mode("error"))
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
from torch._library.opaque_object import is_opaque_type
# Tracing may mutate the states the fake script object,
# so we need to duplicate the fake script objects so that subsequent tracing
@ -541,7 +542,7 @@ def create_aot_state(
def _dup_fake_script_obj(fake_flat_args):
return [
maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj)
if isinstance(arg, FakeScriptObject)
if isinstance(arg, FakeScriptObject) or is_opaque_type(type(arg))
else arg
for arg in fake_flat_args
]

View File

@ -1,13 +1,13 @@
# mypy: allow-untyped-defs
from enum import Enum
from typing import Any, Optional, Union
from weakref import WeakKeyDictionary
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.torchbind import call_torchbind
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.custom_ops import CustomOpDef
from torch._library.effects import EffectType
from torch._library.utils import RegistrationHandle
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
@ -17,39 +17,50 @@ from torch.fx.experimental.proxy_tensor import (
)
class _EffectType(Enum):
ORDERED = "Ordered"
_op_identifier = Union[
str,
"torch._ops.OpOverload",
"torch._library.custom_ops.CustomOpDef",
"torch._ops.HigherOrderOperator",
]
OpType = Union["torch._ops.HigherOrderOperator", "torch._ops.OpOverload"]
_EffectType = EffectType
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
def _get_op_qualname(op: _op_identifier) -> str:
"""Convert an op identifier to a qualified string key."""
if isinstance(op, torch._ops.OpOverload):
return op._name
elif isinstance(op, torch._ops.HigherOrderOperator):
return f"{op.namespace}::{op.name()}"
elif isinstance(op, CustomOpDef):
return op._qualname
elif isinstance(op, str):
return op
raise ValueError(f"Invalid operator input {op}")
SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType](
[
(torch.ops.aten._print.default, _EffectType.ORDERED),
(torch.ops.aten._async_error.default, _EffectType.ORDERED),
(call_torchbind, _EffectType.ORDERED),
]
)
def _register_effectful_op(
op: _op_identifier, effect: Optional[EffectType]
) -> RegistrationHandle:
qualname = _get_op_qualname(op)
entry = torch._library.simple_registry.singleton.find(qualname)
handle = entry.effect.register(effect)
return handle
def _register_effectful_op(op: OpType, effect: _EffectType):
assert isinstance(
op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
) and not has_aliasing(op)
if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
raise RuntimeError(
f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
f"trying to register a different effect type {effect}."
)
SIDE_EFFECTS[op] = effect
def _get_effect(op: _op_identifier) -> Optional[_EffectType]:
qualname = _get_op_qualname(op)
entry = torch._library.simple_registry.singleton.find(qualname)
return entry.effect.effect
def _deregister_effectful_op(op: OpType):
if op not in SIDE_EFFECTS:
raise RuntimeError(f"Op {op} is not registered as effectful")
del SIDE_EFFECTS[op]
_register_effectful_op("aten::_print", _EffectType.ORDERED)
_register_effectful_op("aten::_async_error", _EffectType.ORDERED)
_register_effectful_op("profiler::_record_function_exit._RecordFunction", None)
_register_effectful_op(call_torchbind, _EffectType.ORDERED)
class WithEffects(HigherOrderOperator):
@ -78,7 +89,7 @@ class WithEffects(HigherOrderOperator):
) -> tuple[Any, ...]:
assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
assert not has_aliasing(op), "Ops with aliasing is not supported"
assert has_effects(op, args, kwargs)
assert has_effects(op)
assert isinstance(kwargs, dict)
return super().__call__(token, op, *args, **kwargs)
@ -89,7 +100,7 @@ with_effects = WithEffects()
def has_aliasing(op: OpType):
# NOT FOR PUBLIC USE
if isinstance(op, torch._ops.HigherOrderOperator):
return op not in SIDE_EFFECTS
return not _get_effect(op)
for arg in op._schema.arguments:
if arg.alias_info is not None:
@ -100,7 +111,7 @@ def has_aliasing(op: OpType):
return False
def has_effects(op, args, kwargs) -> bool:
def has_effects(op) -> bool:
# Skip over the profiler's RecordFunction as they should not show up in the graph
_skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction}
if op in _skip_ops:
@ -109,31 +120,10 @@ def has_effects(op, args, kwargs) -> bool:
return (
isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
and not has_aliasing(op)
and get_effect_key(op, args, kwargs) is not None
and _get_effect(op) is not None
)
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
if op in SIDE_EFFECTS:
return SIDE_EFFECTS[op]
for arg in args:
if isinstance(arg, (torch.ScriptObject, FakeScriptObject)):
# Add it to the table so that next time we see the same op we don't
# have to parse through the args again
SIDE_EFFECTS[op] = _EffectType.ORDERED
return _EffectType.ORDERED
for arg in kwargs.values():
if isinstance(arg, (torch.ScriptObject, FakeScriptObject)):
# Add it to the table so that next time we see the same op we don't
# have to parse through the args again
SIDE_EFFECTS[op] = _EffectType.ORDERED
return _EffectType.ORDERED
return None
def new_token_tensor() -> torch.Tensor:
return torch.tensor([])
@ -238,7 +228,7 @@ def handle_effects(
# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
# this will create an empty tensor during proxy mode tracing if the token
# doesn't exist. But the tokens should always exist during proxy mode tracing.
key = get_effect_key(op, args, kwargs)
key = _get_effect(op)
assert key is not None
if key not in tokens:
assert allow_token_discovery, (

View File

@ -2122,6 +2122,10 @@ class PythonWrapperCodegen(CodeGen):
output.writeline(f"{name} = {val}")
def add_torchbind_input(name, value):
if value is None:
output.writeline(f"{name} = None")
return
import pickle
assert isinstance(value, torch.ScriptObject)

View File

@ -91,6 +91,7 @@ from torch._inductor.utils import (
tensor_is_aligned,
)
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import is_opaque_type
from torch._logging import trace_structured
from torch._utils_internal import compile_time_strobelight_meta
from torch.fx import GraphModule
@ -2747,7 +2748,9 @@ def _compile_fx_main(
node.meta["val"] = fake_mode.from_tensor(
target, static_shapes=True
)
elif isinstance(target, torch.ScriptObject):
elif isinstance(target, torch.ScriptObject) or is_opaque_type(
type(target)
):
node.meta["val"] = (
torch._library.fake_class_registry.maybe_to_fake_obj(
fake_mode, target

View File

@ -883,11 +883,12 @@ def _get_optimization_cflags(
should_use_optimized_flags = not (
config.aot_inductor.debug_compile
or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1"
or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1"
)
should_add_debug_symbol_flags = (
config.aot_inductor.debug_compile
or config.aot_inductor.debug_symbols
or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1"
or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1"
)
if should_use_optimized_flags:

View File

@ -9242,12 +9242,9 @@ class EffectfulKernel(FallbackKernel):
unbacked_bindings=unbacked_bindings,
)
from torch._higher_order_ops.effects import get_effect_key
from torch._higher_order_ops.effects import _get_effect
uncovered_args = [
a.value if isinstance(a, TorchBindObject) else a for a in tensor_args
]
effect_type = get_effect_key(kernel, (*nontensor_args, *uncovered_args), kwargs)
effect_type = _get_effect(kernel)
assert effect_type is not None
self.effect_type = effect_type
self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
@ -9298,6 +9295,10 @@ class TorchBindObject(NonTensorObj):
def get_buf_bytes(self) -> int:
# Returns the sum of all tensors in the flattened object
real_script_obj = self.get_real_obj()
if real_script_obj is None:
return 0
assert hasattr(real_script_obj, "__obj_flatten__")
flat_dict = dict(real_script_obj.__obj_flatten__())
flat_elems = pytree.tree_flatten(flat_dict)[0]

View File

@ -26,6 +26,7 @@ import torch.utils._pytree as pytree
from torch._dynamo.utils import counters
from torch._higher_order_ops.associative_scan import associative_scan_op
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.utils import get_layout_constraint_tag
from torch._prims_common import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated]
canonicalize_dim,
@ -2704,6 +2705,8 @@ def require_channels_last(_, *args, **kwargs):
def constrain_to_fake_tensor(arg, fake_arg):
if isinstance(fake_arg, FakeScriptObject):
return arg
if isinstance(arg, ir.IRNode):
meta_stride_expr = [
s.node.expr if isinstance(s, torch.SymInt) else s for s in fake_arg.stride()
@ -7453,9 +7456,9 @@ def _sink_tokens(tokens):
def with_effects(token, op, *args, **kwargs):
result = ir.EffectfulKernel.create(op, *args, **kwargs)
from torch._higher_order_ops.effects import get_effect_key
from torch._higher_order_ops.effects import _get_effect
effect_type = get_effect_key(op, args, kwargs)
effect_type = _get_effect(op)
assert effect_type is not None
effectful_kernel = V.graph.effectful_ops[effect_type]

View File

@ -13,6 +13,7 @@ from torch.types import _dtype
from torch.utils._exposed_in import exposed_in
from . import autograd, utils
from .effects import EffectType
device_types_t = Optional[Union[str, Sequence[str]]]
@ -471,6 +472,9 @@ class CustomOpDef:
self._abstract_fn = fn
return fn
def register_effect(self, effect: Optional[EffectType]) -> None:
self._lib._register_effectful_op(self._qualname, effect)
def register_torch_dispatch(
self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
) -> Callable:

68
torch/_library/effects.py Normal file
View File

@ -0,0 +1,68 @@
from enum import Enum
from typing import Optional
import torch
class EffectType(Enum):
ORDERED = "Ordered"
from torch._library.utils import RegistrationHandle
class EffectHolder:
"""A holder where one can register an effect impl to."""
def __init__(self, qualname: str):
self.qualname: str = qualname
self._set_default_effect()
def _set_default_effect(self) -> None:
self._effect: Optional[EffectType] = None
# If the op contains a ScriptObject input, we want to mark it as having effects
namespace, opname = torch._library.utils.parse_namespace(self.qualname)
split = opname.split(".")
if len(split) > 1:
assert len(split) == 2, (
f"Tried to split {opname} based on '.' but found more than 1 '.'"
)
opname, overload = split
else:
overload = ""
if namespace == "higher_order":
return
opname = f"{namespace}::{opname}"
if torch._C._get_operation_overload(opname, overload) is not None:
# Since we call this when destroying the library, sometimes the
# schema will be gone already at that time.
schema = torch._C._get_schema(opname, overload)
for arg in schema.arguments:
if isinstance(arg.type, torch.ClassType):
self._effect = EffectType.ORDERED
return
@property
def effect(self) -> Optional[EffectType]:
return self._effect
@effect.setter
def effect(self, _):
raise RuntimeError("Unable to directly set kernel.")
def register(self, effect: Optional[EffectType]) -> RegistrationHandle:
"""Register an effect
Returns a RegistrationHandle that one can use to de-register this
effect.
"""
self._effect = effect
def deregister_effect():
self._set_default_effect()
handle = RegistrationHandle(deregister_effect)
return handle

View File

@ -1,6 +1,7 @@
from collections.abc import Callable
from typing import Any, Optional
from .effects import EffectHolder
from .fake_impl import FakeImplHolder
from .utils import RegistrationHandle
@ -51,6 +52,8 @@ class SimpleOperatorEntry:
GenericTorchDispatchRuleHolder(qualname)
)
self.effect: EffectHolder = EffectHolder(qualname)
# For compatibility reasons. We can delete this soon.
@property
def abstract_impl(self) -> FakeImplHolder:

View File

@ -1023,6 +1023,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
DispatchKey.BackendSelect,
DispatchKey.PythonTLSSnapshot,
DispatchKey.PythonDispatcher,
DispatchKey.Functionalize,
]
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
@ -1046,17 +1047,23 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
def _register_as_effectful_op_temporarily(self):
from torch._higher_order_ops.effects import (
_EffectType,
_get_effect,
_register_effectful_op,
SIDE_EFFECTS,
)
try:
if self not in SIDE_EFFECTS:
_register_effectful_op(self, _EffectType.ORDERED)
# We don't want to register the effect if there already exists a
# registration, especially if the registration is None (explicitly
# no effect)
register_tmp_effect = _get_effect(self) is None
handle = None
if register_tmp_effect:
handle = _register_effectful_op(self, _EffectType.ORDERED)
yield
finally:
if self in SIDE_EFFECTS:
del SIDE_EFFECTS[self]
if register_tmp_effect:
assert handle is not None
handle.destroy()
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.

View File

@ -11,7 +11,7 @@ import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._C import _functionalization_reapply_views_tls as _reapply_views
from torch._ops import _get_dispatch_mode_pre_dispatch
from torch._ops import _get_dispatch_mode_pre_dispatch, TorchBindOpOverload
from torch._subclasses.meta_utils import is_sparse_any
from torch.utils._python_dispatch import (
_detect_infra_mode,
@ -471,7 +471,7 @@ class FunctionalTensorMode(TorchDispatchMode):
from torch._higher_order_ops.effects import handle_effects, has_effects
if has_effects(func, args, kwargs):
if has_effects(func):
assert not torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), torch._C.DispatchKey.Functionalize
)
@ -504,65 +504,81 @@ class FunctionalTensorMode(TorchDispatchMode):
- FunctionalTensor._extra_dispatch_keys
)
# All we want to do here is reuse the existing C++ functionalization logic.
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
try:
# By default for python functionalization (for AOTAutograd), we reapply views.
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
if isinstance(func, TorchBindOpOverload):
# When the function is a TorchBindOpOverload, meaning some of the
# inputs are FakeScriptObjects, we need to skip c++ dispatcher and
# dispatch in python because C++ dispatcher will check the schema
# and cannot recognize FakeScriptObject.
ctx = PythonFunctionalizeAPI()
fully_unwrapped_args = ctx.unwrap_tensors(args)
fully_unwrapped_kwargs = ctx.unwrap_tensors(
kwargs # pyrefly: ignore[bad-argument-type]
)
outs_unwrapped = func(
*fully_unwrapped_args,
**fully_unwrapped_kwargs,
)
outs_wrapped = ctx.wrap_tensors(outs_unwrapped)
else:
# All we want to do here is reuse the existing C++ functionalization logic.
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
try:
# By default for python functionalization (for AOTAutograd), we reapply views.
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
# Sometimes these functions cannot be directly dispatched to functionalize key
# because args are sometimes not functional tensors for some reason?
if func in FunctionalTensor.metadata_fns:
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
else:
# Note: [Functionalization View Replay Annotation]
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
# at the first time they are next used.
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
# to have the same annotation that the user specified on the original views. But view replay in
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
# for second_op!
#
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
# is globally set when we lazily perform view replay.
# The globally set metadata will be used to populate the fx node created for the replayed operation.
if m := torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
):
for a in pytree.tree_leaves([args, kwargs]):
if not isinstance(a, FunctionalTensor):
continue
curr_node = m.tracer.tensor_tracker[
torch._from_functional_tensor(a.elem)
].proxy.node
with fx_traceback.set_current_replay_node(curr_node):
torch._sync(a)
# Sometimes these functions cannot be directly dispatched to functionalize key
# because args are sometimes not functional tensors for some reason?
if func in FunctionalTensor.metadata_fns:
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
else:
# Note: [Functionalization View Replay Annotation]
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
# at the first time they are next used.
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
# to have the same annotation that the user specified on the original views. But view replay in
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
# for second_op!
#
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
# is globally set when we lazily perform view replay.
# The globally set metadata will be used to populate the fx node created for the replayed operation.
if m := torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
):
for a in pytree.tree_leaves([args, kwargs]):
if not isinstance(a, FunctionalTensor):
continue
curr_node = m.tracer.tensor_tracker[
torch._from_functional_tensor(a.elem)
].proxy.node
with fx_traceback.set_current_replay_node(curr_node):
torch._sync(a)
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
# back to PreDispatch later
outs_unwrapped = func._op_dk(
torch._C.DispatchKey.Functionalize,
*args_unwrapped,
**kwargs_unwrapped,
)
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
# back to PreDispatch later
outs_unwrapped = func._op_dk(
torch._C.DispatchKey.Functionalize,
*args_unwrapped,
**kwargs_unwrapped,
)
if self.export:
if func is torch.ops.aten.dropout.default:
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
finally:
torch._disable_functionalization()
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
if self.export:
if func is torch.ops.aten.dropout.default:
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
finally:
torch._disable_functionalization()
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
is_included = torch._C._dispatch_tls_is_dispatch_key_included(
torch._C.DispatchKey.Functionalize

View File

@ -18,6 +18,7 @@ import torch
import torch.utils._pytree as pytree
from torch._C import ScriptObject # type: ignore[attr-defined]
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import is_opaque_type
from ._compatibility import compatibility
from ._lazy_graph_module import _make_graph_module
@ -421,8 +422,10 @@ class Tracer(TracerBase):
# a get_attr to retrieve that tensor. Otherwise, we'll store away the
# tensor value into a special attribute on the Module s.t. we can
# retrieve it with a get_attr.
if isinstance(a, _constant_attribute_types):
qualname: Optional[str] = self.tensor_attrs.get(a)
if isinstance(a, _constant_attribute_types) or is_opaque_type(type(a)):
qualname: Optional[str] = self.tensor_attrs.get(
a
) # pyrefly: ignore[no-matching-overload]
# Tensor was not found in the Module hierarchy, stow it away in a
# special attribute and set the qualname to refer to that
@ -433,13 +436,17 @@ class Tracer(TracerBase):
base_name = "_torchbind_obj"
elif isinstance(a, pytree.TreeSpec):
base_name = "_tree_spec_constant"
elif is_opaque_type(type(a)):
base_name = "_opaque_obj"
else:
raise RuntimeError(
f"cannot create constant arg for {a} of type {type(a)}."
)
qualname = self.get_fresh_qualname(base_name)
assert isinstance(qualname, str)
self.tensor_attrs[a] = qualname
self.tensor_attrs[a] = ( # pyrefly: ignore[unsupported-operation]
qualname
)
setattr(self.root, qualname, a)
return self.create_node("get_attr", qualname, (), {})

View File

@ -19,6 +19,7 @@ from torch._library.custom_ops import (
CustomOpDef,
device_types_t,
)
from torch._library.effects import EffectType
from torch._library.infer_schema import infer_schema # noqa: F401
from torch._library.triton import triton_op, wrap_triton
from torch._ops import OpOverload
@ -398,6 +399,22 @@ class Library:
self.m.fallback(dispatch_key, fn, with_keyset)
def _register_effectful_op(self, op_name: str, effect: Optional[EffectType]):
"""
Registers an effect to an operator. This is used to register an op that
has side effects that is not capturable by the schema.
Args:
op_name: operator name (along with the overload) or OpOverload object.
effect: The effect of the op.
"""
from torch._higher_order_ops.effects import (
_register_effectful_op as hoo_register_effect,
)
handle = hoo_register_effect(op_name, effect)
self._registration_handles.append(handle)
def _destroy(self):
if self.m is not None:
self.m.reset()
@ -1065,6 +1082,44 @@ def register_fake(
return register(func)
def _register_effectful_op(
op: _op_identifier,
effect: Optional[EffectType],
*,
lib: Optional[Library] = None,
) -> None:
r"""
To specify that an operator has side-effects, we must register an effect
type for the operator. This will prevent graph passes in torch.compile from
reordering operations with the same effect type.
Args:
op_name: Operator name (along with the overload) or OpOverload object.
effect: Effect type to register. None means the operator is not effectful.
"""
if not isinstance(
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError(
f"register_effectful_op({op}): got unexpected type for op: {type(op)}"
)
if isinstance(op, torch._ops.OpOverload):
op = op._name
opdef = _maybe_get_opdef(op)
if opdef is not None:
opdef.register_effect(effect)
assert isinstance(op, str)
namespace, _ = torch._library.utils.parse_namespace(op)
if lib is None:
use_lib = Library(namespace, "FRAGMENT")
_keep_alive.append(use_lib)
else:
use_lib = lib
use_lib._register_effectful_op(op, effect)
def register_autograd(
op: _op_identifier,
backward: Callable,

View File

@ -37,7 +37,7 @@ import functools
import traceback
import weakref
from collections.abc import Callable
from typing import Any, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING # noqa: F401
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode