Revert "Reify view_func() closures as ViewFuncs (#118404)"

This reverts commit d5a6762263a98e5153bc057c8ba4f377542c7e55.

Reverted https://github.com/pytorch/pytorch/pull/118404 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/118404#issuecomment-1938600260))
This commit is contained in:
PyTorch MergeBot
2024-02-12 12:38:51 +00:00
parent 79df897608
commit 24bdd03d23
23 changed files with 117 additions and 718 deletions

View File

@ -960,7 +960,6 @@ def define_buck_targets(
"Functions.h": ":gen_aten_libtorch[autograd/generated/Functions.h]",
"VariableType.h": ":gen_aten_libtorch[autograd/generated/VariableType.h]",
"variable_factories.h": ":gen_aten_libtorch[autograd/generated/variable_factories.h]",
"ViewFuncs.h": ":gen_aten_libtorch[autograd/generated/ViewFuncs.h]",
# Don't build python bindings on mobile.
#"python_functions.h",
},
@ -1467,7 +1466,6 @@ def define_buck_targets(
"torch/csrc/jit/mobile/train/random.cpp",
"torch/csrc/jit/mobile/train/sequential.cpp",
":gen_aten_libtorch[autograd/generated/Functions.cpp]",
":gen_aten_libtorch[autograd/generated/ViewFuncs.cpp]",
],
compiler_flags = get_pt_compiler_flags(),
exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],

View File

@ -261,7 +261,6 @@ _GENERATED_AUTOGRAD_PYTHON_HEADERS = [
_GENERATED_AUTOGRAD_CPP_HEADERS = [
"torch/csrc/autograd/generated/Functions.h",
"torch/csrc/autograd/generated/VariableType.h",
"torch/csrc/autograd/generated/ViewFuncs.h",
"torch/csrc/autograd/generated/variable_factories.h",
]
@ -304,7 +303,6 @@ GENERATED_AUTOGRAD_CPP = [
"torch/csrc/autograd/generated/VariableType_2.cpp",
"torch/csrc/autograd/generated/VariableType_3.cpp",
"torch/csrc/autograd/generated/VariableType_4.cpp",
"torch/csrc/autograd/generated/ViewFuncs.cpp",
"torch/csrc/autograd/generated/TraceType_0.cpp",
"torch/csrc/autograd/generated/TraceType_1.cpp",
"torch/csrc/autograd/generated/TraceType_2.cpp",

View File

@ -24,7 +24,6 @@ def libtorch_generated_sources(gencode_pattern):
"torch/csrc/autograd/generated/VariableType_2.cpp",
"torch/csrc/autograd/generated/VariableType_3.cpp",
"torch/csrc/autograd/generated/VariableType_4.cpp",
"torch/csrc/autograd/generated/ViewFuncs.cpp",
"torch/csrc/autograd/generated/TraceType_0.cpp",
"torch/csrc/autograd/generated/TraceType_1.cpp",
"torch/csrc/autograd/generated/TraceType_2.cpp",

View File

@ -351,7 +351,6 @@ configure_file("${TORCH_SRC_DIR}/csrc/api/include/torch/version.h.in"
set(GENERATED_CXX_TORCH
"${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/ViewFuncs.cpp"
)
if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
@ -381,7 +380,6 @@ endif()
set(GENERATED_H_TORCH
"${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.h"
"${TORCH_SRC_DIR}/csrc/autograd/generated/variable_factories.h"
"${TORCH_SRC_DIR}/csrc/autograd/generated/ViewFuncs.h"
)
if(NOT INTERN_DISABLE_AUTOGRAD)

View File

@ -131,8 +131,6 @@ def get_generate_code_bin_outs():
"autograd/generated/VariableType_3.cpp": ["autograd/generated/VariableType_3.cpp"],
"autograd/generated/VariableType_4.cpp": ["autograd/generated/VariableType_4.cpp"],
"autograd/generated/variable_factories.h": ["autograd/generated/variable_factories.h"],
"autograd/generated/ViewFuncs.cpp": ["autograd/generated/ViewFuncs.cpp"],
"autograd/generated/ViewFuncs.h": ["autograd/generated/ViewFuncs.h"],
}
if is_arvr_mode():

View File

@ -8792,7 +8792,6 @@ get_out().sum().backward()
_assert_match_metadata(new_inp, inp)
new_out = out._view_func(new_inp)
_assert_match_metadata(new_out, out)
self.assertEqual(new_out, out)
# reverse view_func
new_out = out.detach()
@ -8831,7 +8830,7 @@ get_out().sum().backward()
_test_fn(
lambda x: x.chunk(2, -1)[0].transpose(0, 1).unsqueeze(-1), torch.randn(2, 3, 4))
_test_fn(
lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, 0), torch.randn(2, 3, 4))
lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, -1), torch.randn(2, 3, 4))
# chains with missing view_func()s use as_strided() to cover the gaps
def chain_with_only_parent_view_func(x):
@ -8839,7 +8838,7 @@ get_out().sum().backward()
x = x.split_with_sizes([1, 3], -1)[0]
with torch.autograd._force_original_view_tracking(False):
x = x.chunk(2, 0)
x = x.chunk(2, -1)
return x
@ -8850,50 +8849,12 @@ get_out().sum().backward()
x = x.split_with_sizes([1, 3], -1)[0]
with torch.autograd._force_original_view_tracking(True):
x = x.chunk(2, 0)
x = x.chunk(2, -1)
return x
_test_fn(chain_with_only_current_view_func, torch.randn(2, 3, 4))
def test_view_func_replay_with_modified_state(self):
with torch.autograd._force_original_view_tracking(True):
base = torch.randn(3, 4, 5)
view = base.select(1, 2)
def symint_visitor_fn(x):
# modify saved index
return x + 1
# ensure modifying state changes view replay
new_base = torch.randn_like(base)
new_view = view._view_func(new_base, symint_visitor_fn=symint_visitor_fn)
self.assertEqual(new_view, new_base.select(1, 3))
# ensure saved state reverts back afterwards
self.assertEqual(view._view_func(new_base), new_base.select(1, 2))
# check modifying tensor state. currently, slice_inverse() is the only
# view that saves a tensor
base = torch.randn(3, 4, 5)
sliced = base[:, 2:3, :].detach()
view = torch.ops.aten.slice_inverse(sliced, base, 1, 2, 3, 1)
replacement_shape = (1, 2, 3)
def tensor_visitor_fn(x):
# return tensor with a smaller shape than the saved one
return torch.randn(*replacement_shape)
# ensure modifying state changes view replay
new_sliced = torch.ones_like(base)[:, 2:3, :].detach()
new_view = view._view_func(new_sliced, tensor_visitor_fn=tensor_visitor_fn)
self.assertEqual(new_view.shape, replacement_shape)
self.assertEqual(new_view, new_sliced.as_strided(replacement_shape, (6, 3, 1)))
# ensure saved state reverts back afterwards
self.assertEqual(view._view_func(sliced), base)
def test_setup_context_when_forward_has_default_args(self):
class PowFunction(Function):
@staticmethod

