mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-31 04:04:57 +08:00
Compare commits
3 Commits
ciflow/tru
...
fca5
| Author | SHA1 | Date | |
|---|---|---|---|
| 8cf38caf9c | |||
| 06c9bbc70a | |||
| 94d8ba4625 |
@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry {
|
||||
has_symbolic_sizes_strides_(
|
||||
t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
|
||||
|
||||
explicit TensorGeometry(
|
||||
std::vector<at::SymInt> sizes,
|
||||
std::vector<at::SymInt> strides,
|
||||
at::SymInt storage_offset)
|
||||
: sizes_(std::move(sizes)),
|
||||
strides_(std::move(strides)),
|
||||
storage_offset_(std::move(storage_offset)) {
|
||||
recompute();
|
||||
}
|
||||
|
||||
// true if the tensor is contiguous
|
||||
bool is_contiguous() const;
|
||||
|
||||
|
||||
@ -683,6 +683,8 @@ struct TORCH_API IValue final {
|
||||
c10::List<int64_t> toIntList() &&;
|
||||
c10::List<int64_t> toIntList() const&;
|
||||
std::vector<int64_t> toIntVector() const;
|
||||
c10::List<c10::SymInt> toSymIntList() &&;
|
||||
c10::List<c10::SymInt> toSymIntList() const&;
|
||||
std::vector<c10::SymInt> toSymIntVector() const;
|
||||
at::DimVector toDimVector() const;
|
||||
|
||||
|
||||
@ -1734,6 +1734,7 @@ DEFINE_TO(c10::intrusive_ptr<ivalue::ConstantString>, toString)
|
||||
DEFINE_TO(c10::intrusive_ptr<ivalue::Object>, toObject)
|
||||
DEFINE_TO(at::Scalar, toScalar)
|
||||
DEFINE_TO(c10::List<int64_t>, toIntList)
|
||||
DEFINE_TO(c10::List<c10::SymInt>, toSymIntList)
|
||||
DEFINE_TO(c10::List<double>, toDoubleList)
|
||||
DEFINE_TO(c10::List<c10::complex<double>>, toComplexDoubleList)
|
||||
DEFINE_TO(c10::List<bool>, toBoolList)
|
||||
@ -1990,6 +1991,20 @@ inline std::vector<int64_t> IValue::toIntVector() const {
|
||||
return createVectorFromList<int64_t>(
|
||||
static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
|
||||
}
|
||||
inline c10::List<c10::SymInt> IValue::toSymIntList() && {
|
||||
AT_ASSERT(
|
||||
isSymIntList() || isIntList(),
|
||||
"Expected SymIntList or IntList but got ",
|
||||
tagKind());
|
||||
return c10::List<c10::SymInt>(moveToIntrusivePtr<c10::detail::ListImpl>());
|
||||
}
|
||||
inline c10::List<c10::SymInt> IValue::toSymIntList() const& {
|
||||
AT_ASSERT(
|
||||
isSymIntList() || isIntList(),
|
||||
"Expected SymIntList or IntList but got ",
|
||||
tagKind());
|
||||
return c10::List<c10::SymInt>(toIntrusivePtr<c10::detail::ListImpl>());
|
||||
}
|
||||
inline std::vector<c10::SymInt> IValue::toSymIntVector() const {
|
||||
AT_ASSERT(isSymIntList() || isIntList(), "Expected SymIntList or IntList but got ", tagKind());
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
|
||||
@ -609,6 +609,33 @@ TEST(IValueTest, isAliasOf) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(IValueTest, toSymIntList) {
|
||||
std::vector<int64_t> int_list = {2, 3};
|
||||
auto iv = IValue(int_list);
|
||||
auto result = iv.toSymIntList();
|
||||
EXPECT_EQ(result.size(), 2);
|
||||
EXPECT_EQ(result.get(0), 2);
|
||||
EXPECT_EQ(result.get(1), 3);
|
||||
}
|
||||
|
||||
TEST(IValueTest, toSymIntListTemplate) {
|
||||
std::vector<int64_t> int_list = {2, 3};
|
||||
auto iv = IValue(int_list);
|
||||
auto result = iv.to<c10::List<c10::SymInt>>();
|
||||
EXPECT_EQ(result.size(), 2);
|
||||
EXPECT_EQ(result.get(0), 2);
|
||||
EXPECT_EQ(result.get(1), 3);
|
||||
}
|
||||
|
||||
TEST(IValueTest, toSymIntVector) {
|
||||
std::vector<int64_t> int_list = {2, 3};
|
||||
auto iv = IValue(int_list);
|
||||
auto result = iv.to<std::vector<c10::SymInt>>();
|
||||
EXPECT_EQ(result.size(), 2);
|
||||
EXPECT_EQ(result[0], 2);
|
||||
EXPECT_EQ(result[1], 3);
|
||||
}
|
||||
|
||||
TEST(IValueTest, internalToPointer) {
|
||||
IValue tensor(at::rand({3, 4}));
|
||||
IValue str("hello");
|
||||
|
||||
@ -476,6 +476,7 @@ inductor_core_resources = [
|
||||
"torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp",
|
||||
"torch/csrc/inductor/inductor_ops.cpp",
|
||||
"torch/csrc/jit/serialization/pickle.cpp",
|
||||
"torch/csrc/dynamo/compiled_autograd.cpp",
|
||||
]
|
||||
|
||||
libtorch_core_sources = sorted(
|
||||
|
||||
@ -2908,7 +2908,8 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
||||
"aot0_le",
|
||||
"aot0_permute_2",
|
||||
"code: CompiledFunctionBackward0 (NodeCall 2)",
|
||||
"aot0_tangents_1",
|
||||
# TODO(rzou): what is going on here?
|
||||
# "aot0_tangents_1",
|
||||
"aot0_full_default",
|
||||
"aot0_where",
|
||||
"aot0_mm",
|
||||
@ -3516,6 +3517,7 @@ known_failing_tests = {
|
||||
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
|
||||
"test_backward_hook_relative_ordering", # compiled autograd collects breadth first, and module backward hook not supported
|
||||
# Uncategorized
|
||||
"test_not_implemented_grad", # Dynamo changes the types of exceptions
|
||||
}
|
||||
|
||||
if not HAS_CUDA:
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Sequence
|
||||
|
||||
from torchgen.api.autograd import (
|
||||
Derivative,
|
||||
@ -47,10 +47,6 @@ from torchgen.utils import FileManager
|
||||
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
FUNCTION_DECLARATION = CodeTemplate(
|
||||
"""\
|
||||
#ifdef _WIN32
|
||||
@ -68,6 +64,7 @@ struct TORCH_API ${op} : public ${superclass} {
|
||||
}
|
||||
${will_release_variables}
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
ivalue_list get_state();
|
||||
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
|
||||
${saved_variables}
|
||||
${saved_list_sizes}
|
||||
@ -99,7 +96,7 @@ FUNCTION_DEFINITION = CodeTemplate(
|
||||
"""\
|
||||
static variable_list ${op}_apply_functional(
|
||||
variable_list&& grads,
|
||||
std::array<bool,${num_vars}> needs_input_grad${,unpacked_saved_vars_signature})
|
||||
std::array<bool,${num_inputs}> needs_input_grad${,apply_functional_args_signature})
|
||||
{
|
||||
IndexRangeGenerator gen;
|
||||
${compute_index_ranges}
|
||||
@ -107,24 +104,58 @@ static variable_list ${op}_apply_functional(
|
||||
${body}
|
||||
return grad_inputs;
|
||||
}
|
||||
static variable_list ${op}_apply_functional_ivalue(const variable_list& grads, const ivalue_list& stack)
|
||||
{
|
||||
auto state = SavedState(stack);
|
||||
auto needs_input_grad = state.unpack<std::array<bool, ${num_inputs}>>();
|
||||
${saved_var_dequeues}
|
||||
return ${op}_apply_functional(variable_list(grads), needs_input_grad${,apply_functional_args});
|
||||
}
|
||||
|
||||
variable_list ${op}::apply(variable_list&& grads) {
|
||||
${thread_lock}
|
||||
${asserts}
|
||||
${unpacks}
|
||||
${compute_needs_input_grad}
|
||||
return ${op}_apply_functional(std::move(grads), needs_input_grad${,unpacked_saved_vars});
|
||||
return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
|
||||
}
|
||||
|
||||
void ${op}::compiled_args(CompiledNodeArgs& args) {
|
||||
${compiled_args}
|
||||
}
|
||||
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
|
||||
${apply_with_saved_before}
|
||||
variable_list result = apply(variable_list(grads));
|
||||
${apply_with_saved_after}
|
||||
return result;
|
||||
${apply_with_saved_before}
|
||||
|
||||
variable_list result;
|
||||
auto state = get_state();
|
||||
${compute_schema}
|
||||
const auto& interface = torch::dynamo::autograd::getPyCompilerInterface();
|
||||
result = interface->call_function(
|
||||
saved.get_py_compiler(),
|
||||
"apply_functional",
|
||||
${op}_apply_functional_ivalue,
|
||||
grads,
|
||||
state,
|
||||
num_outputs(),
|
||||
name(),
|
||||
schema,
|
||||
/*builtin*/true);
|
||||
|
||||
|
||||
${apply_with_saved_after}
|
||||
return result;
|
||||
}
|
||||
ivalue_list ${op}::get_state() {
|
||||
SavedState saved_state;
|
||||
${unpacks}
|
||||
${compute_needs_input_grad}
|
||||
saved_state.pack(needs_input_grad);
|
||||
${get_state}
|
||||
std::vector<std::optional<InputMetadata>> input_metadata = collect_input_metadata(next_edges());
|
||||
saved_state.pack(input_metadata);
|
||||
return saved_state.stack;
|
||||
}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
@ -587,24 +618,27 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
||||
compiled_args: list[str] = []
|
||||
apply_with_saved_before: list[str] = []
|
||||
apply_with_saved_after: list[str] = []
|
||||
unpacked_saved_vars: list[str] = []
|
||||
unpacked_saved_vars_ref_type: list[str] = []
|
||||
# Maps var_name to a unique index. The var_name is the
|
||||
# name of an input to the operator that needs a gradient (like "self", "other").
|
||||
# The index is the order in which they appear. We use this mapping
|
||||
# to populate needs_input_grad in some order and then grab values from it.
|
||||
var_name_map: dict[str, int] = {}
|
||||
apply_functional_args: list[str] = []
|
||||
apply_functional_args_ref_types: list[str] = []
|
||||
# Maps the name of an input (to the original forward operator;
|
||||
# examples are "self", "other") to the order in which they appear in the
|
||||
# operator.
|
||||
# For example; if the operator is foo(Tensor self, int64_t k, Tensor other),
|
||||
# the mapping is: {"self": 0, "other": 1}.
|
||||
# We use this mapping to populate needs_input_grad in some order and then grab
|
||||
# values from it.
|
||||
input_name_to_idx: dict[str, int] = {}
|
||||
|
||||
for idx, arg in enumerate(info.args_with_derivatives):
|
||||
if arg.type in TENSOR_LIST_LIKE_CTYPES:
|
||||
size = f"{arg.name}_size_"
|
||||
saved_list_sizes.append(f"size_t {arg.name}_size_;")
|
||||
unpacked_saved_vars.append(f"{arg.name}_size_")
|
||||
unpacked_saved_vars_ref_type.append("size_t")
|
||||
apply_functional_args.append(f"{arg.name}_size_")
|
||||
apply_functional_args_ref_types.append("size_t")
|
||||
else:
|
||||
size = "1"
|
||||
compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
|
||||
var_name_map[arg.name] = idx
|
||||
input_name_to_idx[arg.name] = idx
|
||||
|
||||
def save_var(var: SavedAttribute, is_output: bool) -> None:
|
||||
name = var.nctype.name
|
||||
@ -856,8 +890,8 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
|
||||
if unpacked_ref_type is None:
|
||||
unpacked_ref_type = f"{saved_variables[-1].split(' ')[0]}&"
|
||||
unpacked_saved_vars.append(str(name))
|
||||
unpacked_saved_vars_ref_type.append(unpacked_ref_type)
|
||||
apply_functional_args.append(str(name))
|
||||
apply_functional_args_ref_types.append(unpacked_ref_type)
|
||||
|
||||
for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
|
||||
save_var(var, is_output=False)
|
||||
@ -872,8 +906,8 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
thread_lock = ""
|
||||
|
||||
if uses_retain_variables(info):
|
||||
unpacked_saved_vars.append("retain_variables")
|
||||
unpacked_saved_vars_ref_type.append("bool")
|
||||
apply_functional_args.append("retain_variables")
|
||||
apply_functional_args_ref_types.append("bool")
|
||||
will_release_variables = WILL_RELEASE_VARIABLES.substitute()
|
||||
else:
|
||||
will_release_variables = ""
|
||||
@ -919,14 +953,15 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
derivative_template.substitute(
|
||||
name=var_names[0],
|
||||
derivative=formula,
|
||||
idx=var_name_map[var_names[0]],
|
||||
idx=input_name_to_idx[var_names[0]],
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
if "grad_input_mask" in formula:
|
||||
masks = [
|
||||
f"needs_input_grad[{var_name_map[name]}]," for name in var_names
|
||||
f"needs_input_grad[{input_name_to_idx[name]}],"
|
||||
for name in var_names
|
||||
]
|
||||
grad_input_mask = GRAD_INPUT_MASK.substitute(
|
||||
n=len(var_names), masks=masks
|
||||
@ -934,14 +969,14 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
else:
|
||||
grad_input_mask = ""
|
||||
needs_input_grad = [
|
||||
f"needs_input_grad[{var_name_map[name]}]" for name in var_names
|
||||
f"needs_input_grad[{input_name_to_idx[name]}]" for name in var_names
|
||||
]
|
||||
needs_input_grad = " || ".join(needs_input_grad)
|
||||
copy_ranges: list[str] = []
|
||||
for i, n in enumerate(var_names):
|
||||
copy_ranges.append(
|
||||
DERIVATIVE_MULTI_COPY_RANGE.substitute(
|
||||
name=n, i=i, idx=var_name_map[n]
|
||||
name=n, i=i, idx=input_name_to_idx[n]
|
||||
)
|
||||
)
|
||||
return False, DERIVATIVE_MULTI.substitute(
|
||||
@ -961,7 +996,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
body.append(derivative_text)
|
||||
need_any_grad_defined_var |= checks_any_grad_defined
|
||||
|
||||
for name in var_name_map:
|
||||
for name in input_name_to_idx:
|
||||
masks.append(f"task_should_compute_output({{ {name}_ix }}),")
|
||||
|
||||
# Since single-output derivative formulas need to check if grads are
|
||||
@ -985,17 +1020,45 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
compute_needs_input_grad = COMPUTE_NEEDS_INPUT_GRAD.substitute(
|
||||
n=len(masks), compute_index_ranges=compute_index_ranges, masks=masks
|
||||
)
|
||||
unpacked_saved_vars_signature = [
|
||||
f"{T} {x}" for T, x in zip(unpacked_saved_vars_ref_type, unpacked_saved_vars)
|
||||
apply_functional_args_signature = [
|
||||
f"{T} {x}"
|
||||
for T, x in zip(apply_functional_args_ref_types, apply_functional_args)
|
||||
]
|
||||
get_state = "\n".join(
|
||||
f"saved_state.pack({name});" for name in apply_functional_args
|
||||
)
|
||||
saved_var_dequeues = []
|
||||
for typ, name in zip(apply_functional_args_ref_types, apply_functional_args):
|
||||
if typ.endswith("&"):
|
||||
typ = typ[:-1]
|
||||
saved_var_dequeues.append(f"auto {name} = state.unpack<{typ}>();")
|
||||
|
||||
schema_args = [f"std::array<bool, {len(input_name_to_idx)}>"]
|
||||
for typ in apply_functional_args_ref_types:
|
||||
if typ.endswith("&"):
|
||||
typ = typ[:-1]
|
||||
if typ.startswith("const"):
|
||||
typ = typ[5:]
|
||||
schema_args.append(typ.strip())
|
||||
compute_schema = ["std::vector<at::TypePtr> schema = {"]
|
||||
for arg in schema_args:
|
||||
compute_schema.append(
|
||||
f" torch::dynamo::autograd::IValuePacker<{arg}>::packed_type(),"
|
||||
)
|
||||
# compute_schema.append(
|
||||
# f" torch::dynamo::autograd::IValuePacker<std::vector<std::optional<InputMetadata>>>::packed_type()"
|
||||
# )
|
||||
compute_schema.append("};")
|
||||
|
||||
return template.substitute(
|
||||
unpacks="\n".join(unpack),
|
||||
op=info.op,
|
||||
unpacked_saved_vars=unpacked_saved_vars,
|
||||
unpacked_saved_vars_signature=unpacked_saved_vars_signature,
|
||||
compute_schema="\n".join(compute_schema),
|
||||
apply_functional_args=apply_functional_args,
|
||||
apply_functional_args_signature=apply_functional_args_signature,
|
||||
compute_needs_input_grad=compute_needs_input_grad,
|
||||
num_vars=len(var_name_map),
|
||||
num_inputs=len(input_name_to_idx),
|
||||
saved_var_dequeues="\n".join(saved_var_dequeues),
|
||||
compute_index_ranges=compute_index_ranges,
|
||||
saved_variables=saved_variables,
|
||||
release_variables=release_variables,
|
||||
@ -1010,4 +1073,5 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
compiled_args=compiled_args,
|
||||
apply_with_saved_before=apply_with_saved_before,
|
||||
apply_with_saved_after=apply_with_saved_after,
|
||||
get_state=get_state,
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
#include "torch/csrc/autograd/FunctionsManual.h"
|
||||
#include "torch/csrc/autograd/engine.h"
|
||||
#include "torch/csrc/dynamo/compiled_autograd.h"
|
||||
|
||||
// ${generated_comment}
|
||||
|
||||
@ -5,6 +5,7 @@ import operator
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.external_utils import (
|
||||
call_backward,
|
||||
call_hook,
|
||||
@ -56,6 +57,46 @@ def maybe_clone(x):
|
||||
return x
|
||||
|
||||
|
||||
class OpNamespace:
|
||||
def __init__(self):
|
||||
self.next_id = {}
|
||||
|
||||
def add(self, base_name, fn, builtin):
|
||||
if builtin and hasattr(self, base_name):
|
||||
return getattr(self, base_name)
|
||||
|
||||
name = base_name
|
||||
if not builtin:
|
||||
if base_name not in self.next_id:
|
||||
self.next_id[base_name] = 0
|
||||
nid = self.next_id[base_name]
|
||||
name = f"{base_name}_{nid}"
|
||||
self.next_id[base_name] += 1
|
||||
result = Op(name, fn)
|
||||
torch._dynamo.allow_in_graph(result)
|
||||
setattr(self, name, result)
|
||||
return result
|
||||
|
||||
|
||||
class Op:
|
||||
def __init__(self, name, fn):
|
||||
self.fn = fn
|
||||
self.__name__ = name
|
||||
self.__module__ = "torch._dynamo.compiled_autograd.ops"
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__module__ + "." + self.__name__
|
||||
|
||||
def __str__(self):
|
||||
return self.__module__ + "." + self.__name__
|
||||
|
||||
|
||||
ops = OpNamespace()
|
||||
|
||||
|
||||
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
|
||||
_impure_targets = OrderedSet(
|
||||
[
|
||||
@ -103,7 +144,8 @@ class AutogradCompilerInstance:
|
||||
self.fx_tracer.root = torch.nn.Module()
|
||||
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
|
||||
self.fx_tracer.tensor_attrs = {}
|
||||
args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = (
|
||||
self.symnode_proxy_lookup = {}
|
||||
args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = (
|
||||
self.fx_tracer.create_proxy("placeholder", name, (), {})
|
||||
for name in _graph_placeholders
|
||||
)
|
||||
@ -126,7 +168,9 @@ class AutogradCompilerInstance:
|
||||
)
|
||||
for idx, val in enumerate(sizes)
|
||||
]
|
||||
self.bind_tensors_to_proxies(sizes, sizes_proxy, sizes_origins)
|
||||
self.bind_tensors_to_proxies(sizes, self.sizes_proxy, sizes_origins)
|
||||
for i, symint in enumerate(sizes):
|
||||
self.symnode_proxy_lookup[id(symint.node)] = self.sizes_proxy[i]
|
||||
|
||||
for idx, val in enumerate(scalars):
|
||||
source = self.source("scalars", idx)
|
||||
@ -148,7 +192,9 @@ class AutogradCompilerInstance:
|
||||
)
|
||||
else:
|
||||
raise AssertionError("Unexpected scalar type: ", type(val))
|
||||
self.bind_tensors_to_proxies(scalars, scalars_proxy, scalars_origins)
|
||||
self.bind_tensors_to_proxies(scalars, self.scalars_proxy, scalars_origins)
|
||||
for i, symval in enumerate(scalars):
|
||||
self.symnode_proxy_lookup[id(symval.node)] = self.scalars_proxy[i] # type: ignore[union-attr]
|
||||
|
||||
# TODO(jansel): are all these modes needed?
|
||||
self.stack.enter_context(decompose({}))
|
||||
@ -170,7 +216,6 @@ class AutogradCompilerInstance:
|
||||
saved_tensors,
|
||||
backward_idx: int,
|
||||
):
|
||||
assert self.hooks_proxy is not None
|
||||
backward_c_function = self.hooks_proxy[backward_idx] # type: ignore[index]
|
||||
proxies = self.fx_tracer.create_proxy(
|
||||
kind="call_function",
|
||||
@ -182,7 +227,6 @@ class AutogradCompilerInstance:
|
||||
),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
# create fake Tensors
|
||||
grad_ins: List[Optional[torch.Tensor]] = []
|
||||
@ -198,6 +242,72 @@ class AutogradCompilerInstance:
|
||||
self.bind_tensors_to_proxies(grad_ins, proxies)
|
||||
return tuple(grad_ins)
|
||||
|
||||
def allocate_dummy(self):
|
||||
with disable_proxy_modes_tracing():
|
||||
return torch.zeros(0)
|
||||
|
||||
def apply_functional(self, fn, inputs, stack, num_outputs, debug_name, builtin):
|
||||
input_metadata = stack[-1]
|
||||
other_stuff = stack[:-1]
|
||||
proxy_inputs, proxy_stack = pytree.tree_map(
|
||||
lambda e: self.to_proxy(e),
|
||||
(inputs, other_stuff),
|
||||
)
|
||||
op = ops.add(debug_name, fn, builtin)
|
||||
proxy_out = self.fx_tracer.create_proxy(
|
||||
"call_function", op, args=(proxy_inputs, *proxy_stack), kwargs={}
|
||||
)
|
||||
result = [self.zeros_like(input_metadata[i]) for i in range(num_outputs)]
|
||||
# result = [self.allocate_dummy() for i in range(num_outputs)]
|
||||
self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(num_outputs)])
|
||||
return result
|
||||
|
||||
def proxy_call(self, fn, args, num_outputs):
|
||||
flat_args, _ = pytree.tree_flatten(args)
|
||||
proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args)
|
||||
proxy_out = self.fx_tracer.create_proxy(
|
||||
"call_function", fn, args=proxy_args, kwargs={}
|
||||
)
|
||||
result = [self.allocate_dummy(*flat_args) for _ in range(num_outputs)]
|
||||
self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(num_outputs)])
|
||||
return result
|
||||
|
||||
def zeros_like(self, input_metadata):
|
||||
if input_metadata is None:
|
||||
return None
|
||||
if not isinstance(input_metadata, tuple):
|
||||
breakpoint()
|
||||
tensoroptions, shape, _ = input_metadata
|
||||
kwargs = {}
|
||||
names = ["requires_grad", "memory_format", "device", "dtype", "layout", "pinned_memory"]
|
||||
for name, option in zip(names, tensoroptions):
|
||||
if option is not None:
|
||||
kwargs[name] = option
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
return torch.ops.aten.zeros(shape, **kwargs)
|
||||
|
||||
def validate_outputs(self, fn, outputs, stack, _0, _1, _2):
|
||||
proxy_outputs, proxy_stack = pytree.tree_map(
|
||||
lambda e: self.to_proxy(e),
|
||||
(outputs, stack),
|
||||
)
|
||||
op = ops.add("validate_outputs", fn, True)
|
||||
new_proxy_outputs = self.fx_tracer.create_proxy(
|
||||
"call_function", op, args=(proxy_outputs, *proxy_stack), kwargs={}
|
||||
)
|
||||
|
||||
# Guess what the outputs should be from the InputMetadata.
|
||||
# This is not sound in general (we guess contiguous strides
|
||||
# and no Tensor subclass-ness); we will stop guessing
|
||||
# the output metadata in a follow-up.
|
||||
input_metadatas = stack[0]
|
||||
assert len(input_metadatas) == len(outputs)
|
||||
outputs = [None if output is None or metadata is None else self.zeros_like(metadata) for output, metadata in zip(outputs, input_metadatas)]
|
||||
|
||||
self.bind_tensors_to_proxies(outputs, new_proxy_outputs)
|
||||
return outputs
|
||||
|
||||
def proxy_call_hook(self, hook, *args, **kwargs):
|
||||
return self.fx_tracer.create_proxy(
|
||||
"call_function",
|
||||
@ -280,6 +390,7 @@ class AutogradCompilerInstance:
|
||||
assert nodes[first_getitem_idx] == inputs_users[0]
|
||||
last_getitem_idx = first_getitem_idx + len(inputs_users) - 1
|
||||
assert nodes[last_getitem_idx] == inputs_users[-1]
|
||||
# getitem nodes on inputs
|
||||
for i, node in enumerate(inputs_users):
|
||||
if not has_cuda_inputs and node.meta["val"].device.type == "cuda":
|
||||
has_cuda_inputs = True
|
||||
@ -289,18 +400,20 @@ class AutogradCompilerInstance:
|
||||
is_scalar = len(node.meta["val"].size()) == 0
|
||||
if is_cpu and is_scalar:
|
||||
node_users = list(node.users.keys())
|
||||
# We can only move the cpu scalar if it is not exposed to user code.
|
||||
# The only possible user code using the Op class is custom C++ autograd functions and C++ nodes.
|
||||
if all(
|
||||
isinstance(user.target, torch._ops.OpOverload)
|
||||
and user.target.namespace in ("prims", "aten")
|
||||
isinstance(user.target, torch._dynamo.compiled_autograd.Op)
|
||||
and "CppFunction" not in user.target.__name__
|
||||
for user in node_users
|
||||
):
|
||||
# all users are prims/aten, can move safely
|
||||
to_move[i] = node
|
||||
|
||||
# only move cpu scalars to cuda if there were cuda activations in this graph,
|
||||
# this is to handle the case where cudagraphs is enabled on a cpu-only graph
|
||||
if has_cuda_inputs:
|
||||
for node in to_move.values():
|
||||
verbose_log.debug("Moving node %s from cpu to cuda", node)
|
||||
node.meta["val"] = node.meta["val"].cuda()
|
||||
|
||||
# return runtime indices we need to move to cuda
|
||||
@ -334,7 +447,10 @@ class AutogradCompilerInstance:
|
||||
or (node.op == "call_function" and node.target in _impure_targets)
|
||||
)
|
||||
|
||||
before = len(list(self.fx_tracer.graph.nodes))
|
||||
self.fx_tracer.graph.eliminate_dead_code(is_impure)
|
||||
after = len(list(self.fx_tracer.graph.nodes))
|
||||
verbose_log.debug("DCE removed %d nodes", before - after)
|
||||
|
||||
def end_capture(self, outputs):
|
||||
self.fx_tracer.create_proxy(
|
||||
@ -350,6 +466,10 @@ class AutogradCompilerInstance:
|
||||
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
|
||||
{},
|
||||
)
|
||||
runtime_inputs_to_move: List[int] = []
|
||||
if snapshot_cudagraph_enabled():
|
||||
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
|
||||
# TODO: remove the graph node's dummy metadata
|
||||
self.rename_aot_dispatcher_nodes()
|
||||
self.reorder_tensor_pre_hook_nodes()
|
||||
self.reorder_pre_hook_nodes_to_schedule_asap()
|
||||
@ -368,9 +488,6 @@ class AutogradCompilerInstance:
|
||||
# Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and
|
||||
# should prevent these ops from going into the CA graph.
|
||||
self.dce()
|
||||
runtime_inputs_to_move: List[int] = []
|
||||
if snapshot_cudagraph_enabled():
|
||||
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
|
||||
|
||||
graph = GraphModule(
|
||||
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
|
||||
@ -728,8 +845,10 @@ class AutogradCompilerInstance:
|
||||
return [self.to_proxy(x) for x in t]
|
||||
if isinstance(t, tuple):
|
||||
return tuple(self.to_proxy(x) for x in t)
|
||||
# can it be torch.SymInt as the code used to imply?
|
||||
assert isinstance(t, torch.Tensor)
|
||||
if isinstance(t, (torch.SymInt, torch.SymFloat)):
|
||||
return self.symnode_proxy_lookup[id(t.node)]
|
||||
if not isinstance(t, torch.Tensor):
|
||||
return t
|
||||
proxy_tensor = fetch_object_proxy(self.fx_tracer, t)
|
||||
assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor)
|
||||
return proxy_tensor.proxy
|
||||
|
||||
@ -262,6 +262,7 @@ auto ReadyQueue::pop() -> NodeTask {
|
||||
// Lock mutex for accesses to heap_
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
not_empty_.wait(lock, [this] { return !heap_.empty(); });
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
auto task = std::move(const_cast<NodeTask&>(heap_.top()));
|
||||
heap_.pop();
|
||||
return task;
|
||||
@ -734,14 +735,14 @@ void GraphTask::exec_post_processing() {
|
||||
// the stashed streams should be enough. If leaf_stream.device_index()
|
||||
// happens to be for a new device, operator* on the std::nullopt should
|
||||
// throw an error.
|
||||
const auto& caller_current_stream =
|
||||
caller_current_streams_[leaf_stream.device_index()];
|
||||
const auto caller_current_stream =
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
*caller_current_streams_[leaf_stream.device_index()];
|
||||
|
||||
if (caller_current_stream.has_value() &&
|
||||
caller_current_stream != leaf_stream) {
|
||||
if (caller_current_stream != leaf_stream) {
|
||||
auto event = c10::Event{leaf_stream.device_type()};
|
||||
event.record(leaf_stream);
|
||||
caller_current_stream->wait(event);
|
||||
caller_current_stream.wait(event);
|
||||
}
|
||||
}
|
||||
|
||||
@ -874,7 +875,6 @@ const InputMetadata& get_input_metadata(const T& thing);
|
||||
template <>
|
||||
const InputMetadata& get_input_metadata<c10::optional<InputMetadata>>(
|
||||
const c10::optional<InputMetadata>& thing) {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
return thing.value();
|
||||
}
|
||||
|
||||
@ -898,6 +898,19 @@ bool has_input_metadata<Edge>(const Edge& thing) {
|
||||
return thing.is_valid();
|
||||
}
|
||||
|
||||
std::vector<c10::optional<InputMetadata>> collect_input_metadata(
|
||||
const edge_list& edges) {
|
||||
std::vector<c10::optional<InputMetadata>> input_metadata;
|
||||
for (const auto& edge : edges) {
|
||||
if (!edge.is_valid()) {
|
||||
input_metadata.emplace_back(c10::nullopt);
|
||||
continue;
|
||||
}
|
||||
input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr));
|
||||
}
|
||||
return input_metadata;
|
||||
}
|
||||
|
||||
// Given an vector<Edge> or vector<optional<InputMetdata>>, validate the
|
||||
// outputs. This involves using the InputMetadata to check the outputs and also
|
||||
// potentially calling .sum_to on the outputs.
|
||||
@ -913,9 +926,12 @@ void validate_outputs_impl(
|
||||
TORCH_CHECK(false, format_error(ss.str()));
|
||||
}
|
||||
for (const auto i : c10::irange(grads.size())) {
|
||||
if (!has_input_metadata(input_metadata_container[i])) {
|
||||
// std::cout << "validate_outputs_impl: " << i << std::endl;
|
||||
if (!has_input_metadata(input_metadata_container.at(i))) {
|
||||
continue;
|
||||
}
|
||||
// std::cout << "validate_outputs_impl get_input_metadata: " << i <<
|
||||
// std::endl;
|
||||
const auto& metadata = get_input_metadata(input_metadata_container[i]);
|
||||
auto& grad = grads[i];
|
||||
if (!grad.defined()) {
|
||||
@ -1602,7 +1618,6 @@ void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
|
||||
// Remembers current streams on all devices where a context has been created for
|
||||
// This function assumes the accelerator device is available.
|
||||
void GraphTask::stash_current_streams() {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
const auto accelerator = at::getAccelerator(true).value();
|
||||
const auto guard = c10::impl::VirtualGuardImpl{accelerator};
|
||||
auto num_devices = guard.deviceCount();
|
||||
|
||||
@ -47,6 +47,8 @@ TORCH_API void validate_outputs(
|
||||
const std::vector<c10::optional<InputMetadata>>& input_metadata,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error);
|
||||
TORCH_API std::vector<c10::optional<InputMetadata>> collect_input_metadata(
|
||||
const edge_list& edges);
|
||||
|
||||
struct NodeTask {
|
||||
std::weak_ptr<GraphTask> base_;
|
||||
|
||||
@ -34,8 +34,12 @@ using tensor_list = std::vector<at::Tensor>;
|
||||
using variable_list = std::vector<Variable>;
|
||||
using edge_list = std::vector<Edge>;
|
||||
using saved_variable_list = std::vector<SavedVariable>;
|
||||
using ivalue_list = std::vector<c10::IValue>;
|
||||
using functional_apply_t = std::function<
|
||||
variable_list(const variable_list&, const std::vector<c10::IValue>&)>;
|
||||
using IndexRange = std::pair<size_t, size_t>;
|
||||
using torch::dynamo::autograd::CompiledNodeArgs;
|
||||
using torch::dynamo::autograd::SavedState;
|
||||
using torch::dynamo::autograd::SwapSavedVariables;
|
||||
|
||||
// Custom deleter to prevent stack overflows.
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
namespace torch::dynamo::autograd {
|
||||
class CompiledNodeArgs;
|
||||
class SwapSavedVariables;
|
||||
struct SavedState;
|
||||
} // namespace torch::dynamo::autograd
|
||||
|
||||
// A hook that's called on gradients
|
||||
|
||||
22
torch/csrc/dynamo/compiled_autograd.cpp
Normal file
22
torch/csrc/dynamo/compiled_autograd.cpp
Normal file
@ -0,0 +1,22 @@
|
||||
#include <torch/csrc/dynamo/compiled_autograd.h>
|
||||
|
||||
namespace torch::dynamo::autograd {
|
||||
|
||||
std::unique_ptr<PyCompilerInterface> kPyCompilerInterface;
|
||||
|
||||
const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface() {
|
||||
TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr);
|
||||
return kPyCompilerInterface;
|
||||
}
|
||||
|
||||
void setPyCompilerInterface(std::unique_ptr<PyCompilerInterface>&& impl) {
|
||||
TORCH_INTERNAL_ASSERT(impl != nullptr);
|
||||
std::swap(kPyCompilerInterface, impl);
|
||||
TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr);
|
||||
}
|
||||
|
||||
void resetPyCompilerInterface() {
|
||||
kPyCompilerInterface.reset();
|
||||
}
|
||||
|
||||
} // namespace torch::dynamo::autograd
|
||||
@ -324,7 +324,6 @@ class CompiledNodeArgs {
|
||||
template <typename T>
|
||||
void collect(const std::optional<T>& t) {
|
||||
if (cond(t.has_value())) {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
collect(*t);
|
||||
}
|
||||
}
|
||||
@ -900,6 +899,511 @@ class SwapSavedVariables {
|
||||
StashedVars<at::IValue> stashed_ivalues;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct dependent_false : std::false_type {};
|
||||
|
||||
// NOTE: [Compiled Autograd and backward functions]
|
||||
// Built-in autograd nodes have functional apply variants
|
||||
// (e.g. MulBackward0_apply_functional). Compiled Autograd's initial graph
|
||||
// capture wants to take a variant of this function and proxy it into the graph.
|
||||
// Every autograd node defines an apply_with_saved function, that when invoked,
|
||||
// proxys a call to a function into the Compiled Autograd graph.
|
||||
//
|
||||
// Some requirements that we have are:
|
||||
// - The proxy'ed function must have inputs that are FX-graphable types.
|
||||
// - Windows has a DLL symbol limit of 65536.
|
||||
// - Node::apply_with_saved is in libtorch_cpu which does not have direct access
|
||||
// to Python
|
||||
//
|
||||
// There were multiple ways to skin the cat, but what we end up doing is:
|
||||
// - for e.g. MulBackward0_apply_functional, we create a new C++ function
|
||||
// MulBackward0_apply_functional_ivalue that accepts IValues.
|
||||
// - We define how to pack and unpack arbitrary C++ types into IValues.
|
||||
// - apply_with_saved passes MulBackward0_apply_functional_ivalue and
|
||||
// the IValue arguments to Python via an indirection.
|
||||
// In Python, these get proxy'ed into a graph.
|
||||
|
||||
// Helper struct for packing/unpacking an arbitrary C++ type into a single
|
||||
// IValue. There are various full and partial specializations for IValuePacker
|
||||
// to handle packing specific types (like TensorOptions) into an IValue.
|
||||
template <typename T>
|
||||
struct IValuePacker {
|
||||
// Pack a T into an IValue.
|
||||
static at::IValue pack(const T& t) {
|
||||
return t;
|
||||
}
|
||||
// Unpacks an IValue into a T.
|
||||
static T unpack(const at::IValue& t) {
|
||||
return t.to<T>();
|
||||
}
|
||||
// Returns the TypePtr for the IValue. This is used when
|
||||
// passing the IValue from Python into C++; we use it to
|
||||
// parse the Python object into an IValue.
|
||||
static at::TypePtr packed_type() {
|
||||
if constexpr (std::is_same_v<T, at::Tensor>) {
|
||||
return at::TensorType::get();
|
||||
} else if constexpr (std::is_same_v<T, int64_t>) {
|
||||
return at::IntType::get();
|
||||
} else if constexpr (std::is_same_v<T, c10::SymInt>) {
|
||||
return at::SymIntType::get();
|
||||
} else if constexpr (std::is_same_v<T, bool>) {
|
||||
return at::BoolType::get();
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
return at::FloatType::get();
|
||||
} else if constexpr (std::is_same_v<T, c10::SymFloat>) {
|
||||
return at::SymFloatType::get();
|
||||
} else if constexpr (std::is_same_v<T, c10::SymBool>) {
|
||||
return at::SymBoolType::get();
|
||||
} else if constexpr (std::is_same_v<T, c10::Layout>) {
|
||||
return at::LayoutType::get();
|
||||
} else if constexpr (std::is_same_v<T, std::string>) {
|
||||
return at::StringType::get();
|
||||
} else if constexpr (std::is_same_v<T, at::Device>) {
|
||||
return at::DeviceObjType::get();
|
||||
} else if constexpr (std::is_same_v<T, at::Scalar>) {
|
||||
return at::NumberType::get();
|
||||
} else if constexpr (std::is_same_v<T, at::MemoryFormat>) {
|
||||
return at::MemoryFormatType::get();
|
||||
} else if constexpr (std::is_same_v<T, at::ScalarType>) {
|
||||
return at::ScalarTypeType::get();
|
||||
} else {
|
||||
// If you got here, you have probably added a member of a new type
|
||||
// to a built-in C++ autograd node.
|
||||
// To get this new type to work with Compiled Autograd, please
|
||||
// either change it to be an IValue-constructible type, or
|
||||
// define how to pack and unpack an object of this time into an IValue.
|
||||
// See NOTE: [Compiled Autograd and backward functions] for context.
|
||||
static_assert(dependent_false<T>::value);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IValuePacker<uint64_t> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::IntType::get();
|
||||
}
|
||||
static at::IValue pack(const uint64_t& t) {
|
||||
return static_cast<int64_t>(t);
|
||||
}
|
||||
static uint64_t unpack(const at::IValue& t) {
|
||||
return static_cast<uint64_t>(t.toInt());
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IValuePacker<std::vector<at::SymInt>> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::ListType::create(at::SymIntType::get());
|
||||
}
|
||||
static at::IValue pack(const std::vector<at::SymInt>& t) {
|
||||
return t;
|
||||
}
|
||||
static std::vector<at::SymInt> unpack(const at::IValue& t) {
|
||||
return t.toSymIntVector();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IValuePacker<VariableInfo> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::TupleType::create({
|
||||
at::LayoutType::get(),
|
||||
at::DeviceObjType::get(),
|
||||
at::ScalarTypeType::get(),
|
||||
at::ListType::create(at::SymIntType::get()),
|
||||
at::BoolType::get(),
|
||||
at::BoolType::get(),
|
||||
});
|
||||
}
|
||||
static at::IValue pack(const VariableInfo& t) {
|
||||
auto tuple = std::make_tuple(
|
||||
t.layout, t.device, t.scalar_type, t.size, t.requires_grad, t.is_empty);
|
||||
return tuple;
|
||||
}
|
||||
static VariableInfo unpack(const at::IValue& t) {
|
||||
auto tuple = t.to<std::tuple<
|
||||
at::Layout,
|
||||
at::Device,
|
||||
at::ScalarType,
|
||||
std::vector<at::SymInt>,
|
||||
bool,
|
||||
bool>>();
|
||||
VariableInfo v;
|
||||
v.layout = std::get<0>(tuple);
|
||||
v.device = std::get<1>(tuple);
|
||||
v.scalar_type = std::get<2>(tuple);
|
||||
v.size = std::get<3>(tuple);
|
||||
v.requires_grad = std::get<4>(tuple);
|
||||
v.is_empty = std::get<5>(tuple);
|
||||
return v;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IValuePacker<caffe2::TypeMeta> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::ScalarTypeType::get();
|
||||
}
|
||||
static at::IValue pack(const caffe2::TypeMeta& t) {
|
||||
return at::typeMetaToScalarType(t);
|
||||
}
|
||||
static caffe2::TypeMeta unpack(const at::IValue& t) {
|
||||
return caffe2::TypeMeta::fromScalarType(t.to<at::ScalarType>());
|
||||
}
|
||||
};
|
||||
|
||||
inline std::optional<at::ScalarType> optTypeMetaToScalarType(
|
||||
const std::optional<caffe2::TypeMeta>& t) {
|
||||
if (t.has_value()) {
|
||||
return at::typeMetaToScalarType(t.value());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
using packed_tensoroptions_t = std::tuple<
|
||||
std::optional<bool>,
|
||||
std::optional<at::MemoryFormat>,
|
||||
std::optional<at::Device>,
|
||||
std::optional<at::ScalarType>,
|
||||
std::optional<at::Layout>,
|
||||
std::optional<bool>>;
|
||||
|
||||
inline packed_tensoroptions_t pack_TensorOptions(const at::TensorOptions& t) {
|
||||
auto tuple = std::make_tuple(
|
||||
t.requires_grad_opt(),
|
||||
t.memory_format_opt(),
|
||||
t.device_opt(),
|
||||
optTypeMetaToScalarType(t.dtype_opt()),
|
||||
t.layout_opt(),
|
||||
t.pinned_memory_opt());
|
||||
return tuple;
|
||||
}
|
||||
inline at::TensorOptions unpack_TensorOptions(
|
||||
const packed_tensoroptions_t& tuple) {
|
||||
at::TensorOptions result;
|
||||
if (std::get<0>(tuple).has_value()) {
|
||||
result = result.requires_grad(std::get<0>(tuple).value());
|
||||
}
|
||||
if (std::get<1>(tuple).has_value()) {
|
||||
result = result.memory_format(std::get<1>(tuple).value());
|
||||
}
|
||||
if (std::get<2>(tuple).has_value()) {
|
||||
result = result.device(std::get<2>(tuple).value());
|
||||
}
|
||||
if (std::get<3>(tuple).has_value()) {
|
||||
result = result.dtype(
|
||||
caffe2::TypeMeta::fromScalarType(std::get<3>(tuple).value()));
|
||||
}
|
||||
if (std::get<4>(tuple).has_value()) {
|
||||
result = result.layout(std::get<4>(tuple).value());
|
||||
}
|
||||
if (std::get<5>(tuple).has_value()) {
|
||||
result = result.pinned_memory(std::get<5>(tuple).value());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
struct IValuePacker<at::TensorOptions> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::TupleType::create(
|
||||
{at::OptionalType::create(at::BoolType::get()),
|
||||
at::OptionalType::create(at::MemoryFormatType::get()),
|
||||
at::OptionalType::create(at::DeviceObjType::get()),
|
||||
at::OptionalType::create(at::ScalarTypeType::get()),
|
||||
at::OptionalType::create(at::LayoutType::get()),
|
||||
at::OptionalType::create(at::BoolType::get())});
|
||||
}
|
||||
static at::IValue pack(const at::TensorOptions& t) {
|
||||
return pack_TensorOptions(t);
|
||||
}
|
||||
static at::TensorOptions unpack(const at::IValue& t) {
|
||||
auto tuple = t.to<packed_tensoroptions_t>();
|
||||
return unpack_TensorOptions(tuple);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IValuePacker<TypeAndSize> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::TupleType::create(
|
||||
{IValuePacker<std::vector<at::SymInt>>::packed_type(),
|
||||
IValuePacker<at::TensorOptions>::packed_type()});
|
||||
}
|
||||
static at::IValue pack(const TypeAndSize& t) {
|
||||
auto tuple = std::make_tuple(t.sym_sizes, pack_TensorOptions(t.options));
|
||||
return tuple;
|
||||
}
|
||||
static TypeAndSize unpack(const at::IValue& t) {
|
||||
auto tuple =
|
||||
t.to<std::tuple<std::vector<at::SymInt>, packed_tensoroptions_t>>();
|
||||
TypeAndSize result;
|
||||
result.sym_sizes = std::get<0>(tuple);
|
||||
result.options = unpack_TensorOptions(std::get<1>(tuple));
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct IValuePacker<std::optional<T>> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::OptionalType::create(IValuePacker<T>::packed_type());
|
||||
}
|
||||
static at::IValue pack(const std::optional<T>& t) {
|
||||
if (t.has_value()) {
|
||||
return IValuePacker<T>::pack(t.value());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
static std::optional<T> unpack(const at::IValue& t) {
|
||||
if (t.isNone()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return IValuePacker<T>::unpack(t);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct IValuePacker<std::vector<T>> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::ListType::create(IValuePacker<T>::packed_type());
|
||||
}
|
||||
static at::IValue pack(const std::vector<T>& t) {
|
||||
if constexpr (std::is_constructible_v<at::IValue, T>) {
|
||||
return t;
|
||||
}
|
||||
if (t.empty()) {
|
||||
auto lst = c10::impl::GenericList(at::AnyType::get());
|
||||
return lst;
|
||||
}
|
||||
auto type_ptr = IValuePacker<T>::pack(t[0]).type();
|
||||
auto lst = c10::impl::GenericList(type_ptr);
|
||||
for (const auto& elt : t) {
|
||||
lst.emplace_back(IValuePacker<T>::pack(elt));
|
||||
}
|
||||
return lst;
|
||||
}
|
||||
static std::vector<T> unpack(const at::IValue& t) {
|
||||
if constexpr (std::is_constructible_v<at::IValue, T>) {
|
||||
return t.to<std::vector<T>>();
|
||||
}
|
||||
std::vector<T> result;
|
||||
auto lst = t.toList();
|
||||
for (const at::IValue& elt : lst) {
|
||||
result.emplace_back(IValuePacker<T>::unpack(elt));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct IValuePacker<c10::List<T>> {
|
||||
static at::TypePtr packed_type() {
|
||||
return IValuePacker<std::vector<T>>::packed_type();
|
||||
}
|
||||
static at::IValue pack(const c10::List<T>& t) {
|
||||
return IValuePacker<std::vector<T>>::pack(t.vec());
|
||||
}
|
||||
static c10::List<T> unpack(const at::IValue& t) {
|
||||
return c10::List<T>(IValuePacker<std::vector<T>>::unpack(t));
|
||||
}
|
||||
};
|
||||
|
||||
template <size_t N>
|
||||
struct IValuePacker<std::array<bool, N>> {
|
||||
static at::TypePtr packed_type() {
|
||||
return IValuePacker<std::vector<bool>>::packed_type();
|
||||
}
|
||||
static at::IValue pack(const std::array<bool, N>& t) {
|
||||
std::vector<bool> result(t.begin(), t.end());
|
||||
return IValuePacker<std::vector<bool>>::pack(result);
|
||||
}
|
||||
static std::array<bool, N> unpack(const at::IValue& t) {
|
||||
std::array<bool, N> result;
|
||||
auto packed = IValuePacker<std::vector<bool>>::unpack(t);
|
||||
for (size_t i = 0; i < packed.size(); i++) {
|
||||
result[i] = packed[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IValuePacker<at::TensorGeometry> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::TupleType::create(
|
||||
{IValuePacker<std::vector<at::SymInt>>::packed_type(),
|
||||
IValuePacker<std::vector<at::SymInt>>::packed_type(),
|
||||
at::SymIntType::get()});
|
||||
}
|
||||
static at::IValue pack(const at::TensorGeometry& t) {
|
||||
auto tuple = std::make_tuple(
|
||||
t.sym_sizes().vec(), t.sym_strides().vec(), t.sym_storage_offset());
|
||||
return tuple;
|
||||
}
|
||||
static at::TensorGeometry unpack(const at::IValue& t) {
|
||||
auto tuple = t.to<std::tuple<
|
||||
std::vector<at::SymInt>,
|
||||
std::vector<at::SymInt>,
|
||||
at::SymInt>>();
|
||||
return at::TensorGeometry(
|
||||
std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IValuePacker<InputMetadata> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::TupleType::create(
|
||||
{IValuePacker<at::TensorOptions>::packed_type(),
|
||||
IValuePacker<std::vector<at::SymInt>>::packed_type(),
|
||||
at::BoolType::get()});
|
||||
}
|
||||
static at::IValue pack(const InputMetadata& t) {
|
||||
TORCH_INTERNAL_ASSERT(!t.is_nested_tensor());
|
||||
auto tuple = std::make_tuple(
|
||||
pack_TensorOptions(t.options()),
|
||||
t.shape_as_dim_vector().vec(),
|
||||
t.is_tensor_subclass());
|
||||
return tuple;
|
||||
}
|
||||
static InputMetadata unpack(const at::IValue& t) {
|
||||
auto tuple = t.to<
|
||||
std::tuple<packed_tensoroptions_t, std::vector<at::SymInt>, bool>>();
|
||||
|
||||
return InputMetadata(
|
||||
unpack_TensorOptions(std::get<0>(tuple)),
|
||||
SymIntSmallVec(std::get<1>(tuple)),
|
||||
std::get<2>(tuple),
|
||||
false);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct IValuePacker<at::OptionalArray<T>> {
|
||||
static at::TypePtr packed_type() {
|
||||
return IValuePacker<std::optional<std::vector<T>>>::packed_type();
|
||||
}
|
||||
static at::IValue pack(const at::OptionalArray<T>& t) {
|
||||
return IValuePacker<std::optional<std::vector<T>>>::pack(t.list);
|
||||
}
|
||||
static at::OptionalArray<T> unpack(const at::IValue& t) {
|
||||
auto result = IValuePacker<std::optional<std::vector<T>>>::unpack(t);
|
||||
if (result.has_value()) {
|
||||
return {result.value()};
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IValuePacker<ska::flat_hash_map<std::string, at::IValue>> {
|
||||
static at::TypePtr packed_type() {
|
||||
return at::DictType::create(at::StringType::get(), at::AnyType::get());
|
||||
}
|
||||
static at::IValue pack(const ska::flat_hash_map<std::string, at::IValue>& t) {
|
||||
auto result =
|
||||
c10::impl::GenericDict(at::StringType::get(), at::AnyType::get());
|
||||
for (const auto& [key, value] : t) {
|
||||
result.insert(key, value);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
static ska::flat_hash_map<std::string, at::IValue> unpack(
|
||||
const at::IValue& t) {
|
||||
auto dct = t.toGenericDict();
|
||||
auto result = ska::flat_hash_map<std::string, at::IValue>();
|
||||
for (const auto& entry : dct) {
|
||||
result.insert({entry.key().to<std::string>(), entry.value()});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
using saved_data_t = ska::flat_hash_map<std::string, at::IValue>;
|
||||
|
||||
struct SavedState {
|
||||
SavedState() = default;
|
||||
|
||||
explicit SavedState(std::vector<at::IValue> stack_)
|
||||
: stack(std::move(stack_)) {}
|
||||
|
||||
std::vector<at::IValue> stack;
|
||||
int64_t idx = 0;
|
||||
|
||||
template <typename T>
|
||||
void pack(const T& t) {
|
||||
stack.emplace_back(IValuePacker<T>::pack(t));
|
||||
}
|
||||
template <typename T>
|
||||
T unpack() {
|
||||
return IValuePacker<T>::unpack(std::move(stack[idx++]));
|
||||
}
|
||||
|
||||
void pack_saved_data(const ska::flat_hash_map<std::string, at::IValue>& dct) {
|
||||
std::vector<std::string> keys;
|
||||
std::vector<at::IValue> values;
|
||||
for (const auto& [key, value] : dct) {
|
||||
keys.emplace_back(key);
|
||||
values.emplace_back(value);
|
||||
}
|
||||
pack(keys);
|
||||
for (const auto& value : values) {
|
||||
pack(value);
|
||||
}
|
||||
}
|
||||
|
||||
saved_data_t unpack_saved_data() {
|
||||
ska::flat_hash_map<std::string, at::IValue> dct;
|
||||
auto keys = unpack<std::vector<std::string>>();
|
||||
for (const auto& key : keys) {
|
||||
dct.insert({key, std::move(stack[idx++])});
|
||||
}
|
||||
return dct;
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API PyCompilerInterface {
|
||||
virtual ~PyCompilerInterface(){};
|
||||
virtual variable_list call_function(
|
||||
PyObject* py_compiler,
|
||||
const char* name,
|
||||
functional_apply_t fn,
|
||||
const variable_list& inputs,
|
||||
const ivalue_list& saved_state,
|
||||
int64_t num_outputs,
|
||||
const std::string& debug,
|
||||
const std::vector<at::TypePtr>& saved_state_schema,
|
||||
bool builtin) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
virtual variable_list call_copy_slices_prologue(
|
||||
PyObject* py_compiler,
|
||||
const variable_list& inputs,
|
||||
const at::TensorGeometry& base,
|
||||
const at::TensorGeometry& view) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
virtual variable_list call_copy_slices_epilogue(
|
||||
PyObject* py_compiler,
|
||||
const std::vector<bool>& needs_input_grad,
|
||||
const at::Tensor& result,
|
||||
const variable_list& res,
|
||||
const at::Tensor& grad_slice) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
|
||||
TORCH_API void setPyCompilerInterface(
|
||||
std::unique_ptr<PyCompilerInterface>&& impl);
|
||||
TORCH_API void resetPyCompilerInterface();
|
||||
|
||||
} // namespace torch::dynamo::autograd
|
||||
|
||||
template <>
|
||||
|
||||
@ -52,6 +52,122 @@ Notes:
|
||||
namespace torch::dynamo::autograd {
|
||||
using c10::SymInt;
|
||||
|
||||
static PyObject* kPyCompiler;
|
||||
|
||||
PyObject* current_py_compiler() {
|
||||
return kPyCompiler;
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static variable_list call_function(
|
||||
PyObject* py_compiler,
|
||||
const char* name,
|
||||
Func fn,
|
||||
const variable_list& inputs,
|
||||
const ivalue_list& saved_state,
|
||||
int64_t num_outputs,
|
||||
const std::string& debug,
|
||||
const std::vector<TypePtr>& schema,
|
||||
bool builtin) {
|
||||
// TORCH_INTERNAL_ASSERT(schema.size() == saved_state.size());
|
||||
|
||||
// We are going to bind the following function to Python
|
||||
auto py_func = py::cpp_function(
|
||||
[schema, fn](
|
||||
std::vector<c10::optional<at::Tensor>>& inputs,
|
||||
const py::args& args) -> py::object {
|
||||
// It reconstructs the saved_state from args via the schema
|
||||
std::vector<at::IValue> stack;
|
||||
TORCH_INTERNAL_ASSERT(args.size() == schema.size());
|
||||
auto tuple_args = jit::tuple_slice(args);
|
||||
for (uint64_t idx = 0; idx < schema.size(); idx++) {
|
||||
stack.emplace_back(
|
||||
jit::toIValue(tuple_args[idx], schema[idx], c10::nullopt));
|
||||
}
|
||||
std::vector<at::Tensor> inputs_;
|
||||
for (const auto& inp : inputs) {
|
||||
if (inp.has_value()) {
|
||||
inputs_.emplace_back(*inp);
|
||||
} else {
|
||||
inputs_.emplace_back();
|
||||
}
|
||||
}
|
||||
auto outputs = fn(inputs_, stack);
|
||||
return jit::toPyObject(at::IValue(outputs));
|
||||
});
|
||||
|
||||
// convert ivalue_list -> PyObject*
|
||||
PyObject* py_saved_state =
|
||||
PyTuple_New(static_cast<Py_ssize_t>(saved_state.size()));
|
||||
for (const auto i : c10::irange(saved_state.size())) {
|
||||
py::object obj = jit::toPyObject(saved_state[i]);
|
||||
Py_INCREF(obj.ptr());
|
||||
PyTuple_SET_ITEM(py_saved_state, i, obj.ptr());
|
||||
}
|
||||
|
||||
// call the corresponding method on the py_compiler
|
||||
py::handle handle(py_compiler);
|
||||
py::object stuff = handle.attr(name)(
|
||||
py_func, inputs, py::handle(py_saved_state), num_outputs, debug, builtin);
|
||||
|
||||
// Convert the output from PyObject* to vector<Tensor>
|
||||
auto tmp = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
|
||||
variable_list outputs;
|
||||
for (const auto& t : tmp) {
|
||||
if (t.has_value()) {
|
||||
outputs.emplace_back(t.value());
|
||||
} else {
|
||||
outputs.emplace_back();
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
struct PyCompilerInterfaceImpl : PyCompilerInterface {
|
||||
variable_list call_function(
|
||||
PyObject* py_compiler,
|
||||
const char* name,
|
||||
functional_apply_t fn,
|
||||
const variable_list& inputs,
|
||||
const ivalue_list& saved_state,
|
||||
int64_t num_outputs,
|
||||
const std::string& debug,
|
||||
const std::vector<at::TypePtr>& saved_state_schema,
|
||||
bool builtin) override {
|
||||
return torch::dynamo::autograd::call_function(
|
||||
py_compiler,
|
||||
name,
|
||||
fn,
|
||||
inputs,
|
||||
saved_state,
|
||||
num_outputs,
|
||||
debug,
|
||||
saved_state_schema,
|
||||
builtin);
|
||||
}
|
||||
variable_list call_copy_slices_prologue(
|
||||
PyObject* py_compiler,
|
||||
const variable_list& inputs,
|
||||
const at::TensorGeometry& base,
|
||||
const at::TensorGeometry& view) override {
|
||||
py::handle handle(py_compiler);
|
||||
py::object stuff =
|
||||
handle.attr("call_copy_slices_prologue")(inputs, base, view);
|
||||
return py::cast<std::vector<at::Tensor>>(stuff);
|
||||
}
|
||||
virtual variable_list call_copy_slices_epilogue(
|
||||
PyObject* py_compiler,
|
||||
const std::vector<bool>& needs_input_grad,
|
||||
const at::Tensor& result,
|
||||
const variable_list& res,
|
||||
const at::Tensor& grad_slice) override {
|
||||
py::handle handle(py_compiler);
|
||||
py::object stuff = handle.attr("call_copy_slices_epilogue")(
|
||||
needs_input_grad, result, res, grad_slice);
|
||||
return py::cast<std::vector<at::Tensor>>(stuff);
|
||||
}
|
||||
};
|
||||
|
||||
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
|
||||
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
@ -89,6 +205,22 @@ static void check(bool result) {
|
||||
check(nullptr);
|
||||
}
|
||||
|
||||
static variable_list validate_outputs(
|
||||
variable_list& outputs,
|
||||
const ivalue_list& saved) {
|
||||
SavedState r;
|
||||
r.stack = saved;
|
||||
auto value = r.unpack<std::vector<c10::optional<InputMetadata>>>();
|
||||
|
||||
torch::autograd::validate_outputs(
|
||||
value, outputs, [&](const std::string& msg) {
|
||||
std::ostringstream ss;
|
||||
ss << "[Compiled Autograd Tracing:]" << msg;
|
||||
return ss.str();
|
||||
});
|
||||
return outputs;
|
||||
}
|
||||
|
||||
// snapshot of python verbose logging toggle
|
||||
static PyObject* python_verbose_logger = nullptr;
|
||||
|
||||
@ -656,6 +788,9 @@ CacheNode* _compiled_autograd_impl(
|
||||
// cache miss, need to capture FX graph
|
||||
ClosingTHPObjectPtr py_compiler(
|
||||
check(PyObject_CallNoArgs((the_autograd_compiler))));
|
||||
kPyCompiler = py_compiler.get();
|
||||
|
||||
setPyCompilerInterface(std::make_unique<PyCompilerInterfaceImpl>());
|
||||
|
||||
TraceState state = call_begin_capture(
|
||||
py_compiler, *cache, compiler_call, output_edges.size());
|
||||
@ -723,16 +858,27 @@ CacheNode* _compiled_autograd_impl(
|
||||
|
||||
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
|
||||
variable_list outputs = call.node->apply_with_saved(inputs, saved);
|
||||
|
||||
saved.debug_asserts();
|
||||
saved.before(call.node->next_edges());
|
||||
validate_outputs(
|
||||
call.node->next_edges(), outputs, [&](const std::string& msg) {
|
||||
std::ostringstream ss;
|
||||
ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
|
||||
<< msg;
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
auto input_metadata = collect_input_metadata(call.node->next_edges());
|
||||
TORCH_INTERNAL_ASSERT(input_metadata.size() == outputs.size());
|
||||
|
||||
SavedState state;
|
||||
state.pack(input_metadata);
|
||||
ivalue_list& input_metadata_state = state.stack;
|
||||
outputs = call_function(
|
||||
py_compiler,
|
||||
"validate_outputs",
|
||||
validate_outputs,
|
||||
outputs,
|
||||
input_metadata_state,
|
||||
outputs.size(),
|
||||
"validate_outputs",
|
||||
{IValuePacker<
|
||||
std::vector<c10::optional<InputMetadata>>>::packed_type()},
|
||||
/*builtin*/ true);
|
||||
|
||||
saved.after(call.node->next_edges());
|
||||
saved.debug_asserts();
|
||||
|
||||
@ -761,6 +907,8 @@ CacheNode* _compiled_autograd_impl(
|
||||
}
|
||||
}
|
||||
|
||||
resetPyCompilerInterface();
|
||||
kPyCompiler = nullptr;
|
||||
PyObject* res = check(call_end_capture(py_compiler, state.outputs));
|
||||
TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
|
||||
TORCH_CHECK(
|
||||
|
||||
Reference in New Issue
Block a user