Compare commits

...

1 Commits

Author SHA1 Message Date
e5937dc68c [wip] "Python compiled autograd II"
Today, compiled autograd runs in two phases:
- a make_fx-like phase that uses FakeTensors + fx.Proxy
  to create an fx.Graph from the current autograd graph
- a second phase that applies torch.compile to the result of
  the previous phase.

This PR changes it so that compiled autograd no longer uses FakeTensors in
its first phase.

At a high level:
- [Here's an example of the new graph](https://gist.github.com/zou3519/20272a3e31124621843f53ae66671ed7)
  compiled autograd's first phase produces.
- In order to acquire this graph, we get compiled autograd to effectively
  torch.fx.symbolic_trace over a new `python_autograd` function that runs the
  autograd graph.
- The graph contains calls to `apply_with_saved`, which is a way to apply a
  given node with some inputs and some specific saved values. This is different
  from the existing `Node::apply_with_saved` because that one accepts
  the saved values for the *entire graph*.
- There are also calls to `validate_outputs`, which also needs some
  saved values because it need to swizzle out input metadata state.
- We support graph breaks on unsupported C++ custom ops via emitting
  a special `apply_with_saved_dynamo_disabled` function. The state of
  C++ torch::autograd::Function is completely iterable by us, since
  we ask users to only save values via `ctx->save_for_backward` and
  `ctx->saved[...]`.

There's a long tail of things that don't work yet:
- we don't support all types of hooks yet
- we don't inline user-defined autograd.Function into this graph yet
- we don't inline the backward of torch.compile'd regions
- we need to somehow free the autograd graph when we're done with it
- many more TODOs inline.

ghstack-source-id: 23a98023d271db220a29db66631e9087fb8e2325
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138101
2024-10-17 19:09:43 -04:00
12 changed files with 744 additions and 142 deletions

140
r2.py Normal file
View File

@ -0,0 +1,140 @@
# type: ignore
import torch
import torch.utils.cpp_extension
def compiler_fn(gm):
# return gm
return torch.compile(gm, backend="eager", fullgraph=False)
# ===========================================================
# Basic test with a hook that has side effects
# Test case 1: a hook
x = torch.tensor([1., 2., 3.], requires_grad=True)
y = x ** 2
z = y.sum()
im_grad = []
def hook(grad):
im_grad.append(grad)
return 2 * grad
y.register_hook(hook)
with torch._dynamo.compiled_autograd.enable(compiler_fn):
z.backward()
assert torch.allclose(x.grad, 4 * x)
assert torch.allclose(im_grad[0], torch.ones_like(y))
# ===========================================================
# Unsupported C++ autograd node should graph break.
# This is better than the current compiled autograd behavior of "error out"
# and brings us a step closer to having "compiled autograd on by default".
# In theory we can also add a config that automatically treats
# it as an opaque callable, but such a config is unsound.
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = false;
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& x) {
return x;
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext *ctx,
torch::autograd::variable_list grad_output) {
// not traceable
*grad_output[0].data_ptr<float>() = 3.14;
return grad_output;
}
};
torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
return CustomOpAutogradFunction::apply(x);
}
TORCH_LIBRARY_FRAGMENT(mylib, m) {
m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
}
"""
module = torch.utils.cpp_extension.load_inline(
name="mylib",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
verbose=True,
)
x = torch.ones(2, 2, requires_grad=True)
out = torch.ops.mylib.custom_op_backed_by_autograd_fn(
x
)
loss = out.sum()
with torch._dynamo.compiled_autograd.enable(compiler_fn):
loss.backward()
expected = torch.ones_like(x) * 3.14
assert torch.allclose(x.grad, expected)
# ===========================================================
# Tests that we don't bake in "guessed" metadata.
# This test case would have erroed out in the previous
# compiled autograd.
import torch
import torch.utils.cpp_extension
cpp_source2 = """
struct CustomOpAutogradFunction2 : public torch::autograd::Function<CustomOpAutogradFunction2> {
static constexpr bool is_traceable = true;
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& x) {
return x;
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext *ctx,
torch::autograd::variable_list grad_output) {
if (grad_output[0].is_contiguous()) {
return {2 * grad_output[0]};
} else {
return {3 * grad_output[0]};
}
}
};
torch::Tensor custom_op_backed_by_autograd_fn2(torch::Tensor x) {
return CustomOpAutogradFunction2::apply(x);
}
TORCH_LIBRARY_FRAGMENT(mylib, m) {
m.def("custom_op_backed_by_autograd_fn2", custom_op_backed_by_autograd_fn2);
}
"""
module = torch.utils.cpp_extension.load_inline(
name="mylib",
cpp_sources=cpp_source2,
functions="custom_op_backed_by_autograd_fn2",
verbose=True,
)
x = torch.tensor([[1., 2., 3.], [4, 5, 6]], requires_grad=True)
y = torch.ops.mylib.custom_op_backed_by_autograd_fn2(x)
z = y.clone()
w = z.sum()
def hook(grad):
# return a contiguous grad.
# The previous compiled autograd would have "guessed" that
# the tensor is not contiguous.
assert not grad.is_contiguous()
return grad.contiguous()
z.register_hook(hook)
with torch._dynamo.compiled_autograd.enable(lambda x: x):
w.backward()
assert torch.allclose(x.grad, 2 * torch.ones_like(x))

190
torch/_compiled_autograd.py Normal file
View File

@ -0,0 +1,190 @@
# type: ignore
import threading
import torch
from ._compile import _disable_dynamo
from ._C import _autograd
# TODO(rzou): why doesn't torch.fx.wrap work directly?
from torch.fx._symbolic_trace import _create_wrapped_func as wrap
"""
TODO(rzou): did we really need a new file? I did it to appease trace_rules.
"""
def python_autograd(saved_state, hooks, nodecalls, num_outputs, arange):
"""Given the state of the autograd graph (the saved tensors/sizes/scalar,
hooks, and the actual nodes), execute it in Python.
Compiled Autograd uses the equivalent of torch.fx.symbolic_trace over
this function to produce a graph that can then be Dynamo'ed.
NB: Before executing this function (or an acquired graph version of it)
on real Tensors, please call set_global_nodecalls(nodecalls) to set the
current autograd nodes structure state. We intentionally hide this state
from the graph so that Dynamo doesn't need to deal with proxying it into
the graph.
TODO(rzou): Compiled Autograd is responsible for calling set_global_nodecalls
using the current nodecalls data structure. If the user did not specify
retain_graph=True, then something needs to free it later,
so we don't end up keeping the nodes around forever.
"""
node_to_idx_data = {node_id(call.node): idx for idx, call in enumerate(nodecalls)}
def node_to_idx(node):
return node_to_idx_data[torch._compiled_autograd.node_id(node)]
input_buffers = {}
def lookup_input_buffer(node_idx, num_inputs):
if node_idx not in input_buffers:
input_buffers[node_idx] = [None] * num_inputs
return input_buffers[node_idx]
saved_state = iter(SavedState(
nodecalls,
saved_state[0],
saved_state[1],
saved_state[2],
))
graph_outputs = [None] * num_outputs
for idx, call in enumerate(nodecalls):
node_idx = arange[idx]
inputs = lookup_input_buffer(idx, call.node.num_inputs())
# Given all of the saved state, retrieve the saved state that matters
# for the current node call.
apply_state, validate_outputs_state = next(saved_state)
for hook_idx, input_idx in call.tensor_pre_hooks:
inputs[input_idx] = call_hook(hooks[hook_idx], inputs[input_idx], hook_type="pre_hook")
for input_nr, result_idx in call.graph_output:
graph_outputs[result_idx] = inputs[input_nr]
if not call.needed:
continue
if call.node.is_compiled_autograd_traceable():
outputs = apply_with_saved(node_idx, inputs, *apply_state)
else:
outputs = apply_with_saved_dynamo_disabled(node_idx, inputs, *apply_state)
outputs = validate_outputs(node_idx, outputs, *validate_outputs_state)
for hook_idx, input_idx in call.post_hooks:
call_hook(hooks[hook_idx], outputs, inputs, hook_type="post_hook")
for output_idx in range(call.node.num_outputs()):
output = outputs[output_idx]
next_edge = call.node.next_edge(output_idx)
if not next_edge.is_valid():
continue
next_node = next_edge.function
input_buffer = lookup_input_buffer(node_to_idx(next_node), next_node.num_inputs())
updated_buffer = accumulate(input_buffer[next_edge.input_nr], output)
input_buffer[next_edge.input_nr] = updated_buffer
return graph_outputs
global_nodecalls = threading.local()
def get_node(idx):
return global_nodecalls.thread_local[idx].node
def set_global_nodecalls(nodecalls):
global_nodecalls.thread_local = nodecalls
@wrap
def apply_with_saved(node_idx, inputs, saved_tensors, saved_sizes, saved_scalars):
"""
Applies the node at global_nodecalls[node_idx] using the inputs and saved values.
"""
node = get_node(node_idx)
outputs = _autograd.apply_with_saved(global_nodecalls.thread_local[node_idx], inputs, saved_tensors, list(saved_sizes), saved_scalars)
return outputs
@_disable_dynamo
@wrap
def apply_with_saved_dynamo_disabled(node_idx, inputs, saved_tensors, saved_sizes, saved_scalars):
"""
This is apply_with_saved, but also induces a graph break in Dynamo.
"""
return apply_with_saved(node_idx, inputs, saved_tensors, saved_sizes, saved_scalars)
@wrap
def validate_outputs(node_idx, outputs, saved_tensors, saved_sizes, saved_scalars):
"""
Validates the outputs of the node at global_nodecalls[node_idx]. This requires
swizzling out some input metadata state of the next nodes, which is why
it also accepts some saved variables.
"""
outputs = _autograd.validate_outputs_with_saved(global_nodecalls.thread_local[node_idx], outputs, saved_tensors, list(saved_sizes), saved_scalars)
return outputs
def node_id(node):
if node is None:
breakpoint()
assert node is not None
return _autograd.node_id(node)
def arange(num):
return list(range(num))
@wrap
def call_hook(*args, **kwargs):
return torch._dynamo.external_utils.call_hook(*args, **kwargs)
class IterableWrapper:
def __init__(self, noniterable, size):
self.noniterable = noniterable
self.idx = 0
self.size = size
def __iter__(self):
return self
def __next__(self):
assert self.idx < self.size
result = self.noniterable[self.idx]
self.idx += 1
return result
class SavedState:
def __init__(self, nodecalls, tensors, sizes, scalars):
self.tensors = tensors
self.sizes = sizes
self.scalars = scalars
self.nodecalls = iter(nodecalls)
def __iter__(self):
return self
def __next__(self):
call = next(self.nodecalls)
def get_next(collection_info):
tensors = [next(self.tensors) for _ in range(collection_info.num_saved_tensors)]
sizes = [next(self.sizes) for _ in range(collection_info.num_saved_sizes)]
scalars = [next(self.scalars) for _ in range(collection_info.num_saved_ivalues)]
return (tensors, sizes, scalars)
saved_state_for_apply = get_next(call.compiled_args_info)
saved_state_for_validate_output = get_next(call.next_edges_info)
return saved_state_for_apply, saved_state_for_validate_output
@wrap
def accumulate(old_var, var):
if old_var is None:
return var
if var is None:
return old_var
return old_var + var

View File

@ -82,6 +82,49 @@ class AutogradCompilerInstance:
def source(name, idx) -> GetItemSource:
return GetItemSource(LocalSource(name), idx)
def capture(self, tensors, sizes, scalars, origins, nodecalls, num_outputs):
dynamic_sizes = tuple(s for s in sizes if s is not None)
counters["compiled_autograd"]["captures"] += 1
inputs_origins, sizes_origins, scalars_origins = origins
self.fx_tracer.root = torch.nn.Module()
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
self.fx_tracer.tensor_attrs = {}
inputs_proxy, dynamic_sizes_proxy, scalars_proxy, self.hooks_proxy = (
self.fx_tracer.create_proxy("placeholder", name, (), {})
for name in self.graph_placeholders
)
sizes_proxy = [None] * len(sizes)
dynamic_sizes_next = 0
for idx in range(len(sizes)):
if sizes[idx] is not None:
sizes_proxy[idx] = dynamic_sizes[dynamic_sizes_next]
dynamic_sizes_next += 1
from torch._compiled_autograd import IterableWrapper, python_autograd, arange
arange_proxy = self.fx_tracer.create_proxy(
kind="call_function",
target=arange,
args=(len(nodecalls),),
kwargs={}
)
graph_outputs = python_autograd(
(
IterableWrapper(inputs_proxy, len(tensors)),
IterableWrapper(sizes_proxy, len(sizes)),
IterableWrapper(scalars_proxy, len(scalars)),
),
self.hooks_proxy,
nodecalls,
num_outputs,
arange_proxy,
)
return self.end_capture(graph_outputs)
def begin_capture(
self,
inputs: List[torch.Tensor],
@ -308,8 +351,10 @@ class AutogradCompilerInstance:
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
{},
)
self.rename_aot_dispatcher_nodes()
self.reorder_accumulate_grad_nodes()
# TODO(rzou): we didn't inline the AOTDispatcher nodes
# self.rename_aot_dispatcher_nodes()
# TODO(rzou): we need to transform AccumulateGrad nodes into torch.inductor.accumulate_grad_.
# self.reorder_accumulate_grad_nodes()
runtime_inputs_to_move: List[int] = []
if snapshot_cudagraph_enabled():
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
@ -317,6 +362,7 @@ class AutogradCompilerInstance:
graph = GraphModule(
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
)
graph.print_readable()
set_locals_to_steal(graph, ["inputs"])
lazy_graph_code = lazy_format_graph_code(
"Compiled autograd graph",
@ -562,3 +608,5 @@ def reset() -> None:
assert not in_compiled_autograd_region
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
from torch._compiled_autograd import set_global_nodecalls

View File

@ -950,6 +950,9 @@ class OutputGraph:
list_name = arg.source.local_name
assert list_name in self.code_options["co_varnames"]
for x in needs_alias[list_name]:
if not hasattr(x.source, "index"):
# TODO(rzou): idk
breakpoint()
list_idx = x.source.index
if list_idx not in visited:
alias_name = self.new_var(

View File

@ -134,6 +134,13 @@ If you are removing an existing torch level API:
"""
manual_torch_name_rule_map = {
"torch._compiled_autograd.CA_apply_with_saved": TorchInGraphFunctionVariable,
"torch._compiled_autograd.accumulate2": TorchInGraphFunctionVariable,
"torch._compiled_autograd.CA_validate_outputs": TorchInGraphFunctionVariable,
# "torch._compiled_autograd.CA_apply_with_saved_dynamo_disabled": TorchInGraphFunctionVariable,
"torch._compiled_autograd.CA_update_input_buffers": TorchInGraphFunctionVariable,
"torch._compiled_autograd.CA_input_buffers_init": TorchInGraphFunctionVariable,
"torch._compiled_autograd.CA_input_buffers_lookup": TorchInGraphFunctionVariable,
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
@ -3237,6 +3244,7 @@ if torch.distributed.is_available():
# We are using python module name instead of file or directory object to avoid circular dependency.
# Please keep this sorted alphabetically.
MOD_INLINELIST = [
"torch._compiled_autograd",
"torch._decomp",
"torch._dynamo._trace_wrapped_higher_order_op",
"torch._dynamo.comptime",

View File

@ -1219,7 +1219,12 @@ class VariableBuilder:
maybe_gm = self.tx.output.local_scope.get("self")
if isinstance(
self.source, LocalSource
) and self.source.local_name in get_locals_to_steal(maybe_gm):
# TODO(rzou): We changed compiled autograd to pass all of the inputs saved
# instead of a de-duplicated list. Unfortunately that makes the input
# stealing logic go haywire. We can either fix it or figure out
# how to deal with a de-duplicated list (the problem is
# mapping the de-duplicated saved tensors back to the nodes that need them).
) and self.source.local_name in get_locals_to_steal(maybe_gm) and False:
# The input tensor list to dynamo from compiled autograd may contain activations
# which are freed as they are used in inductor. Dynamo's default behavior is to
# lift all tensors to the graph inputs, but this will cause dynamo to hold an
@ -1249,13 +1254,17 @@ class VariableBuilder:
source_i = GetItemSource(base=source, index=i, index_is_slice=False)
# access unpacked tensor from this list instead of from a lifted arg
self.tx.output.input_source_to_var[source_i] = tensor_variable
tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(
value[i]
)
if isinstance(tensor_variable, TensorVariable):
tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(
value[i]
)
guard = functools.partial(
GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i])
)
guard = functools.partial(
GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i])
)
else:
# TODO(rzou): None guard?
pass
guards.append(source_i.make_guard(guard))
install_guard(*guards, skip=1)