View File

@ -1564,7 +1564,6 @@ class TestCompositeCompliance(TestCase):
_assert_match_metadata(new_inp, inp)
new_out = out._view_func_unsafe(new_inp)
_assert_match_metadata(new_out, out)
self.assertEqual(new_out, out)
# reverse view_func
new_out = out.detach()

View File

@ -124,8 +124,6 @@ def define_tools_targets(
"autograd/templates/TraceType.cpp",
"autograd/templates/VariableType.cpp",
"autograd/templates/VariableType.h",
"autograd/templates/ViewFuncs.cpp",
"autograd/templates/ViewFuncs.h",
"autograd/templates/annotated_fn_args.py.in",
"autograd/templates/python_enum_tag.cpp",
"autograd/templates/python_fft_functions.cpp",

View File

@ -43,7 +43,6 @@ from .gen_inplace_or_view_type import gen_inplace_or_view_type
from .gen_trace_type import gen_trace_type
from .gen_variable_factories import gen_variable_factories
from .gen_variable_type import gen_variable_type
from .gen_view_funcs import gen_view_funcs
from .load_derivatives import load_derivatives
@ -96,9 +95,6 @@ def gen_autograd(
# Generate variable_factories.h
gen_variable_factories(out, native_functions_path, tags_path, template_path)
# Generate ViewFuncs.h/cpp
gen_view_funcs(out, fns_with_diff_infos, template_path)
def gen_autograd_python(
native_functions_path: str,

View File

@ -4,7 +4,7 @@
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
# The fallback is expected to mimick this codegen, so we should keep the two in sync.
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Sequence, Tuple
from torchgen.api import cpp
from torchgen.api.autograd import (
@ -172,7 +172,7 @@ for (auto ${view_idx} : c10::irange(${var}.size())) {
SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate(
"""\
std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
if (${is_view_with_metadata_change} ||
!self.unsafeGetTensorImpl()->support_as_strided() ||
@ -184,9 +184,11 @@ if (${is_view_with_metadata_change} ||
"""
)
REPLAY_VIEW_FUNC = CodeTemplate(
REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate(
"""\
func = std::make_unique<${view_func_name}>(${view_func_args});
func = [=](const at::Tensor& ${input_base}) {
return ${replay_view_call}${view_indexing};
};
"""
)
@ -344,13 +346,24 @@ def get_view_info(f: NativeFunction) -> Optional[str]:
return view_info
def emit_view_func(
# For view replay calls, we generate an ordinary Dispatcher::call() instead, because:
# - We want to replay the entire call into the op, including any previously-set dispatch keys (including autograd!).
# - The view replay call also is not part of the hot path.
def emit_view_call(
f: NativeFunction, input_base: str, unpacked_args: Sequence[str]
) -> str:
# View replay functions use the standard Dispatcher::call API.
return CALL_DISPATCH.substitute(
unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=unpacked_args
)
def emit_view_lambda(
f: NativeFunction, bindings: List[Binding], view_idx: Optional[str] = None
) -> str:
"""Generate an additional lambda function to recover views in backward when as_strided is not supported.
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
"""
# TODO: Clean this logic up if we get rid of reverse view funcs or reify them.
input_base = "input_base"
replay_view_func = ""
updated_args: List[str] = []
@ -399,14 +412,11 @@ def emit_view_func(
else:
updated_args.append(arg)
from .gen_view_funcs import view_func_name
view_func_args = [b.name for b in bindings if b.name != "self"]
if view_idx is not None:
view_func_args.append(f"{view_idx}")
replay_view_func += REPLAY_VIEW_FUNC.substitute(
view_func_name=view_func_name(f, include_namespace=True),
view_func_args=view_func_args,
replay_view_call = emit_view_call(f, input_base, updated_args)
replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute(
input_base=input_base,
replay_view_call=replay_view_call,
view_indexing=("" if view_idx is None else f"[{view_idx}]"),
)
input_view = "input_view"
@ -483,26 +493,26 @@ def emit_view_body(
if is_tensor_list_type(return_info.type):
creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE")
view_idx = "view_idx"
view_func = emit_view_func(
view_lambda = emit_view_lambda(
f, extract_bindings(f), view_idx=view_idx
).strip()
as_view_call = (
f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], "
"/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, "
"/* view_func */ std::move(func), /* rev_view_func */ rev_func, "
"/* view_func */ func, /* rev_view_func */ rev_func, "
f"/* creation_meta */ {creation_meta});"
)
call += MULTI_OUTPUT_VIEW_ITERATION.substitute(
var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}"
var=var, view_idx=view_idx, body=f"{view_lambda}\n{as_view_call}"
)
rhs_value = f"std::move({var})"
else:
call += emit_view_func(f, extract_bindings(f), view_idx=None)
call += emit_view_lambda(f, extract_bindings(f), view_idx=None)
creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT")
rhs_value = (
f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, "
"/* is_fw_differentiable */ true, "
f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})"
f"/* view_func */ func, /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})"
)
else:
# This could be supported but we don't need it at the moment, so keeping things simple.

View File

@ -1,334 +0,0 @@
# Generates ViewFuncs.h/cpp
#
# NOTE: If any changes are being made to the ViewFunc codegen please also check
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
# The fallback is expected to mimic this codegen, so we should keep the two in sync.
from typing import List, Tuple
import torchgen.api.dispatcher as dispatcher
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
NamedCType,
SymIntT,
tensorT,
VectorCType,
)
from torchgen.code_template import CodeTemplate
from torchgen.model import Argument, NativeFunction, OptionalType
from torchgen.utils import FileManager
from .gen_inplace_or_view_type import (
CALL_DISPATCH,
extract_bindings,
get_view_info,
modifies_arguments,
use_derived,
)
FUNCTION_DECLARATION = CodeTemplate(
"""\
#define ${uppercase_op}_AVAILABLE
struct ${op} : public ${superclass} {
${op}(${constructor_args}) ${initializer_list}
{};
virtual ~${op}() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override;
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override;
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
protected:
virtual void set_symints(std::vector<c10::SymInt>) override;
virtual void set_tensors(std::vector<at::Tensor>) override;
private:
${state}
};
"""
)
FUNCTION_DEFINITION = CodeTemplate(
"""\
std::vector<c10::SymInt> ${op}::get_symints() const {
${get_symints}
}
size_t ${op}::num_symints() const {
return static_cast<size_t>(${num_symints});
}
void ${op}::set_symints(std::vector<c10::SymInt> ${symints_vec}) {
TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints());
${set_symints}
}
std::vector<at::Tensor> ${op}::get_tensors() const {
${get_tensors}
}
size_t ${op}::num_tensors() const {
return static_cast<size_t>(${num_tensors});
}
void ${op}::set_tensors(std::vector<at::Tensor> ${tensors_vec}) {
TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors());
${set_tensors}
}
at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const {
return ${op_call};
}
std::unique_ptr<ViewFunc> ${op}::clone_and_set(
std::optional<std::vector<c10::SymInt>> ${symints_vec},
std::optional<std::vector<at::Tensor>> ${tensors_vec}) const {
auto output = std::make_unique<${op}>(${clone_args});
if (${symints_vec}.has_value()) {
output->set_symints(std::move(*(${symints_vec})));
}
if (${tensors_vec}.has_value()) {
output->set_tensors(std::move(*(${tensors_vec})));
}
return output;
}
"""
)
# e.g. as_strided -> AsStridedViewFunc for camel case or
# as_strided_view_func otherwise
def view_func_name(
f: NativeFunction, include_namespace: bool = False, camel_case: bool = True
) -> str:
name = f.func.name.unambiguous_name()
view_func_name = f"{name.replace('.', '_')}_view_func"
if camel_case:
is_private = view_func_name.startswith("_")
view_func_name = "".join(
[p.title() for p in view_func_name.replace(".", "_").split("_")]
)
if is_private:
# put the leading underscore back in
view_func_name = f"_{view_func_name}"
namespace = "torch::autograd::generated::" if include_namespace else ""
return f"{namespace}{view_func_name}"
def is_symint_or_tensor(arg: Argument) -> bool:
return arg.type.is_tensor_like() or arg.type.is_symint_like()
def remove_const_ref(binding: Binding) -> Binding:
return Binding(
name=binding.name,
nctype=binding.nctype.remove_const_ref(),
argument=binding.argument,
default=binding.default,
)
def returns_multi_tensor(fn: NativeFunction) -> bool:
returns = fn.func.returns
assert len(returns) == 1
returns_list_like = returns[0].type.is_list_like() is not None
returns_tensor_like = returns[0].type.is_tensor_like()
return returns_list_like and returns_tensor_like
# Generates strings with logic for getting / setting state of a particular type.
#
# Args:
# bindings (list): List of state bindings of interest (may be empty)
# state_vec_type (NamedCType): Type of vector to either return or copy from
#
# Returns:
# tuple: (list of getter logic strings, list of setter logic strings, string
# with num items expression)
def generate_state_getter_setter(
bindings: List[Binding],
state_vec_type: NamedCType,
) -> Tuple[List[str], List[str], str]:
getter_logic = []
setter_logic = []
state_vec = state_vec_type.name
getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};")
if len(bindings) > 0:
setter_logic.append("auto i = 0;")
num_exprs = []
for i, b in enumerate(bindings):
assert isinstance(b.argument, Argument)
if b.argument.type.is_list_like():
# Handle list-likes.
num_expr = f"{b.name}.size()"
num_exprs.append(num_expr)
getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());"
setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());"
elif isinstance(b.argument.type, OptionalType):
# Handle optionals.
num_expr = f"({b.name}.has_value() ? 1 : 0)"
num_exprs.append(num_expr)
conditional = f"if({b.name}.has_value())"
getter = (
f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));"
)
setter = f"{conditional} {b.name} = {state_vec}[i];"
else:
num_expr = "1"
num_exprs.append(num_expr)
getter = f"{state_vec}.push_back({b.name});"
setter = f"{b.name} = {state_vec}[i];"
getter_logic.append(getter)
setter_logic.append(setter)
if i < len(bindings) - 1:
setter_logic.append(f"i += {num_expr};")
# Reserve / assert based on the total number of items expression.
num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs)
if len(bindings) > 0:
getter_logic.insert(1, f"{state_vec}.reserve({num_items});")
getter_logic.append(f"return {state_vec};")
return getter_logic, setter_logic, num_items
def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
bindings = extract_bindings(fn)
non_self_bindings = [b for b in bindings if b.name != "self"]
non_self_args = fn.func.arguments.flat_all[1:]
non_self_value_bindings = [
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
]
# Generate constructor / clone args for the generated struct.
constructor_args = [b.defn() for b in non_self_bindings]
clone_args = [b.name for b in non_self_bindings]
# Generate state variable declarations for the generated struct.
state_variables = [
f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings
]
# Generate initializer list expressions for the generated struct.
# allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as
# vector<SymInt>s.
init_exprs = translate(
non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True
)
initializers = []
for b, init_expr in zip(non_self_bindings, init_exprs):
name = b.nctype.name
assert isinstance(name, str)
initializers.append(f"{name}({init_expr.expr})")
# Generate call to underlying view op
call_input_name = "input_base"
op_call_args = [call_input_name, *(b.name for b in non_self_bindings)]
op_call = CALL_DISPATCH.substitute(
unambiguous_name=fn.func.name.unambiguous_name(),
unpacked_args=op_call_args,
)
# Multi-output views additionally require a view_idx for disambiguation.
if returns_multi_tensor(fn):
view_idx_name = "view_idx"
view_idx_typename = "int64_t"
view_idx_decl = f"{view_idx_typename} {view_idx_name}"
constructor_args.append(view_idx_decl)
clone_args.append(view_idx_name)
state_variables.append(f"{view_idx_decl};")
initializers.append(f"{view_idx_name}({view_idx_name})")
op_call += f"[{view_idx_name}]"
# Generate initializer list for the generated struct.
initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else ""
# Generate getter / setter logic for any symints.
symint_bindings = [
b
for b in non_self_bindings
if isinstance(b.argument, Argument) and b.argument.type.is_symint_like()
]
symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT)))
get_symints, set_symints, num_symints = generate_state_getter_setter(
symint_bindings, symints_vec_type
)
# Generate getter / setter logic for any tensors.
tensor_bindings = [
b
for b in non_self_bindings
if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like()
]
tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT)))
get_tensors, set_tensors, num_tensors = generate_state_getter_setter(
tensor_bindings, tensors_vec_type
)
return template.substitute(
op=view_func_name(fn),
uppercase_op=view_func_name(fn, camel_case=False).upper(),
superclass="torch::autograd::ViewFunc",
initializer_list=initializer_list,
state=state_variables,
constructor_args=constructor_args,
clone_args=clone_args,
symints_vec=symints_vec_type.name,
get_symints=get_symints,
set_symints=set_symints,
num_symints=num_symints,
tensors_vec=tensors_vec_type.name,
get_tensors=get_tensors,
set_tensors=set_tensors,
num_tensors=num_tensors,
call_input_name=call_input_name,
op_call=op_call,
)
def gen_view_funcs(
out: str,
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
template_path: str,
) -> None:
# don't need the info parts, just the function
fns = [fn.func for fn in fns_with_infos if use_derived(fn)]
# only want out-of-place views
view_fns = [
fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn)
]
declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns]
definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns]
ops_headers = [f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in view_fns]
file_basename = "ViewFuncs"
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
for suffix in [".h", ".cpp"]:
fname = file_basename + suffix
fm.write_with_template(
fname,
fname,
lambda: {
"generated_comment": "@"
+ f"generated from {fm.template_dir_for_comments()}/"
+ fname,
"view_func_declarations": declarations,
"view_func_definitions": definitions,
"ops_headers": ops_headers,
},
)

