Compare commits

...

1 Commits

Author SHA1 Message Date
0aa576e373 foo 2024-10-15 22:06:38 -04:00
5 changed files with 405 additions and 129 deletions

View File

@ -82,6 +82,77 @@ class AutogradCompilerInstance:
def source(name, idx) -> GetItemSource:
return GetItemSource(LocalSource(name), idx)
def capture(self, inputs, sizes, scalars, origins, nodecalls, num_outputs):
counters["compiled_autograd"]["captures"] += 1
inputs_origins, sizes_origins, scalars_origins = origins
self.fx_tracer.root = torch.nn.Module()
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
self.fx_tracer.tensor_attrs = {}
inputs_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = (
self.fx_tracer.create_proxy("placeholder", name, (), {})
for name in self.graph_placeholders
)
graph_outputs = [None] * num_outputs
self.fx_tracer.create_proxy(
kind="call_function",
target=CA_input_buffers_init,
args=(),
kwargs={},
)
for node_idx, call in enumerate(nodecalls):
inputs_idx = 0
sizes_idx = 0
scalars_idx = 0
input_buffer = self.fx_tracer.create_proxy(
kind="call_function",
target=CA_input_buffers_lookup,
args=(node_idx,),
kwargs={},
)
num_saved_inputs = call.num_saved_tensors
num_saved_sizes = call.num_saved_sizes
num_saved_scalars = call.num_saved_ivalues
saved_inputs = inputs_proxy[inputs_idx:inputs_idx + num_saved_inputs]
saved_sizes = sizes_proxy[sizes_idx:sizes_idx + num_saved_sizes]
saved_scalars = scalars_proxy[scalars_idx:scalars_idx + num_saved_scalars]
inputs_idx += num_saved_inputs
sizes_idx += num_saved_sizes
scalars_idx += num_saved_scalars
for input_nr, result_idx in call.graph_output:
graph_outputs[result_idx] = input_buffer[input_nr]
if not call.needed:
continue
outputs = self.fx_tracer.create_proxy(
kind="call_function",
target=CA_apply_with_saved,
args=(node_idx, input_buffer, saved_inputs, saved_sizes, saved_scalars),
kwargs={},
)
self.fx_tracer.create_proxy(
kind="call_function",
target=CA_validate_outputs,
args=(node_idx, outputs),
kwargs={},
)
self.fx_tracer.create_proxy(
kind="call_function",
target=CA_update_input_buffers,
args=(node_idx, outputs),
kwargs={},
)
return self.end_capture(graph_outputs)
def begin_capture(
self,
inputs: List[torch.Tensor],
@ -308,8 +379,8 @@ class AutogradCompilerInstance:
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
{},
)
self.rename_aot_dispatcher_nodes()
self.reorder_accumulate_grad_nodes()
# self.rename_aot_dispatcher_nodes()
# self.reorder_accumulate_grad_nodes()
runtime_inputs_to_move: List[int] = []
if snapshot_cudagraph_enabled():
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
@ -562,3 +633,106 @@ def reset() -> None:
assert not in_compiled_autograd_region
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
global_input_buffers = None
global_nodecalls = None
def get_node(idx):
return global_nodecalls[idx].node
def set_global_nodecalls(nodecalls):
global global_nodecalls
global_nodecalls = nodecalls
# @torch._dynamo.allow_in_graph
@torch.fx.wrap
def CA_input_buffers_init():
global global_input_buffers
global_input_buffers = InputBuffers()
# @torch._dynamo.allow_in_graph
@torch.fx.wrap
def CA_input_buffers_lookup(node_idx):
node = get_node(node_idx)
result = global_input_buffers.lookup(node).buffer
breakpoint()
return result
import torch._C._autograd as _ca
# @torch._dynamo.allow_in_graph
@torch.fx.wrap
def CA_apply_with_saved(node_idx, inputs, saved_tensors, saved_sizes, saved_scalars):
node = get_node(node_idx)
swap_saved_variables = _ca.SwapSavedVariables(_ca.SwapWithReal(saved_tensors, saved_sizes, saved_scalars), None, global_nodecalls[node_idx])
outputs = node.apply_with_saved(inputs, swap_saved_variables)
breakpoint()
return outputs
# @torch._dynamo.allow_in_graph
@torch.fx.wrap
def CA_validate_outputs(node_idx, outputs):
breakpoint()
pass
# @torch._dynamo.allow_in_graph
@torch.fx.wrap
def CA_update_input_buffers(node_idx, outputs):
node = get_node(idx)
for output_idx, output in enumerate_outputs:
next_edge = node.next_edges(output_idx)
if next_edge.is_valid() and output is not None:
global_input_bufferes.lookup(next_edge.function.get()).add(next_edge.input_nr, output)
class InputBuffers:
def __init__(self):
self.dct = {}
def lookup(self, node):
self.dct[node] = InputBuffer(node.num_inputs())
return self.dct[node]
def get(self, node):
return self.dct[node]
class InputBuffer:
def __init__(self, size):
self.buffer = [None] * size
def __getitem__(self, pos):
return self.buffer[pos]
def add(self, pos, var):
if var is None:
return
old_var = self.buffer[pos]
if old_var is None:
self.buffer[pos] = old_var
else:
accumulate(self.buffer, pos, var)
def accumulate(buffer, pos, var):
# TODO(rzou): some more stuff here
buffer[pos] = buffer[pos] + var
def validate_outputs(edges, grads, format_error):
if len(grads) != len(edges):
raise ValueError(f"Invalid number of gradients - expected {len(edges)}, but got {len(grads)}")
# TODO(rzou): some more stuff here
for idx, grad in enumerate(grads):
edge = edges[idx]
if not edge.is_valid():
continue
metadata = edge.function.input_metadata(edge.input_nr)
if grad is None:
continue
grads[idx] = metadata.maybe_reduce(idx, grad, format_error)
# TODO(rzou): some more stuff here

