mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ca] trace saved variable unpacking (#147242)
## Before Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op. ## After We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as: ```python # pseudocode class SavedVariable: def unpack(self): if self.hook: return self.hook(self.packed_data) else: return self.packed_data # This approach won't directly work when we add support for Forward AD or double-backward. ``` Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution. All tests pass when running the CA graph directly, the remaining issues are in Dynamo. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
08f4c1a233
commit
0a2da008f8
@ -10,6 +10,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from importlib.machinery import SourceFileLoader
|
from importlib.machinery import SourceFileLoader
|
||||||
@ -48,11 +49,17 @@ from torch.testing._internal.logging_utils import logs_to_string
|
|||||||
# note: these tests are not run on windows due to inductor_utils.HAS_CPU
|
# note: these tests are not run on windows due to inductor_utils.HAS_CPU
|
||||||
|
|
||||||
|
|
||||||
def make_compiler_fn(fullgraph=True, dynamic=True, backend="inductor"):
|
def make_compiler_fn(
|
||||||
assert backend in ["inductor", "aot_eager"]
|
fullgraph=True, dynamic=True, backend="inductor", gm_hook=lambda gm: None
|
||||||
|
):
|
||||||
|
assert backend in ["inductor", "aot_eager", "ca_eager"]
|
||||||
|
|
||||||
def _compiler_fn(gm):
|
def _compiler_fn(gm):
|
||||||
"""Same as torch.compile() but counts number of compiles"""
|
"""Same as torch.compile() but counts number of compiles"""
|
||||||
|
gm_hook(gm)
|
||||||
|
|
||||||
|
if backend == "ca_eager":
|
||||||
|
return gm
|
||||||
|
|
||||||
def _inner_compiler(gm_, example_inputs_):
|
def _inner_compiler(gm_, example_inputs_):
|
||||||
counters["compiled_autograd"]["compiles"] += 1
|
counters["compiled_autograd"]["compiles"] += 1
|
||||||
@ -112,7 +119,10 @@ class TestCompiledAutograd(TestCase):
|
|||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
expected = list(fn())
|
expected = list(fn())
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
with compiled_autograd._enable(compiler_fn):
|
with compiled_autograd._enable(compiler_fn), mock.patch(
|
||||||
|
"torch._functorch.aot_autograd.AOT_COUNTER",
|
||||||
|
new_callable=itertools.count,
|
||||||
|
):
|
||||||
opt_fn = torch.compile(fn) if compile_fn else fn
|
opt_fn = torch.compile(fn) if compile_fn else fn
|
||||||
actual = list(opt_fn())
|
actual = list(opt_fn())
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
@ -915,7 +925,8 @@ main()
|
|||||||
inputs=[param, activ],
|
inputs=[param, activ],
|
||||||
sizes=(),
|
sizes=(),
|
||||||
scalars=(),
|
scalars=(),
|
||||||
hooks=(),
|
hooks=[],
|
||||||
|
packed_inputs=[],
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
handle.remove()
|
handle.remove()
|
||||||
@ -3322,7 +3333,10 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
|||||||
graphs.append(gm)
|
graphs.append(gm)
|
||||||
return inner_compiler_fn(gm)
|
return inner_compiler_fn(gm)
|
||||||
|
|
||||||
with compiled_autograd._enable(compiler_fn):
|
with compiled_autograd._enable(compiler_fn), mock.patch(
|
||||||
|
"torch._functorch.aot_autograd.AOT_COUNTER",
|
||||||
|
new_callable=itertools.count,
|
||||||
|
):
|
||||||
res = fn(x)
|
res = fn(x)
|
||||||
res.sum().backward()
|
res.sum().backward()
|
||||||
|
|
||||||
@ -3336,7 +3350,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
|||||||
graph_code,
|
graph_code,
|
||||||
"""\
|
"""\
|
||||||
class CompiledAutograd0(torch.nn.Module):
|
class CompiledAutograd0(torch.nn.Module):
|
||||||
def forward(self, inputs, sizes, scalars, hooks):
|
def forward(self, inputs, sizes, scalars, hooks, packed_data):
|
||||||
getitem = inputs[0]
|
getitem = inputs[0]
|
||||||
getitem_1 = inputs[1]
|
getitem_1 = inputs[1]
|
||||||
getitem_2 = inputs[2]
|
getitem_2 = inputs[2]
|
||||||
@ -3511,9 +3525,7 @@ class CompiledAutograd0(torch.nn.Module):
|
|||||||
fn, count=2, compiler_fn=make_compiler_fn(backend="aot_eager")
|
fn, count=2, compiler_fn=make_compiler_fn(backend="aot_eager")
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_saved_tensor_unpack_hook_ordering(self):
|
def test_saved_tensor_unpack_hook_ordering(self):
|
||||||
# not the correct behaviour, I'm just preventing this from changing silently
|
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return x * y
|
return x * y
|
||||||
|
|
||||||
@ -3531,8 +3543,6 @@ class CompiledAutograd0(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def tensor_hook(_):
|
def tensor_hook(_):
|
||||||
# in eager, tensor_hook is fired before unpack_hook
|
|
||||||
# but in compiled autograd, tensor_hook is lifted whereas unpack_hook is not
|
|
||||||
self.assertEqual(unpack_count, 0)
|
self.assertEqual(unpack_count, 0)
|
||||||
|
|
||||||
x = torch.ones(4, requires_grad=True)
|
x = torch.ones(4, requires_grad=True)
|
||||||
@ -3544,21 +3554,252 @@ class CompiledAutograd0(torch.nn.Module):
|
|||||||
self.assertEqual(pack_count, 1)
|
self.assertEqual(pack_count, 1)
|
||||||
self.assertEqual(unpack_count, 0)
|
self.assertEqual(unpack_count, 0)
|
||||||
loss = out_test.sum()
|
loss = out_test.sum()
|
||||||
loss.register_hook(tensor_hook)
|
loss.register_hook(
|
||||||
|
tensor_hook
|
||||||
|
) # scheduled to fire before any saved activations
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.assertEqual(pack_count, 1)
|
self.assertEqual(pack_count, 1)
|
||||||
self.assertEqual(unpack_count, 1)
|
self.assertEqual(unpack_count, 1)
|
||||||
|
|
||||||
def test_reentrant_checkpointing(self):
|
@parametrize("reentrant", (True, False))
|
||||||
def fn(x):
|
def test_checkpointing_simple(self, reentrant):
|
||||||
y = x.sin()
|
def fn():
|
||||||
z = y.cos()
|
def _fn(x):
|
||||||
return (y * z).sum()
|
y = x.sin()
|
||||||
|
z = y.cos()
|
||||||
|
return (y * z).sum()
|
||||||
|
|
||||||
inp = torch.rand(10, 10, requires_grad=True)
|
inp = torch.rand(10, 10, requires_grad=True)
|
||||||
out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True)
|
out = torch.utils.checkpoint.checkpoint(_fn, inp, use_reentrant=reentrant)
|
||||||
with torch._dynamo.compiled_autograd._enable(torch.compile):
|
|
||||||
out.backward()
|
out.backward()
|
||||||
|
yield inp.grad
|
||||||
|
|
||||||
|
if reentrant:
|
||||||
|
self.check_output_and_recompiles(
|
||||||
|
fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# dynamo issues, just run the CA graph directly for now
|
||||||
|
def check(gm):
|
||||||
|
graph_code = normalize_gm(gm.print_readable(print_output=False))
|
||||||
|
self.assertExpectedInline(
|
||||||
|
graph_code,
|
||||||
|
"""\
|
||||||
|
class CompiledAutograd0(torch.nn.Module):
|
||||||
|
def forward(self, inputs, sizes, scalars, hooks, packed_data):
|
||||||
|
getitem = inputs[0]
|
||||||
|
getitem_1 = inputs[1]; inputs = None
|
||||||
|
|
||||||
|
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
|
||||||
|
getitem_2 = validate_outputs[0]; validate_outputs = None
|
||||||
|
|
||||||
|
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_2], [True], [10, 10]); getitem_2 = None
|
||||||
|
getitem_3 = sum_backward0[0]; sum_backward0 = None
|
||||||
|
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_3], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_3 = None
|
||||||
|
getitem_4 = validate_outputs_1[0]; validate_outputs_1 = None
|
||||||
|
|
||||||
|
getitem_5 = hooks[0]
|
||||||
|
getitem_6 = packed_data[0]
|
||||||
|
getitem_7 = hooks[1]
|
||||||
|
getitem_8 = packed_data[1]
|
||||||
|
call_hook = torch__dynamo_external_utils_call_hook(getitem_5, getitem_6, hook_type = 'unpack_hook'); getitem_5 = getitem_6 = None
|
||||||
|
call_hook_1 = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None
|
||||||
|
mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_4], [True, True], call_hook, 6, call_hook_1, 6); getitem_4 = call_hook = call_hook_1 = None
|
||||||
|
getitem_9 = mul_backward0[0]
|
||||||
|
getitem_10 = mul_backward0[1]; mul_backward0 = None
|
||||||
|
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_9, getitem_10], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False), ((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_9 = getitem_10 = None
|
||||||
|
getitem_11 = validate_outputs_2[0]
|
||||||
|
getitem_12 = validate_outputs_2[1]; validate_outputs_2 = None
|
||||||
|
|
||||||
|
getitem_13 = hooks[2]
|
||||||
|
getitem_14 = packed_data[2]
|
||||||
|
call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_13, getitem_14, hook_type = 'unpack_hook'); getitem_13 = getitem_14 = None
|
||||||
|
cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_12], [True], call_hook_2); getitem_12 = call_hook_2 = None
|
||||||
|
getitem_15 = cos_backward0[0]; cos_backward0 = None
|
||||||
|
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_15 = None
|
||||||
|
getitem_16 = validate_outputs_3[0]; validate_outputs_3 = None
|
||||||
|
add = torch.add(getitem_11, getitem_16); getitem_11 = getitem_16 = None
|
||||||
|
|
||||||
|
getitem_17 = hooks[3]; hooks = None
|
||||||
|
getitem_18 = packed_data[3]; packed_data = None
|
||||||
|
call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_17, getitem_18, hook_type = 'unpack_hook'); getitem_17 = getitem_18 = None
|
||||||
|
sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None
|
||||||
|
getitem_19 = sin_backward0[0]; sin_backward0 = None
|
||||||
|
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_19], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_19 = None
|
||||||
|
getitem_20 = validate_outputs_4[0]; validate_outputs_4 = None
|
||||||
|
|
||||||
|
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_20); getitem_1 = getitem_20 = accumulate_grad_ = None
|
||||||
|
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
|
||||||
|
return []
|
||||||
|
""", # noqa: B950
|
||||||
|
)
|
||||||
|
|
||||||
|
self.check_output_and_recompiles(
|
||||||
|
fn,
|
||||||
|
count=[1, 0],
|
||||||
|
compiler_fn=make_compiler_fn(backend="ca_eager", gm_hook=check),
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not HAS_CUDA, "requires cuda")
|
||||||
|
def test_cpu_offloading(self):
|
||||||
|
def fn():
|
||||||
|
def pack(x):
|
||||||
|
return x.cpu()
|
||||||
|
|
||||||
|
def unpack(x):
|
||||||
|
return x.cuda()
|
||||||
|
|
||||||
|
class MyMatMul(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return torch.matmul(x, x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(x,) = ctx.saved_tensors
|
||||||
|
return grad_out * x
|
||||||
|
|
||||||
|
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||||
|
for i in [10, 100, 10, 20, 30]:
|
||||||
|
x = torch.randn(i, requires_grad=True).cuda()
|
||||||
|
MyMatMul.apply(x).sum().backward()
|
||||||
|
yield x.grad
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
def check(gm):
|
||||||
|
nonlocal i
|
||||||
|
if i == 0:
|
||||||
|
i += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
graph_code = normalize_gm(gm.print_readable(print_output=False))
|
||||||
|
self.assertExpectedInline(
|
||||||
|
graph_code,
|
||||||
|
"""\
|
||||||
|
class CompiledAutograd1(torch.nn.Module):
|
||||||
|
def forward(self, inputs, sizes, scalars, hooks, packed_data):
|
||||||
|
getitem = inputs[0]
|
||||||
|
getitem_1 = inputs[1]; inputs = None
|
||||||
|
getitem_2 = sizes[0]; getitem_2 = None
|
||||||
|
getitem_3 = sizes[1]
|
||||||
|
getitem_4 = sizes[2]; sizes = None
|
||||||
|
|
||||||
|
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cuda', index=0), 6, 0, None), [], False)]); getitem = None
|
||||||
|
getitem_5 = validate_outputs[0]; validate_outputs = None
|
||||||
|
|
||||||
|
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], []); getitem_5 = None
|
||||||
|
getitem_6 = sum_backward0[0]; sum_backward0 = None
|
||||||
|
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cuda', index=0), 6, 0, None), [], False)]); getitem_6 = None
|
||||||
|
getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None
|
||||||
|
|
||||||
|
getitem_8 = hooks[0]
|
||||||
|
getitem_9 = packed_data[0]; packed_data = None
|
||||||
|
getitem_10 = hooks[1]; hooks = None
|
||||||
|
call_hook = torch__dynamo_external_utils_call_hook(getitem_8, getitem_9, hook_type = 'unpack_hook'); getitem_8 = getitem_9 = None
|
||||||
|
call_backward = torch__dynamo_external_utils_call_backward(getitem_10, (call_hook,), getitem_7); getitem_10 = call_hook = getitem_7 = None
|
||||||
|
getitem_12 = call_backward[0]; call_backward = None
|
||||||
|
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_12], [((None, None, device(type='cuda', index=0), 6, 0, None), [getitem_3], False)]); getitem_12 = getitem_3 = None
|
||||||
|
getitem_13 = validate_outputs_2[0]; validate_outputs_2 = None
|
||||||
|
|
||||||
|
to_copy_backward0 = torch__dynamo_compiled_autograd_ops_ToCopyBackward0([getitem_13], [True], (None, None, device(type='cpu'), 6, 0, None)); getitem_13 = None
|
||||||
|
getitem_14 = to_copy_backward0[0]; to_copy_backward0 = None
|
||||||
|
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_14], [((None, None, device(type='cpu'), 6, 0, None), [getitem_4], False)]); getitem_14 = getitem_4 = None
|
||||||
|
getitem_15 = validate_outputs_3[0]; validate_outputs_3 = None
|
||||||
|
|
||||||
|
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_15); getitem_1 = getitem_15 = accumulate_grad_ = None
|
||||||
|
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
|
||||||
|
return []
|
||||||
|
""", # noqa: B950
|
||||||
|
)
|
||||||
|
|
||||||
|
self.check_output_and_recompiles(
|
||||||
|
fn, count=2, compiler_fn=make_compiler_fn(gm_hook=check)
|
||||||
|
)
|
||||||
|
|
||||||
|
@skipIfWindows(msg="temp dir not compatible")
|
||||||
|
def test_disk_offloading(self):
|
||||||
|
with tempfile.TemporaryDirectory() as d:
|
||||||
|
|
||||||
|
def fn():
|
||||||
|
pack_count = 0
|
||||||
|
|
||||||
|
def pack(x):
|
||||||
|
nonlocal pack_count
|
||||||
|
path = f"{d}/{pack_count}.pt"
|
||||||
|
torch.save(x, path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def unpack(path):
|
||||||
|
x = torch.load(path)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class MyMatMul(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return torch.matmul(x, x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(x,) = ctx.saved_tensors
|
||||||
|
return grad_out * x
|
||||||
|
|
||||||
|
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||||
|
for i in [10, 100, 10, 20, 30]:
|
||||||
|
x = torch.randn(i, requires_grad=True)
|
||||||
|
MyMatMul.apply(x).sum().backward()
|
||||||
|
yield x.grad
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
def check(gm):
|
||||||
|
nonlocal i
|
||||||
|
if i == 0:
|
||||||
|
i += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
graph_code = normalize_gm(gm.print_readable(print_output=False))
|
||||||
|
self.assertExpectedInline(
|
||||||
|
graph_code,
|
||||||
|
"""\
|
||||||
|
class CompiledAutograd1(torch.nn.Module):
|
||||||
|
def forward(self, inputs, sizes, scalars, hooks, packed_data):
|
||||||
|
getitem = inputs[0]
|
||||||
|
getitem_1 = inputs[1]; inputs = None
|
||||||
|
getitem_2 = sizes[0]; getitem_2 = None
|
||||||
|
getitem_3 = sizes[1]; sizes = None
|
||||||
|
|
||||||
|
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
|
||||||
|
getitem_4 = validate_outputs[0]; validate_outputs = None
|
||||||
|
|
||||||
|
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_4], [True], []); getitem_4 = None
|
||||||
|
getitem_5 = sum_backward0[0]; sum_backward0 = None
|
||||||
|
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_5], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem_5 = None
|
||||||
|
getitem_6 = validate_outputs_1[0]; validate_outputs_1 = None
|
||||||
|
|
||||||
|
getitem_7 = hooks[0]
|
||||||
|
getitem_8 = packed_data[0]; packed_data = None
|
||||||
|
getitem_9 = hooks[1]; hooks = None
|
||||||
|
call_hook = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None
|
||||||
|
call_backward = torch__dynamo_external_utils_call_backward(getitem_9, (call_hook,), getitem_6); getitem_9 = call_hook = getitem_6 = None
|
||||||
|
getitem_11 = call_backward[0]; call_backward = None
|
||||||
|
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_11], [((None, None, device(type='cpu'), 6, 0, None), [getitem_3], False)]); getitem_11 = getitem_3 = None
|
||||||
|
getitem_12 = validate_outputs_2[0]; validate_outputs_2 = None
|
||||||
|
|
||||||
|
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_12); getitem_1 = getitem_12 = accumulate_grad_ = None
|
||||||
|
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
|
||||||
|
return []
|
||||||
|
""", # noqa: B950
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1 graph break on torch.load -> 2 dynamo graphs
|
||||||
|
self.check_output_and_recompiles(
|
||||||
|
fn,
|
||||||
|
count=[2, 4],
|
||||||
|
compiler_fn=make_compiler_fn(fullgraph=False, gm_hook=check),
|
||||||
|
)
|
||||||
|
|
||||||
@skipIfWindows(msg="node name demangling inconsistent on windows")
|
@skipIfWindows(msg="node name demangling inconsistent on windows")
|
||||||
def test_backward_hook_relative_ordering_partial(self):
|
def test_backward_hook_relative_ordering_partial(self):
|
||||||
@ -3617,7 +3858,7 @@ class CompiledAutograd0(torch.nn.Module):
|
|||||||
|
|
||||||
self.check_output_and_recompiles(fn)
|
self.check_output_and_recompiles(fn)
|
||||||
|
|
||||||
def test_sac(self):
|
def test_checkpointing_sac(self):
|
||||||
# circular import
|
# circular import
|
||||||
from torch.utils.checkpoint import (
|
from torch.utils.checkpoint import (
|
||||||
checkpoint,
|
checkpoint,
|
||||||
@ -3666,7 +3907,9 @@ class CompiledAutograd0(torch.nn.Module):
|
|||||||
yield model.layer4.weight.grad
|
yield model.layer4.weight.grad
|
||||||
yield model.layer4.bias.grad
|
yield model.layer4.bias.grad
|
||||||
|
|
||||||
self.check_output_and_recompiles(fn)
|
self.check_output_and_recompiles(
|
||||||
|
fn, count=[1, 5], compiler_fn=make_compiler_fn(fullgraph=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_test_module(name):
|
def load_test_module(name):
|
||||||
@ -3754,6 +3997,26 @@ known_graph_breaks_tests = {
|
|||||||
"test_deep_reentrant", # reentrant .backward
|
"test_deep_reentrant", # reentrant .backward
|
||||||
"test_reentrant_priority", # reentrant .backward
|
"test_reentrant_priority", # reentrant .backward
|
||||||
"test_simple_reentrant", # reentrant .backward
|
"test_simple_reentrant", # reentrant .backward
|
||||||
|
"test_checkpoint_detects_non_determinism", # unpack hook in skip files
|
||||||
|
"test_checkpoint_valid_reset_on_error", # unpack hook in skip files
|
||||||
|
"test_checkpointing_non_reentrant_autocast_cpu", # unpack hook in skip files
|
||||||
|
"test_checkpointing_non_reentrant_autocast_gpu", # unpack hook in skip files
|
||||||
|
"test_checkpointing_without_reentrant_arbitrary_input_output", # unpack hook in skip files
|
||||||
|
"test_checkpointing_without_reentrant_correct_grad", # unpack hook in skip files
|
||||||
|
"test_checkpointing_without_reentrant_custom_function_works", # unpack hook in skip files
|
||||||
|
"test_checkpointing_without_reentrant_dataparallel", # _get_device_index in skip files
|
||||||
|
"test_checkpointing_without_reentrant_detached_tensor_use_reentrant_True", # reentrant .backward
|
||||||
|
"test_checkpointing_without_reentrant_parameter_used_in_an_out", # unpack hook in skip files
|
||||||
|
"test_checkpointing_without_reentrant_with_context_fn", # unpack hook in skip files
|
||||||
|
"test_save_on_cpu_and_checkpoint", # unpack hook in skip files
|
||||||
|
"test_saved_tensor_hooks_custom_error_propagation", # CustomError
|
||||||
|
"test_access_saved_tensor_twice_without_recomputation_works", # unpack hook in skip files
|
||||||
|
"test_saved_tensor_hooks_extra_enter_during_bw_no_leak", # ctx in skip files
|
||||||
|
"test_saved_tensor_hooks_extra_exit_during_bw_no_crash", # ctx in skip files
|
||||||
|
"test_checkpointing", # reentrant .backward
|
||||||
|
"test_checkpointing_without_reentrant_input_requires_grad_False", # reentrant .backward
|
||||||
|
"test_checkpointing_without_reentrant_input_requires_grad_True", # reentrant .backward
|
||||||
|
"test_checkpointing_without_reentrant_memory_savings", # reentrant .backward
|
||||||
}
|
}
|
||||||
|
|
||||||
test_contexts = {
|
test_contexts = {
|
||||||
@ -3764,9 +4027,7 @@ test_contexts = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# These groups of tests aren't supported yet
|
# These groups of tests aren't supported yet
|
||||||
known_failures_re = re.compile(
|
known_failures_re = re.compile(r"^test_(sparse|profiler|gradcheck|named_tensor)")
|
||||||
r"^test_(sparse|profiler|gradcheck|checkpoint|named_tensor)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Bugs needing investigation:
|
# Bugs needing investigation:
|
||||||
skipped_tests = {
|
skipped_tests = {
|
||||||
@ -3837,7 +4098,7 @@ known_failing_tests = {
|
|||||||
# IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
|
# IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
|
||||||
"test_grad_nonleaf_register_hook",
|
"test_grad_nonleaf_register_hook",
|
||||||
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
|
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
|
||||||
# Category: Dynamo
|
# Category: Dynamo (pass when directly running CA graph)
|
||||||
"test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None
|
"test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None
|
||||||
"test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START
|
"test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START
|
||||||
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
|
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
|
||||||
@ -3849,7 +4110,20 @@ known_failing_tests = {
|
|||||||
"test_return_duplicate", # gradient batching rule not implemented for aten::sym_size.int
|
"test_return_duplicate", # gradient batching rule not implemented for aten::sym_size.int
|
||||||
"test_return_duplicate_inplace", # gradient batching rule not implemented for aten::sym_size.int
|
"test_return_duplicate_inplace", # gradient batching rule not implemented for aten::sym_size.int
|
||||||
"test_setitem", # CopySlices accuracy error
|
"test_setitem", # CopySlices accuracy error
|
||||||
# Category: Inductor
|
"test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565
|
||||||
|
"test_checkpoint_detects_non_determinism", # different error
|
||||||
|
"test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute
|
||||||
|
"test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute
|
||||||
|
"test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193
|
||||||
|
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times
|
||||||
|
"test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised
|
||||||
|
"test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute
|
||||||
|
"test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115
|
||||||
|
"test_checkpointing", # takes very very long
|
||||||
|
"test_checkpointing_without_reentrant_input_requires_grad_False", # takes very very long
|
||||||
|
"test_checkpointing_without_reentrant_input_requires_grad_True", # takes very very long
|
||||||
|
"test_checkpointing_without_reentrant_memory_savings", # takes very very long
|
||||||
|
# Category: Inductor (pass on backend="aot_eager")
|
||||||
"test_input_buffer_accum", # does not support sparse_grad=True: https://github.com/pytorch/pytorch/issues/120267
|
"test_input_buffer_accum", # does not support sparse_grad=True: https://github.com/pytorch/pytorch/issues/120267
|
||||||
"test_graph_save_on_cpu", # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173
|
"test_graph_save_on_cpu", # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173
|
||||||
# Category: FakeTensor
|
# Category: FakeTensor
|
||||||
@ -3861,6 +4135,7 @@ known_failing_tests = {
|
|||||||
"test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward
|
"test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward
|
||||||
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
|
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
|
||||||
"test_backward_hook_relative_ordering", # compiled autograd collects breadth first, and module backward hook not supported
|
"test_backward_hook_relative_ordering", # compiled autograd collects breadth first, and module backward hook not supported
|
||||||
|
"test_checkpointing_without_reentrant_custom_function_works", # ctx.saved_tensors are cached by CA
|
||||||
# Category: Subclasses
|
# Category: Subclasses
|
||||||
"test_dtensor_basic",
|
"test_dtensor_basic",
|
||||||
"test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent",
|
"test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent",
|
||||||
|
@ -134,7 +134,7 @@ class Op:
|
|||||||
ops = OpNamespace()
|
ops = OpNamespace()
|
||||||
|
|
||||||
|
|
||||||
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
|
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks", "packed_data"]
|
||||||
_impure_targets = OrderedSet(
|
_impure_targets = OrderedSet(
|
||||||
[
|
[
|
||||||
call_hook,
|
call_hook,
|
||||||
@ -206,7 +206,13 @@ class AutogradCompilerInstance:
|
|||||||
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
|
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
|
||||||
self.fx_tracer.tensor_attrs = {}
|
self.fx_tracer.tensor_attrs = {}
|
||||||
self.symnode_proxy_lookup = {}
|
self.symnode_proxy_lookup = {}
|
||||||
args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = (
|
(
|
||||||
|
args_proxy,
|
||||||
|
self.sizes_proxy,
|
||||||
|
self.scalars_proxy,
|
||||||
|
self.hooks_proxy,
|
||||||
|
self.packed_data_proxy,
|
||||||
|
) = (
|
||||||
self.fx_tracer.create_proxy("placeholder", name, (), {})
|
self.fx_tracer.create_proxy("placeholder", name, (), {})
|
||||||
for name in _graph_placeholders
|
for name in _graph_placeholders
|
||||||
)
|
)
|
||||||
@ -268,7 +274,12 @@ class AutogradCompilerInstance:
|
|||||||
self.stack.enter_context(
|
self.stack.enter_context(
|
||||||
torch.fx.experimental.symbolic_shapes._suppress_guards(env)
|
torch.fx.experimental.symbolic_shapes._suppress_guards(env)
|
||||||
)
|
)
|
||||||
return str(CompileContext.current_compile_id()), inputs, sizes, scalars
|
return (
|
||||||
|
str(CompileContext.current_compile_id()),
|
||||||
|
inputs,
|
||||||
|
sizes,
|
||||||
|
scalars,
|
||||||
|
)
|
||||||
|
|
||||||
def log_compile_reasons(
|
def log_compile_reasons(
|
||||||
self,
|
self,
|
||||||
@ -567,6 +578,19 @@ class AutogradCompilerInstance:
|
|||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def unpack_hook(self, hook_id, data_id):
|
||||||
|
assert self.hooks_proxy is not None
|
||||||
|
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||||||
|
data = self.packed_data_proxy[data_id] # type: ignore[index]
|
||||||
|
proxy = self.proxy_call_hook(
|
||||||
|
hook,
|
||||||
|
data,
|
||||||
|
hook_type="unpack_hook",
|
||||||
|
)
|
||||||
|
out = self.allocate_dummy()
|
||||||
|
self.bind_objects_to_proxies([out], [proxy])
|
||||||
|
return out
|
||||||
|
|
||||||
def tensor_pre_hook(self, inputs, hook_id, i: int):
|
def tensor_pre_hook(self, inputs, hook_id, i: int):
|
||||||
assert self.hooks_proxy is not None
|
assert self.hooks_proxy is not None
|
||||||
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||||||
@ -706,6 +730,9 @@ class AutogradCompilerInstance:
|
|||||||
after = len(self.fx_tracer.graph.nodes)
|
after = len(self.fx_tracer.graph.nodes)
|
||||||
verbose_log.debug("DCE removed %d nodes", before - after)
|
verbose_log.debug("DCE removed %d nodes", before - after)
|
||||||
|
|
||||||
|
def create_graph_module(self, id):
|
||||||
|
return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id)
|
||||||
|
|
||||||
def end_capture(self, outputs):
|
def end_capture(self, outputs):
|
||||||
self.fx_tracer.create_proxy(
|
self.fx_tracer.create_proxy(
|
||||||
"call_function",
|
"call_function",
|
||||||
@ -745,6 +772,7 @@ class AutogradCompilerInstance:
|
|||||||
).print_readable(print_output=False),
|
).print_readable(print_output=False),
|
||||||
)
|
)
|
||||||
self.rename_aot_dispatcher_nodes()
|
self.rename_aot_dispatcher_nodes()
|
||||||
|
self.delay_unpack_hook_nodes()
|
||||||
self.reorder_tensor_pre_hook_nodes()
|
self.reorder_tensor_pre_hook_nodes()
|
||||||
self.reorder_pre_hook_nodes_to_schedule_asap()
|
self.reorder_pre_hook_nodes_to_schedule_asap()
|
||||||
self.reorder_accumulate_grad_nodes()
|
self.reorder_accumulate_grad_nodes()
|
||||||
@ -763,9 +791,7 @@ class AutogradCompilerInstance:
|
|||||||
# should prevent these ops from going into the CA graph.
|
# should prevent these ops from going into the CA graph.
|
||||||
self.dce()
|
self.dce()
|
||||||
|
|
||||||
graph = GraphModule(
|
graph = self.create_graph_module(f"CompiledAutograd{self.id}")
|
||||||
self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}"
|
|
||||||
)
|
|
||||||
set_locals_to_steal(graph, ["inputs"])
|
set_locals_to_steal(graph, ["inputs"])
|
||||||
lazy_graph_code = lazy_format_graph_code(
|
lazy_graph_code = lazy_format_graph_code(
|
||||||
"Compiled autograd graph",
|
"Compiled autograd graph",
|
||||||
@ -781,7 +807,7 @@ class AutogradCompilerInstance:
|
|||||||
payload_fn=lambda: graph.print_readable(print_output=False),
|
payload_fn=lambda: graph.print_readable(print_output=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
|
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs):
|
||||||
global in_compiled_autograd_region
|
global in_compiled_autograd_region
|
||||||
try:
|
try:
|
||||||
in_compiled_autograd_region = True
|
in_compiled_autograd_region = True
|
||||||
@ -789,7 +815,7 @@ class AutogradCompilerInstance:
|
|||||||
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
|
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
|
||||||
|
|
||||||
with _disable(), make_compile_context(self.id):
|
with _disable(), make_compile_context(self.id):
|
||||||
return compiled_fn(inputs, sizes, scalars, hooks)
|
return compiled_fn(inputs, sizes, scalars, hooks, packed_inputs)
|
||||||
finally:
|
finally:
|
||||||
in_compiled_autograd_region = False
|
in_compiled_autograd_region = False
|
||||||
|
|
||||||
@ -938,6 +964,19 @@ class AutogradCompilerInstance:
|
|||||||
if getitem_node is not None:
|
if getitem_node is not None:
|
||||||
arg.append(getitem_node)
|
arg.append(getitem_node)
|
||||||
|
|
||||||
|
def delay_unpack_hook_nodes(self):
|
||||||
|
"""
|
||||||
|
We can delay unpack hooks until they are needed, even later than in the eager autograd engine.
|
||||||
|
"""
|
||||||
|
for node in self.fx_tracer.graph.find_nodes(
|
||||||
|
op="call_function", target=call_hook
|
||||||
|
):
|
||||||
|
if node.kwargs.get("hook_type", None) != "unpack_hook":
|
||||||
|
continue
|
||||||
|
|
||||||
|
first_user = min(node.users)
|
||||||
|
first_user.prepend(node)
|
||||||
|
|
||||||
def reorder_tensor_pre_hook_nodes(self):
|
def reorder_tensor_pre_hook_nodes(self):
|
||||||
"""
|
"""
|
||||||
Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed
|
Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed
|
||||||
|
@ -1334,6 +1334,7 @@ auto Engine::execute(
|
|||||||
!AnomalyMode::is_enabled(),
|
!AnomalyMode::is_enabled(),
|
||||||
"compiled_autograd does not support AnomalyMode")
|
"compiled_autograd does not support AnomalyMode")
|
||||||
GraphTaskGuard guard(graph_task);
|
GraphTaskGuard guard(graph_task);
|
||||||
|
CheckpointValidGuard cpvguard(graph_task);
|
||||||
return (*compiled_autograd)(
|
return (*compiled_autograd)(
|
||||||
graph_root, *graph_task, accumulate_grad, outputs);
|
graph_root, *graph_task, accumulate_grad, outputs);
|
||||||
}
|
}
|
||||||
|
@ -46,6 +46,15 @@ at::Tensor PySavedVariableHooks::call_unpack_hook() {
|
|||||||
// unpack_hook_ will be manually decrefed when the saved variable is released
|
// unpack_hook_ will be manually decrefed when the saved variable is released
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||||
|
PySavedVariableHooks::retrieve_unpack_hook_data() const {
|
||||||
|
Py_INCREF(unpack_hook_);
|
||||||
|
Py_INCREF(data_);
|
||||||
|
return std::make_pair(
|
||||||
|
c10::SafePyObject(unpack_hook_, getPyInterpreter()),
|
||||||
|
c10::SafePyObject(data_, getPyInterpreter()));
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||||
PySavedVariableHooks::~PySavedVariableHooks() {
|
PySavedVariableHooks::~PySavedVariableHooks() {
|
||||||
// If python is already dead, leak the wrapped python objects
|
// If python is already dead, leak the wrapped python objects
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
#include <c10/core/SafePyObject.h>
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
#include <torch/csrc/autograd/python_variable.h>
|
#include <torch/csrc/autograd/python_variable.h>
|
||||||
@ -17,6 +18,8 @@ struct PySavedVariableHooks : public SavedVariableHooks {
|
|||||||
void call_pack_hook(const at::Tensor& tensor) override;
|
void call_pack_hook(const at::Tensor& tensor) override;
|
||||||
at::Tensor call_unpack_hook() override;
|
at::Tensor call_unpack_hook() override;
|
||||||
~PySavedVariableHooks() override;
|
~PySavedVariableHooks() override;
|
||||||
|
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||||
|
retrieve_unpack_hook_data() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PyObject* pack_hook_;
|
PyObject* pack_hook_;
|
||||||
|
@ -59,6 +59,7 @@ SavedVariable::SavedVariable(
|
|||||||
if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) {
|
if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||||
save_metadata(variable);
|
save_metadata(variable);
|
||||||
set_hooks_and_pack_data(std::move(maybe_hooks), variable);
|
set_hooks_and_pack_data(std::move(maybe_hooks), variable);
|
||||||
|
TORCH_INTERNAL_ASSERT(!data_.defined());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,9 +135,14 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
|
|||||||
// We want grad_fn here to provide the most helpful debug message to the user
|
// We want grad_fn here to provide the most helpful debug message to the user
|
||||||
// if versions don't match
|
// if versions don't match
|
||||||
|
|
||||||
auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock()
|
std::shared_ptr<Node> grad_fn;
|
||||||
: !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr
|
if (is_inplace_on_view_) {
|
||||||
: grad_fn_;
|
grad_fn = weak_grad_fn_.lock();
|
||||||
|
} else if (!hooks_) {
|
||||||
|
grad_fn = saved_original_ ? data_.grad_fn() : nullptr;
|
||||||
|
} else {
|
||||||
|
grad_fn = grad_fn_;
|
||||||
|
}
|
||||||
|
|
||||||
if (!is_leaf_ && !grad_fn) {
|
if (!is_leaf_ && !grad_fn) {
|
||||||
// This issue was introduced when we added logic to save the original
|
// This issue was introduced when we added logic to save the original
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/core/SafePyObject.h>
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
#include <torch/csrc/autograd/forward_grad.h>
|
#include <torch/csrc/autograd/forward_grad.h>
|
||||||
#include <torch/csrc/autograd/saved_variable_hooks.h>
|
#include <torch/csrc/autograd/saved_variable_hooks.h>
|
||||||
@ -53,6 +54,15 @@ class TORCH_API SavedVariable {
|
|||||||
return (bool)hooks_;
|
return (bool)hooks_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Used by compiled autograd
|
||||||
|
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||||
|
retrieve_unpack_hook_data() const {
|
||||||
|
if (!hooks_) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
return hooks_->retrieve_unpack_hook_data();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// This field contains either:
|
// This field contains either:
|
||||||
// 1. the variable to save
|
// 1. the variable to save
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
|
#include <c10/core/SafePyObject.h>
|
||||||
|
|
||||||
namespace torch::autograd {
|
namespace torch::autograd {
|
||||||
|
|
||||||
@ -8,6 +9,11 @@ struct TORCH_API SavedVariableHooks {
|
|||||||
virtual void call_pack_hook(const at::Tensor& tensor) = 0;
|
virtual void call_pack_hook(const at::Tensor& tensor) = 0;
|
||||||
virtual at::Tensor call_unpack_hook() = 0;
|
virtual at::Tensor call_unpack_hook() = 0;
|
||||||
virtual ~SavedVariableHooks() = default;
|
virtual ~SavedVariableHooks() = default;
|
||||||
|
virtual std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||||
|
retrieve_unpack_hook_data() const {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Compiled Autograd only supports python saved tensor hooks ");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace torch::autograd
|
} // namespace torch::autograd
|
||||||
|
@ -17,6 +17,78 @@
|
|||||||
namespace torch::dynamo::autograd {
|
namespace torch::dynamo::autograd {
|
||||||
using namespace torch::autograd;
|
using namespace torch::autograd;
|
||||||
|
|
||||||
|
// This is a layer of indirection for calling methods on the Python
|
||||||
|
// AutogradCompilerInstance (referred to as the "py_compiler") from
|
||||||
|
// libtorch_cpu (where Python is not available).
|
||||||
|
// A PyCompilerInterfaceImpl in libtorch_python subclasses it and
|
||||||
|
// overrides the methods to do the actual calls back to Python.
|
||||||
|
struct TORCH_API PyCompilerInterface {
|
||||||
|
PyCompilerInterface() = default;
|
||||||
|
PyCompilerInterface(const PyCompilerInterface&) = delete;
|
||||||
|
PyCompilerInterface& operator=(const PyCompilerInterface&) = delete;
|
||||||
|
PyCompilerInterface(PyCompilerInterface&&) = delete;
|
||||||
|
PyCompilerInterface& operator=(PyCompilerInterface&&) = delete;
|
||||||
|
virtual ~PyCompilerInterface() = default;
|
||||||
|
|
||||||
|
// Invokes py_compiler.bind_function
|
||||||
|
virtual std::string bind_function(
|
||||||
|
PyObject* py_compiler,
|
||||||
|
const std::string& fn_name,
|
||||||
|
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||||
|
functional_apply_t fn,
|
||||||
|
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||||
|
std::vector<at::TypePtr> packed_args_schema,
|
||||||
|
bool is_custom_function = false,
|
||||||
|
bool is_traceable = true) {
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
|
||||||
|
// output_metadata)
|
||||||
|
virtual variable_list call_function(
|
||||||
|
PyObject* py_compiler,
|
||||||
|
const char* method_name,
|
||||||
|
const std::string& fn_name,
|
||||||
|
const variable_list& inputs,
|
||||||
|
const ivalue_list& packed_args,
|
||||||
|
const c10::IValue& output_metadata) {
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||||
|
}
|
||||||
|
virtual variable_list call_copy_slices_prologue(
|
||||||
|
PyObject* py_compiler,
|
||||||
|
const variable_list& inputs,
|
||||||
|
const at::TensorGeometry& base,
|
||||||
|
const at::TensorGeometry& view) {
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||||
|
}
|
||||||
|
virtual variable_list call_copy_slices_epilogue(
|
||||||
|
PyObject* py_compiler,
|
||||||
|
const std::vector<bool>& needs_input_grad,
|
||||||
|
const at::Tensor& result,
|
||||||
|
const variable_list& res,
|
||||||
|
const at::Tensor& grad_slice) {
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||||
|
}
|
||||||
|
virtual at::Tensor call_unpack(
|
||||||
|
PyObject* py_compiler,
|
||||||
|
std::optional<size_t> hook_id,
|
||||||
|
size_t hook_input_id) {
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
|
||||||
|
TORCH_API void setPyCompilerInterface(
|
||||||
|
std::unique_ptr<PyCompilerInterface>&& impl);
|
||||||
|
TORCH_API void resetPyCompilerInterface();
|
||||||
|
|
||||||
|
// including torch/csrc/autograd/engine.h breaks BC by somehow introducing
|
||||||
|
// symbol resolution issues. Instead requiring downstream users to include
|
||||||
|
// engine.h to access collect_input_metadata, we provide it here (with a
|
||||||
|
// different name to avoid ambigous symbols...)
|
||||||
|
TORCH_API std::vector<std::optional<InputMetadata>> get_input_metadata(
|
||||||
|
const edge_list& edges);
|
||||||
|
|
||||||
struct SizeInput {
|
struct SizeInput {
|
||||||
// Note: int value is still needed when dynamic to pass as an arg
|
// Note: int value is still needed when dynamic to pass as an arg
|
||||||
enum DynType : uint8_t { STATIC = 0, DYNAMIC = 1 };
|
enum DynType : uint8_t { STATIC = 0, DYNAMIC = 1 };
|
||||||
@ -154,9 +226,14 @@ struct TensorArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TensorArg& lookup(const SavedVariable& sv) {
|
TensorArg& lookup(const SavedVariable& sv) {
|
||||||
auto it = _saved_variables.find(&sv);
|
if (auto it = _saved_variables.find(&sv); it != _saved_variables.end()) {
|
||||||
TORCH_INTERNAL_ASSERT(it != _saved_variables.end());
|
// unpacked before graph
|
||||||
return *it->second;
|
return *it->second;
|
||||||
|
}
|
||||||
|
// unpacked in graph
|
||||||
|
auto it2 = _saved_variables_proxies.find(&sv);
|
||||||
|
TORCH_INTERNAL_ASSERT(it2 != _saved_variables_proxies.end());
|
||||||
|
return *it2->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorArg& add(const at::Tensor& tensor) {
|
TensorArg& add(const at::Tensor& tensor) {
|
||||||
@ -164,9 +241,7 @@ struct TensorArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TensorArg& add(const SavedVariable& sv, const std::shared_ptr<Node>& node) {
|
TensorArg& add(const SavedVariable& sv, const std::shared_ptr<Node>& node) {
|
||||||
// TODO(jansel): Here we unpack the SavedVariable exactly once. This might
|
// no unpack hooks in this codepath
|
||||||
// fire SavedTensor hooks. In the future we should try to put saved tensor
|
|
||||||
// hooks into the graph.
|
|
||||||
at::Tensor tensor = sv.unpack(node);
|
at::Tensor tensor = sv.unpack(node);
|
||||||
TensorArg& arg = add(tensor);
|
TensorArg& arg = add(tensor);
|
||||||
_saved_variables.emplace(&sv, &arg);
|
_saved_variables.emplace(&sv, &arg);
|
||||||
@ -185,6 +260,7 @@ struct TensorArgs {
|
|||||||
// Every TensorArg from this is actually owned by _args (or _undefined) and
|
// Every TensorArg from this is actually owned by _args (or _undefined) and
|
||||||
// that's why we have an un-owned pointer here.
|
// that's why we have an un-owned pointer here.
|
||||||
std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables;
|
std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables;
|
||||||
|
std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables_proxies;
|
||||||
TensorArg _undefined;
|
TensorArg _undefined;
|
||||||
uint32_t _next_id = 1; // id=0 used by _undefined
|
uint32_t _next_id = 1; // id=0 used by _undefined
|
||||||
};
|
};
|
||||||
@ -245,6 +321,11 @@ struct AutogradCompilerCall {
|
|||||||
return hooks.size() - 1;
|
return hooks.size() - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t emplace_packed_input(c10::SafePyObject&& input) {
|
||||||
|
packed_inputs.emplace_back(std::move(input));
|
||||||
|
return packed_inputs.size() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
void set_active_node_call_idx(size_t node_call_idx) {
|
void set_active_node_call_idx(size_t node_call_idx) {
|
||||||
active_node_call_idx = node_call_idx;
|
active_node_call_idx = node_call_idx;
|
||||||
}
|
}
|
||||||
@ -255,10 +336,13 @@ struct AutogradCompilerCall {
|
|||||||
LiftedIValueArgs lifted_ivalue_args;
|
LiftedIValueArgs lifted_ivalue_args;
|
||||||
std::vector<int64_t> dyn_size_inputs;
|
std::vector<int64_t> dyn_size_inputs;
|
||||||
std::vector<c10::SafePyObject> hooks;
|
std::vector<c10::SafePyObject> hooks;
|
||||||
|
std::vector<c10::SafePyObject> packed_inputs;
|
||||||
NodeCalls node_calls;
|
NodeCalls node_calls;
|
||||||
SizeInput::DynType default_dyn_type;
|
SizeInput::DynType default_dyn_type;
|
||||||
// NodeCall id of each size, only when verbose logging is enabled
|
// NodeCall id of each size, only when verbose logging is enabled
|
||||||
std::vector<uint32_t> size_input_origins;
|
std::vector<uint32_t> size_input_origins;
|
||||||
|
std::unordered_map<const SavedVariable*, std::pair<size_t, size_t>>
|
||||||
|
sv_to_hooks;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CompiledNodeArgs {
|
class CompiledNodeArgs {
|
||||||
@ -285,8 +369,19 @@ class CompiledNodeArgs {
|
|||||||
collect(_compiler.tensor_args.add(t));
|
collect(_compiler.tensor_args.add(t));
|
||||||
}
|
}
|
||||||
void collect(const SavedVariable& sv, bool is_output) {
|
void collect(const SavedVariable& sv, bool is_output) {
|
||||||
collect(
|
if (auto hook_data = sv.retrieve_unpack_hook_data();
|
||||||
_compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
|
hook_data.has_value()) {
|
||||||
|
// hooks, unpack in graph
|
||||||
|
auto& [hook, packed_input] = hook_data.value();
|
||||||
|
size_t hook_id = _compiler.emplace_hook(std::move(hook));
|
||||||
|
// rely on dynamo to dedup packed tensors from unpacked tensors
|
||||||
|
size_t input_id = _compiler.emplace_packed_input(std::move(packed_input));
|
||||||
|
_compiler.sv_to_hooks.emplace(&sv, std::make_pair(hook_id, input_id));
|
||||||
|
} else {
|
||||||
|
// no hooks, unpack now
|
||||||
|
collect(
|
||||||
|
_compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
void collect(const c10::SymInt& t) {
|
void collect(const c10::SymInt& t) {
|
||||||
_compiler.add_size_input(t);
|
_compiler.add_size_input(t);
|
||||||
@ -655,13 +750,26 @@ class SwapSavedVariables {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void before(SavedVariable& t) {
|
void before(SavedVariable& t) {
|
||||||
TensorArg& arg = compiler.tensor_args.lookup(t);
|
if (auto it = compiler.sv_to_hooks.find(&t);
|
||||||
stashed_variables.save(&t, std::move(t));
|
it != compiler.sv_to_hooks.end()) {
|
||||||
if (arg.defined()) {
|
const auto& pyinterface =
|
||||||
|
torch::dynamo::autograd::getPyCompilerInterface();
|
||||||
|
auto proxy_tensor = pyinterface->call_unpack(
|
||||||
|
get_py_compiler(), it->second.first, it->second.second);
|
||||||
|
stashed_variables.save(&t, std::move(t));
|
||||||
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
|
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
|
||||||
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
|
t = SavedVariable(proxy_tensor, false);
|
||||||
t = SavedVariable(arg.proxy_tensor, false);
|
|
||||||
at::SavedTensorDefaultHooks::set_tracing(prior);
|
at::SavedTensorDefaultHooks::set_tracing(prior);
|
||||||
|
} else {
|
||||||
|
// no hooks, was already unpacked
|
||||||
|
TensorArg& arg = compiler.tensor_args.lookup(t);
|
||||||
|
stashed_variables.save(&t, std::move(t));
|
||||||
|
if (arg.defined()) {
|
||||||
|
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
|
||||||
|
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
|
||||||
|
t = SavedVariable(arg.proxy_tensor, false);
|
||||||
|
at::SavedTensorDefaultHooks::set_tracing(prior);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void after(SavedVariable& t) {
|
void after(SavedVariable& t) {
|
||||||
@ -1370,73 +1478,6 @@ struct PackedArgs {
|
|||||||
int64_t idx = 0;
|
int64_t idx = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// This is a layer of indirection for calling methods on the Python
|
|
||||||
// AutogradCompilerInstance (referred to as the "py_compiler") from
|
|
||||||
// libtorch_cpu (where Python is not available).
|
|
||||||
// A PyCompilerInterfaceImpl in libtorch_python subclasses it and
|
|
||||||
// overrides the methods to do the actual calls back to Python.
|
|
||||||
struct TORCH_API PyCompilerInterface {
|
|
||||||
PyCompilerInterface() = default;
|
|
||||||
PyCompilerInterface(const PyCompilerInterface&) = delete;
|
|
||||||
PyCompilerInterface& operator=(const PyCompilerInterface&) = delete;
|
|
||||||
PyCompilerInterface(PyCompilerInterface&&) = delete;
|
|
||||||
PyCompilerInterface& operator=(PyCompilerInterface&&) = delete;
|
|
||||||
virtual ~PyCompilerInterface() = default;
|
|
||||||
|
|
||||||
// Invokes py_compiler.bind_function
|
|
||||||
virtual std::string bind_function(
|
|
||||||
PyObject* py_compiler,
|
|
||||||
const std::string& fn_name,
|
|
||||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
|
||||||
functional_apply_t fn,
|
|
||||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
|
||||||
std::vector<at::TypePtr> packed_args_schema,
|
|
||||||
bool is_custom_function = false,
|
|
||||||
bool is_traceable = true) {
|
|
||||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
|
|
||||||
// output_metadata)
|
|
||||||
virtual variable_list call_function(
|
|
||||||
PyObject* py_compiler,
|
|
||||||
const char* method_name,
|
|
||||||
const std::string& fn_name,
|
|
||||||
const variable_list& inputs,
|
|
||||||
const ivalue_list& packed_args,
|
|
||||||
const c10::IValue& output_metadata) {
|
|
||||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual variable_list call_copy_slices_prologue(
|
|
||||||
PyObject* py_compiler,
|
|
||||||
const variable_list& inputs,
|
|
||||||
const at::TensorGeometry& base,
|
|
||||||
const at::TensorGeometry& view) {
|
|
||||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
|
||||||
}
|
|
||||||
virtual variable_list call_copy_slices_epilogue(
|
|
||||||
PyObject* py_compiler,
|
|
||||||
const std::vector<bool>& needs_input_grad,
|
|
||||||
const at::Tensor& result,
|
|
||||||
const variable_list& res,
|
|
||||||
const at::Tensor& grad_slice) {
|
|
||||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
|
|
||||||
TORCH_API void setPyCompilerInterface(
|
|
||||||
std::unique_ptr<PyCompilerInterface>&& impl);
|
|
||||||
TORCH_API void resetPyCompilerInterface();
|
|
||||||
|
|
||||||
// including torch/csrc/autograd/engine.h breaks BC by somehow introducing
|
|
||||||
// symbol resolution issues. Instead requiring downstream users to include
|
|
||||||
// engine.h to access collect_input_metadata, we provide it here (with a
|
|
||||||
// different name to avoid ambigous symbols...)
|
|
||||||
TORCH_API std::vector<std::optional<InputMetadata>> get_input_metadata(
|
|
||||||
const edge_list& edges);
|
|
||||||
|
|
||||||
} // namespace torch::dynamo::autograd
|
} // namespace torch::dynamo::autograd
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -203,6 +203,16 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface {
|
|||||||
auto output = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
|
auto output = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
|
||||||
return toTensorList(output);
|
return toTensorList(output);
|
||||||
}
|
}
|
||||||
|
at::Tensor call_unpack(
|
||||||
|
PyObject* py_compiler,
|
||||||
|
std::optional<size_t> hook_id,
|
||||||
|
size_t hook_input_id) override {
|
||||||
|
py::handle handle(py_compiler);
|
||||||
|
py::object proxy = handle.attr("unpack_hook")(hook_id, hook_input_id);
|
||||||
|
auto tmp = py::cast<std::optional<at::Tensor>>(proxy);
|
||||||
|
TORCH_INTERNAL_ASSERT(tmp.has_value());
|
||||||
|
return tmp.value();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
|
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
|
||||||
@ -213,7 +223,7 @@ static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
|
|||||||
return pyinput;
|
return pyinput;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
|
static PyObject* convert_pyobj_list(std::vector<c10::SafePyObject>& inputs) {
|
||||||
// inplace, consumes the input hooks
|
// inplace, consumes the input hooks
|
||||||
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
|
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
|
||||||
for (const auto i : c10::irange(inputs.size())) {
|
for (const auto i : c10::irange(inputs.size())) {
|
||||||
@ -654,7 +664,7 @@ static PyObject* wrap_string_list(const std::vector<std::string>& strs) {
|
|||||||
return pystrs;
|
return pystrs;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string unwrap_string(PyObject* pystr) {
|
static std::string unwrap_string(PyObject* pystr) {
|
||||||
TORCH_INTERNAL_ASSERT(PyUnicode_Check(pystr));
|
TORCH_INTERNAL_ASSERT(PyUnicode_Check(pystr));
|
||||||
const char* str = PyUnicode_AsUTF8(pystr);
|
const char* str = PyUnicode_AsUTF8(pystr);
|
||||||
TORCH_INTERNAL_ASSERT(str != nullptr);
|
TORCH_INTERNAL_ASSERT(str != nullptr);
|
||||||
@ -796,7 +806,8 @@ static CacheNode* _compiled_autograd_impl(
|
|||||||
THPObjectPtr* graph_arg_inputs,
|
THPObjectPtr* graph_arg_inputs,
|
||||||
THPObjectPtr* graph_arg_sizes,
|
THPObjectPtr* graph_arg_sizes,
|
||||||
THPObjectPtr* graph_arg_ivalue_args,
|
THPObjectPtr* graph_arg_ivalue_args,
|
||||||
THPObjectPtr* graph_arg_hooks) {
|
THPObjectPtr* graph_arg_hooks,
|
||||||
|
THPObjectPtr* graph_arg_packed_inputs) {
|
||||||
std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
|
std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
|
||||||
std::vector<std::shared_ptr<Node>> worklist{graph_root};
|
std::vector<std::shared_ptr<Node>> worklist{graph_root};
|
||||||
AutogradCompilerCall compiler_call(get_default_dyn_type());
|
AutogradCompilerCall compiler_call(get_default_dyn_type());
|
||||||
@ -1052,7 +1063,8 @@ static CacheNode* _compiled_autograd_impl(
|
|||||||
*graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);
|
*graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);
|
||||||
*graph_arg_ivalue_args =
|
*graph_arg_ivalue_args =
|
||||||
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args);
|
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args);
|
||||||
*graph_arg_hooks = convert_hook_list(compiler_call.hooks);
|
*graph_arg_hooks = convert_pyobj_list(compiler_call.hooks);
|
||||||
|
*graph_arg_packed_inputs = convert_pyobj_list(compiler_call.packed_inputs);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1093,6 +1105,7 @@ static variable_list compiled_autograd(
|
|||||||
THPObjectPtr sizes;
|
THPObjectPtr sizes;
|
||||||
THPObjectPtr ivalue_args;
|
THPObjectPtr ivalue_args;
|
||||||
THPObjectPtr hooks;
|
THPObjectPtr hooks;
|
||||||
|
THPObjectPtr packed_inputs;
|
||||||
CacheNode* cache = _compiled_autograd_impl(
|
CacheNode* cache = _compiled_autograd_impl(
|
||||||
graph_root,
|
graph_root,
|
||||||
graph_task,
|
graph_task,
|
||||||
@ -1101,7 +1114,8 @@ static variable_list compiled_autograd(
|
|||||||
&inputs,
|
&inputs,
|
||||||
&sizes,
|
&sizes,
|
||||||
&ivalue_args,
|
&ivalue_args,
|
||||||
&hooks);
|
&hooks,
|
||||||
|
&packed_inputs);
|
||||||
|
|
||||||
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
|
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
|
||||||
cache->runtime_wrapper.get(),
|
cache->runtime_wrapper.get(),
|
||||||
@ -1110,6 +1124,7 @@ static variable_list compiled_autograd(
|
|||||||
sizes.get(),
|
sizes.get(),
|
||||||
ivalue_args.get(),
|
ivalue_args.get(),
|
||||||
hooks.get(),
|
hooks.get(),
|
||||||
|
packed_inputs.get(),
|
||||||
NULL)));
|
NULL)));
|
||||||
variable_list outputs = THPVariable_UnpackList(pyresult);
|
variable_list outputs = THPVariable_UnpackList(pyresult);
|
||||||
TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size());
|
TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size());
|
||||||
|
Reference in New Issue
Block a user