Compare commits

...

3 Commits

Author SHA1 Message Date
7d88ccc110 [Dynamo] Support the torch._C.DisableTorchFunction ctx manager
ghstack-source-id: c7853c71e9f8187ba7ccbe02c24940336aa6453e
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149491
2025-03-19 15:52:23 -07:00
314405f223 [Dynamo] add support for torch._C._is_torch_function_all_disabled
ghstack-source-id: 8361faa2ddf79a29b93b461a9b94b0e34004f101
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149490
2025-03-19 12:46:24 -07:00
bbe4c6e034 [Dynamo] Refactor DisableTorchFunction ctx manager
ghstack-source-id: f2913a8654380203897ac1d69821527e430be98a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149489
2025-03-19 12:46:24 -07:00
5 changed files with 129 additions and 15 deletions

View File

@ -207,6 +207,18 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
self.assertEqual(_len_torch_function_stack(), 0)
def test_is_torch_function_all_disabled(self):
@torch.compile(fullgraph=True)
def fn(x):
return (
torch._C._is_torch_function_all_disabled(),
torch.add(x, 1.0),
)
input = torch.ones(2, 2)
res, _ = fn(input)
self.assertFalse(res)
def test_error_empty_stack_pop_torch_function_mode(self):
@torch.compile(fullgraph=True)
def fn(x):

View File

@ -522,6 +522,57 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
res, _ = fn(input)
self.assertFalse(res)
def test_disable_all_torch_function(self):
@torch.compile(backend="eager")
def fn(x):
with torch._C.DisableTorchFunction():
torch._dynamo.graph_break()
return (
torch._C._is_torch_function_enabled(),
torch._C._is_torch_function_all_disabled(),
torch.add(x, 1.0),
)
input = torch.ones(2, 2)
res1, res2, _ = fn(input)
self.assertFalse(res1)
self.assertTrue(res2)
def test_disable_all_torch_function_restore_values(self):
@torch.compile(backend="eager")
def fn(x):
with torch._C.DisableTorchFunction():
x = torch._C._is_torch_function_all_disabled()
return (
x,
torch._C._is_torch_function_all_disabled(),
torch.add(x, 1.0),
)
input = torch.ones(2, 2)
res1, res2, _ = fn(input)
self.assertTrue(res1)
self.assertFalse(res2)
def test_disable_all_torch_function_restore_values_graph_break(self):
@torch.compile(backend="eager")
def fn(x):
with torch._C.DisableTorchFunction():
torch._dynamo.graph_break()
x = torch._C._is_torch_function_all_disabled()
return (
x,
torch._C._is_torch_function_all_disabled(),
torch.add(x, 1.0),
)
input = torch.ones(2, 2)
res1, res2, _ = fn(input)
self.assertTrue(res1)
self.assertFalse(res2)
def test_torch_function_state_nested(self):
@torch.compile(backend="eager")
def fn(x):

View File

@ -678,6 +678,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._is_multithreading_enabled",
"torch._C._is_torch_function_enabled",
"torch._C._is_torch_function_mode_enabled",
"torch._C._is_torch_function_all_disabled",
"torch._C._is_tracing",
"torch._C._is_view_replay_enabled",
"torch._C._is_xnnpack_enabled",

View File

@ -674,29 +674,62 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
@staticmethod
def create(tx: "InstructionTranslator", **kwargs):
var = TorchFunctionDisableVariable(
target_values=[False],
initial_values=[tx.output.torch_function_enabled],
target_values=[],
initial_values=[],
**kwargs,
)
# mlazos: I think this is here to make sure we don't reinvoke on clone()
var._call_func(tx, [False])
var.set_cleanup_hook(tx)
return var
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
def __init__(
self, target_values, initial_values=None, only_subclass=True, **kwargs
) -> None:
assert len(target_values) == 0
assert len(initial_values) == 0
from ..symbolic_convert import InstructionTranslator
tx = InstructionTranslator.current_tx()
self.only_subclass = only_subclass
self.initial_torch_function_subclass_enabled = (
tx.symbolic_torch_function_state.torch_function_subclass_enabled
)
self.initial_torch_function_mode_enabled = (
tx.symbolic_torch_function_state.torch_function_mode_enabled
)
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
install_guard(self._guards_singleton)
def enter(self, tx):
return variables.ConstantVariable.create(None)
def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None):
if fn is None:
def fn():
tx.symbolic_torch_function_state.torch_function_subclass_enabled = (
self.initial_torch_function_subclass_enabled
)
if not self.only_subclass:
tx.symbolic_torch_function_state.torch_function_mode_enabled = (
self.initial_torch_function_subclass_enabled
)
self.state.cleanup_fn = fn
tx.output.add_cleanup_hook(self.state.cleanup)
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0]
tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0]
tx.output.set_torch_function_state(values[0])
assert len(values) == 0
tx.symbolic_torch_function_state.torch_function_subclass_enabled = False
if not self.only_subclass:
tx.symbolic_torch_function_state.torch_function_mode_enabled = False
tx.output.set_torch_function_state(False)
def module_name(self):
return "torch._C"
def fn_name(self):
if self.only_subclass:
return "DisableTorchFunctionSubclass"
return "DisableTorchFunction"
class DeterministicAlgorithmsVariable(ContextWrappingVariable):

View File

@ -105,6 +105,7 @@ supported_ctx_manager_classes = dict.fromkeys(
torch.autograd.profiler.profile,
torch.autograd.profiler.record_function,
torch._C.DisableTorchFunctionSubclass,
torch._C.DisableTorchFunction,
torch._functorch.vmap.vmap_increment_nesting,
torch._functorch.eager_transforms.grad_increment_nesting,
torch._functorch.eager_transforms.jvp_increment_nesting,
@ -342,9 +343,14 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
):
warning_once(log, "Profiler function %s will be ignored", self.value)
return ProfilerContextVariable()
elif self.value is torch._C.DisableTorchFunctionSubclass:
elif (
self.value is torch._C.DisableTorchFunctionSubclass
or self.value is torch._C.DisableTorchFunction
):
assert not (args or kwargs)
return TorchFunctionDisableVariable.create(tx)
return TorchFunctionDisableVariable.create(
tx, only_subclass=self.value is torch._C.DisableTorchFunctionSubclass
)
elif self.value is torch._functorch.vmap.vmap_increment_nesting:
assert len(args) == 2
return VmapIncrementNestingCtxManagerVariable.create(
@ -596,7 +602,18 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
@register(torch._C._is_torch_function_enabled)
def handle_is_torch_function_enabled(self, tx):
install_guard(TorchFunctionDisableVariable._guards_singleton)
return ConstantVariable.create(tx.output.torch_function_enabled)
# see comment on SymbolicTorchFunctionState class as to why
# this is not a bug
return ConstantVariable.create(
tx.symbolic_torch_function_state.torch_function_subclass_enabled
)
@register(torch._C._is_torch_function_all_disabled)
def handle_is_torch_function_all_disabled(self, tx):
install_guard(TorchFunctionDisableVariable._guards_singleton)
return ConstantVariable.create(
not tx.symbolic_torch_function_state.torch_function_mode_enabled
)
@register(
torch.overrides.has_torch_function,