View File

@ -1,6 +1,5 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include "torch/csrc/autograd/VariableTypeUtils.h"
#include "torch/csrc/autograd/generated/ViewFuncs.h"
#include <torch/library.h>
#include <ATen/FunctionalInverses.h>

View File

@ -1,71 +0,0 @@
#include "torch/csrc/autograd/generated/ViewFuncs.h"
// ${generated_comment}
using at::Tensor;
using at::Scalar;
using at::IntArrayRef;
using at::TensorList;
namespace torch::autograd {
std::vector<c10::SymInt> ChainedViewFunc::get_symints() const {
auto symints = first->get_symints();
auto second_symints = second->get_symints();
symints.reserve(symints.size() + second_symints.size());
symints.insert(
symints.end(),
std::make_move_iterator(second_symints.begin()),
std::make_move_iterator(second_symints.end()));
return symints;
}
std::vector<at::Tensor> ChainedViewFunc::get_tensors() const {
auto tensors = first->get_tensors();
auto second_tensors = second->get_tensors();
tensors.reserve(tensors.size() + second_tensors.size());
tensors.insert(
tensors.end(),
std::make_move_iterator(second_tensors.begin()),
std::make_move_iterator(second_tensors.end()));
return tensors;
}
at::Tensor ChainedViewFunc::operator()(const at::Tensor& input_base) const {
return (*second)((*first)(input_base));
}
std::unique_ptr<ViewFunc> ChainedViewFunc::clone_and_set(
std::optional<std::vector<c10::SymInt>> symints,
std::optional<std::vector<at::Tensor>> tensors) const {
std::optional<std::vector<c10::SymInt>> first_symints;
std::optional<std::vector<c10::SymInt>> second_symints;
if (symints.has_value()) {
TORCH_INTERNAL_ASSERT(symints->size() == num_symints());
first_symints = std::vector<c10::SymInt>(
symints->begin(), symints->begin() + first->num_symints());
second_symints = std::vector<c10::SymInt>(
symints->begin() + first->num_symints(), symints->end());
}
std::optional<std::vector<at::Tensor>> first_tensors;
std::optional<std::vector<at::Tensor>> second_tensors;
if (tensors.has_value()) {
TORCH_INTERNAL_ASSERT(tensors->size() == num_tensors());
first_tensors = std::vector<at::Tensor>(
tensors->begin(), tensors->begin() + first->num_tensors());
second_tensors = std::vector<at::Tensor>(
tensors->begin() + first->num_tensors(), tensors->end());
}
return std::make_unique<ChainedViewFunc>(
first->clone_and_set(first_symints, first_tensors),
second->clone_and_set(second_symints, second_tensors));
}
namespace generated {
${view_func_definitions}
} // namespace torch::autograd
} // namespace generated

