Compare commits

...

6 Commits

Author SHA1 Message Date
3d732de4c8 more wip 2025-03-03 13:36:11 -08:00
19a27343a0 wip [ca] single step 2025-02-28 19:25:39 -08:00
cabc560762 [ca] side-effect free initial trace: RAII PyCompilerInterface
ghstack-source-id: c6dbc9480b928aaa8147e20de73bf668d547040d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147891
2025-02-26 20:17:18 -08:00
663565c84a [ca] side-effect free inital trace: compiled_args
ghstack-source-id: 4e3fa1d948411730c6e834ea4d996ac8c6d72bcb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147804
2025-02-25 19:57:55 -08:00
95f317d9f8 [ca] side-effect free initial trace: GraphTask
ghstack-source-id: 1e13d1cd892800b1647482b42846a0b4c55912b3
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147796
2025-02-25 19:57:55 -08:00
643da4854f wip [ca] unpack hooks
ghstack-source-id: c2b4c7eedbb69db8e803949894f9f0dda29a5171
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242
2025-02-25 19:57:54 -08:00
36 changed files with 930 additions and 255 deletions

118
hack.py Normal file
View File

@ -0,0 +1,118 @@
import torch
from torch._inductor.ir import NoneAsConstantBuffer
import torch.nn as nn
import torch.nn.functional as F
import depyf
depyf.install()
def fn(loss):
gm = None
args = None
def noop(_gm):
nonlocal gm
gm = _gm
def _noop(*_args, **_kwargs):
assert not _kwargs
nonlocal args
args = _args
return []
return _noop
with torch._dynamo.compiled_autograd._enable(noop):
loss.backward()
return gm, args
result = torch._dynamo.compiled_autograd.Op("FunctionalCompiledAutograd", fn, is_custom_function=False)
setattr(torch._dynamo.compiled_autograd.ops, "FunctionalCompiledAutograd", torch._dynamo.allow_in_graph(result))
x = torch.randn(64, 3)
t = torch.randn(64, 1)
model = nn.Linear(3, 1)
torch._dynamo.config.compiled_autograd = True
torch._dynamo.config.do_not_emit_runtime_asserts = True
@torch.compile(backend="eager")
def train(model, x, t):
y = model(x)
loss = F.mse_loss(y, t)
gm, args = torch._dynamo.compiled_autograd.ops.FunctionalCompiledAutograd(loss)
gm(*args)
return ()
# with torch._dynamo.compiled_autograd._enable(noop):
train(model, x, t)
for p in model.parameters():
assert p.grad is not None
"""
# this kinda works, but not ideal
===== __compiled_fn_1 =====
/home/xmfan/core/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_model_parameters_weight_: "f32[1, 3][3, 1]cpu", L_model_parameters_bias_: "f32[1][1]cpu", L_x_: "f32[64, 3][3, 1]cpu", L_t_: "f32[64, 1][1, 1]cpu"):
l_model_parameters_weight_ = L_model_parameters_weight_
l_model_parameters_bias_ = L_model_parameters_bias_
l_x_ = L_x_
l_t_ = L_t_
# File: /home/xmfan/core/a/pytorch/hack.py:44 in train, code: y = model(x)
y: "f32[64, 1][1, 1]cpu" = torch._C._nn.linear(l_x_, l_model_parameters_weight_, l_model_parameters_bias_); l_x_ = l_model_parameters_weight_ = l_model_parameters_bias_ = None
# File: /home/xmfan/core/a/pytorch/hack.py:45 in train, code: loss = F.mse_loss(y, t)
loss: "f32[][]cpu" = torch.nn.functional.mse_loss(y, l_t_); y = l_t_ = None
# File: /home/xmfan/core/a/pytorch/hack.py:46 in train, code: gm, args = torch._dynamo.compiled_autograd.ops.FunctionalCompiledAutograd(loss)
functional_compiled_autograd = torch__dynamo_compiled_autograd_ops_FunctionalCompiledAutograd(loss); loss = None
getitem = functional_compiled_autograd[1]; functional_compiled_autograd = None
getitem_1 = getitem[0]; getitem = None
getitem_8: "f32[][]cpu" = getitem_1[0]
getitem_9: "f32[64, 1][1, 1]cpu" = getitem_1[1]
getitem_10: "f32[64, 1][1, 1]cpu" = getitem_1[2]
getitem_11: "f32[64, 3][3, 1]cpu" = getitem_1[3]; getitem_1 = None
# File: <eval_with_key>.0:11 in forward, code: validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_8], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem_8 = None
getitem_15: "f32[][]cpu" = validate_outputs[0]; validate_outputs = None
# File: <eval_with_key>.0:13 in forward, code: mse_loss_backward0 = torch__dynamo_compiled_autograd_ops_MseLossBackward0([getitem_6], [True, False], 1, getitem_1, getitem_2); getitem_6 = getitem_1 = getitem_2 = None
mse_loss_backward0 = torch__dynamo_compiled_autograd_ops_MseLossBackward0([getitem_15], [True, False], 1, getitem_9, getitem_10); getitem_15 = getitem_9 = getitem_10 = None
getitem_17: "f32[64, 1][1, 1]cpu" = mse_loss_backward0[0]; mse_loss_backward0 = None
# File: <eval_with_key>.0:16 in forward, code: validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_7, getitem_8], [((None, None, device(type='cpu'), 6, 0, None), [64, 1], True), None]); getitem_7 = getitem_8 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_17, None], [((None, None, device(type='cpu'), 6, 0, None), [64, 1], True), None]); getitem_17 = None
getitem_19: "f32[64, 1][1, 1]cpu" = validate_outputs_1[0]; validate_outputs_1 = None
# File: <eval_with_key>.0:18 in forward, code: addmm_backward0 = torch__dynamo_compiled_autograd_ops_AddmmBackward0([getitem_9], [True, False, True], 1, 1, getitem_3, 0, [64, 3], [], None, 0, [3, 1], [1, 3]); getitem_9 = getitem_3 = None
addmm_backward0 = torch__dynamo_compiled_autograd_ops_AddmmBackward0([getitem_19], [True, False, True], 1, 1, getitem_11, 0, [64, 3], [], None, 0, [3, 1], [1, 3]); getitem_19 = getitem_11 = None
getitem_22: "f32[64, 1][1, 1]cpu" = addmm_backward0[0]
getitem_23: "f32[3, 1][1, 3]cpu" = addmm_backward0[2]; addmm_backward0 = None
# File: <eval_with_key>.0:22 in forward, code: validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_11, getitem_12, getitem_13], [((None, None, device(type='cpu'), 6, 0, None), [1], True), None, ((None, None, device(type='cpu'), 6, 0, None), [3, 1], True)]); getitem_11 = getitem_12 = getitem_13 = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_22, None, getitem_23], [((None, None, device(type='cpu'), 6, 0, None), [1], True), None, ((None, None, device(type='cpu'), 6, 0, None), [3, 1], True)]); getitem_22 = getitem_23 = None
getitem_26: "f32[1][1]cpu" = validate_outputs_2[0]
getitem_27: "f32[3, 1][1, 3]cpu" = validate_outputs_2[2]; validate_outputs_2 = None
# File: /home/xmfan/core/a/pytorch/torch/_dynamo/polyfills/__init__.py:80 in accumulate_grad, code: new_grad = torch.clone(new_grad)
new_grad: "f32[1][1]cpu" = torch.clone(getitem_26); getitem_26 = new_grad = None
# File: <eval_with_key>.0:26 in forward, code: tbackward0 = torch__dynamo_compiled_autograd_ops_TBackward0([getitem_16], [True]); getitem_16 = None
tbackward0 = torch__dynamo_compiled_autograd_ops_TBackward0([getitem_27], [True]); getitem_27 = None
getitem_29: "f32[1, 3][3, 1]cpu" = tbackward0[0]; tbackward0 = None
# File: <eval_with_key>.0:28 in forward, code: validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_17], [((None, None, device(type='cpu'), 6, 0, None), [1, 3], True)]); getitem_17 = None
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_29], [((None, None, device(type='cpu'), 6, 0, None), [1, 3], True)]); getitem_29 = None
getitem_31: "f32[1, 3][3, 1]cpu" = validate_outputs_3[0]; validate_outputs_3 = None
# File: /home/xmfan/core/a/pytorch/torch/_dynamo/polyfills/__init__.py:80 in accumulate_grad, code: new_grad = torch.clone(new_grad)
new_grad_1: "f32[1, 3][3, 1]cpu" = torch.clone(getitem_31); getitem_31 = new_grad_1 = None
return ()
"""

View File

