mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Compare commits
6 Commits
ciflow/tru
...
xmfan/sing
| Author | SHA1 | Date | |
|---|---|---|---|
| 3d732de4c8 | |||
| 19a27343a0 | |||
| cabc560762 | |||
| 663565c84a | |||
| 95f317d9f8 | |||
| 643da4854f |
118
hack.py
Normal file
118
hack.py
Normal file
@ -0,0 +1,118 @@
|
||||
|
||||
import torch
|
||||
from torch._inductor.ir import NoneAsConstantBuffer
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import depyf
|
||||
depyf.install()
|
||||
|
||||
|
||||
def fn(loss):
|
||||
gm = None
|
||||
args = None
|
||||
|
||||
def noop(_gm):
|
||||
nonlocal gm
|
||||
gm = _gm
|
||||
def _noop(*_args, **_kwargs):
|
||||
assert not _kwargs
|
||||
nonlocal args
|
||||
args = _args
|
||||
return []
|
||||
return _noop
|
||||
|
||||
with torch._dynamo.compiled_autograd._enable(noop):
|
||||
loss.backward()
|
||||
|
||||
return gm, args
|
||||
|
||||
|
||||
result = torch._dynamo.compiled_autograd.Op("FunctionalCompiledAutograd", fn, is_custom_function=False)
|
||||
setattr(torch._dynamo.compiled_autograd.ops, "FunctionalCompiledAutograd", torch._dynamo.allow_in_graph(result))
|
||||
|
||||
|
||||
x = torch.randn(64, 3)
|
||||
t = torch.randn(64, 1)
|
||||
|
||||
model = nn.Linear(3, 1)
|
||||
|
||||
torch._dynamo.config.compiled_autograd = True
|
||||
torch._dynamo.config.do_not_emit_runtime_asserts = True
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def train(model, x, t):
|
||||
y = model(x)
|
||||
loss = F.mse_loss(y, t)
|
||||
gm, args = torch._dynamo.compiled_autograd.ops.FunctionalCompiledAutograd(loss)
|
||||
gm(*args)
|
||||
return ()
|
||||
|
||||
# with torch._dynamo.compiled_autograd._enable(noop):
|
||||
train(model, x, t)
|
||||
|
||||
for p in model.parameters():
|
||||
assert p.grad is not None
|
||||
|
||||
|
||||
"""
|
||||
# this kinda works, but not ideal
|
||||
===== __compiled_fn_1 =====
|
||||
/home/xmfan/core/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_model_parameters_weight_: "f32[1, 3][3, 1]cpu", L_model_parameters_bias_: "f32[1][1]cpu", L_x_: "f32[64, 3][3, 1]cpu", L_t_: "f32[64, 1][1, 1]cpu"):
|
||||
l_model_parameters_weight_ = L_model_parameters_weight_
|
||||
l_model_parameters_bias_ = L_model_parameters_bias_
|
||||
l_x_ = L_x_
|
||||
l_t_ = L_t_
|
||||
|
||||
# File: /home/xmfan/core/a/pytorch/hack.py:44 in train, code: y = model(x)
|
||||
y: "f32[64, 1][1, 1]cpu" = torch._C._nn.linear(l_x_, l_model_parameters_weight_, l_model_parameters_bias_); l_x_ = l_model_parameters_weight_ = l_model_parameters_bias_ = None
|
||||
|
||||
# File: /home/xmfan/core/a/pytorch/hack.py:45 in train, code: loss = F.mse_loss(y, t)
|
||||
loss: "f32[][]cpu" = torch.nn.functional.mse_loss(y, l_t_); y = l_t_ = None
|
||||
|
||||
# File: /home/xmfan/core/a/pytorch/hack.py:46 in train, code: gm, args = torch._dynamo.compiled_autograd.ops.FunctionalCompiledAutograd(loss)
|
||||
functional_compiled_autograd = torch__dynamo_compiled_autograd_ops_FunctionalCompiledAutograd(loss); loss = None
|
||||
getitem = functional_compiled_autograd[1]; functional_compiled_autograd = None
|
||||
getitem_1 = getitem[0]; getitem = None
|
||||
getitem_8: "f32[][]cpu" = getitem_1[0]
|
||||
getitem_9: "f32[64, 1][1, 1]cpu" = getitem_1[1]
|
||||
getitem_10: "f32[64, 1][1, 1]cpu" = getitem_1[2]
|
||||
getitem_11: "f32[64, 3][3, 1]cpu" = getitem_1[3]; getitem_1 = None
|
||||
|
||||
# File: <eval_with_key>.0:11 in forward, code: validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_8], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem_8 = None
|
||||
getitem_15: "f32[][]cpu" = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
# File: <eval_with_key>.0:13 in forward, code: mse_loss_backward0 = torch__dynamo_compiled_autograd_ops_MseLossBackward0([getitem_6], [True, False], 1, getitem_1, getitem_2); getitem_6 = getitem_1 = getitem_2 = None
|
||||
mse_loss_backward0 = torch__dynamo_compiled_autograd_ops_MseLossBackward0([getitem_15], [True, False], 1, getitem_9, getitem_10); getitem_15 = getitem_9 = getitem_10 = None
|
||||
getitem_17: "f32[64, 1][1, 1]cpu" = mse_loss_backward0[0]; mse_loss_backward0 = None
|
||||
|
||||
# File: <eval_with_key>.0:16 in forward, code: validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_7, getitem_8], [((None, None, device(type='cpu'), 6, 0, None), [64, 1], True), None]); getitem_7 = getitem_8 = None
|
||||
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_17, None], [((None, None, device(type='cpu'), 6, 0, None), [64, 1], True), None]); getitem_17 = None
|
||||
getitem_19: "f32[64, 1][1, 1]cpu" = validate_outputs_1[0]; validate_outputs_1 = None
|
||||
|
||||
# File: <eval_with_key>.0:18 in forward, code: addmm_backward0 = torch__dynamo_compiled_autograd_ops_AddmmBackward0([getitem_9], [True, False, True], 1, 1, getitem_3, 0, [64, 3], [], None, 0, [3, 1], [1, 3]); getitem_9 = getitem_3 = None
|
||||
addmm_backward0 = torch__dynamo_compiled_autograd_ops_AddmmBackward0([getitem_19], [True, False, True], 1, 1, getitem_11, 0, [64, 3], [], None, 0, [3, 1], [1, 3]); getitem_19 = getitem_11 = None
|
||||
getitem_22: "f32[64, 1][1, 1]cpu" = addmm_backward0[0]
|
||||
getitem_23: "f32[3, 1][1, 3]cpu" = addmm_backward0[2]; addmm_backward0 = None
|
||||
|
||||
# File: <eval_with_key>.0:22 in forward, code: validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_11, getitem_12, getitem_13], [((None, None, device(type='cpu'), 6, 0, None), [1], True), None, ((None, None, device(type='cpu'), 6, 0, None), [3, 1], True)]); getitem_11 = getitem_12 = getitem_13 = None
|
||||
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_22, None, getitem_23], [((None, None, device(type='cpu'), 6, 0, None), [1], True), None, ((None, None, device(type='cpu'), 6, 0, None), [3, 1], True)]); getitem_22 = getitem_23 = None
|
||||
getitem_26: "f32[1][1]cpu" = validate_outputs_2[0]
|
||||
getitem_27: "f32[3, 1][1, 3]cpu" = validate_outputs_2[2]; validate_outputs_2 = None
|
||||
|
||||
# File: /home/xmfan/core/a/pytorch/torch/_dynamo/polyfills/__init__.py:80 in accumulate_grad, code: new_grad = torch.clone(new_grad)
|
||||
new_grad: "f32[1][1]cpu" = torch.clone(getitem_26); getitem_26 = new_grad = None
|
||||
|
||||
# File: <eval_with_key>.0:26 in forward, code: tbackward0 = torch__dynamo_compiled_autograd_ops_TBackward0([getitem_16], [True]); getitem_16 = None
|
||||
tbackward0 = torch__dynamo_compiled_autograd_ops_TBackward0([getitem_27], [True]); getitem_27 = None
|
||||
getitem_29: "f32[1, 3][3, 1]cpu" = tbackward0[0]; tbackward0 = None
|
||||
|
||||
# File: <eval_with_key>.0:28 in forward, code: validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_17], [((None, None, device(type='cpu'), 6, 0, None), [1, 3], True)]); getitem_17 = None
|
||||
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_29], [((None, None, device(type='cpu'), 6, 0, None), [1, 3], True)]); getitem_29 = None
|
||||
getitem_31: "f32[1, 3][3, 1]cpu" = validate_outputs_3[0]; validate_outputs_3 = None
|
||||
|
||||
# File: /home/xmfan/core/a/pytorch/torch/_dynamo/polyfills/__init__.py:80 in accumulate_grad, code: new_grad = torch.clone(new_grad)
|
||||
new_grad_1: "f32[1, 3][3, 1]cpu" = torch.clone(getitem_31); getitem_31 = new_grad_1 = None
|
||||
return ()
|
||||
"""
|
||||
@ -10,6 +10,7 @@ import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from importlib.machinery import SourceFileLoader
|
||||
@ -48,11 +49,17 @@ from torch.testing._internal.logging_utils import logs_to_string
|
||||
# note: these tests are not run on windows due to inductor_utils.HAS_CPU
|
||||
|
||||
|
||||
def make_compiler_fn(fullgraph=True, dynamic=True, backend="inductor"):
|
||||
assert backend in ["inductor", "aot_eager"]
|
||||
def make_compiler_fn(
|
||||
fullgraph=True, dynamic=True, backend="inductor", gm_hook=lambda gm: None
|
||||
):
|
||||
assert backend in ["inductor", "aot_eager", "ca_eager"]
|
||||
|
||||
def _compiler_fn(gm):
|
||||
"""Same as torch.compile() but counts number of compiles"""
|
||||
gm_hook(gm)
|
||||
|
||||
if backend == "ca_eager":
|
||||
return gm
|
||||
|
||||
def _inner_compiler(gm_, example_inputs_):
|
||||
counters["compiled_autograd"]["compiles"] += 1
|
||||
@ -112,7 +119,10 @@ class TestCompiledAutograd(TestCase):
|
||||
torch.manual_seed(123)
|
||||
expected = list(fn())
|
||||
torch.manual_seed(123)
|
||||
with compiled_autograd._enable(compiler_fn):
|
||||
with compiled_autograd._enable(compiler_fn), mock.patch(
|
||||
"torch._functorch.aot_autograd.AOT_COUNTER",
|
||||
new_callable=itertools.count,
|
||||
):
|
||||
opt_fn = torch.compile(fn) if compile_fn else fn
|
||||
actual = list(opt_fn())
|
||||
self.assertEqual(expected, actual)
|
||||
@ -915,7 +925,8 @@ main()
|
||||
inputs=[param, activ],
|
||||
sizes=(),
|
||||
scalars=(),
|
||||
hooks=(),
|
||||
hooks=[],
|
||||
packed_inputs=[],
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
@ -3322,7 +3333,10 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
||||
graphs.append(gm)
|
||||
return inner_compiler_fn(gm)
|
||||
|
||||
with compiled_autograd._enable(compiler_fn):
|
||||
with compiled_autograd._enable(compiler_fn), mock.patch(
|
||||
"torch._functorch.aot_autograd.AOT_COUNTER",
|
||||
new_callable=itertools.count,
|
||||
):
|
||||
res = fn(x)
|
||||
res.sum().backward()
|
||||
|
||||
@ -3336,7 +3350,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
||||
graph_code,
|
||||
"""\
|
||||
class CompiledAutograd0(torch.nn.Module):
|
||||
def forward(self, inputs, sizes, scalars, hooks):
|
||||
def forward(self, inputs, sizes, scalars, hooks, packed_data):
|
||||
getitem = inputs[0]
|
||||
getitem_1 = inputs[1]
|
||||
getitem_2 = inputs[2]
|
||||
@ -3511,9 +3525,7 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
fn, count=2, compiler_fn=make_compiler_fn(backend="aot_eager")
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_saved_tensor_unpack_hook_ordering(self):
|
||||
# not the correct behaviour, I'm just preventing this from changing silently
|
||||
def f(x, y):
|
||||
return x * y
|
||||
|
||||
@ -3531,8 +3543,6 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
return x
|
||||
|
||||
def tensor_hook(_):
|
||||
# in eager, tensor_hook is fired before unpack_hook
|
||||
# but in compiled autograd, tensor_hook is lifted whereas unpack_hook is not
|
||||
self.assertEqual(unpack_count, 0)
|
||||
|
||||
x = torch.ones(4, requires_grad=True)
|
||||
@ -3544,21 +3554,252 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
self.assertEqual(pack_count, 1)
|
||||
self.assertEqual(unpack_count, 0)
|
||||
loss = out_test.sum()
|
||||
loss.register_hook(tensor_hook)
|
||||
loss.register_hook(
|
||||
tensor_hook
|
||||
) # scheduled to fire before any saved activations
|
||||
loss.backward()
|
||||
self.assertEqual(pack_count, 1)
|
||||
self.assertEqual(unpack_count, 1)
|
||||
|
||||
def test_reentrant_checkpointing(self):
|
||||
def fn(x):
|
||||
y = x.sin()
|
||||
z = y.cos()
|
||||
return (y * z).sum()
|
||||
@parametrize("reentrant", (True, False))
|
||||
def test_checkpointing_simple(self, reentrant):
|
||||
def fn():
|
||||
def _fn(x):
|
||||
y = x.sin()
|
||||
z = y.cos()
|
||||
return (y * z).sum()
|
||||
|
||||
inp = torch.rand(10, 10, requires_grad=True)
|
||||
out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True)
|
||||
with torch._dynamo.compiled_autograd._enable(torch.compile):
|
||||
inp = torch.rand(10, 10, requires_grad=True)
|
||||
out = torch.utils.checkpoint.checkpoint(_fn, inp, use_reentrant=reentrant)
|
||||
out.backward()
|
||||
yield inp.grad
|
||||
|
||||
if reentrant:
|
||||
self.check_output_and_recompiles(
|
||||
fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False)
|
||||
)
|
||||
else:
|
||||
# dynamo issues, just run the CA graph directly for now
|
||||
def check(gm):
|
||||
graph_code = normalize_gm(gm.print_readable(print_output=False))
|
||||
self.assertExpectedInline(
|
||||
graph_code,
|
||||
"""\
|
||||
class CompiledAutograd0(torch.nn.Module):
|
||||
def forward(self, inputs, sizes, scalars, hooks, packed_data):
|
||||
getitem = inputs[0]
|
||||
getitem_1 = inputs[1]; inputs = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
|
||||
getitem_2 = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_2], [True], [10, 10]); getitem_2 = None
|
||||
getitem_3 = sum_backward0[0]; sum_backward0 = None
|
||||
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_3], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_3 = None
|
||||
getitem_4 = validate_outputs_1[0]; validate_outputs_1 = None
|
||||
|
||||
getitem_5 = hooks[0]
|
||||
getitem_6 = packed_data[0]
|
||||
getitem_7 = hooks[1]
|
||||
getitem_8 = packed_data[1]
|
||||
call_hook = torch__dynamo_external_utils_call_hook(getitem_5, getitem_6, hook_type = 'unpack_hook'); getitem_5 = getitem_6 = None
|
||||
call_hook_1 = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None
|
||||
mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_4], [True, True], call_hook, 6, call_hook_1, 6); getitem_4 = call_hook = call_hook_1 = None
|
||||
getitem_9 = mul_backward0[0]
|
||||
getitem_10 = mul_backward0[1]; mul_backward0 = None
|
||||
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_9, getitem_10], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False), ((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_9 = getitem_10 = None
|
||||
getitem_11 = validate_outputs_2[0]
|
||||
getitem_12 = validate_outputs_2[1]; validate_outputs_2 = None
|
||||
|
||||
getitem_13 = hooks[2]
|
||||
getitem_14 = packed_data[2]
|
||||
call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_13, getitem_14, hook_type = 'unpack_hook'); getitem_13 = getitem_14 = None
|
||||
cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_12], [True], call_hook_2); getitem_12 = call_hook_2 = None
|
||||
getitem_15 = cos_backward0[0]; cos_backward0 = None
|
||||
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_15 = None
|
||||
getitem_16 = validate_outputs_3[0]; validate_outputs_3 = None
|
||||
add = torch.add(getitem_11, getitem_16); getitem_11 = getitem_16 = None
|
||||
|
||||
getitem_17 = hooks[3]; hooks = None
|
||||
getitem_18 = packed_data[3]; packed_data = None
|
||||
call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_17, getitem_18, hook_type = 'unpack_hook'); getitem_17 = getitem_18 = None
|
||||
sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None
|
||||
getitem_19 = sin_backward0[0]; sin_backward0 = None
|
||||
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_19], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_19 = None
|
||||
getitem_20 = validate_outputs_4[0]; validate_outputs_4 = None
|
||||
|
||||
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_20); getitem_1 = getitem_20 = accumulate_grad_ = None
|
||||
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
|
||||
return []
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.check_output_and_recompiles(
|
||||
fn,
|
||||
count=[1, 0],
|
||||
compiler_fn=make_compiler_fn(backend="ca_eager", gm_hook=check),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "requires cuda")
|
||||
def test_cpu_offloading(self):
|
||||
def fn():
|
||||
def pack(x):
|
||||
return x.cpu()
|
||||
|
||||
def unpack(x):
|
||||
return x.cuda()
|
||||
|
||||
class MyMatMul(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return torch.matmul(x, x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(x,) = ctx.saved_tensors
|
||||
return grad_out * x
|
||||
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
for i in [10, 100, 10, 20, 30]:
|
||||
x = torch.randn(i, requires_grad=True).cuda()
|
||||
MyMatMul.apply(x).sum().backward()
|
||||
yield x.grad
|
||||
|
||||
i = 0
|
||||
|
||||
def check(gm):
|
||||
nonlocal i
|
||||
if i == 0:
|
||||
i += 1
|
||||
return
|
||||
|
||||
graph_code = normalize_gm(gm.print_readable(print_output=False))
|
||||
self.assertExpectedInline(
|
||||
graph_code,
|
||||
"""\
|
||||
class CompiledAutograd1(torch.nn.Module):
|
||||
def forward(self, inputs, sizes, scalars, hooks, packed_data):
|
||||
getitem = inputs[0]
|
||||
getitem_1 = inputs[1]; inputs = None
|
||||
getitem_2 = sizes[0]; getitem_2 = None
|
||||
getitem_3 = sizes[1]
|
||||
getitem_4 = sizes[2]; sizes = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cuda', index=0), 6, 0, None), [], False)]); getitem = None
|
||||
getitem_5 = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], []); getitem_5 = None
|
||||
getitem_6 = sum_backward0[0]; sum_backward0 = None
|
||||
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cuda', index=0), 6, 0, None), [], False)]); getitem_6 = None
|
||||
getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None
|
||||
|
||||
getitem_8 = hooks[0]
|
||||
getitem_9 = packed_data[0]; packed_data = None
|
||||
getitem_10 = hooks[1]; hooks = None
|
||||
call_hook = torch__dynamo_external_utils_call_hook(getitem_8, getitem_9, hook_type = 'unpack_hook'); getitem_8 = getitem_9 = None
|
||||
call_backward = torch__dynamo_external_utils_call_backward(getitem_10, (call_hook,), getitem_7); getitem_10 = call_hook = getitem_7 = None
|
||||
getitem_12 = call_backward[0]; call_backward = None
|
||||
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_12], [((None, None, device(type='cuda', index=0), 6, 0, None), [getitem_3], False)]); getitem_12 = getitem_3 = None
|
||||
getitem_13 = validate_outputs_2[0]; validate_outputs_2 = None
|
||||
|
||||
to_copy_backward0 = torch__dynamo_compiled_autograd_ops_ToCopyBackward0([getitem_13], [True], (None, None, device(type='cpu'), 6, 0, None)); getitem_13 = None
|
||||
getitem_14 = to_copy_backward0[0]; to_copy_backward0 = None
|
||||
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_14], [((None, None, device(type='cpu'), 6, 0, None), [getitem_4], False)]); getitem_14 = getitem_4 = None
|
||||
getitem_15 = validate_outputs_3[0]; validate_outputs_3 = None
|
||||
|
||||
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_15); getitem_1 = getitem_15 = accumulate_grad_ = None
|
||||
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
|
||||
return []
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.check_output_and_recompiles(
|
||||
fn, count=2, compiler_fn=make_compiler_fn(gm_hook=check)
|
||||
)
|
||||
|
||||
@skipIfWindows(msg="temp dir not compatible")
|
||||
def test_disk_offloading(self):
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
|
||||
def fn():
|
||||
pack_count = 0
|
||||
|
||||
def pack(x):
|
||||
nonlocal pack_count
|
||||
path = f"{d}/{pack_count}.pt"
|
||||
torch.save(x, path)
|
||||
return path
|
||||
|
||||
def unpack(path):
|
||||
x = torch.load(path)
|
||||
return x
|
||||
|
||||
class MyMatMul(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return torch.matmul(x, x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(x,) = ctx.saved_tensors
|
||||
return grad_out * x
|
||||
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
for i in [10, 100, 10, 20, 30]:
|
||||
x = torch.randn(i, requires_grad=True)
|
||||
MyMatMul.apply(x).sum().backward()
|
||||
yield x.grad
|
||||
|
||||
i = 0
|
||||
|
||||
def check(gm):
|
||||
nonlocal i
|
||||
if i == 0:
|
||||
i += 1
|
||||
return
|
||||
|
||||
graph_code = normalize_gm(gm.print_readable(print_output=False))
|
||||
self.assertExpectedInline(
|
||||
graph_code,
|
||||
"""\
|
||||
class CompiledAutograd1(torch.nn.Module):
|
||||
def forward(self, inputs, sizes, scalars, hooks, packed_data):
|
||||
getitem = inputs[0]
|
||||
getitem_1 = inputs[1]; inputs = None
|
||||
getitem_2 = sizes[0]; getitem_2 = None
|
||||
getitem_3 = sizes[1]; sizes = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
|
||||
getitem_4 = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_4], [True], []); getitem_4 = None
|
||||
getitem_5 = sum_backward0[0]; sum_backward0 = None
|
||||
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_5], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem_5 = None
|
||||
getitem_6 = validate_outputs_1[0]; validate_outputs_1 = None
|
||||
|
||||
getitem_7 = hooks[0]
|
||||
getitem_8 = packed_data[0]; packed_data = None
|
||||
getitem_9 = hooks[1]; hooks = None
|
||||
call_hook = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None
|
||||
call_backward = torch__dynamo_external_utils_call_backward(getitem_9, (call_hook,), getitem_6); getitem_9 = call_hook = getitem_6 = None
|
||||
getitem_11 = call_backward[0]; call_backward = None
|
||||
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_11], [((None, None, device(type='cpu'), 6, 0, None), [getitem_3], False)]); getitem_11 = getitem_3 = None
|
||||
getitem_12 = validate_outputs_2[0]; validate_outputs_2 = None
|
||||
|
||||
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_12); getitem_1 = getitem_12 = accumulate_grad_ = None
|
||||
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
|
||||
return []
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
# 1 graph break on torch.load -> 2 dynamo graphs
|
||||
self.check_output_and_recompiles(
|
||||
fn,
|
||||
count=[2, 4],
|
||||
compiler_fn=make_compiler_fn(fullgraph=False, gm_hook=check),
|
||||
)
|
||||
|
||||
@skipIfWindows(msg="node name demangling inconsistent on windows")
|
||||
def test_backward_hook_relative_ordering_partial(self):
|
||||
@ -3617,7 +3858,7 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
|
||||
self.check_output_and_recompiles(fn)
|
||||
|
||||
def test_sac(self):
|
||||
def test_checkpointing_sac(self):
|
||||
# circular import
|
||||
from torch.utils.checkpoint import (
|
||||
checkpoint,
|
||||
@ -3666,7 +3907,9 @@ class CompiledAutograd0(torch.nn.Module):
|
||||
yield model.layer4.weight.grad
|
||||
yield model.layer4.bias.grad
|
||||
|
||||
self.check_output_and_recompiles(fn)
|
||||
self.check_output_and_recompiles(
|
||||
fn, count=[1, 5], compiler_fn=make_compiler_fn(fullgraph=False)
|
||||
)
|
||||
|
||||
|
||||
def load_test_module(name):
|
||||
@ -3754,6 +3997,26 @@ known_graph_breaks_tests = {
|
||||
"test_deep_reentrant", # reentrant .backward
|
||||
"test_reentrant_priority", # reentrant .backward
|
||||
"test_simple_reentrant", # reentrant .backward
|
||||
"test_checkpoint_detects_non_determinism", # unpack hook in skip files
|
||||
"test_checkpoint_valid_reset_on_error", # unpack hook in skip files
|
||||
"test_checkpointing_non_reentrant_autocast_cpu", # unpack hook in skip files
|
||||
"test_checkpointing_non_reentrant_autocast_gpu", # unpack hook in skip files
|
||||
"test_checkpointing_without_reentrant_arbitrary_input_output", # unpack hook in skip files
|
||||
"test_checkpointing_without_reentrant_correct_grad", # unpack hook in skip files
|
||||
"test_checkpointing_without_reentrant_custom_function_works", # unpack hook in skip files
|
||||
"test_checkpointing_without_reentrant_dataparallel", # _get_device_index in skip files
|
||||
"test_checkpointing_without_reentrant_detached_tensor_use_reentrant_True", # reentrant .backward
|
||||
"test_checkpointing_without_reentrant_parameter_used_in_an_out", # unpack hook in skip files
|
||||
"test_checkpointing_without_reentrant_with_context_fn", # unpack hook in skip files
|
||||
"test_save_on_cpu_and_checkpoint", # unpack hook in skip files
|
||||
"test_saved_tensor_hooks_custom_error_propagation", # CustomError
|
||||
"test_access_saved_tensor_twice_without_recomputation_works", # unpack hook in skip files
|
||||
"test_saved_tensor_hooks_extra_enter_during_bw_no_leak", # ctx in skip files
|
||||
"test_saved_tensor_hooks_extra_exit_during_bw_no_crash", # ctx in skip files
|
||||
"test_checkpointing", # reentrant .backward
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_False", # reentrant .backward
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_True", # reentrant .backward
|
||||
"test_checkpointing_without_reentrant_memory_savings", # reentrant .backward
|
||||
}
|
||||
|
||||
test_contexts = {
|
||||
@ -3764,9 +4027,7 @@ test_contexts = {
|
||||
}
|
||||
|
||||
# These groups of tests aren't supported yet
|
||||
known_failures_re = re.compile(
|
||||
r"^test_(sparse|profiler|gradcheck|checkpoint|named_tensor)"
|
||||
)
|
||||
known_failures_re = re.compile(r"^test_(sparse|profiler|gradcheck|named_tensor)")
|
||||
|
||||
# Bugs needing investigation:
|
||||
skipped_tests = {
|
||||
@ -3837,7 +4098,7 @@ known_failing_tests = {
|
||||
# IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
|
||||
"test_grad_nonleaf_register_hook",
|
||||
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
|
||||
# Category: Dynamo
|
||||
# Category: Dynamo (pass when directly running CA graph)
|
||||
"test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None
|
||||
"test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START
|
||||
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
|
||||
@ -3849,7 +4110,20 @@ known_failing_tests = {
|
||||
"test_return_duplicate", # gradient batching rule not implemented for aten::sym_size.int
|
||||
"test_return_duplicate_inplace", # gradient batching rule not implemented for aten::sym_size.int
|
||||
"test_setitem", # CopySlices accuracy error
|
||||
# Category: Inductor
|
||||
"test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565
|
||||
"test_checkpoint_detects_non_determinism", # different error
|
||||
"test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute
|
||||
"test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute
|
||||
"test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193
|
||||
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times
|
||||
"test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised
|
||||
"test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute
|
||||
"test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115
|
||||
"test_checkpointing", # takes very very long
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_False", # takes very very long
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_True", # takes very very long
|
||||
"test_checkpointing_without_reentrant_memory_savings", # takes very very long
|
||||
# Category: Inductor (pass on backend="aot_eager")
|
||||
"test_input_buffer_accum", # does not support sparse_grad=True: https://github.com/pytorch/pytorch/issues/120267
|
||||
"test_graph_save_on_cpu", # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173
|
||||
# Category: FakeTensor
|
||||
@ -3861,6 +4135,7 @@ known_failing_tests = {
|
||||
"test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward
|
||||
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
|
||||
"test_backward_hook_relative_ordering", # compiled autograd collects breadth first, and module backward hook not supported
|
||||
"test_checkpointing_without_reentrant_custom_function_works", # ctx.saved_tensors are cached by CA
|
||||
# Category: Subclasses
|
||||
"test_dtensor_basic",
|
||||
"test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent",
|
||||
|
||||
@ -67,7 +67,7 @@ struct TORCH_API ${op} : public ${superclass} {
|
||||
${release_variables}
|
||||
}
|
||||
${will_release_variables}
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
|
||||
${saved_variables}
|
||||
${saved_list_sizes}
|
||||
@ -127,7 +127,7 @@ variable_list ${op}::apply(variable_list&& grads) {
|
||||
return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
|
||||
}
|
||||
|
||||
void ${op}::compiled_args(CompiledNodeArgs& args) {
|
||||
void ${op}::compiled_args(CompiledNodeArgs& args) const {
|
||||
${compiled_args}
|
||||
}
|
||||
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
|
||||
|
||||
@ -134,7 +134,7 @@ class Op:
|
||||
ops = OpNamespace()
|
||||
|
||||
|
||||
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
|
||||
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks", "packed_data"]
|
||||
_impure_targets = OrderedSet(
|
||||
[
|
||||
call_hook,
|
||||
@ -164,12 +164,16 @@ class AutogradCompilerInstance:
|
||||
self.compiler_fn = compiler_fn
|
||||
self.stack = contextlib.ExitStack()
|
||||
self.close = self.stack.close
|
||||
self.shape_env = ShapeEnv()
|
||||
self.fake_tensor_mode = FakeTensorMode(
|
||||
allow_fallback_kernels=True,
|
||||
allow_non_fake_inputs=True,
|
||||
shape_env=self.shape_env,
|
||||
)
|
||||
if ctx := torch._guards.TracingContext.try_get():
|
||||
self.shape_env = ctx.fake_mode.shape_env
|
||||
self.fake_tensor_mode = ctx.fake_mode
|
||||
else:
|
||||
self.shape_env = ShapeEnv()
|
||||
self.fake_tensor_mode = FakeTensorMode(
|
||||
allow_fallback_kernels=True,
|
||||
allow_non_fake_inputs=True,
|
||||
shape_env=self.shape_env,
|
||||
)
|
||||
self.fx_tracer = PythonKeyTracer()
|
||||
self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
|
||||
self.hooks_proxy: Optional[Proxy] = None
|
||||
@ -206,7 +210,13 @@ class AutogradCompilerInstance:
|
||||
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
|
||||
self.fx_tracer.tensor_attrs = {}
|
||||
self.symnode_proxy_lookup = {}
|
||||
args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = (
|
||||
(
|
||||
args_proxy,
|
||||
self.sizes_proxy,
|
||||
self.scalars_proxy,
|
||||
self.hooks_proxy,
|
||||
self.packed_data_proxy,
|
||||
) = (
|
||||
self.fx_tracer.create_proxy("placeholder", name, (), {})
|
||||
for name in _graph_placeholders
|
||||
)
|
||||
@ -268,7 +278,12 @@ class AutogradCompilerInstance:
|
||||
self.stack.enter_context(
|
||||
torch.fx.experimental.symbolic_shapes._suppress_guards(env)
|
||||
)
|
||||
return str(CompileContext.current_compile_id()), inputs, sizes, scalars
|
||||
return (
|
||||
str(CompileContext.current_compile_id()),
|
||||
inputs,
|
||||
sizes,
|
||||
scalars,
|
||||
)
|
||||
|
||||
def log_compile_reasons(
|
||||
self,
|
||||
@ -567,6 +582,19 @@ class AutogradCompilerInstance:
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def unpack_hook(self, hook_id, data_id):
|
||||
assert self.hooks_proxy is not None
|
||||
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||||
data = self.packed_data_proxy[data_id] # type: ignore[index]
|
||||
proxy = self.proxy_call_hook(
|
||||
hook,
|
||||
data,
|
||||
hook_type="unpack_hook",
|
||||
)
|
||||
out = self.allocate_dummy()
|
||||
self.bind_objects_to_proxies([out], [proxy])
|
||||
return out
|
||||
|
||||
def tensor_pre_hook(self, inputs, hook_id, i: int):
|
||||
assert self.hooks_proxy is not None
|
||||
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||||
@ -706,6 +734,9 @@ class AutogradCompilerInstance:
|
||||
after = len(self.fx_tracer.graph.nodes)
|
||||
verbose_log.debug("DCE removed %d nodes", before - after)
|
||||
|
||||
def create_graph_module(self, id):
|
||||
return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id)
|
||||
|
||||
def end_capture(self, outputs):
|
||||
self.fx_tracer.create_proxy(
|
||||
"call_function",
|
||||
@ -745,6 +776,7 @@ class AutogradCompilerInstance:
|
||||
).print_readable(print_output=False),
|
||||
)
|
||||
self.rename_aot_dispatcher_nodes()
|
||||
self.delay_unpack_hook_nodes()
|
||||
self.reorder_tensor_pre_hook_nodes()
|
||||
self.reorder_pre_hook_nodes_to_schedule_asap()
|
||||
self.reorder_accumulate_grad_nodes()
|
||||
@ -763,9 +795,7 @@ class AutogradCompilerInstance:
|
||||
# should prevent these ops from going into the CA graph.
|
||||
self.dce()
|
||||
|
||||
graph = GraphModule(
|
||||
self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}"
|
||||
)
|
||||
graph = self.create_graph_module(f"CompiledAutograd{self.id}")
|
||||
set_locals_to_steal(graph, ["inputs"])
|
||||
lazy_graph_code = lazy_format_graph_code(
|
||||
"Compiled autograd graph",
|
||||
@ -781,17 +811,20 @@ class AutogradCompilerInstance:
|
||||
payload_fn=lambda: graph.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
|
||||
global in_compiled_autograd_region
|
||||
try:
|
||||
in_compiled_autograd_region = True
|
||||
for i in runtime_inputs_to_move:
|
||||
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
|
||||
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs):
|
||||
# prior = torch._C._dynamo.eval_frame.set_eval_frame(None)
|
||||
return compiled_fn(inputs, sizes, scalars, hooks, packed_inputs)
|
||||
# torch._C._dynamo.eval_frame.set_eval_frame(prior)
|
||||
# global in_compiled_autograd_region
|
||||
# try:
|
||||
# in_compiled_autograd_region = True
|
||||
# for i in runtime_inputs_to_move:
|
||||
# inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
|
||||
|
||||
with _disable(), make_compile_context(self.id):
|
||||
return compiled_fn(inputs, sizes, scalars, hooks)
|
||||
finally:
|
||||
in_compiled_autograd_region = False
|
||||
# with _disable(), make_compile_context(self.id):
|
||||
# return compiled_fn(inputs, sizes, scalars, hooks, packed_inputs)
|
||||
# finally:
|
||||
# in_compiled_autograd_region = False
|
||||
|
||||
get_chromium_event_logger().log_event_end(
|
||||
"compiled_autograd",
|
||||
@ -938,6 +971,19 @@ class AutogradCompilerInstance:
|
||||
if getitem_node is not None:
|
||||
arg.append(getitem_node)
|
||||
|
||||
def delay_unpack_hook_nodes(self):
|
||||
"""
|
||||
We can delay unpack hooks until they are needed, even later than in the eager autograd engine.
|
||||
"""
|
||||
for node in self.fx_tracer.graph.find_nodes(
|
||||
op="call_function", target=call_hook
|
||||
):
|
||||
if node.kwargs.get("hook_type", None) != "unpack_hook":
|
||||
continue
|
||||
|
||||
first_user = min(node.users)
|
||||
first_user.prepend(node)
|
||||
|
||||
def reorder_tensor_pre_hook_nodes(self):
|
||||
"""
|
||||
Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed
|
||||
|
||||
@ -80,8 +80,9 @@ def accumulate_grad(x, new_grad):
|
||||
new_grad = torch.clone(new_grad)
|
||||
if x.grad is None:
|
||||
x.grad = new_grad
|
||||
else:
|
||||
x.grad.add_(new_grad)
|
||||
# problem here. can't trace???
|
||||
# else:
|
||||
# x.grad.add_(new_grad)
|
||||
|
||||
|
||||
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
|
||||
|
||||
@ -466,6 +466,7 @@ class CompileEventLogger:
|
||||
chromium_log = get_chromium_event_logger()
|
||||
top_event = chromium_log.get_outermost_event()
|
||||
if top_event is None:
|
||||
return
|
||||
raise RuntimeError(
|
||||
"No toplevel event active. Please only call this function within a metrics context/dynamo_timed."
|
||||
)
|
||||
|
||||
@ -2426,6 +2426,7 @@ def _wrap_fx_proxy(
|
||||
|
||||
# This handles wrapping of the output of an op traced into the graph
|
||||
def handle_traced_output(example_value, tx, proxy, options, subclass_type, target_cls):
|
||||
# HERE
|
||||
import torch._functorch.vmap
|
||||
import torch._subclasses.fake_tensor
|
||||
import torch._utils
|
||||
@ -2480,6 +2481,69 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
return SizeVariable(sizes, **options)
|
||||
elif isinstance(example_value, (tuple, list)):
|
||||
set_example_value(proxy.node, example_value)
|
||||
if (
|
||||
len(example_value) == 2
|
||||
and example_value[1]
|
||||
and len(example_value[1]) == 5
|
||||
and [type(val) for val in example_value[1]]
|
||||
== [list, tuple, list, tuple, tuple]
|
||||
and isinstance(example_value[0], torch.fx.graph_module.GraphModule)
|
||||
and "locals_to_steal" in example_value[0].meta
|
||||
):
|
||||
gm, (inputs, sizes, scalars, hooks, packed_data) = example_value
|
||||
# gm_proxy = proxy.tracer.create_proxy(
|
||||
# kind="call_function",
|
||||
# target=operator.getitem,
|
||||
# args=(proxy, 0),
|
||||
# kwargs={},
|
||||
# ) # mmmmm should we use the proxy vs just by value? it is const
|
||||
# then need to guard here?
|
||||
gm_var = SourcelessGraphModuleVariable(gm)
|
||||
args_proxy = proxy.tracer.create_proxy(
|
||||
kind="call_function",
|
||||
target=operator.getitem,
|
||||
args=(proxy, 1),
|
||||
kwargs={},
|
||||
)
|
||||
inputs_proxy = proxy.tracer.create_proxy(
|
||||
kind="call_function",
|
||||
target=operator.getitem,
|
||||
args=(args_proxy, 0),
|
||||
kwargs={},
|
||||
)
|
||||
inputs_var = []
|
||||
for i, inp in enumerate(inputs):
|
||||
inputs_proxy_i = proxy.tracer.create_proxy(
|
||||
kind="call_function",
|
||||
target=operator.getitem,
|
||||
args=(inputs_proxy, i),
|
||||
kwargs={},
|
||||
)
|
||||
inputs_var.append(
|
||||
wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=inputs_proxy_i,
|
||||
example_value=inp,
|
||||
subclass_type=None, # uhhh how to tell this?
|
||||
source=None, # won't have any?
|
||||
**options,
|
||||
)
|
||||
)
|
||||
|
||||
# need to actually pass args in here
|
||||
args_var = TupleVariable(
|
||||
[
|
||||
ListVariable(inputs_var),
|
||||
TupleVariable([]),
|
||||
ListVariable([]),
|
||||
TupleVariable([]),
|
||||
TupleVariable([]),
|
||||
]
|
||||
)
|
||||
out = TupleVariable([gm_var, args_var])
|
||||
breakpoint()
|
||||
return out
|
||||
|
||||
unpacked = []
|
||||
for i, val in enumerate(example_value):
|
||||
if val is None:
|
||||
@ -2508,7 +2572,14 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
# use the same options object as parent
|
||||
options_i = options
|
||||
|
||||
# if isinstance(val, torch.fx.graph_module.GraphModule):
|
||||
# print("creating ")
|
||||
# unpacked.append(SourcelessGraphModuleVariable(val))
|
||||
# unpacked.append(SourcelessBuilder.create(self.tx, v for v in example_value))
|
||||
|
||||
# WARNING: this assumes the same target_cls as this tuple/list call
|
||||
# I think we need to create a new Variable to properly model the output of the functional CA call
|
||||
# it's wrong that we even end up here, target_cls is TensorVariable
|
||||
unpacked.append(
|
||||
wrap_fx_proxy_cls(
|
||||
target_cls=target_cls,
|
||||
|
||||
@ -1227,6 +1227,8 @@ class SkipFunctionVariable(VariableTracker):
|
||||
]
|
||||
# also warn on it because most users won't see the graph break message
|
||||
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
|
||||
# not supposed to get here
|
||||
breakpoint()
|
||||
if self.value.__qualname__ == "allow_in_graph":
|
||||
explanation = (
|
||||
"Found an allow_in_graph decorator to a function which "
|
||||
|
||||
@ -577,6 +577,9 @@ class TensorVariable(VariableTracker):
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
|
||||
|
||||
# if name == "backward":
|
||||
# breakpoint()
|
||||
|
||||
if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
|
||||
unimplemented(f"Illegal method invocation {name} in strict mode")
|
||||
|
||||
|
||||
@ -657,6 +657,8 @@ def _create_aot_dispatcher_function(
|
||||
needs_autograd = any(
|
||||
x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)
|
||||
)
|
||||
# only for these fullgraphs
|
||||
needs_autograd = False
|
||||
|
||||
with enable_python_dispatcher():
|
||||
# Patch set_rng_state as set_rng_state with fake tensors is
|
||||
@ -1193,6 +1195,7 @@ def aot_module_simplified(
|
||||
# convention. This should get fixed...
|
||||
# NB: GraphModule/nn.Module rely on the non-boxed calling convention here
|
||||
def forward(*runtime_args: tuple[Any]):
|
||||
print("at runtime")
|
||||
full_args = []
|
||||
full_args.extend(params_flat)
|
||||
full_args.extend(runtime_args)
|
||||
|
||||
@ -284,7 +284,7 @@ struct CppNode : public Node {
|
||||
void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
|
||||
void save_variables_to_ctx();
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override {
|
||||
void compiled_args(CompiledNodeArgs& args) const override {
|
||||
// although neither of the 2 methods below have uniqueness guarantees
|
||||
// it is unlikely for them to collide at the same time
|
||||
args.collect(static_cast<uint64_t>(typeid(T).hash_code()));
|
||||
|
||||
@ -1330,10 +1330,11 @@ auto Engine::execute(
|
||||
TORCH_CHECK(
|
||||
!create_graph, "compiled_autograd does not support create_graph");
|
||||
_thread_check.release();
|
||||
TORCH_CHECK(
|
||||
!AnomalyMode::is_enabled(),
|
||||
"compiled_autograd does not support AnomalyMode")
|
||||
// TORCH_CHECK(
|
||||
// !AnomalyMode::is_enabled(),
|
||||
// "compiled_autograd does not support AnomalyMode")
|
||||
GraphTaskGuard guard(graph_task);
|
||||
CheckpointValidGuard cpvguard(graph_task);
|
||||
return (*compiled_autograd)(
|
||||
graph_root, *graph_task, accumulate_grad, outputs);
|
||||
}
|
||||
|
||||
@ -138,7 +138,7 @@ struct TORCH_API Engine {
|
||||
// see [Note: Compiled Autograd]
|
||||
typedef variable_list (*compiled_autograd_fn)(
|
||||
const std::shared_ptr<Node>& graph_root,
|
||||
GraphTask& graph_task,
|
||||
const GraphTask& graph_task,
|
||||
bool accumulate_grad,
|
||||
const edge_list& outputs);
|
||||
static void set_compiled_autograd(compiled_autograd_fn fn);
|
||||
|
||||
@ -545,8 +545,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
return tensor_pre_hooks_;
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<PostAccumulateGradHook>&
|
||||
tensor_post_acc_grad_hooks() noexcept {
|
||||
virtual std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
|
||||
const noexcept {
|
||||
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
|
||||
return empty;
|
||||
}
|
||||
@ -593,7 +593,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
// 2) Collect node information for specialization and caching
|
||||
// Implementations in subclasses should call args.collect() with all node
|
||||
// attrs. These functions are only called durring backward.
|
||||
virtual void compiled_args(CompiledNodeArgs& args) {
|
||||
virtual void compiled_args(CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("compiled_args not implemented: ") + name());
|
||||
}
|
||||
|
||||
@ -22,7 +22,8 @@ struct TORCH_API FunctionPreHook {
|
||||
virtual ~FunctionPreHook() = default;
|
||||
virtual variable_list operator()(const variable_list& grads) = 0;
|
||||
// only implemented for python hooks, registers hook with compiled autograd
|
||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
virtual void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||
typeid(*this).name());
|
||||
@ -35,7 +36,8 @@ struct TORCH_API FunctionPostHook {
|
||||
const variable_list& outputs /* grad_inputs */,
|
||||
const variable_list& inputs /* grad_outputs */) = 0;
|
||||
// only implemented for python hooks, registers hook with compiled autograd
|
||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
virtual void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||
typeid(*this).name());
|
||||
@ -47,7 +49,8 @@ struct TORCH_API PostAccumulateGradHook {
|
||||
virtual void operator()(const Variable& tensor) = 0;
|
||||
// only implemented for python hooks on nodes, registers hook with compiled
|
||||
// autograd
|
||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
virtual void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("not yet implemented for compiled autograd: ") +
|
||||
typeid(*this).name());
|
||||
|
||||
@ -66,12 +66,12 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
|
||||
return variable_list();
|
||||
}
|
||||
|
||||
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) {
|
||||
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const {
|
||||
if (args.cond(variable.defined() && variable.requires_grad())) {
|
||||
args.collect(variable);
|
||||
args.collect(variable.grad());
|
||||
}
|
||||
auto& hook = tensor_post_acc_grad_hooks();
|
||||
const auto& hook = tensor_post_acc_grad_hooks();
|
||||
if (hook != nullptr) {
|
||||
hook->compiled_args(args);
|
||||
}
|
||||
|
||||
@ -50,8 +50,8 @@ struct TORCH_API AccumulateGrad : public Node {
|
||||
return impl::hooks(variable);
|
||||
}
|
||||
|
||||
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks() noexcept
|
||||
override {
|
||||
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
|
||||
const noexcept override {
|
||||
// NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
|
||||
// it can be destroyed even though the Tensor is still alive (contrary
|
||||
// to all other Nodes). So we must lazily read the Tensor hooks here.
|
||||
@ -262,7 +262,7 @@ struct TORCH_API AccumulateGrad : public Node {
|
||||
}
|
||||
}
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
|
||||
@ -12,11 +12,15 @@
|
||||
|
||||
namespace torch::autograd {
|
||||
|
||||
auto Error::apply(variable_list&& inputs) -> variable_list {
|
||||
variable_list Error::apply(variable_list&& inputs) {
|
||||
return static_cast<const Error*>(this)->apply(std::move(inputs));
|
||||
}
|
||||
|
||||
variable_list Error::apply(variable_list&& inputs) const {
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
|
||||
void Error::compiled_args(CompiledNodeArgs& args) {
|
||||
void Error::compiled_args(CompiledNodeArgs& args) const {
|
||||
// throw the error durring collect, the graph won't get compiled
|
||||
apply(variable_list());
|
||||
}
|
||||
@ -66,7 +70,7 @@ auto Identity::apply(variable_list&& grads) -> variable_list {
|
||||
return std::move(grads);
|
||||
}
|
||||
|
||||
void GraphRoot::compiled_args(CompiledNodeArgs& args) {
|
||||
void GraphRoot::compiled_args(CompiledNodeArgs& args) const {
|
||||
args.collect(outputs);
|
||||
}
|
||||
variable_list GraphRoot::apply_with_saved(
|
||||
|
||||
@ -18,8 +18,9 @@ struct TORCH_API Error : public Node {
|
||||
Error(std::string msg) : msg(std::move(msg)) {}
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list apply(variable_list&& inputs) const;
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
@ -51,6 +52,7 @@ struct TORCH_API DelayedError : public Node {
|
||||
}
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list apply(variable_list&& inputs) const;
|
||||
|
||||
std::string msg;
|
||||
};
|
||||
@ -61,6 +63,7 @@ struct TORCH_API UndefinedGrad : public Node {
|
||||
}
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list apply(variable_list&& inputs) const;
|
||||
};
|
||||
|
||||
struct TORCH_API UndefinedGradBackward : public Node {
|
||||
@ -69,8 +72,9 @@ struct TORCH_API UndefinedGradBackward : public Node {
|
||||
UndefinedGradBackward() = default;
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list apply(variable_list&& inputs) const;
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override {}
|
||||
void compiled_args(CompiledNodeArgs& args) const override {}
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override {
|
||||
@ -93,7 +97,7 @@ struct TORCH_API GraphRoot : public Node {
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
|
||||
@ -60,7 +60,7 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
|
||||
src_options);
|
||||
}
|
||||
|
||||
void CopyBackwards::compiled_args(CompiledNodeArgs& args) {
|
||||
void CopyBackwards::compiled_args(CompiledNodeArgs& args) const {
|
||||
args.collect(src_options);
|
||||
}
|
||||
|
||||
@ -235,7 +235,7 @@ void CopySlices::release_variables() {
|
||||
fn = nullptr;
|
||||
}
|
||||
|
||||
void CopySlices::compiled_args(CompiledNodeArgs& args) {
|
||||
void CopySlices::compiled_args(CompiledNodeArgs& args) const {
|
||||
TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd")
|
||||
TORCH_INTERNAL_ASSERT((bool)fn);
|
||||
args.collect(base);
|
||||
|
||||
@ -15,7 +15,7 @@ namespace torch::autograd {
|
||||
|
||||
struct TORCH_API CopyBackwards : public Node {
|
||||
variable_list apply(variable_list&& grads) override;
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
@ -168,7 +168,7 @@ struct TORCH_API CopySlices : public Node {
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
void release_variables() override;
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/tensor_dtypes.h>
|
||||
|
||||
#include <autograd/function.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
@ -185,9 +186,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
||||
return to_variable_list(r.get(), is_variable_input);
|
||||
}
|
||||
|
||||
auto PyNode::defer_to_dynamo(
|
||||
auto PyNode::apply_with_saved_impl(
|
||||
const variable_list& inputs,
|
||||
const std::optional<PyObject*>& compiler) -> variable_list {
|
||||
const SwapSavedVariables& saved) -> variable_list {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::OptionalDeviceGuard _device_guard;
|
||||
THPFunction* py_fn = (THPFunction*)obj;
|
||||
@ -235,24 +236,24 @@ auto PyNode::defer_to_dynamo(
|
||||
}
|
||||
THPObjectPtr saved_tensors(unpack_saved_variables(
|
||||
py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
_backward_idx.has_value(),
|
||||
"indices should already be set by compiled_args, called before apply_with_saved");
|
||||
|
||||
auto [bwd_idx, maybe_bwd_state_idx] = saved.retrieve_pynode_objs(this);
|
||||
|
||||
PyObject* backward_state_idx = Py_None;
|
||||
if (_backward_state_idx.has_value()) {
|
||||
backward_state_idx = THPUtils_packInt64(_backward_state_idx.value());
|
||||
if (maybe_bwd_state_idx.has_value()) {
|
||||
backward_state_idx = THPUtils_packUInt64(maybe_bwd_state_idx.value());
|
||||
// this might be simplifiable now that we no longer inline
|
||||
Py_CLEAR(py_fn->compiled_autograd_backward_state);
|
||||
}
|
||||
THPObjectPtr r(PyObject_CallMethod(
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
compiler.value(),
|
||||
saved.get_py_compiler(),
|
||||
"proxy_call_backward",
|
||||
"OOOiOO",
|
||||
pyInputs.get(),
|
||||
fwdInputMetadatas.get(),
|
||||
saved_tensors.get(),
|
||||
*_backward_idx,
|
||||
bwd_idx,
|
||||
obj,
|
||||
backward_state_idx));
|
||||
|
||||
@ -301,7 +302,7 @@ bool PyNode::is_aot_backward() const {
|
||||
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
|
||||
}
|
||||
|
||||
void PyNode::compiled_args(CompiledNodeArgs& args) {
|
||||
void PyNode::compiled_args(CompiledNodeArgs& args) const {
|
||||
static PyObject* method_name =
|
||||
PyUnicode_InternFromString("_compiled_autograd_key");
|
||||
THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr));
|
||||
@ -346,14 +347,15 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
|
||||
args.collect(f->input_info);
|
||||
|
||||
Py_INCREF(obj);
|
||||
_backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
|
||||
|
||||
c10::SafePyObject backward_obj(obj, getPyInterpreter());
|
||||
std::optional<c10::SafePyObject> backward_state_obj;
|
||||
PyObject* bw_state = f->compiled_autograd_backward_state;
|
||||
if (args.cond(bw_state != nullptr)) {
|
||||
Py_INCREF(bw_state);
|
||||
_backward_state_idx = args.add_backward_state(
|
||||
c10::SafePyObject(bw_state, getPyInterpreter()));
|
||||
backward_state_obj = c10::SafePyObject(bw_state, getPyInterpreter());
|
||||
}
|
||||
args.collect_pynode_objs(
|
||||
this, std::move(backward_obj), std::move(backward_state_obj));
|
||||
}
|
||||
|
||||
variable_list PyNode::apply_with_saved(
|
||||
@ -368,8 +370,7 @@ variable_list PyNode::apply_with_saved(
|
||||
saved.before(f->output_info);
|
||||
saved.before(f->input_info);
|
||||
f->compiled_autograd_tracing = true;
|
||||
variable_list result =
|
||||
defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
|
||||
variable_list result = apply_with_saved_impl(variable_list(inputs), saved);
|
||||
f->compiled_autograd_tracing = false;
|
||||
saved.after(f->compiled_autograd_symints);
|
||||
saved.after(f->saved_variables);
|
||||
|
||||
@ -35,9 +35,9 @@ struct PyNode : public Node {
|
||||
const std::vector<bool>& is_variable_input);
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list defer_to_dynamo(
|
||||
variable_list apply_with_saved_impl(
|
||||
const variable_list& inputs,
|
||||
const std::optional<PyObject*>& compiler);
|
||||
const SwapSavedVariables& saved);
|
||||
|
||||
void release_variables() override;
|
||||
std::string name() const override;
|
||||
@ -45,7 +45,7 @@ struct PyNode : public Node {
|
||||
|
||||
bool is_aot_backward() const override;
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
@ -53,13 +53,6 @@ struct PyNode : public Node {
|
||||
// THPFunction this Function is wrapping. Owning!
|
||||
PyObject* obj;
|
||||
|
||||
// The AutogradCompilerCall::hooks idx corresponding to this node's backward
|
||||
std::optional<int> _backward_idx;
|
||||
|
||||
// The AutogradCompilerCall::hooks idx corresponding to this node's
|
||||
// backward_state
|
||||
std::optional<int> _backward_state_idx;
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
~PyNode() override {
|
||||
// Can't use THPObjectPtr as a field in this class; destructor won't take
|
||||
|
||||
@ -176,7 +176,7 @@ auto PyFunctionPostHook::operator()(
|
||||
return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
|
||||
}
|
||||
|
||||
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
|
||||
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) const {
|
||||
PyObject *key = nullptr, *value = nullptr;
|
||||
Py_ssize_t pos = 0;
|
||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||
@ -189,7 +189,7 @@ void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
|
||||
Py_END_CRITICAL_SECTION();
|
||||
}
|
||||
|
||||
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
|
||||
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) const {
|
||||
PyObject *key = nullptr, *value = nullptr;
|
||||
Py_ssize_t pos = 0;
|
||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||
@ -200,7 +200,7 @@ void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
|
||||
Py_END_CRITICAL_SECTION();
|
||||
}
|
||||
|
||||
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) {
|
||||
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) const {
|
||||
PyObject *key = nullptr, *value = nullptr;
|
||||
Py_ssize_t pos = 0;
|
||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||
@ -237,7 +237,7 @@ auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor)
|
||||
}
|
||||
|
||||
void PyFunctionTensorPostAccGradHooks::compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
PyObject *key = nullptr, *value = nullptr;
|
||||
Py_ssize_t pos = 0;
|
||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||
|
||||
@ -14,7 +14,8 @@ struct PyFunctionTensorPreHook : public FunctionPreHook {
|
||||
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
|
||||
~PyFunctionTensorPreHook() override;
|
||||
variable_list operator()(const variable_list& values) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||
PyObject* dict;
|
||||
size_t value_idx;
|
||||
};
|
||||
@ -23,7 +24,8 @@ struct PyFunctionPreHook : public FunctionPreHook {
|
||||
PyFunctionPreHook(PyObject* dict);
|
||||
~PyFunctionPreHook() override;
|
||||
variable_list operator()(const variable_list& values) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||
PyObject* dict;
|
||||
};
|
||||
|
||||
@ -33,7 +35,8 @@ struct PyFunctionPostHook : public FunctionPostHook {
|
||||
variable_list operator()(
|
||||
const variable_list& outputs,
|
||||
const variable_list& inputs) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||
PyObject* dict;
|
||||
};
|
||||
|
||||
@ -45,7 +48,8 @@ struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
|
||||
PyFunctionTensorPostAccGradHooks(PyObject* dict);
|
||||
~PyFunctionTensorPostAccGradHooks() override;
|
||||
void operator()(const Variable& tensor) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||
void apply_with_saved(
|
||||
Variable& tensor,
|
||||
torch::dynamo::autograd::SwapSavedVariables& saved) override;
|
||||
|
||||
@ -46,6 +46,15 @@ at::Tensor PySavedVariableHooks::call_unpack_hook() {
|
||||
// unpack_hook_ will be manually decrefed when the saved variable is released
|
||||
}
|
||||
|
||||
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||
PySavedVariableHooks::retrieve_unpack_hook_data() const {
|
||||
Py_INCREF(unpack_hook_);
|
||||
Py_INCREF(data_);
|
||||
return std::make_pair(
|
||||
c10::SafePyObject(unpack_hook_, getPyInterpreter()),
|
||||
c10::SafePyObject(data_, getPyInterpreter()));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
PySavedVariableHooks::~PySavedVariableHooks() {
|
||||
// If python is already dead, leak the wrapped python objects
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
@ -17,6 +18,8 @@ struct PySavedVariableHooks : public SavedVariableHooks {
|
||||
void call_pack_hook(const at::Tensor& tensor) override;
|
||||
at::Tensor call_unpack_hook() override;
|
||||
~PySavedVariableHooks() override;
|
||||
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||
retrieve_unpack_hook_data() const override;
|
||||
|
||||
private:
|
||||
PyObject* pack_hook_;
|
||||
|
||||
@ -59,6 +59,7 @@ SavedVariable::SavedVariable(
|
||||
if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
save_metadata(variable);
|
||||
set_hooks_and_pack_data(std::move(maybe_hooks), variable);
|
||||
TORCH_INTERNAL_ASSERT(!data_.defined());
|
||||
return;
|
||||
}
|
||||
|
||||
@ -134,9 +135,14 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
|
||||
// We want grad_fn here to provide the most helpful debug message to the user
|
||||
// if versions don't match
|
||||
|
||||
auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock()
|
||||
: !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr
|
||||
: grad_fn_;
|
||||
std::shared_ptr<Node> grad_fn;
|
||||
if (is_inplace_on_view_) {
|
||||
grad_fn = weak_grad_fn_.lock();
|
||||
} else if (!hooks_) {
|
||||
grad_fn = saved_original_ ? data_.grad_fn() : nullptr;
|
||||
} else {
|
||||
grad_fn = grad_fn_;
|
||||
}
|
||||
|
||||
if (!is_leaf_ && !grad_fn) {
|
||||
// This issue was introduced when we added logic to save the original
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/autograd/forward_grad.h>
|
||||
#include <torch/csrc/autograd/saved_variable_hooks.h>
|
||||
@ -53,6 +54,15 @@ class TORCH_API SavedVariable {
|
||||
return (bool)hooks_;
|
||||
}
|
||||
|
||||
// Used by compiled autograd
|
||||
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||
retrieve_unpack_hook_data() const {
|
||||
if (!hooks_) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return hooks_->retrieve_unpack_hook_data();
|
||||
}
|
||||
|
||||
private:
|
||||
// This field contains either:
|
||||
// 1. the variable to save
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/SafePyObject.h>
|
||||
|
||||
namespace torch::autograd {
|
||||
|
||||
@ -8,6 +9,11 @@ struct TORCH_API SavedVariableHooks {
|
||||
virtual void call_pack_hook(const at::Tensor& tensor) = 0;
|
||||
virtual at::Tensor call_unpack_hook() = 0;
|
||||
virtual ~SavedVariableHooks() = default;
|
||||
virtual std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||
retrieve_unpack_hook_data() const {
|
||||
throw std::runtime_error(
|
||||
"Compiled Autograd only supports python saved tensor hooks ");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch::autograd
|
||||
|
||||
@ -27,7 +27,7 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
|
||||
return fn_(outputs, inputs);
|
||||
}
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override {}
|
||||
void compiled_args(CompiledNodeArgs& args) const override {}
|
||||
|
||||
protected:
|
||||
std::function<variable_list(const variable_list&, const variable_list&)> fn_;
|
||||
|
||||
@ -183,21 +183,23 @@ Reducer::Reducer(
|
||||
#endif
|
||||
// Hook to execute after the gradient accumulator has executed.
|
||||
hooks_.emplace_back(
|
||||
grad_accumulator->add_post_hook(
|
||||
std::make_unique<torch::autograd::utils::LambdaPostHook>(
|
||||
[this, variable_index](
|
||||
const torch::autograd::variable_list& outputs,
|
||||
const torch::autograd::variable_list& /* unused */) {
|
||||
grad_accumulator->add_post_hook(std::make_unique<
|
||||
torch::autograd::utils::
|
||||
LambdaPostHook>(
|
||||
[this, variable_index](
|
||||
const torch::autograd::variable_list& outputs,
|
||||
const torch::autograd::variable_list& /* unused */) {
|
||||
#ifndef _WIN32
|
||||
this->rpc_context_.set(
|
||||
ThreadLocalDistAutogradContext::getContextPtr());
|
||||
this->rpc_context_.set(
|
||||
ThreadLocalDistAutogradContext::getContextPtr());
|
||||
#endif
|
||||
this->autograd_hook(variable_index);
|
||||
return outputs;
|
||||
},
|
||||
[=](torch::autograd::CompiledNodeArgs& args) {
|
||||
// Make post_hook an noop if compiled_autograds is enabled.
|
||||
})),
|
||||
this->autograd_hook(variable_index);
|
||||
return outputs;
|
||||
},
|
||||
[=](torch::autograd::CompiledNodeArgs& args) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
|
||||
})),
|
||||
grad_accumulator);
|
||||
|
||||
// Map raw function pointer to parameter index.
|
||||
|
||||
@ -3,20 +3,22 @@
|
||||
|
||||
namespace torch::dynamo::autograd {
|
||||
|
||||
std::unique_ptr<PyCompilerInterface> kPyCompilerInterface;
|
||||
std::unique_ptr<PyCompilerInterface> kActivePyCompilerInterface;
|
||||
|
||||
const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface() {
|
||||
TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr);
|
||||
return kPyCompilerInterface;
|
||||
TORCH_INTERNAL_ASSERT(kActivePyCompilerInterface != nullptr);
|
||||
return kActivePyCompilerInterface;
|
||||
}
|
||||
|
||||
void setPyCompilerInterface(std::unique_ptr<PyCompilerInterface>&& impl) {
|
||||
TORCH_INTERNAL_ASSERT(impl != nullptr);
|
||||
kPyCompilerInterface = std::move(impl);
|
||||
PyCompilerGuard::PyCompilerGuard(std::unique_ptr<PyCompilerInterface>&& impl) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
kActivePyCompilerInterface == nullptr && impl != nullptr);
|
||||
kActivePyCompilerInterface = std::move(impl);
|
||||
}
|
||||
|
||||
void resetPyCompilerInterface() {
|
||||
kPyCompilerInterface.reset();
|
||||
PyCompilerGuard::~PyCompilerGuard() {
|
||||
TORCH_INTERNAL_ASSERT(kActivePyCompilerInterface != nullptr);
|
||||
kActivePyCompilerInterface.reset();
|
||||
}
|
||||
|
||||
std::vector<std::optional<InputMetadata>> get_input_metadata(
|
||||
|
||||
@ -17,6 +17,84 @@
|
||||
namespace torch::dynamo::autograd {
|
||||
using namespace torch::autograd;
|
||||
|
||||
// This is a layer of indirection for calling methods on the Python
|
||||
// AutogradCompilerInstance (referred to as the "py_compiler") from
|
||||
// libtorch_cpu (where Python is not available).
|
||||
// A PyCompilerInterfaceImpl in libtorch_python subclasses it and
|
||||
// overrides the methods to do the actual calls back to Python.
|
||||
struct TORCH_API PyCompilerInterface {
|
||||
PyCompilerInterface() = default;
|
||||
PyCompilerInterface(const PyCompilerInterface&) = delete;
|
||||
PyCompilerInterface& operator=(const PyCompilerInterface&) = delete;
|
||||
PyCompilerInterface(PyCompilerInterface&&) = delete;
|
||||
PyCompilerInterface& operator=(PyCompilerInterface&&) = delete;
|
||||
virtual ~PyCompilerInterface() = default;
|
||||
|
||||
// Invokes py_compiler.bind_function
|
||||
virtual std::string bind_function(
|
||||
PyObject* py_compiler,
|
||||
const std::string& fn_name,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
functional_apply_t fn,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::vector<at::TypePtr> packed_args_schema,
|
||||
bool is_custom_function = false,
|
||||
bool is_traceable = true) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
|
||||
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
|
||||
// output_metadata)
|
||||
virtual variable_list call_function(
|
||||
PyObject* py_compiler,
|
||||
const char* method_name,
|
||||
const std::string& fn_name,
|
||||
const variable_list& inputs,
|
||||
const ivalue_list& packed_args,
|
||||
const c10::IValue& output_metadata) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
virtual variable_list call_copy_slices_prologue(
|
||||
PyObject* py_compiler,
|
||||
const variable_list& inputs,
|
||||
const at::TensorGeometry& base,
|
||||
const at::TensorGeometry& view) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
virtual variable_list call_copy_slices_epilogue(
|
||||
PyObject* py_compiler,
|
||||
const std::vector<bool>& needs_input_grad,
|
||||
const at::Tensor& result,
|
||||
const variable_list& res,
|
||||
const at::Tensor& grad_slice) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
virtual at::Tensor call_unpack(
|
||||
PyObject* py_compiler,
|
||||
std::optional<size_t> hook_id,
|
||||
size_t hook_input_id) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
|
||||
struct TORCH_API PyCompilerGuard {
|
||||
explicit PyCompilerGuard(std::unique_ptr<PyCompilerInterface>&& impl);
|
||||
PyCompilerGuard(const PyCompilerGuard&) = delete;
|
||||
PyCompilerGuard& operator=(const PyCompilerGuard&) = delete;
|
||||
PyCompilerGuard(PyCompilerGuard&&) = delete;
|
||||
PyCompilerGuard& operator=(PyCompilerGuard&&) = delete;
|
||||
|
||||
~PyCompilerGuard();
|
||||
};
|
||||
|
||||
// including torch/csrc/autograd/engine.h breaks BC by somehow introducing
|
||||
// symbol resolution issues. Instead requiring downstream users to include
|
||||
// engine.h to access collect_input_metadata, we provide it here (with a
|
||||
// different name to avoid ambigous symbols...)
|
||||
TORCH_API std::vector<std::optional<InputMetadata>> get_input_metadata(
|
||||
const edge_list& edges);
|
||||
|
||||
struct SizeInput {
|
||||
// Note: int value is still needed when dynamic to pass as an arg
|
||||
enum DynType : uint8_t { STATIC = 0, DYNAMIC = 1 };
|
||||
@ -154,9 +232,14 @@ struct TensorArgs {
|
||||
}
|
||||
|
||||
TensorArg& lookup(const SavedVariable& sv) {
|
||||
auto it = _saved_variables.find(&sv);
|
||||
TORCH_INTERNAL_ASSERT(it != _saved_variables.end());
|
||||
return *it->second;
|
||||
if (auto it = _saved_variables.find(&sv); it != _saved_variables.end()) {
|
||||
// unpacked before graph
|
||||
return *it->second;
|
||||
}
|
||||
// unpacked in graph
|
||||
auto it2 = _saved_variables_proxies.find(&sv);
|
||||
TORCH_INTERNAL_ASSERT(it2 != _saved_variables_proxies.end());
|
||||
return *it2->second;
|
||||
}
|
||||
|
||||
TensorArg& add(const at::Tensor& tensor) {
|
||||
@ -164,9 +247,7 @@ struct TensorArgs {
|
||||
}
|
||||
|
||||
TensorArg& add(const SavedVariable& sv, const std::shared_ptr<Node>& node) {
|
||||
// TODO(jansel): Here we unpack the SavedVariable exactly once. This might
|
||||
// fire SavedTensor hooks. In the future we should try to put saved tensor
|
||||
// hooks into the graph.
|
||||
// no unpack hooks in this codepath
|
||||
at::Tensor tensor = sv.unpack(node);
|
||||
TensorArg& arg = add(tensor);
|
||||
_saved_variables.emplace(&sv, &arg);
|
||||
@ -185,6 +266,7 @@ struct TensorArgs {
|
||||
// Every TensorArg from this is actually owned by _args (or _undefined) and
|
||||
// that's why we have an un-owned pointer here.
|
||||
std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables;
|
||||
std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables_proxies;
|
||||
TensorArg _undefined;
|
||||
uint32_t _next_id = 1; // id=0 used by _undefined
|
||||
};
|
||||
@ -245,6 +327,11 @@ struct AutogradCompilerCall {
|
||||
return hooks.size() - 1;
|
||||
}
|
||||
|
||||
size_t emplace_packed_input(c10::SafePyObject&& input) {
|
||||
packed_inputs.emplace_back(std::move(input));
|
||||
return packed_inputs.size() - 1;
|
||||
}
|
||||
|
||||
void set_active_node_call_idx(size_t node_call_idx) {
|
||||
active_node_call_idx = node_call_idx;
|
||||
}
|
||||
@ -255,10 +342,16 @@ struct AutogradCompilerCall {
|
||||
LiftedIValueArgs lifted_ivalue_args;
|
||||
std::vector<int64_t> dyn_size_inputs;
|
||||
std::vector<c10::SafePyObject> hooks;
|
||||
std::vector<c10::SafePyObject> packed_inputs;
|
||||
NodeCalls node_calls;
|
||||
SizeInput::DynType default_dyn_type;
|
||||
// NodeCall id of each size, only when verbose logging is enabled
|
||||
std::vector<uint32_t> size_input_origins;
|
||||
std::unordered_map<const SavedVariable*, std::pair<size_t, size_t>>
|
||||
sv_to_hooks;
|
||||
// pynode -> backward and backward state idx
|
||||
std::unordered_map<const Node*, std::pair<size_t, std::optional<size_t>>>
|
||||
pynode_objs;
|
||||
};
|
||||
|
||||
class CompiledNodeArgs {
|
||||
@ -285,8 +378,19 @@ class CompiledNodeArgs {
|
||||
collect(_compiler.tensor_args.add(t));
|
||||
}
|
||||
void collect(const SavedVariable& sv, bool is_output) {
|
||||
collect(
|
||||
_compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
|
||||
if (auto hook_data = sv.retrieve_unpack_hook_data();
|
||||
hook_data.has_value()) {
|
||||
// hooks, unpack in graph
|
||||
auto& [hook, packed_input] = hook_data.value();
|
||||
size_t hook_id = _compiler.emplace_hook(std::move(hook));
|
||||
// rely on dynamo to dedup packed tensors from unpacked tensors
|
||||
size_t input_id = _compiler.emplace_packed_input(std::move(packed_input));
|
||||
_compiler.sv_to_hooks.emplace(&sv, std::make_pair(hook_id, input_id));
|
||||
} else {
|
||||
// no hooks, unpack now
|
||||
collect(
|
||||
_compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
|
||||
}
|
||||
}
|
||||
void collect(const c10::SymInt& t) {
|
||||
_compiler.add_size_input(t);
|
||||
@ -518,12 +622,17 @@ class CompiledNodeArgs {
|
||||
typeid(*node), _specialization_key, _specialization_key_size);
|
||||
}
|
||||
|
||||
size_t add_backward(c10::SafePyObject&& obj) {
|
||||
return _compiler.emplace_hook(std::move(obj));
|
||||
}
|
||||
|
||||
size_t add_backward_state(c10::SafePyObject&& obj) {
|
||||
return _compiler.emplace_hook(std::move(obj));
|
||||
void collect_pynode_objs(
|
||||
const Node* pynode,
|
||||
c10::SafePyObject&& bwd,
|
||||
std::optional<c10::SafePyObject>&& bwd_state) {
|
||||
size_t bwd_idx = _compiler.emplace_hook(std::move(bwd));
|
||||
std::optional<size_t> bwd_state_idx;
|
||||
if (auto state = std::move(bwd_state); state.has_value()) {
|
||||
bwd_state_idx = _compiler.emplace_hook(std::move(state.value()));
|
||||
}
|
||||
_compiler.pynode_objs.emplace(
|
||||
pynode, std::make_pair(bwd_idx, bwd_state_idx));
|
||||
}
|
||||
|
||||
void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) {
|
||||
@ -642,6 +751,13 @@ class SwapSavedVariables {
|
||||
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
|
||||
// allows tracing to happen, then swaps them back afterwards.
|
||||
public:
|
||||
std::pair<size_t, std::optional<size_t>> retrieve_pynode_objs(
|
||||
Node* pynode) const {
|
||||
auto it = compiler.pynode_objs.find(pynode);
|
||||
TORCH_INTERNAL_ASSERT(it != compiler.pynode_objs.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void before(at::Tensor& t) {
|
||||
TensorArg& arg = compiler.tensor_args.lookup(t);
|
||||
stashed_tensors.save(&t, std::move(t));
|
||||
@ -655,13 +771,26 @@ class SwapSavedVariables {
|
||||
}
|
||||
|
||||
void before(SavedVariable& t) {
|
||||
TensorArg& arg = compiler.tensor_args.lookup(t);
|
||||
stashed_variables.save(&t, std::move(t));
|
||||
if (arg.defined()) {
|
||||
if (auto it = compiler.sv_to_hooks.find(&t);
|
||||
it != compiler.sv_to_hooks.end()) {
|
||||
const auto& pyinterface =
|
||||
torch::dynamo::autograd::getPyCompilerInterface();
|
||||
auto proxy_tensor = pyinterface->call_unpack(
|
||||
get_py_compiler(), it->second.first, it->second.second);
|
||||
stashed_variables.save(&t, std::move(t));
|
||||
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
|
||||
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
|
||||
t = SavedVariable(arg.proxy_tensor, false);
|
||||
t = SavedVariable(proxy_tensor, false);
|
||||
at::SavedTensorDefaultHooks::set_tracing(prior);
|
||||
} else {
|
||||
// no hooks, was already unpacked
|
||||
TensorArg& arg = compiler.tensor_args.lookup(t);
|
||||
stashed_variables.save(&t, std::move(t));
|
||||
if (arg.defined()) {
|
||||
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
|
||||
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
|
||||
t = SavedVariable(arg.proxy_tensor, false);
|
||||
at::SavedTensorDefaultHooks::set_tracing(prior);
|
||||
}
|
||||
}
|
||||
}
|
||||
void after(SavedVariable& t) {
|
||||
@ -834,7 +963,7 @@ class SwapSavedVariables {
|
||||
const NodeCall& n)
|
||||
: compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
|
||||
|
||||
PyObject* get_py_compiler() {
|
||||
PyObject* get_py_compiler() const {
|
||||
return py_compiler;
|
||||
}
|
||||
|
||||
@ -1370,73 +1499,6 @@ struct PackedArgs {
|
||||
int64_t idx = 0;
|
||||
};
|
||||
|
||||
// This is a layer of indirection for calling methods on the Python
|
||||
// AutogradCompilerInstance (referred to as the "py_compiler") from
|
||||
// libtorch_cpu (where Python is not available).
|
||||
// A PyCompilerInterfaceImpl in libtorch_python subclasses it and
|
||||
// overrides the methods to do the actual calls back to Python.
|
||||
struct TORCH_API PyCompilerInterface {
|
||||
PyCompilerInterface() = default;
|
||||
PyCompilerInterface(const PyCompilerInterface&) = delete;
|
||||
PyCompilerInterface& operator=(const PyCompilerInterface&) = delete;
|
||||
PyCompilerInterface(PyCompilerInterface&&) = delete;
|
||||
PyCompilerInterface& operator=(PyCompilerInterface&&) = delete;
|
||||
virtual ~PyCompilerInterface() = default;
|
||||
|
||||
// Invokes py_compiler.bind_function
|
||||
virtual std::string bind_function(
|
||||
PyObject* py_compiler,
|
||||
const std::string& fn_name,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
functional_apply_t fn,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::vector<at::TypePtr> packed_args_schema,
|
||||
bool is_custom_function = false,
|
||||
bool is_traceable = true) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
|
||||
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
|
||||
// output_metadata)
|
||||
virtual variable_list call_function(
|
||||
PyObject* py_compiler,
|
||||
const char* method_name,
|
||||
const std::string& fn_name,
|
||||
const variable_list& inputs,
|
||||
const ivalue_list& packed_args,
|
||||
const c10::IValue& output_metadata) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
|
||||
virtual variable_list call_copy_slices_prologue(
|
||||
PyObject* py_compiler,
|
||||
const variable_list& inputs,
|
||||
const at::TensorGeometry& base,
|
||||
const at::TensorGeometry& view) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
virtual variable_list call_copy_slices_epilogue(
|
||||
PyObject* py_compiler,
|
||||
const std::vector<bool>& needs_input_grad,
|
||||
const at::Tensor& result,
|
||||
const variable_list& res,
|
||||
const at::Tensor& grad_slice) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
|
||||
TORCH_API void setPyCompilerInterface(
|
||||
std::unique_ptr<PyCompilerInterface>&& impl);
|
||||
TORCH_API void resetPyCompilerInterface();
|
||||
|
||||
// including torch/csrc/autograd/engine.h breaks BC by somehow introducing
|
||||
// symbol resolution issues. Instead requiring downstream users to include
|
||||
// engine.h to access collect_input_metadata, we provide it here (with a
|
||||
// different name to avoid ambigous symbols...)
|
||||
TORCH_API std::vector<std::optional<InputMetadata>> get_input_metadata(
|
||||
const edge_list& edges);
|
||||
|
||||
} // namespace torch::dynamo::autograd
|
||||
|
||||
template <>
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
||||
|
||||
/*
|
||||
[Note: Compiled Autograd]
|
||||
@ -52,6 +53,12 @@ Notes:
|
||||
namespace torch::dynamo::autograd {
|
||||
using c10::SymInt;
|
||||
|
||||
namespace {
|
||||
PyObject* the_autograd_compiler = nullptr;
|
||||
int default_dyn_type_int = 0;
|
||||
PyObject* python_verbose_logger = nullptr;
|
||||
} // namespace
|
||||
|
||||
// List[Optional[Tensor]] in Python can't be directly parsed into a
|
||||
// List[Tensor], so we need to do this conversion manually.
|
||||
static std::vector<at::Tensor> toTensorList(
|
||||
@ -203,6 +210,16 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface {
|
||||
auto output = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
|
||||
return toTensorList(output);
|
||||
}
|
||||
at::Tensor call_unpack(
|
||||
PyObject* py_compiler,
|
||||
std::optional<size_t> hook_id,
|
||||
size_t hook_input_id) override {
|
||||
py::handle handle(py_compiler);
|
||||
py::object proxy = handle.attr("unpack_hook")(hook_id, hook_input_id);
|
||||
auto tmp = py::cast<std::optional<at::Tensor>>(proxy);
|
||||
TORCH_INTERNAL_ASSERT(tmp.has_value());
|
||||
return tmp.value();
|
||||
}
|
||||
};
|
||||
|
||||
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
|
||||
@ -213,7 +230,7 @@ static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
|
||||
return pyinput;
|
||||
}
|
||||
|
||||
static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
|
||||
static PyObject* convert_pyobj_list(std::vector<c10::SafePyObject>& inputs) {
|
||||
// inplace, consumes the input hooks
|
||||
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
@ -257,9 +274,6 @@ static variable_list validate_outputs(
|
||||
return new_outputs;
|
||||
}
|
||||
|
||||
// snapshot of python verbose logging toggle
|
||||
static PyObject* python_verbose_logger = nullptr;
|
||||
|
||||
struct PythonLogger {
|
||||
PythonLogger() = delete;
|
||||
explicit PythonLogger(PyObject* logger) : logger_(logger) {
|
||||
@ -540,8 +554,6 @@ struct InputBuffers : public std::unordered_map<Node*, InputBuffer> {
|
||||
}
|
||||
};
|
||||
|
||||
static PyObject* the_autograd_compiler = nullptr;
|
||||
static int default_dyn_type_int = 0;
|
||||
static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args);
|
||||
|
||||
static PyObject* clear_cache(PyObject* dummy, PyObject* args) {
|
||||
@ -654,7 +666,7 @@ static PyObject* wrap_string_list(const std::vector<std::string>& strs) {
|
||||
return pystrs;
|
||||
}
|
||||
|
||||
std::string unwrap_string(PyObject* pystr) {
|
||||
static std::string unwrap_string(PyObject* pystr) {
|
||||
TORCH_INTERNAL_ASSERT(PyUnicode_Check(pystr));
|
||||
const char* str = PyUnicode_AsUTF8(pystr);
|
||||
TORCH_INTERNAL_ASSERT(str != nullptr);
|
||||
@ -790,14 +802,18 @@ static SizeInput::DynType get_default_dyn_type() {
|
||||
// Only call this function while holding GIL
|
||||
static CacheNode* _compiled_autograd_impl(
|
||||
const std::shared_ptr<Node>& graph_root,
|
||||
GraphTask& graph_task,
|
||||
const GraphTask& graph_task,
|
||||
bool accumulate_grad,
|
||||
const edge_list& output_edges,
|
||||
THPObjectPtr* graph_arg_inputs,
|
||||
THPObjectPtr* graph_arg_sizes,
|
||||
THPObjectPtr* graph_arg_ivalue_args,
|
||||
THPObjectPtr* graph_arg_hooks) {
|
||||
std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
|
||||
THPObjectPtr* graph_arg_hooks,
|
||||
THPObjectPtr* graph_arg_packed_inputs) {
|
||||
const std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
|
||||
std::unordered_map<Node*, int> visited_dependencies;
|
||||
visited_dependencies.reserve(dependencies.size());
|
||||
|
||||
std::vector<std::shared_ptr<Node>> worklist{graph_root};
|
||||
AutogradCompilerCall compiler_call(get_default_dyn_type());
|
||||
|
||||
@ -809,8 +825,8 @@ static CacheNode* _compiled_autograd_impl(
|
||||
}
|
||||
const bool check_exec_info = !graph_task.exec_info_.empty();
|
||||
CacheNode* cache = CacheNode::root();
|
||||
std::vector<NodeCall*> calls;
|
||||
calls.reserve(
|
||||
std::vector<NodeCall*> ordered_calls;
|
||||
ordered_calls.reserve(
|
||||
check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1);
|
||||
|
||||
int i = 0;
|
||||
@ -820,7 +836,7 @@ static CacheNode* _compiled_autograd_impl(
|
||||
std::shared_ptr<Node> fn = std::move(worklist.back());
|
||||
worklist.pop_back();
|
||||
NodeCall& call = compiler_call.node_calls.lookup(fn);
|
||||
calls.emplace_back(&call);
|
||||
ordered_calls.emplace_back(&call);
|
||||
|
||||
{ // update cache and gather args into `compiler_call`
|
||||
CompiledNodeArgs node_args(compiler_call, call);
|
||||
@ -829,6 +845,7 @@ static CacheNode* _compiled_autograd_impl(
|
||||
}
|
||||
node_args.collect(call);
|
||||
if (node_args.cond(call.needed)) {
|
||||
std::cout << "collecting " << fn->name() << std::endl;
|
||||
fn->compiled_args(node_args);
|
||||
node_args.collect(call.node->next_edges());
|
||||
}
|
||||
@ -861,9 +878,9 @@ static CacheNode* _compiled_autograd_impl(
|
||||
}
|
||||
}
|
||||
auto it = dependencies.find(edge.function.get());
|
||||
TORCH_INTERNAL_ASSERT(it != dependencies.end());
|
||||
if (--it->second == 0) {
|
||||
dependencies.erase(it);
|
||||
int count = ++visited_dependencies[it->first];
|
||||
TORCH_INTERNAL_ASSERT(count <= it->second);
|
||||
if (count == it->second) {
|
||||
worklist.emplace_back(edge.function);
|
||||
}
|
||||
}
|
||||
@ -876,8 +893,8 @@ static CacheNode* _compiled_autograd_impl(
|
||||
TORCH_INTERNAL_ASSERT(!vlogger.has_value() || compile_reason.has_value());
|
||||
ClosingTHPObjectPtr py_compiler(
|
||||
check(PyObject_CallNoArgs((the_autograd_compiler))));
|
||||
|
||||
setPyCompilerInterface(std::make_unique<PyCompilerInterfaceImpl>());
|
||||
PyCompilerGuard py_compiler_guard(
|
||||
std::make_unique<PyCompilerInterfaceImpl>());
|
||||
|
||||
TraceState state = call_begin_capture(
|
||||
py_compiler,
|
||||
@ -887,8 +904,8 @@ static CacheNode* _compiled_autograd_impl(
|
||||
std::move(compile_reason));
|
||||
InputBuffers input_buffers;
|
||||
|
||||
for (size_t i = 0; i < calls.size(); i++) {
|
||||
NodeCall& call = *calls[i];
|
||||
for (size_t i = 0; i < ordered_calls.size(); i++) {
|
||||
NodeCall& call = *ordered_calls[i];
|
||||
|
||||
std::string _node_name = call.node->name();
|
||||
THPObjectPtr node_name(PyUnicode_FromString(_node_name.data()));
|
||||
@ -948,6 +965,7 @@ static CacheNode* _compiled_autograd_impl(
|
||||
}
|
||||
|
||||
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
|
||||
std::cout << "apply_with_saved " << call.node->name() << std::endl;
|
||||
variable_list outputs = call.node->apply_with_saved(inputs, saved);
|
||||
saved.debug_asserts();
|
||||
saved.before(call.node->next_edges());
|
||||
@ -1024,7 +1042,6 @@ static CacheNode* _compiled_autograd_impl(
|
||||
}
|
||||
}
|
||||
|
||||
resetPyCompilerInterface();
|
||||
PyObject* res = check(call_end_capture(py_compiler, state.outputs));
|
||||
TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
|
||||
TORCH_CHECK(
|
||||
@ -1043,7 +1060,8 @@ static CacheNode* _compiled_autograd_impl(
|
||||
|
||||
// TODO(jansel): clear grads we will overwrite below
|
||||
if (!graph_task.keep_graph_) {
|
||||
for (auto& call : calls) {
|
||||
for (auto& call : ordered_calls) {
|
||||
// Once we release variables, we can no longer fallback to eager autograd
|
||||
call->node->release_variables();
|
||||
}
|
||||
}
|
||||
@ -1052,7 +1070,8 @@ static CacheNode* _compiled_autograd_impl(
|
||||
*graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);
|
||||
*graph_arg_ivalue_args =
|
||||
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args);
|
||||
*graph_arg_hooks = convert_hook_list(compiler_call.hooks);
|
||||
*graph_arg_hooks = convert_pyobj_list(compiler_call.hooks);
|
||||
*graph_arg_packed_inputs = convert_pyobj_list(compiler_call.packed_inputs);
|
||||
return cache;
|
||||
}
|
||||
|
||||
@ -1078,12 +1097,20 @@ struct LockGuardWithErrorLogs {
|
||||
|
||||
static variable_list compiled_autograd(
|
||||
const std::shared_ptr<Node>& graph_root,
|
||||
GraphTask& graph_task,
|
||||
const GraphTask& graph_task,
|
||||
bool accumulate_grad,
|
||||
const edge_list& output_edges) {
|
||||
TORCH_CHECK(
|
||||
c10::impl::TorchDispatchModeTLS::stack_len() == 0,
|
||||
"TorchDispatchMode not yet implemented for compiled autograd")
|
||||
// std::cout << "called compiled autograd" << std::endl;
|
||||
// std::vector<std::optional<std::shared_ptr<c10::impl::PyObject_TorchDispatchMode>>> infra_modes;
|
||||
|
||||
// infra_modes.emplace_back(c10::impl::TorchDispatchModeTLS::unset_mode(c10::impl::TorchDispatchModeKey::FAKE));
|
||||
// infra_modes.emplace_back(c10::impl::TorchDispatchModeTLS::unset_mode(c10::impl::TorchDispatchModeKey::PROXY));
|
||||
// infra_modes.emplace_back(c10::impl::TorchDispatchModeTLS::unset_mode(c10::impl::TorchDispatchModeKey::FUNCTIONAL));
|
||||
|
||||
// TORCH_INTERNAL_ASSERT(c10::impl::TorchDispatchModeTLS::stack_len() == 0);
|
||||
// TORCH_CHECK(
|
||||
// c10::impl::TorchDispatchModeTLS::stack_len() == 0,
|
||||
// "TorchDispatchMode not yet implemented for compiled autograd")
|
||||
static std::mutex mtx;
|
||||
LockGuardWithErrorLogs lock_guard(mtx);
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
@ -1093,6 +1120,7 @@ static variable_list compiled_autograd(
|
||||
THPObjectPtr sizes;
|
||||
THPObjectPtr ivalue_args;
|
||||
THPObjectPtr hooks;
|
||||
THPObjectPtr packed_inputs;
|
||||
CacheNode* cache = _compiled_autograd_impl(
|
||||
graph_root,
|
||||
graph_task,
|
||||
@ -1101,7 +1129,8 @@ static variable_list compiled_autograd(
|
||||
&inputs,
|
||||
&sizes,
|
||||
&ivalue_args,
|
||||
&hooks);
|
||||
&hooks,
|
||||
&packed_inputs);
|
||||
|
||||
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
|
||||
cache->runtime_wrapper.get(),
|
||||
@ -1110,9 +1139,25 @@ static variable_list compiled_autograd(
|
||||
sizes.get(),
|
||||
ivalue_args.get(),
|
||||
hooks.get(),
|
||||
packed_inputs.get(),
|
||||
NULL)));
|
||||
variable_list outputs = THPVariable_UnpackList(pyresult);
|
||||
TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size());
|
||||
cache->clear();
|
||||
|
||||
|
||||
// if (infra_modes[0].has_value()) {
|
||||
// std::cout << "setting FAKE" << std::endl;
|
||||
// c10::impl::TorchDispatchModeTLS::set_mode(infra_modes[0].value(), c10::impl::TorchDispatchModeKey::FAKE);
|
||||
// }
|
||||
// if (infra_modes[1].has_value()) {
|
||||
// std::cout << "setting PROXY" << std::endl;
|
||||
// c10::impl::TorchDispatchModeTLS::set_mode(infra_modes[1].value(), c10::impl::TorchDispatchModeKey::PROXY);
|
||||
// }
|
||||
// if (infra_modes[2].has_value()) {
|
||||
// std::cout << "setting FUNCTIONAL" << std::endl;
|
||||
// c10::impl::TorchDispatchModeTLS::set_mode(infra_modes[2].value(), c10::impl::TorchDispatchModeKey::FUNCTIONAL);
|
||||
// }
|
||||
return outputs;
|
||||
}
|
||||
|
||||
|
||||
@ -361,6 +361,6 @@ def tensorify_python_scalars(
|
||||
metrics_context.set("tensorify_float_success", True, overwrite=True)
|
||||
raise TensorifyScalarRestartAnalysis
|
||||
|
||||
graph_code_log.debug(
|
||||
"%s", lazy_format_graph_code("tensorify_python_scalars", gm, colored=True)
|
||||
)
|
||||
# graph_code_log.debug(
|
||||
# "%s", lazy_format_graph_code("tensorify_python_scalars", gm, colored=True)
|
||||
# )
|
||||
|
||||
Reference in New Issue
Block a user