View File

@ -1,106 +0,0 @@
#pragma once
// ${generated_comment}
#include <torch/library.h>
#include <c10/core/SymIntArrayRef.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Operators.h>
#else
$ops_headers
#endif
namespace torch { namespace autograd {
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
virtual ~ViewFunc() {}
/// Returns any SymInts in the saved state.
virtual std::vector<c10::SymInt> get_symints() const { return {}; }
/// Returns the number of SymInts in the saved state.
virtual size_t num_symints() const { return 0; }
/// Returns any tensors in the saved state.
virtual std::vector<at::Tensor> get_tensors() const { return {}; }
/// Returns the number of tensors in the saved state.
virtual size_t num_tensors() const { return 0; }
/// Reapplies the view on the given base using the saved state.
virtual at::Tensor operator()(const at::Tensor&) const = 0;
/// Returns a clone of this ViewFunc, optionally with the specified saved state.
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;
protected:
/// Sets the values of any SymInts in the saved state. The input vector size must
/// match the number of SymInts in the saved state (i.e. the size of the list
/// returned by get_symints()).
virtual void set_symints(std::vector<c10::SymInt>) {}
/// Sets the values of any Tensors in the saved state. The input vector size must
/// match the number of Tensors in the saved state (i.e. the size of the list
/// returned by get_tensors()).
virtual void set_tensors(std::vector<at::Tensor>) {}
};
/// ViewFunc that represents a chain of two ViewFuncs.
struct ChainedViewFunc : public ViewFunc {
ChainedViewFunc(
std::unique_ptr<ViewFunc> first,
std::unique_ptr<ViewFunc> second)
: first(std::move(first)),
second(std::move(second)) {}
virtual ~ChainedViewFunc() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override {
return first->num_symints() + second->num_symints();
}
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override {
return first->num_tensors() + second->num_tensors();
}
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
private:
std::unique_ptr<ViewFunc> first;
std::unique_ptr<ViewFunc> second;
};
/// ViewFunc that errors with a specified error message when called.
struct ErroringViewFunc : public ViewFunc {
ErroringViewFunc(const std::string& error_msg) : error_msg(error_msg) {}
virtual ~ErroringViewFunc() override {};
virtual at::Tensor operator()(const at::Tensor&) const override {
TORCH_CHECK(false, error_msg);
}
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override {
return std::make_unique<ErroringViewFunc>(error_msg);
}
private:
std::string error_msg;
};
namespace generated {
using at::Scalar;
using at::Tensor;
using at::IntArrayRef;
using at::ArrayRef;
using at::Type;
using at::ScalarType;
using c10::optional;
using c10::fmap;
${view_func_declarations}
}}} // namespace torch::autograd::generated

