mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
Compare commits
1 Commits
ciflow/tru
...
python_com
Author | SHA1 | Date | |
---|---|---|---|
e5937dc68c |
140
r2.py
Normal file
140
r2.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
# type: ignore
|
||||||
|
import torch
|
||||||
|
import torch.utils.cpp_extension
|
||||||
|
|
||||||
|
def compiler_fn(gm):
|
||||||
|
# return gm
|
||||||
|
return torch.compile(gm, backend="eager", fullgraph=False)
|
||||||
|
|
||||||
|
# ===========================================================
|
||||||
|
# Basic test with a hook that has side effects
|
||||||
|
|
||||||
|
|
||||||
|
# Test case 1: a hook
|
||||||
|
x = torch.tensor([1., 2., 3.], requires_grad=True)
|
||||||
|
y = x ** 2
|
||||||
|
z = y.sum()
|
||||||
|
|
||||||
|
im_grad = []
|
||||||
|
|
||||||
|
def hook(grad):
|
||||||
|
im_grad.append(grad)
|
||||||
|
return 2 * grad
|
||||||
|
|
||||||
|
y.register_hook(hook)
|
||||||
|
|
||||||
|
with torch._dynamo.compiled_autograd.enable(compiler_fn):
|
||||||
|
z.backward()
|
||||||
|
|
||||||
|
assert torch.allclose(x.grad, 4 * x)
|
||||||
|
assert torch.allclose(im_grad[0], torch.ones_like(y))
|
||||||
|
|
||||||
|
# ===========================================================
|
||||||
|
# Unsupported C++ autograd node should graph break.
|
||||||
|
# This is better than the current compiled autograd behavior of "error out"
|
||||||
|
# and brings us a step closer to having "compiled autograd on by default".
|
||||||
|
# In theory we can also add a config that automatically treats
|
||||||
|
# it as an opaque callable, but such a config is unsound.
|
||||||
|
|
||||||
|
cpp_source = """
|
||||||
|
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||||
|
static constexpr bool is_traceable = false;
|
||||||
|
static torch::Tensor forward(
|
||||||
|
torch::autograd::AutogradContext* ctx,
|
||||||
|
const torch::Tensor& x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static torch::autograd::variable_list backward(
|
||||||
|
torch::autograd::AutogradContext *ctx,
|
||||||
|
torch::autograd::variable_list grad_output) {
|
||||||
|
// not traceable
|
||||||
|
*grad_output[0].data_ptr<float>() = 3.14;
|
||||||
|
return grad_output;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
|
||||||
|
return CustomOpAutogradFunction::apply(x);
|
||||||
|
}
|
||||||
|
TORCH_LIBRARY_FRAGMENT(mylib, m) {
|
||||||
|
m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
module = torch.utils.cpp_extension.load_inline(
|
||||||
|
name="mylib",
|
||||||
|
cpp_sources=cpp_source,
|
||||||
|
functions="custom_op_backed_by_autograd_fn",
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.ones(2, 2, requires_grad=True)
|
||||||
|
out = torch.ops.mylib.custom_op_backed_by_autograd_fn(
|
||||||
|
x
|
||||||
|
)
|
||||||
|
loss = out.sum()
|
||||||
|
with torch._dynamo.compiled_autograd.enable(compiler_fn):
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
expected = torch.ones_like(x) * 3.14
|
||||||
|
assert torch.allclose(x.grad, expected)
|
||||||
|
|
||||||
|
# ===========================================================
|
||||||
|
# Tests that we don't bake in "guessed" metadata.
|
||||||
|
# This test case would have erroed out in the previous
|
||||||
|
# compiled autograd.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.cpp_extension
|
||||||
|
cpp_source2 = """
|
||||||
|
struct CustomOpAutogradFunction2 : public torch::autograd::Function<CustomOpAutogradFunction2> {
|
||||||
|
static constexpr bool is_traceable = true;
|
||||||
|
static torch::Tensor forward(
|
||||||
|
torch::autograd::AutogradContext* ctx,
|
||||||
|
const torch::Tensor& x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
static torch::autograd::variable_list backward(
|
||||||
|
torch::autograd::AutogradContext *ctx,
|
||||||
|
torch::autograd::variable_list grad_output) {
|
||||||
|
if (grad_output[0].is_contiguous()) {
|
||||||
|
return {2 * grad_output[0]};
|
||||||
|
} else {
|
||||||
|
return {3 * grad_output[0]};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
torch::Tensor custom_op_backed_by_autograd_fn2(torch::Tensor x) {
|
||||||
|
return CustomOpAutogradFunction2::apply(x);
|
||||||
|
}
|
||||||
|
TORCH_LIBRARY_FRAGMENT(mylib, m) {
|
||||||
|
m.def("custom_op_backed_by_autograd_fn2", custom_op_backed_by_autograd_fn2);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
module = torch.utils.cpp_extension.load_inline(
|
||||||
|
name="mylib",
|
||||||
|
cpp_sources=cpp_source2,
|
||||||
|
functions="custom_op_backed_by_autograd_fn2",
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.tensor([[1., 2., 3.], [4, 5, 6]], requires_grad=True)
|
||||||
|
y = torch.ops.mylib.custom_op_backed_by_autograd_fn2(x)
|
||||||
|
z = y.clone()
|
||||||
|
w = z.sum()
|
||||||
|
|
||||||
|
def hook(grad):
|
||||||
|
# return a contiguous grad.
|
||||||
|
# The previous compiled autograd would have "guessed" that
|
||||||
|
# the tensor is not contiguous.
|
||||||
|
assert not grad.is_contiguous()
|
||||||
|
return grad.contiguous()
|
||||||
|
|
||||||
|
z.register_hook(hook)
|
||||||
|
|
||||||
|
with torch._dynamo.compiled_autograd.enable(lambda x: x):
|
||||||
|
w.backward()
|
||||||
|
|
||||||
|
assert torch.allclose(x.grad, 2 * torch.ones_like(x))
|
190
torch/_compiled_autograd.py
Normal file
190
torch/_compiled_autograd.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
# type: ignore
|
||||||
|
import threading
|
||||||
|
import torch
|
||||||
|
from ._compile import _disable_dynamo
|
||||||
|
from ._C import _autograd
|
||||||
|
# TODO(rzou): why doesn't torch.fx.wrap work directly?
|
||||||
|
from torch.fx._symbolic_trace import _create_wrapped_func as wrap
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO(rzou): did we really need a new file? I did it to appease trace_rules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def python_autograd(saved_state, hooks, nodecalls, num_outputs, arange):
|
||||||
|
"""Given the state of the autograd graph (the saved tensors/sizes/scalar,
|
||||||
|
hooks, and the actual nodes), execute it in Python.
|
||||||
|
|
||||||
|
Compiled Autograd uses the equivalent of torch.fx.symbolic_trace over
|
||||||
|
this function to produce a graph that can then be Dynamo'ed.
|
||||||
|
|
||||||
|
NB: Before executing this function (or an acquired graph version of it)
|
||||||
|
on real Tensors, please call set_global_nodecalls(nodecalls) to set the
|
||||||
|
current autograd nodes structure state. We intentionally hide this state
|
||||||
|
from the graph so that Dynamo doesn't need to deal with proxying it into
|
||||||
|
the graph.
|
||||||
|
|
||||||
|
TODO(rzou): Compiled Autograd is responsible for calling set_global_nodecalls
|
||||||
|
using the current nodecalls data structure. If the user did not specify
|
||||||
|
retain_graph=True, then something needs to free it later,
|
||||||
|
so we don't end up keeping the nodes around forever.
|
||||||
|
"""
|
||||||
|
node_to_idx_data = {node_id(call.node): idx for idx, call in enumerate(nodecalls)}
|
||||||
|
|
||||||
|
def node_to_idx(node):
|
||||||
|
return node_to_idx_data[torch._compiled_autograd.node_id(node)]
|
||||||
|
|
||||||
|
input_buffers = {}
|
||||||
|
|
||||||
|
def lookup_input_buffer(node_idx, num_inputs):
|
||||||
|
if node_idx not in input_buffers:
|
||||||
|
input_buffers[node_idx] = [None] * num_inputs
|
||||||
|
return input_buffers[node_idx]
|
||||||
|
|
||||||
|
saved_state = iter(SavedState(
|
||||||
|
nodecalls,
|
||||||
|
saved_state[0],
|
||||||
|
saved_state[1],
|
||||||
|
saved_state[2],
|
||||||
|
))
|
||||||
|
|
||||||
|
graph_outputs = [None] * num_outputs
|
||||||
|
|
||||||
|
for idx, call in enumerate(nodecalls):
|
||||||
|
node_idx = arange[idx]
|
||||||
|
inputs = lookup_input_buffer(idx, call.node.num_inputs())
|
||||||
|
|
||||||
|
# Given all of the saved state, retrieve the saved state that matters
|
||||||
|
# for the current node call.
|
||||||
|
apply_state, validate_outputs_state = next(saved_state)
|
||||||
|
|
||||||
|
for hook_idx, input_idx in call.tensor_pre_hooks:
|
||||||
|
inputs[input_idx] = call_hook(hooks[hook_idx], inputs[input_idx], hook_type="pre_hook")
|
||||||
|
for input_nr, result_idx in call.graph_output:
|
||||||
|
graph_outputs[result_idx] = inputs[input_nr]
|
||||||
|
if not call.needed:
|
||||||
|
continue
|
||||||
|
if call.node.is_compiled_autograd_traceable():
|
||||||
|
outputs = apply_with_saved(node_idx, inputs, *apply_state)
|
||||||
|
else:
|
||||||
|
outputs = apply_with_saved_dynamo_disabled(node_idx, inputs, *apply_state)
|
||||||
|
outputs = validate_outputs(node_idx, outputs, *validate_outputs_state)
|
||||||
|
for hook_idx, input_idx in call.post_hooks:
|
||||||
|
call_hook(hooks[hook_idx], outputs, inputs, hook_type="post_hook")
|
||||||
|
for output_idx in range(call.node.num_outputs()):
|
||||||
|
output = outputs[output_idx]
|
||||||
|
next_edge = call.node.next_edge(output_idx)
|
||||||
|
if not next_edge.is_valid():
|
||||||
|
continue
|
||||||
|
next_node = next_edge.function
|
||||||
|
input_buffer = lookup_input_buffer(node_to_idx(next_node), next_node.num_inputs())
|
||||||
|
updated_buffer = accumulate(input_buffer[next_edge.input_nr], output)
|
||||||
|
input_buffer[next_edge.input_nr] = updated_buffer
|
||||||
|
|
||||||
|
return graph_outputs
|
||||||
|
|
||||||
|
|
||||||
|
global_nodecalls = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def get_node(idx):
|
||||||
|
return global_nodecalls.thread_local[idx].node
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_nodecalls(nodecalls):
|
||||||
|
global_nodecalls.thread_local = nodecalls
|
||||||
|
|
||||||
|
|
||||||
|
@wrap
|
||||||
|
def apply_with_saved(node_idx, inputs, saved_tensors, saved_sizes, saved_scalars):
|
||||||
|
"""
|
||||||
|
Applies the node at global_nodecalls[node_idx] using the inputs and saved values.
|
||||||
|
"""
|
||||||
|
node = get_node(node_idx)
|
||||||
|
outputs = _autograd.apply_with_saved(global_nodecalls.thread_local[node_idx], inputs, saved_tensors, list(saved_sizes), saved_scalars)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@_disable_dynamo
|
||||||
|
@wrap
|
||||||
|
def apply_with_saved_dynamo_disabled(node_idx, inputs, saved_tensors, saved_sizes, saved_scalars):
|
||||||
|
"""
|
||||||
|
This is apply_with_saved, but also induces a graph break in Dynamo.
|
||||||
|
"""
|
||||||
|
return apply_with_saved(node_idx, inputs, saved_tensors, saved_sizes, saved_scalars)
|
||||||
|
|
||||||
|
|
||||||
|
@wrap
|
||||||
|
def validate_outputs(node_idx, outputs, saved_tensors, saved_sizes, saved_scalars):
|
||||||
|
"""
|
||||||
|
Validates the outputs of the node at global_nodecalls[node_idx]. This requires
|
||||||
|
swizzling out some input metadata state of the next nodes, which is why
|
||||||
|
it also accepts some saved variables.
|
||||||
|
"""
|
||||||
|
outputs = _autograd.validate_outputs_with_saved(global_nodecalls.thread_local[node_idx], outputs, saved_tensors, list(saved_sizes), saved_scalars)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def node_id(node):
|
||||||
|
if node is None:
|
||||||
|
breakpoint()
|
||||||
|
assert node is not None
|
||||||
|
return _autograd.node_id(node)
|
||||||
|
|
||||||
|
|
||||||
|
def arange(num):
|
||||||
|
return list(range(num))
|
||||||
|
|
||||||
|
|
||||||
|
@wrap
|
||||||
|
def call_hook(*args, **kwargs):
|
||||||
|
return torch._dynamo.external_utils.call_hook(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class IterableWrapper:
|
||||||
|
def __init__(self, noniterable, size):
|
||||||
|
self.noniterable = noniterable
|
||||||
|
self.idx = 0
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
assert self.idx < self.size
|
||||||
|
result = self.noniterable[self.idx]
|
||||||
|
self.idx += 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class SavedState:
|
||||||
|
def __init__(self, nodecalls, tensors, sizes, scalars):
|
||||||
|
self.tensors = tensors
|
||||||
|
self.sizes = sizes
|
||||||
|
self.scalars = scalars
|
||||||
|
self.nodecalls = iter(nodecalls)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
call = next(self.nodecalls)
|
||||||
|
|
||||||
|
def get_next(collection_info):
|
||||||
|
tensors = [next(self.tensors) for _ in range(collection_info.num_saved_tensors)]
|
||||||
|
sizes = [next(self.sizes) for _ in range(collection_info.num_saved_sizes)]
|
||||||
|
scalars = [next(self.scalars) for _ in range(collection_info.num_saved_ivalues)]
|
||||||
|
return (tensors, sizes, scalars)
|
||||||
|
|
||||||
|
saved_state_for_apply = get_next(call.compiled_args_info)
|
||||||
|
saved_state_for_validate_output = get_next(call.next_edges_info)
|
||||||
|
return saved_state_for_apply, saved_state_for_validate_output
|
||||||
|
|
||||||
|
|
||||||
|
@wrap
|
||||||
|
def accumulate(old_var, var):
|
||||||
|
if old_var is None:
|
||||||
|
return var
|
||||||
|
if var is None:
|
||||||
|
return old_var
|
||||||
|
return old_var + var
|
@ -82,6 +82,49 @@ class AutogradCompilerInstance:
|
|||||||
def source(name, idx) -> GetItemSource:
|
def source(name, idx) -> GetItemSource:
|
||||||
return GetItemSource(LocalSource(name), idx)
|
return GetItemSource(LocalSource(name), idx)
|
||||||
|
|
||||||
|
def capture(self, tensors, sizes, scalars, origins, nodecalls, num_outputs):
|
||||||
|
dynamic_sizes = tuple(s for s in sizes if s is not None)
|
||||||
|
|
||||||
|
counters["compiled_autograd"]["captures"] += 1
|
||||||
|
inputs_origins, sizes_origins, scalars_origins = origins
|
||||||
|
|
||||||
|
self.fx_tracer.root = torch.nn.Module()
|
||||||
|
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
|
||||||
|
self.fx_tracer.tensor_attrs = {}
|
||||||
|
inputs_proxy, dynamic_sizes_proxy, scalars_proxy, self.hooks_proxy = (
|
||||||
|
self.fx_tracer.create_proxy("placeholder", name, (), {})
|
||||||
|
for name in self.graph_placeholders
|
||||||
|
)
|
||||||
|
|
||||||
|
sizes_proxy = [None] * len(sizes)
|
||||||
|
dynamic_sizes_next = 0
|
||||||
|
for idx in range(len(sizes)):
|
||||||
|
if sizes[idx] is not None:
|
||||||
|
sizes_proxy[idx] = dynamic_sizes[dynamic_sizes_next]
|
||||||
|
dynamic_sizes_next += 1
|
||||||
|
|
||||||
|
from torch._compiled_autograd import IterableWrapper, python_autograd, arange
|
||||||
|
|
||||||
|
arange_proxy = self.fx_tracer.create_proxy(
|
||||||
|
kind="call_function",
|
||||||
|
target=arange,
|
||||||
|
args=(len(nodecalls),),
|
||||||
|
kwargs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_outputs = python_autograd(
|
||||||
|
(
|
||||||
|
IterableWrapper(inputs_proxy, len(tensors)),
|
||||||
|
IterableWrapper(sizes_proxy, len(sizes)),
|
||||||
|
IterableWrapper(scalars_proxy, len(scalars)),
|
||||||
|
),
|
||||||
|
self.hooks_proxy,
|
||||||
|
nodecalls,
|
||||||
|
num_outputs,
|
||||||
|
arange_proxy,
|
||||||
|
)
|
||||||
|
return self.end_capture(graph_outputs)
|
||||||
|
|
||||||
def begin_capture(
|
def begin_capture(
|
||||||
self,
|
self,
|
||||||
inputs: List[torch.Tensor],
|
inputs: List[torch.Tensor],
|
||||||
@ -308,8 +351,10 @@ class AutogradCompilerInstance:
|
|||||||
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
|
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
|
||||||
{},
|
{},
|
||||||
)
|
)
|
||||||
self.rename_aot_dispatcher_nodes()
|
# TODO(rzou): we didn't inline the AOTDispatcher nodes
|
||||||
self.reorder_accumulate_grad_nodes()
|
# self.rename_aot_dispatcher_nodes()
|
||||||
|
# TODO(rzou): we need to transform AccumulateGrad nodes into torch.inductor.accumulate_grad_.
|
||||||
|
# self.reorder_accumulate_grad_nodes()
|
||||||
runtime_inputs_to_move: List[int] = []
|
runtime_inputs_to_move: List[int] = []
|
||||||
if snapshot_cudagraph_enabled():
|
if snapshot_cudagraph_enabled():
|
||||||
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
|
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
|
||||||
@ -317,6 +362,7 @@ class AutogradCompilerInstance:
|
|||||||
graph = GraphModule(
|
graph = GraphModule(
|
||||||
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
|
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
|
||||||
)
|
)
|
||||||
|
graph.print_readable()
|
||||||
set_locals_to_steal(graph, ["inputs"])
|
set_locals_to_steal(graph, ["inputs"])
|
||||||
lazy_graph_code = lazy_format_graph_code(
|
lazy_graph_code = lazy_format_graph_code(
|
||||||
"Compiled autograd graph",
|
"Compiled autograd graph",
|
||||||
@ -562,3 +608,5 @@ def reset() -> None:
|
|||||||
assert not in_compiled_autograd_region
|
assert not in_compiled_autograd_region
|
||||||
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
|
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
|
||||||
torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
|
torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
|
||||||
|
|
||||||
|
from torch._compiled_autograd import set_global_nodecalls
|
||||||
|
@ -950,6 +950,9 @@ class OutputGraph:
|
|||||||
list_name = arg.source.local_name
|
list_name = arg.source.local_name
|
||||||
assert list_name in self.code_options["co_varnames"]
|
assert list_name in self.code_options["co_varnames"]
|
||||||
for x in needs_alias[list_name]:
|
for x in needs_alias[list_name]:
|
||||||
|
if not hasattr(x.source, "index"):
|
||||||
|
# TODO(rzou): idk
|
||||||
|
breakpoint()
|
||||||
list_idx = x.source.index
|
list_idx = x.source.index
|
||||||
if list_idx not in visited:
|
if list_idx not in visited:
|
||||||
alias_name = self.new_var(
|
alias_name = self.new_var(
|
||||||
|
@ -134,6 +134,13 @@ If you are removing an existing torch level API:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
manual_torch_name_rule_map = {
|
manual_torch_name_rule_map = {
|
||||||
|
"torch._compiled_autograd.CA_apply_with_saved": TorchInGraphFunctionVariable,
|
||||||
|
"torch._compiled_autograd.accumulate2": TorchInGraphFunctionVariable,
|
||||||
|
"torch._compiled_autograd.CA_validate_outputs": TorchInGraphFunctionVariable,
|
||||||
|
# "torch._compiled_autograd.CA_apply_with_saved_dynamo_disabled": TorchInGraphFunctionVariable,
|
||||||
|
"torch._compiled_autograd.CA_update_input_buffers": TorchInGraphFunctionVariable,
|
||||||
|
"torch._compiled_autograd.CA_input_buffers_init": TorchInGraphFunctionVariable,
|
||||||
|
"torch._compiled_autograd.CA_input_buffers_lookup": TorchInGraphFunctionVariable,
|
||||||
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
|
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
|
||||||
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
|
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
|
||||||
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
|
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
|
||||||
@ -3237,6 +3244,7 @@ if torch.distributed.is_available():
|
|||||||
# We are using python module name instead of file or directory object to avoid circular dependency.
|
# We are using python module name instead of file or directory object to avoid circular dependency.
|
||||||
# Please keep this sorted alphabetically.
|
# Please keep this sorted alphabetically.
|
||||||
MOD_INLINELIST = [
|
MOD_INLINELIST = [
|
||||||
|
"torch._compiled_autograd",
|
||||||
"torch._decomp",
|
"torch._decomp",
|
||||||
"torch._dynamo._trace_wrapped_higher_order_op",
|
"torch._dynamo._trace_wrapped_higher_order_op",
|
||||||
"torch._dynamo.comptime",
|
"torch._dynamo.comptime",
|
||||||
|
@ -1219,7 +1219,12 @@ class VariableBuilder:
|
|||||||
maybe_gm = self.tx.output.local_scope.get("self")
|
maybe_gm = self.tx.output.local_scope.get("self")
|
||||||
if isinstance(
|
if isinstance(
|
||||||
self.source, LocalSource
|
self.source, LocalSource
|
||||||
) and self.source.local_name in get_locals_to_steal(maybe_gm):
|
# TODO(rzou): We changed compiled autograd to pass all of the inputs saved
|
||||||
|
# instead of a de-duplicated list. Unfortunately that makes the input
|
||||||
|
# stealing logic go haywire. We can either fix it or figure out
|
||||||
|
# how to deal with a de-duplicated list (the problem is
|
||||||
|
# mapping the de-duplicated saved tensors back to the nodes that need them).
|
||||||
|
) and self.source.local_name in get_locals_to_steal(maybe_gm) and False:
|
||||||
# The input tensor list to dynamo from compiled autograd may contain activations
|
# The input tensor list to dynamo from compiled autograd may contain activations
|
||||||
# which are freed as they are used in inductor. Dynamo's default behavior is to
|
# which are freed as they are used in inductor. Dynamo's default behavior is to
|
||||||
# lift all tensors to the graph inputs, but this will cause dynamo to hold an
|
# lift all tensors to the graph inputs, but this will cause dynamo to hold an
|
||||||
@ -1249,13 +1254,17 @@ class VariableBuilder:
|
|||||||
source_i = GetItemSource(base=source, index=i, index_is_slice=False)
|
source_i = GetItemSource(base=source, index=i, index_is_slice=False)
|
||||||
# access unpacked tensor from this list instead of from a lifted arg
|
# access unpacked tensor from this list instead of from a lifted arg
|
||||||
self.tx.output.input_source_to_var[source_i] = tensor_variable
|
self.tx.output.input_source_to_var[source_i] = tensor_variable
|
||||||
tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(
|
if isinstance(tensor_variable, TensorVariable):
|
||||||
value[i]
|
tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(
|
||||||
)
|
value[i]
|
||||||
|
)
|
||||||
|
|
||||||
guard = functools.partial(
|
guard = functools.partial(
|
||||||
GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i])
|
GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i])
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# TODO(rzou): None guard?
|
||||||
|
pass
|
||||||
guards.append(source_i.make_guard(guard))
|
guards.append(source_i.make_guard(guard))
|
||||||
|
|
||||||
install_guard(*guards, skip=1)
|
install_guard(*guards, skip=1)
|
||||||
|
@ -188,16 +188,25 @@ struct CppNode : public Node {
|
|||||||
void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
|
void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
|
||||||
void save_variables_to_ctx();
|
void save_variables_to_ctx();
|
||||||
|
|
||||||
|
bool is_compiled_autograd_traceable() override {
|
||||||
|
static_assert(
|
||||||
|
std::is_same_v<std::remove_cv_t<decltype(T::is_traceable)>, bool>);
|
||||||
|
return T::is_traceable;
|
||||||
|
}
|
||||||
|
|
||||||
void compiled_args(CompiledNodeArgs& args) override {
|
void compiled_args(CompiledNodeArgs& args) override {
|
||||||
static_assert(
|
static_assert(
|
||||||
std::is_same_v<std::remove_cv_t<decltype(T::is_traceable)>, bool>);
|
std::is_same_v<std::remove_cv_t<decltype(T::is_traceable)>, bool>);
|
||||||
if (!T::is_traceable) {
|
// if (!T::is_traceable) {
|
||||||
throw std::runtime_error(
|
// throw std::runtime_error(
|
||||||
std::string(
|
// std::string(
|
||||||
"Attempting to trace a potentially unsafe C++ autograd function: ") +
|
// "Attempting to trace a potentially unsafe C++ autograd
|
||||||
name() +
|
// function: ") +
|
||||||
". It may be possible to trace it safely, please refer to the instructions in: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/.");
|
// name() +
|
||||||
}
|
// ". It may be possible to trace it safely, please refer to the
|
||||||
|
// instructions in:
|
||||||
|
// https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/.");
|
||||||
|
// }
|
||||||
|
|
||||||
// although neither of the 2 methods below have uniqueness guarantees
|
// although neither of the 2 methods below have uniqueness guarantees
|
||||||
// it is unlikely for them to collide at the same time
|
// it is unlikely for them to collide at the same time
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include <c10/util/ThreadLocal.h>
|
#include <c10/util/ThreadLocal.h>
|
||||||
#include <torch/csrc/autograd/engine.h>
|
#include <torch/csrc/autograd/engine.h>
|
||||||
#include <torch/csrc/autograd/variable.h>
|
#include <torch/csrc/autograd/variable.h>
|
||||||
|
#include <torch/csrc/dynamo/compiled_autograd.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
|
@ -563,6 +563,10 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
|||||||
/// release variables as they run.
|
/// release variables as they run.
|
||||||
virtual void will_release_variables() {}
|
virtual void will_release_variables() {}
|
||||||
|
|
||||||
|
virtual bool is_compiled_autograd_traceable() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns true if this function is traceable. An op is traceable if all
|
/// Returns true if this function is traceable. An op is traceable if all
|
||||||
/// operations happening within `apply()` are performed on autograd
|
/// operations happening within `apply()` are performed on autograd
|
||||||
/// `Variables` (i.e. apply mostly instantiates and applies other functions).
|
/// `Variables` (i.e. apply mostly instantiates and applies other functions).
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#include <torch/csrc/python_headers.h>
|
#include <torch/csrc/python_headers.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include <ATen/PythonTorchFunctionTLS.h>
|
#include <ATen/PythonTorchFunctionTLS.h>
|
||||||
#include <ATen/SavedTensorHooks.h>
|
#include <ATen/SavedTensorHooks.h>
|
||||||
@ -14,6 +15,7 @@
|
|||||||
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
||||||
#include <torch/csrc/autograd/autograd.h>
|
#include <torch/csrc/autograd/autograd.h>
|
||||||
#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
|
#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
|
||||||
|
#include <torch/csrc/autograd/engine.h>
|
||||||
#include <torch/csrc/autograd/function.h>
|
#include <torch/csrc/autograd/function.h>
|
||||||
#include <torch/csrc/autograd/grad_mode.h>
|
#include <torch/csrc/autograd/grad_mode.h>
|
||||||
#include <torch/csrc/autograd/input_metadata.h>
|
#include <torch/csrc/autograd/input_metadata.h>
|
||||||
@ -26,6 +28,7 @@
|
|||||||
#include <torch/csrc/autograd/saved_variable.h>
|
#include <torch/csrc/autograd/saved_variable.h>
|
||||||
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
|
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
|
||||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||||
|
#include <torch/csrc/dynamo/compiled_autograd.h>
|
||||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||||
#include <torch/csrc/profiler/collection.h>
|
#include <torch/csrc/profiler/collection.h>
|
||||||
#include <torch/csrc/profiler/kineto_shim.h>
|
#include <torch/csrc/profiler/kineto_shim.h>
|
||||||
@ -42,6 +45,7 @@
|
|||||||
|
|
||||||
using torch::impl::py_context_manager;
|
using torch::impl::py_context_manager;
|
||||||
using torch::impl::py_context_manager_DEPRECATED;
|
using torch::impl::py_context_manager_DEPRECATED;
|
||||||
|
using namespace torch::dynamo::autograd;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -79,6 +83,55 @@ struct EnablePythonDispatcher {
|
|||||||
c10::impl::PyInterpreter* old_;
|
c10::impl::PyInterpreter* old_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::vector<at::Tensor> toVec(
|
||||||
|
const std::vector<std::optional<at::Tensor>>& ts) {
|
||||||
|
std::vector<at::Tensor> result;
|
||||||
|
for (const auto& opt_tensor : ts) {
|
||||||
|
if (opt_tensor.has_value()) {
|
||||||
|
result.push_back(opt_tensor.value());
|
||||||
|
} else {
|
||||||
|
result.emplace_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
variable_list validate_outputs_with_saved(
|
||||||
|
const NodeCall& nodecall,
|
||||||
|
std::vector<at::Tensor>& outputs,
|
||||||
|
const std::vector<at::Tensor>& saved_tensors,
|
||||||
|
const std::vector<std::optional<at::SymInt>>& saved_sizes,
|
||||||
|
const std::vector<at::IValue>& saved_ivalues) {
|
||||||
|
auto saved = SwapSavedVariables(
|
||||||
|
saved_tensors, saved_sizes, saved_ivalues, nullptr, nodecall);
|
||||||
|
saved.before(nodecall.node->next_edges());
|
||||||
|
torch::autograd::validate_outputs(
|
||||||
|
nodecall.node->next_edges(), outputs, [&](const std::string& msg) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "[Compiled Autograd Tracing: " << nodecall.node->name() << "] "
|
||||||
|
<< msg;
|
||||||
|
return ss.str();
|
||||||
|
});
|
||||||
|
saved.after(nodecall.node->next_edges());
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
variable_list apply_with_saved314(
|
||||||
|
const NodeCall& nodecall,
|
||||||
|
const std::vector<std::optional<at::Tensor>>& inputs,
|
||||||
|
const std::vector<std::optional<at::Tensor>>& saved_tensors,
|
||||||
|
const std::vector<std::optional<at::SymInt>>& saved_sizes,
|
||||||
|
const std::vector<at::IValue>& saved_ivalues) {
|
||||||
|
auto saved = SwapSavedVariables(
|
||||||
|
toVec(saved_tensors), saved_sizes, saved_ivalues, nullptr, nodecall);
|
||||||
|
auto outputs = nodecall.node->apply_with_saved(toVec(inputs), saved);
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t node_id(const std::shared_ptr<Node>& node) {
|
||||||
|
return reinterpret_cast<uint64_t>(node.get());
|
||||||
|
}
|
||||||
|
|
||||||
struct EnablePreDispatch {
|
struct EnablePreDispatch {
|
||||||
EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {}
|
EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {}
|
||||||
c10::impl::IncludeDispatchKeyGuard guard_;
|
c10::impl::IncludeDispatchKeyGuard guard_;
|
||||||
@ -491,6 +544,50 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// compiled_autograd stuff
|
||||||
|
py::class_<torch::autograd::Node, std::shared_ptr<torch::autograd::Node>>(
|
||||||
|
m, "Node")
|
||||||
|
.def("compiled_args", &torch::autograd::Node::compiled_args)
|
||||||
|
.def("next_edge", &torch::autograd::Node::next_edge)
|
||||||
|
.def(
|
||||||
|
"is_compiled_autograd_traceable",
|
||||||
|
&torch::autograd::Node::is_compiled_autograd_traceable)
|
||||||
|
.def("name", &torch::autograd::Node::name)
|
||||||
|
.def("num_outputs", &torch::autograd::Node::num_outputs)
|
||||||
|
.def("num_inputs", &torch::autograd::Node::num_inputs);
|
||||||
|
py::class_<torch::autograd::Edge>(m, "Edge")
|
||||||
|
.def("is_valid", &torch::autograd::Edge::is_valid)
|
||||||
|
.def_readonly("input_nr", &torch::autograd::Edge::input_nr)
|
||||||
|
.def_readonly("function", &torch::autograd::Edge::function);
|
||||||
|
py::class_<CollectionInfo>(m, "CollectionInfo")
|
||||||
|
.def_readonly("num_saved_tensors", &CollectionInfo::num_saved_tensors)
|
||||||
|
.def_readonly("num_saved_sizes", &CollectionInfo::num_saved_sizes)
|
||||||
|
.def_readonly("num_saved_ivalues", &CollectionInfo::num_saved_ivalues);
|
||||||
|
py::class_<torch::dynamo::autograd::NodeCall>(m, "NodeCall")
|
||||||
|
.def_readonly("node", &NodeCall::node)
|
||||||
|
.def_readonly("compiled_args_info", &NodeCall::compiled_args_info)
|
||||||
|
.def_readonly("next_edges_info", &NodeCall::next_edges_info)
|
||||||
|
.def_readonly("tensor_pre_hooks", &NodeCall::tensor_pre_hooks)
|
||||||
|
.def_readonly("post_hooks", &NodeCall::post_hooks)
|
||||||
|
.def_readonly("graph_output", &NodeCall::graph_output)
|
||||||
|
.def_readonly("needed", &NodeCall::needed);
|
||||||
|
py::class_<torch::dynamo::autograd::CompiledNodeArgs>(m, "CompiledNodeArgs")
|
||||||
|
.def(py::init<AutogradCompilerCall&, NodeCall&>());
|
||||||
|
py::class_<torch::dynamo::autograd::AutogradCompilerCall>(
|
||||||
|
m, "AutogradCompilerCall")
|
||||||
|
.def(py::init<>());
|
||||||
|
m.def("apply_with_saved", &apply_with_saved314);
|
||||||
|
m.def("validate_outputs_with_saved", &validate_outputs_with_saved);
|
||||||
|
m.def("node_id", &node_id);
|
||||||
|
// py::class_<SwapInterface,PySwapInterface>(m, "SwapInterface");
|
||||||
|
// py::class_<SwapWithReal,SwapInterface>(m, "SwapWithReal")
|
||||||
|
// .def(py::init<std::vector<at::Tensor>,std::vector<c10::SymInt>,std::vector<c10::IValue>>())
|
||||||
|
// ;
|
||||||
|
// py::class_<SwapSavedVariables>(m, "SwapSavedVariables")
|
||||||
|
// .def(py::init<std::vector<at::Tensor>,std::vector<c10::SymInt>,std::vector<c10::IValue>,PyObject*,const
|
||||||
|
// NodeCall&>())
|
||||||
|
// ;
|
||||||
|
|
||||||
_C_m.def("_activate_gpu_trace", []() { activateGPUTrace(); });
|
_C_m.def("_activate_gpu_trace", []() { activateGPUTrace(); });
|
||||||
|
|
||||||
py_context_manager_DEPRECATED<c10::InferenceMode, bool>(
|
py_context_manager_DEPRECATED<c10::InferenceMode, bool>(
|
||||||
|
@ -69,7 +69,15 @@ struct CacheKey {
|
|||||||
const uint8_t* key;
|
const uint8_t* key;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct NodeCall {
|
struct CollectionInfo {
|
||||||
|
int num_saved_tensors = 0;
|
||||||
|
int num_saved_sizes = 0;
|
||||||
|
int num_saved_ivalues = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum CollectionMode { COMPILED_ARGS, NEXT_EDGES };
|
||||||
|
|
||||||
|
struct TORCH_API NodeCall {
|
||||||
NodeCall(uint32_t id_, std::shared_ptr<Node> node_)
|
NodeCall(uint32_t id_, std::shared_ptr<Node> node_)
|
||||||
: id(id_), node(std::move(node_)) {}
|
: id(id_), node(std::move(node_)) {}
|
||||||
|
|
||||||
@ -84,6 +92,24 @@ struct NodeCall {
|
|||||||
std::vector<int> post_hooks;
|
std::vector<int> post_hooks;
|
||||||
std::vector<int> post_acc_grad_hooks;
|
std::vector<int> post_acc_grad_hooks;
|
||||||
std::vector<std::pair<int, int>> graph_output;
|
std::vector<std::pair<int, int>> graph_output;
|
||||||
|
|
||||||
|
CollectionInfo& collection_info() {
|
||||||
|
if (mode == CollectionMode::NEXT_EDGES) {
|
||||||
|
return next_edges_info;
|
||||||
|
} else {
|
||||||
|
return compiled_args_info;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given the full list of saved arguments (saved tensors, saved sizes,
|
||||||
|
// saved scalars), we want to be able to map them back to which node
|
||||||
|
// they came from.
|
||||||
|
// The way we do this is that we store information on how many
|
||||||
|
// tensors/sizes/scalars each Node uses.
|
||||||
|
CollectionMode mode = CollectionMode::COMPILED_ARGS;
|
||||||
|
CollectionInfo compiled_args_info;
|
||||||
|
CollectionInfo next_edges_info;
|
||||||
|
|
||||||
bool needed = true;
|
bool needed = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -143,9 +169,9 @@ struct TensorArgs {
|
|||||||
auto impl = tensor.unsafeGetTensorImpl();
|
auto impl = tensor.unsafeGetTensorImpl();
|
||||||
auto it = _args.find(impl);
|
auto it = _args.find(impl);
|
||||||
if (it == _args.end()) {
|
if (it == _args.end()) {
|
||||||
TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1);
|
// TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1);
|
||||||
it = _args.emplace(impl, TensorArg(_next_id++)).first;
|
it = _args.emplace(impl, TensorArg(_next_id++)).first;
|
||||||
inputs.emplace_back(tensor);
|
// inputs.emplace_back(tensor);
|
||||||
if (active_node_call_idx.has_value()) {
|
if (active_node_call_idx.has_value()) {
|
||||||
input_origins.emplace_back(active_node_call_idx.value());
|
input_origins.emplace_back(active_node_call_idx.value());
|
||||||
}
|
}
|
||||||
@ -160,6 +186,9 @@ struct TensorArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TensorArg& add(const at::Tensor& tensor) {
|
TensorArg& add(const at::Tensor& tensor) {
|
||||||
|
// unconditionally add the tensor to inputs... Dynamo will de-dupe them
|
||||||
|
// later
|
||||||
|
inputs.emplace_back(tensor);
|
||||||
return lookup(tensor, true);
|
return lookup(tensor, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -208,6 +237,11 @@ struct LiftedIValueArgs {
|
|||||||
return iv_arg.proxy;
|
return iv_arg.proxy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
at::IValue& next_proxy() {
|
||||||
|
auto& iv_arg = args.at(next++);
|
||||||
|
return iv_arg.proxy;
|
||||||
|
}
|
||||||
|
|
||||||
void add(const at::IValue* iv) {
|
void add(const at::IValue* iv) {
|
||||||
args.emplace_back(iv);
|
args.emplace_back(iv);
|
||||||
if (active_node_call_idx.has_value()) {
|
if (active_node_call_idx.has_value()) {
|
||||||
@ -278,13 +312,16 @@ class CompiledNodeArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void collect(const at::Tensor& t) {
|
void collect(const at::Tensor& t) {
|
||||||
|
_node_call.collection_info().num_saved_tensors++;
|
||||||
collect(_compiler.tensor_args.add(t));
|
collect(_compiler.tensor_args.add(t));
|
||||||
}
|
}
|
||||||
void collect(const SavedVariable& sv, bool is_output) {
|
void collect(const SavedVariable& sv, bool is_output) {
|
||||||
|
_node_call.collection_info().num_saved_tensors++;
|
||||||
collect(
|
collect(
|
||||||
_compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
|
_compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
|
||||||
}
|
}
|
||||||
void collect(const c10::SymInt& t) {
|
void collect(const c10::SymInt& t) {
|
||||||
|
_node_call.collection_info().num_saved_sizes++;
|
||||||
_compiler.add_size_input(t);
|
_compiler.add_size_input(t);
|
||||||
}
|
}
|
||||||
void collect(const std::vector<SavedVariable>& t, bool is_output) {
|
void collect(const std::vector<SavedVariable>& t, bool is_output) {
|
||||||
@ -366,6 +403,7 @@ class CompiledNodeArgs {
|
|||||||
!nested &&
|
!nested &&
|
||||||
(iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat())) {
|
(iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat())) {
|
||||||
// can't lift ivalues nested in collections
|
// can't lift ivalues nested in collections
|
||||||
|
_node_call.collection_info().num_saved_ivalues++;
|
||||||
_compiler.lifted_ivalue_args.add(&iv);
|
_compiler.lifted_ivalue_args.add(&iv);
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
@ -629,17 +667,110 @@ struct TraceState {
|
|||||||
variable_list outputs;
|
variable_list outputs;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TORCH_API SwapInterface {
|
||||||
|
virtual ~SwapInterface() = default;
|
||||||
|
virtual std::optional<at::Tensor> tensor(const at::Tensor& tensor) = 0;
|
||||||
|
virtual std::optional<at::Tensor> tensor(const SavedVariable& tensor) = 0;
|
||||||
|
virtual std::optional<c10::SymInt> next_size() = 0;
|
||||||
|
virtual c10::IValue next_ivalue() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SwapWithProxies : public SwapInterface {
|
||||||
|
explicit SwapWithProxies(AutogradCompilerCall& compiler, TraceState& state)
|
||||||
|
: compiler_(compiler), state_(state) {}
|
||||||
|
|
||||||
|
~SwapWithProxies() override = default;
|
||||||
|
|
||||||
|
std::optional<at::Tensor> tensor(const at::Tensor& tensor) override {
|
||||||
|
TensorArg& arg = compiler_.tensor_args.lookup(tensor);
|
||||||
|
if (arg.defined()) {
|
||||||
|
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
|
||||||
|
return arg.proxy_tensor;
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<at::Tensor> tensor(const SavedVariable& t) override {
|
||||||
|
TensorArg& arg = compiler_.tensor_args.lookup(t);
|
||||||
|
if (arg.defined()) {
|
||||||
|
return arg.proxy_tensor;
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<c10::SymInt> next_size() override {
|
||||||
|
return state_.next_sym_size();
|
||||||
|
}
|
||||||
|
c10::IValue next_ivalue() override {
|
||||||
|
return compiler_.lifted_ivalue_args.next_proxy();
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||||
|
AutogradCompilerCall& compiler_;
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||||
|
TraceState& state_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// The previous compiled autograd implementation was about swapping in
|
||||||
|
// ProxyTensors for a node. Given a single node and some saved
|
||||||
|
// tensors/sizes/scalars, we needed some way to swap in those saved
|
||||||
|
// tensors/sizes/scalars. That's what SwapWithReal is.
|
||||||
|
struct SwapWithReal : public SwapInterface {
|
||||||
|
explicit SwapWithReal(
|
||||||
|
std::vector<at::Tensor> tensors,
|
||||||
|
std::vector<std::optional<c10::SymInt>> sizes,
|
||||||
|
std::vector<c10::IValue> ivalues)
|
||||||
|
: tensors_(std::move(tensors)),
|
||||||
|
sizes_(std::move(sizes)),
|
||||||
|
ivalues_(std::move(ivalues)) {}
|
||||||
|
|
||||||
|
~SwapWithReal() override = default;
|
||||||
|
|
||||||
|
std::optional<at::Tensor> tensor(const at::Tensor& _ignored) override {
|
||||||
|
auto result = tensors_[tensors_idx];
|
||||||
|
tensors_idx++;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<at::Tensor> tensor(const SavedVariable& _ignored) override {
|
||||||
|
TORCH_INTERNAL_ASSERT(tensors_idx < tensors_.size());
|
||||||
|
auto result = tensors_[tensors_idx];
|
||||||
|
tensors_idx++;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<c10::SymInt> next_size() override {
|
||||||
|
TORCH_INTERNAL_ASSERT(sizes_idx < sizes_.size());
|
||||||
|
auto result = sizes_[sizes_idx];
|
||||||
|
sizes_idx++;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::IValue next_ivalue() override {
|
||||||
|
TORCH_INTERNAL_ASSERT(ivalues_idx < ivalues_.size());
|
||||||
|
auto result = ivalues_[ivalues_idx];
|
||||||
|
ivalues_idx++;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<at::Tensor> tensors_;
|
||||||
|
size_t tensors_idx = 0;
|
||||||
|
std::vector<std::optional<c10::SymInt>> sizes_;
|
||||||
|
size_t sizes_idx = 0;
|
||||||
|
std::vector<c10::IValue> ivalues_;
|
||||||
|
size_t ivalues_idx = 0;
|
||||||
|
};
|
||||||
|
|
||||||
class SwapSavedVariables {
|
class SwapSavedVariables {
|
||||||
// SwapSavedVariables is used during the tracing/compilation phase after a
|
// SwapSavedVariables is used during the tracing/compilation phase after a
|
||||||
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
|
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
|
||||||
// allows tracing to happen, then swaps them back afterwards.
|
// allows tracing to happen, then swaps them back afterwards.
|
||||||
public:
|
public:
|
||||||
void before(at::Tensor& t) {
|
void before(at::Tensor& t) {
|
||||||
TensorArg& arg = compiler.tensor_args.lookup(t);
|
auto replacement = state->tensor(t);
|
||||||
stashed_tensors.save(&t, std::move(t));
|
stashed_tensors.save(&t, std::move(t));
|
||||||
if (arg.defined()) {
|
if (replacement.has_value()) {
|
||||||
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
|
t = *replacement;
|
||||||
t = arg.proxy_tensor;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void after(at::Tensor& t) {
|
void after(at::Tensor& t) {
|
||||||
@ -647,12 +778,11 @@ class SwapSavedVariables {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void before(SavedVariable& t) {
|
void before(SavedVariable& t) {
|
||||||
TensorArg& arg = compiler.tensor_args.lookup(t);
|
auto replacement = state->tensor(t);
|
||||||
stashed_variables.save(&t, std::move(t));
|
stashed_variables.save(&t, std::move(t));
|
||||||
if (arg.defined()) {
|
if (replacement.has_value()) {
|
||||||
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
|
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
|
||||||
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
|
t = SavedVariable(replacement.value(), false);
|
||||||
t = SavedVariable(arg.proxy_tensor, false);
|
|
||||||
at::SavedTensorDefaultHooks::set_tracing(prior);
|
at::SavedTensorDefaultHooks::set_tracing(prior);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -662,7 +792,7 @@ class SwapSavedVariables {
|
|||||||
|
|
||||||
void before(c10::SymInt& t) {
|
void before(c10::SymInt& t) {
|
||||||
stashed_symints.save(&t, c10::SymInt(t));
|
stashed_symints.save(&t, c10::SymInt(t));
|
||||||
auto opt_value = state.next_sym_size();
|
auto opt_value = state->next_size();
|
||||||
if (opt_value.has_value()) {
|
if (opt_value.has_value()) {
|
||||||
t = *opt_value; // dynamic shape
|
t = *opt_value; // dynamic shape
|
||||||
}
|
}
|
||||||
@ -677,7 +807,7 @@ class SwapSavedVariables {
|
|||||||
} else {
|
} else {
|
||||||
stashed_ivalues.save(&iv, at::IValue(iv));
|
stashed_ivalues.save(&iv, at::IValue(iv));
|
||||||
if (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat()) {
|
if (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat()) {
|
||||||
iv = compiler.lifted_ivalue_args.next_proxy(&iv);
|
iv = state->next_ivalue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -824,7 +954,23 @@ class SwapSavedVariables {
|
|||||||
TraceState& s,
|
TraceState& s,
|
||||||
PyObject* p,
|
PyObject* p,
|
||||||
const NodeCall& n)
|
const NodeCall& n)
|
||||||
: compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
|
: py_compiler(p), curr_node_call(n) {
|
||||||
|
state = std::make_shared<SwapWithProxies>(c, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
SwapSavedVariables(
|
||||||
|
std::vector<at::Tensor> a,
|
||||||
|
std::vector<std::optional<at::SymInt>> b,
|
||||||
|
std::vector<at::IValue> c,
|
||||||
|
PyObject* p,
|
||||||
|
const NodeCall& n)
|
||||||
|
: state(std::static_pointer_cast<SwapInterface>(
|
||||||
|
std::make_shared<SwapWithReal>(
|
||||||
|
std::move(a),
|
||||||
|
std::move(b),
|
||||||
|
std::move(c)))),
|
||||||
|
py_compiler(p),
|
||||||
|
curr_node_call(n) {}
|
||||||
|
|
||||||
PyObject* get_py_compiler() {
|
PyObject* get_py_compiler() {
|
||||||
return py_compiler;
|
return py_compiler;
|
||||||
@ -875,9 +1021,10 @@ class SwapSavedVariables {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||||
AutogradCompilerCall& compiler;
|
// AutogradCompilerCall& compiler;
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||||
TraceState& state;
|
std::shared_ptr<SwapInterface> state;
|
||||||
|
// TraceState& state;
|
||||||
// This is a borrowed reference, we do not increment ownership, or lower it,
|
// This is a borrowed reference, we do not increment ownership, or lower it,
|
||||||
// it's lifecycle is entirely longer than this objects.
|
// it's lifecycle is entirely longer than this objects.
|
||||||
PyObject* py_compiler;
|
PyObject* py_compiler;
|
||||||
|
@ -451,6 +451,37 @@ void set_ivalue_proxies(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* call_capture(
|
||||||
|
PyObject* self,
|
||||||
|
CacheNode& cache,
|
||||||
|
AutogradCompilerCall& compiler_call,
|
||||||
|
size_t num_outputs,
|
||||||
|
PyObject* nodecalls) {
|
||||||
|
static PyObject* method_name = PyUnicode_InternFromString("capture");
|
||||||
|
THPObjectPtr pyinput(THPVariable_WrapList(compiler_call.tensor_args.inputs));
|
||||||
|
|
||||||
|
THPObjectPtr pysizeinput(cache.wrap_dynamic_inputs());
|
||||||
|
std::vector<std::optional<c10::SymInt>> dynamic_inputs =
|
||||||
|
cache.unwrap_dynamic_inputs(py::cast<py::list>(pysizeinput.get()).ptr());
|
||||||
|
|
||||||
|
THPObjectPtr pyivalueargsinput(
|
||||||
|
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args));
|
||||||
|
THPObjectPtr pynodeorigins(
|
||||||
|
wrap_node_origins(compiler_call, PyTuple_GET_SIZE(pysizeinput.get())));
|
||||||
|
PyObject* py_num_outputs = THPUtils_packUInt32(num_outputs);
|
||||||
|
return check(PyObject_CallMethodObjArgs(
|
||||||
|
self,
|
||||||
|
method_name,
|
||||||
|
pyinput.get(),
|
||||||
|
// TODO(rzou): is this leaking memory?
|
||||||
|
py::cast(dynamic_inputs).ptr(),
|
||||||
|
pyivalueargsinput.get(),
|
||||||
|
pynodeorigins.get(),
|
||||||
|
nodecalls,
|
||||||
|
py_num_outputs,
|
||||||
|
nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
static TraceState call_begin_capture(
|
static TraceState call_begin_capture(
|
||||||
PyObject* self,
|
PyObject* self,
|
||||||
CacheNode& cache,
|
CacheNode& cache,
|
||||||
@ -552,7 +583,9 @@ CacheNode* _compiled_autograd_impl(
|
|||||||
compiler_call.set_active_node_call_idx(i);
|
compiler_call.set_active_node_call_idx(i);
|
||||||
}
|
}
|
||||||
if (node_args.cond(call.needed)) {
|
if (node_args.cond(call.needed)) {
|
||||||
|
call.mode = CollectionMode::COMPILED_ARGS;
|
||||||
fn->compiled_args(node_args);
|
fn->compiled_args(node_args);
|
||||||
|
call.mode = CollectionMode::NEXT_EDGES;
|
||||||
node_args.collect(call.node->next_edges());
|
node_args.collect(call.node->next_edges());
|
||||||
}
|
}
|
||||||
CacheKey key = node_args.key();
|
CacheKey key = node_args.key();
|
||||||
@ -600,112 +633,15 @@ CacheNode* _compiled_autograd_impl(
|
|||||||
ClosingTHPObjectPtr py_compiler(
|
ClosingTHPObjectPtr py_compiler(
|
||||||
check(PyObject_CallNoArgs((the_autograd_compiler))));
|
check(PyObject_CallNoArgs((the_autograd_compiler))));
|
||||||
|
|
||||||
TraceState state = call_begin_capture(
|
// nodes
|
||||||
py_compiler, *cache, compiler_call, output_edges.size());
|
py::object nodecalls = py::cast(calls);
|
||||||
InputBuffers input_buffers;
|
PyObject* res = call_capture(
|
||||||
|
py_compiler,
|
||||||
|
*cache,
|
||||||
|
compiler_call,
|
||||||
|
output_edges.size(),
|
||||||
|
nodecalls.ptr());
|
||||||
|
|
||||||
for (size_t i = 0; i < calls.size(); i++) {
|
|
||||||
NodeCall& call = *calls[i];
|
|
||||||
// TODO(jansel): consider adding some of this stuff:
|
|
||||||
// guard(local_graph_task); NodeGuard ndguard(task.fn_); const auto
|
|
||||||
// opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
|
|
||||||
// c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
|
|
||||||
// CheckpointValidGuard cpvguard(graph_task);
|
|
||||||
// at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
|
|
||||||
// if (C10_UNLIKELY(step_callbacks.has_value())) { ... }
|
|
||||||
|
|
||||||
variable_list inputs =
|
|
||||||
std::move(input_buffers.lookup(call.node.get()).buffer);
|
|
||||||
input_buffers.erase(call.node.get());
|
|
||||||
|
|
||||||
if (!call.tensor_pre_hooks.empty()) {
|
|
||||||
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
|
|
||||||
for (const auto& hook : call.tensor_pre_hooks) {
|
|
||||||
pyinputs = check(PyObject_CallMethod(
|
|
||||||
py_compiler,
|
|
||||||
"tensor_pre_hook",
|
|
||||||
"Oii",
|
|
||||||
pyinputs.get(),
|
|
||||||
hook.first,
|
|
||||||
hook.second));
|
|
||||||
}
|
|
||||||
inputs = THPVariable_UnpackList(pyinputs);
|
|
||||||
}
|
|
||||||
for (const auto& graph_output : call.graph_output) {
|
|
||||||
int input_nr = graph_output.first;
|
|
||||||
int output_index = graph_output.second;
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
output_index < static_cast<int>(state.outputs.size()));
|
|
||||||
TORCH_INTERNAL_ASSERT(!state.outputs[output_index].defined());
|
|
||||||
state.outputs[output_index] = inputs[input_nr];
|
|
||||||
}
|
|
||||||
if (!call.needed) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!call.pre_hooks.empty()) {
|
|
||||||
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
|
|
||||||
for (const auto hook : call.pre_hooks) {
|
|
||||||
pyinputs = check(PyObject_CallMethod(
|
|
||||||
py_compiler.get(), "pre_hook", "Oi", pyinputs.get(), hook));
|
|
||||||
}
|
|
||||||
inputs = THPVariable_UnpackList(pyinputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string _node_name = call.node->name();
|
|
||||||
THPObjectPtr node_name(PyUnicode_FromString(_node_name.data()));
|
|
||||||
TORCH_INTERNAL_ASSERT(node_name != nullptr);
|
|
||||||
THPObjectPtr set_node_origin(
|
|
||||||
PyObject_GetAttrString(py_compiler.get(), "set_node_origin"));
|
|
||||||
|
|
||||||
PyObject* pyobj = Py_None;
|
|
||||||
if (auto pynode = std::dynamic_pointer_cast<PyNode>(call.node)) {
|
|
||||||
pyobj = pynode->obj;
|
|
||||||
}
|
|
||||||
|
|
||||||
check(PyObject_CallFunction(
|
|
||||||
set_node_origin, "OIO", node_name.get(), i, pyobj, nullptr));
|
|
||||||
|
|
||||||
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
|
|
||||||
variable_list outputs = call.node->apply_with_saved(inputs, saved);
|
|
||||||
|
|
||||||
saved.debug_asserts();
|
|
||||||
saved.before(call.node->next_edges());
|
|
||||||
validate_outputs(
|
|
||||||
call.node->next_edges(), outputs, [&](const std::string& msg) {
|
|
||||||
std::ostringstream ss;
|
|
||||||
ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
|
|
||||||
<< msg;
|
|
||||||
return ss.str();
|
|
||||||
});
|
|
||||||
saved.after(call.node->next_edges());
|
|
||||||
saved.debug_asserts();
|
|
||||||
|
|
||||||
if (!call.post_hooks.empty()) {
|
|
||||||
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
|
|
||||||
THPObjectPtr pyoutputs(THPVariable_WrapList(outputs));
|
|
||||||
for (const auto hook : call.post_hooks) {
|
|
||||||
pyoutputs = check(PyObject_CallMethod(
|
|
||||||
py_compiler.get(),
|
|
||||||
"post_hook",
|
|
||||||
"OOi",
|
|
||||||
pyoutputs.get(),
|
|
||||||
pyinputs.get(),
|
|
||||||
hook));
|
|
||||||
}
|
|
||||||
outputs = THPVariable_UnpackList(pyoutputs);
|
|
||||||
}
|
|
||||||
for (const auto i : c10::irange(outputs.size())) {
|
|
||||||
auto& output = outputs[i];
|
|
||||||
const auto& next = call.node->next_edge(i);
|
|
||||||
if (next.is_valid() && output.defined()) {
|
|
||||||
input_buffers.lookup(next.function.get())
|
|
||||||
.add(
|
|
||||||
next.input_nr, std::move(output), std::nullopt, std::nullopt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PyObject* res = check(call_end_capture(py_compiler, state.outputs));
|
|
||||||
TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
|
TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
PyTuple_Size(res) == 2,
|
PyTuple_Size(res) == 2,
|
||||||
@ -718,15 +654,25 @@ CacheNode* _compiled_autograd_impl(
|
|||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
PyCallable_Check(cache->compiled_fn),
|
PyCallable_Check(cache->compiled_fn),
|
||||||
"Expected end_capture to return compiled_fn");
|
"Expected end_capture to return compiled_fn");
|
||||||
state.debug_asserts();
|
// TODO(rzou): what is this?
|
||||||
|
// state.debug_asserts();
|
||||||
} // End cache miss region
|
} // End cache miss region
|
||||||
|
|
||||||
|
// TODO(rzou): need some mechanism to release the variables when we're ready.
|
||||||
// TODO(jansel): clear grads we will overwrite below
|
// TODO(jansel): clear grads we will overwrite below
|
||||||
if (!graph_task.keep_graph_) {
|
// if (!graph_task.keep_graph_) {
|
||||||
for (auto& call : calls) {
|
// for (auto& call : calls) {
|
||||||
call->node->release_variables();
|
// call->node->release_variables();
|
||||||
}
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// TODO(rzou): we probably shouldn't be copying the nodes in the hot path?
|
||||||
|
std::vector<NodeCall> persistent_node_calls;
|
||||||
|
for (NodeCall* call : calls) {
|
||||||
|
persistent_node_calls.push_back(*call);
|
||||||
}
|
}
|
||||||
|
auto ca = py::module::import("torch._dynamo.compiled_autograd");
|
||||||
|
ca.attr("set_global_nodecalls")(persistent_node_calls);
|
||||||
|
|
||||||
*graph_arg_inputs = THPVariable_WrapList(compiler_call.tensor_args.inputs);
|
*graph_arg_inputs = THPVariable_WrapList(compiler_call.tensor_args.inputs);
|
||||||
*graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);
|
*graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);
|
||||||
|
Reference in New Issue
Block a user