mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ca] Support TorchDispatchMode via pass through (#156516)
The CA initial trace just proxies nodes without dispatching any ops, we should hide it from ambient TorchDispatchModes In terms of differences with eager autograd engine: - For function mode, CA additionally disables/re-enables `_set_multithreading_enabled` - For dispatch mode: - accumulate grad doesn't go down the stealing path (inaccurate compile-time refcount) so the grad `detach` ops are `copy_` instead - Since we always initial trace with dynamic shapes, and we filter out sizes, there's 1 aten.empty.memory_format for each mark_dynamic'd scalar Pull Request resolved: https://github.com/pytorch/pytorch/pull/156516 Approved by: https://github.com/jansel ghstack dependencies: #156374, #156509
This commit is contained in:
committed by
PyTorch MergeBot
parent
5f2f343e1e
commit
c06c2569ee
@ -32,6 +32,7 @@ from torch._inductor import config as inductor_config
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.overrides import BaseTorchFunctionMode
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
ops,
|
||||
@ -46,6 +47,7 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.hop_db import hop_db
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU
|
||||
from torch.testing._internal.logging_utils import logs_to_string
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
# note: these tests are not run on windows due to inductor_utils.HAS_CPU
|
||||
@ -4873,6 +4875,112 @@ class CompiledAutograd1(torch.nn.Module):
|
||||
count=[1, 3],
|
||||
)
|
||||
|
||||
def test_torch_function_mode(self):
|
||||
called_funcs = []
|
||||
|
||||
class LoggingTorchFunctionMode(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
called_funcs.append(str(func.__name__))
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
class MyLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, out):
|
||||
ctx.save_for_backward(out)
|
||||
return out.sum()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(saved,) = ctx.saved_tensors
|
||||
return torch.ones_like(saved) * grad_output
|
||||
|
||||
x = torch.randn(2, 2, requires_grad=True)
|
||||
y = torch.randn(2, 2)
|
||||
z = torch.randn(2, 2)
|
||||
|
||||
def fwd(x, y, z):
|
||||
out = x * y * z
|
||||
loss = MyLoss.apply(out)
|
||||
return loss
|
||||
|
||||
with LoggingTorchFunctionMode():
|
||||
called_funcs.append("Forward")
|
||||
loss = fwd(x, y, z)
|
||||
called_funcs.append("Backward")
|
||||
with torch._dynamo.compiled_autograd._enable(torch.compile):
|
||||
loss.backward()
|
||||
|
||||
self.assertExpectedInline(
|
||||
"\n".join(called_funcs),
|
||||
"""\
|
||||
Forward
|
||||
mul
|
||||
mul
|
||||
sum
|
||||
Backward
|
||||
_set_multithreading_enabled
|
||||
backward
|
||||
_set_multithreading_enabled""",
|
||||
) # noqa: B950
|
||||
|
||||
def test_torch_dispatch_mode(self):
|
||||
called_funcs = []
|
||||
|
||||
class LoggingTorchDispatchMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
called_funcs.append(str(func.__name__))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
class MyLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, out):
|
||||
ctx.save_for_backward(out)
|
||||
return out.sum()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(saved,) = ctx.saved_tensors
|
||||
return torch.ones_like(saved) * grad_output
|
||||
|
||||
x = torch.randn(2, 2, requires_grad=True)
|
||||
y = torch.randn(2, 2)
|
||||
z = torch.randn(2, 2)
|
||||
|
||||
def fwd(x, y, z):
|
||||
out = x * y * z
|
||||
loss = MyLoss.apply(out)
|
||||
return loss
|
||||
|
||||
with LoggingTorchDispatchMode():
|
||||
called_funcs.append("Forward")
|
||||
loss = fwd(x, y, z)
|
||||
called_funcs.append("Backward")
|
||||
with torch._dynamo.compiled_autograd._enable(lambda gm: gm):
|
||||
loss.backward()
|
||||
|
||||
self.assertExpectedInline(
|
||||
"\n".join(called_funcs),
|
||||
"""\
|
||||
Forward
|
||||
mul.Tensor
|
||||
mul.Tensor
|
||||
sum.default
|
||||
Backward
|
||||
ones_like.default
|
||||
empty.memory_format
|
||||
empty.memory_format
|
||||
empty.memory_format
|
||||
empty.memory_format
|
||||
empty.memory_format
|
||||
empty.memory_format
|
||||
ones_like.default
|
||||
mul.Tensor
|
||||
mul.Tensor
|
||||
mul.Tensor
|
||||
new_empty_strided.default
|
||||
copy_.default""",
|
||||
) # noqa: B950
|
||||
|
||||
|
||||
def load_test_module(name):
|
||||
testdir = Path(__file__).absolute().parent.parent
|
||||
@ -5062,8 +5170,6 @@ xfail_by_backend = {
|
||||
"test_reentrant_with_callbacks_depth_1", # queue_callback
|
||||
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
|
||||
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
|
||||
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
|
||||
"test_nested_checkpoint_set_early_stop_no_recompution_needed", # TorchDispatchMode not yet implemented
|
||||
"test_post_accumulate_grad_hook_ordering", # accuracy error
|
||||
"test_current_graph_task_id", # autograd state already cleared once dynamo is called
|
||||
"test_custom_function_forward_mode_forward_is_no_op", # forward AD
|
||||
@ -5156,6 +5262,7 @@ xfail_divergence_from_eager = {
|
||||
"test_vjp_call_compiled_backward_fn", # different functorch error
|
||||
"test_vmap_call_compiled_backward_fn", # different functorch error
|
||||
"test_accumulate_grad", # always out of place add for compiled autograd
|
||||
"test_current_node", # slightly different dispatched ops
|
||||
}
|
||||
|
||||
skipped_tests = set()
|
||||
|
@ -1205,10 +1205,6 @@ static variable_list compiled_autograd(
|
||||
const GraphTask& graph_task,
|
||||
bool accumulate_grad,
|
||||
const edge_list& output_edges) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
c10::impl::TorchDispatchModeTLS::stack_len() == 0,
|
||||
"TorchDispatchMode not yet implemented for compiled autograd. " +
|
||||
TURN_OFF_COMPILED_AUTOGRAD_MSG());
|
||||
static std::mutex mtx;
|
||||
LockGuardWithErrorLogs lock_guard(mtx);
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
@ -1222,6 +1218,8 @@ static variable_list compiled_autograd(
|
||||
THPObjectPtr packed_inputs;
|
||||
CacheNode* cache = nullptr;
|
||||
try {
|
||||
torch_dispatch_mode::StashTorchDispatchStackGuard stash_stack_guard;
|
||||
TORCH_INTERNAL_ASSERT(c10::impl::TorchDispatchModeTLS::stack_len() == 0);
|
||||
cache = _compiled_autograd_impl(
|
||||
graph_root,
|
||||
graph_task,
|
||||
@ -1233,6 +1231,7 @@ static variable_list compiled_autograd(
|
||||
&hooks,
|
||||
&packed_inputs,
|
||||
active_rstate);
|
||||
TORCH_INTERNAL_ASSERT(c10::impl::TorchDispatchModeTLS::stack_len() == 0);
|
||||
} catch (const c10::NotImplementedError& e) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, std::string(e.what()) + " " + TURN_OFF_COMPILED_AUTOGRAD_MSG());
|
||||
|
Reference in New Issue
Block a user