View File

@ -9,7 +9,6 @@
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/generated/VariableType.h>
#include <torch/csrc/autograd/generated/ViewFuncs.h>
#include <torch/library.h>
#include <utility>
@ -475,10 +474,13 @@ static Tensor _fw_primal(
at::AutoDispatchBelowADInplaceOrView guard;
return at::alias(self);
})();
std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
std::function<at::Tensor(const at::Tensor&)> func = nullptr;
std::function<at::Tensor(const at::Tensor&)> rev_func = nullptr;
if (!self.unsafeGetTensorImpl()->support_as_strided()) {
func = std::make_unique<ViewViewFunc>(self.sym_sizes());
auto size_vec = self.sizes().vec();
func = [=](const at::Tensor& input_base) {
return input_base.view(size_vec);
};
rev_func = [=](const at::Tensor& input_view) {
TORCH_INTERNAL_ASSERT(
false,
@ -508,10 +510,13 @@ static Tensor _make_dual(
at::AutoDispatchBelowADInplaceOrView guard;
return at::alias(primal);
})();
std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
std::function<at::Tensor(const at::Tensor&)> func = nullptr;
std::function<at::Tensor(const at::Tensor&)> rev_func = nullptr;
if (!primal.unsafeGetTensorImpl()->support_as_strided()) {
func = std::make_unique<ViewViewFunc>(primal.sym_sizes());
auto size_vec = primal.sizes().vec();
func = [=](const at::Tensor& input_base) {
return input_base.view(size_vec);
};
rev_func = [=](const at::Tensor& input_view) {
TORCH_INTERNAL_ASSERT(
false,

View File

@ -160,7 +160,7 @@ inline at::Tensor as_view(
const at::Tensor& tensor,
bool is_bw_differentiable,
bool is_fw_differentiable,
std::unique_ptr<ViewFunc> view_func = nullptr,
std::function<at::Tensor(const at::Tensor&)> view_func = nullptr,
std::function<at::Tensor(const at::Tensor&)> rev_view_func = nullptr,
CreationMeta creation_meta = CreationMeta::DEFAULT,
bool allow_tensor_metadata_change = true) {
@ -208,13 +208,11 @@ inline at::Tensor as_view(
c10::optional<ViewInfo> new_fw_info;
if (is_bw_differentiable) {
auto bw_view_func = view_func ? view_func->clone_and_set() : nullptr;
if (diff_view_meta && diff_view_meta->has_bw_view()) {
const auto& base_bw_info = diff_view_meta->get_backward_view();
new_bw_info = base_bw_info.chain(
base, tensor, std::move(bw_view_func), rev_view_func);
new_bw_info = base_bw_info.chain(base, tensor, view_func, rev_view_func);
} else {
new_bw_info = ViewInfo(base, std::move(bw_view_func), rev_view_func);
new_bw_info = ViewInfo(base, view_func, rev_view_func);
}
} else {
TORCH_CHECK(

View File

@ -213,7 +213,7 @@ void AutogradMeta::set_fw_grad(
// - Copy the given new_grad into this view
// - Use this view as the new new_grad
if (this_view_meta->has_fw_view()) {
auto& view_info = this_view_meta->get_forward_view();
auto view_info = this_view_meta->get_forward_view();
auto& base = view_info.base_;
if (!base._fw_grad(level).defined()) {

View File

@ -530,14 +530,18 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl(
(*stack)[stack->size() - num_returns + aliased_output_idx];
// See NOTE [ View + Inplace detection ] for more details about this logic
const auto erroring_view_func = [op_name = op_name](const at::Tensor&) {
// We always need this view_func because otherwise if we do in-place
// on this view, we would implicitly use AsStridedBackward instead
// of the NotImplemented node. For the cross-dtype/non-strided
// cases, we would create something like this anyway
auto error_msg =
("Mutating the view " + op_name +
TORCH_CHECK(
false,
"Mutating the view ",
op_name,
" which does not have a derivative implemented is forbidden.");
auto erroring_view_func = std::make_unique<ErroringViewFunc>(error_msg);
return at::Tensor();
};
const auto erroring_rev_view_func = [op_name = op_name](const at::Tensor&) {
TORCH_CHECK(
@ -556,7 +560,7 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl(
/* tensor=*/sub_output,
/* is_bw_differentiable=*/true,
/* is_fw_differentiable=*/true,
/* view_func=*/std::move(erroring_view_func),
/* view_func=*/erroring_view_func,
/* rev_view_func=*/erroring_rev_view_func,
/* creation_meta=*/
InferenceMode::is_enabled()
@ -573,7 +577,7 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl(
/* tensor=*/std::move(aliased_output_iv).toTensor(),
/* is_bw_differentiable=*/true,
/* is_fw_differentiable=*/true,
/* view_func=*/std::move(erroring_view_func),
/* view_func=*/erroring_view_func,
/* rev_view_func=*/erroring_rev_view_func,
/* creation_meta=*/
InferenceMode::is_enabled()

View File

@ -55,7 +55,7 @@ variable_list CopyBackwards::apply_with_saved(
CopySlices::CopySlices(
const Variable& base_var,
at::TensorGeometry view_,
std::unique_ptr<ViewFunc> view_fn_,
std::function<at::Tensor(const at::Tensor&)> view_fn_,
std::shared_ptr<Node> fn_)
: Node(),
base(base_var),
@ -98,7 +98,7 @@ inline variable_list CopySlices::apply_impl(
at::Tensor grad_slice;
if (view_fn) {
grad_slice = (*view_fn)(result);
grad_slice = view_fn(result);
} else {
auto offset = view.sym_storage_offset() - base.sym_storage_offset();
grad_slice =

View File

@ -79,7 +79,7 @@ struct TORCH_API CopyBackwards : public Node {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
//
// We need to perform grad_view = fn(grad_view), but out-of-place.
// view_fn_ is an optional function saved in DifferentiableViewMeta
// view_fn_ is an optional lambda function saved in DifferentiableViewMeta
// from forward pass, so that we can recover we when as_strided is not
// supported. It preserves the invariants:
// view = view_fn_(base)
@ -160,7 +160,7 @@ struct TORCH_API CopySlices : public Node {
CopySlices(
const Variable& base_var,
at::TensorGeometry view_,
std::unique_ptr<ViewFunc> view_fn_,
std::function<at::Tensor(const at::Tensor&)> view_fn_,
std::shared_ptr<Node> fn_);
// common code between apply/apply_with_saved
@ -178,7 +178,7 @@ struct TORCH_API CopySlices : public Node {
// view and view_fn are redundant and view_fn will be used if available.
// See Note [View + Inplace update for base tensor] for details.
at::TensorGeometry view;
std::unique_ptr<ViewFunc> view_fn;
std::function<at::Tensor(const at::Tensor&)> view_fn;
std::shared_ptr<Node> fn;
};

View File

@ -524,36 +524,16 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
Py_RETURN_NONE;
}
// Maps the given python callable over a vector of items, returning a vector
// of the same type of items.
template <typename T>
static std::vector<T> map_py_func(
const py::function& func,
const std::vector<T>& items) {
std::vector<T> new_items;
new_items.reserve(items.size());
for (auto& item : items) {
new_items.push_back(py::cast<T>(func(item)));
}
return new_items;
}
static PyObject* view_func_impl(
PyObject* _self,
PyObject* args,
PyObject* kwargs,
PyObject* self_,
PyObject* arg,
bool check_has_same_meta) {
HANDLE_TH_ERRORS
const auto& self = THPVariable_Unpack(_self);
static PythonArgParser parser({
"_view_func(Tensor new_base, PyObject* symint_visitor_fn=None, PyObject* tensor_visitor_fn=None)",
});
ParsedArgs<3> parsed_args{};
auto r = parser.parse(_self, args, kwargs, parsed_args);
auto new_base = r.tensor(0);
PyObject* symint_visitor_fn = r.pyobject(1);
PyObject* tensor_visitor_fn = r.pyobject(2);
const auto& self = THPVariable_Unpack(self_);
TORCH_CHECK(
THPVariable_Check(arg),
"_view_func expect a single argument that is a Tensor");
const auto& new_base = THPVariable_Unpack(arg);
// Ensure that self is indeed a backward differentiable view
// If not, we return an undefined Tensor (None) and let the user handle it.
@ -566,29 +546,7 @@ static PyObject* view_func_impl(
torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
// Do the actual view replay
if (view_info.has_view_fn()) {
auto& view_func = view_info.view_fn();
// Determine new SymInt / tensor state as needed.
c10::optional<std::vector<c10::SymInt>> new_symints = c10::nullopt;
if (symint_visitor_fn != Py_None) {
new_symints = map_py_func(
py::cast<py::function>(symint_visitor_fn),
view_func.get_symints());
}
c10::optional<std::vector<at::Tensor>> new_tensors = c10::nullopt;
if (tensor_visitor_fn != Py_None) {
new_tensors = map_py_func(
py::cast<py::function>(tensor_visitor_fn),
view_func.get_tensors());
}
// call view func
if (new_symints.has_value() || new_tensors.has_value()) {
out = (*view_func.clone_and_set(new_symints, new_tensors))(new_base);
} else {
out = view_func(new_base);
}
out = view_info.view_fn()(new_base);
} else {
out = new_base.as_strided(
self.sizes(), self.strides(), self.storage_offset());
@ -599,18 +557,12 @@ static PyObject* view_func_impl(
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_view_func(
PyObject* self_,
PyObject* args,
PyObject* kwargs) {
return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/true);
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
return view_func_impl(self_, arg, /*check_has_same_meta=*/true);
}
static PyObject* THPVariable_view_func_unsafe(
PyObject* self_,
PyObject* args,
PyObject* kwargs) {
return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/false);
static PyObject* THPVariable_view_func_unsafe(PyObject* self_, PyObject* arg) {
return view_func_impl(self_, arg, /*check_has_same_meta=*/false);
}
static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) {
@ -1715,14 +1667,8 @@ static PyMethodDef extra_methods[] = {
METH_STATIC | METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
{"_view_func",
castPyCFunctionWithKeywords(THPVariable_view_func),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_view_func_unsafe",
castPyCFunctionWithKeywords(THPVariable_view_func_unsafe),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_view_func", THPVariable_view_func, METH_O, nullptr},
{"_view_func_unsafe", THPVariable_view_func_unsafe, METH_O, nullptr},
{"_rev_view_func_unsafe",
THPVariable_rev_view_func_unsafe,
METH_O,

View File

@ -8,7 +8,6 @@
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/functions/tensor.h>
#include <torch/csrc/autograd/generated/Functions.h>
#include <torch/csrc/autograd/generated/ViewFuncs.h>
#include <torch/csrc/autograd/utils/error_messages.h>
#include <ATen/ATen.h>
@ -26,19 +25,6 @@
namespace torch {
namespace autograd {
// Returns a ViewFunc with a corresponding view that matches the shape,
// stride, and storage offset of the given tensor.
// NB: On mobile, the as_strided() op and thus the generated AsStridedViewFunc
// may not be available.
static std::unique_ptr<ViewFunc> create_view_func_matching(const Variable& t) {
#ifdef AS_STRIDED_VIEW_FUNC_AVAILABLE
return std::make_unique<torch::autograd::generated::AsStridedViewFunc>(
t.sym_sizes(), t.sym_strides(), t.sym_storage_offset());
#else
return std::make_unique<ErroringViewFunc>("as_strided() not available");
#endif
}
DifferentiableViewMeta::DifferentiableViewMeta(
at::TensorImpl* self_impl,
c10::optional<ViewInfo> backward_info,
@ -72,7 +58,7 @@ DifferentiableViewMeta::DifferentiableViewMeta(
ViewInfo ViewInfo::chain(
const Variable& base,
const Variable& tensor,
std::unique_ptr<ViewFunc> view_func,
std::function<Variable(const Variable&)> view_func,
std::function<Variable(const Variable&)> rev_view_func) const {
// Set `view_func` using the root base as input.
// `view_func` is used to recover views in backward when either as_strided is
@ -83,8 +69,12 @@ ViewInfo ViewInfo::chain(
if (view_func) {
// both current_view and it's parent have a view_func
if (view_fn_) {
view_func = std::make_unique<ChainedViewFunc>(
view_fn_->clone_and_set(), std::move(view_func));
// Copy parent view function to gain ownership
auto prev_fn = view_fn_;
view_func = [=](const at::Tensor& root_base) {
auto temp = prev_fn(root_base);
return view_func(temp);
};
// assume view_fn_ / rev_view_fn_ always exist together or neither are set
auto prev_rev_fn = rev_view_fn_;
@ -95,9 +85,13 @@ ViewInfo ViewInfo::chain(
} else {
// current_view has a view_func and but it's parent doesn't have one
if (base.unsafeGetTensorImpl()->support_as_strided()) {
auto match_base_view_func = create_view_func_matching(base);
view_func = std::make_unique<ChainedViewFunc>(
std::move(match_base_view_func), std::move(view_func));
auto size = base.sym_sizes().vec();
auto stride = base.sym_strides().vec();
auto storage_offset = base.sym_storage_offset();
view_func = [=](const at::Tensor& root_base) {
auto temp = root_base.as_strided_symint(size, stride, storage_offset);
return view_func(temp);
};
// assume view_fn_ / rev_view_fn_ always exist together or neither are
// set
@ -117,7 +111,12 @@ ViewInfo ViewInfo::chain(
auto error_msg =
("Attempted to chain views when the parent view has no view_func() and "
"does not support as_strided(). This is not supported.");
view_func = std::make_unique<ErroringViewFunc>(error_msg);
view_func = [=](const at::Tensor& root_base) {
TORCH_CHECK(false, error_msg);
return root_base;
};
rev_view_func = [=](const at::Tensor& root_view) {
TORCH_CHECK(false, error_msg);
return root_view;
@ -126,9 +125,15 @@ ViewInfo ViewInfo::chain(
}
} else if (view_fn_) {
// if current_view doesn't have a view_func but it's parent has one
auto match_tensor_view_func = create_view_func_matching(tensor);
view_func = std::make_unique<ChainedViewFunc>(
view_fn_->clone_and_set(), std::move(match_tensor_view_func));
// Copy parent view function to gain ownership
auto prev_view_fn = view_fn_;
auto size = tensor.sym_sizes().vec();
auto stride = tensor.sym_strides().vec();
auto storage_offset = tensor.sym_storage_offset();
view_func = [=](const at::Tensor& root_base) {
auto temp = prev_view_fn(root_base);
return temp.as_strided_symint(size, stride, storage_offset);
};
// assume view_fn_ / rev_view_fn_ always exist together or neither are set
auto prev_rev_view_fn = rev_view_fn_;
@ -227,12 +232,12 @@ void rebase_history(const Variable& self, Edge gradient_edge) {
TORCH_CHECK(
gradient_edge.function->num_inputs() == 1,
"Functions which modify views in-place must return a single Variable");
const auto& view_info = diff_view_meta->get_backward_view();
auto view_info = diff_view_meta->get_backward_view();
diff_view_meta->output_nr_ = gradient_edge.input_nr;
auto copy_slices = std::make_shared<CopySlices>(
view_info.base_,
at::TensorGeometry(self),
view_info.has_view_fn() ? view_info.view_fn().clone_and_set() : nullptr,
view_info.view_fn_,
std::move(gradient_edge.function));
if (self.requires_grad()) {
// If self did not previously require grad, there are no hooks to move
@ -651,7 +656,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
if (diff_view_meta && diff_view_meta->has_bw_view()) {
// See NOTE [ View + Inplace detection ]
std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
auto& view_info = diff_view_meta->get_backward_view();
auto view_info = diff_view_meta->get_backward_view();
if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) {
return diff_view_meta->grad_fn_;
}
@ -691,7 +696,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
// in VariableType_x.cpp
// that would provide a way to recreate the grad_fn chain.
if (view_info.has_view_fn()) {
auto& view_fn = view_info.view_fn();
auto view_fn = view_info.view_fn();
Tensor diff_view;
{
// We can reach this path with grad_mode disabled, e.g. engine

View File

@ -7,7 +7,6 @@
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/forward_grad.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/generated/ViewFuncs.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/core/Tensor.h>
@ -332,8 +331,7 @@ struct TORCH_API ViewInfo {
/// By default we use as_strided to recover views which is more efficient.
/// view_fn is only saved when as_strided is not supported.
/// If view_fn has value, we use it to recover views in backward.
std::unique_ptr<ViewFunc> view_fn_;
std::function<Variable(const Variable&)> view_fn_;
/// Analogue of view_fn but in reverse: given a view -> produce the base by
/// applying the inverse view.
std::function<Variable(const Variable&)> rev_view_fn_;
@ -344,10 +342,10 @@ struct TORCH_API ViewInfo {
return view_fn_ != nullptr;
}
const ViewFunc& view_fn() const {
std::function<Variable(const Variable&)> view_fn() const {
TORCH_CHECK(
has_view_fn(), "Can only access the view function if it exists.");
return *view_fn_;
return view_fn_;
}
std::function<Variable(const Variable&)> rev_view_fn() const {
@ -368,12 +366,12 @@ struct TORCH_API ViewInfo {
ViewInfo chain(
const Variable& base,
const Variable& tensor,
std::unique_ptr<ViewFunc> view_func = nullptr,
std::function<Variable(const Variable&)> view_func = nullptr,
std::function<Variable(const Variable&)> rev_view_func = nullptr) const;
ViewInfo(
Variable base,
std::unique_ptr<ViewFunc> view_fn,
std::function<Variable(const Variable&)> view_fn,
std::function<Variable(const Variable&)> rev_view_fn)
: base_(std::move(base)),
view_fn_(std::move(view_fn)),