View File

@ -188,16 +188,25 @@ struct CppNode : public Node {
void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
void save_variables_to_ctx();
bool is_compiled_autograd_traceable() override {
static_assert(
std::is_same_v<std::remove_cv_t<decltype(T::is_traceable)>, bool>);
return T::is_traceable;
}
void compiled_args(CompiledNodeArgs& args) override {
static_assert(
std::is_same_v<std::remove_cv_t<decltype(T::is_traceable)>, bool>);
if (!T::is_traceable) {
throw std::runtime_error(
std::string(
"Attempting to trace a potentially unsafe C++ autograd function: ") +
name() +
". It may be possible to trace it safely, please refer to the instructions in: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/.");
}
// if (!T::is_traceable) {
// throw std::runtime_error(
// std::string(
// "Attempting to trace a potentially unsafe C++ autograd
// function: ") +
// name() +
// ". It may be possible to trace it safely, please refer to the
// instructions in:
// https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/.");
// }
// although neither of the 2 methods below have uniqueness guarantees
// it is unlikely for them to collide at the same time

View File

@ -3,6 +3,7 @@
#include <c10/util/ThreadLocal.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/dynamo/compiled_autograd.h>
#include <ATen/ATen.h>

View File

@ -563,6 +563,10 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
/// release variables as they run.
virtual void will_release_variables() {}
virtual bool is_compiled_autograd_traceable() {
return true;
}
/// Returns true if this function is traceable. An op is traceable if all
/// operations happening within `apply()` are performed on autograd
/// `Variables` (i.e. apply mostly instantiates and applies other functions).

View File

@ -1,4 +1,5 @@
#include <torch/csrc/python_headers.h>
#include <memory>
#include <ATen/PythonTorchFunctionTLS.h>
#include <ATen/SavedTensorHooks.h>
@ -14,6 +15,7 @@
#include <torch/csrc/autograd/VariableTypeUtils.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/input_metadata.h>
@ -26,6 +28,7 @@
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/dynamo/compiled_autograd.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/kineto_shim.h>
@ -42,6 +45,7 @@
using torch::impl::py_context_manager;
using torch::impl::py_context_manager_DEPRECATED;
using namespace torch::dynamo::autograd;
namespace {
@ -79,6 +83,55 @@ struct EnablePythonDispatcher {
c10::impl::PyInterpreter* old_;
};
std::vector<at::Tensor> toVec(
const std::vector<std::optional<at::Tensor>>& ts) {
std::vector<at::Tensor> result;
for (const auto& opt_tensor : ts) {
if (opt_tensor.has_value()) {
result.push_back(opt_tensor.value());
} else {
result.emplace_back();
}
}
return result;
}
variable_list validate_outputs_with_saved(
const NodeCall& nodecall,
std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& saved_tensors,
const std::vector<std::optional<at::SymInt>>& saved_sizes,
const std::vector<at::IValue>& saved_ivalues) {
auto saved = SwapSavedVariables(
saved_tensors, saved_sizes, saved_ivalues, nullptr, nodecall);
saved.before(nodecall.node->next_edges());
torch::autograd::validate_outputs(
nodecall.node->next_edges(), outputs, [&](const std::string& msg) {
std::ostringstream ss;
ss << "[Compiled Autograd Tracing: " << nodecall.node->name() << "] "
<< msg;
return ss.str();
});
saved.after(nodecall.node->next_edges());
return outputs;
}
variable_list apply_with_saved314(
const NodeCall& nodecall,
const std::vector<std::optional<at::Tensor>>& inputs,
const std::vector<std::optional<at::Tensor>>& saved_tensors,
const std::vector<std::optional<at::SymInt>>& saved_sizes,
const std::vector<at::IValue>& saved_ivalues) {
auto saved = SwapSavedVariables(
toVec(saved_tensors), saved_sizes, saved_ivalues, nullptr, nodecall);
auto outputs = nodecall.node->apply_with_saved(toVec(inputs), saved);
return outputs;
}
uint64_t node_id(const std::shared_ptr<Node>& node) {
return reinterpret_cast<uint64_t>(node.get());
}
struct EnablePreDispatch {
EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {}
c10::impl::IncludeDispatchKeyGuard guard_;
@ -491,6 +544,50 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
}
});
// compiled_autograd stuff
py::class_<torch::autograd::Node, std::shared_ptr<torch::autograd::Node>>(
m, "Node")
.def("compiled_args", &torch::autograd::Node::compiled_args)
.def("next_edge", &torch::autograd::Node::next_edge)
.def(
"is_compiled_autograd_traceable",
&torch::autograd::Node::is_compiled_autograd_traceable)
.def("name", &torch::autograd::Node::name)
.def("num_outputs", &torch::autograd::Node::num_outputs)
.def("num_inputs", &torch::autograd::Node::num_inputs);
py::class_<torch::autograd::Edge>(m, "Edge")
.def("is_valid", &torch::autograd::Edge::is_valid)
.def_readonly("input_nr", &torch::autograd::Edge::input_nr)
.def_readonly("function", &torch::autograd::Edge::function);
py::class_<CollectionInfo>(m, "CollectionInfo")
.def_readonly("num_saved_tensors", &CollectionInfo::num_saved_tensors)
.def_readonly("num_saved_sizes", &CollectionInfo::num_saved_sizes)
.def_readonly("num_saved_ivalues", &CollectionInfo::num_saved_ivalues);
py::class_<torch::dynamo::autograd::NodeCall>(m, "NodeCall")
.def_readonly("node", &NodeCall::node)
.def_readonly("compiled_args_info", &NodeCall::compiled_args_info)
.def_readonly("next_edges_info", &NodeCall::next_edges_info)
.def_readonly("tensor_pre_hooks", &NodeCall::tensor_pre_hooks)
.def_readonly("post_hooks", &NodeCall::post_hooks)
.def_readonly("graph_output", &NodeCall::graph_output)
.def_readonly("needed", &NodeCall::needed);
py::class_<torch::dynamo::autograd::CompiledNodeArgs>(m, "CompiledNodeArgs")
.def(py::init<AutogradCompilerCall&, NodeCall&>());
py::class_<torch::dynamo::autograd::AutogradCompilerCall>(
m, "AutogradCompilerCall")
.def(py::init<>());
m.def("apply_with_saved", &apply_with_saved314);
m.def("validate_outputs_with_saved", &validate_outputs_with_saved);
m.def("node_id", &node_id);
// py::class_<SwapInterface,PySwapInterface>(m, "SwapInterface");
// py::class_<SwapWithReal,SwapInterface>(m, "SwapWithReal")
// .def(py::init<std::vector<at::Tensor>,std::vector<c10::SymInt>,std::vector<c10::IValue>>())
// ;
// py::class_<SwapSavedVariables>(m, "SwapSavedVariables")
// .def(py::init<std::vector<at::Tensor>,std::vector<c10::SymInt>,std::vector<c10::IValue>,PyObject*,const
// NodeCall&>())
// ;
_C_m.def("_activate_gpu_trace", []() { activateGPUTrace(); });
py_context_manager_DEPRECATED<c10::InferenceMode, bool>(

View File

@ -69,7 +69,15 @@ struct CacheKey {
const uint8_t* key;
};
struct NodeCall {
struct CollectionInfo {
int num_saved_tensors = 0;
int num_saved_sizes = 0;
int num_saved_ivalues = 0;
};
enum CollectionMode { COMPILED_ARGS, NEXT_EDGES };
struct TORCH_API NodeCall {
NodeCall(uint32_t id_, std::shared_ptr<Node> node_)
: id(id_), node(std::move(node_)) {}
@ -84,6 +92,24 @@ struct NodeCall {
std::vector<int> post_hooks;
std::vector<int> post_acc_grad_hooks;
std::vector<std::pair<int, int>> graph_output;
CollectionInfo& collection_info() {
if (mode == CollectionMode::NEXT_EDGES) {
return next_edges_info;
} else {
return compiled_args_info;
}
}
// Given the full list of saved arguments (saved tensors, saved sizes,
// saved scalars), we want to be able to map them back to which node
// they came from.
// The way we do this is that we store information on how many
// tensors/sizes/scalars each Node uses.
CollectionMode mode = CollectionMode::COMPILED_ARGS;
CollectionInfo compiled_args_info;
CollectionInfo next_edges_info;
bool needed = true;
};
@ -143,9 +169,9 @@ struct TensorArgs {
auto impl = tensor.unsafeGetTensorImpl();
auto it = _args.find(impl);
if (it == _args.end()) {
TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1);
// TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1);
it = _args.emplace(impl, TensorArg(_next_id++)).first;
inputs.emplace_back(tensor);
// inputs.emplace_back(tensor);
if (active_node_call_idx.has_value()) {
input_origins.emplace_back(active_node_call_idx.value());
}
@ -160,6 +186,9 @@ struct TensorArgs {
}
TensorArg& add(const at::Tensor& tensor) {
// unconditionally add the tensor to inputs... Dynamo will de-dupe them
// later
inputs.emplace_back(tensor);
return lookup(tensor, true);
}
@ -208,6 +237,11 @@ struct LiftedIValueArgs {
return iv_arg.proxy;
}
at::IValue& next_proxy() {
auto& iv_arg = args.at(next++);
return iv_arg.proxy;
}
void add(const at::IValue* iv) {
args.emplace_back(iv);
if (active_node_call_idx.has_value()) {
@ -278,13 +312,16 @@ class CompiledNodeArgs {
}
void collect(const at::Tensor& t) {
_node_call.collection_info().num_saved_tensors++;
collect(_compiler.tensor_args.add(t));
}
void collect(const SavedVariable& sv, bool is_output) {
_node_call.collection_info().num_saved_tensors++;
collect(
_compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
}
void collect(const c10::SymInt& t) {
_node_call.collection_info().num_saved_sizes++;
_compiler.add_size_input(t);
}
void collect(const std::vector<SavedVariable>& t, bool is_output) {
@ -366,6 +403,7 @@ class CompiledNodeArgs {
!nested &&
(iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat())) {
// can't lift ivalues nested in collections
_node_call.collection_info().num_saved_ivalues++;
_compiler.lifted_ivalue_args.add(&iv);
} else {
try {
@ -629,17 +667,110 @@ struct TraceState {
variable_list outputs;
};
struct TORCH_API SwapInterface {
virtual ~SwapInterface() = default;
virtual std::optional<at::Tensor> tensor(const at::Tensor& tensor) = 0;
virtual std::optional<at::Tensor> tensor(const SavedVariable& tensor) = 0;
virtual std::optional<c10::SymInt> next_size() = 0;
virtual c10::IValue next_ivalue() = 0;
};
struct SwapWithProxies : public SwapInterface {
explicit SwapWithProxies(AutogradCompilerCall& compiler, TraceState& state)
: compiler_(compiler), state_(state) {}
~SwapWithProxies() override = default;
std::optional<at::Tensor> tensor(const at::Tensor& tensor) override {
TensorArg& arg = compiler_.tensor_args.lookup(tensor);
if (arg.defined()) {
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
return arg.proxy_tensor;
}
return std::nullopt;
}
std::optional<at::Tensor> tensor(const SavedVariable& t) override {
TensorArg& arg = compiler_.tensor_args.lookup(t);
if (arg.defined()) {
return arg.proxy_tensor;
}
return std::nullopt;
}
std::optional<c10::SymInt> next_size() override {
return state_.next_sym_size();
}
c10::IValue next_ivalue() override {
return compiler_.lifted_ivalue_args.next_proxy();
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
AutogradCompilerCall& compiler_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
TraceState& state_;
};
// The previous compiled autograd implementation was about swapping in
// ProxyTensors for a node. Given a single node and some saved
// tensors/sizes/scalars, we needed some way to swap in those saved
// tensors/sizes/scalars. That's what SwapWithReal is.
struct SwapWithReal : public SwapInterface {
explicit SwapWithReal(
std::vector<at::Tensor> tensors,
std::vector<std::optional<c10::SymInt>> sizes,
std::vector<c10::IValue> ivalues)
: tensors_(std::move(tensors)),
sizes_(std::move(sizes)),
ivalues_(std::move(ivalues)) {}
~SwapWithReal() override = default;
std::optional<at::Tensor> tensor(const at::Tensor& _ignored) override {
auto result = tensors_[tensors_idx];
tensors_idx++;
return result;
}
std::optional<at::Tensor> tensor(const SavedVariable& _ignored) override {
TORCH_INTERNAL_ASSERT(tensors_idx < tensors_.size());
auto result = tensors_[tensors_idx];
tensors_idx++;
return result;
}
std::optional<c10::SymInt> next_size() override {
TORCH_INTERNAL_ASSERT(sizes_idx < sizes_.size());
auto result = sizes_[sizes_idx];
sizes_idx++;
return result;
}
c10::IValue next_ivalue() override {
TORCH_INTERNAL_ASSERT(ivalues_idx < ivalues_.size());
auto result = ivalues_[ivalues_idx];
ivalues_idx++;
return result;
}
std::vector<at::Tensor> tensors_;
size_t tensors_idx = 0;
std::vector<std::optional<c10::SymInt>> sizes_;
size_t sizes_idx = 0;
std::vector<c10::IValue> ivalues_;
size_t ivalues_idx = 0;
};
class SwapSavedVariables {
// SwapSavedVariables is used during the tracing/compilation phase after a
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
// allows tracing to happen, then swaps them back afterwards.
public:
void before(at::Tensor& t) {
TensorArg& arg = compiler.tensor_args.lookup(t);
auto replacement = state->tensor(t);
stashed_tensors.save(&t, std::move(t));
if (arg.defined()) {
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
t = arg.proxy_tensor;
if (replacement.has_value()) {
t = *replacement;
}
}
void after(at::Tensor& t) {
@ -647,12 +778,11 @@ class SwapSavedVariables {
}
void before(SavedVariable& t) {
TensorArg& arg = compiler.tensor_args.lookup(t);
auto replacement = state->tensor(t);
stashed_variables.save(&t, std::move(t));
if (arg.defined()) {
if (replacement.has_value()) {
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
t = SavedVariable(arg.proxy_tensor, false);
t = SavedVariable(replacement.value(), false);
at::SavedTensorDefaultHooks::set_tracing(prior);
}
}
@ -662,7 +792,7 @@ class SwapSavedVariables {
void before(c10::SymInt& t) {
stashed_symints.save(&t, c10::SymInt(t));
auto opt_value = state.next_sym_size();
auto opt_value = state->next_size();
if (opt_value.has_value()) {
t = *opt_value; // dynamic shape
}
@ -677,7 +807,7 @@ class SwapSavedVariables {
} else {
stashed_ivalues.save(&iv, at::IValue(iv));
if (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat()) {
iv = compiler.lifted_ivalue_args.next_proxy(&iv);
iv = state->next_ivalue();
}
}
}
@ -824,7 +954,23 @@ class SwapSavedVariables {
TraceState& s,
PyObject* p,
const NodeCall& n)
: compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
: py_compiler(p), curr_node_call(n) {
state = std::make_shared<SwapWithProxies>(c, s);
}
SwapSavedVariables(
std::vector<at::Tensor> a,
std::vector<std::optional<at::SymInt>> b,
std::vector<at::IValue> c,
PyObject* p,
const NodeCall& n)
: state(std::static_pointer_cast<SwapInterface>(
std::make_shared<SwapWithReal>(
std::move(a),
std::move(b),
std::move(c)))),
py_compiler(p),
curr_node_call(n) {}
PyObject* get_py_compiler() {
return py_compiler;
@ -875,9 +1021,10 @@ class SwapSavedVariables {
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
AutogradCompilerCall& compiler;
// AutogradCompilerCall& compiler;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
TraceState& state;
std::shared_ptr<SwapInterface> state;
// TraceState& state;
// This is a borrowed reference, we do not increment ownership, or lower it,
// it's lifecycle is entirely longer than this objects.
PyObject* py_compiler;

View File

@ -451,6 +451,37 @@ void set_ivalue_proxies(
}
}
static PyObject* call_capture(
PyObject* self,
CacheNode& cache,
AutogradCompilerCall& compiler_call,
size_t num_outputs,
PyObject* nodecalls) {
static PyObject* method_name = PyUnicode_InternFromString("capture");
THPObjectPtr pyinput(THPVariable_WrapList(compiler_call.tensor_args.inputs));
THPObjectPtr pysizeinput(cache.wrap_dynamic_inputs());
std::vector<std::optional<c10::SymInt>> dynamic_inputs =
cache.unwrap_dynamic_inputs(py::cast<py::list>(pysizeinput.get()).ptr());
THPObjectPtr pyivalueargsinput(
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args));
THPObjectPtr pynodeorigins(
wrap_node_origins(compiler_call, PyTuple_GET_SIZE(pysizeinput.get())));
PyObject* py_num_outputs = THPUtils_packUInt32(num_outputs);
return check(PyObject_CallMethodObjArgs(
self,
method_name,
pyinput.get(),
// TODO(rzou): is this leaking memory?
py::cast(dynamic_inputs).ptr(),
pyivalueargsinput.get(),
pynodeorigins.get(),
nodecalls,
py_num_outputs,
nullptr));
}
static TraceState call_begin_capture(
PyObject* self,
CacheNode& cache,
@ -552,7 +583,9 @@ CacheNode* _compiled_autograd_impl(
compiler_call.set_active_node_call_idx(i);
}
if (node_args.cond(call.needed)) {
call.mode = CollectionMode::COMPILED_ARGS;
fn->compiled_args(node_args);
call.mode = CollectionMode::NEXT_EDGES;
node_args.collect(call.node->next_edges());
}
CacheKey key = node_args.key();
@ -600,112 +633,15 @@ CacheNode* _compiled_autograd_impl(
ClosingTHPObjectPtr py_compiler(
check(PyObject_CallNoArgs((the_autograd_compiler))));
TraceState state = call_begin_capture(
py_compiler, *cache, compiler_call, output_edges.size());
InputBuffers input_buffers;
// nodes
py::object nodecalls = py::cast(calls);
PyObject* res = call_capture(
py_compiler,
*cache,
compiler_call,
output_edges.size(),
nodecalls.ptr());
for (size_t i = 0; i < calls.size(); i++) {
NodeCall& call = *calls[i];
// TODO(jansel): consider adding some of this stuff:
// guard(local_graph_task); NodeGuard ndguard(task.fn_); const auto
// opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
// c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
// CheckpointValidGuard cpvguard(graph_task);
// at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
// if (C10_UNLIKELY(step_callbacks.has_value())) { ... }
variable_list inputs =
std::move(input_buffers.lookup(call.node.get()).buffer);
input_buffers.erase(call.node.get());
if (!call.tensor_pre_hooks.empty()) {
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
for (const auto& hook : call.tensor_pre_hooks) {
pyinputs = check(PyObject_CallMethod(
py_compiler,
"tensor_pre_hook",
"Oii",
pyinputs.get(),
hook.first,
hook.second));
}
inputs = THPVariable_UnpackList(pyinputs);
}
for (const auto& graph_output : call.graph_output) {
int input_nr = graph_output.first;
int output_index = graph_output.second;
TORCH_INTERNAL_ASSERT(
output_index < static_cast<int>(state.outputs.size()));
TORCH_INTERNAL_ASSERT(!state.outputs[output_index].defined());
state.outputs[output_index] = inputs[input_nr];
}
if (!call.needed) {
continue;
}
if (!call.pre_hooks.empty()) {
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
for (const auto hook : call.pre_hooks) {
pyinputs = check(PyObject_CallMethod(
py_compiler.get(), "pre_hook", "Oi", pyinputs.get(), hook));
}
inputs = THPVariable_UnpackList(pyinputs);
}
std::string _node_name = call.node->name();
THPObjectPtr node_name(PyUnicode_FromString(_node_name.data()));
TORCH_INTERNAL_ASSERT(node_name != nullptr);
THPObjectPtr set_node_origin(
PyObject_GetAttrString(py_compiler.get(), "set_node_origin"));
PyObject* pyobj = Py_None;
if (auto pynode = std::dynamic_pointer_cast<PyNode>(call.node)) {
pyobj = pynode->obj;
}
check(PyObject_CallFunction(
set_node_origin, "OIO", node_name.get(), i, pyobj, nullptr));
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
variable_list outputs = call.node->apply_with_saved(inputs, saved);
saved.debug_asserts();
saved.before(call.node->next_edges());
validate_outputs(
call.node->next_edges(), outputs, [&](const std::string& msg) {
std::ostringstream ss;
ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
<< msg;
return ss.str();
});
saved.after(call.node->next_edges());
saved.debug_asserts();
if (!call.post_hooks.empty()) {
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
THPObjectPtr pyoutputs(THPVariable_WrapList(outputs));
for (const auto hook : call.post_hooks) {
pyoutputs = check(PyObject_CallMethod(
py_compiler.get(),
"post_hook",
"OOi",
pyoutputs.get(),
pyinputs.get(),
hook));
}
outputs = THPVariable_UnpackList(pyoutputs);
}
for (const auto i : c10::irange(outputs.size())) {
auto& output = outputs[i];
const auto& next = call.node->next_edge(i);
if (next.is_valid() && output.defined()) {
input_buffers.lookup(next.function.get())
.add(
next.input_nr, std::move(output), std::nullopt, std::nullopt);
}
}
}
PyObject* res = check(call_end_capture(py_compiler, state.outputs));
TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
TORCH_CHECK(
PyTuple_Size(res) == 2,
@ -718,15 +654,25 @@ CacheNode* _compiled_autograd_impl(
TORCH_CHECK(
PyCallable_Check(cache->compiled_fn),
"Expected end_capture to return compiled_fn");
state.debug_asserts();
// TODO(rzou): what is this?
// state.debug_asserts();
} // End cache miss region
// TODO(rzou): need some mechanism to release the variables when we're ready.
// TODO(jansel): clear grads we will overwrite below
if (!graph_task.keep_graph_) {
for (auto& call : calls) {
call->node->release_variables();
}
// if (!graph_task.keep_graph_) {
// for (auto& call : calls) {
// call->node->release_variables();
// }
// }
// TODO(rzou): we probably shouldn't be copying the nodes in the hot path?
std::vector<NodeCall> persistent_node_calls;
for (NodeCall* call : calls) {
persistent_node_calls.push_back(*call);
}
auto ca = py::module::import("torch._dynamo.compiled_autograd");
ca.attr("set_global_nodecalls")(persistent_node_calls);
*graph_arg_inputs = THPVariable_WrapList(compiler_call.tensor_args.inputs);
*graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);