mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Reify view_func() closures as ViewFuncs (#118404)"
This reverts commit d5a6762263a98e5153bc057c8ba4f377542c7e55. Reverted https://github.com/pytorch/pytorch/pull/118404 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/118404#issuecomment-1938600260))
This commit is contained in:
@ -960,7 +960,6 @@ def define_buck_targets(
|
||||
"Functions.h": ":gen_aten_libtorch[autograd/generated/Functions.h]",
|
||||
"VariableType.h": ":gen_aten_libtorch[autograd/generated/VariableType.h]",
|
||||
"variable_factories.h": ":gen_aten_libtorch[autograd/generated/variable_factories.h]",
|
||||
"ViewFuncs.h": ":gen_aten_libtorch[autograd/generated/ViewFuncs.h]",
|
||||
# Don't build python bindings on mobile.
|
||||
#"python_functions.h",
|
||||
},
|
||||
@ -1467,7 +1466,6 @@ def define_buck_targets(
|
||||
"torch/csrc/jit/mobile/train/random.cpp",
|
||||
"torch/csrc/jit/mobile/train/sequential.cpp",
|
||||
":gen_aten_libtorch[autograd/generated/Functions.cpp]",
|
||||
":gen_aten_libtorch[autograd/generated/ViewFuncs.cpp]",
|
||||
],
|
||||
compiler_flags = get_pt_compiler_flags(),
|
||||
exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],
|
||||
|
@ -261,7 +261,6 @@ _GENERATED_AUTOGRAD_PYTHON_HEADERS = [
|
||||
_GENERATED_AUTOGRAD_CPP_HEADERS = [
|
||||
"torch/csrc/autograd/generated/Functions.h",
|
||||
"torch/csrc/autograd/generated/VariableType.h",
|
||||
"torch/csrc/autograd/generated/ViewFuncs.h",
|
||||
"torch/csrc/autograd/generated/variable_factories.h",
|
||||
]
|
||||
|
||||
@ -304,7 +303,6 @@ GENERATED_AUTOGRAD_CPP = [
|
||||
"torch/csrc/autograd/generated/VariableType_2.cpp",
|
||||
"torch/csrc/autograd/generated/VariableType_3.cpp",
|
||||
"torch/csrc/autograd/generated/VariableType_4.cpp",
|
||||
"torch/csrc/autograd/generated/ViewFuncs.cpp",
|
||||
"torch/csrc/autograd/generated/TraceType_0.cpp",
|
||||
"torch/csrc/autograd/generated/TraceType_1.cpp",
|
||||
"torch/csrc/autograd/generated/TraceType_2.cpp",
|
||||
|
@ -24,7 +24,6 @@ def libtorch_generated_sources(gencode_pattern):
|
||||
"torch/csrc/autograd/generated/VariableType_2.cpp",
|
||||
"torch/csrc/autograd/generated/VariableType_3.cpp",
|
||||
"torch/csrc/autograd/generated/VariableType_4.cpp",
|
||||
"torch/csrc/autograd/generated/ViewFuncs.cpp",
|
||||
"torch/csrc/autograd/generated/TraceType_0.cpp",
|
||||
"torch/csrc/autograd/generated/TraceType_1.cpp",
|
||||
"torch/csrc/autograd/generated/TraceType_2.cpp",
|
||||
|
@ -351,7 +351,6 @@ configure_file("${TORCH_SRC_DIR}/csrc/api/include/torch/version.h.in"
|
||||
|
||||
set(GENERATED_CXX_TORCH
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/ViewFuncs.cpp"
|
||||
)
|
||||
|
||||
if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
|
||||
@ -381,7 +380,6 @@ endif()
|
||||
set(GENERATED_H_TORCH
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.h"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/variable_factories.h"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/ViewFuncs.h"
|
||||
)
|
||||
|
||||
if(NOT INTERN_DISABLE_AUTOGRAD)
|
||||
|
@ -131,8 +131,6 @@ def get_generate_code_bin_outs():
|
||||
"autograd/generated/VariableType_3.cpp": ["autograd/generated/VariableType_3.cpp"],
|
||||
"autograd/generated/VariableType_4.cpp": ["autograd/generated/VariableType_4.cpp"],
|
||||
"autograd/generated/variable_factories.h": ["autograd/generated/variable_factories.h"],
|
||||
"autograd/generated/ViewFuncs.cpp": ["autograd/generated/ViewFuncs.cpp"],
|
||||
"autograd/generated/ViewFuncs.h": ["autograd/generated/ViewFuncs.h"],
|
||||
}
|
||||
|
||||
if is_arvr_mode():
|
||||
|
@ -8792,7 +8792,6 @@ get_out().sum().backward()
|
||||
_assert_match_metadata(new_inp, inp)
|
||||
new_out = out._view_func(new_inp)
|
||||
_assert_match_metadata(new_out, out)
|
||||
self.assertEqual(new_out, out)
|
||||
|
||||
# reverse view_func
|
||||
new_out = out.detach()
|
||||
@ -8831,7 +8830,7 @@ get_out().sum().backward()
|
||||
_test_fn(
|
||||
lambda x: x.chunk(2, -1)[0].transpose(0, 1).unsqueeze(-1), torch.randn(2, 3, 4))
|
||||
_test_fn(
|
||||
lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, 0), torch.randn(2, 3, 4))
|
||||
lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, -1), torch.randn(2, 3, 4))
|
||||
|
||||
# chains with missing view_func()s use as_strided() to cover the gaps
|
||||
def chain_with_only_parent_view_func(x):
|
||||
@ -8839,7 +8838,7 @@ get_out().sum().backward()
|
||||
x = x.split_with_sizes([1, 3], -1)[0]
|
||||
|
||||
with torch.autograd._force_original_view_tracking(False):
|
||||
x = x.chunk(2, 0)
|
||||
x = x.chunk(2, -1)
|
||||
|
||||
return x
|
||||
|
||||
@ -8850,50 +8849,12 @@ get_out().sum().backward()
|
||||
x = x.split_with_sizes([1, 3], -1)[0]
|
||||
|
||||
with torch.autograd._force_original_view_tracking(True):
|
||||
x = x.chunk(2, 0)
|
||||
x = x.chunk(2, -1)
|
||||
|
||||
return x
|
||||
|
||||
_test_fn(chain_with_only_current_view_func, torch.randn(2, 3, 4))
|
||||
|
||||
def test_view_func_replay_with_modified_state(self):
|
||||
with torch.autograd._force_original_view_tracking(True):
|
||||
base = torch.randn(3, 4, 5)
|
||||
view = base.select(1, 2)
|
||||
|
||||
def symint_visitor_fn(x):
|
||||
# modify saved index
|
||||
return x + 1
|
||||
|
||||
# ensure modifying state changes view replay
|
||||
new_base = torch.randn_like(base)
|
||||
new_view = view._view_func(new_base, symint_visitor_fn=symint_visitor_fn)
|
||||
self.assertEqual(new_view, new_base.select(1, 3))
|
||||
|
||||
# ensure saved state reverts back afterwards
|
||||
self.assertEqual(view._view_func(new_base), new_base.select(1, 2))
|
||||
|
||||
# check modifying tensor state. currently, slice_inverse() is the only
|
||||
# view that saves a tensor
|
||||
base = torch.randn(3, 4, 5)
|
||||
sliced = base[:, 2:3, :].detach()
|
||||
view = torch.ops.aten.slice_inverse(sliced, base, 1, 2, 3, 1)
|
||||
|
||||
replacement_shape = (1, 2, 3)
|
||||
|
||||
def tensor_visitor_fn(x):
|
||||
# return tensor with a smaller shape than the saved one
|
||||
return torch.randn(*replacement_shape)
|
||||
|
||||
# ensure modifying state changes view replay
|
||||
new_sliced = torch.ones_like(base)[:, 2:3, :].detach()
|
||||
new_view = view._view_func(new_sliced, tensor_visitor_fn=tensor_visitor_fn)
|
||||
self.assertEqual(new_view.shape, replacement_shape)
|
||||
self.assertEqual(new_view, new_sliced.as_strided(replacement_shape, (6, 3, 1)))
|
||||
|
||||
# ensure saved state reverts back afterwards
|
||||
self.assertEqual(view._view_func(sliced), base)
|
||||
|
||||
def test_setup_context_when_forward_has_default_args(self):
|
||||
class PowFunction(Function):
|
||||
@staticmethod
|
||||
|
@ -1564,7 +1564,6 @@ class TestCompositeCompliance(TestCase):
|
||||
_assert_match_metadata(new_inp, inp)
|
||||
new_out = out._view_func_unsafe(new_inp)
|
||||
_assert_match_metadata(new_out, out)
|
||||
self.assertEqual(new_out, out)
|
||||
|
||||
# reverse view_func
|
||||
new_out = out.detach()
|
||||
|
@ -124,8 +124,6 @@ def define_tools_targets(
|
||||
"autograd/templates/TraceType.cpp",
|
||||
"autograd/templates/VariableType.cpp",
|
||||
"autograd/templates/VariableType.h",
|
||||
"autograd/templates/ViewFuncs.cpp",
|
||||
"autograd/templates/ViewFuncs.h",
|
||||
"autograd/templates/annotated_fn_args.py.in",
|
||||
"autograd/templates/python_enum_tag.cpp",
|
||||
"autograd/templates/python_fft_functions.cpp",
|
||||
|
@ -43,7 +43,6 @@ from .gen_inplace_or_view_type import gen_inplace_or_view_type
|
||||
from .gen_trace_type import gen_trace_type
|
||||
from .gen_variable_factories import gen_variable_factories
|
||||
from .gen_variable_type import gen_variable_type
|
||||
from .gen_view_funcs import gen_view_funcs
|
||||
from .load_derivatives import load_derivatives
|
||||
|
||||
|
||||
@ -96,9 +95,6 @@ def gen_autograd(
|
||||
# Generate variable_factories.h
|
||||
gen_variable_factories(out, native_functions_path, tags_path, template_path)
|
||||
|
||||
# Generate ViewFuncs.h/cpp
|
||||
gen_view_funcs(out, fns_with_diff_infos, template_path)
|
||||
|
||||
|
||||
def gen_autograd_python(
|
||||
native_functions_path: str,
|
||||
|
@ -4,7 +4,7 @@
|
||||
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
|
||||
# The fallback is expected to mimick this codegen, so we should keep the two in sync.
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.autograd import (
|
||||
@ -172,7 +172,7 @@ for (auto ${view_idx} : c10::irange(${var}.size())) {
|
||||
|
||||
SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate(
|
||||
"""\
|
||||
std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
|
||||
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
|
||||
std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
|
||||
if (${is_view_with_metadata_change} ||
|
||||
!self.unsafeGetTensorImpl()->support_as_strided() ||
|
||||
@ -184,9 +184,11 @@ if (${is_view_with_metadata_change} ||
|
||||
"""
|
||||
)
|
||||
|
||||
REPLAY_VIEW_FUNC = CodeTemplate(
|
||||
REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate(
|
||||
"""\
|
||||
func = std::make_unique<${view_func_name}>(${view_func_args});
|
||||
func = [=](const at::Tensor& ${input_base}) {
|
||||
return ${replay_view_call}${view_indexing};
|
||||
};
|
||||
"""
|
||||
)
|
||||
|
||||
@ -344,13 +346,24 @@ def get_view_info(f: NativeFunction) -> Optional[str]:
|
||||
return view_info
|
||||
|
||||
|
||||
def emit_view_func(
|
||||
# For view replay calls, we generate an ordinary Dispatcher::call() instead, because:
|
||||
# - We want to replay the entire call into the op, including any previously-set dispatch keys (including autograd!).
|
||||
# - The view replay call also is not part of the hot path.
|
||||
def emit_view_call(
|
||||
f: NativeFunction, input_base: str, unpacked_args: Sequence[str]
|
||||
) -> str:
|
||||
# View replay functions use the standard Dispatcher::call API.
|
||||
return CALL_DISPATCH.substitute(
|
||||
unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=unpacked_args
|
||||
)
|
||||
|
||||
|
||||
def emit_view_lambda(
|
||||
f: NativeFunction, bindings: List[Binding], view_idx: Optional[str] = None
|
||||
) -> str:
|
||||
"""Generate an additional lambda function to recover views in backward when as_strided is not supported.
|
||||
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
|
||||
"""
|
||||
# TODO: Clean this logic up if we get rid of reverse view funcs or reify them.
|
||||
input_base = "input_base"
|
||||
replay_view_func = ""
|
||||
updated_args: List[str] = []
|
||||
@ -399,14 +412,11 @@ def emit_view_func(
|
||||
else:
|
||||
updated_args.append(arg)
|
||||
|
||||
from .gen_view_funcs import view_func_name
|
||||
|
||||
view_func_args = [b.name for b in bindings if b.name != "self"]
|
||||
if view_idx is not None:
|
||||
view_func_args.append(f"{view_idx}")
|
||||
replay_view_func += REPLAY_VIEW_FUNC.substitute(
|
||||
view_func_name=view_func_name(f, include_namespace=True),
|
||||
view_func_args=view_func_args,
|
||||
replay_view_call = emit_view_call(f, input_base, updated_args)
|
||||
replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute(
|
||||
input_base=input_base,
|
||||
replay_view_call=replay_view_call,
|
||||
view_indexing=("" if view_idx is None else f"[{view_idx}]"),
|
||||
)
|
||||
|
||||
input_view = "input_view"
|
||||
@ -483,26 +493,26 @@ def emit_view_body(
|
||||
if is_tensor_list_type(return_info.type):
|
||||
creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE")
|
||||
view_idx = "view_idx"
|
||||
view_func = emit_view_func(
|
||||
view_lambda = emit_view_lambda(
|
||||
f, extract_bindings(f), view_idx=view_idx
|
||||
).strip()
|
||||
as_view_call = (
|
||||
f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], "
|
||||
"/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, "
|
||||
"/* view_func */ std::move(func), /* rev_view_func */ rev_func, "
|
||||
"/* view_func */ func, /* rev_view_func */ rev_func, "
|
||||
f"/* creation_meta */ {creation_meta});"
|
||||
)
|
||||
call += MULTI_OUTPUT_VIEW_ITERATION.substitute(
|
||||
var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}"
|
||||
var=var, view_idx=view_idx, body=f"{view_lambda}\n{as_view_call}"
|
||||
)
|
||||
rhs_value = f"std::move({var})"
|
||||
else:
|
||||
call += emit_view_func(f, extract_bindings(f), view_idx=None)
|
||||
call += emit_view_lambda(f, extract_bindings(f), view_idx=None)
|
||||
creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT")
|
||||
rhs_value = (
|
||||
f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, "
|
||||
"/* is_fw_differentiable */ true, "
|
||||
f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})"
|
||||
f"/* view_func */ func, /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})"
|
||||
)
|
||||
else:
|
||||
# This could be supported but we don't need it at the moment, so keeping things simple.
|
||||
|
@ -1,334 +0,0 @@
|
||||
# Generates ViewFuncs.h/cpp
|
||||
#
|
||||
# NOTE: If any changes are being made to the ViewFunc codegen please also check
|
||||
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
|
||||
# The fallback is expected to mimic this codegen, so we should keep the two in sync.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
BaseCType,
|
||||
Binding,
|
||||
NamedCType,
|
||||
SymIntT,
|
||||
tensorT,
|
||||
VectorCType,
|
||||
)
|
||||
from torchgen.code_template import CodeTemplate
|
||||
from torchgen.model import Argument, NativeFunction, OptionalType
|
||||
from torchgen.utils import FileManager
|
||||
|
||||
from .gen_inplace_or_view_type import (
|
||||
CALL_DISPATCH,
|
||||
extract_bindings,
|
||||
get_view_info,
|
||||
modifies_arguments,
|
||||
use_derived,
|
||||
)
|
||||
|
||||
FUNCTION_DECLARATION = CodeTemplate(
|
||||
"""\
|
||||
#define ${uppercase_op}_AVAILABLE
|
||||
struct ${op} : public ${superclass} {
|
||||
${op}(${constructor_args}) ${initializer_list}
|
||||
{};
|
||||
virtual ~${op}() override {};
|
||||
virtual std::vector<c10::SymInt> get_symints() const override;
|
||||
virtual size_t num_symints() const override;
|
||||
virtual std::vector<at::Tensor> get_tensors() const override;
|
||||
virtual size_t num_tensors() const override;
|
||||
virtual at::Tensor operator()(const at::Tensor&) const override;
|
||||
virtual std::unique_ptr<ViewFunc> clone_and_set(
|
||||
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
|
||||
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
|
||||
|
||||
protected:
|
||||
virtual void set_symints(std::vector<c10::SymInt>) override;
|
||||
virtual void set_tensors(std::vector<at::Tensor>) override;
|
||||
|
||||
private:
|
||||
${state}
|
||||
};
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
FUNCTION_DEFINITION = CodeTemplate(
|
||||
"""\
|
||||
std::vector<c10::SymInt> ${op}::get_symints() const {
|
||||
${get_symints}
|
||||
}
|
||||
|
||||
size_t ${op}::num_symints() const {
|
||||
return static_cast<size_t>(${num_symints});
|
||||
}
|
||||
|
||||
void ${op}::set_symints(std::vector<c10::SymInt> ${symints_vec}) {
|
||||
TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints());
|
||||
${set_symints}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> ${op}::get_tensors() const {
|
||||
${get_tensors}
|
||||
}
|
||||
|
||||
size_t ${op}::num_tensors() const {
|
||||
return static_cast<size_t>(${num_tensors});
|
||||
}
|
||||
|
||||
void ${op}::set_tensors(std::vector<at::Tensor> ${tensors_vec}) {
|
||||
TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors());
|
||||
${set_tensors}
|
||||
}
|
||||
|
||||
at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const {
|
||||
return ${op_call};
|
||||
}
|
||||
|
||||
std::unique_ptr<ViewFunc> ${op}::clone_and_set(
|
||||
std::optional<std::vector<c10::SymInt>> ${symints_vec},
|
||||
std::optional<std::vector<at::Tensor>> ${tensors_vec}) const {
|
||||
auto output = std::make_unique<${op}>(${clone_args});
|
||||
if (${symints_vec}.has_value()) {
|
||||
output->set_symints(std::move(*(${symints_vec})));
|
||||
}
|
||||
if (${tensors_vec}.has_value()) {
|
||||
output->set_tensors(std::move(*(${tensors_vec})));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# e.g. as_strided -> AsStridedViewFunc for camel case or
|
||||
# as_strided_view_func otherwise
|
||||
def view_func_name(
|
||||
f: NativeFunction, include_namespace: bool = False, camel_case: bool = True
|
||||
) -> str:
|
||||
name = f.func.name.unambiguous_name()
|
||||
view_func_name = f"{name.replace('.', '_')}_view_func"
|
||||
if camel_case:
|
||||
is_private = view_func_name.startswith("_")
|
||||
view_func_name = "".join(
|
||||
[p.title() for p in view_func_name.replace(".", "_").split("_")]
|
||||
)
|
||||
if is_private:
|
||||
# put the leading underscore back in
|
||||
view_func_name = f"_{view_func_name}"
|
||||
namespace = "torch::autograd::generated::" if include_namespace else ""
|
||||
return f"{namespace}{view_func_name}"
|
||||
|
||||
|
||||
def is_symint_or_tensor(arg: Argument) -> bool:
|
||||
return arg.type.is_tensor_like() or arg.type.is_symint_like()
|
||||
|
||||
|
||||
def remove_const_ref(binding: Binding) -> Binding:
|
||||
return Binding(
|
||||
name=binding.name,
|
||||
nctype=binding.nctype.remove_const_ref(),
|
||||
argument=binding.argument,
|
||||
default=binding.default,
|
||||
)
|
||||
|
||||
|
||||
def returns_multi_tensor(fn: NativeFunction) -> bool:
|
||||
returns = fn.func.returns
|
||||
assert len(returns) == 1
|
||||
returns_list_like = returns[0].type.is_list_like() is not None
|
||||
returns_tensor_like = returns[0].type.is_tensor_like()
|
||||
return returns_list_like and returns_tensor_like
|
||||
|
||||
|
||||
# Generates strings with logic for getting / setting state of a particular type.
|
||||
#
|
||||
# Args:
|
||||
# bindings (list): List of state bindings of interest (may be empty)
|
||||
# state_vec_type (NamedCType): Type of vector to either return or copy from
|
||||
#
|
||||
# Returns:
|
||||
# tuple: (list of getter logic strings, list of setter logic strings, string
|
||||
# with num items expression)
|
||||
def generate_state_getter_setter(
|
||||
bindings: List[Binding],
|
||||
state_vec_type: NamedCType,
|
||||
) -> Tuple[List[str], List[str], str]:
|
||||
getter_logic = []
|
||||
setter_logic = []
|
||||
|
||||
state_vec = state_vec_type.name
|
||||
getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};")
|
||||
if len(bindings) > 0:
|
||||
setter_logic.append("auto i = 0;")
|
||||
|
||||
num_exprs = []
|
||||
for i, b in enumerate(bindings):
|
||||
assert isinstance(b.argument, Argument)
|
||||
if b.argument.type.is_list_like():
|
||||
# Handle list-likes.
|
||||
num_expr = f"{b.name}.size()"
|
||||
num_exprs.append(num_expr)
|
||||
getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());"
|
||||
setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());"
|
||||
elif isinstance(b.argument.type, OptionalType):
|
||||
# Handle optionals.
|
||||
num_expr = f"({b.name}.has_value() ? 1 : 0)"
|
||||
num_exprs.append(num_expr)
|
||||
conditional = f"if({b.name}.has_value())"
|
||||
getter = (
|
||||
f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));"
|
||||
)
|
||||
setter = f"{conditional} {b.name} = {state_vec}[i];"
|
||||
else:
|
||||
num_expr = "1"
|
||||
num_exprs.append(num_expr)
|
||||
getter = f"{state_vec}.push_back({b.name});"
|
||||
setter = f"{b.name} = {state_vec}[i];"
|
||||
|
||||
getter_logic.append(getter)
|
||||
setter_logic.append(setter)
|
||||
if i < len(bindings) - 1:
|
||||
setter_logic.append(f"i += {num_expr};")
|
||||
|
||||
# Reserve / assert based on the total number of items expression.
|
||||
num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs)
|
||||
if len(bindings) > 0:
|
||||
getter_logic.insert(1, f"{state_vec}.reserve({num_items});")
|
||||
|
||||
getter_logic.append(f"return {state_vec};")
|
||||
|
||||
return getter_logic, setter_logic, num_items
|
||||
|
||||
|
||||
def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
|
||||
bindings = extract_bindings(fn)
|
||||
non_self_bindings = [b for b in bindings if b.name != "self"]
|
||||
|
||||
non_self_args = fn.func.arguments.flat_all[1:]
|
||||
non_self_value_bindings = [
|
||||
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
|
||||
]
|
||||
|
||||
# Generate constructor / clone args for the generated struct.
|
||||
constructor_args = [b.defn() for b in non_self_bindings]
|
||||
clone_args = [b.name for b in non_self_bindings]
|
||||
|
||||
# Generate state variable declarations for the generated struct.
|
||||
state_variables = [
|
||||
f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings
|
||||
]
|
||||
|
||||
# Generate initializer list expressions for the generated struct.
|
||||
# allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as
|
||||
# vector<SymInt>s.
|
||||
init_exprs = translate(
|
||||
non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True
|
||||
)
|
||||
initializers = []
|
||||
for b, init_expr in zip(non_self_bindings, init_exprs):
|
||||
name = b.nctype.name
|
||||
assert isinstance(name, str)
|
||||
initializers.append(f"{name}({init_expr.expr})")
|
||||
|
||||
# Generate call to underlying view op
|
||||
call_input_name = "input_base"
|
||||
op_call_args = [call_input_name, *(b.name for b in non_self_bindings)]
|
||||
op_call = CALL_DISPATCH.substitute(
|
||||
unambiguous_name=fn.func.name.unambiguous_name(),
|
||||
unpacked_args=op_call_args,
|
||||
)
|
||||
|
||||
# Multi-output views additionally require a view_idx for disambiguation.
|
||||
if returns_multi_tensor(fn):
|
||||
view_idx_name = "view_idx"
|
||||
view_idx_typename = "int64_t"
|
||||
view_idx_decl = f"{view_idx_typename} {view_idx_name}"
|
||||
constructor_args.append(view_idx_decl)
|
||||
clone_args.append(view_idx_name)
|
||||
state_variables.append(f"{view_idx_decl};")
|
||||
initializers.append(f"{view_idx_name}({view_idx_name})")
|
||||
op_call += f"[{view_idx_name}]"
|
||||
|
||||
# Generate initializer list for the generated struct.
|
||||
initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else ""
|
||||
|
||||
# Generate getter / setter logic for any symints.
|
||||
symint_bindings = [
|
||||
b
|
||||
for b in non_self_bindings
|
||||
if isinstance(b.argument, Argument) and b.argument.type.is_symint_like()
|
||||
]
|
||||
symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT)))
|
||||
get_symints, set_symints, num_symints = generate_state_getter_setter(
|
||||
symint_bindings, symints_vec_type
|
||||
)
|
||||
|
||||
# Generate getter / setter logic for any tensors.
|
||||
tensor_bindings = [
|
||||
b
|
||||
for b in non_self_bindings
|
||||
if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like()
|
||||
]
|
||||
tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT)))
|
||||
get_tensors, set_tensors, num_tensors = generate_state_getter_setter(
|
||||
tensor_bindings, tensors_vec_type
|
||||
)
|
||||
|
||||
return template.substitute(
|
||||
op=view_func_name(fn),
|
||||
uppercase_op=view_func_name(fn, camel_case=False).upper(),
|
||||
superclass="torch::autograd::ViewFunc",
|
||||
initializer_list=initializer_list,
|
||||
state=state_variables,
|
||||
constructor_args=constructor_args,
|
||||
clone_args=clone_args,
|
||||
symints_vec=symints_vec_type.name,
|
||||
get_symints=get_symints,
|
||||
set_symints=set_symints,
|
||||
num_symints=num_symints,
|
||||
tensors_vec=tensors_vec_type.name,
|
||||
get_tensors=get_tensors,
|
||||
set_tensors=set_tensors,
|
||||
num_tensors=num_tensors,
|
||||
call_input_name=call_input_name,
|
||||
op_call=op_call,
|
||||
)
|
||||
|
||||
|
||||
def gen_view_funcs(
|
||||
out: str,
|
||||
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
|
||||
template_path: str,
|
||||
) -> None:
|
||||
# don't need the info parts, just the function
|
||||
fns = [fn.func for fn in fns_with_infos if use_derived(fn)]
|
||||
# only want out-of-place views
|
||||
view_fns = [
|
||||
fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn)
|
||||
]
|
||||
|
||||
declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns]
|
||||
definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns]
|
||||
ops_headers = [f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in view_fns]
|
||||
|
||||
file_basename = "ViewFuncs"
|
||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||
for suffix in [".h", ".cpp"]:
|
||||
fname = file_basename + suffix
|
||||
fm.write_with_template(
|
||||
fname,
|
||||
fname,
|
||||
lambda: {
|
||||
"generated_comment": "@"
|
||||
+ f"generated from {fm.template_dir_for_comments()}/"
|
||||
+ fname,
|
||||
"view_func_declarations": declarations,
|
||||
"view_func_definitions": definitions,
|
||||
"ops_headers": ops_headers,
|
||||
},
|
||||
)
|
@ -1,6 +1,5 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include "torch/csrc/autograd/VariableTypeUtils.h"
|
||||
#include "torch/csrc/autograd/generated/ViewFuncs.h"
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <ATen/FunctionalInverses.h>
|
||||
|
@ -1,71 +0,0 @@
|
||||
#include "torch/csrc/autograd/generated/ViewFuncs.h"
|
||||
|
||||
// ${generated_comment}
|
||||
|
||||
using at::Tensor;
|
||||
using at::Scalar;
|
||||
using at::IntArrayRef;
|
||||
using at::TensorList;
|
||||
|
||||
namespace torch::autograd {
|
||||
|
||||
std::vector<c10::SymInt> ChainedViewFunc::get_symints() const {
|
||||
auto symints = first->get_symints();
|
||||
auto second_symints = second->get_symints();
|
||||
symints.reserve(symints.size() + second_symints.size());
|
||||
symints.insert(
|
||||
symints.end(),
|
||||
std::make_move_iterator(second_symints.begin()),
|
||||
std::make_move_iterator(second_symints.end()));
|
||||
return symints;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> ChainedViewFunc::get_tensors() const {
|
||||
auto tensors = first->get_tensors();
|
||||
auto second_tensors = second->get_tensors();
|
||||
tensors.reserve(tensors.size() + second_tensors.size());
|
||||
tensors.insert(
|
||||
tensors.end(),
|
||||
std::make_move_iterator(second_tensors.begin()),
|
||||
std::make_move_iterator(second_tensors.end()));
|
||||
return tensors;
|
||||
}
|
||||
|
||||
at::Tensor ChainedViewFunc::operator()(const at::Tensor& input_base) const {
|
||||
return (*second)((*first)(input_base));
|
||||
}
|
||||
|
||||
std::unique_ptr<ViewFunc> ChainedViewFunc::clone_and_set(
|
||||
std::optional<std::vector<c10::SymInt>> symints,
|
||||
std::optional<std::vector<at::Tensor>> tensors) const {
|
||||
std::optional<std::vector<c10::SymInt>> first_symints;
|
||||
std::optional<std::vector<c10::SymInt>> second_symints;
|
||||
if (symints.has_value()) {
|
||||
TORCH_INTERNAL_ASSERT(symints->size() == num_symints());
|
||||
first_symints = std::vector<c10::SymInt>(
|
||||
symints->begin(), symints->begin() + first->num_symints());
|
||||
second_symints = std::vector<c10::SymInt>(
|
||||
symints->begin() + first->num_symints(), symints->end());
|
||||
}
|
||||
|
||||
std::optional<std::vector<at::Tensor>> first_tensors;
|
||||
std::optional<std::vector<at::Tensor>> second_tensors;
|
||||
if (tensors.has_value()) {
|
||||
TORCH_INTERNAL_ASSERT(tensors->size() == num_tensors());
|
||||
first_tensors = std::vector<at::Tensor>(
|
||||
tensors->begin(), tensors->begin() + first->num_tensors());
|
||||
second_tensors = std::vector<at::Tensor>(
|
||||
tensors->begin() + first->num_tensors(), tensors->end());
|
||||
}
|
||||
|
||||
return std::make_unique<ChainedViewFunc>(
|
||||
first->clone_and_set(first_symints, first_tensors),
|
||||
second->clone_and_set(second_symints, second_tensors));
|
||||
}
|
||||
|
||||
namespace generated {
|
||||
|
||||
${view_func_definitions}
|
||||
|
||||
} // namespace torch::autograd
|
||||
} // namespace generated
|
@ -1,106 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// ${generated_comment}
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Operators.h>
|
||||
#else
|
||||
$ops_headers
|
||||
#endif
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
/// Base class for view functions, providing reapplication of a view on a new base.
|
||||
/// Each view op should get a codegenerated subclass of this class containing
|
||||
/// any state needed to reconstruct the view. The class also provides convenience
|
||||
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
|
||||
/// where we want to use symbolic values or fake tensors instead.
|
||||
struct TORCH_API ViewFunc {
|
||||
virtual ~ViewFunc() {}
|
||||
/// Returns any SymInts in the saved state.
|
||||
virtual std::vector<c10::SymInt> get_symints() const { return {}; }
|
||||
/// Returns the number of SymInts in the saved state.
|
||||
virtual size_t num_symints() const { return 0; }
|
||||
/// Returns any tensors in the saved state.
|
||||
virtual std::vector<at::Tensor> get_tensors() const { return {}; }
|
||||
/// Returns the number of tensors in the saved state.
|
||||
virtual size_t num_tensors() const { return 0; }
|
||||
/// Reapplies the view on the given base using the saved state.
|
||||
virtual at::Tensor operator()(const at::Tensor&) const = 0;
|
||||
/// Returns a clone of this ViewFunc, optionally with the specified saved state.
|
||||
virtual std::unique_ptr<ViewFunc> clone_and_set(
|
||||
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
|
||||
std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;
|
||||
|
||||
protected:
|
||||
/// Sets the values of any SymInts in the saved state. The input vector size must
|
||||
/// match the number of SymInts in the saved state (i.e. the size of the list
|
||||
/// returned by get_symints()).
|
||||
virtual void set_symints(std::vector<c10::SymInt>) {}
|
||||
/// Sets the values of any Tensors in the saved state. The input vector size must
|
||||
/// match the number of Tensors in the saved state (i.e. the size of the list
|
||||
/// returned by get_tensors()).
|
||||
virtual void set_tensors(std::vector<at::Tensor>) {}
|
||||
};
|
||||
|
||||
/// ViewFunc that represents a chain of two ViewFuncs.
|
||||
struct ChainedViewFunc : public ViewFunc {
|
||||
ChainedViewFunc(
|
||||
std::unique_ptr<ViewFunc> first,
|
||||
std::unique_ptr<ViewFunc> second)
|
||||
: first(std::move(first)),
|
||||
second(std::move(second)) {}
|
||||
virtual ~ChainedViewFunc() override {};
|
||||
virtual std::vector<c10::SymInt> get_symints() const override;
|
||||
virtual size_t num_symints() const override {
|
||||
return first->num_symints() + second->num_symints();
|
||||
}
|
||||
virtual std::vector<at::Tensor> get_tensors() const override;
|
||||
virtual size_t num_tensors() const override {
|
||||
return first->num_tensors() + second->num_tensors();
|
||||
}
|
||||
virtual at::Tensor operator()(const at::Tensor&) const override;
|
||||
virtual std::unique_ptr<ViewFunc> clone_and_set(
|
||||
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
|
||||
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<ViewFunc> first;
|
||||
std::unique_ptr<ViewFunc> second;
|
||||
};
|
||||
|
||||
/// ViewFunc that errors with a specified error message when called.
|
||||
struct ErroringViewFunc : public ViewFunc {
|
||||
ErroringViewFunc(const std::string& error_msg) : error_msg(error_msg) {}
|
||||
virtual ~ErroringViewFunc() override {};
|
||||
virtual at::Tensor operator()(const at::Tensor&) const override {
|
||||
TORCH_CHECK(false, error_msg);
|
||||
}
|
||||
virtual std::unique_ptr<ViewFunc> clone_and_set(
|
||||
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
|
||||
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override {
|
||||
return std::make_unique<ErroringViewFunc>(error_msg);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string error_msg;
|
||||
};
|
||||
|
||||
namespace generated {
|
||||
|
||||
using at::Scalar;
|
||||
using at::Tensor;
|
||||
using at::IntArrayRef;
|
||||
using at::ArrayRef;
|
||||
using at::Type;
|
||||
using at::ScalarType;
|
||||
using c10::optional;
|
||||
using c10::fmap;
|
||||
|
||||
${view_func_declarations}
|
||||
|
||||
}}} // namespace torch::autograd::generated
|
@ -9,7 +9,6 @@
|
||||
#include <torch/csrc/autograd/autograd.h>
|
||||
#include <torch/csrc/autograd/functions/utils.h>
|
||||
#include <torch/csrc/autograd/generated/VariableType.h>
|
||||
#include <torch/csrc/autograd/generated/ViewFuncs.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <utility>
|
||||
@ -475,10 +474,13 @@ static Tensor _fw_primal(
|
||||
at::AutoDispatchBelowADInplaceOrView guard;
|
||||
return at::alias(self);
|
||||
})();
|
||||
std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
|
||||
std::function<at::Tensor(const at::Tensor&)> func = nullptr;
|
||||
std::function<at::Tensor(const at::Tensor&)> rev_func = nullptr;
|
||||
if (!self.unsafeGetTensorImpl()->support_as_strided()) {
|
||||
func = std::make_unique<ViewViewFunc>(self.sym_sizes());
|
||||
auto size_vec = self.sizes().vec();
|
||||
func = [=](const at::Tensor& input_base) {
|
||||
return input_base.view(size_vec);
|
||||
};
|
||||
rev_func = [=](const at::Tensor& input_view) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
@ -508,10 +510,13 @@ static Tensor _make_dual(
|
||||
at::AutoDispatchBelowADInplaceOrView guard;
|
||||
return at::alias(primal);
|
||||
})();
|
||||
std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
|
||||
std::function<at::Tensor(const at::Tensor&)> func = nullptr;
|
||||
std::function<at::Tensor(const at::Tensor&)> rev_func = nullptr;
|
||||
if (!primal.unsafeGetTensorImpl()->support_as_strided()) {
|
||||
func = std::make_unique<ViewViewFunc>(primal.sym_sizes());
|
||||
auto size_vec = primal.sizes().vec();
|
||||
func = [=](const at::Tensor& input_base) {
|
||||
return input_base.view(size_vec);
|
||||
};
|
||||
rev_func = [=](const at::Tensor& input_view) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
|
@ -160,7 +160,7 @@ inline at::Tensor as_view(
|
||||
const at::Tensor& tensor,
|
||||
bool is_bw_differentiable,
|
||||
bool is_fw_differentiable,
|
||||
std::unique_ptr<ViewFunc> view_func = nullptr,
|
||||
std::function<at::Tensor(const at::Tensor&)> view_func = nullptr,
|
||||
std::function<at::Tensor(const at::Tensor&)> rev_view_func = nullptr,
|
||||
CreationMeta creation_meta = CreationMeta::DEFAULT,
|
||||
bool allow_tensor_metadata_change = true) {
|
||||
@ -208,13 +208,11 @@ inline at::Tensor as_view(
|
||||
c10::optional<ViewInfo> new_fw_info;
|
||||
|
||||
if (is_bw_differentiable) {
|
||||
auto bw_view_func = view_func ? view_func->clone_and_set() : nullptr;
|
||||
if (diff_view_meta && diff_view_meta->has_bw_view()) {
|
||||
const auto& base_bw_info = diff_view_meta->get_backward_view();
|
||||
new_bw_info = base_bw_info.chain(
|
||||
base, tensor, std::move(bw_view_func), rev_view_func);
|
||||
new_bw_info = base_bw_info.chain(base, tensor, view_func, rev_view_func);
|
||||
} else {
|
||||
new_bw_info = ViewInfo(base, std::move(bw_view_func), rev_view_func);
|
||||
new_bw_info = ViewInfo(base, view_func, rev_view_func);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
|
@ -213,7 +213,7 @@ void AutogradMeta::set_fw_grad(
|
||||
// - Copy the given new_grad into this view
|
||||
// - Use this view as the new new_grad
|
||||
if (this_view_meta->has_fw_view()) {
|
||||
auto& view_info = this_view_meta->get_forward_view();
|
||||
auto view_info = this_view_meta->get_forward_view();
|
||||
auto& base = view_info.base_;
|
||||
|
||||
if (!base._fw_grad(level).defined()) {
|
||||
|
@ -530,14 +530,18 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl(
|
||||
(*stack)[stack->size() - num_returns + aliased_output_idx];
|
||||
|
||||
// See NOTE [ View + Inplace detection ] for more details about this logic
|
||||
// We always need this view_func because otherwise if we do in-place
|
||||
// on this view, we would implicitly use AsStridedBackward instead
|
||||
// of the NotImplemented node. For the cross-dtype/non-strided
|
||||
// cases, we would create something like this anyway
|
||||
auto error_msg =
|
||||
("Mutating the view " + op_name +
|
||||
"which does not have a derivative implemented is forbidden.");
|
||||
auto erroring_view_func = std::make_unique<ErroringViewFunc>(error_msg);
|
||||
const auto erroring_view_func = [op_name = op_name](const at::Tensor&) {
|
||||
// We always need this view_func because otherwise if we do in-place
|
||||
// on this view, we would implicitly use AsStridedBackward instead
|
||||
// of the NotImplemented node. For the cross-dtype/non-strided
|
||||
// cases, we would create something like this anyway
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Mutating the view ",
|
||||
op_name,
|
||||
" which does not have a derivative implemented is forbidden.");
|
||||
return at::Tensor();
|
||||
};
|
||||
|
||||
const auto erroring_rev_view_func = [op_name = op_name](const at::Tensor&) {
|
||||
TORCH_CHECK(
|
||||
@ -556,7 +560,7 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl(
|
||||
/* tensor=*/sub_output,
|
||||
/* is_bw_differentiable=*/true,
|
||||
/* is_fw_differentiable=*/true,
|
||||
/* view_func=*/std::move(erroring_view_func),
|
||||
/* view_func=*/erroring_view_func,
|
||||
/* rev_view_func=*/erroring_rev_view_func,
|
||||
/* creation_meta=*/
|
||||
InferenceMode::is_enabled()
|
||||
@ -573,7 +577,7 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl(
|
||||
/* tensor=*/std::move(aliased_output_iv).toTensor(),
|
||||
/* is_bw_differentiable=*/true,
|
||||
/* is_fw_differentiable=*/true,
|
||||
/* view_func=*/std::move(erroring_view_func),
|
||||
/* view_func=*/erroring_view_func,
|
||||
/* rev_view_func=*/erroring_rev_view_func,
|
||||
/* creation_meta=*/
|
||||
InferenceMode::is_enabled()
|
||||
|
@ -55,7 +55,7 @@ variable_list CopyBackwards::apply_with_saved(
|
||||
CopySlices::CopySlices(
|
||||
const Variable& base_var,
|
||||
at::TensorGeometry view_,
|
||||
std::unique_ptr<ViewFunc> view_fn_,
|
||||
std::function<at::Tensor(const at::Tensor&)> view_fn_,
|
||||
std::shared_ptr<Node> fn_)
|
||||
: Node(),
|
||||
base(base_var),
|
||||
@ -98,7 +98,7 @@ inline variable_list CopySlices::apply_impl(
|
||||
|
||||
at::Tensor grad_slice;
|
||||
if (view_fn) {
|
||||
grad_slice = (*view_fn)(result);
|
||||
grad_slice = view_fn(result);
|
||||
} else {
|
||||
auto offset = view.sym_storage_offset() - base.sym_storage_offset();
|
||||
grad_slice =
|
||||
|
@ -79,7 +79,7 @@ struct TORCH_API CopyBackwards : public Node {
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
//
|
||||
// We need to perform grad_view = fn(grad_view), but out-of-place.
|
||||
// view_fn_ is an optional function saved in DifferentiableViewMeta
|
||||
// view_fn_ is an optional lambda function saved in DifferentiableViewMeta
|
||||
// from forward pass, so that we can recover we when as_strided is not
|
||||
// supported. It preserves the invariants:
|
||||
// view = view_fn_(base)
|
||||
@ -160,7 +160,7 @@ struct TORCH_API CopySlices : public Node {
|
||||
CopySlices(
|
||||
const Variable& base_var,
|
||||
at::TensorGeometry view_,
|
||||
std::unique_ptr<ViewFunc> view_fn_,
|
||||
std::function<at::Tensor(const at::Tensor&)> view_fn_,
|
||||
std::shared_ptr<Node> fn_);
|
||||
|
||||
// common code between apply/apply_with_saved
|
||||
@ -178,7 +178,7 @@ struct TORCH_API CopySlices : public Node {
|
||||
// view and view_fn are redundant and view_fn will be used if available.
|
||||
// See Note [View + Inplace update for base tensor] for details.
|
||||
at::TensorGeometry view;
|
||||
std::unique_ptr<ViewFunc> view_fn;
|
||||
std::function<at::Tensor(const at::Tensor&)> view_fn;
|
||||
std::shared_ptr<Node> fn;
|
||||
};
|
||||
|
||||
|
@ -524,36 +524,16 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// Maps the given python callable over a vector of items, returning a vector
|
||||
// of the same type of items.
|
||||
template <typename T>
|
||||
static std::vector<T> map_py_func(
|
||||
const py::function& func,
|
||||
const std::vector<T>& items) {
|
||||
std::vector<T> new_items;
|
||||
new_items.reserve(items.size());
|
||||
for (auto& item : items) {
|
||||
new_items.push_back(py::cast<T>(func(item)));
|
||||
}
|
||||
return new_items;
|
||||
}
|
||||
|
||||
static PyObject* view_func_impl(
|
||||
PyObject* _self,
|
||||
PyObject* args,
|
||||
PyObject* kwargs,
|
||||
PyObject* self_,
|
||||
PyObject* arg,
|
||||
bool check_has_same_meta) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& self = THPVariable_Unpack(_self);
|
||||
|
||||
static PythonArgParser parser({
|
||||
"_view_func(Tensor new_base, PyObject* symint_visitor_fn=None, PyObject* tensor_visitor_fn=None)",
|
||||
});
|
||||
ParsedArgs<3> parsed_args{};
|
||||
auto r = parser.parse(_self, args, kwargs, parsed_args);
|
||||
auto new_base = r.tensor(0);
|
||||
PyObject* symint_visitor_fn = r.pyobject(1);
|
||||
PyObject* tensor_visitor_fn = r.pyobject(2);
|
||||
const auto& self = THPVariable_Unpack(self_);
|
||||
TORCH_CHECK(
|
||||
THPVariable_Check(arg),
|
||||
"_view_func expect a single argument that is a Tensor");
|
||||
const auto& new_base = THPVariable_Unpack(arg);
|
||||
|
||||
// Ensure that self is indeed a backward differentiable view
|
||||
// If not, we return an undefined Tensor (None) and let the user handle it.
|
||||
@ -566,29 +546,7 @@ static PyObject* view_func_impl(
|
||||
torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
|
||||
// Do the actual view replay
|
||||
if (view_info.has_view_fn()) {
|
||||
auto& view_func = view_info.view_fn();
|
||||
|
||||
// Determine new SymInt / tensor state as needed.
|
||||
c10::optional<std::vector<c10::SymInt>> new_symints = c10::nullopt;
|
||||
if (symint_visitor_fn != Py_None) {
|
||||
new_symints = map_py_func(
|
||||
py::cast<py::function>(symint_visitor_fn),
|
||||
view_func.get_symints());
|
||||
}
|
||||
|
||||
c10::optional<std::vector<at::Tensor>> new_tensors = c10::nullopt;
|
||||
if (tensor_visitor_fn != Py_None) {
|
||||
new_tensors = map_py_func(
|
||||
py::cast<py::function>(tensor_visitor_fn),
|
||||
view_func.get_tensors());
|
||||
}
|
||||
|
||||
// call view func
|
||||
if (new_symints.has_value() || new_tensors.has_value()) {
|
||||
out = (*view_func.clone_and_set(new_symints, new_tensors))(new_base);
|
||||
} else {
|
||||
out = view_func(new_base);
|
||||
}
|
||||
out = view_info.view_fn()(new_base);
|
||||
} else {
|
||||
out = new_base.as_strided(
|
||||
self.sizes(), self.strides(), self.storage_offset());
|
||||
@ -599,18 +557,12 @@ static PyObject* view_func_impl(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_view_func(
|
||||
PyObject* self_,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/true);
|
||||
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
|
||||
return view_func_impl(self_, arg, /*check_has_same_meta=*/true);
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_view_func_unsafe(
|
||||
PyObject* self_,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/false);
|
||||
static PyObject* THPVariable_view_func_unsafe(PyObject* self_, PyObject* arg) {
|
||||
return view_func_impl(self_, arg, /*check_has_same_meta=*/false);
|
||||
}
|
||||
|
||||
static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) {
|
||||
@ -1715,14 +1667,8 @@ static PyMethodDef extra_methods[] = {
|
||||
METH_STATIC | METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
|
||||
{"_view_func",
|
||||
castPyCFunctionWithKeywords(THPVariable_view_func),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_view_func_unsafe",
|
||||
castPyCFunctionWithKeywords(THPVariable_view_func_unsafe),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_view_func", THPVariable_view_func, METH_O, nullptr},
|
||||
{"_view_func_unsafe", THPVariable_view_func_unsafe, METH_O, nullptr},
|
||||
{"_rev_view_func_unsafe",
|
||||
THPVariable_rev_view_func_unsafe,
|
||||
METH_O,
|
||||
|
@ -8,7 +8,6 @@
|
||||
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
||||
#include <torch/csrc/autograd/functions/tensor.h>
|
||||
#include <torch/csrc/autograd/generated/Functions.h>
|
||||
#include <torch/csrc/autograd/generated/ViewFuncs.h>
|
||||
#include <torch/csrc/autograd/utils/error_messages.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
@ -26,19 +25,6 @@
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
// Returns a ViewFunc with a corresponding view that matches the shape,
|
||||
// stride, and storage offset of the given tensor.
|
||||
// NB: On mobile, the as_strided() op and thus the generated AsStridedViewFunc
|
||||
// may not be available.
|
||||
static std::unique_ptr<ViewFunc> create_view_func_matching(const Variable& t) {
|
||||
#ifdef AS_STRIDED_VIEW_FUNC_AVAILABLE
|
||||
return std::make_unique<torch::autograd::generated::AsStridedViewFunc>(
|
||||
t.sym_sizes(), t.sym_strides(), t.sym_storage_offset());
|
||||
#else
|
||||
return std::make_unique<ErroringViewFunc>("as_strided() not available");
|
||||
#endif
|
||||
}
|
||||
|
||||
DifferentiableViewMeta::DifferentiableViewMeta(
|
||||
at::TensorImpl* self_impl,
|
||||
c10::optional<ViewInfo> backward_info,
|
||||
@ -72,7 +58,7 @@ DifferentiableViewMeta::DifferentiableViewMeta(
|
||||
ViewInfo ViewInfo::chain(
|
||||
const Variable& base,
|
||||
const Variable& tensor,
|
||||
std::unique_ptr<ViewFunc> view_func,
|
||||
std::function<Variable(const Variable&)> view_func,
|
||||
std::function<Variable(const Variable&)> rev_view_func) const {
|
||||
// Set `view_func` using the root base as input.
|
||||
// `view_func` is used to recover views in backward when either as_strided is
|
||||
@ -83,8 +69,12 @@ ViewInfo ViewInfo::chain(
|
||||
if (view_func) {
|
||||
// both current_view and it's parent have a view_func
|
||||
if (view_fn_) {
|
||||
view_func = std::make_unique<ChainedViewFunc>(
|
||||
view_fn_->clone_and_set(), std::move(view_func));
|
||||
// Copy parent view function to gain ownership
|
||||
auto prev_fn = view_fn_;
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
auto temp = prev_fn(root_base);
|
||||
return view_func(temp);
|
||||
};
|
||||
|
||||
// assume view_fn_ / rev_view_fn_ always exist together or neither are set
|
||||
auto prev_rev_fn = rev_view_fn_;
|
||||
@ -95,9 +85,13 @@ ViewInfo ViewInfo::chain(
|
||||
} else {
|
||||
// current_view has a view_func and but it's parent doesn't have one
|
||||
if (base.unsafeGetTensorImpl()->support_as_strided()) {
|
||||
auto match_base_view_func = create_view_func_matching(base);
|
||||
view_func = std::make_unique<ChainedViewFunc>(
|
||||
std::move(match_base_view_func), std::move(view_func));
|
||||
auto size = base.sym_sizes().vec();
|
||||
auto stride = base.sym_strides().vec();
|
||||
auto storage_offset = base.sym_storage_offset();
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
auto temp = root_base.as_strided_symint(size, stride, storage_offset);
|
||||
return view_func(temp);
|
||||
};
|
||||
|
||||
// assume view_fn_ / rev_view_fn_ always exist together or neither are
|
||||
// set
|
||||
@ -117,7 +111,12 @@ ViewInfo ViewInfo::chain(
|
||||
auto error_msg =
|
||||
("Attempted to chain views when the parent view has no view_func() and "
|
||||
"does not support as_strided(). This is not supported.");
|
||||
view_func = std::make_unique<ErroringViewFunc>(error_msg);
|
||||
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
TORCH_CHECK(false, error_msg);
|
||||
return root_base;
|
||||
};
|
||||
|
||||
rev_view_func = [=](const at::Tensor& root_view) {
|
||||
TORCH_CHECK(false, error_msg);
|
||||
return root_view;
|
||||
@ -126,9 +125,15 @@ ViewInfo ViewInfo::chain(
|
||||
}
|
||||
} else if (view_fn_) {
|
||||
// if current_view doesn't have a view_func but it's parent has one
|
||||
auto match_tensor_view_func = create_view_func_matching(tensor);
|
||||
view_func = std::make_unique<ChainedViewFunc>(
|
||||
view_fn_->clone_and_set(), std::move(match_tensor_view_func));
|
||||
// Copy parent view function to gain ownership
|
||||
auto prev_view_fn = view_fn_;
|
||||
auto size = tensor.sym_sizes().vec();
|
||||
auto stride = tensor.sym_strides().vec();
|
||||
auto storage_offset = tensor.sym_storage_offset();
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
auto temp = prev_view_fn(root_base);
|
||||
return temp.as_strided_symint(size, stride, storage_offset);
|
||||
};
|
||||
|
||||
// assume view_fn_ / rev_view_fn_ always exist together or neither are set
|
||||
auto prev_rev_view_fn = rev_view_fn_;
|
||||
@ -227,12 +232,12 @@ void rebase_history(const Variable& self, Edge gradient_edge) {
|
||||
TORCH_CHECK(
|
||||
gradient_edge.function->num_inputs() == 1,
|
||||
"Functions which modify views in-place must return a single Variable");
|
||||
const auto& view_info = diff_view_meta->get_backward_view();
|
||||
auto view_info = diff_view_meta->get_backward_view();
|
||||
diff_view_meta->output_nr_ = gradient_edge.input_nr;
|
||||
auto copy_slices = std::make_shared<CopySlices>(
|
||||
view_info.base_,
|
||||
at::TensorGeometry(self),
|
||||
view_info.has_view_fn() ? view_info.view_fn().clone_and_set() : nullptr,
|
||||
view_info.view_fn_,
|
||||
std::move(gradient_edge.function));
|
||||
if (self.requires_grad()) {
|
||||
// If self did not previously require grad, there are no hooks to move
|
||||
@ -651,7 +656,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
|
||||
if (diff_view_meta && diff_view_meta->has_bw_view()) {
|
||||
// See NOTE [ View + Inplace detection ]
|
||||
std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
|
||||
auto& view_info = diff_view_meta->get_backward_view();
|
||||
auto view_info = diff_view_meta->get_backward_view();
|
||||
if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) {
|
||||
return diff_view_meta->grad_fn_;
|
||||
}
|
||||
@ -691,7 +696,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
|
||||
// in VariableType_x.cpp
|
||||
// that would provide a way to recreate the grad_fn chain.
|
||||
if (view_info.has_view_fn()) {
|
||||
auto& view_fn = view_info.view_fn();
|
||||
auto view_fn = view_info.view_fn();
|
||||
Tensor diff_view;
|
||||
{
|
||||
// We can reach this path with grad_mode disabled, e.g. engine
|
||||
|
@ -7,7 +7,6 @@
|
||||
#include <torch/csrc/autograd/edge.h>
|
||||
#include <torch/csrc/autograd/forward_grad.h>
|
||||
#include <torch/csrc/autograd/function_hook.h>
|
||||
#include <torch/csrc/autograd/generated/ViewFuncs.h>
|
||||
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
@ -332,8 +331,7 @@ struct TORCH_API ViewInfo {
|
||||
/// By default we use as_strided to recover views which is more efficient.
|
||||
/// view_fn is only saved when as_strided is not supported.
|
||||
/// If view_fn has value, we use it to recover views in backward.
|
||||
std::unique_ptr<ViewFunc> view_fn_;
|
||||
|
||||
std::function<Variable(const Variable&)> view_fn_;
|
||||
/// Analogue of view_fn but in reverse: given a view -> produce the base by
|
||||
/// applying the inverse view.
|
||||
std::function<Variable(const Variable&)> rev_view_fn_;
|
||||
@ -344,10 +342,10 @@ struct TORCH_API ViewInfo {
|
||||
return view_fn_ != nullptr;
|
||||
}
|
||||
|
||||
const ViewFunc& view_fn() const {
|
||||
std::function<Variable(const Variable&)> view_fn() const {
|
||||
TORCH_CHECK(
|
||||
has_view_fn(), "Can only access the view function if it exists.");
|
||||
return *view_fn_;
|
||||
return view_fn_;
|
||||
}
|
||||
|
||||
std::function<Variable(const Variable&)> rev_view_fn() const {
|
||||
@ -368,12 +366,12 @@ struct TORCH_API ViewInfo {
|
||||
ViewInfo chain(
|
||||
const Variable& base,
|
||||
const Variable& tensor,
|
||||
std::unique_ptr<ViewFunc> view_func = nullptr,
|
||||
std::function<Variable(const Variable&)> view_func = nullptr,
|
||||
std::function<Variable(const Variable&)> rev_view_func = nullptr) const;
|
||||
|
||||
ViewInfo(
|
||||
Variable base,
|
||||
std::unique_ptr<ViewFunc> view_fn,
|
||||
std::function<Variable(const Variable&)> view_fn,
|
||||
std::function<Variable(const Variable&)> rev_view_fn)
|
||||
: base_(std::move(base)),
|
||||
view_fn_(std::move(view_fn)),
|
||||
|
Reference in New Issue
Block a user