[ca] Support TorchDispatchMode via pass through (#156516)

The CA initial trace just proxies nodes without dispatching any ops, we should hide it from ambient TorchDispatchModes

In terms of differences with eager autograd engine:
- For function mode, CA additionally disables/re-enables `_set_multithreading_enabled`
- For dispatch mode:
  - accumulate grad doesn't go down the stealing path (inaccurate compile-time refcount) so the grad `detach` ops are `copy_` instead
  - Since we always initial trace with dynamic shapes, and we filter out sizes, there's 1 aten.empty.memory_format for each mark_dynamic'd scalar

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156516
Approved by: https://github.com/jansel
ghstack dependencies: #156374, #156509
This commit is contained in:
Simon Fan
2025-06-21 11:25:11 -07:00
committed by PyTorch MergeBot
parent 5f2f343e1e
commit c06c2569ee
2 changed files with 112 additions and 6 deletions

View File

@ -32,6 +32,7 @@ from torch._inductor import config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.overrides import BaseTorchFunctionMode
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
@ -46,6 +47,7 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.hop_db import hop_db
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU
from torch.testing._internal.logging_utils import logs_to_string
from torch.utils._python_dispatch import TorchDispatchMode
# note: these tests are not run on windows due to inductor_utils.HAS_CPU
@ -4873,6 +4875,112 @@ class CompiledAutograd1(torch.nn.Module):
count=[1, 3],
)
def test_torch_function_mode(self):
called_funcs = []
class LoggingTorchFunctionMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
called_funcs.append(str(func.__name__))
return super().__torch_function__(func, types, args, kwargs)
class MyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, out):
ctx.save_for_backward(out)
return out.sum()
@staticmethod
def backward(ctx, grad_output):
(saved,) = ctx.saved_tensors
return torch.ones_like(saved) * grad_output
x = torch.randn(2, 2, requires_grad=True)
y = torch.randn(2, 2)
z = torch.randn(2, 2)
def fwd(x, y, z):
out = x * y * z
loss = MyLoss.apply(out)
return loss
with LoggingTorchFunctionMode():
called_funcs.append("Forward")
loss = fwd(x, y, z)
called_funcs.append("Backward")
with torch._dynamo.compiled_autograd._enable(torch.compile):
loss.backward()
self.assertExpectedInline(
"\n".join(called_funcs),
"""\
Forward
mul
mul
sum
Backward
_set_multithreading_enabled
backward
_set_multithreading_enabled""",
) # noqa: B950
def test_torch_dispatch_mode(self):
called_funcs = []
class LoggingTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
called_funcs.append(str(func.__name__))
return func(*args, **kwargs)
class MyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, out):
ctx.save_for_backward(out)
return out.sum()
@staticmethod
def backward(ctx, grad_output):
(saved,) = ctx.saved_tensors
return torch.ones_like(saved) * grad_output
x = torch.randn(2, 2, requires_grad=True)
y = torch.randn(2, 2)
z = torch.randn(2, 2)
def fwd(x, y, z):
out = x * y * z
loss = MyLoss.apply(out)
return loss
with LoggingTorchDispatchMode():
called_funcs.append("Forward")
loss = fwd(x, y, z)
called_funcs.append("Backward")
with torch._dynamo.compiled_autograd._enable(lambda gm: gm):
loss.backward()
self.assertExpectedInline(
"\n".join(called_funcs),
"""\
Forward
mul.Tensor
mul.Tensor
sum.default
Backward
ones_like.default
empty.memory_format
empty.memory_format
empty.memory_format
empty.memory_format
empty.memory_format
empty.memory_format
ones_like.default
mul.Tensor
mul.Tensor
mul.Tensor
new_empty_strided.default
copy_.default""",
) # noqa: B950
def load_test_module(name):
testdir = Path(__file__).absolute().parent.parent
@ -5062,8 +5170,6 @@ xfail_by_backend = {
"test_reentrant_with_callbacks_depth_1", # queue_callback
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
"test_nested_checkpoint_set_early_stop_no_recompution_needed", # TorchDispatchMode not yet implemented
"test_post_accumulate_grad_hook_ordering", # accuracy error
"test_current_graph_task_id", # autograd state already cleared once dynamo is called
"test_custom_function_forward_mode_forward_is_no_op", # forward AD
@ -5156,6 +5262,7 @@ xfail_divergence_from_eager = {
"test_vjp_call_compiled_backward_fn", # different functorch error
"test_vmap_call_compiled_backward_fn", # different functorch error
"test_accumulate_grad", # always out of place add for compiled autograd
"test_current_node", # slightly different dispatched ops
}
skipped_tests = set()

View File

@ -1205,10 +1205,6 @@ static variable_list compiled_autograd(
const GraphTask& graph_task,
bool accumulate_grad,
const edge_list& output_edges) {
TORCH_CHECK_NOT_IMPLEMENTED(
c10::impl::TorchDispatchModeTLS::stack_len() == 0,
"TorchDispatchMode not yet implemented for compiled autograd. " +
TURN_OFF_COMPILED_AUTOGRAD_MSG());
static std::mutex mtx;
LockGuardWithErrorLogs lock_guard(mtx);
pybind11::gil_scoped_acquire gil;
@ -1222,6 +1218,8 @@ static variable_list compiled_autograd(
THPObjectPtr packed_inputs;
CacheNode* cache = nullptr;
try {
torch_dispatch_mode::StashTorchDispatchStackGuard stash_stack_guard;
TORCH_INTERNAL_ASSERT(c10::impl::TorchDispatchModeTLS::stack_len() == 0);
cache = _compiled_autograd_impl(
graph_root,
graph_task,
@ -1233,6 +1231,7 @@ static variable_list compiled_autograd(
&hooks,
&packed_inputs,
active_rstate);
TORCH_INTERNAL_ASSERT(c10::impl::TorchDispatchModeTLS::stack_len() == 0);
} catch (const c10::NotImplementedError& e) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, std::string(e.what()) + " " + TURN_OFF_COMPILED_AUTOGRAD_MSG());