mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 14:15:07 +08:00
Compare commits
6 Commits
update_sub
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 3dc9878a70 | |||
| 8919f69362 | |||
| 19c867873a | |||
| e3dadb1d36 | |||
| c9b09a31e8 | |||
| 35571fe94b |
@ -129,7 +129,7 @@ function install_129 {
|
||||
}
|
||||
|
||||
function install_128 {
|
||||
CUDNN_VERSION=9.8.0.87
|
||||
CUDNN_VERSION=9.10.2.21
|
||||
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
|
||||
# install CUDA 12.8.1 in the same container
|
||||
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux
|
||||
|
||||
@ -18,15 +18,16 @@ from functorch.compile import (
|
||||
nop,
|
||||
)
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch._higher_order_ops.effects import with_effects
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_get_effect,
|
||||
_register_effectful_op,
|
||||
with_effects,
|
||||
)
|
||||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
SM70OrLater,
|
||||
SM80OrLater,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import SM70OrLater, SM80OrLater
|
||||
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
@ -300,7 +301,6 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
@unittest.skipIf(IS_WINDOWS, "triton")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "triton")
|
||||
@unittest.skipIf(not SM80OrLater, "triton")
|
||||
@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
|
||||
@unittest.skipIf(not TEST_CUDA, "triton")
|
||||
@skipIfNoDynamoSupport
|
||||
def test_register_effectful_custom_op(self):
|
||||
@ -308,41 +308,23 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
torch.library.define(
|
||||
"mylib::record_scalar_tensor",
|
||||
"(Tensor x, str prefix) -> ()",
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
# global variable to store the recorded tensor and prefix.
|
||||
recorded_dict = {}
|
||||
|
||||
# Pytorch custorm op implementation
|
||||
@torch.library.impl(
|
||||
"mylib::record_scalar_tensor",
|
||||
"CompositeExplicitAutograd",
|
||||
lib=lib,
|
||||
)
|
||||
def record_scalar_tensor(x, prefix):
|
||||
# Pytorch custom op implementation
|
||||
@torch.library.custom_op("mylib::record_scalar_tensor", mutates_args=())
|
||||
def record_scalar_tensor(x: torch.Tensor, prefix: str) -> None:
|
||||
recorded_dict[prefix] = x.clone()
|
||||
return
|
||||
|
||||
# Meta function of the custom op
|
||||
@torch.library.register_fake(
|
||||
"mylib::record_scalar_tensor",
|
||||
lib=lib,
|
||||
)
|
||||
@record_scalar_tensor.register_fake
|
||||
def record_scalar_tensor_meta(x, prefix):
|
||||
return
|
||||
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
record_scalar_tensor.register_effect(_EffectType.ORDERED)
|
||||
|
||||
_register_effectful_op(
|
||||
torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
|
||||
)
|
||||
self.assertEqual(_get_effect(record_scalar_tensor), _EffectType.ORDERED)
|
||||
|
||||
my_config = {}
|
||||
my_config["MockModule"] = "mean"
|
||||
@ -469,13 +451,12 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
|
||||
torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)
|
||||
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
torch.library._register_effectful_op(
|
||||
torch.ops._mylib.zoo.default, _EffectType.ORDERED
|
||||
)
|
||||
torch.library._register_effectful_op(
|
||||
torch.ops._mylib.zoo2.default, _EffectType.ORDERED
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
|
||||
_register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)
|
||||
|
||||
def fn(x, y):
|
||||
return torch.ops._mylib.zoo(x) + y
|
||||
@ -687,13 +668,13 @@ def forward(self, arg0_1, arg1_1):
|
||||
|
||||
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)
|
||||
|
||||
from torch._higher_order_ops.effects import (
|
||||
_deregister_effectful_op,
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
handle = _register_effectful_op(
|
||||
torch.ops._mylib.foo.default, _EffectType.ORDERED
|
||||
)
|
||||
self.assertEqual(
|
||||
_get_effect(torch.ops._mylib.foo.default), _EffectType.ORDERED
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
|
||||
try:
|
||||
|
||||
def fn(x, y):
|
||||
@ -779,17 +760,13 @@ def forward(self, tangents_1, tangents_2, tangents_token):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
finally:
|
||||
_deregister_effectful_op(torch.ops._mylib.foo.default)
|
||||
handle.destroy()
|
||||
|
||||
self.assertEqual(_get_effect(torch.ops._mylib.foo.default), None)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
def test_regular_effectful_op_only_in_backward(self):
|
||||
from torch._higher_order_ops.effects import (
|
||||
_deregister_effectful_op,
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
handle = _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
try:
|
||||
|
||||
def fn(x):
|
||||
@ -852,17 +829,11 @@ def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
|
||||
return (mul, mul_1, getitem_2)""",
|
||||
)
|
||||
finally:
|
||||
_deregister_effectful_op(torch.ops.aten.cos.default)
|
||||
handle.destroy()
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
def test_regular_effectful_op_in_forward_and_backward(self):
|
||||
from torch._higher_order_ops.effects import (
|
||||
_deregister_effectful_op,
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
handle = _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
try:
|
||||
|
||||
def fn(x):
|
||||
@ -897,7 +868,7 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
|
||||
return (mul_1, getitem_2)""",
|
||||
)
|
||||
finally:
|
||||
_deregister_effectful_op(torch.ops.aten.cos.default)
|
||||
handle.destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -136,12 +136,59 @@ class TestStandaloneInductor(TestCase):
|
||||
mod_opt = inductor.compile(mod, inp)
|
||||
self.assertEqual(mod(*inp), mod_opt(*inp))
|
||||
|
||||
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_COMPILE": "1"})
|
||||
def test_inductor_generate_debug_compile(self):
|
||||
cpp_code = """
|
||||
int main(){
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
_, source_path = write(
|
||||
cpp_code,
|
||||
"cpp",
|
||||
)
|
||||
build_option = CppOptions()
|
||||
cpp_builder = CppBuilder(
|
||||
name="test_compile",
|
||||
sources=source_path,
|
||||
output_dir=os.path.dirname(source_path),
|
||||
BuildOption=build_option,
|
||||
)
|
||||
cpp_builder.build()
|
||||
binary_path = cpp_builder.get_target_file_path()
|
||||
|
||||
"""
|
||||
When we turn on generate debug compile.
|
||||
On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG.
|
||||
On Linux, it should create some debug sections in binary file.
|
||||
"""
|
||||
|
||||
def check_linux_debug_section(module_path: str):
|
||||
check_cmd = shlex.split(f"readelf -S {module_path}")
|
||||
output = safe_command_output(check_cmd)
|
||||
has_debug_sym = ".debug_info" in output
|
||||
self.assertEqual(has_debug_sym, True)
|
||||
|
||||
def check_windows_pdb_exist(module_path: str):
|
||||
file_name_no_ext = os.path.splitext(module_path)[0]
|
||||
file_name_pdb = f"{file_name_no_ext}.pdb"
|
||||
has_pdb_file = os.path.exists(file_name_pdb)
|
||||
self.assertEqual(has_pdb_file, True)
|
||||
|
||||
if _IS_WINDOWS:
|
||||
check_windows_pdb_exist(binary_path)
|
||||
elif _IS_MACOS:
|
||||
pass # MacOS not sure that if it should be works.
|
||||
else:
|
||||
check_linux_debug_section(binary_path)
|
||||
|
||||
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_SYMBOL": "1"})
|
||||
def test_inductor_generate_debug_symbol(self):
|
||||
cpp_code = """
|
||||
int main(){
|
||||
return 0;
|
||||
}
|
||||
int main(){
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
_, source_path = write(
|
||||
|
||||
@ -90,7 +90,7 @@ class TestOpaqueObject(TestCase):
|
||||
# This is not accurate since the queue could have tensors that are
|
||||
# not rank 1
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.create_unbacked_symint()
|
||||
u0 = ctx.new_dynamic_size()
|
||||
return torch.empty(u0)
|
||||
|
||||
self.lib._register_fake("queue_pop", pop_impl_fake)
|
||||
@ -107,8 +107,7 @@ class TestOpaqueObject(TestCase):
|
||||
@size_impl.register_fake
|
||||
def size_impl_fake(q: torch._C.ScriptObject) -> int:
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.create_unbacked_symint()
|
||||
torch._check_is_size(u0)
|
||||
u0 = ctx.new_dynamic_size()
|
||||
return u0
|
||||
|
||||
super().setUp()
|
||||
|
||||
@ -1,12 +1,22 @@
|
||||
# Owner(s): ["module: custom-operators"]
|
||||
|
||||
import random
|
||||
from contextlib import ExitStack
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs
|
||||
from torch._functorch.aot_autograd import (
|
||||
aot_compile_joint_with_descriptors,
|
||||
aot_export_joint_with_descriptors,
|
||||
aot_export_module,
|
||||
)
|
||||
from torch._library.effects import EffectType
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import register_opaque_type
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -41,11 +51,21 @@ class OpaqueQueue:
|
||||
|
||||
class RNGState:
|
||||
def __init__(self, seed):
|
||||
self.rng = random.Random(seed)
|
||||
self.seed = seed
|
||||
self.rng = random.Random(self.seed)
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self, start):
|
||||
self.counter = torch.tensor(start)
|
||||
|
||||
def increment_counter(self):
|
||||
self.counter += 1
|
||||
|
||||
|
||||
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
|
||||
register_opaque_type(RNGState, "_TestOpaqueObject_RNGState")
|
||||
register_opaque_type(Counter, "_TestOpaqueObject_Counter")
|
||||
|
||||
|
||||
class TestOpaqueObject(TestCase):
|
||||
@ -125,6 +145,20 @@ class TestOpaqueObject(TestCase):
|
||||
def noisy_inject_fake(x: torch.Tensor, obj: RNGState) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_TestOpaqueObject::increment_counter",
|
||||
mutates_args=["prev"],
|
||||
)
|
||||
def increment_counter_impl(c: Counter, prev: torch.Tensor) -> torch.Tensor:
|
||||
assert isinstance(c, Counter)
|
||||
prev.copy_(c.counter)
|
||||
c.increment_counter()
|
||||
return c.counter
|
||||
|
||||
@increment_counter_impl.register_fake
|
||||
def increment_counter_fake(c: Counter, prev: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(prev)
|
||||
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
@ -233,6 +267,235 @@ def forward(self, arg0_1, arg1_1):
|
||||
):
|
||||
make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3))
|
||||
|
||||
def test_aot_export(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, rng_state, x):
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x * x
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x + x
|
||||
return (x,)
|
||||
|
||||
mod = Model()
|
||||
rng = RNGState(0)
|
||||
x = torch.ones(2, 3)
|
||||
|
||||
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
|
||||
fake_rng = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, rng)
|
||||
fake_x = fake_mode.from_tensor(x)
|
||||
gm = aot_export_module(mod, (fake_rng, fake_x), trace_joint=False)[0]
|
||||
|
||||
# By default we don't register ops containing PyObjs as being effectful
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg1_1, arg0_1); arg1_1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
|
||||
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg0_1); mul = arg0_1 = None
|
||||
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
torch.library._register_effectful_op(
|
||||
"_TestOpaqueObject::noisy_inject", EffectType.ORDERED
|
||||
)
|
||||
try:
|
||||
gm = aot_export_module(mod, (rng, fake_x), trace_joint=False)[0]
|
||||
# inputs: token, rng, x
|
||||
# return: token, res
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TestOpaqueObject.noisy_inject.default, arg2_1, arg1_1); arg0_1 = arg2_1 = None
|
||||
getitem = with_effects[0]
|
||||
getitem_1 = with_effects[1]; with_effects = None
|
||||
mul = torch.ops.aten.mul.Tensor(getitem_1, getitem_1); getitem_1 = None
|
||||
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TestOpaqueObject.noisy_inject.default, mul, arg1_1); getitem = mul = arg1_1 = None
|
||||
getitem_2 = with_effects_1[0]
|
||||
getitem_3 = with_effects_1[1]; with_effects_1 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem_3, getitem_3); getitem_3 = None
|
||||
return (getitem_2, add)""", # noqa: B950
|
||||
)
|
||||
finally:
|
||||
torch.library._register_effectful_op(
|
||||
"_TestOpaqueObject::noisy_inject", None
|
||||
)
|
||||
|
||||
def test_compile(self):
|
||||
def foo(rng_state, x):
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x * x
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x + x
|
||||
return x
|
||||
|
||||
rng = RNGState(0)
|
||||
x = torch.ones(2, 3)
|
||||
|
||||
res = torch.compile(foo, fullgraph=True, backend="inductor")(rng, x)
|
||||
self.assertFalse(torch.allclose(res, x * x + x))
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
torch.compile(foo, fullgraph=True, backend=backend)(rng, x)
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, L_x_ : torch.Tensor, L_rng_state_ : __main___RNGState):
|
||||
l_x_ = L_x_
|
||||
l_rng_state_ = L_rng_state_
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(l_x_, l_rng_state_); l_x_ = None
|
||||
x_1 = x * x; x = None
|
||||
x_2 = torch.ops._TestOpaqueObject.noisy_inject(x_1, l_rng_state_); x_1 = l_rng_state_ = None
|
||||
x_3 = x_2 + x_2; x_2 = None
|
||||
return (x_3,)""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
backend.fw_graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg0_1, arg1_1); arg0_1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
|
||||
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg1_1); mul = arg1_1 = None
|
||||
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_compile_intermediate(self):
|
||||
counter = Counter(0)
|
||||
|
||||
def foo(x, y):
|
||||
z = torch.ops._TestOpaqueObject.increment_counter(counter, y)
|
||||
x = x * z
|
||||
z = torch.ops._TestOpaqueObject.increment_counter(counter, y)
|
||||
x = x + z
|
||||
return x, counter
|
||||
|
||||
inp = (torch.tensor(1), torch.tensor(0))
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_f = torch.compile(foo, fullgraph=True, backend=backend)
|
||||
res = opt_f(*inp)
|
||||
self.assertEqual(res[0], torch.tensor(3))
|
||||
self.assertEqual(res[1].counter, torch.tensor(2))
|
||||
|
||||
res = opt_f(*inp)
|
||||
self.assertEqual(res[0], torch.tensor(7))
|
||||
self.assertEqual(res[1].counter, torch.tensor(4))
|
||||
|
||||
# counter is automatically lifted as an input
|
||||
# Even though we returned counter in the eager code, it does not get
|
||||
# returned in the graph because dynamo does not detect that the object
|
||||
# is mutated.
|
||||
self.assertExpectedInline(
|
||||
backend.fw_graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops._TestOpaqueObject.increment_counter.default, c = arg1_1, _prev_base_index = 0, _all_bases = [arg0_1])
|
||||
getitem = auto_functionalized_v2[0]
|
||||
getitem_1 = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
||||
mul = torch.ops.aten.mul.Tensor(arg2_1, getitem); arg2_1 = getitem = None
|
||||
auto_functionalized_v2_1 = torch.ops.higher_order.auto_functionalized_v2(torch.ops._TestOpaqueObject.increment_counter.default, c = arg1_1, _prev_base_index = 0, _all_bases = [getitem_1]); arg1_1 = getitem_1 = None
|
||||
getitem_2 = auto_functionalized_v2_1[0]
|
||||
getitem_3 = auto_functionalized_v2_1[1]; auto_functionalized_v2_1 = None
|
||||
add = torch.ops.aten.add.Tensor(mul, getitem_2); mul = getitem_2 = None
|
||||
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_3); arg0_1 = getitem_3 = copy_ = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_compile_attribute(self):
|
||||
counter = Counter(0)
|
||||
|
||||
def foo(counter, x):
|
||||
x = x * x
|
||||
counter.increment_counter()
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Attempted to access attributes/methods on an OpaqueObject"
|
||||
):
|
||||
torch.compile(foo)(counter, torch.ones(2, 3))
|
||||
|
||||
def bar(counter, x):
|
||||
x = x * x
|
||||
x += counter.counter
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Attempted to access attributes/methods on an OpaqueObject"
|
||||
):
|
||||
torch.compile(bar)(counter, torch.ones(2, 3))
|
||||
|
||||
def test_export_joint(self):
|
||||
class Moo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x * y
|
||||
|
||||
register_opaque_type(Moo, "_TestOpaqueObject_Moo")
|
||||
|
||||
torch.library.define(
|
||||
"_TestOpaqueObject::module_mul",
|
||||
"(_TestOpaqueObject_Moo a, Tensor b, SymInt c) -> Tensor",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=self.lib,
|
||||
)
|
||||
|
||||
@torch.library.impl(
|
||||
"_TestOpaqueObject::module_mul", "CompositeExplicitAutograd", lib=self.lib
|
||||
)
|
||||
def module_mul_impl(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor:
|
||||
assert isinstance(m, Moo)
|
||||
return m(a, b)
|
||||
|
||||
@torch.library.register_fake("_TestOpaqueObject::module_mul", lib=self.lib)
|
||||
def module_mul_fake(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor:
|
||||
return torch.empty_like(a)
|
||||
|
||||
def module_mul_setup_context(ctx, inputs, output):
|
||||
m, a, b = inputs
|
||||
ctx.b = b
|
||||
|
||||
def module_mul_backward(ctx, grad) -> torch.Tensor:
|
||||
return None, grad * ctx.b, None
|
||||
|
||||
torch.library.register_autograd(
|
||||
"_TestOpaqueObject::module_mul",
|
||||
module_mul_backward,
|
||||
setup_context=module_mul_setup_context,
|
||||
lib=self.lib,
|
||||
)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.moo = Moo()
|
||||
|
||||
def forward(self, x, y):
|
||||
b = y.item()
|
||||
return torch.ops._TestOpaqueObject.module_mul(self.moo, x, b)
|
||||
|
||||
inp = (torch.randn(3, requires_grad=True), torch.tensor(4))
|
||||
with ExitStack() as stack:
|
||||
with FakeTensorMode(shape_env=ShapeEnv()):
|
||||
joint = aot_export_joint_with_descriptors(stack, M(), inp)
|
||||
self.assertExpectedInline(
|
||||
joint.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, primals, tangents):
|
||||
primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
|
||||
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(primals_2); primals_2 = None
|
||||
_opaque_obj0 = self._opaque_obj0
|
||||
module_mul = torch.ops._TestOpaqueObject.module_mul.default(_opaque_obj0, primals_1, _local_scalar_dense); _opaque_obj0 = primals_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, _local_scalar_dense); tangents_1 = _local_scalar_dense = None
|
||||
return pytree.tree_unflatten([module_mul, mul_1, None], self._out_spec)""", # noqa: B950
|
||||
)
|
||||
compiled_fn = aot_compile_joint_with_descriptors(joint)
|
||||
|
||||
self.assertEqual(compiled_fn(*inp), M()(*inp))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestOpaqueObject)
|
||||
|
||||
|
||||
@ -3657,5 +3657,15 @@
|
||||
"Explanation": "Encountered triton kernel unsupported feature: {msg}",
|
||||
"Hints": []
|
||||
}
|
||||
],
|
||||
"GB0362": [
|
||||
{
|
||||
"Gb_type": "Attempted to access attributes/methods on an OpaqueObject",
|
||||
"Context": "value={self.value}, attr={name}",
|
||||
"Explanation": "Attribute/method access of OpaqueObjects is not supported.",
|
||||
"Hints": [
|
||||
"Use custom operators instead of direct attribute/method access."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -56,6 +56,7 @@ from torch._guards import (
|
||||
tracing,
|
||||
TracingContext,
|
||||
)
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch.export.dynamic_shapes import _ConstraintTarget
|
||||
@ -2605,6 +2606,8 @@ class OutputGraph(OutputGraphCommon):
|
||||
fake_attr_val,
|
||||
)
|
||||
continue
|
||||
if is_opaque_type(type(node.meta["grapharg"].example)):
|
||||
continue
|
||||
fake = (
|
||||
arg.fake_tensor if arg.fake_tensor is not None else arg.example
|
||||
)
|
||||
|
||||
@ -58,6 +58,7 @@ from torch._dynamo.utils import (
|
||||
from torch._guards import TracingContext
|
||||
from torch._higher_order_ops.flat_apply import flat_apply
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
||||
from torch._subclasses.meta_utils import is_sparse_any, safe_grad
|
||||
@ -1452,27 +1453,32 @@ class VariableBuilder:
|
||||
source=self.source,
|
||||
)
|
||||
|
||||
# This exists to allow a smoother transition.
|
||||
# The implications are:
|
||||
# The script objects won't be tracked as proxies.
|
||||
# Methods on these objects won't show up in the graph.
|
||||
# The original script object might be mutated.
|
||||
if not hasattr(value, "__obj_flatten__"):
|
||||
return self.wrap_user_defined(value)
|
||||
if is_opaque_type(type(value)):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
|
||||
# Install the guards on the fully qualified name of the script object
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
|
||||
value._type().qualified_name() # type: ignore[attr-defined]
|
||||
elif not hasattr(value, "__obj_flatten__"):
|
||||
# This exists to allow a smoother transition.
|
||||
# The implications are:
|
||||
# The script objects won't be tracked as proxies.
|
||||
# Methods on these objects won't show up in the graph.
|
||||
# The original script object might be mutated.
|
||||
return self.wrap_user_defined(value)
|
||||
else:
|
||||
# Install the guards on the fully qualified name of the script object
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(
|
||||
self.tx, ScriptObjectQualifiedNameSource(self.source)
|
||||
)(
|
||||
value._type().qualified_name() # type: ignore[attr-defined]
|
||||
)
|
||||
)
|
||||
)
|
||||
# Install the guards on the content of the script object by setting the source
|
||||
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
|
||||
value.__obj_flatten__()
|
||||
# Install the guards on the content of the script object by setting the source
|
||||
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
|
||||
value.__obj_flatten__()
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
self.tx.output.fake_mode, value
|
||||
|
||||
@ -25,6 +25,7 @@ from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
from torch._guards import Source
|
||||
from torch._library.opaque_object import is_opaque_type, OpaqueTypeStr
|
||||
from torch.fx.proxy import Proxy
|
||||
|
||||
from .. import graph_break_hints
|
||||
@ -61,7 +62,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||
|
||||
@classmethod
|
||||
def is_matching_cls(cls, user_cls: type) -> bool:
|
||||
return issubclass(user_cls, torch.ScriptObject)
|
||||
return issubclass(user_cls, torch.ScriptObject) or is_opaque_type(user_cls)
|
||||
|
||||
@staticmethod
|
||||
def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable":
|
||||
@ -80,6 +81,16 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||
"Dynamo cannot safely trace script object due to graph break."
|
||||
)
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if getattr(self.value, "script_class_name", "") == OpaqueTypeStr:
|
||||
unimplemented(
|
||||
gb_type="Attempted to access attributes/methods on an OpaqueObject",
|
||||
context=f"value={self.value}, attr={name}",
|
||||
explanation="Attribute/method access of OpaqueObjects is not supported.",
|
||||
hints=[
|
||||
"Use custom operators instead of direct attribute/method access.",
|
||||
],
|
||||
)
|
||||
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
|
||||
from ..source import AttrSource
|
||||
|
||||
@ -24,6 +24,7 @@ from torch._export.passes.lift_constants_pass import ConstantAttrMap
|
||||
from torch._export.utils import _fakify_params_buffers
|
||||
from torch._guards import Source
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.export import Constraint
|
||||
from torch.export.dynamic_shapes import (
|
||||
@ -946,7 +947,9 @@ def _fakify_script_objects(
|
||||
|
||||
try:
|
||||
for obj, fqns in constant_attrs.items():
|
||||
if torch._library.fake_class_registry._is_script_object(obj):
|
||||
if torch._library.fake_class_registry._is_script_object(
|
||||
obj
|
||||
) or is_opaque_type(obj):
|
||||
fake_script_obj = _maybe_fakify_obj(obj)
|
||||
for fqn in fqns:
|
||||
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Optional
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
@ -46,7 +47,7 @@ def process_inputs(
|
||||
hint=x,
|
||||
source=source,
|
||||
)
|
||||
if isinstance(x, torch.ScriptObject):
|
||||
if isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)):
|
||||
return torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
fake_mode, x
|
||||
)
|
||||
|
||||
@ -534,6 +534,7 @@ def create_aot_state(
|
||||
stack.enter_context(autograd_fallback_mode("error"))
|
||||
|
||||
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
|
||||
# Tracing may mutate the states the fake script object,
|
||||
# so we need to duplicate the fake script objects so that subsequent tracing
|
||||
@ -541,7 +542,7 @@ def create_aot_state(
|
||||
def _dup_fake_script_obj(fake_flat_args):
|
||||
return [
|
||||
maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj)
|
||||
if isinstance(arg, FakeScriptObject)
|
||||
if isinstance(arg, FakeScriptObject) or is_opaque_type(type(arg))
|
||||
else arg
|
||||
for arg in fake_flat_args
|
||||
]
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.custom_ops import CustomOpDef
|
||||
from torch._library.effects import EffectType
|
||||
from torch._library.utils import RegistrationHandle
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
@ -17,39 +17,50 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
)
|
||||
|
||||
|
||||
class _EffectType(Enum):
|
||||
ORDERED = "Ordered"
|
||||
_op_identifier = Union[
|
||||
str,
|
||||
"torch._ops.OpOverload",
|
||||
"torch._library.custom_ops.CustomOpDef",
|
||||
"torch._ops.HigherOrderOperator",
|
||||
]
|
||||
OpType = Union["torch._ops.HigherOrderOperator", "torch._ops.OpOverload"]
|
||||
|
||||
_EffectType = EffectType
|
||||
|
||||
|
||||
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
|
||||
def _get_op_qualname(op: _op_identifier) -> str:
|
||||
"""Convert an op identifier to a qualified string key."""
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
return op._name
|
||||
elif isinstance(op, torch._ops.HigherOrderOperator):
|
||||
return f"{op.namespace}::{op.name()}"
|
||||
elif isinstance(op, CustomOpDef):
|
||||
return op._qualname
|
||||
elif isinstance(op, str):
|
||||
return op
|
||||
|
||||
raise ValueError(f"Invalid operator input {op}")
|
||||
|
||||
|
||||
SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType](
|
||||
[
|
||||
(torch.ops.aten._print.default, _EffectType.ORDERED),
|
||||
(torch.ops.aten._async_error.default, _EffectType.ORDERED),
|
||||
(call_torchbind, _EffectType.ORDERED),
|
||||
]
|
||||
)
|
||||
def _register_effectful_op(
|
||||
op: _op_identifier, effect: Optional[EffectType]
|
||||
) -> RegistrationHandle:
|
||||
qualname = _get_op_qualname(op)
|
||||
entry = torch._library.simple_registry.singleton.find(qualname)
|
||||
handle = entry.effect.register(effect)
|
||||
return handle
|
||||
|
||||
|
||||
def _register_effectful_op(op: OpType, effect: _EffectType):
|
||||
assert isinstance(
|
||||
op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
|
||||
) and not has_aliasing(op)
|
||||
if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
|
||||
raise RuntimeError(
|
||||
f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
|
||||
f"trying to register a different effect type {effect}."
|
||||
)
|
||||
SIDE_EFFECTS[op] = effect
|
||||
def _get_effect(op: _op_identifier) -> Optional[_EffectType]:
|
||||
qualname = _get_op_qualname(op)
|
||||
entry = torch._library.simple_registry.singleton.find(qualname)
|
||||
return entry.effect.effect
|
||||
|
||||
|
||||
def _deregister_effectful_op(op: OpType):
|
||||
if op not in SIDE_EFFECTS:
|
||||
raise RuntimeError(f"Op {op} is not registered as effectful")
|
||||
|
||||
del SIDE_EFFECTS[op]
|
||||
_register_effectful_op("aten::_print", _EffectType.ORDERED)
|
||||
_register_effectful_op("aten::_async_error", _EffectType.ORDERED)
|
||||
_register_effectful_op("profiler::_record_function_exit._RecordFunction", None)
|
||||
_register_effectful_op(call_torchbind, _EffectType.ORDERED)
|
||||
|
||||
|
||||
class WithEffects(HigherOrderOperator):
|
||||
@ -78,7 +89,7 @@ class WithEffects(HigherOrderOperator):
|
||||
) -> tuple[Any, ...]:
|
||||
assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
|
||||
assert not has_aliasing(op), "Ops with aliasing is not supported"
|
||||
assert has_effects(op, args, kwargs)
|
||||
assert has_effects(op)
|
||||
assert isinstance(kwargs, dict)
|
||||
return super().__call__(token, op, *args, **kwargs)
|
||||
|
||||
@ -89,7 +100,7 @@ with_effects = WithEffects()
|
||||
def has_aliasing(op: OpType):
|
||||
# NOT FOR PUBLIC USE
|
||||
if isinstance(op, torch._ops.HigherOrderOperator):
|
||||
return op not in SIDE_EFFECTS
|
||||
return not _get_effect(op)
|
||||
|
||||
for arg in op._schema.arguments:
|
||||
if arg.alias_info is not None:
|
||||
@ -100,7 +111,7 @@ def has_aliasing(op: OpType):
|
||||
return False
|
||||
|
||||
|
||||
def has_effects(op, args, kwargs) -> bool:
|
||||
def has_effects(op) -> bool:
|
||||
# Skip over the profiler's RecordFunction as they should not show up in the graph
|
||||
_skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction}
|
||||
if op in _skip_ops:
|
||||
@ -109,31 +120,10 @@ def has_effects(op, args, kwargs) -> bool:
|
||||
return (
|
||||
isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
|
||||
and not has_aliasing(op)
|
||||
and get_effect_key(op, args, kwargs) is not None
|
||||
and _get_effect(op) is not None
|
||||
)
|
||||
|
||||
|
||||
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
|
||||
if op in SIDE_EFFECTS:
|
||||
return SIDE_EFFECTS[op]
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, (torch.ScriptObject, FakeScriptObject)):
|
||||
# Add it to the table so that next time we see the same op we don't
|
||||
# have to parse through the args again
|
||||
SIDE_EFFECTS[op] = _EffectType.ORDERED
|
||||
return _EffectType.ORDERED
|
||||
|
||||
for arg in kwargs.values():
|
||||
if isinstance(arg, (torch.ScriptObject, FakeScriptObject)):
|
||||
# Add it to the table so that next time we see the same op we don't
|
||||
# have to parse through the args again
|
||||
SIDE_EFFECTS[op] = _EffectType.ORDERED
|
||||
return _EffectType.ORDERED
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def new_token_tensor() -> torch.Tensor:
|
||||
return torch.tensor([])
|
||||
|
||||
@ -238,7 +228,7 @@ def handle_effects(
|
||||
# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
|
||||
# this will create an empty tensor during proxy mode tracing if the token
|
||||
# doesn't exist. But the tokens should always exist during proxy mode tracing.
|
||||
key = get_effect_key(op, args, kwargs)
|
||||
key = _get_effect(op)
|
||||
assert key is not None
|
||||
if key not in tokens:
|
||||
assert allow_token_discovery, (
|
||||
|
||||
@ -2122,6 +2122,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
output.writeline(f"{name} = {val}")
|
||||
|
||||
def add_torchbind_input(name, value):
|
||||
if value is None:
|
||||
output.writeline(f"{name} = None")
|
||||
return
|
||||
|
||||
import pickle
|
||||
|
||||
assert isinstance(value, torch.ScriptObject)
|
||||
|
||||
@ -91,6 +91,7 @@ from torch._inductor.utils import (
|
||||
tensor_is_aligned,
|
||||
)
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._logging import trace_structured
|
||||
from torch._utils_internal import compile_time_strobelight_meta
|
||||
from torch.fx import GraphModule
|
||||
@ -2747,7 +2748,9 @@ def _compile_fx_main(
|
||||
node.meta["val"] = fake_mode.from_tensor(
|
||||
target, static_shapes=True
|
||||
)
|
||||
elif isinstance(target, torch.ScriptObject):
|
||||
elif isinstance(target, torch.ScriptObject) or is_opaque_type(
|
||||
type(target)
|
||||
):
|
||||
node.meta["val"] = (
|
||||
torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
fake_mode, target
|
||||
|
||||
@ -883,11 +883,12 @@ def _get_optimization_cflags(
|
||||
|
||||
should_use_optimized_flags = not (
|
||||
config.aot_inductor.debug_compile
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1"
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1"
|
||||
)
|
||||
should_add_debug_symbol_flags = (
|
||||
config.aot_inductor.debug_compile
|
||||
or config.aot_inductor.debug_symbols
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1"
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1"
|
||||
)
|
||||
if should_use_optimized_flags:
|
||||
|
||||
@ -9242,12 +9242,9 @@ class EffectfulKernel(FallbackKernel):
|
||||
unbacked_bindings=unbacked_bindings,
|
||||
)
|
||||
|
||||
from torch._higher_order_ops.effects import get_effect_key
|
||||
from torch._higher_order_ops.effects import _get_effect
|
||||
|
||||
uncovered_args = [
|
||||
a.value if isinstance(a, TorchBindObject) else a for a in tensor_args
|
||||
]
|
||||
effect_type = get_effect_key(kernel, (*nontensor_args, *uncovered_args), kwargs)
|
||||
effect_type = _get_effect(kernel)
|
||||
assert effect_type is not None
|
||||
self.effect_type = effect_type
|
||||
self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
|
||||
@ -9298,6 +9295,10 @@ class TorchBindObject(NonTensorObj):
|
||||
def get_buf_bytes(self) -> int:
|
||||
# Returns the sum of all tensors in the flattened object
|
||||
real_script_obj = self.get_real_obj()
|
||||
|
||||
if real_script_obj is None:
|
||||
return 0
|
||||
|
||||
assert hasattr(real_script_obj, "__obj_flatten__")
|
||||
flat_dict = dict(real_script_obj.__obj_flatten__())
|
||||
flat_elems = pytree.tree_flatten(flat_dict)[0]
|
||||
|
||||
@ -26,6 +26,7 @@ import torch.utils._pytree as pytree
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._higher_order_ops.associative_scan import associative_scan_op
|
||||
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.utils import get_layout_constraint_tag
|
||||
from torch._prims_common import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated]
|
||||
canonicalize_dim,
|
||||
@ -2704,6 +2705,8 @@ def require_channels_last(_, *args, **kwargs):
|
||||
|
||||
|
||||
def constrain_to_fake_tensor(arg, fake_arg):
|
||||
if isinstance(fake_arg, FakeScriptObject):
|
||||
return arg
|
||||
if isinstance(arg, ir.IRNode):
|
||||
meta_stride_expr = [
|
||||
s.node.expr if isinstance(s, torch.SymInt) else s for s in fake_arg.stride()
|
||||
@ -7453,9 +7456,9 @@ def _sink_tokens(tokens):
|
||||
def with_effects(token, op, *args, **kwargs):
|
||||
result = ir.EffectfulKernel.create(op, *args, **kwargs)
|
||||
|
||||
from torch._higher_order_ops.effects import get_effect_key
|
||||
from torch._higher_order_ops.effects import _get_effect
|
||||
|
||||
effect_type = get_effect_key(op, args, kwargs)
|
||||
effect_type = _get_effect(op)
|
||||
assert effect_type is not None
|
||||
effectful_kernel = V.graph.effectful_ops[effect_type]
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from torch.types import _dtype
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
from . import autograd, utils
|
||||
from .effects import EffectType
|
||||
|
||||
|
||||
device_types_t = Optional[Union[str, Sequence[str]]]
|
||||
@ -471,6 +472,9 @@ class CustomOpDef:
|
||||
self._abstract_fn = fn
|
||||
return fn
|
||||
|
||||
def register_effect(self, effect: Optional[EffectType]) -> None:
|
||||
self._lib._register_effectful_op(self._qualname, effect)
|
||||
|
||||
def register_torch_dispatch(
|
||||
self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
|
||||
) -> Callable:
|
||||
|
||||
68
torch/_library/effects.py
Normal file
68
torch/_library/effects.py
Normal file
@ -0,0 +1,68 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class EffectType(Enum):
|
||||
ORDERED = "Ordered"
|
||||
|
||||
|
||||
from torch._library.utils import RegistrationHandle
|
||||
|
||||
|
||||
class EffectHolder:
|
||||
"""A holder where one can register an effect impl to."""
|
||||
|
||||
def __init__(self, qualname: str):
|
||||
self.qualname: str = qualname
|
||||
self._set_default_effect()
|
||||
|
||||
def _set_default_effect(self) -> None:
|
||||
self._effect: Optional[EffectType] = None
|
||||
|
||||
# If the op contains a ScriptObject input, we want to mark it as having effects
|
||||
namespace, opname = torch._library.utils.parse_namespace(self.qualname)
|
||||
split = opname.split(".")
|
||||
if len(split) > 1:
|
||||
assert len(split) == 2, (
|
||||
f"Tried to split {opname} based on '.' but found more than 1 '.'"
|
||||
)
|
||||
opname, overload = split
|
||||
else:
|
||||
overload = ""
|
||||
|
||||
if namespace == "higher_order":
|
||||
return
|
||||
|
||||
opname = f"{namespace}::{opname}"
|
||||
if torch._C._get_operation_overload(opname, overload) is not None:
|
||||
# Since we call this when destroying the library, sometimes the
|
||||
# schema will be gone already at that time.
|
||||
schema = torch._C._get_schema(opname, overload)
|
||||
for arg in schema.arguments:
|
||||
if isinstance(arg.type, torch.ClassType):
|
||||
self._effect = EffectType.ORDERED
|
||||
return
|
||||
|
||||
@property
|
||||
def effect(self) -> Optional[EffectType]:
|
||||
return self._effect
|
||||
|
||||
@effect.setter
|
||||
def effect(self, _):
|
||||
raise RuntimeError("Unable to directly set kernel.")
|
||||
|
||||
def register(self, effect: Optional[EffectType]) -> RegistrationHandle:
|
||||
"""Register an effect
|
||||
|
||||
Returns a RegistrationHandle that one can use to de-register this
|
||||
effect.
|
||||
"""
|
||||
self._effect = effect
|
||||
|
||||
def deregister_effect():
|
||||
self._set_default_effect()
|
||||
|
||||
handle = RegistrationHandle(deregister_effect)
|
||||
return handle
|
||||
@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from .effects import EffectHolder
|
||||
from .fake_impl import FakeImplHolder
|
||||
from .utils import RegistrationHandle
|
||||
|
||||
@ -51,6 +52,8 @@ class SimpleOperatorEntry:
|
||||
GenericTorchDispatchRuleHolder(qualname)
|
||||
)
|
||||
|
||||
self.effect: EffectHolder = EffectHolder(qualname)
|
||||
|
||||
# For compatibility reasons. We can delete this soon.
|
||||
@property
|
||||
def abstract_impl(self) -> FakeImplHolder:
|
||||
|
||||
@ -1023,6 +1023,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
|
||||
DispatchKey.BackendSelect,
|
||||
DispatchKey.PythonTLSSnapshot,
|
||||
DispatchKey.PythonDispatcher,
|
||||
DispatchKey.Functionalize,
|
||||
]
|
||||
|
||||
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
|
||||
@ -1046,17 +1047,23 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
|
||||
def _register_as_effectful_op_temporarily(self):
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_get_effect,
|
||||
_register_effectful_op,
|
||||
SIDE_EFFECTS,
|
||||
)
|
||||
|
||||
try:
|
||||
if self not in SIDE_EFFECTS:
|
||||
_register_effectful_op(self, _EffectType.ORDERED)
|
||||
# We don't want to register the effect if there already exists a
|
||||
# registration, especially if the registration is None (explicitly
|
||||
# no effect)
|
||||
register_tmp_effect = _get_effect(self) is None
|
||||
handle = None
|
||||
if register_tmp_effect:
|
||||
handle = _register_effectful_op(self, _EffectType.ORDERED)
|
||||
yield
|
||||
finally:
|
||||
if self in SIDE_EFFECTS:
|
||||
del SIDE_EFFECTS[self]
|
||||
if register_tmp_effect:
|
||||
assert handle is not None
|
||||
handle.destroy()
|
||||
|
||||
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _functionalization_reapply_views_tls as _reapply_views
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch, TorchBindOpOverload
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.utils._python_dispatch import (
|
||||
_detect_infra_mode,
|
||||
@ -471,7 +471,7 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
|
||||
from torch._higher_order_ops.effects import handle_effects, has_effects
|
||||
|
||||
if has_effects(func, args, kwargs):
|
||||
if has_effects(func):
|
||||
assert not torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
func.name(), torch._C.DispatchKey.Functionalize
|
||||
)
|
||||
@ -504,65 +504,81 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
- FunctionalTensor._extra_dispatch_keys
|
||||
)
|
||||
|
||||
# All we want to do here is reuse the existing C++ functionalization logic.
|
||||
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
try:
|
||||
# By default for python functionalization (for AOTAutograd), we reapply views.
|
||||
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
|
||||
if isinstance(func, TorchBindOpOverload):
|
||||
# When the function is a TorchBindOpOverload, meaning some of the
|
||||
# inputs are FakeScriptObjects, we need to skip c++ dispatcher and
|
||||
# dispatch in python because C++ dispatcher will check the schema
|
||||
# and cannot recognize FakeScriptObject.
|
||||
ctx = PythonFunctionalizeAPI()
|
||||
fully_unwrapped_args = ctx.unwrap_tensors(args)
|
||||
fully_unwrapped_kwargs = ctx.unwrap_tensors(
|
||||
kwargs # pyrefly: ignore[bad-argument-type]
|
||||
)
|
||||
outs_unwrapped = func(
|
||||
*fully_unwrapped_args,
|
||||
**fully_unwrapped_kwargs,
|
||||
)
|
||||
outs_wrapped = ctx.wrap_tensors(outs_unwrapped)
|
||||
else:
|
||||
# All we want to do here is reuse the existing C++ functionalization logic.
|
||||
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
try:
|
||||
# By default for python functionalization (for AOTAutograd), we reapply views.
|
||||
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
|
||||
|
||||
# Sometimes these functions cannot be directly dispatched to functionalize key
|
||||
# because args are sometimes not functional tensors for some reason?
|
||||
if func in FunctionalTensor.metadata_fns:
|
||||
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
else:
|
||||
# Note: [Functionalization View Replay Annotation]
|
||||
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
|
||||
# at the first time they are next used.
|
||||
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
|
||||
# to have the same annotation that the user specified on the original views. But view replay in
|
||||
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
|
||||
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
|
||||
# for second_op!
|
||||
#
|
||||
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
|
||||
# is globally set when we lazily perform view replay.
|
||||
# The globally set metadata will be used to populate the fx node created for the replayed operation.
|
||||
if m := torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
):
|
||||
for a in pytree.tree_leaves([args, kwargs]):
|
||||
if not isinstance(a, FunctionalTensor):
|
||||
continue
|
||||
curr_node = m.tracer.tensor_tracker[
|
||||
torch._from_functional_tensor(a.elem)
|
||||
].proxy.node
|
||||
with fx_traceback.set_current_replay_node(curr_node):
|
||||
torch._sync(a)
|
||||
# Sometimes these functions cannot be directly dispatched to functionalize key
|
||||
# because args are sometimes not functional tensors for some reason?
|
||||
if func in FunctionalTensor.metadata_fns:
|
||||
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
else:
|
||||
# Note: [Functionalization View Replay Annotation]
|
||||
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
|
||||
# at the first time they are next used.
|
||||
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
|
||||
# to have the same annotation that the user specified on the original views. But view replay in
|
||||
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
|
||||
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
|
||||
# for second_op!
|
||||
#
|
||||
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
|
||||
# is globally set when we lazily perform view replay.
|
||||
# The globally set metadata will be used to populate the fx node created for the replayed operation.
|
||||
if m := torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
):
|
||||
for a in pytree.tree_leaves([args, kwargs]):
|
||||
if not isinstance(a, FunctionalTensor):
|
||||
continue
|
||||
curr_node = m.tracer.tensor_tracker[
|
||||
torch._from_functional_tensor(a.elem)
|
||||
].proxy.node
|
||||
with fx_traceback.set_current_replay_node(curr_node):
|
||||
torch._sync(a)
|
||||
|
||||
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
|
||||
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
|
||||
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
|
||||
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
|
||||
# back to PreDispatch later
|
||||
outs_unwrapped = func._op_dk(
|
||||
torch._C.DispatchKey.Functionalize,
|
||||
*args_unwrapped,
|
||||
**kwargs_unwrapped,
|
||||
)
|
||||
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
|
||||
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
|
||||
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
|
||||
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
|
||||
# back to PreDispatch later
|
||||
outs_unwrapped = func._op_dk(
|
||||
torch._C.DispatchKey.Functionalize,
|
||||
*args_unwrapped,
|
||||
**kwargs_unwrapped,
|
||||
)
|
||||
|
||||
if self.export:
|
||||
if func is torch.ops.aten.dropout.default:
|
||||
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
|
||||
if self.export:
|
||||
if func is torch.ops.aten.dropout.default:
|
||||
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
|
||||
|
||||
is_included = torch._C._dispatch_tls_is_dispatch_key_included(
|
||||
torch._C.DispatchKey.Functionalize
|
||||
|
||||
@ -18,6 +18,7 @@ import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import ScriptObject # type: ignore[attr-defined]
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from ._lazy_graph_module import _make_graph_module
|
||||
@ -421,8 +422,10 @@ class Tracer(TracerBase):
|
||||
# a get_attr to retrieve that tensor. Otherwise, we'll store away the
|
||||
# tensor value into a special attribute on the Module s.t. we can
|
||||
# retrieve it with a get_attr.
|
||||
if isinstance(a, _constant_attribute_types):
|
||||
qualname: Optional[str] = self.tensor_attrs.get(a)
|
||||
if isinstance(a, _constant_attribute_types) or is_opaque_type(type(a)):
|
||||
qualname: Optional[str] = self.tensor_attrs.get(
|
||||
a
|
||||
) # pyrefly: ignore[no-matching-overload]
|
||||
|
||||
# Tensor was not found in the Module hierarchy, stow it away in a
|
||||
# special attribute and set the qualname to refer to that
|
||||
@ -433,13 +436,17 @@ class Tracer(TracerBase):
|
||||
base_name = "_torchbind_obj"
|
||||
elif isinstance(a, pytree.TreeSpec):
|
||||
base_name = "_tree_spec_constant"
|
||||
elif is_opaque_type(type(a)):
|
||||
base_name = "_opaque_obj"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"cannot create constant arg for {a} of type {type(a)}."
|
||||
)
|
||||
qualname = self.get_fresh_qualname(base_name)
|
||||
assert isinstance(qualname, str)
|
||||
self.tensor_attrs[a] = qualname
|
||||
self.tensor_attrs[a] = ( # pyrefly: ignore[unsupported-operation]
|
||||
qualname
|
||||
)
|
||||
setattr(self.root, qualname, a)
|
||||
|
||||
return self.create_node("get_attr", qualname, (), {})
|
||||
|
||||
@ -19,6 +19,7 @@ from torch._library.custom_ops import (
|
||||
CustomOpDef,
|
||||
device_types_t,
|
||||
)
|
||||
from torch._library.effects import EffectType
|
||||
from torch._library.infer_schema import infer_schema # noqa: F401
|
||||
from torch._library.triton import triton_op, wrap_triton
|
||||
from torch._ops import OpOverload
|
||||
@ -398,6 +399,22 @@ class Library:
|
||||
|
||||
self.m.fallback(dispatch_key, fn, with_keyset)
|
||||
|
||||
def _register_effectful_op(self, op_name: str, effect: Optional[EffectType]):
|
||||
"""
|
||||
Registers an effect to an operator. This is used to register an op that
|
||||
has side effects that is not capturable by the schema.
|
||||
|
||||
Args:
|
||||
op_name: operator name (along with the overload) or OpOverload object.
|
||||
effect: The effect of the op.
|
||||
"""
|
||||
from torch._higher_order_ops.effects import (
|
||||
_register_effectful_op as hoo_register_effect,
|
||||
)
|
||||
|
||||
handle = hoo_register_effect(op_name, effect)
|
||||
self._registration_handles.append(handle)
|
||||
|
||||
def _destroy(self):
|
||||
if self.m is not None:
|
||||
self.m.reset()
|
||||
@ -1065,6 +1082,44 @@ def register_fake(
|
||||
return register(func)
|
||||
|
||||
|
||||
def _register_effectful_op(
|
||||
op: _op_identifier,
|
||||
effect: Optional[EffectType],
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
To specify that an operator has side-effects, we must register an effect
|
||||
type for the operator. This will prevent graph passes in torch.compile from
|
||||
reordering operations with the same effect type.
|
||||
|
||||
Args:
|
||||
op_name: Operator name (along with the overload) or OpOverload object.
|
||||
effect: Effect type to register. None means the operator is not effectful.
|
||||
"""
|
||||
if not isinstance(
|
||||
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||
):
|
||||
raise ValueError(
|
||||
f"register_effectful_op({op}): got unexpected type for op: {type(op)}"
|
||||
)
|
||||
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
op = op._name
|
||||
opdef = _maybe_get_opdef(op)
|
||||
if opdef is not None:
|
||||
opdef.register_effect(effect)
|
||||
assert isinstance(op, str)
|
||||
|
||||
namespace, _ = torch._library.utils.parse_namespace(op)
|
||||
if lib is None:
|
||||
use_lib = Library(namespace, "FRAGMENT")
|
||||
_keep_alive.append(use_lib)
|
||||
else:
|
||||
use_lib = lib
|
||||
use_lib._register_effectful_op(op, effect)
|
||||
|
||||
|
||||
def register_autograd(
|
||||
op: _op_identifier,
|
||||
backward: Callable,
|
||||
|
||||
@ -37,7 +37,7 @@ import functools
|
||||
import traceback
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
|
||||
Reference in New Issue
Block a user