@ -10,6 +10,7 @@ import os
import re
import subprocess
import sys
import tempfile
import unittest
from copy import deepcopy
from importlib.machinery import SourceFileLoader
@ -48,11 +49,17 @@ from torch.testing._internal.logging_utils import logs_to_string
# note: these tests are not run on windows due to inductor_utils.HAS_CPU
def make_compiler_fn(fullgraph=True, dynamic=True, backend="inductor"):
assert backend in ["inductor", "aot_eager"]
def make_compiler_fn(
fullgraph=True, dynamic=True, backend="inductor", gm_hook=lambda gm: None
):
assert backend in ["inductor", "aot_eager", "ca_eager"]
def _compiler_fn(gm):
"""Same as torch.compile() but counts number of compiles"""
gm_hook(gm)
if backend == "ca_eager":
return gm
def _inner_compiler(gm_, example_inputs_):
counters["compiled_autograd"]["compiles"] += 1
@ -112,7 +119,10 @@ class TestCompiledAutograd(TestCase):
torch.manual_seed(123)
expected = list(fn())
torch.manual_seed(123)
with compiled_autograd._enable(compiler_fn):
with compiled_autograd._enable(compiler_fn), mock.patch(
"torch._functorch.aot_autograd.AOT_COUNTER",
new_callable=itertools.count,
):
opt_fn = torch.compile(fn) if compile_fn else fn
actual = list(opt_fn())
self.assertEqual(expected, actual)
@ -915,7 +925,8 @@ main()
inputs=[param, activ],
sizes=(),
scalars=(),
hooks=(),
hooks=[],
packed_inputs=[],
)
finally:
handle.remove()
@ -3322,7 +3333,10 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
graphs.append(gm)
return inner_compiler_fn(gm)
with compiled_autograd._enable(compiler_fn):
with compiled_autograd._enable(compiler_fn), mock.patch(
"torch._functorch.aot_autograd.AOT_COUNTER",
new_callable=itertools.count,
):
res = fn(x)
res.sum().backward()
@ -3336,7 +3350,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
graph_code,
"""\
class CompiledAutograd0(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks):
def forward(self, inputs, sizes, scalars, hooks, packed_data):
getitem = inputs[0]
getitem_1 = inputs[1]
getitem_2 = inputs[2]
@ -3511,9 +3525,7 @@ class CompiledAutograd0(torch.nn.Module):
fn, count=2, compiler_fn=make_compiler_fn(backend="aot_eager")
)
@unittest.expectedFailure
def test_saved_tensor_unpack_hook_ordering(self):
# not the correct behaviour, I'm just preventing this from changing silently
def f(x, y):
return x * y
@ -3531,8 +3543,6 @@ class CompiledAutograd0(torch.nn.Module):
return x
def tensor_hook(_):
# in eager, tensor_hook is fired before unpack_hook
# but in compiled autograd, tensor_hook is lifted whereas unpack_hook is not
self.assertEqual(unpack_count, 0)
x = torch.ones(4, requires_grad=True)
@ -3544,21 +3554,252 @@ class CompiledAutograd0(torch.nn.Module):
self.assertEqual(pack_count, 1)
self.assertEqual(unpack_count, 0)
loss = out_test.sum()
loss.register_hook(tensor_hook)
loss.register_hook(
tensor_hook
) # scheduled to fire before any saved activations
loss.backward()
self.assertEqual(pack_count, 1)
self.assertEqual(unpack_count, 1)
def test_reentrant_checkpointing(self):
def fn(x):
y = x.sin()
z = y.cos()
return (y * z).sum()
@parametrize("reentrant", (True, False))
def test_checkpointing_simple(self, reentrant):
def fn():
def _fn(x):
y = x.sin()
z = y.cos()
return (y * z).sum()
inp = torch.rand(10, 10, requires_grad=True)
out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True)
with torch._dynamo.compiled_autograd._enable(torch.compile):
inp = torch.rand(10, 10, requires_grad=True)
out = torch.utils.checkpoint.checkpoint(_fn, inp, use_reentrant=reentrant)
out.backward()
yield inp.grad
if reentrant:
self.check_output_and_recompiles(
fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False)
)
else:
# dynamo issues, just run the CA graph directly for now
def check(gm):
graph_code = normalize_gm(gm.print_readable(print_output=False))
self.assertExpectedInline(
graph_code,
"""\
class CompiledAutograd0(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks, packed_data):
getitem = inputs[0]
getitem_1 = inputs[1]; inputs = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
getitem_2 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_2], [True], [10, 10]); getitem_2 = None
getitem_3 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_3], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_3 = None
getitem_4 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_5 = hooks[0]
getitem_6 = packed_data[0]
getitem_7 = hooks[1]
getitem_8 = packed_data[1]
call_hook = torch__dynamo_external_utils_call_hook(getitem_5, getitem_6, hook_type = 'unpack_hook'); getitem_5 = getitem_6 = None
call_hook_1 = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None
mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_4], [True, True], call_hook, 6, call_hook_1, 6); getitem_4 = call_hook = call_hook_1 = None
getitem_9 = mul_backward0[0]
getitem_10 = mul_backward0[1]; mul_backward0 = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_9, getitem_10], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False), ((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_9 = getitem_10 = None
getitem_11 = validate_outputs_2[0]
getitem_12 = validate_outputs_2[1]; validate_outputs_2 = None
getitem_13 = hooks[2]
getitem_14 = packed_data[2]
call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_13, getitem_14, hook_type = 'unpack_hook'); getitem_13 = getitem_14 = None
cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_12], [True], call_hook_2); getitem_12 = call_hook_2 = None
getitem_15 = cos_backward0[0]; cos_backward0 = None
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_15 = None
getitem_16 = validate_outputs_3[0]; validate_outputs_3 = None
add = torch.add(getitem_11, getitem_16); getitem_11 = getitem_16 = None
getitem_17 = hooks[3]; hooks = None
getitem_18 = packed_data[3]; packed_data = None
call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_17, getitem_18, hook_type = 'unpack_hook'); getitem_17 = getitem_18 = None
sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None
getitem_19 = sin_backward0[0]; sin_backward0 = None
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_19], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_19 = None
getitem_20 = validate_outputs_4[0]; validate_outputs_4 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_20); getitem_1 = getitem_20 = accumulate_grad_ = None
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
return []
""", # noqa: B950
)
self.check_output_and_recompiles(
fn,
count=[1, 0],
compiler_fn=make_compiler_fn(backend="ca_eager", gm_hook=check),
)
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_cpu_offloading(self):
def fn():
def pack(x):
return x.cpu()
def unpack(x):
return x.cuda()
class MyMatMul(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.matmul(x, x)
@staticmethod
def backward(ctx, grad_out):
(x,) = ctx.saved_tensors
return grad_out * x
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
for i in [10, 100, 10, 20, 30]:
x = torch.randn(i, requires_grad=True).cuda()
MyMatMul.apply(x).sum().backward()
yield x.grad
i = 0
def check(gm):
nonlocal i
if i == 0:
i += 1
return
graph_code = normalize_gm(gm.print_readable(print_output=False))
self.assertExpectedInline(
graph_code,
"""\
class CompiledAutograd1(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks, packed_data):
getitem = inputs[0]
getitem_1 = inputs[1]; inputs = None
getitem_2 = sizes[0]; getitem_2 = None
getitem_3 = sizes[1]
getitem_4 = sizes[2]; sizes = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cuda', index=0), 6, 0, None), [], False)]); getitem = None
getitem_5 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], []); getitem_5 = None
getitem_6 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cuda', index=0), 6, 0, None), [], False)]); getitem_6 = None
getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_8 = hooks[0]
getitem_9 = packed_data[0]; packed_data = None
getitem_10 = hooks[1]; hooks = None
call_hook = torch__dynamo_external_utils_call_hook(getitem_8, getitem_9, hook_type = 'unpack_hook'); getitem_8 = getitem_9 = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_10, (call_hook,), getitem_7); getitem_10 = call_hook = getitem_7 = None
getitem_12 = call_backward[0]; call_backward = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_12], [((None, None, device(type='cuda', index=0), 6, 0, None), [getitem_3], False)]); getitem_12 = getitem_3 = None
getitem_13 = validate_outputs_2[0]; validate_outputs_2 = None
to_copy_backward0 = torch__dynamo_compiled_autograd_ops_ToCopyBackward0([getitem_13], [True], (None, None, device(type='cpu'), 6, 0, None)); getitem_13 = None
getitem_14 = to_copy_backward0[0]; to_copy_backward0 = None
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_14], [((None, None, device(type='cpu'), 6, 0, None), [getitem_4], False)]); getitem_14 = getitem_4 = None
getitem_15 = validate_outputs_3[0]; validate_outputs_3 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_15); getitem_1 = getitem_15 = accumulate_grad_ = None
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
return []
""", # noqa: B950
)
self.check_output_and_recompiles(
fn, count=2, compiler_fn=make_compiler_fn(gm_hook=check)
)
@skipIfWindows(msg="temp dir not compatible")
def test_disk_offloading(self):
with tempfile.TemporaryDirectory() as d:
def fn():
pack_count = 0
def pack(x):
nonlocal pack_count
path = f"{d}/{pack_count}.pt"
torch.save(x, path)
return path
def unpack(path):
x = torch.load(path)
return x
class MyMatMul(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.matmul(x, x)
@staticmethod
def backward(ctx, grad_out):
(x,) = ctx.saved_tensors
return grad_out * x
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
for i in [10, 100, 10, 20, 30]:
x = torch.randn(i, requires_grad=True)
MyMatMul.apply(x).sum().backward()
yield x.grad
i = 0
def check(gm):
nonlocal i
if i == 0:
i += 1
return
graph_code = normalize_gm(gm.print_readable(print_output=False))
self.assertExpectedInline(
graph_code,
"""\
class CompiledAutograd1(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks, packed_data):
getitem = inputs[0]
getitem_1 = inputs[1]; inputs = None
getitem_2 = sizes[0]; getitem_2 = None
getitem_3 = sizes[1]; sizes = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
getitem_4 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_4], [True], []); getitem_4 = None
getitem_5 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_5], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem_5 = None
getitem_6 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_7 = hooks[0]
getitem_8 = packed_data[0]; packed_data = None
getitem_9 = hooks[1]; hooks = None
call_hook = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_9, (call_hook,), getitem_6); getitem_9 = call_hook = getitem_6 = None
getitem_11 = call_backward[0]; call_backward = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_11], [((None, None, device(type='cpu'), 6, 0, None), [getitem_3], False)]); getitem_11 = getitem_3 = None
getitem_12 = validate_outputs_2[0]; validate_outputs_2 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_12); getitem_1 = getitem_12 = accumulate_grad_ = None
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
return []
""", # noqa: B950
)
# 1 graph break on torch.load -> 2 dynamo graphs
self.check_output_and_recompiles(
fn,
count=[2, 4],
compiler_fn=make_compiler_fn(fullgraph=False, gm_hook=check),
)
@skipIfWindows(msg="node name demangling inconsistent on windows")
def test_backward_hook_relative_ordering_partial(self):
@ -3617,7 +3858,7 @@ class CompiledAutograd0(torch.nn.Module):
self.check_output_and_recompiles(fn)
def test_sac(self):
def test_checkpointing_sac(self):
# circular import
from torch.utils.checkpoint import (
checkpoint,
@ -3666,7 +3907,9 @@ class CompiledAutograd0(torch.nn.Module):
yield model.layer4.weight.grad
yield model.layer4.bias.grad
self.check_output_and_recompiles(fn)
self.check_output_and_recompiles(
fn, count=[1, 5], compiler_fn=make_compiler_fn(fullgraph=False)
)
def load_test_module(name):
@ -3754,6 +3997,26 @@ known_graph_breaks_tests = {
"test_deep_reentrant", # reentrant .backward
"test_reentrant_priority", # reentrant .backward
"test_simple_reentrant", # reentrant .backward
"test_checkpoint_detects_non_determinism", # unpack hook in skip files
"test_checkpoint_valid_reset_on_error", # unpack hook in skip files
"test_checkpointing_non_reentrant_autocast_cpu", # unpack hook in skip files
"test_checkpointing_non_reentrant_autocast_gpu", # unpack hook in skip files
"test_checkpointing_without_reentrant_arbitrary_input_output", # unpack hook in skip files
"test_checkpointing_without_reentrant_correct_grad", # unpack hook in skip files
"test_checkpointing_without_reentrant_custom_function_works", # unpack hook in skip files
"test_checkpointing_without_reentrant_dataparallel", # _get_device_index in skip files
"test_checkpointing_without_reentrant_detached_tensor_use_reentrant_True", # reentrant .backward
"test_checkpointing_without_reentrant_parameter_used_in_an_out", # unpack hook in skip files
"test_checkpointing_without_reentrant_with_context_fn", # unpack hook in skip files
"test_save_on_cpu_and_checkpoint", # unpack hook in skip files
"test_saved_tensor_hooks_custom_error_propagation", # CustomError
"test_access_saved_tensor_twice_without_recomputation_works", # unpack hook in skip files
"test_saved_tensor_hooks_extra_enter_during_bw_no_leak", # ctx in skip files
"test_saved_tensor_hooks_extra_exit_during_bw_no_crash", # ctx in skip files
"test_checkpointing", # reentrant .backward
"test_checkpointing_without_reentrant_input_requires_grad_False", # reentrant .backward
"test_checkpointing_without_reentrant_input_requires_grad_True", # reentrant .backward
"test_checkpointing_without_reentrant_memory_savings", # reentrant .backward
}
test_contexts = {
@ -3764,9 +4027,7 @@ test_contexts = {
}
# These groups of tests aren't supported yet
known_failures_re = re.compile(
r"^test_(sparse|profiler|gradcheck|checkpoint|named_tensor)"
)
known_failures_re = re.compile(r"^test_(sparse|profiler|gradcheck|named_tensor)")
# Bugs needing investigation:
skipped_tests = {
@ -3837,7 +4098,7 @@ known_failing_tests = {
# IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
"test_grad_nonleaf_register_hook",
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
# Category: Dynamo
# Category: Dynamo (pass when directly running CA graph)
"test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None
"test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
@ -3849,7 +4110,20 @@ known_failing_tests = {
"test_return_duplicate", # gradient batching rule not implemented for aten::sym_size.int
"test_return_duplicate_inplace", # gradient batching rule not implemented for aten::sym_size.int
"test_setitem", # CopySlices accuracy error
# Category: Inductor
"test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565
"test_checkpoint_detects_non_determinism", # different error
"test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute
"test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute
"test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times
"test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised
"test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute
"test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115
"test_checkpointing", # takes very very long
"test_checkpointing_without_reentrant_input_requires_grad_False", # takes very very long
"test_checkpointing_without_reentrant_input_requires_grad_True", # takes very very long
"test_checkpointing_without_reentrant_memory_savings", # takes very very long
# Category: Inductor (pass on backend="aot_eager")
"test_input_buffer_accum", # does not support sparse_grad=True: https://github.com/pytorch/pytorch/issues/120267
"test_graph_save_on_cpu", # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173
# Category: FakeTensor
@ -3861,6 +4135,7 @@ known_failing_tests = {
"test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
"test_backward_hook_relative_ordering", # compiled autograd collects breadth first, and module backward hook not supported
"test_checkpointing_without_reentrant_custom_function_works", # ctx.saved_tensors are cached by CA
# Category: Subclasses
"test_dtensor_basic",
"test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent",

View File

@ -67,7 +67,7 @@ struct TORCH_API ${op} : public ${superclass} {
${release_variables}
}
${will_release_variables}
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
${saved_variables}
${saved_list_sizes}
@ -127,7 +127,7 @@ variable_list ${op}::apply(variable_list&& grads) {
return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
}
void ${op}::compiled_args(CompiledNodeArgs& args) {
void ${op}::compiled_args(CompiledNodeArgs& args) const {
${compiled_args}
}
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {

View File

@ -134,7 +134,7 @@ class Op:
ops = OpNamespace()
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks", "packed_data"]
_impure_targets = OrderedSet(
[
call_hook,
@ -164,12 +164,16 @@ class AutogradCompilerInstance:
self.compiler_fn = compiler_fn
self.stack = contextlib.ExitStack()
self.close = self.stack.close
self.shape_env = ShapeEnv()
self.fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=True,
allow_non_fake_inputs=True,
shape_env=self.shape_env,
)
if ctx := torch._guards.TracingContext.try_get():
self.shape_env = ctx.fake_mode.shape_env
self.fake_tensor_mode = ctx.fake_mode
else:
self.shape_env = ShapeEnv()
self.fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=True,
allow_non_fake_inputs=True,
shape_env=self.shape_env,
)
self.fx_tracer = PythonKeyTracer()
self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
self.hooks_proxy: Optional[Proxy] = None
@ -206,7 +210,13 @@ class AutogradCompilerInstance:
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
self.fx_tracer.tensor_attrs = {}
self.symnode_proxy_lookup = {}
args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = (
(
args_proxy,
self.sizes_proxy,
self.scalars_proxy,
self.hooks_proxy,
self.packed_data_proxy,
) = (
self.fx_tracer.create_proxy("placeholder", name, (), {})
for name in _graph_placeholders
)
@ -268,7 +278,12 @@ class AutogradCompilerInstance:
self.stack.enter_context(
torch.fx.experimental.symbolic_shapes._suppress_guards(env)
)
return str(CompileContext.current_compile_id()), inputs, sizes, scalars
return (
str(CompileContext.current_compile_id()),
inputs,
sizes,
scalars,
)
def log_compile_reasons(
self,
@ -567,6 +582,19 @@ class AutogradCompilerInstance:
kwargs,
)
def unpack_hook(self, hook_id, data_id):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
data = self.packed_data_proxy[data_id] # type: ignore[index]
proxy = self.proxy_call_hook(
hook,
data,
hook_type="unpack_hook",
)
out = self.allocate_dummy()
self.bind_objects_to_proxies([out], [proxy])
return out
def tensor_pre_hook(self, inputs, hook_id, i: int):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
@ -706,6 +734,9 @@ class AutogradCompilerInstance:
after = len(self.fx_tracer.graph.nodes)
verbose_log.debug("DCE removed %d nodes", before - after)
def create_graph_module(self, id):
return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id)
def end_capture(self, outputs):
self.fx_tracer.create_proxy(
"call_function",
@ -745,6 +776,7 @@ class AutogradCompilerInstance:
).print_readable(print_output=False),
)
self.rename_aot_dispatcher_nodes()
self.delay_unpack_hook_nodes()
self.reorder_tensor_pre_hook_nodes()
self.reorder_pre_hook_nodes_to_schedule_asap()
self.reorder_accumulate_grad_nodes()
@ -763,9 +795,7 @@ class AutogradCompilerInstance:
# should prevent these ops from going into the CA graph.
self.dce()
graph = GraphModule(
self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}"
)
graph = self.create_graph_module(f"CompiledAutograd{self.id}")
set_locals_to_steal(graph, ["inputs"])
lazy_graph_code = lazy_format_graph_code(
"Compiled autograd graph",
@ -781,17 +811,20 @@ class AutogradCompilerInstance:
payload_fn=lambda: graph.print_readable(print_output=False),
)
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
global in_compiled_autograd_region
try:
in_compiled_autograd_region = True
for i in runtime_inputs_to_move:
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs):
# prior = torch._C._dynamo.eval_frame.set_eval_frame(None)
return compiled_fn(inputs, sizes, scalars, hooks, packed_inputs)
# torch._C._dynamo.eval_frame.set_eval_frame(prior)
# global in_compiled_autograd_region
# try:
# in_compiled_autograd_region = True
# for i in runtime_inputs_to_move:
# inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
with _disable(), make_compile_context(self.id):
return compiled_fn(inputs, sizes, scalars, hooks)
finally:
in_compiled_autograd_region = False
# with _disable(), make_compile_context(self.id):
# return compiled_fn(inputs, sizes, scalars, hooks, packed_inputs)
# finally:
# in_compiled_autograd_region = False
get_chromium_event_logger().log_event_end(
"compiled_autograd",
@ -938,6 +971,19 @@ class AutogradCompilerInstance:
if getitem_node is not None:
arg.append(getitem_node)
def delay_unpack_hook_nodes(self):
"""
We can delay unpack hooks until they are needed, even later than in the eager autograd engine.
"""
for node in self.fx_tracer.graph.find_nodes(
op="call_function", target=call_hook
):
if node.kwargs.get("hook_type", None) != "unpack_hook":
continue
first_user = min(node.users)
first_user.prepend(node)
def reorder_tensor_pre_hook_nodes(self):
"""
Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed

View File

@ -80,8 +80,9 @@ def accumulate_grad(x, new_grad):
new_grad = torch.clone(new_grad)
if x.grad is None:
x.grad = new_grad
else:
x.grad.add_(new_grad)
# problem here. can't trace???
# else:
# x.grad.add_(new_grad)
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):

View File

@ -466,6 +466,7 @@ class CompileEventLogger:
chromium_log = get_chromium_event_logger()
top_event = chromium_log.get_outermost_event()
if top_event is None:
return
raise RuntimeError(
"No toplevel event active. Please only call this function within a metrics context/dynamo_timed."
)

View File

@ -2426,6 +2426,7 @@ def _wrap_fx_proxy(
# This handles wrapping of the output of an op traced into the graph
def handle_traced_output(example_value, tx, proxy, options, subclass_type, target_cls):
# HERE
import torch._functorch.vmap
import torch._subclasses.fake_tensor
import torch._utils
@ -2480,6 +2481,69 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
return SizeVariable(sizes, **options)
elif isinstance(example_value, (tuple, list)):
set_example_value(proxy.node, example_value)
if (
len(example_value) == 2
and example_value[1]
and len(example_value[1]) == 5
and [type(val) for val in example_value[1]]
== [list, tuple, list, tuple, tuple]
and isinstance(example_value[0], torch.fx.graph_module.GraphModule)
and "locals_to_steal" in example_value[0].meta
):
gm, (inputs, sizes, scalars, hooks, packed_data) = example_value
# gm_proxy = proxy.tracer.create_proxy(
# kind="call_function",
# target=operator.getitem,
# args=(proxy, 0),
# kwargs={},
# ) # mmmmm should we use the proxy vs just by value? it is const
# then need to guard here?
gm_var = SourcelessGraphModuleVariable(gm)
args_proxy = proxy.tracer.create_proxy(
kind="call_function",
target=operator.getitem,
args=(proxy, 1),
kwargs={},
)
inputs_proxy = proxy.tracer.create_proxy(
kind="call_function",
target=operator.getitem,
args=(args_proxy, 0),
kwargs={},
)
inputs_var = []
for i, inp in enumerate(inputs):
inputs_proxy_i = proxy.tracer.create_proxy(
kind="call_function",
target=operator.getitem,
args=(inputs_proxy, i),
kwargs={},
)
inputs_var.append(
wrap_fx_proxy(
tx=tx,
proxy=inputs_proxy_i,
example_value=inp,
subclass_type=None, # uhhh how to tell this?
source=None, # won't have any?
**options,
)
)
# need to actually pass args in here
args_var = TupleVariable(
[
ListVariable(inputs_var),
TupleVariable([]),
ListVariable([]),
TupleVariable([]),
TupleVariable([]),
]
)
out = TupleVariable([gm_var, args_var])
breakpoint()
return out
unpacked = []
for i, val in enumerate(example_value):
if val is None:
@ -2508,7 +2572,14 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
# use the same options object as parent
options_i = options
# if isinstance(val, torch.fx.graph_module.GraphModule):
# print("creating ")
# unpacked.append(SourcelessGraphModuleVariable(val))
# unpacked.append(SourcelessBuilder.create(self.tx, v for v in example_value))
# WARNING: this assumes the same target_cls as this tuple/list call
# I think we need to create a new Variable to properly model the output of the functional CA call
# it's wrong that we even end up here, target_cls is TensorVariable
unpacked.append(
wrap_fx_proxy_cls(
target_cls=target_cls,

View File

@ -1227,6 +1227,8 @@ class SkipFunctionVariable(VariableTracker):
]
# also warn on it because most users won't see the graph break message
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
# not supposed to get here
breakpoint()
if self.value.__qualname__ == "allow_in_graph":
explanation = (
"Found an allow_in_graph decorator to a function which "

View File

@ -577,6 +577,9 @@ class TensorVariable(VariableTracker):
from .builder import SourcelessBuilder, VariableBuilder
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
# if name == "backward":
# breakpoint()
if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
unimplemented(f"Illegal method invocation {name} in strict mode")

View File

@ -657,6 +657,8 @@ def _create_aot_dispatcher_function(
needs_autograd = any(
x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)
)
# only for these fullgraphs
needs_autograd = False
with enable_python_dispatcher():
# Patch set_rng_state as set_rng_state with fake tensors is
@ -1193,6 +1195,7 @@ def aot_module_simplified(
# convention. This should get fixed...
# NB: GraphModule/nn.Module rely on the non-boxed calling convention here
def forward(*runtime_args: tuple[Any]):
print("at runtime")
full_args = []
full_args.extend(params_flat)
full_args.extend(runtime_args)

View File

@ -284,7 +284,7 @@ struct CppNode : public Node {
void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
void save_variables_to_ctx();
void compiled_args(CompiledNodeArgs& args) override {
void compiled_args(CompiledNodeArgs& args) const override {
// although neither of the 2 methods below have uniqueness guarantees
// it is unlikely for them to collide at the same time
args.collect(static_cast<uint64_t>(typeid(T).hash_code()));

View File

@ -1330,10 +1330,11 @@ auto Engine::execute(
TORCH_CHECK(
!create_graph, "compiled_autograd does not support create_graph");
_thread_check.release();
TORCH_CHECK(
!AnomalyMode::is_enabled(),
"compiled_autograd does not support AnomalyMode")
// TORCH_CHECK(
// !AnomalyMode::is_enabled(),
// "compiled_autograd does not support AnomalyMode")
GraphTaskGuard guard(graph_task);
CheckpointValidGuard cpvguard(graph_task);
return (*compiled_autograd)(
graph_root, *graph_task, accumulate_grad, outputs);
}

View File

@ -138,7 +138,7 @@ struct TORCH_API Engine {
// see [Note: Compiled Autograd]
typedef variable_list (*compiled_autograd_fn)(
const std::shared_ptr<Node>& graph_root,
GraphTask& graph_task,
const GraphTask& graph_task,
bool accumulate_grad,
const edge_list& outputs);
static void set_compiled_autograd(compiled_autograd_fn fn);

View File

@ -545,8 +545,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return tensor_pre_hooks_;
}
virtual std::unique_ptr<PostAccumulateGradHook>&
tensor_post_acc_grad_hooks() noexcept {
virtual std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
const noexcept {
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
return empty;
}
@ -593,7 +593,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// 2) Collect node information for specialization and caching
// Implementations in subclasses should call args.collect() with all node
// attrs. These functions are only called durring backward.
virtual void compiled_args(CompiledNodeArgs& args) {
virtual void compiled_args(CompiledNodeArgs& args) const {
throw std::runtime_error(
std::string("compiled_args not implemented: ") + name());
}

View File

@ -22,7 +22,8 @@ struct TORCH_API FunctionPreHook {
virtual ~FunctionPreHook() = default;
virtual variable_list operator()(const variable_list& grads) = 0;
// only implemented for python hooks, registers hook with compiled autograd
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error(
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name());
@ -35,7 +36,8 @@ struct TORCH_API FunctionPostHook {
const variable_list& outputs /* grad_inputs */,
const variable_list& inputs /* grad_outputs */) = 0;
// only implemented for python hooks, registers hook with compiled autograd
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error(
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name());
@ -47,7 +49,8 @@ struct TORCH_API PostAccumulateGradHook {
virtual void operator()(const Variable& tensor) = 0;
// only implemented for python hooks on nodes, registers hook with compiled
// autograd
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error(
std::string("not yet implemented for compiled autograd: ") +
typeid(*this).name());

View File

@ -66,12 +66,12 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
return variable_list();
}
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) {
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const {
if (args.cond(variable.defined() && variable.requires_grad())) {
args.collect(variable);
args.collect(variable.grad());
}
auto& hook = tensor_post_acc_grad_hooks();
const auto& hook = tensor_post_acc_grad_hooks();
if (hook != nullptr) {
hook->compiled_args(args);
}

View File

@ -50,8 +50,8 @@ struct TORCH_API AccumulateGrad : public Node {
return impl::hooks(variable);
}
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks() noexcept
override {
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
const noexcept override {
// NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
// it can be destroyed even though the Tensor is still alive (contrary
// to all other Nodes). So we must lazily read the Tensor hooks here.
@ -262,7 +262,7 @@ struct TORCH_API AccumulateGrad : public Node {
}
}
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;

View File

@ -12,11 +12,15 @@
namespace torch::autograd {
auto Error::apply(variable_list&& inputs) -> variable_list {
variable_list Error::apply(variable_list&& inputs) {
return static_cast<const Error*>(this)->apply(std::move(inputs));
}
variable_list Error::apply(variable_list&& inputs) const {
throw std::runtime_error(msg);
}
void Error::compiled_args(CompiledNodeArgs& args) {
void Error::compiled_args(CompiledNodeArgs& args) const {
// throw the error durring collect, the graph won't get compiled
apply(variable_list());
}
@ -66,7 +70,7 @@ auto Identity::apply(variable_list&& grads) -> variable_list {
return std::move(grads);
}
void GraphRoot::compiled_args(CompiledNodeArgs& args) {
void GraphRoot::compiled_args(CompiledNodeArgs& args) const {
args.collect(outputs);
}
variable_list GraphRoot::apply_with_saved(

View File

@ -18,8 +18,9 @@ struct TORCH_API Error : public Node {
Error(std::string msg) : msg(std::move(msg)) {}
variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;
@ -51,6 +52,7 @@ struct TORCH_API DelayedError : public Node {
}
variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
std::string msg;
};
@ -61,6 +63,7 @@ struct TORCH_API UndefinedGrad : public Node {
}
variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
};
struct TORCH_API UndefinedGradBackward : public Node {
@ -69,8 +72,9 @@ struct TORCH_API UndefinedGradBackward : public Node {
UndefinedGradBackward() = default;
variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
void compiled_args(CompiledNodeArgs& args) override {}
void compiled_args(CompiledNodeArgs& args) const override {}
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override {
@ -93,7 +97,7 @@ struct TORCH_API GraphRoot : public Node {
return outputs;
}
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;

View File

@ -60,7 +60,7 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
src_options);
}
void CopyBackwards::compiled_args(CompiledNodeArgs& args) {
void CopyBackwards::compiled_args(CompiledNodeArgs& args) const {
args.collect(src_options);
}
@ -235,7 +235,7 @@ void CopySlices::release_variables() {
fn = nullptr;
}
void CopySlices::compiled_args(CompiledNodeArgs& args) {
void CopySlices::compiled_args(CompiledNodeArgs& args) const {
TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd")
TORCH_INTERNAL_ASSERT((bool)fn);
args.collect(base);

View File

@ -15,7 +15,7 @@ namespace torch::autograd {
struct TORCH_API CopyBackwards : public Node {
variable_list apply(variable_list&& grads) override;
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;
@ -168,7 +168,7 @@ struct TORCH_API CopySlices : public Node {
variable_list apply(variable_list&& inputs) override;
void release_variables() override;
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;

View File

@ -34,6 +34,7 @@
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_dtypes.h>
#include <autograd/function.h>
#include <functional>
#include <memory>
#include <stdexcept>
@ -185,9 +186,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
return to_variable_list(r.get(), is_variable_input);
}
auto PyNode::defer_to_dynamo(
auto PyNode::apply_with_saved_impl(
const variable_list& inputs,
const std::optional<PyObject*>& compiler) -> variable_list {
const SwapSavedVariables& saved) -> variable_list {
pybind11::gil_scoped_acquire gil;
at::OptionalDeviceGuard _device_guard;
THPFunction* py_fn = (THPFunction*)obj;
@ -235,24 +236,24 @@ auto PyNode::defer_to_dynamo(
}
THPObjectPtr saved_tensors(unpack_saved_variables(
py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
TORCH_INTERNAL_ASSERT(
_backward_idx.has_value(),
"indices should already be set by compiled_args, called before apply_with_saved");
auto [bwd_idx, maybe_bwd_state_idx] = saved.retrieve_pynode_objs(this);
PyObject* backward_state_idx = Py_None;
if (_backward_state_idx.has_value()) {
backward_state_idx = THPUtils_packInt64(_backward_state_idx.value());
if (maybe_bwd_state_idx.has_value()) {
backward_state_idx = THPUtils_packUInt64(maybe_bwd_state_idx.value());
// this might be simplifiable now that we no longer inline
Py_CLEAR(py_fn->compiled_autograd_backward_state);
}
THPObjectPtr r(PyObject_CallMethod(
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
compiler.value(),
saved.get_py_compiler(),
"proxy_call_backward",
"OOOiOO",
pyInputs.get(),
fwdInputMetadatas.get(),
saved_tensors.get(),
*_backward_idx,
bwd_idx,
obj,
backward_state_idx));
@ -301,7 +302,7 @@ bool PyNode::is_aot_backward() const {
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
}
void PyNode::compiled_args(CompiledNodeArgs& args) {
void PyNode::compiled_args(CompiledNodeArgs& args) const {
static PyObject* method_name =
PyUnicode_InternFromString("_compiled_autograd_key");
THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr));
@ -346,14 +347,15 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
args.collect(f->input_info);
Py_INCREF(obj);
_backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
c10::SafePyObject backward_obj(obj, getPyInterpreter());
std::optional<c10::SafePyObject> backward_state_obj;
PyObject* bw_state = f->compiled_autograd_backward_state;
if (args.cond(bw_state != nullptr)) {
Py_INCREF(bw_state);
_backward_state_idx = args.add_backward_state(
c10::SafePyObject(bw_state, getPyInterpreter()));
backward_state_obj = c10::SafePyObject(bw_state, getPyInterpreter());
}
args.collect_pynode_objs(
this, std::move(backward_obj), std::move(backward_state_obj));
}
variable_list PyNode::apply_with_saved(
@ -368,8 +370,7 @@ variable_list PyNode::apply_with_saved(
saved.before(f->output_info);
saved.before(f->input_info);
f->compiled_autograd_tracing = true;
variable_list result =
defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
variable_list result = apply_with_saved_impl(variable_list(inputs), saved);
f->compiled_autograd_tracing = false;
saved.after(f->compiled_autograd_symints);
saved.after(f->saved_variables);

View File

@ -35,9 +35,9 @@ struct PyNode : public Node {
const std::vector<bool>& is_variable_input);
variable_list apply(variable_list&& inputs) override;
variable_list defer_to_dynamo(
variable_list apply_with_saved_impl(
const variable_list& inputs,
const std::optional<PyObject*>& compiler);
const SwapSavedVariables& saved);
void release_variables() override;
std::string name() const override;
@ -45,7 +45,7 @@ struct PyNode : public Node {
bool is_aot_backward() const override;
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;
@ -53,13 +53,6 @@ struct PyNode : public Node {
// THPFunction this Function is wrapping. Owning!
PyObject* obj;
// The AutogradCompilerCall::hooks idx corresponding to this node's backward
std::optional<int> _backward_idx;
// The AutogradCompilerCall::hooks idx corresponding to this node's
// backward_state
std::optional<int> _backward_state_idx;
// NOLINTNEXTLINE(bugprone-exception-escape)
~PyNode() override {
// Can't use THPObjectPtr as a field in this class; destructor won't take

View File

@ -176,7 +176,7 @@ auto PyFunctionPostHook::operator()(
return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
}
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict);
@ -189,7 +189,7 @@ void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
Py_END_CRITICAL_SECTION();
}
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict);
@ -200,7 +200,7 @@ void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
Py_END_CRITICAL_SECTION();
}
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) {
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict);
@ -237,7 +237,7 @@ auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor)
}
void PyFunctionTensorPostAccGradHooks::compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) {
torch::dynamo::autograd::CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict);

View File

@ -14,7 +14,8 @@ struct PyFunctionTensorPreHook : public FunctionPreHook {
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
~PyFunctionTensorPreHook() override;
variable_list operator()(const variable_list& values) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
PyObject* dict;
size_t value_idx;
};
@ -23,7 +24,8 @@ struct PyFunctionPreHook : public FunctionPreHook {
PyFunctionPreHook(PyObject* dict);
~PyFunctionPreHook() override;
variable_list operator()(const variable_list& values) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
PyObject* dict;
};
@ -33,7 +35,8 @@ struct PyFunctionPostHook : public FunctionPostHook {
variable_list operator()(
const variable_list& outputs,
const variable_list& inputs) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
PyObject* dict;
};
@ -45,7 +48,8 @@ struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
PyFunctionTensorPostAccGradHooks(PyObject* dict);
~PyFunctionTensorPostAccGradHooks() override;
void operator()(const Variable& tensor) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
void apply_with_saved(
Variable& tensor,
torch::dynamo::autograd::SwapSavedVariables& saved) override;

View File

@ -46,6 +46,15 @@ at::Tensor PySavedVariableHooks::call_unpack_hook() {
// unpack_hook_ will be manually decrefed when the saved variable is released
}
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
PySavedVariableHooks::retrieve_unpack_hook_data() const {
Py_INCREF(unpack_hook_);
Py_INCREF(data_);
return std::make_pair(
c10::SafePyObject(unpack_hook_, getPyInterpreter()),
c10::SafePyObject(data_, getPyInterpreter()));
}
// NOLINTNEXTLINE(bugprone-exception-escape)
PySavedVariableHooks::~PySavedVariableHooks() {
// If python is already dead, leak the wrapped python objects

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/ATen.h>
#include <c10/core/SafePyObject.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/autograd/python_variable.h>
@ -17,6 +18,8 @@ struct PySavedVariableHooks : public SavedVariableHooks {
void call_pack_hook(const at::Tensor& tensor) override;
at::Tensor call_unpack_hook() override;
~PySavedVariableHooks() override;
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
retrieve_unpack_hook_data() const override;
private:
PyObject* pack_hook_;

View File

@ -59,6 +59,7 @@ SavedVariable::SavedVariable(
if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) {
save_metadata(variable);
set_hooks_and_pack_data(std::move(maybe_hooks), variable);
TORCH_INTERNAL_ASSERT(!data_.defined());
return;
}
@ -134,9 +135,14 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
// We want grad_fn here to provide the most helpful debug message to the user
// if versions don't match
auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock()
: !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr
: grad_fn_;
std::shared_ptr<Node> grad_fn;
if (is_inplace_on_view_) {
grad_fn = weak_grad_fn_.lock();
} else if (!hooks_) {
grad_fn = saved_original_ ? data_.grad_fn() : nullptr;
} else {
grad_fn = grad_fn_;
}
if (!is_leaf_ && !grad_fn) {
// This issue was introduced when we added logic to save the original

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/autograd/forward_grad.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>
@ -53,6 +54,15 @@ class TORCH_API SavedVariable {
return (bool)hooks_;
}
// Used by compiled autograd
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
retrieve_unpack_hook_data() const {
if (!hooks_) {
return std::nullopt;
}
return hooks_->retrieve_unpack_hook_data();
}
private:
// This field contains either:
// 1. the variable to save

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/core/SafePyObject.h>
namespace torch::autograd {
@ -8,6 +9,11 @@ struct TORCH_API SavedVariableHooks {
virtual void call_pack_hook(const at::Tensor& tensor) = 0;
virtual at::Tensor call_unpack_hook() = 0;
virtual ~SavedVariableHooks() = default;
virtual std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
retrieve_unpack_hook_data() const {
throw std::runtime_error(
"Compiled Autograd only supports python saved tensor hooks ");
}
};
} // namespace torch::autograd

View File

@ -27,7 +27,7 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
return fn_(outputs, inputs);
}
void compiled_args(CompiledNodeArgs& args) override {}
void compiled_args(CompiledNodeArgs& args) const override {}
protected:
std::function<variable_list(const variable_list&, const variable_list&)> fn_;

View File

@ -183,21 +183,23 @@ Reducer::Reducer(
#endif
// Hook to execute after the gradient accumulator has executed.
hooks_.emplace_back(
grad_accumulator->add_post_hook(
std::make_unique<torch::autograd::utils::LambdaPostHook>(
[this, variable_index](
const torch::autograd::variable_list& outputs,
const torch::autograd::variable_list& /* unused */) {
grad_accumulator->add_post_hook(std::make_unique<
torch::autograd::utils::
LambdaPostHook>(
[this, variable_index](
const torch::autograd::variable_list& outputs,
const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32
this->rpc_context_.set(
ThreadLocalDistAutogradContext::getContextPtr());
this->rpc_context_.set(
ThreadLocalDistAutogradContext::getContextPtr());
#endif
this->autograd_hook(variable_index);
return outputs;
},
[=](torch::autograd::CompiledNodeArgs& args) {
// Make post_hook an noop if compiled_autograds is enabled.
})),
this->autograd_hook(variable_index);
return outputs;
},
[=](torch::autograd::CompiledNodeArgs& args) {
TORCH_INTERNAL_ASSERT(
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
})),
grad_accumulator);
// Map raw function pointer to parameter index.

View File

@ -3,20 +3,22 @@
namespace torch::dynamo::autograd {
std::unique_ptr<PyCompilerInterface> kPyCompilerInterface;
std::unique_ptr<PyCompilerInterface> kActivePyCompilerInterface;
const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface() {
TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr);
return kPyCompilerInterface;
TORCH_INTERNAL_ASSERT(kActivePyCompilerInterface != nullptr);
return kActivePyCompilerInterface;
}
void setPyCompilerInterface(std::unique_ptr<PyCompilerInterface>&& impl) {
TORCH_INTERNAL_ASSERT(impl != nullptr);
kPyCompilerInterface = std::move(impl);
PyCompilerGuard::PyCompilerGuard(std::unique_ptr<PyCompilerInterface>&& impl) {
TORCH_INTERNAL_ASSERT(
kActivePyCompilerInterface == nullptr && impl != nullptr);
kActivePyCompilerInterface = std::move(impl);
}
void resetPyCompilerInterface() {
kPyCompilerInterface.reset();
PyCompilerGuard::~PyCompilerGuard() {
TORCH_INTERNAL_ASSERT(kActivePyCompilerInterface != nullptr);
kActivePyCompilerInterface.reset();
}
std::vector<std::optional<InputMetadata>> get_input_metadata(

View File

@ -17,6 +17,84 @@
namespace torch::dynamo::autograd {
using namespace torch::autograd;
// This is a layer of indirection for calling methods on the Python
// AutogradCompilerInstance (referred to as the "py_compiler") from
// libtorch_cpu (where Python is not available).
// A PyCompilerInterfaceImpl in libtorch_python subclasses it and
// overrides the methods to do the actual calls back to Python.
struct TORCH_API PyCompilerInterface {
PyCompilerInterface() = default;
PyCompilerInterface(const PyCompilerInterface&) = delete;
PyCompilerInterface& operator=(const PyCompilerInterface&) = delete;
PyCompilerInterface(PyCompilerInterface&&) = delete;
PyCompilerInterface& operator=(PyCompilerInterface&&) = delete;
virtual ~PyCompilerInterface() = default;
// Invokes py_compiler.bind_function
virtual std::string bind_function(
PyObject* py_compiler,
const std::string& fn_name,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
functional_apply_t fn,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<at::TypePtr> packed_args_schema,
bool is_custom_function = false,
bool is_traceable = true) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
// output_metadata)
virtual variable_list call_function(
PyObject* py_compiler,
const char* method_name,
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_prologue(
PyObject* py_compiler,
const variable_list& inputs,
const at::TensorGeometry& base,
const at::TensorGeometry& view) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_epilogue(
PyObject* py_compiler,
const std::vector<bool>& needs_input_grad,
const at::Tensor& result,
const variable_list& res,
const at::Tensor& grad_slice) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual at::Tensor call_unpack(
PyObject* py_compiler,
std::optional<size_t> hook_id,
size_t hook_input_id) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
};
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
struct TORCH_API PyCompilerGuard {
explicit PyCompilerGuard(std::unique_ptr<PyCompilerInterface>&& impl);
PyCompilerGuard(const PyCompilerGuard&) = delete;
PyCompilerGuard& operator=(const PyCompilerGuard&) = delete;
PyCompilerGuard(PyCompilerGuard&&) = delete;
PyCompilerGuard& operator=(PyCompilerGuard&&) = delete;
~PyCompilerGuard();
};
// including torch/csrc/autograd/engine.h breaks BC by somehow introducing
// symbol resolution issues. Instead requiring downstream users to include
// engine.h to access collect_input_metadata, we provide it here (with a
// different name to avoid ambigous symbols...)
TORCH_API std::vector<std::optional<InputMetadata>> get_input_metadata(
const edge_list& edges);
struct SizeInput {
// Note: int value is still needed when dynamic to pass as an arg
enum DynType : uint8_t { STATIC = 0, DYNAMIC = 1 };
@ -154,9 +232,14 @@ struct TensorArgs {
}
TensorArg& lookup(const SavedVariable& sv) {
auto it = _saved_variables.find(&sv);
TORCH_INTERNAL_ASSERT(it != _saved_variables.end());
return *it->second;
if (auto it = _saved_variables.find(&sv); it != _saved_variables.end()) {
// unpacked before graph
return *it->second;
}
// unpacked in graph
auto it2 = _saved_variables_proxies.find(&sv);
TORCH_INTERNAL_ASSERT(it2 != _saved_variables_proxies.end());
return *it2->second;
}
TensorArg& add(const at::Tensor& tensor) {
@ -164,9 +247,7 @@ struct TensorArgs {
}
TensorArg& add(const SavedVariable& sv, const std::shared_ptr<Node>& node) {
// TODO(jansel): Here we unpack the SavedVariable exactly once. This might
// fire SavedTensor hooks. In the future we should try to put saved tensor
// hooks into the graph.
// no unpack hooks in this codepath
at::Tensor tensor = sv.unpack(node);
TensorArg& arg = add(tensor);
_saved_variables.emplace(&sv, &arg);
@ -185,6 +266,7 @@ struct TensorArgs {
// Every TensorArg from this is actually owned by _args (or _undefined) and
// that's why we have an un-owned pointer here.
std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables;
std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables_proxies;
TensorArg _undefined;
uint32_t _next_id = 1; // id=0 used by _undefined
};
@ -245,6 +327,11 @@ struct AutogradCompilerCall {
return hooks.size() - 1;
}
size_t emplace_packed_input(c10::SafePyObject&& input) {
packed_inputs.emplace_back(std::move(input));
return packed_inputs.size() - 1;
}
void set_active_node_call_idx(size_t node_call_idx) {
active_node_call_idx = node_call_idx;
}
@ -255,10 +342,16 @@ struct AutogradCompilerCall {
LiftedIValueArgs lifted_ivalue_args;
std::vector<int64_t> dyn_size_inputs;
std::vector<c10::SafePyObject> hooks;
std::vector<c10::SafePyObject> packed_inputs;
NodeCalls node_calls;
SizeInput::DynType default_dyn_type;
// NodeCall id of each size, only when verbose logging is enabled
std::vector<uint32_t> size_input_origins;
std::unordered_map<const SavedVariable*, std::pair<size_t, size_t>>
sv_to_hooks;
// pynode -> backward and backward state idx
std::unordered_map<const Node*, std::pair<size_t, std::optional<size_t>>>
pynode_objs;
};
class CompiledNodeArgs {
@ -285,8 +378,19 @@ class CompiledNodeArgs {
collect(_compiler.tensor_args.add(t));
}
void collect(const SavedVariable& sv, bool is_output) {
collect(
_compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
if (auto hook_data = sv.retrieve_unpack_hook_data();
hook_data.has_value()) {
// hooks, unpack in graph
auto& [hook, packed_input] = hook_data.value();
size_t hook_id = _compiler.emplace_hook(std::move(hook));
// rely on dynamo to dedup packed tensors from unpacked tensors
size_t input_id = _compiler.emplace_packed_input(std::move(packed_input));
_compiler.sv_to_hooks.emplace(&sv, std::make_pair(hook_id, input_id));
} else {
// no hooks, unpack now
collect(
_compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
}
}
void collect(const c10::SymInt& t) {
_compiler.add_size_input(t);
@ -518,12 +622,17 @@ class CompiledNodeArgs {
typeid(*node), _specialization_key, _specialization_key_size);
}
size_t add_backward(c10::SafePyObject&& obj) {
return _compiler.emplace_hook(std::move(obj));
}
size_t add_backward_state(c10::SafePyObject&& obj) {
return _compiler.emplace_hook(std::move(obj));
void collect_pynode_objs(
const Node* pynode,
c10::SafePyObject&& bwd,
std::optional<c10::SafePyObject>&& bwd_state) {
size_t bwd_idx = _compiler.emplace_hook(std::move(bwd));
std::optional<size_t> bwd_state_idx;
if (auto state = std::move(bwd_state); state.has_value()) {
bwd_state_idx = _compiler.emplace_hook(std::move(state.value()));
}
_compiler.pynode_objs.emplace(
pynode, std::make_pair(bwd_idx, bwd_state_idx));
}
void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) {
@ -642,6 +751,13 @@ class SwapSavedVariables {
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
// allows tracing to happen, then swaps them back afterwards.
public:
std::pair<size_t, std::optional<size_t>> retrieve_pynode_objs(
Node* pynode) const {
auto it = compiler.pynode_objs.find(pynode);
TORCH_INTERNAL_ASSERT(it != compiler.pynode_objs.end());
return it->second;
}
void before(at::Tensor& t) {
TensorArg& arg = compiler.tensor_args.lookup(t);
stashed_tensors.save(&t, std::move(t));
@ -655,13 +771,26 @@ class SwapSavedVariables {
}
void before(SavedVariable& t) {
TensorArg& arg = compiler.tensor_args.lookup(t);
stashed_variables.save(&t, std::move(t));
if (arg.defined()) {
if (auto it = compiler.sv_to_hooks.find(&t);
it != compiler.sv_to_hooks.end()) {
const auto& pyinterface =
torch::dynamo::autograd::getPyCompilerInterface();
auto proxy_tensor = pyinterface->call_unpack(
get_py_compiler(), it->second.first, it->second.second);
stashed_variables.save(&t, std::move(t));
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
t = SavedVariable(arg.proxy_tensor, false);
t = SavedVariable(proxy_tensor, false);
at::SavedTensorDefaultHooks::set_tracing(prior);
} else {
// no hooks, was already unpacked
TensorArg& arg = compiler.tensor_args.lookup(t);
stashed_variables.save(&t, std::move(t));
if (arg.defined()) {
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
t = SavedVariable(arg.proxy_tensor, false);
at::SavedTensorDefaultHooks::set_tracing(prior);
}
}
}
void after(SavedVariable& t) {
@ -834,7 +963,7 @@ class SwapSavedVariables {
const NodeCall& n)
: compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
PyObject* get_py_compiler() {
PyObject* get_py_compiler() const {
return py_compiler;
}
@ -1370,73 +1499,6 @@ struct PackedArgs {
int64_t idx = 0;
};
// This is a layer of indirection for calling methods on the Python
// AutogradCompilerInstance (referred to as the "py_compiler") from
// libtorch_cpu (where Python is not available).
// A PyCompilerInterfaceImpl in libtorch_python subclasses it and
// overrides the methods to do the actual calls back to Python.
struct TORCH_API PyCompilerInterface {
PyCompilerInterface() = default;
PyCompilerInterface(const PyCompilerInterface&) = delete;
PyCompilerInterface& operator=(const PyCompilerInterface&) = delete;
PyCompilerInterface(PyCompilerInterface&&) = delete;
PyCompilerInterface& operator=(PyCompilerInterface&&) = delete;
virtual ~PyCompilerInterface() = default;
// Invokes py_compiler.bind_function
virtual std::string bind_function(
PyObject* py_compiler,
const std::string& fn_name,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
functional_apply_t fn,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<at::TypePtr> packed_args_schema,
bool is_custom_function = false,
bool is_traceable = true) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
// output_metadata)
virtual variable_list call_function(
PyObject* py_compiler,
const char* method_name,
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_prologue(
PyObject* py_compiler,
const variable_list& inputs,
const at::TensorGeometry& base,
const at::TensorGeometry& view) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_epilogue(
PyObject* py_compiler,
const std::vector<bool>& needs_input_grad,
const at::Tensor& result,
const variable_list& res,
const at::Tensor& grad_slice) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
};
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
TORCH_API void setPyCompilerInterface(
std::unique_ptr<PyCompilerInterface>&& impl);
TORCH_API void resetPyCompilerInterface();
// including torch/csrc/autograd/engine.h breaks BC by somehow introducing
// symbol resolution issues. Instead requiring downstream users to include
// engine.h to access collect_input_metadata, we provide it here (with a
// different name to avoid ambigous symbols...)
TORCH_API std::vector<std::optional<InputMetadata>> get_input_metadata(
const edge_list& edges);
} // namespace torch::dynamo::autograd
template <>

View File

@ -11,6 +11,7 @@
#include <sstream>
#include <string>
#include <vector>
#include <c10/core/impl/TorchDispatchModeTLS.h>
/*
[Note: Compiled Autograd]
@ -52,6 +53,12 @@ Notes:
namespace torch::dynamo::autograd {
using c10::SymInt;
namespace {
PyObject* the_autograd_compiler = nullptr;
int default_dyn_type_int = 0;
PyObject* python_verbose_logger = nullptr;
} // namespace
// List[Optional[Tensor]] in Python can't be directly parsed into a
// List[Tensor], so we need to do this conversion manually.
static std::vector<at::Tensor> toTensorList(
@ -203,6 +210,16 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface {
auto output = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
return toTensorList(output);
}
at::Tensor call_unpack(
PyObject* py_compiler,
std::optional<size_t> hook_id,
size_t hook_input_id) override {
py::handle handle(py_compiler);
py::object proxy = handle.attr("unpack_hook")(hook_id, hook_input_id);
auto tmp = py::cast<std::optional<at::Tensor>>(proxy);
TORCH_INTERNAL_ASSERT(tmp.has_value());
return tmp.value();
}
};
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
@ -213,7 +230,7 @@ static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
return pyinput;
}
static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
static PyObject* convert_pyobj_list(std::vector<c10::SafePyObject>& inputs) {
// inplace, consumes the input hooks
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
for (const auto i : c10::irange(inputs.size())) {
@ -257,9 +274,6 @@ static variable_list validate_outputs(
return new_outputs;
}
// snapshot of python verbose logging toggle
static PyObject* python_verbose_logger = nullptr;
struct PythonLogger {
PythonLogger() = delete;
explicit PythonLogger(PyObject* logger) : logger_(logger) {
@ -540,8 +554,6 @@ struct InputBuffers : public std::unordered_map<Node*, InputBuffer> {
}
};
static PyObject* the_autograd_compiler = nullptr;
static int default_dyn_type_int = 0;
static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args);
static PyObject* clear_cache(PyObject* dummy, PyObject* args) {
@ -654,7 +666,7 @@ static PyObject* wrap_string_list(const std::vector<std::string>& strs) {
return pystrs;
}
std::string unwrap_string(PyObject* pystr) {
static std::string unwrap_string(PyObject* pystr) {
TORCH_INTERNAL_ASSERT(PyUnicode_Check(pystr));
const char* str = PyUnicode_AsUTF8(pystr);
TORCH_INTERNAL_ASSERT(str != nullptr);
@ -790,14 +802,18 @@ static SizeInput::DynType get_default_dyn_type() {
// Only call this function while holding GIL
static CacheNode* _compiled_autograd_impl(
const std::shared_ptr<Node>& graph_root,
GraphTask& graph_task,
const GraphTask& graph_task,
bool accumulate_grad,
const edge_list& output_edges,
THPObjectPtr* graph_arg_inputs,
THPObjectPtr* graph_arg_sizes,
THPObjectPtr* graph_arg_ivalue_args,
THPObjectPtr* graph_arg_hooks) {
std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
THPObjectPtr* graph_arg_hooks,
THPObjectPtr* graph_arg_packed_inputs) {
const std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
std::unordered_map<Node*, int> visited_dependencies;
visited_dependencies.reserve(dependencies.size());
std::vector<std::shared_ptr<Node>> worklist{graph_root};
AutogradCompilerCall compiler_call(get_default_dyn_type());
@ -809,8 +825,8 @@ static CacheNode* _compiled_autograd_impl(
}
const bool check_exec_info = !graph_task.exec_info_.empty();
CacheNode* cache = CacheNode::root();
std::vector<NodeCall*> calls;
calls.reserve(
std::vector<NodeCall*> ordered_calls;
ordered_calls.reserve(
check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1);
int i = 0;
@ -820,7 +836,7 @@ static CacheNode* _compiled_autograd_impl(
std::shared_ptr<Node> fn = std::move(worklist.back());
worklist.pop_back();
NodeCall& call = compiler_call.node_calls.lookup(fn);
calls.emplace_back(&call);
ordered_calls.emplace_back(&call);
{ // update cache and gather args into `compiler_call`
CompiledNodeArgs node_args(compiler_call, call);
@ -829,6 +845,7 @@ static CacheNode* _compiled_autograd_impl(
}
node_args.collect(call);
if (node_args.cond(call.needed)) {
std::cout << "collecting " << fn->name() << std::endl;
fn->compiled_args(node_args);
node_args.collect(call.node->next_edges());
}
@ -861,9 +878,9 @@ static CacheNode* _compiled_autograd_impl(
}
}
auto it = dependencies.find(edge.function.get());
TORCH_INTERNAL_ASSERT(it != dependencies.end());
if (--it->second == 0) {
dependencies.erase(it);
int count = ++visited_dependencies[it->first];
TORCH_INTERNAL_ASSERT(count <= it->second);
if (count == it->second) {
worklist.emplace_back(edge.function);
}
}
@ -876,8 +893,8 @@ static CacheNode* _compiled_autograd_impl(
TORCH_INTERNAL_ASSERT(!vlogger.has_value() || compile_reason.has_value());
ClosingTHPObjectPtr py_compiler(
check(PyObject_CallNoArgs((the_autograd_compiler))));
setPyCompilerInterface(std::make_unique<PyCompilerInterfaceImpl>());
PyCompilerGuard py_compiler_guard(
std::make_unique<PyCompilerInterfaceImpl>());
TraceState state = call_begin_capture(
py_compiler,
@ -887,8 +904,8 @@ static CacheNode* _compiled_autograd_impl(
std::move(compile_reason));
InputBuffers input_buffers;
for (size_t i = 0; i < calls.size(); i++) {
NodeCall& call = *calls[i];
for (size_t i = 0; i < ordered_calls.size(); i++) {
NodeCall& call = *ordered_calls[i];
std::string _node_name = call.node->name();
THPObjectPtr node_name(PyUnicode_FromString(_node_name.data()));
@ -948,6 +965,7 @@ static CacheNode* _compiled_autograd_impl(
}
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
std::cout << "apply_with_saved " << call.node->name() << std::endl;
variable_list outputs = call.node->apply_with_saved(inputs, saved);
saved.debug_asserts();
saved.before(call.node->next_edges());
@ -1024,7 +1042,6 @@ static CacheNode* _compiled_autograd_impl(
}
}
resetPyCompilerInterface();
PyObject* res = check(call_end_capture(py_compiler, state.outputs));
TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
TORCH_CHECK(
@ -1043,7 +1060,8 @@ static CacheNode* _compiled_autograd_impl(
// TODO(jansel): clear grads we will overwrite below
if (!graph_task.keep_graph_) {
for (auto& call : calls) {
for (auto& call : ordered_calls) {
// Once we release variables, we can no longer fallback to eager autograd
call->node->release_variables();
}
}
@ -1052,7 +1070,8 @@ static CacheNode* _compiled_autograd_impl(
*graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);
*graph_arg_ivalue_args =
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args);
*graph_arg_hooks = convert_hook_list(compiler_call.hooks);
*graph_arg_hooks = convert_pyobj_list(compiler_call.hooks);
*graph_arg_packed_inputs = convert_pyobj_list(compiler_call.packed_inputs);
return cache;
}
@ -1078,12 +1097,20 @@ struct LockGuardWithErrorLogs {
static variable_list compiled_autograd(
const std::shared_ptr<Node>& graph_root,
GraphTask& graph_task,
const GraphTask& graph_task,
bool accumulate_grad,
const edge_list& output_edges) {
TORCH_CHECK(
c10::impl::TorchDispatchModeTLS::stack_len() == 0,
"TorchDispatchMode not yet implemented for compiled autograd")
// std::cout << "called compiled autograd" << std::endl;
// std::vector<std::optional<std::shared_ptr<c10::impl::PyObject_TorchDispatchMode>>> infra_modes;
// infra_modes.emplace_back(c10::impl::TorchDispatchModeTLS::unset_mode(c10::impl::TorchDispatchModeKey::FAKE));
// infra_modes.emplace_back(c10::impl::TorchDispatchModeTLS::unset_mode(c10::impl::TorchDispatchModeKey::PROXY));
// infra_modes.emplace_back(c10::impl::TorchDispatchModeTLS::unset_mode(c10::impl::TorchDispatchModeKey::FUNCTIONAL));
// TORCH_INTERNAL_ASSERT(c10::impl::TorchDispatchModeTLS::stack_len() == 0);
// TORCH_CHECK(
// c10::impl::TorchDispatchModeTLS::stack_len() == 0,
// "TorchDispatchMode not yet implemented for compiled autograd")
static std::mutex mtx;
LockGuardWithErrorLogs lock_guard(mtx);
pybind11::gil_scoped_acquire gil;
@ -1093,6 +1120,7 @@ static variable_list compiled_autograd(
THPObjectPtr sizes;
THPObjectPtr ivalue_args;
THPObjectPtr hooks;
THPObjectPtr packed_inputs;
CacheNode* cache = _compiled_autograd_impl(
graph_root,
graph_task,
@ -1101,7 +1129,8 @@ static variable_list compiled_autograd(
&inputs,
&sizes,
&ivalue_args,
&hooks);
&hooks,
&packed_inputs);
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
cache->runtime_wrapper.get(),
@ -1110,9 +1139,25 @@ static variable_list compiled_autograd(
sizes.get(),
ivalue_args.get(),
hooks.get(),
packed_inputs.get(),
NULL)));
variable_list outputs = THPVariable_UnpackList(pyresult);
TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size());
cache->clear();
// if (infra_modes[0].has_value()) {
// std::cout << "setting FAKE" << std::endl;
// c10::impl::TorchDispatchModeTLS::set_mode(infra_modes[0].value(), c10::impl::TorchDispatchModeKey::FAKE);
// }
// if (infra_modes[1].has_value()) {
// std::cout << "setting PROXY" << std::endl;
// c10::impl::TorchDispatchModeTLS::set_mode(infra_modes[1].value(), c10::impl::TorchDispatchModeKey::PROXY);
// }
// if (infra_modes[2].has_value()) {
// std::cout << "setting FUNCTIONAL" << std::endl;
// c10::impl::TorchDispatchModeTLS::set_mode(infra_modes[2].value(), c10::impl::TorchDispatchModeKey::FUNCTIONAL);
// }
return outputs;
}

View File

@ -361,6 +361,6 @@ def tensorify_python_scalars(
metrics_context.set("tensorify_float_success", True, overwrite=True)
raise TensorifyScalarRestartAnalysis
graph_code_log.debug(
"%s", lazy_format_graph_code("tensorify_python_scalars", gm, colored=True)
)
# graph_code_log.debug(
# "%s", lazy_format_graph_code("tensorify_python_scalars", gm, colored=True)
# )