mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 10:34:54 +08:00
Compare commits
1 Commits
ciflow/tru
...
fca
| Author | SHA1 | Date | |
|---|---|---|---|
| f85a0b82eb |
@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry {
|
||||
has_symbolic_sizes_strides_(
|
||||
t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
|
||||
|
||||
explicit TensorGeometry(
|
||||
std::vector<at::SymInt> sizes,
|
||||
std::vector<at::SymInt> strides,
|
||||
at::SymInt storage_offset)
|
||||
: sizes_(std::move(sizes)),
|
||||
strides_(std::move(strides)),
|
||||
storage_offset_(std::move(storage_offset)) {
|
||||
recompute();
|
||||
}
|
||||
|
||||
// true if the tensor is contiguous
|
||||
bool is_contiguous() const;
|
||||
|
||||
|
||||
@ -93,7 +93,9 @@ c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
|
||||
case Tag::None:
|
||||
return NoneType::get();
|
||||
case Tag::Tensor:
|
||||
return TensorType::create(v.toTensor());
|
||||
return TensorType::get();
|
||||
// TODO(rzou): following errors
|
||||
// return TensorType::create(v.toTensor());
|
||||
case Tag::Storage:
|
||||
return StorageType::get();
|
||||
case Tag::Double:
|
||||
|
||||
@ -2075,6 +2075,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_id, m) {
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
verbose=True,
|
||||
extra_cflags=["-g", "-O0"],
|
||||
)
|
||||
|
||||
def same_autograd_fn():
|
||||
@ -2113,8 +2114,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_id, m) {
|
||||
|
||||
self.check_output_and_recompiles(different_autograd_fn, 2)
|
||||
|
||||
@scoped_load_inline
|
||||
def test_autograd_cpp_node_saved(self, load_inline):
|
||||
@unittest.skip("Flaky, cache from test ordering affects test. #135369")
|
||||
def test_autograd_cpp_node_saved(self):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
@ -2190,7 +2191,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved, m) {
|
||||
self.check_output_and_recompiles(fn, 2)
|
||||
|
||||
@scoped_load_inline
|
||||
def test_autograd_cpp_node_saved_dynamic(self, load_inline):
|
||||
def test_autograd_cpp_node_saved_dynamic(self):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
|
||||
@ -64,6 +64,9 @@ struct TORCH_API ${op} : public ${superclass} {
|
||||
}
|
||||
${will_release_variables}
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
ivalue_list get_state();
|
||||
ivalue_list retrieve_saved(SwapSavedVariables& saved) override;
|
||||
functional_apply_t get_functional() override;
|
||||
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
|
||||
${saved_variables}
|
||||
${saved_list_sizes}
|
||||
@ -82,15 +85,22 @@ void will_release_variables() override {
|
||||
|
||||
FUNCTION_DEFINITION = CodeTemplate(
|
||||
"""\
|
||||
variable_list ${op}::apply(variable_list&& grads) {
|
||||
${thread_lock}
|
||||
${asserts}
|
||||
static variable_list ${op}_apply_functional(variable_list&& grads, std::array<bool,${num_vars}> needs_input_grad ${unpacked_saved_vars_signature}) {
|
||||
IndexRangeGenerator gen;
|
||||
${compute_index_ranges}
|
||||
variable_list grad_inputs(gen.size());
|
||||
${body}
|
||||
return grad_inputs;
|
||||
}
|
||||
|
||||
variable_list ${op}::apply(variable_list&& grads) {
|
||||
${thread_lock}
|
||||
${asserts}
|
||||
${unpacks}
|
||||
${compute_needs_input_grad}
|
||||
return ${op}_apply_functional(std::move(grads), grad_input_mask ${unpacked_saved_vars});
|
||||
}
|
||||
|
||||
void ${op}::compiled_args(CompiledNodeArgs& args) {
|
||||
${compiled_args}
|
||||
}
|
||||
@ -100,6 +110,28 @@ variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVaria
|
||||
${apply_with_saved_after}
|
||||
return result;
|
||||
}
|
||||
ivalue_list ${op}::get_state() {
|
||||
SavedState saved_state;
|
||||
${unpacks}
|
||||
${get_state}
|
||||
return saved_state.stack;
|
||||
}
|
||||
ivalue_list ${op}::retrieve_saved(SwapSavedVariables& saved) {
|
||||
${apply_with_saved_before}
|
||||
auto state = get_state();
|
||||
${apply_with_saved_after}
|
||||
return state;
|
||||
}
|
||||
|
||||
functional_apply_t ${op}::get_functional() {
|
||||
${compute_needs_input_grad}
|
||||
return [grad_input_mask](const variable_list& inputs, const std::vector<c10::IValue>& saved) {
|
||||
SavedState state;
|
||||
state.stack = saved;
|
||||
${saved_var_dequeues}
|
||||
return ${op}_apply_functional(variable_list(inputs), grad_input_mask ${unpacked_saved_vars});
|
||||
};
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
@ -107,13 +139,23 @@ GRAD_INPUT_MASK = CodeTemplate(
|
||||
"""\
|
||||
auto grad_input_mask = std::array<bool, ${n}>{
|
||||
${masks}
|
||||
};\
|
||||
};
|
||||
"""
|
||||
)
|
||||
|
||||
COMPUTE_NEEDS_INPUT_GRAD = CodeTemplate(
|
||||
"""\
|
||||
${ix_ranges}
|
||||
auto grad_input_mask = std::array<bool, ${n}>{
|
||||
${masks}
|
||||
};\
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
DERIVATIVE_SINGLE = CodeTemplate(
|
||||
"""\
|
||||
if (task_should_compute_output({ ${name}_ix })) {
|
||||
if (needs_input_grad[std::get<0>(${name}_ix)]) {
|
||||
auto grad_result = ${derivative};
|
||||
copy_range(grad_inputs, ${name}_ix, grad_result);
|
||||
}
|
||||
@ -126,7 +168,7 @@ if (task_should_compute_output({ ${name}_ix })) {
|
||||
# to each `Tensor`(s) of `self`, and the others.
|
||||
DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
|
||||
"""\
|
||||
if (task_should_compute_output({ ${name}_ix })) {
|
||||
if (needs_input_grad[std::get<0>(${name}_ix)]) {
|
||||
std::vector<Tensor> grad_result;
|
||||
grad_result.reserve(grads.size());
|
||||
for (const auto & i : c10::irange(grads.size())) {
|
||||
@ -143,7 +185,7 @@ if (task_should_compute_output({ ${name}_ix })) {
|
||||
|
||||
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
|
||||
"""\
|
||||
if (task_should_compute_output({ ${name}_ix })) {
|
||||
if (needs_input_grad[std::get<0>(${name}_ix)]) {
|
||||
copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
|
||||
}
|
||||
"""
|
||||
@ -151,7 +193,7 @@ DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
|
||||
|
||||
DERIVATIVE_MULTI = CodeTemplate(
|
||||
"""\
|
||||
if (task_should_compute_output({ ${idx_ranges} })) {
|
||||
if (${needs_input_grad}) {
|
||||
${grad_input_mask}
|
||||
auto grad_result = ${derivative};
|
||||
${copy_ranges}
|
||||
@ -552,10 +594,16 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
||||
apply_with_saved_before: list[str] = []
|
||||
apply_with_saved_after: list[str] = []
|
||||
|
||||
for arg in info.args_with_derivatives:
|
||||
unpacked_saved_vars = []
|
||||
unpacked_saved_vars_ref_type = []
|
||||
|
||||
for idx, arg in enumerate(info.args_with_derivatives):
|
||||
# compute_index_ranges.append(f"auto {arg.name}_ix = {idx};")
|
||||
if arg.type in TENSOR_LIST_LIKE_CTYPES:
|
||||
size = f"{arg.name}_size_"
|
||||
saved_list_sizes.append(f"size_t {arg.name}_size_;")
|
||||
unpacked_saved_vars.append(f"{arg.name}_size_")
|
||||
unpacked_saved_vars_ref_type.append("size_t")
|
||||
else:
|
||||
size = "1"
|
||||
compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
|
||||
@ -567,6 +615,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
||||
should_append_raw_getsetdef = False
|
||||
visit_name = name
|
||||
uses_cpp_saved_variable_cls = False
|
||||
unpacked_ref_type = None
|
||||
|
||||
if (
|
||||
type == BaseCType(tensorT)
|
||||
@ -591,6 +640,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
||||
)
|
||||
should_append_raw_getsetdef = True
|
||||
visit_name = f"{name}_"
|
||||
unpacked_ref_type = "Tensor&"
|
||||
elif (
|
||||
type == BaseCType(tensorListT)
|
||||
or type == BaseCType(iTensorListRefT)
|
||||
@ -630,6 +680,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
||||
)
|
||||
should_append_raw_getsetdef = True
|
||||
visit_name = f"{name}_"
|
||||
unpacked_ref_type = "std::vector<Tensor>&"
|
||||
elif type == ListCType(OptionalCType(BaseCType(tensorT))):
|
||||
uses_cpp_saved_variable_cls = True
|
||||
saved_variables.append(f"std::vector<SavedVariable> {name}_;")
|
||||
@ -652,6 +703,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
||||
)
|
||||
should_append_raw_getsetdef = True
|
||||
visit_name = f"{name}_"
|
||||
unpacked_ref_type = "torch::List<std::optional<Tensor>>&"
|
||||
elif type == BaseCType(intArrayRefT):
|
||||
saved_variables.append(f"std::vector<int64_t> {name};")
|
||||
getter_definitions.append(
|
||||
@ -733,6 +785,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
||||
elem=BaseCType(type=BaseCppType(ns="at", name="Scalar"))
|
||||
):
|
||||
saved_variables.append(f"std::vector<at::Scalar> {name};")
|
||||
unpacked_ref_type = "std::vector<at::Scalar>&"
|
||||
saved_variables.append(f"bool {name}_released_ = false;")
|
||||
# Just clear() is sufficient, we don't need to loop and clear each variable.
|
||||
# Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
|
||||
@ -803,6 +856,14 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
apply_with_saved_before.append(f"saved.before({visit_name});")
|
||||
apply_with_saved_after.append(f"saved.after({visit_name});")
|
||||
|
||||
if unpacked_ref_type is None:
|
||||
# TODO(rzou): should reformulate in terms of type, then ref
|
||||
unpacked_ref_type = f"{saved_variables[-1].split(' ')[0]}&"
|
||||
if unpacked_ref_type.startswith("const "):
|
||||
unpacked_ref_type = unpacked_ref_type[6:]
|
||||
unpacked_saved_vars.append(name)
|
||||
unpacked_saved_vars_ref_type.append(unpacked_ref_type)
|
||||
|
||||
for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
|
||||
save_var(var, is_output=False)
|
||||
for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)):
|
||||
@ -816,6 +877,8 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
thread_lock = ""
|
||||
|
||||
if uses_retain_variables(info):
|
||||
unpacked_saved_vars.append("retain_variables")
|
||||
unpacked_saved_vars_ref_type.append("bool")
|
||||
will_release_variables = WILL_RELEASE_VARIABLES.substitute()
|
||||
else:
|
||||
will_release_variables = ""
|
||||
@ -834,9 +897,11 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
def emit_derivative(
|
||||
derivative: Derivative,
|
||||
args_with_derivatives: Sequence[Binding],
|
||||
num_grad_inputs: int,
|
||||
) -> tuple[bool, str]:
|
||||
formula = derivative.formula
|
||||
var_names = derivative.var_names
|
||||
|
||||
if len(var_names) == 1:
|
||||
checks_any_grad_defined = False
|
||||
if "not_implemented" not in formula:
|
||||
@ -857,35 +922,45 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
derivative_template = DERIVATIVE_SINGLE
|
||||
return (
|
||||
checks_any_grad_defined,
|
||||
derivative_template.substitute(name=var_names[0], derivative=formula),
|
||||
derivative_template.substitute(
|
||||
name=var_names[0],
|
||||
derivative=formula,
|
||||
idx=num_grad_inputs,
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
if "grad_input_mask" in formula:
|
||||
masks = [
|
||||
f"task_should_compute_output({{ {n}_ix }})," for n in var_names
|
||||
f"needs_input_grad[std::get<0>({n}_ix)]," for n in var_names
|
||||
]
|
||||
grad_input_mask = GRAD_INPUT_MASK.substitute(
|
||||
masks=masks, n=len(var_names)
|
||||
n=len(var_names),
|
||||
masks=masks
|
||||
)
|
||||
else:
|
||||
grad_input_mask = ""
|
||||
idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
|
||||
needs_input_grad = [f"needs_input_grad[std::get<0>({var_names[i]}_ix)]" for i in range(len(var_names))]
|
||||
needs_input_grad = " || ".join(needs_input_grad)
|
||||
# idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
|
||||
copy_ranges: list[str] = []
|
||||
for i, n in enumerate(var_names):
|
||||
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
|
||||
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i, idx=num_grad_inputs + i))
|
||||
return False, DERIVATIVE_MULTI.substitute(
|
||||
idx_ranges=idx_ranges,
|
||||
needs_input_grad=needs_input_grad,
|
||||
copy_ranges=copy_ranges,
|
||||
derivative=formula,
|
||||
grad_input_mask=grad_input_mask,
|
||||
)
|
||||
|
||||
body.extend(unpack)
|
||||
num_grad_inputs = 0
|
||||
|
||||
need_any_grad_defined_var = False
|
||||
for derivative in info.derivatives:
|
||||
for idx, derivative in enumerate(info.derivatives):
|
||||
checks_any_grad_defined, derivative_text = emit_derivative(
|
||||
derivative, info.args_with_derivatives
|
||||
derivative, info.args_with_derivatives, num_grad_inputs
|
||||
)
|
||||
num_grad_inputs += len(derivative.var_names)
|
||||
body.append(derivative_text)
|
||||
need_any_grad_defined_var |= checks_any_grad_defined
|
||||
# Since single-output derivative formulas need to check if grads are
|
||||
@ -896,6 +971,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
"bool any_grad_defined = any_variable_defined(grads);",
|
||||
)
|
||||
|
||||
|
||||
if info.name in UNTRACEABLE_FUNCTIONS:
|
||||
superclass = "Node"
|
||||
else:
|
||||
@ -906,8 +982,41 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
)
|
||||
all_getter_definitions = "\n".join(getter_definitions)
|
||||
|
||||
get_state = "\n".join(
|
||||
f"saved_state.enqueue({name});"
|
||||
for name in unpacked_saved_vars
|
||||
)
|
||||
saved_var_dequeues = []
|
||||
for typ, name in zip(unpacked_saved_vars_ref_type, unpacked_saved_vars):
|
||||
if typ.endswith("&"):
|
||||
typ = typ[:-1]
|
||||
saved_var_dequeues.append(f"{typ} {name};")
|
||||
saved_var_dequeues.append(f"state.dequeue({name});")
|
||||
|
||||
masks = [
|
||||
f"task_should_compute_output({n})," for n in range(num_grad_inputs)
|
||||
]
|
||||
compute_needs_input_grad = COMPUTE_NEEDS_INPUT_GRAD.substitute(
|
||||
ix_ranges="",
|
||||
n=num_grad_inputs,
|
||||
masks=masks);
|
||||
if len(unpacked_saved_vars) > 0:
|
||||
unpacked_saved_vars_signature = ", " + ",".join(f"{T} {x}" for T, x in zip(unpacked_saved_vars_ref_type, unpacked_saved_vars))
|
||||
else:
|
||||
unpacked_saved_vars_signature = ""
|
||||
if len(unpacked_saved_vars) > 0:
|
||||
unpacked_saved_vars = ", " + ", ".join(unpacked_saved_vars)
|
||||
else:
|
||||
unpacked_saved_vars = ""
|
||||
|
||||
return template.substitute(
|
||||
unpacks="\n".join(unpack),
|
||||
op=info.op,
|
||||
saved_var_dequeues="\n".join(saved_var_dequeues),
|
||||
unpacked_saved_vars=unpacked_saved_vars,
|
||||
unpacked_saved_vars_signature=unpacked_saved_vars_signature,
|
||||
compute_needs_input_grad=compute_needs_input_grad,
|
||||
num_vars=num_grad_inputs,
|
||||
compute_index_ranges=compute_index_ranges,
|
||||
saved_variables=saved_variables,
|
||||
release_variables=release_variables,
|
||||
@ -922,4 +1031,5 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
compiled_args=compiled_args,
|
||||
apply_with_saved_before=apply_with_saved_before,
|
||||
apply_with_saved_after=apply_with_saved_after,
|
||||
get_state=get_state,
|
||||
)
|
||||
|
||||
@ -26,6 +26,7 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
PythonKeyTracer,
|
||||
track_tensor_tree,
|
||||
)
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
||||
from torch.fx.traceback import preserve_node_meta, set_stack_trace
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
@ -54,6 +55,35 @@ def maybe_clone(x):
|
||||
return clone_preserve_strides(x)
|
||||
return x
|
||||
|
||||
counter = 0
|
||||
|
||||
class OpNamespace:
|
||||
def __init__(self):
|
||||
self.next_id = {}
|
||||
|
||||
def add(self, base_name, fn):
|
||||
if base_name not in self.next_id:
|
||||
self.next_id[base_name] = 0
|
||||
nid = self.next_id[base_name]
|
||||
name = f"{base_name}_{nid}"
|
||||
self.next_id[base_name] += 1
|
||||
result = Op(name, fn)
|
||||
torch._dynamo.allow_in_graph(result)
|
||||
setattr(self, name, result)
|
||||
return result
|
||||
|
||||
class Op:
|
||||
def __init__(self, name, fn):
|
||||
self.fn = fn
|
||||
self.__name__ = name
|
||||
self.__module__ = "torch._dynamo.compiled_autograd.ops"
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
|
||||
ops = OpNamespace()
|
||||
|
||||
|
||||
class AutogradCompilerInstance:
|
||||
def __init__(self, compiler_fn) -> None:
|
||||
@ -70,6 +100,7 @@ class AutogradCompilerInstance:
|
||||
self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
|
||||
self.hooks_proxy: Optional[Proxy] = None
|
||||
self.graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
|
||||
self.old_inline_behavior = True
|
||||
|
||||
def wrap_fake(self, x, source):
|
||||
assert isinstance(x, torch.Tensor)
|
||||
@ -187,6 +218,55 @@ class AutogradCompilerInstance:
|
||||
self.bind_tensors_to_proxies(grad_ins, proxies)
|
||||
return tuple(grad_ins)
|
||||
|
||||
def allocate_dummy(self, *examples):
|
||||
with disable_proxy_modes_tracing():
|
||||
return torch.zeros(0)
|
||||
|
||||
def apply_functional(self, fn, inputs, stack, num_outputs, debug_name):
|
||||
if self.old_inline_behavior:
|
||||
result = fn(inputs, *stack)
|
||||
return result
|
||||
# TODO: if the node is a python autograd.Function or a CompiledFunctionBackward,
|
||||
# we should probably "plop" the subgraph into the graph instead
|
||||
# of allow_in_graph the node through Dynamo.
|
||||
proxy_inputs, proxy_stack = pytree.tree_map(lambda t: self.to_proxy(t) if isinstance(t, torch.Tensor) else t, (inputs, stack))
|
||||
op = ops.add(debug_name, fn)
|
||||
proxy_out = self.fx_tracer.create_proxy(
|
||||
"call_function",
|
||||
op,
|
||||
args=(proxy_inputs, *proxy_stack),
|
||||
kwargs={})
|
||||
result = [self.allocate_dummy(*inputs, *stack) for _ in range(num_outputs)]
|
||||
self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(num_outputs)])
|
||||
return result
|
||||
|
||||
def validate_outputs(self, fn, outputs, stack, _0, _1):
|
||||
if self.old_inline_behavior:
|
||||
return fn(outputs, *stack)
|
||||
proxy_outputs, proxy_stack = pytree.tree_map(lambda t: self.to_proxy(t) if isinstance(t, torch.Tensor) else t, (outputs, stack))
|
||||
op = ops.add("validate_outputs", fn)
|
||||
new_proxy_outputs = self.fx_tracer.create_proxy(
|
||||
"call_function",
|
||||
op,
|
||||
args=(proxy_outputs, *proxy_stack),
|
||||
kwargs={})
|
||||
self.bind_tensors_to_proxies(outputs, new_proxy_outputs)
|
||||
return outputs
|
||||
|
||||
def accumulate(self, old_var, new_var):
|
||||
if self.old_inline_behavior:
|
||||
return torch.add(old_var, new_var)
|
||||
old_var_proxy = self.to_proxy(old_var)
|
||||
new_var_proxy = self.to_proxy(new_var)
|
||||
proxy_out = self.fx_tracer.create_proxy(
|
||||
"call_function",
|
||||
torch.add,
|
||||
args=(old_var_proxy, new_var_proxy),
|
||||
kwargs={})
|
||||
result = self.allocate_dummy(old_var)
|
||||
self.bind_tensors_to_proxies([result], [proxy_out])
|
||||
return result
|
||||
|
||||
def proxy_call_hook(self, hook, *args, **kwargs):
|
||||
return self.fx_tracer.create_proxy(
|
||||
"call_function",
|
||||
@ -710,8 +790,6 @@ class AutogradCompilerInstance:
|
||||
return [self.to_proxy(x) for x in t]
|
||||
if isinstance(t, tuple):
|
||||
return tuple(self.to_proxy(x) for x in t)
|
||||
# can it be torch.SymInt as the code used to imply?
|
||||
assert isinstance(t, torch.Tensor)
|
||||
proxy_tensor = fetch_object_proxy(self.fx_tracer, t)
|
||||
assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor)
|
||||
return proxy_tensor.proxy
|
||||
|
||||
@ -3273,6 +3273,7 @@ if torch.distributed.is_available():
|
||||
MOD_INLINELIST = [
|
||||
"torch._decomp",
|
||||
"torch._dynamo._trace_wrapped_higher_order_op",
|
||||
"torch._dynamo.compiled_autograd.ops",
|
||||
"torch._dynamo.comptime",
|
||||
"torch._dynamo.polyfills",
|
||||
"torch._functorch.autograd_function",
|
||||
|
||||
@ -530,9 +530,16 @@ variable_list AutogradContext::get_saved_variables() const {
|
||||
variable_list saved;
|
||||
saved.reserve(saved_variables_.size());
|
||||
auto ptr = grad_fn_.lock();
|
||||
TORCH_INTERNAL_ASSERT(ptr);
|
||||
for (auto& var : saved_variables_) {
|
||||
saved.push_back(var.unpack(ptr));
|
||||
// TORCH_INTERNAL_ASSERT(ptr);
|
||||
// TODO(rzou): hacky, can do this in a more legit way
|
||||
if (ptr) {
|
||||
for (auto& var : saved_variables_) {
|
||||
saved.push_back(var.unpack(ptr));
|
||||
}
|
||||
} else {
|
||||
for (auto& var : saved_variables_) {
|
||||
saved.push_back(var.unpack());
|
||||
}
|
||||
}
|
||||
return saved;
|
||||
}
|
||||
@ -543,6 +550,7 @@ bool AutogradContext::needs_input_grad(size_t output_edge_index) const {
|
||||
return ptr->task_should_compute_output(output_edge_index);
|
||||
}
|
||||
|
||||
// TODO(rzou): might segfault, need to make this functional
|
||||
bool AutogradContext::needs_input_grad(
|
||||
std::initializer_list<IndexRange> idxs) const {
|
||||
auto ptr = grad_fn_.lock();
|
||||
|
||||
@ -241,6 +241,111 @@ struct CppNode : public Node {
|
||||
saved.after(output_info_);
|
||||
return results;
|
||||
}
|
||||
|
||||
functional_apply_t get_functional() override {
|
||||
auto name = this->name();
|
||||
|
||||
// TODO(rzou): probably need to pre compute needs_input_grad
|
||||
return [name](const variable_list& inputs, const std::vector<c10::IValue>& saved) {
|
||||
SavedState state;
|
||||
state.stack = saved;
|
||||
auto ctx = AutogradContext();
|
||||
std::vector<VariableInfo> output_info;
|
||||
std::vector<bool> is_variable_input;
|
||||
state.dequeue(ctx.saved_data);
|
||||
state.dequeue(ctx.saved_variables_);
|
||||
state.dequeue(ctx.materialize_grads_);
|
||||
state.dequeue(output_info);
|
||||
state.dequeue(is_variable_input);
|
||||
|
||||
// TODO(rzou): refactor to share code with CppNode<T>::apply
|
||||
at::OptionalDeviceGuard _device_guard;
|
||||
auto num_inputs = inputs.size();
|
||||
variable_list backward_inputs;
|
||||
backward_inputs.reserve(num_inputs);
|
||||
for (const auto i : c10::irange(num_inputs)) {
|
||||
if (inputs[i].defined() || !ctx.materialize_grads_) {
|
||||
backward_inputs.emplace_back(inputs[i]);
|
||||
} else {
|
||||
backward_inputs.emplace_back(output_info[i].zeros(_device_guard));
|
||||
}
|
||||
}
|
||||
|
||||
auto outputs = T::backward(&ctx, inputs);
|
||||
|
||||
const auto num_forward_inputs =
|
||||
static_cast<int64_t>(is_variable_input.size());
|
||||
auto num_outputs = static_cast<int64_t>(outputs.size());
|
||||
// Returning too many results is ok, but only as long as they're all
|
||||
// undefined. Truncate the result vector in that case.
|
||||
if (num_outputs > num_forward_inputs) {
|
||||
bool all_undef = true;
|
||||
for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
|
||||
all_undef &= (!outputs[i].defined());
|
||||
}
|
||||
if (all_undef) {
|
||||
outputs.resize(num_forward_inputs);
|
||||
num_outputs = num_forward_inputs;
|
||||
}
|
||||
}
|
||||
|
||||
if (num_outputs != num_forward_inputs) {
|
||||
std::string msg("function ");
|
||||
msg += name + " returned an incorrect number of gradients (expected ";
|
||||
msg += std::to_string(num_forward_inputs) + ", got ";
|
||||
msg += std::to_string(num_outputs) + ")";
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
|
||||
variable_list results;
|
||||
results.reserve(num_outputs);
|
||||
for (const auto i : c10::irange(num_outputs)) {
|
||||
if (!is_variable_input[i]) {
|
||||
if (outputs[i].defined()) {
|
||||
std::string msg("function ");
|
||||
msg += name +
|
||||
" returned a gradient different that is defined at position ";
|
||||
msg += std::to_string(i + 1) +
|
||||
", std the corresponding forward input was not a Variable";
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
results.emplace_back(outputs[i]);
|
||||
}
|
||||
return results;
|
||||
};
|
||||
}
|
||||
ivalue_list retrieve_saved(SwapSavedVariables& saved) override {
|
||||
saved.before(ctx_.saved_data);
|
||||
TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
|
||||
TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
|
||||
saved.before(ctx_.saved_variables_);
|
||||
TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty());
|
||||
saved.before(ctx_.materialize_grads_);
|
||||
saved.before(ctx_.has_freed_buffers_);
|
||||
saved.before(input_info_);
|
||||
saved.before(output_info_);
|
||||
|
||||
SavedState state;
|
||||
state.enqueue(ctx_.saved_data);
|
||||
state.enqueue(ctx_.saved_variables_, shared_from_this());
|
||||
state.enqueue(ctx_.materialize_grads_);
|
||||
state.enqueue(output_info_);
|
||||
state.enqueue(is_variable_input_);
|
||||
|
||||
saved.after(ctx_.saved_data);
|
||||
TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
|
||||
TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
|
||||
saved.after(ctx_.saved_variables_);
|
||||
TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty());
|
||||
saved.after(ctx_.materialize_grads_);
|
||||
saved.after(ctx_.has_freed_buffers_);
|
||||
saved.after(input_info_);
|
||||
saved.after(output_info_);
|
||||
|
||||
return state.stack;
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractVariables : IterArgs<ExtractVariables> {
|
||||
|
||||
@ -859,18 +859,38 @@ void validate_outputs(
|
||||
const edge_list& edges,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error) {
|
||||
if (grads.size() != edges.size()) {
|
||||
// TODO(rzou): probably too many heap allocations here...
|
||||
auto input_metadata = collect_input_metadata(edges);
|
||||
validate_outputs(input_metadata, grads, format_error);
|
||||
}
|
||||
|
||||
std::vector<c10::optional<InputMetadata>> collect_input_metadata(const edge_list& edges) {
|
||||
std::vector<c10::optional<InputMetadata>> input_metadata;
|
||||
for (const auto& edge : edges) {
|
||||
if (!edge.is_valid()) {
|
||||
input_metadata.emplace_back(c10::nullopt);
|
||||
continue;
|
||||
}
|
||||
input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr));
|
||||
}
|
||||
return input_metadata;
|
||||
}
|
||||
|
||||
void validate_outputs(
|
||||
const std::vector<c10::optional<InputMetadata>>& input_metadata,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error) {
|
||||
if (grads.size() != input_metadata.size()) {
|
||||
std::stringstream ss;
|
||||
ss << "invalid number of gradients - expected ";
|
||||
ss << edges.size() << ", but got " << grads.size();
|
||||
ss << input_metadata.size() << ", but got " << grads.size();
|
||||
TORCH_CHECK(false, format_error(ss.str()));
|
||||
}
|
||||
for (const auto i : c10::irange(grads.size())) {
|
||||
const auto& edge = edges[i];
|
||||
if (!edge.is_valid())
|
||||
if (!input_metadata[i].has_value()) {
|
||||
continue;
|
||||
|
||||
const auto& metadata = edge.function->input_metadata(edge.input_nr);
|
||||
}
|
||||
const auto& metadata = input_metadata[i].value();
|
||||
auto& grad = grads[i];
|
||||
if (!grad.defined()) {
|
||||
// FIXME: TestJit.test_ge_optimized fails this assertion.
|
||||
|
||||
@ -43,6 +43,11 @@ TORCH_API void validate_outputs(
|
||||
const edge_list& edges,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error);
|
||||
TORCH_API void validate_outputs(
|
||||
const std::vector<c10::optional<InputMetadata>>& input_metadata,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error);
|
||||
TORCH_API std::vector<c10::optional<InputMetadata>> collect_input_metadata(const edge_list& edges);
|
||||
|
||||
struct NodeTask {
|
||||
std::weak_ptr<GraphTask> base_;
|
||||
|
||||
@ -34,8 +34,12 @@ using tensor_list = std::vector<at::Tensor>;
|
||||
using variable_list = std::vector<Variable>;
|
||||
using edge_list = std::vector<Edge>;
|
||||
using saved_variable_list = std::vector<SavedVariable>;
|
||||
using ivalue_list = std::vector<c10::IValue>;
|
||||
using functional_apply_t = std::function<
|
||||
variable_list(const variable_list&, const std::vector<c10::IValue>&)>;
|
||||
using IndexRange = std::pair<size_t, size_t>;
|
||||
using torch::dynamo::autograd::CompiledNodeArgs;
|
||||
using torch::dynamo::autograd::SavedState;
|
||||
using torch::dynamo::autograd::SwapSavedVariables;
|
||||
|
||||
// Custom deleter to prevent stack overflows.
|
||||
@ -604,6 +608,17 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
std::string("apply_with_saved not implemented: ") + name());
|
||||
}
|
||||
|
||||
virtual ivalue_list retrieve_saved(SwapSavedVariables& saved) {
|
||||
throw std::runtime_error(
|
||||
std::string("retrieve_saved not implemented: ") + name());
|
||||
}
|
||||
virtual std::function<
|
||||
variable_list(const variable_list&, const std::vector<c10::IValue>&)>
|
||||
get_functional() {
|
||||
throw std::runtime_error(
|
||||
std::string("get_functional not implemented: ") + name());
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Performs the `Node`'s actual operation.
|
||||
virtual variable_list apply(variable_list&& inputs) = 0;
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
namespace torch::dynamo::autograd {
|
||||
class CompiledNodeArgs;
|
||||
class SwapSavedVariables;
|
||||
struct SavedState;
|
||||
} // namespace torch::dynamo::autograd
|
||||
|
||||
// A hook that's called on gradients
|
||||
|
||||
@ -103,4 +103,40 @@ variable_list AccumulateGrad::apply_with_saved(
|
||||
return variable_list();
|
||||
}
|
||||
|
||||
ivalue_list AccumulateGrad::retrieve_saved(SwapSavedVariables& saved) {
|
||||
auto should_visit = variable.defined() && variable.requires_grad();
|
||||
if (should_visit) {
|
||||
saved.before(variable);
|
||||
}
|
||||
|
||||
SavedState state;
|
||||
state.enqueue(variable);
|
||||
|
||||
if (should_visit) {
|
||||
saved.after(variable);
|
||||
}
|
||||
|
||||
return state.stack;
|
||||
}
|
||||
|
||||
functional_apply_t AccumulateGrad::get_functional() {
|
||||
return [](const variable_list& inputs,
|
||||
const std::vector<c10::IValue>& saved) -> variable_list {
|
||||
SavedState state;
|
||||
state.stack = saved;
|
||||
Variable foo;
|
||||
state.dequeue(foo);
|
||||
if (!(foo.defined() && foo.requires_grad()) || !inputs[0].defined()) {
|
||||
return variable_list();
|
||||
}
|
||||
// op is intentionally static
|
||||
static auto op = c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("inductor::accumulate_grad_", "")
|
||||
.typed<void(const at::Tensor&, const at::Tensor&)>();
|
||||
op.call(foo, inputs[0]);
|
||||
// TODO(rzou): tensor_post_acc_grad_hooks
|
||||
return variable_list();
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace torch::autograd
|
||||
|
||||
@ -267,6 +267,9 @@ struct TORCH_API AccumulateGrad : public Node {
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
|
||||
ivalue_list retrieve_saved(SwapSavedVariables& saved) override;
|
||||
functional_apply_t get_functional() override;
|
||||
|
||||
Variable variable;
|
||||
};
|
||||
|
||||
|
||||
@ -77,5 +77,22 @@ variable_list GraphRoot::apply_with_saved(
|
||||
saved.after(outputs);
|
||||
return result;
|
||||
}
|
||||
ivalue_list GraphRoot::retrieve_saved(SwapSavedVariables& saved) {
|
||||
saved.before(outputs);
|
||||
SavedState state;
|
||||
state.enqueue(outputs);
|
||||
saved.after(outputs);
|
||||
return state.stack;
|
||||
}
|
||||
functional_apply_t GraphRoot::get_functional() {
|
||||
return [](const variable_list& inputs,
|
||||
const std::vector<c10::IValue>& saved) -> variable_list {
|
||||
SavedState state;
|
||||
state.stack = saved;
|
||||
variable_list outputs;
|
||||
state.dequeue(outputs);
|
||||
return outputs;
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace torch::autograd
|
||||
|
||||
@ -97,6 +97,8 @@ struct TORCH_API GraphRoot : public Node {
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
ivalue_list retrieve_saved(SwapSavedVariables& saved) override;
|
||||
functional_apply_t get_functional() override;
|
||||
|
||||
variable_list outputs;
|
||||
};
|
||||
|
||||
@ -103,7 +103,7 @@ struct TORCH_API InputMetadata {
|
||||
bool maybe_expandable_to(const at::Tensor& grad) const;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const at::TensorOptions options_;
|
||||
at::TensorOptions options_;
|
||||
MetadataShape shape_;
|
||||
c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
|
||||
bool is_tensor_subclass_ = false;
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include <torch/csrc/autograd/saved_variable.h>
|
||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||
#include <torch/csrc/dynamo/compiled_autograd.h>
|
||||
#include <torch/csrc/dynamo/python_compiled_autograd.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
@ -396,6 +397,99 @@ variable_list PyNode::apply_with_saved(
|
||||
return result;
|
||||
}
|
||||
|
||||
ivalue_list PyNode::retrieve_saved(SwapSavedVariables& saved) {
|
||||
auto f = (THPFunction*)obj;
|
||||
saved.before(f->compiled_autograd_symints);
|
||||
saved.before(f->saved_variables);
|
||||
saved.before(f->needs_input_grad);
|
||||
saved.before(f->materialize_non_diff_grads);
|
||||
saved.before(f->output_info);
|
||||
saved.before(f->input_info);
|
||||
|
||||
SavedState state;
|
||||
state.enqueue(f->compiled_autograd_symints);
|
||||
state.enqueue(f->saved_variables, shared_from_this());
|
||||
// state.enqueue(f->needs_input_grad);
|
||||
// state.enqueue(f->materialize_non_diff_grads);
|
||||
// state.enqueue(f->output_info);
|
||||
// state.enqueue(f->input_info);
|
||||
|
||||
saved.after(f->compiled_autograd_symints);
|
||||
saved.after(f->saved_variables);
|
||||
saved.after(f->needs_input_grad);
|
||||
saved.after(f->materialize_non_diff_grads);
|
||||
saved.after(f->output_info);
|
||||
saved.after(f->input_info);
|
||||
|
||||
state.enqueue(f->compiled_autograd_symints);
|
||||
state.enqueue(f->saved_variables, shared_from_this());
|
||||
// state.enqueue(f->needs_input_grad);
|
||||
// state.enqueue(f->materialize_non_diff_grads);
|
||||
// state.enqueue(f->output_info);
|
||||
// state.enqueue(f->input_info);
|
||||
|
||||
return state.stack;
|
||||
}
|
||||
|
||||
// TODO(rzou): compiled autograd needs special handling of the following.
|
||||
std::function<
|
||||
variable_list(const variable_list&, const std::vector<c10::IValue>&)>
|
||||
PyNode::get_functional() {
|
||||
auto node = std::static_pointer_cast<PyNode>(shared_from_this());
|
||||
// TODO(rzou): probably need to pre compute needs_input_grad
|
||||
return
|
||||
[node](
|
||||
const variable_list& inputs, const std::vector<c10::IValue>& saved) {
|
||||
SavedState state;
|
||||
state.stack = saved;
|
||||
|
||||
auto f = (THPFunction*)node->obj;
|
||||
|
||||
state.dequeue(f->compiled_autograd_symints);
|
||||
state.dequeue(f->saved_variables);
|
||||
// state.dequeue(f->needs_input_grad);
|
||||
// state.dequeue(f->materialize_non_diff_grads);
|
||||
// state.dequeue(f->output_info);
|
||||
// state.dequeue(f->input_info);
|
||||
|
||||
f->compiled_autograd_tracing = true;
|
||||
variable_list result;
|
||||
if (!node->compiled_autograd_should_lift()) {
|
||||
if (node->_backward_state_idx.has_value()) {
|
||||
PyObject* r = PyObject_CallMethod(
|
||||
torch::dynamo::autograd::current_py_compiler(),
|
||||
"bind_backward_state",
|
||||
"i",
|
||||
*node->_backward_state_idx);
|
||||
if (r == nullptr) {
|
||||
throw python_error();
|
||||
}
|
||||
THPObjectPtr prior(f->compiled_autograd_backward_state);
|
||||
f->compiled_autograd_backward_state = r;
|
||||
result = node->apply(variable_list(inputs));
|
||||
Py_CLEAR(f->compiled_autograd_backward_state);
|
||||
f->compiled_autograd_backward_state = prior.release();
|
||||
} else {
|
||||
result = node->apply(variable_list(inputs));
|
||||
}
|
||||
} else {
|
||||
result = node->defer_to_dynamo(
|
||||
variable_list(inputs),
|
||||
torch::dynamo::autograd::current_py_compiler());
|
||||
}
|
||||
f->compiled_autograd_tracing = false;
|
||||
|
||||
state.dequeue(f->compiled_autograd_symints);
|
||||
state.dequeue(f->saved_variables);
|
||||
// state.dequeue(f->needs_input_grad);
|
||||
// state.dequeue(f->materialize_non_diff_grads);
|
||||
// state.dequeue(f->output_info);
|
||||
// state.dequeue(f->input_info);
|
||||
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
PyObject* PyNode::to_py_args(
|
||||
const variable_list& inputs,
|
||||
at::OptionalDeviceGuard* device_guard) {
|
||||
|
||||
@ -70,6 +70,11 @@ struct PyNode : public Node {
|
||||
Py_DECREF(obj);
|
||||
}
|
||||
}
|
||||
|
||||
std::function<
|
||||
variable_list(const variable_list&, const std::vector<c10::IValue>&)>
|
||||
get_functional() override;
|
||||
ivalue_list retrieve_saved(SwapSavedVariables& saved) override;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@ -898,6 +898,321 @@ class SwapSavedVariables {
|
||||
StashedVars<at::IValue> stashed_ivalues;
|
||||
};
|
||||
|
||||
struct SavedState {
|
||||
std::vector<at::IValue> stack;
|
||||
int64_t idx = 0;
|
||||
|
||||
void enqueue(
|
||||
const SavedVariable& sv,
|
||||
const std::shared_ptr<Node>& saved_for) {
|
||||
stack.emplace_back(sv.unpack(saved_for));
|
||||
}
|
||||
void dequeue(SavedVariable& sv) {
|
||||
sv = SavedVariable(stack[idx++].toTensor(), /*is_output*/ true);
|
||||
}
|
||||
|
||||
void enqueue(
|
||||
const std::vector<SavedVariable>& sv,
|
||||
const std::shared_ptr<Node>& saved_for) {
|
||||
enqueue(static_cast<int64_t>(sv.size()));
|
||||
for (const auto& v : sv) {
|
||||
enqueue(v, saved_for);
|
||||
}
|
||||
}
|
||||
void dequeue(std::vector<SavedVariable>& sv) {
|
||||
int64_t size = 0;
|
||||
dequeue(size);
|
||||
sv.clear();
|
||||
for (int64_t idx = 0; idx < size; idx++) {
|
||||
sv.emplace_back();
|
||||
dequeue(sv.back());
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
void enqueue(const PyObject*& t) {
|
||||
enqueue_ivalue(t);
|
||||
}
|
||||
void dequeue(PyObject*& t) {
|
||||
dequeue_ivalue(t);
|
||||
}
|
||||
*/
|
||||
|
||||
void enqueue(const VariableInfo& t) {
|
||||
enqueue(t.layout);
|
||||
enqueue(t.device);
|
||||
enqueue(t.scalar_type);
|
||||
enqueue(t.size);
|
||||
enqueue(t.requires_grad);
|
||||
enqueue(t.is_empty);
|
||||
}
|
||||
void dequeue(VariableInfo& t) {
|
||||
dequeue(t.layout);
|
||||
dequeue(t.device);
|
||||
dequeue(t.scalar_type);
|
||||
dequeue(t.size);
|
||||
dequeue(t.requires_grad);
|
||||
dequeue(t.is_empty);
|
||||
}
|
||||
|
||||
void enqueue(size_t t) {
|
||||
enqueue(static_cast<int64_t>(t));
|
||||
}
|
||||
void dequeue(size_t& t) {
|
||||
int64_t tmp = 0;
|
||||
dequeue(tmp);
|
||||
t = static_cast<size_t>(tmp);
|
||||
}
|
||||
|
||||
// TODO: probably wildly inefficient
|
||||
template <class T>
|
||||
void enqueue(const c10::List<T> t) {
|
||||
enqueue(t.vec());
|
||||
}
|
||||
template <class T>
|
||||
void dequeue(c10::List<T>& t) {
|
||||
std::vector<T> tmp;
|
||||
dequeue(tmp);
|
||||
t = c10::List<T>(tmp);
|
||||
}
|
||||
|
||||
void enqueue(const TypeAndSize& value) {
|
||||
enqueue(value.sym_sizes);
|
||||
enqueue(value.options);
|
||||
}
|
||||
void dequeue(TypeAndSize& value) {
|
||||
dequeue(value.sym_sizes);
|
||||
dequeue(value.options);
|
||||
}
|
||||
|
||||
void enqueue(const InputMetadata& value) {
|
||||
enqueue(value.options());
|
||||
enqueue(value.shape_as_dim_vector().vec());
|
||||
enqueue(value.is_tensor_subclass());
|
||||
TORCH_INTERNAL_ASSERT(!value.is_nested_tensor());
|
||||
}
|
||||
// Special case: InputMetadata has no copy ctor
|
||||
// TODO(rzou): ??
|
||||
void dequeue(InputMetadata& value) {
|
||||
at::TensorOptions options;
|
||||
dequeue(options);
|
||||
std::vector<at::SymInt> shape;
|
||||
dequeue(shape);
|
||||
bool is_tensor_subclass = false;
|
||||
dequeue(is_tensor_subclass);
|
||||
SymIntSmallVec sym_shape;
|
||||
for (const auto& s : shape) {
|
||||
sym_shape.emplace_back(s);
|
||||
}
|
||||
value = InputMetadata(options, sym_shape, is_tensor_subclass, false);
|
||||
}
|
||||
|
||||
void enqueue(const ska::flat_hash_map<std::string, at::IValue>& dct) {
|
||||
std::vector<std::string> keys;
|
||||
std::vector<at::IValue> values;
|
||||
for (const auto& [key, value] : dct) {
|
||||
keys.emplace_back(key);
|
||||
values.emplace_back(value);
|
||||
}
|
||||
enqueue(keys);
|
||||
enqueue(values);
|
||||
}
|
||||
void enqueue(const at::IValue& iv) {
|
||||
stack.emplace_back(iv);
|
||||
}
|
||||
void dequeue(at::IValue& iv) {
|
||||
iv = stack[idx++];
|
||||
}
|
||||
void dequeue(ska::flat_hash_map<std::string, at::IValue>& dct) {
|
||||
std::vector<std::string> keys;
|
||||
std::vector<at::IValue> values;
|
||||
dequeue(keys);
|
||||
dequeue(values);
|
||||
dct.clear();
|
||||
for (const auto i : c10::irange(keys.size())) {
|
||||
dct.insert({keys[i], values[i]});
|
||||
}
|
||||
}
|
||||
|
||||
void enqueue(const at::TensorOptions& value) {
|
||||
enqueue(value.requires_grad_opt());
|
||||
enqueue(value.memory_format_opt());
|
||||
enqueue(value.device_opt());
|
||||
enqueue(value.dtype_opt());
|
||||
enqueue(value.layout_opt());
|
||||
enqueue(value.pinned_memory_opt());
|
||||
}
|
||||
void dequeue(at::TensorOptions& value) {
|
||||
auto result = at::TensorOptions();
|
||||
c10::optional<bool> requires_grad_opt;
|
||||
dequeue(requires_grad_opt);
|
||||
if (requires_grad_opt) {
|
||||
result = result.requires_grad(*requires_grad_opt);
|
||||
}
|
||||
c10::optional<c10::MemoryFormat> memory_format_opt;
|
||||
dequeue(memory_format_opt);
|
||||
if (memory_format_opt) {
|
||||
result = result.memory_format(*memory_format_opt);
|
||||
}
|
||||
c10::optional<c10::Device> device_opt;
|
||||
dequeue(device_opt);
|
||||
if (device_opt) {
|
||||
result = result.device(*device_opt);
|
||||
}
|
||||
c10::optional<caffe2::TypeMeta> dtype_opt;
|
||||
dequeue(dtype_opt);
|
||||
if (dtype_opt) {
|
||||
result = result.dtype(*dtype_opt);
|
||||
}
|
||||
c10::optional<c10::Layout> layout_opt;
|
||||
dequeue(layout_opt);
|
||||
if (layout_opt) {
|
||||
result = result.layout(*layout_opt);
|
||||
}
|
||||
c10::optional<bool> pinned_memory_opt;
|
||||
dequeue(pinned_memory_opt);
|
||||
if (pinned_memory_opt) {
|
||||
result = result.pinned_memory(*pinned_memory_opt);
|
||||
}
|
||||
value = result;
|
||||
}
|
||||
|
||||
void enqueue(const caffe2::TypeMeta& value) {
|
||||
enqueue(at::typeMetaToScalarType(value));
|
||||
}
|
||||
void dequeue(caffe2::TypeMeta& value) {
|
||||
at::ScalarType result = at::kFloat;
|
||||
dequeue(result);
|
||||
value = caffe2::TypeMeta::fromScalarType(result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void enqueue(const c10::OptionalArray<T>& t) {
|
||||
enqueue(t.list);
|
||||
}
|
||||
template <typename T>
|
||||
void dequeue(c10::OptionalArray<T>& t) {
|
||||
dequeue(t.list);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void enqueue(const std::optional<T>& t) {
|
||||
enqueue(t.has_value());
|
||||
if (t.has_value()) {
|
||||
enqueue(*t);
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void dequeue(c10::optional<T>& value) {
|
||||
bool has_value = false;
|
||||
dequeue(has_value);
|
||||
T tmp;
|
||||
if (has_value) {
|
||||
dequeue(tmp);
|
||||
}
|
||||
value = tmp;
|
||||
}
|
||||
|
||||
void enqueue(const at::TensorGeometry& t) {
|
||||
enqueue(t.sym_sizes().vec());
|
||||
enqueue(t.sym_strides().vec());
|
||||
enqueue(t.sym_storage_offset());
|
||||
}
|
||||
void dequeue(at::TensorGeometry& t) {
|
||||
std::vector<at::SymInt> sym_sizes;
|
||||
std::vector<at::SymInt> sym_strides;
|
||||
at::SymInt sym_storage_offset;
|
||||
dequeue(sym_sizes);
|
||||
dequeue(sym_strides);
|
||||
dequeue(sym_storage_offset);
|
||||
t = at::TensorGeometry(sym_sizes, sym_strides, sym_storage_offset);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void enqueue(const std::vector<T>& t) {
|
||||
enqueue(static_cast<int64_t>(t.size()));
|
||||
for (const T& i : t) {
|
||||
enqueue(i);
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void dequeue(std::vector<T>& t) {
|
||||
int64_t size = 0;
|
||||
dequeue(size);
|
||||
t.clear();
|
||||
for (int64_t idx = 0; idx < size; idx++) {
|
||||
t.emplace_back();
|
||||
dequeue(t.back());
|
||||
}
|
||||
}
|
||||
|
||||
void enqueue(const c10::SymInt& t) {
|
||||
stack.emplace_back(t);
|
||||
}
|
||||
void dequeue(c10::SymInt& t) {
|
||||
t = stack[idx++].toSymInt();
|
||||
}
|
||||
|
||||
void enqueue(int64_t t) {
|
||||
stack.emplace_back(t);
|
||||
}
|
||||
void dequeue(int64_t& t) {
|
||||
t = stack[idx++].toInt();
|
||||
}
|
||||
|
||||
void enqueue(const std::vector<c10::SymInt>& t) {
|
||||
enqueue_ivalue(t);
|
||||
}
|
||||
void dequeue(std::vector<c10::SymInt>& t) {
|
||||
t = stack[idx++].toSymIntVector();
|
||||
}
|
||||
|
||||
void enqueue(const std::vector<int64_t>& t) {
|
||||
enqueue_ivalue(t);
|
||||
}
|
||||
void dequeue(std::vector<int64_t>& t) {
|
||||
t = stack[idx++].toIntVector();
|
||||
}
|
||||
|
||||
template <class ivalue_t>
|
||||
void enqueue_ivalue(const ivalue_t& t) {
|
||||
stack.emplace_back(t);
|
||||
}
|
||||
template <class ivalue_t>
|
||||
void dequeue_ivalue(ivalue_t& value) {
|
||||
value = stack[idx++].to<ivalue_t>();
|
||||
}
|
||||
#define HANDLE_IVALUE(ivalue_t) \
|
||||
void enqueue(const ivalue_t& value) { \
|
||||
return enqueue_ivalue<ivalue_t>(value); \
|
||||
} \
|
||||
void enqueue(const std::vector<ivalue_t>& value) { \
|
||||
return enqueue_ivalue<std::vector<ivalue_t>>(value); \
|
||||
} \
|
||||
void enqueue(const c10::optional<ivalue_t>& value) { \
|
||||
return enqueue_ivalue<c10::optional<ivalue_t>>(value); \
|
||||
} \
|
||||
void dequeue(ivalue_t& value) { \
|
||||
return dequeue_ivalue<ivalue_t>(value); \
|
||||
} \
|
||||
void dequeue(std::vector<ivalue_t>& value) { \
|
||||
return dequeue_ivalue<std::vector<ivalue_t>>(value); \
|
||||
} \
|
||||
void dequeue(c10::optional<ivalue_t>& value) { \
|
||||
return dequeue_ivalue<c10::optional<ivalue_t>>(value); \
|
||||
}
|
||||
HANDLE_IVALUE(at::Tensor)
|
||||
HANDLE_IVALUE(c10::ScalarType)
|
||||
HANDLE_IVALUE(c10::Scalar)
|
||||
HANDLE_IVALUE(c10::Layout)
|
||||
HANDLE_IVALUE(c10::Device)
|
||||
HANDLE_IVALUE(c10::MemoryFormat)
|
||||
HANDLE_IVALUE(bool)
|
||||
HANDLE_IVALUE(double)
|
||||
HANDLE_IVALUE(std::string)
|
||||
#undef HANDLE_IVALUE
|
||||
};
|
||||
|
||||
} // namespace torch::dynamo::autograd
|
||||
|
||||
template <>
|
||||
|
||||
@ -52,6 +52,12 @@ Notes:
|
||||
namespace torch::dynamo::autograd {
|
||||
using c10::SymInt;
|
||||
|
||||
static PyObject* kPyCompiler;
|
||||
|
||||
PyObject* current_py_compiler() {
|
||||
return kPyCompiler;
|
||||
}
|
||||
|
||||
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
|
||||
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
@ -89,6 +95,23 @@ static void check(bool result) {
|
||||
check(nullptr);
|
||||
}
|
||||
|
||||
static variable_list validate_outputs(
|
||||
variable_list& outputs,
|
||||
const ivalue_list& saved) {
|
||||
SavedState r;
|
||||
r.stack = saved;
|
||||
std::vector<c10::optional<InputMetadata>> value;
|
||||
r.dequeue(value);
|
||||
|
||||
torch::autograd::validate_outputs(
|
||||
value, outputs, [&](const std::string& msg) {
|
||||
std::ostringstream ss;
|
||||
ss << "[Compiled Autograd Tracing:]" << msg;
|
||||
return ss.str();
|
||||
});
|
||||
return outputs;
|
||||
}
|
||||
|
||||
// snapshot of python verbose logging toggle
|
||||
static PyObject* python_verbose_logger = nullptr;
|
||||
|
||||
@ -495,6 +518,91 @@ void set_ivalue_proxies(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static variable_list call_function(
|
||||
PyObject* py_compiler,
|
||||
const char* name,
|
||||
Func fn,
|
||||
const variable_list& inputs,
|
||||
const ivalue_list& saved_state,
|
||||
int64_t num_outputs,
|
||||
const std::string& debug) {
|
||||
// Need this to do PyObject* -> IValue conversion
|
||||
std::vector<at::TypePtr> schema;
|
||||
schema.reserve(saved_state.size());
|
||||
for (const auto& ivalue : saved_state) {
|
||||
schema.emplace_back(ivalue.type());
|
||||
}
|
||||
|
||||
// We are going to bind the following function to Python
|
||||
auto py_func = py::cpp_function(
|
||||
[schema, fn](
|
||||
std::vector<c10::optional<at::Tensor>>& inputs,
|
||||
const py::args& args) -> py::object {
|
||||
// It reconstructs the saved_state from args via the schema
|
||||
std::vector<at::IValue> stack;
|
||||
TORCH_INTERNAL_ASSERT(args.size() == schema.size());
|
||||
auto tuple_args = jit::tuple_slice(args);
|
||||
for (uint64_t idx = 0; idx < schema.size(); idx++) {
|
||||
stack.emplace_back(
|
||||
jit::toIValue(tuple_args[idx], schema[idx], c10::nullopt));
|
||||
}
|
||||
std::vector<at::Tensor> inputs_;
|
||||
for (const auto& inp : inputs) {
|
||||
if (inp.has_value()) {
|
||||
inputs_.emplace_back(*inp);
|
||||
} else {
|
||||
inputs_.emplace_back();
|
||||
}
|
||||
}
|
||||
auto outputs = fn(inputs_, stack);
|
||||
return jit::toPyObject(at::IValue(outputs));
|
||||
});
|
||||
|
||||
// convert ivalue_list -> PyObject*
|
||||
PyObject* py_saved_state =
|
||||
PyTuple_New(static_cast<Py_ssize_t>(schema.size()));
|
||||
for (const auto i : c10::irange(schema.size())) {
|
||||
py::object obj = jit::toPyObject(saved_state[i]);
|
||||
Py_INCREF(obj.ptr());
|
||||
PyTuple_SET_ITEM(py_saved_state, i, obj.ptr());
|
||||
}
|
||||
|
||||
// call the corresponding method on the py_compiler
|
||||
// That method will figure out what to do with the function
|
||||
// (it can either inline it or plop it straight into the FX graph).
|
||||
py::handle handle(py_compiler);
|
||||
py::object stuff = handle.attr(name)(
|
||||
py_func, inputs, py::handle(py_saved_state), num_outputs, debug);
|
||||
|
||||
// Convert the output from PyObject* to vector<Tensor>
|
||||
auto tmp = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
|
||||
variable_list outputs;
|
||||
for (const auto& t : tmp) {
|
||||
if (t.has_value()) {
|
||||
outputs.emplace_back(t.value());
|
||||
} else {
|
||||
outputs.emplace_back();
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
static at::Tensor call_accumulate(
|
||||
PyObject* py_compiler,
|
||||
const at::Tensor& old_var,
|
||||
const at::Tensor& new_var) {
|
||||
if (!old_var.defined()) {
|
||||
return new_var;
|
||||
}
|
||||
if (!new_var.defined()) {
|
||||
return old_var;
|
||||
}
|
||||
py::handle handle(py_compiler);
|
||||
py::object stuff = handle.attr("accumulate")(old_var, new_var);
|
||||
return py::cast<at::Tensor>(stuff);
|
||||
}
|
||||
|
||||
static TraceState call_begin_capture(
|
||||
PyObject* self,
|
||||
CacheNode& cache,
|
||||
@ -648,6 +756,7 @@ CacheNode* _compiled_autograd_impl(
|
||||
// cache miss, need to capture FX graph
|
||||
ClosingTHPObjectPtr py_compiler(
|
||||
check(PyObject_CallNoArgs((the_autograd_compiler))));
|
||||
kPyCompiler = py_compiler.get();
|
||||
|
||||
TraceState state = call_begin_capture(
|
||||
py_compiler, *cache, compiler_call, output_edges.size());
|
||||
@ -714,17 +823,42 @@ CacheNode* _compiled_autograd_impl(
|
||||
}
|
||||
|
||||
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
|
||||
variable_list outputs = call.node->apply_with_saved(inputs, saved);
|
||||
auto saved_state = call.node->retrieve_saved(saved);
|
||||
// std::cout << call.node->name() << std::endl;
|
||||
// std::cout << saved_state.size() << std::endl;
|
||||
// for (const auto& ivalue: saved_state) {
|
||||
// if (ivalue.isTensor()) {
|
||||
// std::cout << "tensor" << std::endl;
|
||||
// } else {
|
||||
// ivalue.dump();
|
||||
// }
|
||||
// }
|
||||
auto outputs = call_function(
|
||||
py_compiler,
|
||||
"apply_functional",
|
||||
call.node->get_functional(),
|
||||
inputs,
|
||||
saved_state,
|
||||
call.node->num_outputs(),
|
||||
call.node->name());
|
||||
// variable_list outputs = call.node->apply_with_saved(inputs, saved);
|
||||
|
||||
saved.debug_asserts();
|
||||
saved.before(call.node->next_edges());
|
||||
validate_outputs(
|
||||
call.node->next_edges(), outputs, [&](const std::string& msg) {
|
||||
std::ostringstream ss;
|
||||
ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
|
||||
<< msg;
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
auto input_metadata = collect_input_metadata(call.node->next_edges());
|
||||
SavedState state;
|
||||
state.enqueue(input_metadata);
|
||||
ivalue_list& input_metadata_state = state.stack;
|
||||
outputs = call_function(
|
||||
py_compiler,
|
||||
"validate_outputs",
|
||||
validate_outputs,
|
||||
outputs,
|
||||
input_metadata_state,
|
||||
outputs.size(),
|
||||
"validate_outputs");
|
||||
|
||||
saved.after(call.node->next_edges());
|
||||
saved.debug_asserts();
|
||||
|
||||
@ -746,13 +880,14 @@ CacheNode* _compiled_autograd_impl(
|
||||
auto& output = outputs[i];
|
||||
const auto& next = call.node->next_edge(i);
|
||||
if (next.is_valid() && output.defined()) {
|
||||
input_buffers.lookup(next.function.get())
|
||||
.add(
|
||||
next.input_nr, std::move(output), std::nullopt, std::nullopt);
|
||||
auto& buffer = input_buffers.lookup(next.function.get());
|
||||
buffer.buffer[next.input_nr] = call_accumulate(
|
||||
py_compiler, buffer.buffer[next.input_nr], output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kPyCompiler = nullptr;
|
||||
PyObject* res = check(call_end_capture(py_compiler, state.outputs));
|
||||
TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
|
||||
TORCH_CHECK(
|
||||
|
||||
@ -4,4 +4,5 @@
|
||||
// see [Note: Compiled Autograd]
|
||||
namespace torch::dynamo::autograd {
|
||||
PyObject* torch_c_dynamo_compiled_autograd_init();
|
||||
PyObject* current_py_compiler();
|
||||
} // namespace torch::dynamo::autograd
|
||||
|
||||
@ -369,8 +369,18 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional<int32_t> N) {
|
||||
}
|
||||
case TypeKind::BoolType:
|
||||
return IValue(py::cast<std::vector<bool>>(obj));
|
||||
case TypeKind::TensorType:
|
||||
return IValue(py::cast<std::vector<at::Tensor>>(obj));
|
||||
case TypeKind::TensorType: {
|
||||
auto thing = py::cast<std::vector<std::optional<at::Tensor>>>(obj);
|
||||
auto thing2 = std::vector<at::Tensor>();
|
||||
for (const auto& inp : thing) {
|
||||
if (inp.has_value()) {
|
||||
thing2.emplace_back(*inp);
|
||||
} else {
|
||||
thing2.emplace_back();
|
||||
}
|
||||
}
|
||||
return IValue(thing2);
|
||||
}
|
||||
default:
|
||||
return createGenericList(obj, elem_type);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user