View File

@ -3,6 +3,7 @@
#include <c10/util/ThreadLocal.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/dynamo/compiled_autograd.h>
#include <ATen/ATen.h>

View File

@ -1,4 +1,5 @@
#include <torch/csrc/python_headers.h>
#include <memory>
#include <ATen/PythonTorchFunctionTLS.h>
#include <ATen/SavedTensorHooks.h>
@ -35,6 +36,7 @@
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/python_raii.h>
#include <torch/csrc/utils/python_torch_function_mode.h>
#include <torch/csrc/dynamo/compiled_autograd.h>
#include <set>
#include <unordered_set>
@ -42,6 +44,7 @@
using torch::impl::py_context_manager;
using torch::impl::py_context_manager_DEPRECATED;
using namespace torch::dynamo::autograd;
namespace {
@ -79,6 +82,51 @@ struct EnablePythonDispatcher {
c10::impl::PyInterpreter* old_;
};
variable_list apply_with_saved314(
const NodeCall& nodecall,
const variable_list& inputs,
const std::vector<at::Tensor>& saved_tensors,
const std::vector<at::SymInt>& saved_sizes,
const std::vector<at::IValue>& saved_ivalues) {
auto swap_saved_variable = SwapSavedVariables(saved_tensors, saved_sizes, saved_ivalues, nullptr, nodecall);
return nodecall.node->apply_with_saved(inputs, swap_saved_variable);
}
/*
struct PySwapInterface : public SwapInterface {
public:
using SwapInterface::SwapInterface; // Inherit constructors
std::optional<at::Tensor> tensor(const at::Tensor& tensor) override {
PYBIND11_OVERRIDE_PURE(
std::optional<at::Tensor>, // Return type
SwapInterface, // Parent class
tensor, // Function name
);
}
std::optional<at::Tensor> tensor(const SavedVariable& sv) override {
PYBIND11_OVERRIDE_PURE(
std::optional<at::Tensor>, // Return type
SwapInterface, // Parent class
tensor, // Function name
);
}
std::optional<c10::SymInt> next_size() override {
PYBIND11_OVERRIDE_PURE(
std::optional<at::SymInt>, // Return type
SwapInterface, // Parent class
tensor, // Function name
);
}
c10::IValue next_ivalue() override {
PYBIND11_OVERRIDE_PURE(
c10::IValue, // Return type
SwapInterface, // Parent class
tensor, // Function name
);
}
};
*/
struct EnablePreDispatch {
EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {}
c10::impl::IncludeDispatchKeyGuard guard_;
@ -491,6 +539,32 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
}
});
// compiled_autograd stuff
py::class_<torch::autograd::Node, std::shared_ptr<torch::autograd::Node>>(m, "Node")
.def("compiled_args", &torch::autograd::Node::compiled_args)
.def("num_inputs", &torch::autograd::Node::num_inputs);
py::class_<torch::dynamo::autograd::NodeCall>(m, "NodeCall")
.def_readonly("node", &NodeCall::node)
.def_readonly("num_saved_tensors", &NodeCall::num_saved_tensors)
.def_readonly("num_saved_sizes", &NodeCall::num_saved_sizes)
.def_readonly("num_saved_ivalues", &NodeCall::num_saved_ivalues)
.def_readonly("graph_output", &NodeCall::graph_output)
.def_readonly("needed", &NodeCall::needed)
;
py::class_<torch::dynamo::autograd::CompiledNodeArgs>(m, "CompiledNodeArgs")
.def(py::init<AutogradCompilerCall&,NodeCall&>());
py::class_<torch::dynamo::autograd::AutogradCompilerCall>(m, "AutogradCompilerCall")
.def(py::init<>())
;
_C_m.def("apply_with_saved", &apply_with_saved314);
//py::class_<SwapInterface,PySwapInterface>(m, "SwapInterface");
// py::class_<SwapWithReal,SwapInterface>(m, "SwapWithReal")
// .def(py::init<std::vector<at::Tensor>,std::vector<c10::SymInt>,std::vector<c10::IValue>>())
// ;
// py::class_<SwapSavedVariables>(m, "SwapSavedVariables")
// .def(py::init<std::vector<at::Tensor>,std::vector<c10::SymInt>,std::vector<c10::IValue>,PyObject*,const NodeCall&>())
// ;
_C_m.def("_activate_gpu_trace", []() { activateGPUTrace(); });
py_context_manager_DEPRECATED<c10::InferenceMode, bool>(

View File

@ -69,7 +69,7 @@ struct CacheKey {
const uint8_t* key;
};
struct NodeCall {
struct TORCH_API NodeCall {
NodeCall(uint32_t id_, std::shared_ptr<Node> node_)
: id(id_), node(std::move(node_)) {}
@ -84,6 +84,9 @@ struct NodeCall {
std::vector<int> post_hooks;
std::vector<int> post_acc_grad_hooks;
std::vector<std::pair<int, int>> graph_output;
int num_saved_tensors = 0;
int num_saved_sizes = 0;
int num_saved_ivalues = 0;
bool needed = true;
};
@ -137,15 +140,17 @@ struct TensorArgs {
: active_node_call_idx(active_node_call_idx) {}
TensorArg& lookup(const at::Tensor& tensor, bool create = false) {
// unconditionally add the tensor to inputs... Dynamo will de-dupe them later
inputs.emplace_back(tensor);
if (!tensor.defined()) {
return _undefined;
}
auto impl = tensor.unsafeGetTensorImpl();
auto it = _args.find(impl);
if (it == _args.end()) {
TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1);
// TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1);
it = _args.emplace(impl, TensorArg(_next_id++)).first;
inputs.emplace_back(tensor);
// inputs.emplace_back(tensor);
if (active_node_call_idx.has_value()) {
input_origins.emplace_back(active_node_call_idx.value());
}
@ -208,6 +213,11 @@ struct LiftedIValueArgs {
return iv_arg.proxy;
}
at::IValue& next_proxy() {
auto& iv_arg = args.at(next++);
return iv_arg.proxy;
}
void add(const at::IValue* iv) {
args.emplace_back(iv);
if (active_node_call_idx.has_value()) {
@ -266,6 +276,7 @@ class CompiledNodeArgs {
// key.
public:
void collect(const TensorArg& t) {
_node_call.num_saved_tensors++;
collect_size(t.id);
if (t.defined()) {
const at::Tensor& tensor = _compiler.tensor_args.inputs[t.index()];
@ -366,6 +377,7 @@ class CompiledNodeArgs {
!nested &&
(iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat())) {
// can't lift ivalues nested in collections
_node_call.num_saved_ivalues++;
_compiler.lifted_ivalue_args.add(&iv);
} else {
try {
@ -550,6 +562,7 @@ class CompiledNodeArgs {
// -Wunknown-pragmas
template <typename T>
std::enable_if_t<std::is_unsigned_v<T>, void> collect_size(T s) {
_node_call.num_saved_sizes++;
// we expect sizes to be small, so try to cram them into a single byte
constexpr uint8_t encode_as_u64 = std::numeric_limits<uint8_t>::max();
constexpr uint8_t encode_as_u32 = encode_as_u64 - 1;
@ -629,17 +642,93 @@ struct TraceState {
variable_list outputs;
};
struct TORCH_API SwapInterface {
virtual ~SwapInterface() = default;
virtual std::optional<at::Tensor> tensor(const at::Tensor& tensor);
virtual std::optional<at::Tensor> tensor(const SavedVariable& tensor);
virtual std::optional<c10::SymInt> next_size();
virtual c10::IValue next_ivalue();
};
struct SwapWithProxies : public SwapInterface {
explicit SwapWithProxies(AutogradCompilerCall& compiler, TraceState& state): compiler_(compiler), state_(state) {}
std::optional<at::Tensor> tensor(const at::Tensor& tensor) override {
TensorArg& arg = compiler_.tensor_args.lookup(tensor);
if (arg.defined()) {
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
return arg.proxy_tensor;
}
return std::nullopt;
}
std::optional<at::Tensor> tensor(const SavedVariable& t) override {
TensorArg& arg = compiler_.tensor_args.lookup(t);
if (arg.defined()) {
return arg.proxy_tensor;
}
return std::nullopt;
}
std::optional<c10::SymInt> next_size() override {
return state_.next_sym_size();
}
c10::IValue next_ivalue() override {
return compiler_.lifted_ivalue_args.next_proxy();
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
AutogradCompilerCall& compiler_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
TraceState& state_;
};
struct SwapWithReal : public SwapInterface {
explicit SwapWithReal(std::vector<at::Tensor> tensors, std::vector<c10::SymInt> sizes, std::vector<c10::IValue> ivalues)
: tensors_(std::move(tensors)), sizes_(std::move(sizes)), ivalues_(std::move(ivalues)) {}
std::optional<at::Tensor> tensor(const at::Tensor& _ignored) override {
auto result = tensors_[tensors_idx];
tensors_idx++;
return result;
}
std::optional<at::Tensor> tensor(const SavedVariable& _ignored) override {
auto result = tensors_[tensors_idx];
tensors_idx++;
return result;
}
std::optional<c10::SymInt> next_size() override {
auto result = sizes_[sizes_idx];
sizes_idx++;
return result;
}
c10::IValue next_ivalue() override {
auto result = ivalues_[ivalues_idx];
ivalues_idx++;
return result;
}
std::vector<at::Tensor> tensors_;
int64_t tensors_idx = 0;
std::vector<c10::SymInt> sizes_;
int64_t sizes_idx = 0;
std::vector<c10::IValue> ivalues_;
int64_t ivalues_idx = 0;
};
class SwapSavedVariables {
// SwapSavedVariables is used during the tracing/compilation phase after a
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
// allows tracing to happen, then swaps them back afterwards.
public:
void before(at::Tensor& t) {
TensorArg& arg = compiler.tensor_args.lookup(t);
auto replacement = state->tensor(t);
stashed_tensors.save(&t, std::move(t));
if (arg.defined()) {
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
t = arg.proxy_tensor;
if (replacement.has_value()) {
t = *replacement;
}
}
void after(at::Tensor& t) {
@ -647,12 +736,11 @@ class SwapSavedVariables {
}
void before(SavedVariable& t) {
TensorArg& arg = compiler.tensor_args.lookup(t);
auto replacement = state->tensor(t);
stashed_variables.save(&t, std::move(t));
if (arg.defined()) {
if (replacement.has_value()) {
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
t = SavedVariable(arg.proxy_tensor, false);
t = SavedVariable(replacement.value(), false);
at::SavedTensorDefaultHooks::set_tracing(prior);
}
}
@ -662,7 +750,7 @@ class SwapSavedVariables {
void before(c10::SymInt& t) {
stashed_symints.save(&t, c10::SymInt(t));
auto opt_value = state.next_sym_size();
auto opt_value = state->next_size();
if (opt_value.has_value()) {
t = *opt_value; // dynamic shape
}
@ -677,7 +765,7 @@ class SwapSavedVariables {
} else {
stashed_ivalues.save(&iv, at::IValue(iv));
if (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat()) {
iv = compiler.lifted_ivalue_args.next_proxy(&iv);
iv = state->next_ivalue();
}
}
}
@ -824,7 +912,17 @@ class SwapSavedVariables {
TraceState& s,
PyObject* p,
const NodeCall& n)
: compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
: py_compiler(p), curr_node_call(n) {
state = std::make_shared<SwapWithProxies>(c, s);
}
SwapSavedVariables(
std::vector<at::Tensor> a,
std::vector<at::SymInt> b,
std::vector<at::IValue> c,
PyObject* p,
const NodeCall& n)
: state(std::static_pointer_cast<SwapInterface>(std::make_shared<SwapWithReal>(std::move(a), std::move(b), std::move(c)))), py_compiler(p), curr_node_call(n) {}
PyObject* get_py_compiler() {
return py_compiler;
@ -875,9 +973,10 @@ class SwapSavedVariables {
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
AutogradCompilerCall& compiler;
// AutogradCompilerCall& compiler;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
TraceState& state;
std::shared_ptr<SwapInterface> state;
// TraceState& state;
// This is a borrowed reference, we do not increment ownership, or lower it,
// it's lifecycle is entirely longer than this objects.
PyObject* py_compiler;

View File

@ -451,6 +451,32 @@ void set_ivalue_proxies(
}
}
static PyObject* call_capture(
PyObject* self,
CacheNode& cache,
AutogradCompilerCall& compiler_call,
size_t num_outputs,
PyObject* nodecalls) {
static PyObject* method_name = PyUnicode_InternFromString("capture");
THPObjectPtr pyinput(THPVariable_WrapList(compiler_call.tensor_args.inputs));
THPObjectPtr pysizeinput(cache.wrap_dynamic_inputs());
THPObjectPtr pyivalueargsinput(
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args));
THPObjectPtr pynodeorigins(
wrap_node_origins(compiler_call, PyTuple_GET_SIZE(pysizeinput.get())));
PyObject* py_num_outputs = THPUtils_packUInt32(num_outputs);
return check(PyObject_CallMethodObjArgs(
self,
method_name,
pyinput.get(),
pysizeinput.get(),
pyivalueargsinput.get(),
pynodeorigins.get(),
nodecalls,
py_num_outputs,
nullptr));
}
static TraceState call_begin_capture(
PyObject* self,
CacheNode& cache,
@ -600,112 +626,10 @@ CacheNode* _compiled_autograd_impl(
ClosingTHPObjectPtr py_compiler(
check(PyObject_CallNoArgs((the_autograd_compiler))));
TraceState state = call_begin_capture(
py_compiler, *cache, compiler_call, output_edges.size());
InputBuffers input_buffers;
// nodes
py::object nodecalls = py::cast(calls);
PyObject* res = call_capture(py_compiler, *cache, compiler_call, output_edges.size(), nodecalls.ptr());
for (size_t i = 0; i < calls.size(); i++) {
NodeCall& call = *calls[i];
// TODO(jansel): consider adding some of this stuff:
// guard(local_graph_task); NodeGuard ndguard(task.fn_); const auto
// opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
// c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
// CheckpointValidGuard cpvguard(graph_task);
// at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
// if (C10_UNLIKELY(step_callbacks.has_value())) { ... }
variable_list inputs =
std::move(input_buffers.lookup(call.node.get()).buffer);
input_buffers.erase(call.node.get());
if (!call.tensor_pre_hooks.empty()) {
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
for (const auto& hook : call.tensor_pre_hooks) {
pyinputs = check(PyObject_CallMethod(
py_compiler,
"tensor_pre_hook",
"Oii",
pyinputs.get(),
hook.first,
hook.second));
}
inputs = THPVariable_UnpackList(pyinputs);
}
for (const auto& graph_output : call.graph_output) {
int input_nr = graph_output.first;
int output_index = graph_output.second;
TORCH_INTERNAL_ASSERT(
output_index < static_cast<int>(state.outputs.size()));
TORCH_INTERNAL_ASSERT(!state.outputs[output_index].defined());
state.outputs[output_index] = inputs[input_nr];
}
if (!call.needed) {
continue;
}
if (!call.pre_hooks.empty()) {
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
for (const auto hook : call.pre_hooks) {
pyinputs = check(PyObject_CallMethod(
py_compiler.get(), "pre_hook", "Oi", pyinputs.get(), hook));
}
inputs = THPVariable_UnpackList(pyinputs);
}
std::string _node_name = call.node->name();
THPObjectPtr node_name(PyUnicode_FromString(_node_name.data()));
TORCH_INTERNAL_ASSERT(node_name != nullptr);
THPObjectPtr set_node_origin(
PyObject_GetAttrString(py_compiler.get(), "set_node_origin"));
PyObject* pyobj = Py_None;
if (auto pynode = std::dynamic_pointer_cast<PyNode>(call.node)) {
pyobj = pynode->obj;
}
check(PyObject_CallFunction(
set_node_origin, "OIO", node_name.get(), i, pyobj, nullptr));
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
variable_list outputs = call.node->apply_with_saved(inputs, saved);
saved.debug_asserts();
saved.before(call.node->next_edges());
validate_outputs(
call.node->next_edges(), outputs, [&](const std::string& msg) {
std::ostringstream ss;
ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
<< msg;
return ss.str();
});
saved.after(call.node->next_edges());
saved.debug_asserts();
if (!call.post_hooks.empty()) {
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
THPObjectPtr pyoutputs(THPVariable_WrapList(outputs));
for (const auto hook : call.post_hooks) {
pyoutputs = check(PyObject_CallMethod(
py_compiler.get(),
"post_hook",
"OOi",
pyoutputs.get(),
pyinputs.get(),
hook));
}
outputs = THPVariable_UnpackList(pyoutputs);
}
for (const auto i : c10::irange(outputs.size())) {
auto& output = outputs[i];
const auto& next = call.node->next_edge(i);
if (next.is_valid() && output.defined()) {
input_buffers.lookup(next.function.get())
.add(
next.input_nr, std::move(output), std::nullopt, std::nullopt);
}
}
}
PyObject* res = check(call_end_capture(py_compiler, state.outputs));
TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
TORCH_CHECK(
PyTuple_Size(res) == 2,
@ -718,15 +642,19 @@ CacheNode* _compiled_autograd_impl(
TORCH_CHECK(
PyCallable_Check(cache->compiled_fn),
"Expected end_capture to return compiled_fn");
state.debug_asserts();
// TODO(rzou): what is this?
// state.debug_asserts();
} // End cache miss region
// TODO(rzou): need some mechanism to release the variables when we're ready.
// TODO(jansel): clear grads we will overwrite below
if (!graph_task.keep_graph_) {
for (auto& call : calls) {
call->node->release_variables();
}
}
// if (!graph_task.keep_graph_) {
// for (auto& call : calls) {
// call->node->release_variables();
// }
// }
auto ca = py::module::import("torch._dynamo.compiled_autograd");
ca.attr("set_global_nodecalls")(calls);
*graph_arg_inputs = THPVariable_WrapList(compiler_call.tensor_args.inputs);
*graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);