diff --git a/buckbuild.bzl b/buckbuild.bzl index 305a0cf3c896..380d330600ab 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -960,6 +960,7 @@ def define_buck_targets( "Functions.h": ":gen_aten_libtorch[autograd/generated/Functions.h]", "VariableType.h": ":gen_aten_libtorch[autograd/generated/VariableType.h]", "variable_factories.h": ":gen_aten_libtorch[autograd/generated/variable_factories.h]", + "ViewFuncs.h": ":gen_aten_libtorch[autograd/generated/ViewFuncs.h]", # Don't build python bindings on mobile. #"python_functions.h", }, @@ -1466,6 +1467,7 @@ def define_buck_targets( "torch/csrc/jit/mobile/train/random.cpp", "torch/csrc/jit/mobile/train/sequential.cpp", ":gen_aten_libtorch[autograd/generated/Functions.cpp]", + ":gen_aten_libtorch[autograd/generated/ViewFuncs.cpp]", ], compiler_flags = get_pt_compiler_flags(), exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"], diff --git a/build.bzl b/build.bzl index 6490a7f3839e..5ab9f92acecc 100644 --- a/build.bzl +++ b/build.bzl @@ -261,6 +261,7 @@ _GENERATED_AUTOGRAD_PYTHON_HEADERS = [ _GENERATED_AUTOGRAD_CPP_HEADERS = [ "torch/csrc/autograd/generated/Functions.h", "torch/csrc/autograd/generated/VariableType.h", + "torch/csrc/autograd/generated/ViewFuncs.h", "torch/csrc/autograd/generated/variable_factories.h", ] @@ -303,6 +304,7 @@ GENERATED_AUTOGRAD_CPP = [ "torch/csrc/autograd/generated/VariableType_2.cpp", "torch/csrc/autograd/generated/VariableType_3.cpp", "torch/csrc/autograd/generated/VariableType_4.cpp", + "torch/csrc/autograd/generated/ViewFuncs.cpp", "torch/csrc/autograd/generated/TraceType_0.cpp", "torch/csrc/autograd/generated/TraceType_1.cpp", "torch/csrc/autograd/generated/TraceType_2.cpp", diff --git a/build_variables.bzl b/build_variables.bzl index d423ad5793c7..71b52e8d5481 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -24,6 +24,7 @@ def libtorch_generated_sources(gencode_pattern): "torch/csrc/autograd/generated/VariableType_2.cpp", "torch/csrc/autograd/generated/VariableType_3.cpp", "torch/csrc/autograd/generated/VariableType_4.cpp", + "torch/csrc/autograd/generated/ViewFuncs.cpp", "torch/csrc/autograd/generated/TraceType_0.cpp", "torch/csrc/autograd/generated/TraceType_1.cpp", "torch/csrc/autograd/generated/TraceType_2.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index dd4d6dfcb437..3600398b6726 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -351,6 +351,7 @@ configure_file("${TORCH_SRC_DIR}/csrc/api/include/torch/version.h.in" set(GENERATED_CXX_TORCH "${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.cpp" + "${TORCH_SRC_DIR}/csrc/autograd/generated/ViewFuncs.cpp" ) if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER) @@ -380,6 +381,7 @@ endif() set(GENERATED_H_TORCH "${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.h" "${TORCH_SRC_DIR}/csrc/autograd/generated/variable_factories.h" + "${TORCH_SRC_DIR}/csrc/autograd/generated/ViewFuncs.h" ) if(NOT INTERN_DISABLE_AUTOGRAD) diff --git a/pt_template_srcs.bzl b/pt_template_srcs.bzl index 6bfcfc6f2310..6d42026ba6ca 100644 --- a/pt_template_srcs.bzl +++ b/pt_template_srcs.bzl @@ -131,6 +131,8 @@ def get_generate_code_bin_outs(): "autograd/generated/VariableType_3.cpp": ["autograd/generated/VariableType_3.cpp"], "autograd/generated/VariableType_4.cpp": ["autograd/generated/VariableType_4.cpp"], "autograd/generated/variable_factories.h": ["autograd/generated/variable_factories.h"], + "autograd/generated/ViewFuncs.cpp": ["autograd/generated/ViewFuncs.cpp"], + "autograd/generated/ViewFuncs.h": ["autograd/generated/ViewFuncs.h"], } if is_arvr_mode(): diff --git a/test/test_autograd.py b/test/test_autograd.py index bce6b3538c6c..ec570fab4c6a 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -8792,6 +8792,7 @@ get_out().sum().backward() _assert_match_metadata(new_inp, inp) new_out = out._view_func(new_inp) _assert_match_metadata(new_out, out) + self.assertEqual(new_out, out) # reverse view_func new_out = out.detach() @@ -8830,7 +8831,7 @@ get_out().sum().backward() _test_fn( lambda x: x.chunk(2, -1)[0].transpose(0, 1).unsqueeze(-1), torch.randn(2, 3, 4)) _test_fn( - lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, -1), torch.randn(2, 3, 4)) + lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, 0), torch.randn(2, 3, 4)) # chains with missing view_func()s use as_strided() to cover the gaps def chain_with_only_parent_view_func(x): @@ -8838,7 +8839,7 @@ get_out().sum().backward() x = x.split_with_sizes([1, 3], -1)[0] with torch.autograd._force_original_view_tracking(False): - x = x.chunk(2, -1) + x = x.chunk(2, 0) return x @@ -8849,12 +8850,50 @@ get_out().sum().backward() x = x.split_with_sizes([1, 3], -1)[0] with torch.autograd._force_original_view_tracking(True): - x = x.chunk(2, -1) + x = x.chunk(2, 0) return x _test_fn(chain_with_only_current_view_func, torch.randn(2, 3, 4)) + def test_view_func_replay_with_modified_state(self): + with torch.autograd._force_original_view_tracking(True): + base = torch.randn(3, 4, 5) + view = base.select(1, 2) + + def symint_visitor_fn(x): + # modify saved index + return x + 1 + + # ensure modifying state changes view replay + new_base = torch.randn_like(base) + new_view = view._view_func(new_base, symint_visitor_fn=symint_visitor_fn) + self.assertEqual(new_view, new_base.select(1, 3)) + + # ensure saved state reverts back afterwards + self.assertEqual(view._view_func(new_base), new_base.select(1, 2)) + + # check modifying tensor state. currently, slice_inverse() is the only + # view that saves a tensor + base = torch.randn(3, 4, 5) + sliced = base[:, 2:3, :].detach() + view = torch.ops.aten.slice_inverse(sliced, base, 1, 2, 3, 1) + + replacement_shape = (1, 2, 3) + + def tensor_visitor_fn(x): + # return tensor with a smaller shape than the saved one + return torch.randn(*replacement_shape) + + # ensure modifying state changes view replay + new_sliced = torch.ones_like(base)[:, 2:3, :].detach() + new_view = view._view_func(new_sliced, tensor_visitor_fn=tensor_visitor_fn) + self.assertEqual(new_view.shape, replacement_shape) + self.assertEqual(new_view, new_sliced.as_strided(replacement_shape, (6, 3, 1))) + + # ensure saved state reverts back afterwards + self.assertEqual(view._view_func(sliced), base) + def test_setup_context_when_forward_has_default_args(self): class PowFunction(Function): @staticmethod diff --git a/test/test_ops.py b/test/test_ops.py index 74a1f86b01e7..22023059ea62 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1564,6 +1564,7 @@ class TestCompositeCompliance(TestCase): _assert_match_metadata(new_inp, inp) new_out = out._view_func_unsafe(new_inp) _assert_match_metadata(new_out, out) + self.assertEqual(new_out, out) # reverse view_func new_out = out.detach() diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl index 7ad0882e44f5..5b410274ff77 100644 --- a/tools/BUCK.bzl +++ b/tools/BUCK.bzl @@ -124,6 +124,8 @@ def define_tools_targets( "autograd/templates/TraceType.cpp", "autograd/templates/VariableType.cpp", "autograd/templates/VariableType.h", + "autograd/templates/ViewFuncs.cpp", + "autograd/templates/ViewFuncs.h", "autograd/templates/annotated_fn_args.py.in", "autograd/templates/python_enum_tag.cpp", "autograd/templates/python_fft_functions.cpp", diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index c4d1df00a95d..0d4aa91d3fad 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -43,6 +43,7 @@ from .gen_inplace_or_view_type import gen_inplace_or_view_type from .gen_trace_type import gen_trace_type from .gen_variable_factories import gen_variable_factories from .gen_variable_type import gen_variable_type +from .gen_view_funcs import gen_view_funcs from .load_derivatives import load_derivatives @@ -95,6 +96,9 @@ def gen_autograd( # Generate variable_factories.h gen_variable_factories(out, native_functions_path, tags_path, template_path) + # Generate ViewFuncs.h/cpp + gen_view_funcs(out, fns_with_diff_infos, template_path) + def gen_autograd_python( native_functions_path: str, diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index 6e713579445d..df9841312fab 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -4,7 +4,7 @@ # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp # The fallback is expected to mimick this codegen, so we should keep the two in sync. -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Tuple from torchgen.api import cpp from torchgen.api.autograd import ( @@ -172,7 +172,7 @@ for (auto ${view_idx} : c10::irange(${var}.size())) { SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate( """\ -std::function func=nullptr; +std::unique_ptr func(nullptr); std::function rev_func=nullptr; if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided() || @@ -184,11 +184,9 @@ if (${is_view_with_metadata_change} || """ ) -REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate( +REPLAY_VIEW_FUNC = CodeTemplate( """\ -func = [=](const at::Tensor& ${input_base}) { - return ${replay_view_call}${view_indexing}; -}; +func = std::make_unique<${view_func_name}>(${view_func_args}); """ ) @@ -346,24 +344,13 @@ def get_view_info(f: NativeFunction) -> Optional[str]: return view_info -# For view replay calls, we generate an ordinary Dispatcher::call() instead, because: -# - We want to replay the entire call into the op, including any previously-set dispatch keys (including autograd!). -# - The view replay call also is not part of the hot path. -def emit_view_call( - f: NativeFunction, input_base: str, unpacked_args: Sequence[str] -) -> str: - # View replay functions use the standard Dispatcher::call API. - return CALL_DISPATCH.substitute( - unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=unpacked_args - ) - - -def emit_view_lambda( +def emit_view_func( f: NativeFunction, bindings: List[Binding], view_idx: Optional[str] = None ) -> str: """Generate an additional lambda function to recover views in backward when as_strided is not supported. See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details. """ + # TODO: Clean this logic up if we get rid of reverse view funcs or reify them. input_base = "input_base" replay_view_func = "" updated_args: List[str] = [] @@ -412,11 +399,14 @@ def emit_view_lambda( else: updated_args.append(arg) - replay_view_call = emit_view_call(f, input_base, updated_args) - replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute( - input_base=input_base, - replay_view_call=replay_view_call, - view_indexing=("" if view_idx is None else f"[{view_idx}]"), + from .gen_view_funcs import view_func_name + + view_func_args = [b.name for b in bindings if b.name != "self"] + if view_idx is not None: + view_func_args.append(f"{view_idx}") + replay_view_func += REPLAY_VIEW_FUNC.substitute( + view_func_name=view_func_name(f, include_namespace=True), + view_func_args=view_func_args, ) input_view = "input_view" @@ -493,26 +483,26 @@ def emit_view_body( if is_tensor_list_type(return_info.type): creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE") view_idx = "view_idx" - view_lambda = emit_view_lambda( + view_func = emit_view_func( f, extract_bindings(f), view_idx=view_idx ).strip() as_view_call = ( f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], " "/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, " - "/* view_func */ func, /* rev_view_func */ rev_func, " + "/* view_func */ std::move(func), /* rev_view_func */ rev_func, " f"/* creation_meta */ {creation_meta});" ) call += MULTI_OUTPUT_VIEW_ITERATION.substitute( - var=var, view_idx=view_idx, body=f"{view_lambda}\n{as_view_call}" + var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}" ) rhs_value = f"std::move({var})" else: - call += emit_view_lambda(f, extract_bindings(f), view_idx=None) + call += emit_view_func(f, extract_bindings(f), view_idx=None) creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT") rhs_value = ( f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, " "/* is_fw_differentiable */ true, " - f"/* view_func */ func, /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})" + f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})" ) else: # This could be supported but we don't need it at the moment, so keeping things simple. diff --git a/tools/autograd/gen_view_funcs.py b/tools/autograd/gen_view_funcs.py new file mode 100644 index 000000000000..c9f7561dca17 --- /dev/null +++ b/tools/autograd/gen_view_funcs.py @@ -0,0 +1,334 @@ +# Generates ViewFuncs.h/cpp +# +# NOTE: If any changes are being made to the ViewFunc codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# The fallback is expected to mimic this codegen, so we should keep the two in sync. + +from typing import List, Tuple + +import torchgen.api.dispatcher as dispatcher +from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + NamedCType, + SymIntT, + tensorT, + VectorCType, +) +from torchgen.code_template import CodeTemplate +from torchgen.model import Argument, NativeFunction, OptionalType +from torchgen.utils import FileManager + +from .gen_inplace_or_view_type import ( + CALL_DISPATCH, + extract_bindings, + get_view_info, + modifies_arguments, + use_derived, +) + +FUNCTION_DECLARATION = CodeTemplate( + """\ +#define ${uppercase_op}_AVAILABLE +struct ${op} : public ${superclass} { + ${op}(${constructor_args}) ${initializer_list} + {}; + virtual ~${op}() override {}; + virtual std::vector get_symints() const override; + virtual size_t num_symints() const override; + virtual std::vector get_tensors() const override; + virtual size_t num_tensors() const override; + virtual at::Tensor operator()(const at::Tensor&) const override; + virtual std::unique_ptr clone_and_set( + std::optional> = c10::nullopt, + std::optional> = c10::nullopt) const override; + +protected: + virtual void set_symints(std::vector) override; + virtual void set_tensors(std::vector) override; + +private: + ${state} +}; + +""" +) + +FUNCTION_DEFINITION = CodeTemplate( + """\ +std::vector ${op}::get_symints() const { + ${get_symints} +} + +size_t ${op}::num_symints() const { + return static_cast(${num_symints}); +} + +void ${op}::set_symints(std::vector ${symints_vec}) { + TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints()); + ${set_symints} +} + +std::vector ${op}::get_tensors() const { + ${get_tensors} +} + +size_t ${op}::num_tensors() const { + return static_cast(${num_tensors}); +} + +void ${op}::set_tensors(std::vector ${tensors_vec}) { + TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors()); + ${set_tensors} +} + +at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const { + return ${op_call}; +} + +std::unique_ptr ${op}::clone_and_set( + std::optional> ${symints_vec}, + std::optional> ${tensors_vec}) const { + auto output = std::make_unique<${op}>(${clone_args}); + if (${symints_vec}.has_value()) { + output->set_symints(std::move(*(${symints_vec}))); + } + if (${tensors_vec}.has_value()) { + output->set_tensors(std::move(*(${tensors_vec}))); + } + return output; +} + +""" +) + + +# e.g. as_strided -> AsStridedViewFunc for camel case or +# as_strided_view_func otherwise +def view_func_name( + f: NativeFunction, include_namespace: bool = False, camel_case: bool = True +) -> str: + name = f.func.name.unambiguous_name() + view_func_name = f"{name.replace('.', '_')}_view_func" + if camel_case: + is_private = view_func_name.startswith("_") + view_func_name = "".join( + [p.title() for p in view_func_name.replace(".", "_").split("_")] + ) + if is_private: + # put the leading underscore back in + view_func_name = f"_{view_func_name}" + namespace = "torch::autograd::generated::" if include_namespace else "" + return f"{namespace}{view_func_name}" + + +def is_symint_or_tensor(arg: Argument) -> bool: + return arg.type.is_tensor_like() or arg.type.is_symint_like() + + +def remove_const_ref(binding: Binding) -> Binding: + return Binding( + name=binding.name, + nctype=binding.nctype.remove_const_ref(), + argument=binding.argument, + default=binding.default, + ) + + +def returns_multi_tensor(fn: NativeFunction) -> bool: + returns = fn.func.returns + assert len(returns) == 1 + returns_list_like = returns[0].type.is_list_like() is not None + returns_tensor_like = returns[0].type.is_tensor_like() + return returns_list_like and returns_tensor_like + + +# Generates strings with logic for getting / setting state of a particular type. +# +# Args: +# bindings (list): List of state bindings of interest (may be empty) +# state_vec_type (NamedCType): Type of vector to either return or copy from +# +# Returns: +# tuple: (list of getter logic strings, list of setter logic strings, string +# with num items expression) +def generate_state_getter_setter( + bindings: List[Binding], + state_vec_type: NamedCType, +) -> Tuple[List[str], List[str], str]: + getter_logic = [] + setter_logic = [] + + state_vec = state_vec_type.name + getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};") + if len(bindings) > 0: + setter_logic.append("auto i = 0;") + + num_exprs = [] + for i, b in enumerate(bindings): + assert isinstance(b.argument, Argument) + if b.argument.type.is_list_like(): + # Handle list-likes. + num_expr = f"{b.name}.size()" + num_exprs.append(num_expr) + getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());" + setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());" + elif isinstance(b.argument.type, OptionalType): + # Handle optionals. + num_expr = f"({b.name}.has_value() ? 1 : 0)" + num_exprs.append(num_expr) + conditional = f"if({b.name}.has_value())" + getter = ( + f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));" + ) + setter = f"{conditional} {b.name} = {state_vec}[i];" + else: + num_expr = "1" + num_exprs.append(num_expr) + getter = f"{state_vec}.push_back({b.name});" + setter = f"{b.name} = {state_vec}[i];" + + getter_logic.append(getter) + setter_logic.append(setter) + if i < len(bindings) - 1: + setter_logic.append(f"i += {num_expr};") + + # Reserve / assert based on the total number of items expression. + num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs) + if len(bindings) > 0: + getter_logic.insert(1, f"{state_vec}.reserve({num_items});") + + getter_logic.append(f"return {state_vec};") + + return getter_logic, setter_logic, num_items + + +def process_function(fn: NativeFunction, template: CodeTemplate) -> str: + bindings = extract_bindings(fn) + non_self_bindings = [b for b in bindings if b.name != "self"] + + non_self_args = fn.func.arguments.flat_all[1:] + non_self_value_bindings = [ + dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args + ] + + # Generate constructor / clone args for the generated struct. + constructor_args = [b.defn() for b in non_self_bindings] + clone_args = [b.name for b in non_self_bindings] + + # Generate state variable declarations for the generated struct. + state_variables = [ + f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings + ] + + # Generate initializer list expressions for the generated struct. + # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as + # vectors. + init_exprs = translate( + non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True + ) + initializers = [] + for b, init_expr in zip(non_self_bindings, init_exprs): + name = b.nctype.name + assert isinstance(name, str) + initializers.append(f"{name}({init_expr.expr})") + + # Generate call to underlying view op + call_input_name = "input_base" + op_call_args = [call_input_name, *(b.name for b in non_self_bindings)] + op_call = CALL_DISPATCH.substitute( + unambiguous_name=fn.func.name.unambiguous_name(), + unpacked_args=op_call_args, + ) + + # Multi-output views additionally require a view_idx for disambiguation. + if returns_multi_tensor(fn): + view_idx_name = "view_idx" + view_idx_typename = "int64_t" + view_idx_decl = f"{view_idx_typename} {view_idx_name}" + constructor_args.append(view_idx_decl) + clone_args.append(view_idx_name) + state_variables.append(f"{view_idx_decl};") + initializers.append(f"{view_idx_name}({view_idx_name})") + op_call += f"[{view_idx_name}]" + + # Generate initializer list for the generated struct. + initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else "" + + # Generate getter / setter logic for any symints. + symint_bindings = [ + b + for b in non_self_bindings + if isinstance(b.argument, Argument) and b.argument.type.is_symint_like() + ] + symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT))) + get_symints, set_symints, num_symints = generate_state_getter_setter( + symint_bindings, symints_vec_type + ) + + # Generate getter / setter logic for any tensors. + tensor_bindings = [ + b + for b in non_self_bindings + if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like() + ] + tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT))) + get_tensors, set_tensors, num_tensors = generate_state_getter_setter( + tensor_bindings, tensors_vec_type + ) + + return template.substitute( + op=view_func_name(fn), + uppercase_op=view_func_name(fn, camel_case=False).upper(), + superclass="torch::autograd::ViewFunc", + initializer_list=initializer_list, + state=state_variables, + constructor_args=constructor_args, + clone_args=clone_args, + symints_vec=symints_vec_type.name, + get_symints=get_symints, + set_symints=set_symints, + num_symints=num_symints, + tensors_vec=tensors_vec_type.name, + get_tensors=get_tensors, + set_tensors=set_tensors, + num_tensors=num_tensors, + call_input_name=call_input_name, + op_call=op_call, + ) + + +def gen_view_funcs( + out: str, + fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], + template_path: str, +) -> None: + # don't need the info parts, just the function + fns = [fn.func for fn in fns_with_infos if use_derived(fn)] + # only want out-of-place views + view_fns = [ + fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn) + ] + + declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns] + definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns] + ops_headers = [f"#include " for fn in view_fns] + + file_basename = "ViewFuncs" + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + for suffix in [".h", ".cpp"]: + fname = file_basename + suffix + fm.write_with_template( + fname, + fname, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/" + + fname, + "view_func_declarations": declarations, + "view_func_definitions": definitions, + "ops_headers": ops_headers, + }, + ) diff --git a/tools/autograd/templates/ADInplaceOrViewType.cpp b/tools/autograd/templates/ADInplaceOrViewType.cpp index 7a19047dd5c9..e8276697eee0 100644 --- a/tools/autograd/templates/ADInplaceOrViewType.cpp +++ b/tools/autograd/templates/ADInplaceOrViewType.cpp @@ -1,5 +1,6 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include "torch/csrc/autograd/VariableTypeUtils.h" +#include "torch/csrc/autograd/generated/ViewFuncs.h" #include #include diff --git a/tools/autograd/templates/ViewFuncs.cpp b/tools/autograd/templates/ViewFuncs.cpp new file mode 100644 index 000000000000..0f104023ba2c --- /dev/null +++ b/tools/autograd/templates/ViewFuncs.cpp @@ -0,0 +1,71 @@ +#include "torch/csrc/autograd/generated/ViewFuncs.h" + +// ${generated_comment} + +using at::Tensor; +using at::Scalar; +using at::IntArrayRef; +using at::TensorList; + +namespace torch::autograd { + +std::vector ChainedViewFunc::get_symints() const { + auto symints = first->get_symints(); + auto second_symints = second->get_symints(); + symints.reserve(symints.size() + second_symints.size()); + symints.insert( + symints.end(), + std::make_move_iterator(second_symints.begin()), + std::make_move_iterator(second_symints.end())); + return symints; +} + +std::vector ChainedViewFunc::get_tensors() const { + auto tensors = first->get_tensors(); + auto second_tensors = second->get_tensors(); + tensors.reserve(tensors.size() + second_tensors.size()); + tensors.insert( + tensors.end(), + std::make_move_iterator(second_tensors.begin()), + std::make_move_iterator(second_tensors.end())); + return tensors; +} + +at::Tensor ChainedViewFunc::operator()(const at::Tensor& input_base) const { + return (*second)((*first)(input_base)); +} + +std::unique_ptr ChainedViewFunc::clone_and_set( + std::optional> symints, + std::optional> tensors) const { + std::optional> first_symints; + std::optional> second_symints; + if (symints.has_value()) { + TORCH_INTERNAL_ASSERT(symints->size() == num_symints()); + first_symints = std::vector( + symints->begin(), symints->begin() + first->num_symints()); + second_symints = std::vector( + symints->begin() + first->num_symints(), symints->end()); + } + + std::optional> first_tensors; + std::optional> second_tensors; + if (tensors.has_value()) { + TORCH_INTERNAL_ASSERT(tensors->size() == num_tensors()); + first_tensors = std::vector( + tensors->begin(), tensors->begin() + first->num_tensors()); + second_tensors = std::vector( + tensors->begin() + first->num_tensors(), tensors->end()); + } + + return std::make_unique( + first->clone_and_set(first_symints, first_tensors), + second->clone_and_set(second_symints, second_tensors)); +} + +namespace generated { + +${view_func_definitions} + +} // namespace torch::autograd +} // namespace generated diff --git a/tools/autograd/templates/ViewFuncs.h b/tools/autograd/templates/ViewFuncs.h new file mode 100644 index 000000000000..f9acce387a9d --- /dev/null +++ b/tools/autograd/templates/ViewFuncs.h @@ -0,0 +1,106 @@ +#pragma once + +// ${generated_comment} + +#include + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +namespace torch { namespace autograd { + +/// Base class for view functions, providing reapplication of a view on a new base. +/// Each view op should get a codegenerated subclass of this class containing +/// any state needed to reconstruct the view. The class also provides convenience +/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification, +/// where we want to use symbolic values or fake tensors instead. +struct TORCH_API ViewFunc { + virtual ~ViewFunc() {} + /// Returns any SymInts in the saved state. + virtual std::vector get_symints() const { return {}; } + /// Returns the number of SymInts in the saved state. + virtual size_t num_symints() const { return 0; } + /// Returns any tensors in the saved state. + virtual std::vector get_tensors() const { return {}; } + /// Returns the number of tensors in the saved state. + virtual size_t num_tensors() const { return 0; } + /// Reapplies the view on the given base using the saved state. + virtual at::Tensor operator()(const at::Tensor&) const = 0; + /// Returns a clone of this ViewFunc, optionally with the specified saved state. + virtual std::unique_ptr clone_and_set( + std::optional> = c10::nullopt, + std::optional> = c10::nullopt) const = 0; + +protected: + /// Sets the values of any SymInts in the saved state. The input vector size must + /// match the number of SymInts in the saved state (i.e. the size of the list + /// returned by get_symints()). + virtual void set_symints(std::vector) {} + /// Sets the values of any Tensors in the saved state. The input vector size must + /// match the number of Tensors in the saved state (i.e. the size of the list + /// returned by get_tensors()). + virtual void set_tensors(std::vector) {} +}; + +/// ViewFunc that represents a chain of two ViewFuncs. +struct ChainedViewFunc : public ViewFunc { + ChainedViewFunc( + std::unique_ptr first, + std::unique_ptr second) + : first(std::move(first)), + second(std::move(second)) {} + virtual ~ChainedViewFunc() override {}; + virtual std::vector get_symints() const override; + virtual size_t num_symints() const override { + return first->num_symints() + second->num_symints(); + } + virtual std::vector get_tensors() const override; + virtual size_t num_tensors() const override { + return first->num_tensors() + second->num_tensors(); + } + virtual at::Tensor operator()(const at::Tensor&) const override; + virtual std::unique_ptr clone_and_set( + std::optional> = c10::nullopt, + std::optional> = c10::nullopt) const override; + +private: + std::unique_ptr first; + std::unique_ptr second; +}; + +/// ViewFunc that errors with a specified error message when called. +struct ErroringViewFunc : public ViewFunc { + ErroringViewFunc(const std::string& error_msg) : error_msg(error_msg) {} + virtual ~ErroringViewFunc() override {}; + virtual at::Tensor operator()(const at::Tensor&) const override { + TORCH_CHECK(false, error_msg); + } + virtual std::unique_ptr clone_and_set( + std::optional> = c10::nullopt, + std::optional> = c10::nullopt) const override { + return std::make_unique(error_msg); + } + +private: + std::string error_msg; +}; + +namespace generated { + +using at::Scalar; +using at::Tensor; +using at::IntArrayRef; +using at::ArrayRef; +using at::Type; +using at::ScalarType; +using c10::optional; +using c10::fmap; + +${view_func_declarations} + +}}} // namespace torch::autograd::generated diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 83e963d64f62..38a63640c11e 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -474,13 +475,10 @@ static Tensor _fw_primal( at::AutoDispatchBelowADInplaceOrView guard; return at::alias(self); })(); - std::function func = nullptr; + std::unique_ptr func(nullptr); std::function rev_func = nullptr; if (!self.unsafeGetTensorImpl()->support_as_strided()) { - auto size_vec = self.sizes().vec(); - func = [=](const at::Tensor& input_base) { - return input_base.view(size_vec); - }; + func = std::make_unique(self.sym_sizes()); rev_func = [=](const at::Tensor& input_view) { TORCH_INTERNAL_ASSERT( false, @@ -510,13 +508,10 @@ static Tensor _make_dual( at::AutoDispatchBelowADInplaceOrView guard; return at::alias(primal); })(); - std::function func = nullptr; + std::unique_ptr func(nullptr); std::function rev_func = nullptr; if (!primal.unsafeGetTensorImpl()->support_as_strided()) { - auto size_vec = primal.sizes().vec(); - func = [=](const at::Tensor& input_base) { - return input_base.view(size_vec); - }; + func = std::make_unique(primal.sym_sizes()); rev_func = [=](const at::Tensor& input_view) { TORCH_INTERNAL_ASSERT( false, diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index 8f87317d8478..9794ca9a4ad6 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -160,7 +160,7 @@ inline at::Tensor as_view( const at::Tensor& tensor, bool is_bw_differentiable, bool is_fw_differentiable, - std::function view_func = nullptr, + std::unique_ptr view_func = nullptr, std::function rev_view_func = nullptr, CreationMeta creation_meta = CreationMeta::DEFAULT, bool allow_tensor_metadata_change = true) { @@ -208,11 +208,13 @@ inline at::Tensor as_view( c10::optional new_fw_info; if (is_bw_differentiable) { + auto bw_view_func = view_func ? view_func->clone_and_set() : nullptr; if (diff_view_meta && diff_view_meta->has_bw_view()) { const auto& base_bw_info = diff_view_meta->get_backward_view(); - new_bw_info = base_bw_info.chain(base, tensor, view_func, rev_view_func); + new_bw_info = base_bw_info.chain( + base, tensor, std::move(bw_view_func), rev_view_func); } else { - new_bw_info = ViewInfo(base, view_func, rev_view_func); + new_bw_info = ViewInfo(base, std::move(bw_view_func), rev_view_func); } } else { TORCH_CHECK( diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index bea070a973a3..f6de7599c6b0 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -213,7 +213,7 @@ void AutogradMeta::set_fw_grad( // - Copy the given new_grad into this view // - Use this view as the new new_grad if (this_view_meta->has_fw_view()) { - auto view_info = this_view_meta->get_forward_view(); + auto& view_info = this_view_meta->get_forward_view(); auto& base = view_info.base_; if (!base._fw_grad(level).defined()) { diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index b334e6f097f7..3659a826b913 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -530,18 +530,14 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl( (*stack)[stack->size() - num_returns + aliased_output_idx]; // See NOTE [ View + Inplace detection ] for more details about this logic - const auto erroring_view_func = [op_name = op_name](const at::Tensor&) { - // We always need this view_func because otherwise if we do in-place - // on this view, we would implicitly use AsStridedBackward instead - // of the NotImplemented node. For the cross-dtype/non-strided - // cases, we would create something like this anyway - TORCH_CHECK( - false, - "Mutating the view ", - op_name, - " which does not have a derivative implemented is forbidden."); - return at::Tensor(); - }; + // We always need this view_func because otherwise if we do in-place + // on this view, we would implicitly use AsStridedBackward instead + // of the NotImplemented node. For the cross-dtype/non-strided + // cases, we would create something like this anyway + auto error_msg = + ("Mutating the view " + op_name + + "which does not have a derivative implemented is forbidden."); + auto erroring_view_func = std::make_unique(error_msg); const auto erroring_rev_view_func = [op_name = op_name](const at::Tensor&) { TORCH_CHECK( @@ -560,7 +556,7 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl( /* tensor=*/sub_output, /* is_bw_differentiable=*/true, /* is_fw_differentiable=*/true, - /* view_func=*/erroring_view_func, + /* view_func=*/std::move(erroring_view_func), /* rev_view_func=*/erroring_rev_view_func, /* creation_meta=*/ InferenceMode::is_enabled() @@ -577,7 +573,7 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl( /* tensor=*/std::move(aliased_output_iv).toTensor(), /* is_bw_differentiable=*/true, /* is_fw_differentiable=*/true, - /* view_func=*/erroring_view_func, + /* view_func=*/std::move(erroring_view_func), /* rev_view_func=*/erroring_rev_view_func, /* creation_meta=*/ InferenceMode::is_enabled() diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index 2f5015f9a2cf..a47ac8f77cad 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -55,7 +55,7 @@ variable_list CopyBackwards::apply_with_saved( CopySlices::CopySlices( const Variable& base_var, at::TensorGeometry view_, - std::function view_fn_, + std::unique_ptr view_fn_, std::shared_ptr fn_) : Node(), base(base_var), @@ -98,7 +98,7 @@ inline variable_list CopySlices::apply_impl( at::Tensor grad_slice; if (view_fn) { - grad_slice = view_fn(result); + grad_slice = (*view_fn)(result); } else { auto offset = view.sym_storage_offset() - base.sym_storage_offset(); grad_slice = diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h index 29f9259170f3..6e99ed6ae2aa 100644 --- a/torch/csrc/autograd/functions/tensor.h +++ b/torch/csrc/autograd/functions/tensor.h @@ -79,7 +79,7 @@ struct TORCH_API CopyBackwards : public Node { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // // We need to perform grad_view = fn(grad_view), but out-of-place. -// view_fn_ is an optional lambda function saved in DifferentiableViewMeta +// view_fn_ is an optional function saved in DifferentiableViewMeta // from forward pass, so that we can recover we when as_strided is not // supported. It preserves the invariants: // view = view_fn_(base) @@ -160,7 +160,7 @@ struct TORCH_API CopySlices : public Node { CopySlices( const Variable& base_var, at::TensorGeometry view_, - std::function view_fn_, + std::unique_ptr view_fn_, std::shared_ptr fn_); // common code between apply/apply_with_saved @@ -178,7 +178,7 @@ struct TORCH_API CopySlices : public Node { // view and view_fn are redundant and view_fn will be used if available. // See Note [View + Inplace update for base tensor] for details. at::TensorGeometry view; - std::function view_fn; + std::unique_ptr view_fn; std::shared_ptr fn; }; diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 10c3c2de5300..52dda12906ac 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -525,16 +525,36 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) { Py_RETURN_NONE; } +// Maps the given python callable over a vector of items, returning a vector +// of the same type of items. +template +static std::vector map_py_func( + const py::function& func, + const std::vector& items) { + std::vector new_items; + new_items.reserve(items.size()); + for (auto& item : items) { + new_items.push_back(py::cast(func(item))); + } + return new_items; +} + static PyObject* view_func_impl( - PyObject* self_, - PyObject* arg, + PyObject* _self, + PyObject* args, + PyObject* kwargs, bool check_has_same_meta) { HANDLE_TH_ERRORS - const auto& self = THPVariable_Unpack(self_); - TORCH_CHECK( - THPVariable_Check(arg), - "_view_func expect a single argument that is a Tensor"); - const auto& new_base = THPVariable_Unpack(arg); + const auto& self = THPVariable_Unpack(_self); + + static PythonArgParser parser({ + "_view_func(Tensor new_base, PyObject* symint_visitor_fn=None, PyObject* tensor_visitor_fn=None)", + }); + ParsedArgs<3> parsed_args{}; + auto r = parser.parse(_self, args, kwargs, parsed_args); + auto new_base = r.tensor(0); + PyObject* symint_visitor_fn = r.pyobject(1); + PyObject* tensor_visitor_fn = r.pyobject(2); // Ensure that self is indeed a backward differentiable view // If not, we return an undefined Tensor (None) and let the user handle it. @@ -547,7 +567,29 @@ static PyObject* view_func_impl( torch::autograd::utils::has_same_meta(new_base, view_info.base_)) { // Do the actual view replay if (view_info.has_view_fn()) { - out = view_info.view_fn()(new_base); + auto& view_func = view_info.view_fn(); + + // Determine new SymInt / tensor state as needed. + c10::optional> new_symints = c10::nullopt; + if (symint_visitor_fn != Py_None) { + new_symints = map_py_func( + py::cast(symint_visitor_fn), + view_func.get_symints()); + } + + c10::optional> new_tensors = c10::nullopt; + if (tensor_visitor_fn != Py_None) { + new_tensors = map_py_func( + py::cast(tensor_visitor_fn), + view_func.get_tensors()); + } + + // call view func + if (new_symints.has_value() || new_tensors.has_value()) { + out = (*view_func.clone_and_set(new_symints, new_tensors))(new_base); + } else { + out = view_func(new_base); + } } else { out = new_base.as_strided( self.sizes(), self.strides(), self.storage_offset()); @@ -558,12 +600,18 @@ static PyObject* view_func_impl( END_HANDLE_TH_ERRORS } -static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) { - return view_func_impl(self_, arg, /*check_has_same_meta=*/true); +static PyObject* THPVariable_view_func( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { + return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/true); } -static PyObject* THPVariable_view_func_unsafe(PyObject* self_, PyObject* arg) { - return view_func_impl(self_, arg, /*check_has_same_meta=*/false); +static PyObject* THPVariable_view_func_unsafe( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { + return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/false); } static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) { @@ -1668,8 +1716,14 @@ static PyMethodDef extra_methods[] = { METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, {"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr}, - {"_view_func", THPVariable_view_func, METH_O, nullptr}, - {"_view_func_unsafe", THPVariable_view_func_unsafe, METH_O, nullptr}, + {"_view_func", + castPyCFunctionWithKeywords(THPVariable_view_func), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_view_func_unsafe", + castPyCFunctionWithKeywords(THPVariable_view_func_unsafe), + METH_VARARGS | METH_KEYWORDS, + nullptr}, {"_rev_view_func_unsafe", THPVariable_rev_view_func_unsafe, METH_O, diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 821eea07c4b7..5eab310766ce 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -25,6 +26,19 @@ namespace torch { namespace autograd { +// Returns a ViewFunc with a corresponding view that matches the shape, +// stride, and storage offset of the given tensor. +// NB: On mobile, the as_strided() op and thus the generated AsStridedViewFunc +// may not be available. +static std::unique_ptr create_view_func_matching(const Variable& t) { +#ifdef AS_STRIDED_VIEW_FUNC_AVAILABLE + return std::make_unique( + t.sym_sizes(), t.sym_strides(), t.sym_storage_offset()); +#else + return std::make_unique("as_strided() not available"); +#endif +} + DifferentiableViewMeta::DifferentiableViewMeta( at::TensorImpl* self_impl, c10::optional backward_info, @@ -58,7 +72,7 @@ DifferentiableViewMeta::DifferentiableViewMeta( ViewInfo ViewInfo::chain( const Variable& base, const Variable& tensor, - std::function view_func, + std::unique_ptr view_func, std::function rev_view_func) const { // Set `view_func` using the root base as input. // `view_func` is used to recover views in backward when either as_strided is @@ -69,12 +83,8 @@ ViewInfo ViewInfo::chain( if (view_func) { // both current_view and it's parent have a view_func if (view_fn_) { - // Copy parent view function to gain ownership - auto prev_fn = view_fn_; - view_func = [=](const at::Tensor& root_base) { - auto temp = prev_fn(root_base); - return view_func(temp); - }; + view_func = std::make_unique( + view_fn_->clone_and_set(), std::move(view_func)); // assume view_fn_ / rev_view_fn_ always exist together or neither are set auto prev_rev_fn = rev_view_fn_; @@ -85,13 +95,9 @@ ViewInfo ViewInfo::chain( } else { // current_view has a view_func and but it's parent doesn't have one if (base.unsafeGetTensorImpl()->support_as_strided()) { - auto size = base.sym_sizes().vec(); - auto stride = base.sym_strides().vec(); - auto storage_offset = base.sym_storage_offset(); - view_func = [=](const at::Tensor& root_base) { - auto temp = root_base.as_strided_symint(size, stride, storage_offset); - return view_func(temp); - }; + auto match_base_view_func = create_view_func_matching(base); + view_func = std::make_unique( + std::move(match_base_view_func), std::move(view_func)); // assume view_fn_ / rev_view_fn_ always exist together or neither are // set @@ -111,12 +117,7 @@ ViewInfo ViewInfo::chain( auto error_msg = ("Attempted to chain views when the parent view has no view_func() and " "does not support as_strided(). This is not supported."); - - view_func = [=](const at::Tensor& root_base) { - TORCH_CHECK(false, error_msg); - return root_base; - }; - + view_func = std::make_unique(error_msg); rev_view_func = [=](const at::Tensor& root_view) { TORCH_CHECK(false, error_msg); return root_view; @@ -125,15 +126,9 @@ ViewInfo ViewInfo::chain( } } else if (view_fn_) { // if current_view doesn't have a view_func but it's parent has one - // Copy parent view function to gain ownership - auto prev_view_fn = view_fn_; - auto size = tensor.sym_sizes().vec(); - auto stride = tensor.sym_strides().vec(); - auto storage_offset = tensor.sym_storage_offset(); - view_func = [=](const at::Tensor& root_base) { - auto temp = prev_view_fn(root_base); - return temp.as_strided_symint(size, stride, storage_offset); - }; + auto match_tensor_view_func = create_view_func_matching(tensor); + view_func = std::make_unique( + view_fn_->clone_and_set(), std::move(match_tensor_view_func)); // assume view_fn_ / rev_view_fn_ always exist together or neither are set auto prev_rev_view_fn = rev_view_fn_; @@ -232,12 +227,12 @@ void rebase_history(const Variable& self, Edge gradient_edge) { TORCH_CHECK( gradient_edge.function->num_inputs() == 1, "Functions which modify views in-place must return a single Variable"); - auto view_info = diff_view_meta->get_backward_view(); + const auto& view_info = diff_view_meta->get_backward_view(); diff_view_meta->output_nr_ = gradient_edge.input_nr; auto copy_slices = std::make_shared( view_info.base_, at::TensorGeometry(self), - view_info.view_fn_, + view_info.has_view_fn() ? view_info.view_fn().clone_and_set() : nullptr, std::move(gradient_edge.function)); if (self.requires_grad()) { // If self did not previously require grad, there are no hooks to move @@ -656,7 +651,7 @@ const std::shared_ptr& VariableHooks::grad_fn( if (diff_view_meta && diff_view_meta->has_bw_view()) { // See NOTE [ View + Inplace detection ] std::lock_guard lock(diff_view_meta->mutex_); - auto view_info = diff_view_meta->get_backward_view(); + auto& view_info = diff_view_meta->get_backward_view(); if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) { return diff_view_meta->grad_fn_; } @@ -696,7 +691,7 @@ const std::shared_ptr& VariableHooks::grad_fn( // in VariableType_x.cpp // that would provide a way to recreate the grad_fn chain. if (view_info.has_view_fn()) { - auto view_fn = view_info.view_fn(); + auto& view_fn = view_info.view_fn(); Tensor diff_view; { // We can reach this path with grad_mode disabled, e.g. engine diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 2ff5b149aeb2..1c98c3317afa 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -331,7 +332,8 @@ struct TORCH_API ViewInfo { /// By default we use as_strided to recover views which is more efficient. /// view_fn is only saved when as_strided is not supported. /// If view_fn has value, we use it to recover views in backward. - std::function view_fn_; + std::unique_ptr view_fn_; + /// Analogue of view_fn but in reverse: given a view -> produce the base by /// applying the inverse view. std::function rev_view_fn_; @@ -342,10 +344,10 @@ struct TORCH_API ViewInfo { return view_fn_ != nullptr; } - std::function view_fn() const { + const ViewFunc& view_fn() const { TORCH_CHECK( has_view_fn(), "Can only access the view function if it exists."); - return view_fn_; + return *view_fn_; } std::function rev_view_fn() const { @@ -366,12 +368,12 @@ struct TORCH_API ViewInfo { ViewInfo chain( const Variable& base, const Variable& tensor, - std::function view_func = nullptr, + std::unique_ptr view_func = nullptr, std::function rev_view_func = nullptr) const; ViewInfo( Variable base, - std::function view_fn, + std::unique_ptr view_fn, std::function rev_view_fn) : base_(std::move(base)), view_fn_(std::move(view_fn)),