Reland: Add base forward grad logic (#49734)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49734

RFC: https://github.com/pytorch/rfcs/pull/11

This PR add the basic logic to handle forward grad as dual Tensors.
It contains the following:
- Mechanism to save dual state on a Tensor and clear it up when the dual level ends
- C++ and python user facing API
- Updated view system that is able to track both forward and backward views

The current PR has the following limitations:
- Extensive tests are in the next PR in the stack as formulas are needed to write full tests.
- Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack)
- Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR.
- We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise.
- We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise.

Reading guide:
- Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d). This introduces the new ViewInfo to hold view informations shared for forward and backward. It also updates the differentiable view meta to use this. And it updates the as_view function to handle both forward and backward view.
- New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325). EDIT: These files also contain the new flag to globally disable forward AD that allows us to reduce performance issues while this is in development.
- Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677)
- API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243)
- c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9)
- python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d)
- python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8)
- c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3)
- Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433)
- Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030)

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D25678797

Pulled By: albanD

fbshipit-source-id: 3d58550c11b5f58b9b73fd30596d042b857fb9dd
This commit is contained in:
albanD
2020-12-22 12:07:00 -08:00
committed by Facebook GitHub Bot
parent eabe05ab72
commit c23808d8e8
37 changed files with 1444 additions and 155 deletions

View File

@ -292,6 +292,11 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
stream << ", axis: " << tensor_.q_per_channel_axis();
}
}
auto& fw_grad = tensor.fw_grad(/* level */ 0);
if (fw_grad.defined()) {
stream << ", tangent:" << std::endl << fw_grad;
}
stream << " ]";
}
return stream;

View File

@ -510,4 +510,5 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("_version", CppFunction::makeFallthrough());
m.impl("requires_grad_", CppFunction::makeFallthrough());
m.impl("retain_grad", CppFunction::makeFallthrough());
m.impl("_fw_primal", CppFunction::makeFallthrough());
}

View File

@ -0,0 +1,27 @@
#include <ATen/ATen.h>
namespace at {
namespace native {
/// This function can be used to create a dual Tensor that holds a tangent to compute forward mode gradients.
/// Note that the dual Tensor's primal is a view of the given primal and the given tangent is used as-is.
/// This function is backward differentiable.
at::Tensor make_dual(const at::Tensor& primal, const at::Tensor& tangent, int64_t level) {
TORCH_CHECK(!primal.fw_grad(level).defined(), "Making a dual Tensor based on a Tensor that "
"already has a forward gradient at the same level ", level, " is not supported.");
auto dual_tensor = primal.view(primal.sizes());
dual_tensor.set_fw_grad(tangent, level, /* is_inplace_op */ false);
return dual_tensor;
}
/// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal
/// is a view of the dual and the tangent is returned as is.
/// This function is backward differentiable.
std::tuple<at::Tensor, at::Tensor> unpack_dual(const at::Tensor& tensor, int64_t level) {
return std::tuple<at::Tensor, at::Tensor>(tensor._fw_primal(level), tensor.fw_grad(level));
}
} // namespace native
} // namespace at

View File

@ -40,5 +40,9 @@ void retain_grad(Tensor& self) {
AT_ERROR("retain_grad is not implemented for Tensor");
}
Tensor _fw_primal(const Tensor& self, int64_t level) {
AT_ERROR("_fw_primal is not implemented for Tensor");
}
} // namespace native
} // namespace at

View File

@ -105,6 +105,20 @@
manual_kernel_registration: True
variants: method
- func: _fw_primal(Tensor(a) self, int level) -> Tensor(a)
use_c10_dispatcher: full
variants: method
dispatch:
DefaultBackend: _fw_primal
- func: make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a)
use_c10_dispatcher: full
variants: function
- func: unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent)
use_c10_dispatcher: full
variants: function
- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
use_c10_dispatcher: full
variants: method

View File

@ -599,6 +599,23 @@ class TORCH_API Tensor {
return impl_->grad();
}
// The Forward AD API functions below are low level and are not to be used by end
// users who should use the API provided in torch/csrc/autograd.h
/// This function returns the forward gradient for this Tensor at the given level.
const Tensor& fw_grad(uint64_t level) const {
return impl_->fw_grad(level, *this);
}
/// This function can be used to set the value of the forward grad.
/// Note that the given new_grad might not be used directly if it has different
/// metadata (size/stride/storage offset) compared to this Tensor. In that case,
/// new_grad content will be copied into a new Tensor
void set_fw_grad(const Tensor& new_grad, uint64_t level, bool is_inplace_op) {
impl_->set_fw_grad(new_grad, *this, level, is_inplace_op);
}
// STOP. Thinking of adding a method here, which only makes use
// of other ATen methods? Define it in native_functions.yaml.

View File

@ -44,6 +44,17 @@ const at::Tensor& TensorImpl::grad() const {
return autograd_meta_->grad();
}
const at::Tensor& TensorImpl::fw_grad(uint64_t level, const at::Tensor& self) const {
// See TensorImpl::grad() above for explanation about the line below
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor();
return autograd_meta_->fw_grad(level, self);
}
void TensorImpl::set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) {
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
}
TensorImpl::TensorImpl(
Storage&& storage,
DispatchKeySet key_set,

View File

@ -136,6 +136,8 @@ struct C10_API AutogradMetaInterface {
virtual bool requires_grad() const = 0;
virtual at::Tensor& mutable_grad() = 0;
virtual const at::Tensor& grad() const = 0;
virtual const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const = 0;
virtual void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) = 0;
virtual ~AutogradMetaInterface();
};
@ -598,6 +600,42 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
*/
const at::Tensor& grad() const;
/**
* Return the accumulated gradient of a tensor. This gradient is computed
* using forward mode AD.
*
* This is an internal API that should never be used by end users.
*
* The API is as follows:
* - "level" allows to specify the level of forward AD nesting for which the
* gradient should be returned. Note that since levels are not fully
* supported yet, this argument should be 0. See documentation for
* torch::autograd::enter_dual_level for more details about forward AD nesting.
* - "self" should represent the Tensor whose forward grad is accessed. It is
* required when dealing with view.
*/
const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const;
/**
* Sets the forward gradient for this Tensor.
* The given Tensor might not be used directly and its content will be copied.
*
* This is an internal API that should never be used by end users.
*
* The API is as follows:
* - "new_grad" is a Tensor containing the new value of the gradient that should
* be set
* - "self" should reprensent the Tensor whose forward grad is accessed. It is
* required when dealing with view.
* - "level" allows to specify the level of forward AD nesting for which the
* gradient should be set. Note that since levels are not fully supported
* yet, this argument should be 0. See documentation for torch::autograd::enter_dual_level
* for more details about forward AD nesting.
* - "is_inplace_op" is a boolean flag that tells if this gradient was generated
* by an inplace operation or an out of place one. This allows better error checking.
*/
void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op);
/**
* Return a typed data pointer to the actual data which this tensor refers to.
* This checks that the requested type (from the template parameter) matches

View File

@ -35,6 +35,7 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoL
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck)
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
from torch.autograd.function import InplaceFunction
import torch.autograd.forward_ad as fwAD
from torch.testing import randn_like
from torch.testing._internal.common_methods_invocations import (method_tests,
create_input, unpack_variables,
@ -5326,6 +5327,26 @@ class TestAutogradComplex(TestCase):
self.assertEqual(x.grad, y.grad)
def test_view_with_multi_output(self):
x = torch.randn(2, 2, 2, dtype=torch.double)
x1 = torch.view_as_complex(x)
# Taking an invalid view should always be allowed as long as it is not
# modified inplace
res = x1.unbind(0)
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
res[0] += torch.rand(2, requires_grad=True)
x.requires_grad_(True)
x1 = torch.view_as_complex(x)
# Taking an invalid view should always be allowed as long as it is not
# modified inplace
res = x1.unbind(0)
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
res[0] += torch.rand(2, requires_grad=True)
def as_identity(self):
# view_as_real and view_as_complex behavior should be like an identity
def func(z):
@ -6324,6 +6345,66 @@ class TestAutogradFunctional(TestCase):
self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1))
self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0))
class TestAutogradForwardMode(TestCase):
def test_forward_level_cleanup(self):
import weakref
def get_tensor_and_weak_ref():
# Helper function to get a Tensor and a weak ref that tells us
# if the c++ version of this Tensor is still alive or not.
#
# Create the following reference chain to do so:
# - python Tensor t
# - c++ Tensor corresponding by t
# - c++ Node corresponding to t.grad_fn
# - python dict of metadata from this Node
# - an object in this dict that we can take a weakref of
# Create a new Tensor and Node
t = torch.rand(2, requires_grad=True).clone()
# Create the metadata dict
meta_dict = t.grad_fn.metadata
# Create the object in the dict
class Foo(object):
pass
my_obj = Foo()
meta_dict[0] = my_obj
# After exiting this function, the python Tensor t is the only
# thing keeping ref alive
ref = weakref.ref(my_obj)
return t, ref
# Sanity check that the helper function works as expected
t, t_ref = get_tensor_and_weak_ref()
self.assertIsNotNone(t_ref())
del t
self.assertIsNone(t_ref())
# Main test code
foo = torch.rand(2)
with fwAD.dual_level():
tangent, tangent_ref = get_tensor_and_weak_ref()
self.assertIsNotNone(tangent_ref())
dual = fwAD.make_dual(foo, tangent)
self.assertIsNotNone(tangent_ref())
# Make sure that the tangent we provided has been re-used as is
self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent)
# Make sure that dual is keeping the tangent alive
del tangent
self.assertIsNotNone(tangent_ref())
# Make sure that the dual level does not keep the c++
# version of the tangent alive
del dual
self.assertIsNone(tangent_ref())
# Generic device type autograd tests.
class TestAutogradDeviceType(TestCase):

View File

@ -12,7 +12,7 @@ aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.
all_operators_with_namedtuple_return = {
'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig',
'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq',
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh'
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual"
}
@ -65,6 +65,7 @@ class TestNamedTupleAPI(unittest.TestCase):
op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True),
op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True),
op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True),
op(operators=['unpack_dual'], input=(a, 0), names=('primal', 'tangent'), hasout=False),
]
for op in operators:
@ -75,7 +76,9 @@ class TestNamedTupleAPI(unittest.TestCase):
for i, name in enumerate(op.names):
self.assertIs(getattr(ret, name), ret[i])
else:
ret = getattr(a, f)(*op.input)
# Handle op that are not methods
func = getattr(a, f) if hasattr(a, f) else getattr(torch, f)
ret = func(*op.input)
for i, name in enumerate(op.names):
self.assertIs(getattr(ret, name), ret[i])
if op.hasout:

View File

@ -80,7 +80,8 @@ SKIP_PYTHON_BINDINGS = [
'nonzero(_(out|numpy))?',
'set_data',
'.*_overrideable', # overrideable functions for backend extension
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_'
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_',
'_fw_primal'
]
# These function signatures are not exposed to Python. Note that this signature

View File

@ -25,7 +25,7 @@ MANUAL_BACKEND = set([
# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys.
# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
MANUAL_AUTOGRAD_AND_TRACER = set([
'resize_', 'resize_as_', 'detach', 'detach_', 'copy_',
'resize_', 'resize_as_', 'detach', 'detach_', 'copy_', '_fw_primal',
])
# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops:

View File

@ -689,7 +689,7 @@ def emit_body(declaration):
if len(differentiable_output_vars) == 0:
# no output is differentiable (.indices() for SparseTensors for example)
rhs_value = 'as_view({}, {}, /* is_differentiable */ false)'.format(view_info, var)
rhs_value = f'as_view({view_info}, {var}, /* is_bw_differentiable */ false, /* is_fw_differentiable */ false)'
elif len(differentiable_output_vars) == 1:
# Single differentiable output (Tensor or Tensor[])
return_info = differentiable_outputs[0]
@ -704,13 +704,15 @@ def emit_body(declaration):
creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE"
else:
creation_meta = "CreationMeta::MULTI_OUTPUT_NODE"
call += ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, "
"/* creation_meta */ {});\n").format(view_info, var, creation_meta)
call += ("as_view(/* base */ {}, /* output */ {}, /* is_bw_differentiable */ true, "
"/* is_fw_differentiable */ true, "
"/* creation_meta */ {});").format(view_info, var, creation_meta)
rhs_value = 'std::move({})'.format(var)
else:
call += emit_view_lambda()
creation_meta = "GradMode::is_enabled() ? CreationMeta::DEFAULT: CreationMeta::NO_GRAD_MODE"
rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, "
rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_bw_differentiable */ true, "
"/* is_fw_differentiable */ true, "
"/* view_func */ func, /* creation_meta */ {})").format(view_info, var, creation_meta)
else:
# This could be supported but we don't need it at the moment, so keeping things simple.

View File

@ -90,6 +90,8 @@ core_sources_common = [
"torch/csrc/autograd/profiler_legacy.cpp",
"torch/csrc/autograd/profiler_kineto.cpp",
"torch/csrc/autograd/profiler_utils.cpp",
"torch/csrc/autograd/autograd_meta.cpp",
"torch/csrc/autograd/forward_grad.cpp",
"torch/csrc/jit/frontend/edit_distance.cpp",
"torch/csrc/jit/frontend/string_to_type.cpp",
"torch/csrc/jit/mobile/type_parser.cpp",

View File

@ -522,6 +522,12 @@ def autocast_increment_nesting() -> _int: ...
def autocast_decrement_nesting() -> _int: ...
def set_anomaly_enabled(enabled: _bool) -> None: ...
def is_anomaly_enabled() -> _bool: ...
def _enter_dual_level() -> _int: ...
def _exit_dual_level(level: _int) -> None: ...
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...
def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ...
def __set_forward_AD_enabled(enabled: _bool) -> None: ...
def __is_forward_AD_enabled() -> _bool: ...
# Defined in torch/csrc/jit/python/script_init.cpp
class LoggerBase(object):

View File

@ -275,11 +275,16 @@ def get_summarized_data(self):
else:
return torch.stack([get_summarized_data(x) for x in self])
def _str_intern(self):
def _str_intern(inp):
prefix = 'tensor('
indent = len(prefix)
suffixes = []
# This is used to extract the primal value and thus disable the forward AD
# within this function.
# TODO(albanD) This needs to be updated when more than one level is supported
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
# Note [Print tensor device]:
# A general logic here is we only print device when it doesn't match
# the device specified in default tensor type.
@ -355,17 +360,22 @@ def _str_intern(self):
if self.layout != torch.strided:
suffixes.append('layout=' + str(self.layout))
if self.grad_fn is not None:
name = type(self.grad_fn).__name__
# Use inp here to get the original grad_fn and not the one generated by the forward grad
# unpacking.
if inp.grad_fn is not None:
name = type(inp.grad_fn).__name__
if name == 'CppFunction':
name = self.grad_fn.name().rsplit('::', 1)[-1]
name = inp.grad_fn.name().rsplit('::', 1)[-1]
suffixes.append('grad_fn=<{}>'.format(name))
elif self.requires_grad:
elif inp.requires_grad:
suffixes.append('requires_grad=True')
if self.has_names():
suffixes.append('names={}'.format(self.names))
if tangent is not None:
suffixes.append('tangent={}'.format(tangent))
return _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)
def _str(self):

View File

@ -19,6 +19,7 @@ from .grad_mode import no_grad, enable_grad, set_grad_enabled
from .anomaly_mode import detect_anomaly, set_detect_anomaly
from ..overrides import has_torch_function, handle_torch_function
from . import functional
from . import forward_ad
__all__ = ['Variable', 'Function', 'backward', 'grad_mode']

View File

@ -0,0 +1,116 @@
import torch
from .grad_mode import _DecoratorContextManager
from typing import Any
# TODO(alband): Once most of the formulas are implemented, these functions need to be added
# to the main doc to make them fully "public".
# Global variable used to make the python API simpler to use
_current_level = -1
def enter_dual_level():
r"""Function that can be used to enter a new forward grad level.
This level can be used to make and unpack dual Tensors to compute
forward gradients.
This function also updates the current level that is used by default
by the other functions in this API.
"""
global _current_level
new_level = torch._C._enter_dual_level()
if new_level != _current_level + 1:
raise RuntimeError("Entering a new forward AD level but the current level "
"is not valid. Make sure you did not modified it directly.")
_current_level = new_level
return new_level
def exit_dual_level(*, level=None):
r"""Function that can be used to exit a forward grad level.
This function deletes all the gradients associated with this
level. Only deleting the latest entered level is allowed.
This function also updates the current level that is used by default
by the other functions in this API.
"""
global _current_level
if level is None:
level = _current_level
if level != _current_level:
raise RuntimeError("Trying to exit a forward AD level that was not the last one "
"that was created. This is not supported.")
torch._C._exit_dual_level(level=level)
_current_level = level - 1
def make_dual(tensor, tangent, *, level=None):
r"""Function that creates a "dual object" that can be used to compute forward AD gradients
based on the given Tensor and its tangent. It returns a new Tensor that shares memory with
:attr:`tensor` and the :attr:`tangent` is used as-is.
This function is backward differentiable.
Given a function `f` whose jacobian is `J`, it allows to compute the jacobian vector product,
named `jvp`, between `J` and a given vector `v` as follows.
Example::
>>> inp = make_dual(x, v)
>>> out = f(inp)
>>> y, jvp = unpack_dual(out)
"""
if level is None:
level = _current_level
if level < 0:
raise RuntimeError("Trying to create a dual Tensor for forward AD but no level "
"exists, make sure to enter_dual_level() first.")
return torch.make_dual(tensor, tangent, level=level)
def unpack_dual(tensor, *, level=None):
r"""Function that unpacks a "dual object" to recover two plain tensors, one representing
the primal and the other the tangent (both are views of :attr:`tensor`. Neither of these
tensors can be dual tensor of level :attr:`level`.
This function is backward differentiable.
"""
if level is None:
level = _current_level
if level < 0:
return tensor, None
return torch.unpack_dual(tensor, level=level)
class dual_level(_DecoratorContextManager):
r"""Context-manager that controls the current forward ad level. It
appropriately enters and exit the dual level.
This function also updates the current level that is used by default
by the other functions in this API.
Example::
>>> x = torch.tensor([1])
>>> x_t = torch.tensor([1])
>>> with dual_level():
... inp = make_dual(x, x_t)
... # Do computations with inp
... out = your_fn(inp)
... _, grad = unpack_dual(out)
>>> grad is None
False
>>> # After exiting the level, the grad is deleted
>>> _, grad_after = unpack_dual(out)
>>> grad is None
True
"""
def __init__(self):
super().__init__()
def __enter__(self):
return enter_dual_level()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
exit_dual_level()

View File

@ -35,10 +35,22 @@ bool isDefined(const c10::optional<Tensor>& t) {
return t.has_value() && t->defined();
}
bool isFwGradDefined(const c10::optional<Tensor>& t) {
return t.has_value() && t->defined() && t->fw_grad(/*level */ 0).defined();
}
Tensor toLegacyTensor(const c10::optional<Tensor>& t) {
return t.has_value() ? *t : Tensor();
}
Tensor toLegacyFwGrad(const c10::optional<Tensor>& t) {
return (t.has_value() && t->defined()) ? t->fw_grad(/*level */ 0) : Tensor();
}
Tensor toLegacyPrimal(const c10::optional<Tensor>& t) {
return (t.has_value() && t->defined()) ? t->_fw_primal(/*level */ 0) : Tensor();
}
void copy_range(variable_list& out, IndexRange range, const Tensor & t) {
AT_ASSERT(range.second <= out.size());
AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output");

View File

@ -29,6 +29,10 @@ struct IndexRangeGenerator {
size_t i = 0;
};
bool isFwGradDefined(const c10::optional<Tensor>& t);
Tensor toLegacyFwGrad(const c10::optional<Tensor>& t);
Tensor toLegacyPrimal(const c10::optional<Tensor>& t);
bool any_variable_defined(variable_list& variables);
void copy_range(variable_list& out, IndexRange range, const at::Tensor & t);
void copy_range(variable_list& out, IndexRange range, at::ArrayRef<at::Tensor> t);

View File

@ -139,6 +139,7 @@ TORCH_LIBRARY_IMPL(aten, Tracer, m) {
m.impl("_version", CppFunction::makeFallthrough());
m.impl("requires_grad_", CppFunction::makeFallthrough());
m.impl("retain_grad", CppFunction::makeFallthrough());
m.impl("_fw_primal", CppFunction::makeFallthrough());
}
} // namespace

View File

@ -1,6 +1,7 @@
#include <c10/util/Optional.h>
#include <c10/core/ScalarType.h>
#include <torch/csrc/autograd/VariableTypeUtils.h>
#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/autograd/utils/error_messages.h>
#include <torch/csrc/autograd/autograd.h>
@ -194,6 +195,39 @@ void retain_grad(Tensor & self) {
impl::get_autograd_meta(self)->retains_grad_ = true;
}
// Taken from codegened version
Tensor _fw_primal(const Tensor & self, int64_t level) {
auto& self_ = unpack(self, "self", 0);
std::shared_ptr<Identity> grad_fn;
if (compute_requires_grad( self )) {
grad_fn = std::make_shared<Identity>();
grad_fn->set_next_edges(collect_next_edges( self ));
}
auto tmp = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true);
return self_.alias();
})();
c10::optional<std::function<at::Tensor(const at::Tensor&)>> func=c10::nullopt;
if (!self.unsafeGetTensorImpl()->support_as_strided()) {
auto size_vec = self.sizes().vec();
func = [=](const at::Tensor& input_base) {
return input_base.view(size_vec);
};
}
auto result = as_view(/* base */ self, /* output */ tmp, /* is_bw_differentiable */ true,
/* is_fw_differentiable */ false, /* view_func */ func, /* creation_meta */ CreationMeta::DEFAULT);
if (grad_fn) {
set_history(flatten_tensor_args( result ), grad_fn);
}
if (generated::details::isFwGradDefined(self)) {
// Modified from original codegen
// We explicitly want to ignore the forward grad at the given level
TORCH_CHECK(level == 0, "Invalid level given to _fw_primal");
// End modified from original codegen
}
return result;
}
// We don't have an outplace copy, so this can't be generated automatically
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
jit::Value* output = nullptr;
@ -217,6 +251,24 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
}
increment_version(self);
rebase_history(self , std::move(grad_fn));
if (isDifferentiableType(self.scalar_type()) &&
(generated::details::isFwGradDefined(self) || generated::details::isFwGradDefined(src))) {
auto self_fw_grad = generated::details::toLegacyFwGrad(self);
auto src_fw_grad = generated::details::toLegacyFwGrad(src);
Tensor new_fw_grad;
if (self_fw_grad.defined()) {
if (src_fw_grad.defined()) {
new_fw_grad = self_fw_grad.copy_(src_fw_grad);
} else {
new_fw_grad = self_fw_grad.fill_(0);
}
} else {
new_fw_grad = src_fw_grad;
}
self.set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ true);
}
return self;
}
@ -232,6 +284,11 @@ Tensor& resize_(
at::AutoNonVariableTypeMode non_var_type_mode(true);
self_.resize_(size, std::move(optional_memory_format));
}
if (self.fw_grad(/* level */ 0).defined()) {
AT_ERROR("cannot resize variables that has a forward grad");
}
return self;
}
@ -248,13 +305,28 @@ Tensor& resize_as_(
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::resize_as_(self_, the_template_, std::move(optional_memory_format));
}
// Handle fw grad
if (self.fw_grad(/* level */ 0).defined()) {
AT_ERROR("cannot resize variables that has a forward grad");
}
return self;
}
Tensor detach(const Tensor & self) {
RECORD_FUNCTION("detach", std::vector<c10::IValue>({self}));
auto result = make_variable_non_differentiable_view(self, self, /*allow_tensor_metadata_change=*/false);
c10::optional<std::function<at::Tensor(const at::Tensor&)>> func=c10::nullopt;
auto result = as_view(/* base */ self, /* output */ self, /* is_bw_differentiable */ false,
/* is_fw_differentiable */ true, /* view_func */ func, /* creation_meta */ CreationMeta::DEFAULT,
/*allow_tensor_metadata_change=*/false);
namedinference::propagate_names(result, self);
// detach only backward gradients for both primal and tangent
if (self.fw_grad(/* level */ 0).defined()) {
auto new_fw_grad = self.fw_grad(/* level */ 0).detach();
result.set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
}
return result;
}
@ -264,7 +336,7 @@ Tensor & detach_(Tensor & self) {
// NB: is_view() ==> get_autograd_meta()
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(self));
// See NOTE [ View + Inplace detection ]
if (diff_view_meta->creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) {
if (diff_view_meta->get_creation_meta() == CreationMeta::MULTI_OUTPUT_SAFE) {
TORCH_WARN("This view is an output of a function that "
"returns multiple views. Detaching such views inplace "
"is being deprecated and will be forbidden "
@ -272,7 +344,8 @@ Tensor & detach_(Tensor & self) {
"of detach_(). Alternatively, create this view with an "
"`unsafe_` version of the function that produced it.");
} else {
AT_ERROR("If you are using DistributedDataParallel (DDP) for training, "
AT_ERROR("Can't detach views in-place. Use detach() instead. "
"If you are using DistributedDataParallel (DDP) for training, "
"and gradient_as_bucket_view is set as True, gradients are "
"views of DDP buckets, and hence detach_() cannot be called "
"on these gradients. To fix this error, please refer to the "
@ -290,6 +363,12 @@ Tensor & detach_(Tensor & self) {
autograd_meta->set_requires_grad(false, self.unsafeGetTensorImpl());
autograd_meta->grad_fn_.reset();
autograd_meta->output_nr_ = 0;
// detach only backward gradients for both primal and tangent
if (self.fw_grad(/* level */ 0).defined()) {
self.fw_grad(/* level */ 0).detach_();
}
return self;
}
@ -321,6 +400,7 @@ TORCH_LIBRARY_IMPL(aten, Autograd, m) {
// and requires_grad_(), then remove the backend Autograd kernel here, only leaving the Math kernel.
m.impl("_backward", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::_backward)));
m.impl("requires_grad_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::requires_grad_)));
m.impl("_fw_primal", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::_fw_primal)));
}
TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) {

View File

@ -134,88 +134,111 @@ template<typename... Args> inline variable_list flatten_tensor_args(Args&&... ar
}
// See NOTE [ Autograd View Variables ] for details.
inline Tensor as_view(const Tensor & base, const Tensor& tensor, bool is_differentiable,
c10::optional<std::function<Tensor(const Tensor&)>> view_func=c10::nullopt,
CreationMeta creation_meta=CreationMeta::DEFAULT) {
auto base_var = Variable(base);
if (base_var.is_view()) {
// Set `view_func` using the root base as input.
// `view_func` is used to recover views in backward when either as_strided is not supported
// or the view function changes the metadata which is not recorded by as_strided
// See Note [View + Inplace update on base tensor] and [View + Inplace update on view tensor]
// for more details how we use this function in backward.
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base_var));
if (view_func.has_value()) {
auto fn = view_func.value();
// both current_view and it's parent have a view_func
if (diff_view_meta->has_view_fn()) {
auto prev_fn = diff_view_meta->view_fn();
view_func = [=](const at::Tensor& root_base) {
auto temp = prev_fn(root_base);
return fn(temp);
};
inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_differentiable,
bool is_fw_differentiable, c10::optional<std::function<Tensor(const Tensor&)>> view_func=c10::nullopt,
CreationMeta creation_meta=CreationMeta::DEFAULT, bool allow_tensor_metadata_change=true) {
if (!isForwardADEnabled()) {
// Fast codepath for backward only code
// It is useful as it avoids the creation of the temporary c10<optional> which makes
// a significant difference when measuring instruction count for a single "t.view(-1)" call from c++.
if (is_bw_differentiable) {
if (base.is_view()) {
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
const auto& base_bw_info = diff_view_meta->get_backward_view();
return make_variable_differentiable_view(tensor, base_bw_info.chain(base, tensor, view_func),
c10::nullopt, creation_meta, allow_tensor_metadata_change);
} else {
// current_view has a view_func and but it's parent doesn't have one
if(base_var.unsafeGetTensorImpl()->support_as_strided()) {
auto size = base.sizes().vec();
auto stride = base.strides().vec();
auto storage_offset = base.storage_offset();
view_func = [=](const at::Tensor& root_base) {
auto temp = root_base.as_strided(size, stride, storage_offset);
return fn(temp);
};
} else {
// When base_var is a view but doesn't carry a view_fn in DifferentiableViewMeta, it's
// a view that doesn't support inplace update, e.g. unbind.
// In this case we should throw an error when inplace update happens in **forward**.
// One would naturally think the following function will be first called in backward pass.
// But the first call site is indeed in **forward** pass when we refresh `grad_fn`
// triggered by inplace update.
// Search Note [View + Inplace update for view tensor] to for the call site.
view_func = [=](const at::Tensor& root_base) {
TORCH_CHECK(false, "This view is the output of a function that returns multiple views."
"Such functions do not allow the output views to be modified inplace."
"You should replace the inplace operation by an out-of-place one");
return root_base;
};
}
return make_variable_differentiable_view(tensor, ViewInfo(base, view_func),
c10::nullopt, creation_meta, allow_tensor_metadata_change);
}
} else if(diff_view_meta->has_view_fn()) {
// if current_view doesn't have a view_func but it's parent has one
auto prev_view_fn = diff_view_meta->view_fn();
auto size = tensor.sizes().vec();
auto stride = tensor.strides().vec();
auto storage_offset = tensor.storage_offset();
view_func = [=](const at::Tensor& root_base) {
auto temp = prev_view_fn(root_base);
return temp.as_strided(size, stride, storage_offset);
};
} else {
TORCH_CHECK(creation_meta == CreationMeta::DEFAULT,
"Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
return make_variable_non_differentiable_view(base, std::move(tensor), allow_tensor_metadata_change);
}
base_var = base_var._base();
}
if (is_differentiable) {
return make_variable_differentiable_view(std::move(base_var), tensor, creation_meta, std::move(view_func));
// Create both the forward and backward info that are needed
c10::optional<ViewInfo> new_bw_info;
c10::optional<ViewInfo> new_fw_info;
if (is_bw_differentiable) {
if (base.is_view()) {
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
const auto& base_bw_info = diff_view_meta->get_backward_view();
new_bw_info = base_bw_info.chain(base, tensor, view_func);
} else {
new_bw_info = ViewInfo(base, view_func);
}
} else {
TORCH_CHECK(creation_meta == CreationMeta::DEFAULT,
"Non-differentiable views must have creation_meta=CreationMeta::DEFAULT");
return make_variable_non_differentiable_view(std::move(base_var), tensor);
"Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
}
if (is_fw_differentiable) {
// Check if base is a forward differentiable view
auto base_meta = torch::autograd::impl::get_autograd_meta(base);
auto is_view = base_meta && base_meta->is_view_;
if (is_view && static_cast<DifferentiableViewMeta*>(base_meta)->has_fw_view()) {
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(base_meta);
const auto& base_fw_info = diff_view_meta->get_forward_view();
new_fw_info = base_fw_info.chain(base, tensor, view_func);
} else {
new_fw_info = ViewInfo(base, view_func);
}
}
if (is_fw_differentiable || is_bw_differentiable) {
return make_variable_differentiable_view(tensor, std::move(new_bw_info), std::move(new_fw_info),
creation_meta, allow_tensor_metadata_change);
} else {
return make_variable_non_differentiable_view(base, tensor, allow_tensor_metadata_change);
}
}
// See NOTE [ Autograd View Variables ] for details.
inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& tensors, bool is_differentiable,
CreationMeta creation_meta=CreationMeta::DEFAULT) {
auto base_var = Variable(base);
if (base_var.is_view()) {
base_var = base_var._base();
}
for(Tensor &tensor : tensors) {
if (is_differentiable) {
tensor = make_variable_differentiable_view(base_var, tensor, creation_meta);
inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& tensors, bool is_bw_differentiable,
bool is_fw_differentiable, CreationMeta creation_meta=CreationMeta::DEFAULT) {
c10::optional<ViewInfo> new_bw_info = c10::nullopt;
c10::optional<ViewInfo> new_fw_info = c10::nullopt;
if (is_bw_differentiable) {
if (base.is_view()) {
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
const auto& base_bw_info = diff_view_meta->get_backward_view();
TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::MULTI_OUTPUT_NODE || creation_meta == CreationMeta::MULTI_OUTPUT_SAFE,
"Functions that result multiple view must have a creation meta reflecting this behavior.");
// It is ok to create a ViewInfo where only the base is correct in this case as inplace operations on such views are
// not allowed
new_bw_info = ViewInfo(base_bw_info.base_, /* view_func */ c10::nullopt);
} else {
TORCH_CHECK(creation_meta == CreationMeta::DEFAULT,
"Non-differentiable views must have creation_meta=CreationMeta::DEFAULT");
tensor = make_variable_non_differentiable_view(base_var, tensor);
new_bw_info = ViewInfo(base, /* view_func */ c10::nullopt);
}
} else {
TORCH_CHECK(creation_meta == CreationMeta::DEFAULT,
"Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
}
if (isForwardADEnabled() && is_fw_differentiable) {
// Check if base is a forward differentiabble view
auto base_meta = torch::autograd::impl::get_autograd_meta(base);
auto is_view = base_meta && base_meta->is_view_;
if (is_view && static_cast<DifferentiableViewMeta*>(base_meta)->has_fw_view()) {
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(base_meta);
const auto& base_fw_info = diff_view_meta->get_forward_view();
TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::MULTI_OUTPUT_NODE || creation_meta == CreationMeta::MULTI_OUTPUT_SAFE,
"Functions that result multiple view must have a creation meta reflecting this behavior.");
// It is ok to create a ViewInfo where only the base is correct in this case as inplace operations on such views are
// not allowed
new_fw_info = ViewInfo(base_fw_info.base_, /* view_func */ c10::nullopt);
} else {
new_fw_info = ViewInfo(base, /* view_func */ c10::nullopt);
}
}
for(Tensor &tensor : tensors) {
if (is_fw_differentiable || is_bw_differentiable) {
tensor = make_variable_differentiable_view(tensor, new_bw_info, new_fw_info, creation_meta);
} else {
tensor = make_variable_non_differentiable_view(base, tensor);
}
}
return tensors;

View File

@ -155,5 +155,18 @@ variable_list grad(
outputs, gradients, retain_graph.value(), create_graph, inputs, allow_unused, /*accumulate_grad=*/false);
}
namespace forward_ad {
uint64_t enter_dual_level() {
return ForwardADLevel::get_next_idx();
}
void exit_dual_level(uint64_t level) {
ForwardADLevel::release_idx(level);
}
} // namespace forward_ad
} // namespace autograd
} // namespace torch

View File

@ -75,5 +75,20 @@ TORCH_API variable_list grad(
bool create_graph = false,
bool allow_unused = false);
namespace forward_ad {
/// Creates a new dual level and returns its index. This level index should then be used to call
/// into the other functions below.
/// This API supports entering a new level before the previous one is exited. We call them nested
/// forward AD levels. These can be used to compute higher order derivatives.
TORCH_API uint64_t enter_dual_level();
/// Exits the given level. This will clear up all the gradients from this level and all dual Tensors
/// that had gradients for this level will become regular Tensors again.
/// This function can only be used to exit the innermost nesting level and so exiting must happen in
/// reverse order compared to the entering that was done with the function above.
TORCH_API void exit_dual_level(uint64_t level);
} // namespace forward_ad
} // namespace autograd
} // namespace torch

View File

@ -0,0 +1,218 @@
#include <torch/csrc/autograd/variable.h>
namespace torch {
namespace autograd {
using at::Tensor;
// [Forward Grad View/inplace]
// It is important to us to allow view and inplace to work with dual Tensors. These operations
// should either compute the right gradient or raise a user-friendly error.
// The basic case where all Tensors are dual Tensors is as follows:
// # Have:
// # foo is a dual Tensor that is not a view
// # bar is a dual Tensor of appropriate size (depending on cases) that is not a view
//
// # Case 1: no view
// foo.copy_(bar)
//
// # Case 2: with view, propagate from view to base
// view = foo[0]
// view.copy_(bar)
//
// # Case 3: with view, propagate from base to view
// view = foo[0]
// foo.copy_(bar)
//
// # In both cases, the forward grad of foo must be properly updated.
// # In the second and third cases, the forward grad of view must match
// # the one of foo for the subset they have in common.
//
// All these cases can be handled by the following layout constraint on the forward grad:
// - A Tensor and its forward grad (for all levels) must have the same metadata (size, stride
// and storage offset). Storage offset must be in this metadata because of as_strided.
// - View operations must create a forward grad that is a view of the base's forward grad.
// - Inplace operations must modify the input's forward grad inplace.
//
// This layout constraint is ensured in the `set_fw_grad` function below
// More complex cases arrise when non-dual Tensor interact with dual Tensors.
// The two most important cases are:
//
// # Have:
// # foo is a regular Tensor that is not a view
// # bar is a dual Tensor of appropriate size (depending on cases) that is not a view
//
// # Case 4: Changes on the view must propagate to its base
// view = foo[0]
// # view is still a regular Tensor here
// view.copy_(bar)
// # Now both view and foo are dual Tensor with appropriate forward grad
//
// # Case 5: Changes on the base must propagate on all its views
// view = foo[0]
// # view is still a regular Tensor here
// base.copy_(bar)
// # Now both view and foo are dual Tensor with appropriate forward grad
//
// # NB there is a case 6 involving changes on a view propagating to other views
// # but it is fully described by the two others and is skipped in this discussion.
//
// Case 4 is handled by set_fw_grad by properly setting the forward grad of the base if needed.
// Case 5 is handled in fw_grad by reading the forward grad from the base if needed.
namespace {
// Check if two Tensor have the same storage offset, sizes and strides
bool has_same_meta(const Variable& base, const Variable& other) {
if (!base.defined() || !other.defined()) {
return false;
}
if (base.storage_offset() != other.storage_offset()) {
return false;
}
if (base.dim() != other.dim()) {
return false;
}
for (int64_t i=0; i<base.dim(); ++i) {
if (base.sizes()[i] != other.sizes()[i]) {
return false;
}
if (base.strides()[i] != other.strides()[i]) {
return false;
}
}
return true;
}
Tensor new_with_same_meta(const Variable& base) {
// We need to create a storage of the same size to be able to have the same
// viewing behavior in all cases
// Explicit type here to appease Windows build
int64_t nelement_in_storage = base.storage().nbytes() / base.itemsize();
auto new_tensor = at::zeros({nelement_in_storage}, base.options());
auto res = new_tensor.as_strided(base.sizes(), base.strides(), base.storage_offset());
return res;
}
} // anonymous namespace
// This function is will ensure that the fw_grad_ is properly a view of the base for inplace ops on
// Tensors that do not have forward grad originally.
void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self, uint64_t level, bool is_inplace_op) {
// Lazy initialization
{
std::lock_guard<std::mutex> lock(mutex_);
if (!fw_grad_) {
fw_grad_ = std::make_shared<ForwardGrad>();
}
}
if (fw_grad_->contains(level)) {
// Setting the forward grad again is only allowed if it is a no-op.
// We do allow this case to simplify writing codegen for inplace ops.
TORCH_INTERNAL_ASSERT(new_grad_.defined(), "Cannot set a forward grad that is an undefined Tensor. Use "
"_fw_primal(level) to get a new Tensor with this forward grad unset.");
TORCH_INTERNAL_ASSERT(is_inplace_op, "Only inplace operations can re-set the forward grad of a Tensor that "
"already has one.");
TORCH_INTERNAL_ASSERT(fw_grad_->value(level).is_same(new_grad_), "Cannot set a value of a forward grad if it "
"already exists. Inplace operations should modify it inplace.");
} else {
// TODO(alband) remove this spurious version counter bump
auto new_grad = new_grad_;
if (is_inplace_op && is_view_) {
auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
// For inplace ops on a Tensor that does not already have a forward grad and is a view, we propagate
// the tangent to the base and ensure that the new_grad is a view of that base's tangent.
// This ensure that case 4 from [Forward Grad View/inplace] above works fine
// What happens in this long if statement is:
// - Check if the base already has a grad
// - If not, set a new fw_grad for it full of zeros
// - Take a view of the base's forward grad
// - Copy the given new_grad into this view
// - Use this view as the new new_grad
if (this_view_meta->has_fw_view()) {
auto view_info = this_view_meta->get_forward_view();
auto& base = view_info.base_;
if (!base.fw_grad(level).defined()) {
// Enforce same meta here to make sure that the view op below is always valid
Tensor new_base_fw_grad;
if (has_same_meta(new_grad, base)) {
// TODO extend this special case to when the underlying storage of new_grad
// can be re-used.
new_base_fw_grad = new_grad;
} else {
new_base_fw_grad = new_with_same_meta(base);
// Update new_grad to be a view of the base
Tensor new_fw_grad_value;
if (view_info.has_view_fn()) {
new_fw_grad_value = view_info.view_fn()(new_base_fw_grad);
} else {
new_fw_grad_value = new_base_fw_grad.as_strided(self.sizes(), self.strides(), self.storage_offset());
}
new_fw_grad_value.copy_(new_grad);
new_grad = new_fw_grad_value;
}
base.set_fw_grad(new_base_fw_grad, level, /* is_inplace_op */ false);
}
}
}
// Enforce the basic layout constraint
if (!has_same_meta(new_grad, self)) {
Tensor new_grad_with_meta = new_with_same_meta(self);
new_grad_with_meta.copy_(new_grad);
new_grad = new_grad_with_meta;
}
fw_grad_->set_value(new_grad, level);
}
}
const Variable& AutogradMeta::fw_grad(uint64_t level, const Variable& self) const {
// Ensure that concurent fw_grad() "reads" are thread safe
std::lock_guard<std::mutex> lock(mutex_);
const auto& direct_fw_grad = fw_grad_ ? fw_grad_->value(level) : ForwardGrad::undef_grad();
if (!direct_fw_grad.defined() && is_view_) {
// For view that don't have a forward grad, check if their base has one that
// has been defined by an inplace operation.
// This ensure that case 5 from [Forward Grad View/inplace] above works fine
auto const_view_meta = static_cast<const torch::autograd::DifferentiableViewMeta*>(this);
// This is ok to do as we ONLY modify fw_grad_ and this field is properly locked in all methods
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto this_view_meta = const_cast<torch::autograd::DifferentiableViewMeta*>(const_view_meta);
if (this_view_meta->has_fw_view()) {
const auto& view_info = this_view_meta->get_forward_view();
const auto& base = view_info.base_;
const auto& base_val = base.fw_grad(level);
if (base_val.defined()) {
// Lazy initialization of fw_grad_
this_view_meta->fw_grad_ = std::make_shared<ForwardGrad>();
Variable new_val;
if (view_info.has_view_fn()) {
new_val = view_info.view_fn()(base_val);
} else {
new_val = base_val.as_strided(self.sizes(), self.strides(), self.storage_offset());
}
this_view_meta->fw_grad_->set_value(new_val, level);
return this_view_meta->fw_grad_->value(level);
}
}
}
return direct_fw_grad;
}
}} // torch::autograd

View File

@ -124,7 +124,7 @@ variable_list _wrap_outputs(const variable_list &input_vars,
if (!(is_input && is_modified) && var.is_view()) {
// NB: is_view() ==> get_autograd_meta()
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(impl::get_autograd_meta(var));
diff_view_meta->creation_meta = CreationMeta::IN_CUSTOM_FUNCTION;
diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION);
}
if (is_differentiable) {
@ -142,7 +142,7 @@ variable_list _wrap_outputs(const variable_list &input_vars,
if (var.is_view()) {
// NB: is_view() ==> get_autograd_meta()
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(impl::get_autograd_meta(var));
diff_view_meta->creation_meta = CreationMeta::MULTI_OUTPUT_NODE;
diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
}
}
}

View File

@ -0,0 +1,90 @@
#include <torch/csrc/autograd/forward_grad.h>
namespace torch { namespace autograd {
namespace {
// See discussion in forward_grad.h for why these are global variables and not
// thread local
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static std::mutex all_forward_levels_mutex_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static uint64_t next_forward_idx_ = 0;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static std::vector<std::shared_ptr<ForwardADLevel>> all_forward_levels_;
const static at::Tensor singleton_undefined_tensor;
// Temporary flag to disable forward mode
// TODO(alband) remove these when perf issues are solved
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static bool is_forward_grad_enabled = false;
}
uint64_t ForwardADLevel::get_next_idx() {
std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
TORCH_CHECK(next_forward_idx_ == 0, "Nested forward mode AD is not supported at the moment");
auto new_index = next_forward_idx_++;
TORCH_INTERNAL_ASSERT(new_index == all_forward_levels_.size());
all_forward_levels_.push_back(std::make_shared<ForwardADLevel>(new_index));
return new_index;
}
void ForwardADLevel::release_idx(uint64_t idx) {
std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
TORCH_CHECK(idx == all_forward_levels_.size() - 1, "Exiting a forward AD level that is not the "
"last that was created is not support. Ensure they are released in the reverse "
"order they were created.");
TORCH_CHECK(idx >= 0, "No forward AD level was created so you cannot exit it.");
next_forward_idx_--;
all_forward_levels_.pop_back();
}
std::shared_ptr<ForwardADLevel> ForwardADLevel::get_by_idx(uint64_t idx) {
std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
TORCH_CHECK(idx < all_forward_levels_.size(), "Trying to access a forward AD level with an invalid index. "
"This index was either not created or is already deleted.");
return all_forward_levels_[idx];
}
std::shared_ptr<ForwardADLevel> ForwardADLevel::try_get_by_idx(uint64_t idx) {
std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
if (idx < all_forward_levels_.size()) {
return all_forward_levels_[idx];
} else {
return nullptr;
}
}
ForwardADLevel::~ForwardADLevel() {
std::lock_guard<std::mutex> lock(mutex_);
auto it = grads_.begin();
while (it != grads_.end()) {
// Warning this will lock *it mutex
// This is ok as this function is the *only* one to call back into another class's method.
(*it)->reset(idx_, /* update_level */ false);
it = grads_.erase(it);
}
}
const at::Tensor& ForwardGrad::value(uint64_t level) const {
std::lock_guard<std::mutex> lock(mutex_);
const auto& it = content_.find(level);
return it == content_.end() ? singleton_undefined_tensor : (*it).second;
}
const at::Tensor& ForwardGrad::undef_grad() {
return singleton_undefined_tensor;
}
// Temporary functions to disable forward AD
// TODO(alband) remove these when perf issues are solved
bool isForwardADEnabled() {
return is_forward_grad_enabled;
}
void setForwardADEnabled(bool value) {
is_forward_grad_enabled = value;
}
}} // namespace torch::autograd

View File

@ -0,0 +1,193 @@
#pragma once
#include <ATen/ATen.h>
namespace torch { namespace autograd {
// [ Using ForwardGrad ]
// ForwardGrad needs to be a shared_ptr to satisfy constraints of its inner design. But
// this shared_ptr must be uniquely associated with the object that stores it (as of
// writing, either AutogradMeta or SavedVariable). This object is called the "owning object"
// in the discussions below. This owning object must call `ForwardGrad::clear()` when it
// is destroyed to ensure that the ForwardGrad is properly de-allocated.
struct ForwardGrad;
// This file contains two classes that are used to store forward AD gradients and
// ensure that they are scoped properly.
// Because forward AD runs concurrently with the evaluation of the function, we need
// a mechanism to separate different forward AD invocations and be able to compute the
// right gradients. We model such invocations as levels here.
// The particular scoping issue mentioned above has two main drivers:
// - Ensure that we can conveniently use forward AD within a high level API without
// leaking the forward AD states outside.
// - Ensure that we can keep the level that we expose to the user API simple (an integer
// that represents the nesting depth) while avoiding confusions when the level index
// is re-used.
// The important external APIs from this file are:
// - ForwardADLevel::get_next_idx() that can be used to enter a new level and get its index
// - ForwardADLevel::release_idx() that can be used to exit a given level.
// - ForwardGrad() can be used to store a given forward gradient that will handle the level
// tracking automatically.
// The basic implementation strategy is as follows:
// Every tensor has a ForwardGrad, maintaining a map from levels to tangents.
// ForwardGrad is responsible for registering itself to the appropriate ForwardADLevel when a new
// tangent is added to it via ForwardGrad::set_value and to un-register itself from this same level
// if that tangent is removed via ForwardGrad::reset.
// The ForwardADLevel is created when a new level is entered via ForwardADLevel::get_next_idx.
// A reference to the new ForwardADLevel is stored into a global (for the whole process) vector that
// ensure it can be accessed via ForwardADLevel::get_by_idx. This reference is deleted when the index is
// released by the user when calling ForwardADLevel::release_idx.
// When it is destructed, the ForwardADLevel is responsible for clearing all the tangents for its
// level stored in all the ForwardGrad that registered with it.
//
// This process-wide level design, compared to a thread local one, allows us to use very simple user facing
// handle for the level (an int) while enabling cross-thread forward AD.
// The only required synchronization for the user is when entering and exiting the levels.
// Some discussion on alternative design is in https://github.com/pytorch/pytorch/pull/49097#discussion_r543716453
// and can be refined in the future.
// Correctness of concurrency:
// Each class uses its own lock when reading or modifying internal storages. This allows in particular
// to safely remove tangents from ForwardGrad when the ForwardADLevel is being exited.
// We ensure no deadlock by ensuring that a methods never calls into another class's method while
// the local class's lock is held except in one single case: calling from ForwardADLevel's destructor
// into ForwardGrad::reset with update_level=false.
// The lifetime of these objects is as follows:
// The ForwardADLevel can be in three states:
// - Initialized: where one of its reference is held by the global vector and there may be more
// references held by temporary variables in ForwardGrad's methods.
// - About to be destructed: where "release_idx" has been called and the only reason for the
// ForwardADLevel not to be destructed right away is that some methods in ForwardGrad have
// owning reference to it. This is done so that a ForwardADLevel can never be destructed when
// a ForwardGrad is registered with it and in the process of adding something to its internal state.
// - Being destructed: Here the ForwardADLevel is not referenced anymore and can be safely reset
// all of the ForwardGrad. Note that we can have more than one reset being called here (which is ok)
// but we are guaranteed that there is at least one.
// The ForwardGrad is simpler as there is no intermediary state and no special destructor for. The logic to
// unregister it from the different ForwardADLevel is done when the owning object (AutogradMeta or
// SavedVariable) is being destroyed.
// Other considered design:
// To avoid having the ForwardGrad::clear, we considered storing weak_ptr inside the ForwardADLevel. While this
// would work, it would mean that the set inside the ForwardADLevel would only grow unless we do an
// expensive linear scan to remove all the dangling weak pointers. Hence this approach was not used.
// Data structures in this file are optimized for this maximum number of levels.
// The number of levels corresponds to the degree of the gradient being
// computed using forward AD and we don't expect more than second order gradients
// to be common.
#define EXPECTED_MAX_LEVEL 2
struct TORCH_API ForwardADLevel {
ForwardADLevel(uint64_t idx): idx_(idx) {}
~ForwardADLevel();
static uint64_t get_next_idx();
static void release_idx(uint64_t idx);
static std::shared_ptr<ForwardADLevel> get_by_idx(uint64_t idx);
static std::shared_ptr<ForwardADLevel> try_get_by_idx(uint64_t idx);
void erase(const std::shared_ptr<ForwardGrad>& grad) {
std::lock_guard<std::mutex> lock(mutex_);
grads_.erase(grad);
}
void insert(const std::shared_ptr<ForwardGrad>& grad) {
std::lock_guard<std::mutex> lock(mutex_);
grads_.insert(grad);
}
private:
std::unordered_set<std::shared_ptr<ForwardGrad>> grads_;
std::mutex mutex_;
uint64_t idx_;
};
struct TORCH_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> {
ForwardGrad() {}
// This function must only be called when AutogradMeta or SavedVariable is being
// destructed as it ensures that:
// - The only (potential) other references to this ForwardGrad are the
// different level it is registered to
// - No other thread will try to call `set_value` or `value` ever from now on
// - Any of the ForwardADLevel that this ForwardGrad is registered with might
// call `reset` at any point during this function
void clear() {
c10::SmallVector<uint64_t, EXPECTED_MAX_LEVEL> levels_idx;
{
std::lock_guard<std::mutex> lock(mutex_);
for (auto& c: content_) {
levels_idx.push_back(c.first);
}
}
for (auto l_idx: levels_idx) {
// Use "try" version here as another thread might have deleted this
// level before we got here
// This is an owning reference as we want to keep the level alive
// until we successfully unregister ourselves
auto level = ForwardADLevel::try_get_by_idx(l_idx);
if (level) {
level->erase(shared_from_this());
}
}
}
void set_value(const at::Tensor& value, uint64_t level) {
// Owning reference to ensure the forward_level is not destroyed
// while we are updating our internal state
auto forward_level = ForwardADLevel::get_by_idx(level);
forward_level->insert(shared_from_this());
std::lock_guard<std::mutex> lock(mutex_);
content_.insert({level, value});
}
// This function removes the tangent for a given level from this ForwardGrad
// Use the update_level flag to disable notifying the level about this reset
// This flag is most notably used by the ForwardADLevel destructor.
void reset(uint64_t level, bool update_level=true) {
if (update_level) {
ForwardADLevel::get_by_idx(level)->erase(shared_from_this());
}
std::lock_guard<std::mutex> lock(mutex_);
content_.erase(level);
}
const at::Tensor& value(uint64_t level) const;
bool contains(uint64_t level) {
std::lock_guard<std::mutex> lock(mutex_);
return content_.count(level) > 0;
}
bool empty() const {
return content_.empty();
}
static const at::Tensor& undef_grad();
private:
// TODO(albanD): replace this with a SmallVector
std::unordered_map<uint64_t, at::Tensor> content_;
mutable std::mutex mutex_;
};
// Temporary functions to disable forward AD
// TODO(alband) remove these when perf issues are solved
bool TORCH_API isForwardADEnabled();
void TORCH_API setForwardADEnabled(bool value);
}} // namespace torch::autograd

View File

@ -47,4 +47,8 @@ auto UndefinedGradBackward::apply(variable_list&& output_grads) -> variable_list
return input_grads;
}
auto Identity::apply(variable_list&& grads) -> variable_list {
return std::move(grads);
}
}} // namespace torch::autograd

View File

@ -83,4 +83,8 @@ struct TORCH_API GraphRoot : public Node {
variable_list outputs;
};
struct TORCH_API Identity : public Node {
variable_list apply(variable_list&& inputs) override;
};
}}

View File

@ -3,11 +3,15 @@
#include <c10/core/DeviceType.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <ATen/autocast_mode.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/python_function.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
using namespace torch::autograd::profiler;
@ -230,6 +234,26 @@ static PyObject * autocast_decrement_nesting(PyObject* _unused, PyObject *arg) {
END_HANDLE_TH_ERRORS
}
static PyObject * set_forward_AD_enabled(PyObject* _unused, PyObject *arg) {
HANDLE_TH_ERRORS
if (!PyBool_Check(arg)) {
throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
}
setForwardADEnabled(arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject * is_forward_AD_enabled(PyObject* _unused, PyObject *arg) {
HANDLE_TH_ERRORS
if (isForwardADEnabled()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) {
HANDLE_TH_ERRORS
if (!PyBool_Check(arg)) {
@ -270,10 +294,34 @@ static PyObject * is_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) {
END_HANDLE_TH_ERRORS
}
static PyObject * python_enter_dual_level(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
// It is unlikely that the depth of forward nesting will overflow int64_t so we
// just static cast here.
return utils::wrap(static_cast<int64_t>(forward_ad::enter_dual_level()));
END_HANDLE_TH_ERRORS
}
static PyObject * python_exit_dual_level(PyObject* _unused, PyObject* args, PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"exit_dual_level(int64_t level)"
});
ParsedArgs<1> parsed_args;
auto _r = parser.parse(args, kwargs, parsed_args);
forward_ad::exit_dual_level(_r.toInt64(0));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// autograd methods on torch._C
static PyMethodDef methods[] = { // NOLINT
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
{"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr},
{"_set_forward_AD_enabled", set_forward_AD_enabled, METH_O, nullptr},
{"_is_forward_AD_enabled", is_forward_AD_enabled, METH_NOARGS, nullptr},
{"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr},
{"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr},
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
@ -281,6 +329,8 @@ static PyMethodDef methods[] = { // NOLINT
{"autocast_decrement_nesting", autocast_decrement_nesting, METH_NOARGS, nullptr},
{"set_anomaly_enabled", set_anomaly_mode_enabled, METH_O, nullptr},
{"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
{"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr},
{nullptr, nullptr, 0, nullptr}
};

View File

@ -24,6 +24,12 @@ SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_i
// These copies are all shared_ptr copies, so slightly more expensive.
// Do them here instead of in the init list in case data is undefined.
data_ = variable.tensor_data();
// TODO(albanD) This needs to be updated when moving to multiple levels
const auto& fw_grad = variable.fw_grad(/* level */ 0);
if (fw_grad.defined()) {
fw_grad_ = std::make_shared<ForwardGrad>();
fw_grad_->set_value(fw_grad, /* level */ 0);
}
if (variable.is_leaf()) {
grad_accumulator_ = impl::grad_accumulator(variable);
} else if (!is_output) {
@ -100,6 +106,16 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
throw std::logic_error("No grad accumulator for a saved leaf!");
impl::set_grad_accumulator(var, grad_accumulator_);
// NB: var here is never a view so there is no need to make anything special
// for the case where the saved Tensor was a view. This whole argument relies
// on the fact that the Tensor returned by this function is never
// modified in-place.
if (fw_grad_ && !fw_grad_->empty()) {
// TODO(albanD) This needs to be updated when moving to multiple levels
auto new_fw_grad = fw_grad_->value(/* level */ 0);
var.set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
}
return var;
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/autograd/forward_grad.h>
#include <ATen/ATen.h>
@ -23,6 +24,12 @@ class TORCH_API SavedVariable {
SavedVariable(const c10::optional<Variable>& variable, bool is_output, bool is_inplace_view=false);
SavedVariable(SavedVariable&&) = default;
SavedVariable& operator=(SavedVariable&&) = default;
~SavedVariable() {
if (fw_grad_) {
// See note [ Using ForwardGrad ]
fw_grad_->clear();
}
}
/// Reconstructs the saved variable. Pass `saved_for` as the gradient
/// function if constructing the `SavedVariable` with it would have caused a
@ -40,6 +47,11 @@ class TORCH_API SavedVariable {
private:
at::Tensor data_;
// This field is used to store the forward AD gradients associated with
// the saved Tensor. Note that this shared_ptr must never be shared with
// either the saved Tensor or the unpacked Tensor. See note [ Using ForwardGrad ]
std::shared_ptr<ForwardGrad> fw_grad_;
// The gradient function associated with this node. If has_grad_fn
// is false, then this is a leaf node. Note that the grad_fn is not saved if
// it would create a circular reference. In that case, the grad_fn must be

View File

@ -11,6 +11,7 @@
#include <ATen/core/VariableHooksInterface.h>
#include <ATen/ATen.h>
#include <ATen/MemoryOverlap.h>
#include <c10/util/Exception.h>
#include <list>
@ -20,28 +21,83 @@
#include <string>
#include <vector>
#include <typeinfo>
#include <iostream>
namespace torch {
namespace autograd {
DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base,
c10::optional<std::function<at::Tensor(const at::Tensor&)>> view_fn,
DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl,
c10::optional<ViewInfo> backward_info,
c10::optional<ViewInfo> forward_info,
CreationMeta creation_meta)
: AutogradMeta(self_impl), creation_meta(creation_meta) {
base_ = std::move(base);
view_fn_ = std::move(view_fn);
TORCH_CHECK(base_.defined(), "base is undefined");
if (base_.is_view()) {
base_ = base_._base();
}
: AutogradMeta(self_impl),
backward_info_(std::move(backward_info)),
forward_info_(std::move(forward_info)),
creation_meta(creation_meta) {
is_view_ = true;
self_impl->set_version_counter(impl::version_counter(base_));
attr_version = self_impl->version_counter().current_version();
if (backward_info_.has_value()) {
self_impl->set_version_counter(impl::version_counter(backward_info_.value().base_));
attr_version = self_impl->version_counter().current_version();
}
}
DifferentiableViewMeta::~DifferentiableViewMeta() {
base_.reset();
// Chain this view info with the new view op between base and tensor
ViewInfo ViewInfo::chain(const Variable & base, const Variable & tensor,
c10::optional<std::function<Variable(const Variable&)>> view_func) const {
// Set `view_func` using the root base as input.
// `view_func` is used to recover views in backward when either as_strided is not supported
// or the view function changes the metadata which is not recorded by as_strided
// See Note [View + Inplace update on base tensor] and [View + Inplace update on view tensor]
// for more details how we use this function in backward.
if (view_func.has_value()) {
auto fn = view_func.value();
// both current_view and it's parent have a view_func
if (view_fn_.has_value()) {
auto prev_fn = view_fn_.value();
view_func = [=](const at::Tensor& root_base) {
auto temp = prev_fn(root_base);
return fn(temp);
};
} else {
// current_view has a view_func and but it's parent doesn't have one
if (base.unsafeGetTensorImpl()->support_as_strided()) {
auto size = base.sizes().vec();
auto stride = base.strides().vec();
auto storage_offset = base.storage_offset();
view_func = [=](const at::Tensor& root_base) {
auto temp = root_base.as_strided(size, stride, storage_offset);
return fn(temp);
};
} else {
// When base is a view but doesn't carry a view_fn in DifferentiableViewMeta, it's
// a view that doesn't support inplace update, e.g. unbind.
// In this case we should throw an error when inplace update happens in **forward**.
// One would naturally think the following function will be first called in backward pass.
// But the first call site is indeed in **forward** pass when we refresh `grad_fn`
// triggered by inplace update.
// Search Note [View + Inplace update for view tensor] to for the call site.
view_func = [=](const at::Tensor& root_base) {
TORCH_CHECK(false, "This view is the output of a function that returns multiple views."
"Such functions do not allow the output views to be modified inplace."
"You should replace the inplace operation by an out-of-place one");
return root_base;
};
}
}
} else if(view_fn_.has_value()) {
// if current_view doesn't have a view_func but it's parent has one
auto prev_view_fn = view_fn_.value();
auto size = tensor.sizes().vec();
auto stride = tensor.strides().vec();
auto storage_offset = tensor.storage_offset();
view_func = [=](const at::Tensor& root_base) {
auto temp = prev_view_fn(root_base);
return temp.as_strided(size, stride, storage_offset);
};
}
return ViewInfo(base_, view_func);
}
namespace {
@ -81,21 +137,23 @@ namespace impl {
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(get_autograd_meta(self));
// See NOTE [ View + Inplace detection ]
if (diff_view_meta->creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) {
auto creation_meta = diff_view_meta->get_creation_meta();
if (creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) {
// Do not use handle_view_on_rebase here as check_inplace should have been called before this
// and either throw an error or clear the warning
// Temporary error message as a full fix is too risky for now
// Should be an internal assert again
TORCH_INTERNAL_ASSERT(diff_view_meta->creation_meta == CreationMeta::DEFAULT);
TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT);
TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0);
TORCH_INTERNAL_ASSERT(gradient_edge.function);
TORCH_CHECK(
gradient_edge.function->num_inputs() == 1,
"Functions which modify views in-place must return a single Variable");
auto view_info = diff_view_meta->get_backward_view();
diff_view_meta->output_nr_ = gradient_edge.input_nr;
auto copy_slices = std::make_shared<CopySlices>(
diff_view_meta->base_, at::TensorGeometry(self), diff_view_meta->view_fn_, std::move(gradient_edge.function));
set_gradient_edge(diff_view_meta->base_, {std::move(copy_slices), 0});
view_info.base_, at::TensorGeometry(self), view_info.view_fn_, std::move(gradient_edge.function));
set_gradient_edge(view_info.base_, {std::move(copy_slices), 0});
self.grad_fn(); // trigger an update to the view's grad_fn
return;
}
@ -181,7 +239,7 @@ namespace impl {
if (self.is_view()) {
// NB: is_view() ==> get_autograd_meta()
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(meta);
diff_view_meta->attr_version = self._version();
diff_view_meta->set_attr_version(self._version());
}
}
@ -298,12 +356,14 @@ Tensor VariableHooks::tensor_data(const Tensor& self) const {
return at::Tensor(self_impl_copy);
}
// View Variables
// Backward View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
bool VariableHooks::is_view(const Tensor& self) const {
if (torch::autograd::impl::get_autograd_meta(self)) {
return torch::autograd::impl::get_autograd_meta(self)->is_view_;
auto meta = torch::autograd::impl::get_autograd_meta(self);
if (meta && meta->is_view_) {
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(meta);
return diff_view_meta->has_bw_view();
} else {
return false;
}
@ -313,9 +373,10 @@ const Tensor& VariableHooks::base(const Tensor& self) const {
if (self.is_view()) {
// is_view() implies get_autograd_meta()
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(self));
return diff_view_meta->base_;
TORCH_CHECK(diff_view_meta->has_bw_view(), "Can't get base of non-backward view Tensor");
return diff_view_meta->get_backward_view().base_;
} else {
throw std::runtime_error("Can't get base of non-view Variable");
throw std::runtime_error("Can't get base of non-view Tensor");
}
}
@ -342,13 +403,14 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tenso
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(self));
// See NOTE [ View + Inplace detection ]
if (diff_view_meta->creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) {
if (diff_view_meta->get_creation_meta() != CreationMeta::MULTI_OUTPUT_SAFE) {
std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
if (!diff_view_meta->grad_fn_ && !diff_view_meta->base_.requires_grad()) {
auto view_info = diff_view_meta->get_backward_view();
if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) {
return diff_view_meta->grad_fn_;
}
auto current_version = self._version();
if (diff_view_meta->attr_version != current_version) {
if (diff_view_meta->get_attr_version() != current_version) {
// This is an indirect rebase_history due to another view or the base being modified inplace
handle_view_on_rebase(diff_view_meta, /* indirect */ true);
TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0);
@ -377,24 +439,24 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tenso
//
// TODO: Potentially the following logic can be replaced by special logic in VariableType_x.cpp
// that would provide a way to recreate the grad_fn chain.
if (diff_view_meta->has_view_fn()) {
auto view_fn = diff_view_meta->view_fn();
auto diff_view = view_fn(diff_view_meta->base_);
if (view_info.has_view_fn()) {
auto view_fn = view_info.view_fn();
auto diff_view = view_fn(view_info.base_);
diff_view_meta->grad_fn_ = diff_view.grad_fn();
} else {
auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward>();
fn->self_geometry = at::TensorGeometry(diff_view_meta->base_);
fn->self_geometry = at::TensorGeometry(view_info.base_);
fn->size = self.sizes().vec();
fn->stride = self.strides().vec();
fn->storage_offset = self.storage_offset();
fn->set_next_edges(torch::autograd::collect_next_edges(diff_view_meta->base_));
fn->set_next_edges(torch::autograd::collect_next_edges(view_info.base_));
fn->add_input_metadata(
diff_view_meta->base_.options(),
view_info.base_.options(),
self.sizes(), // Note: sizes(), not base_.sizes(), is intentional
diff_view_meta->base_.device());
view_info.base_.device());
diff_view_meta->grad_fn_ = std::move(fn);
}
diff_view_meta->attr_version = current_version;
diff_view_meta->set_attr_version(current_version);
}
return diff_view_meta->grad_fn_;
}
@ -429,7 +491,8 @@ unsigned VariableHooks::_register_hook(const Tensor& self, std::function<Tensor(
void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect) {
/// See NOTE [ View + Inplace detection ] for justification of the logic below
if (diff_view_meta->creation_meta != CreationMeta::DEFAULT) {
auto creation_meta = diff_view_meta->get_creation_meta();
if (creation_meta != CreationMeta::DEFAULT) {
auto grad_fn = diff_view_meta->grad_fn_.get();
std::string msg;
std::string modified_obj;
@ -446,24 +509,24 @@ void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect
msg = c10::str("A view was created in no_grad mode and ", modified_obj, " modified inplace with grad mode enabled.");
}
if (diff_view_meta->creation_meta == CreationMeta::MULTI_OUTPUT_NODE) {
if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) {
TORCH_CHECK(false, msg, " This view is the output of a function that returns multiple views. Such functions do not"
" allow the output views to be modified inplace. You should replace the inplace operation by an"
" out-of-place one.");
} else {
if (diff_view_meta->creation_meta == CreationMeta::NO_GRAD_MODE) {
if (creation_meta == CreationMeta::NO_GRAD_MODE) {
TORCH_INTERNAL_ASSERT(!grad_fn);
msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is deprecated and will be forbidden"
" starting 1.6 (see https://github.com/pytorch/pytorch/pull/32839 for more details about this). You"
" can clarify your code and remove this warning by moving both the view and the inplace either both"
" inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want"
" the inplace to be tracked).");
} else if (diff_view_meta->creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) {
} else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) {
msg = c10::str(msg, " This view was created inside a custom Function (or because an input was returned as-is) and the"
" autograd logic to handle view+inplace would override the custom backward associated with the custom"
" Function, leading to incorrect gradients. This behavior is deprecated and will be forbidden starting"
" version 1.6. You can remove this warning by cloning the output of the custom Function.");
} else if (diff_view_meta->creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) {
} else if (creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) {
msg = c10::str(msg, " This view is an output of a function that "
"returns multiple views. Inplace operators on such "
"views are being deprecated and will be forbidden "
@ -487,8 +550,10 @@ void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect
// We warn only once per view
// Note that if a Tensor is modified inplace from two threads at the same time, this is not thread safe and can warn
// multiple time. This is ok as it should be a rare event.
diff_view_meta->creation_meta = CreationMeta::DEFAULT;
diff_view_meta->set_creation_meta(CreationMeta::DEFAULT);
}
}
}} // namespace torch::autograd

View File

@ -6,6 +6,7 @@
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/cpp_hook.h>
#include <torch/csrc/autograd/forward_grad.h>
#include <ATen/ATen.h>
#include <ATen/NamedTensorUtils.h>
@ -193,6 +194,17 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
std::shared_ptr<Node> grad_fn_;
std::weak_ptr<Node> grad_accumulator_;
// This field is used to store all the forward AD gradients
// associated with this AutogradMeta (and the Tensor it corresponds to)
// There is a semantic 1:1 correspondence between AutogradMeta and
// ForwardGrad but:
// - This field is lazily populated.
// - This field is a shared_ptr but it must never be
// shared by multiple Tensors. See Note [ Using ForwardGrad ]
// Any transition from not_initialized to initialized
// must be protected by mutex_
std::shared_ptr<ForwardGrad> fw_grad_;
std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
std::shared_ptr<hooks_list> cpp_hooks_list;
@ -211,9 +223,11 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
uint32_t output_nr_;
// Mutex to ensure that concurrent read operations that modify internal
// state are still thread-safe. Used by grad_fn() and
// grad_accumulator().
std::mutex mutex_;
// state are still thread-safe. Used by grad_fn(), grad_accumulator(),
// fw_grad() and set_fw_grad()
// This is mutable because we need to be able to acquire this from const
// version of this class for the functions above
mutable std::mutex mutex_;
/// Sets the `requires_grad` property of `Variable`. This should be true for
/// leaf variables that want to accumulate gradients, and false for all other
@ -238,6 +252,10 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
return grad_;
}
const Variable& fw_grad(uint64_t level, const Variable& self) const override;
void set_fw_grad(const Variable& new_grad, const Variable& self, uint64_t level, bool is_inplace_op) override;
AutogradMeta(at::TensorImpl* self_impl = nullptr, bool requires_grad = false, Edge gradient_edge = Edge() ) {
grad_fn_ = std::move(gradient_edge.function);
requires_grad_ = false;
@ -254,6 +272,55 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
!grad_fn_ || !requires_grad_,
"requires_grad should be false if grad_fn is set");
}
~AutogradMeta() override {
// If AutogradMeta is being destroyed, it means that there is no other reference to its
// corresponding Tensor. It implies that no other thread can be using this object and so there is
// no need to lock mutex_ here to guard the check if fw_grad_ is populated.
if (fw_grad_) {
// See note [ Using ForwardGrad ]
fw_grad_->clear();
}
}
};
struct TORCH_API ViewInfo {
/// The base `Variable`
/// If this ViewInfo represents a forward (respectively backward) AD gradient,
/// then this Tensor cannot be a forward (respectively backward) view.
Variable base_;
/// By default we use as_strided to recover views which is more efficient.
/// view_fn is only saved when as_strided is not supported.
/// If view_fn has value, we use it to recover views in backward.
c10::optional<std::function<Variable(const Variable&)>> view_fn_;
/// Accessors for the view function
bool has_view_fn() const {
return view_fn_.has_value();
}
std::function<Variable(const Variable&)> view_fn() const {
TORCH_CHECK(has_view_fn(), "Can only access the view function if it exists.");
return view_fn_.value();
}
/// The chain function can be used to build a new ViewInfo for a differentiable view
/// function. It will return a new view info that accurately represents how "tensor" is
/// a view of this instance's "base_".
/// The "base" and "tensor" are respectively the input and output of the differentiable
/// view function that happened. They are required to properly set the optional
/// view_fn_ when it is not provided.
/// The "view_func", if provided, should be a function that allows to re-do the view
/// between "base" and "tensor".
ViewInfo chain(const Variable & base, const Variable & tensor,
c10::optional<std::function<Variable(const Variable&)>> view_func=c10::nullopt) const;
ViewInfo(Variable base, c10::optional<std::function<Variable(const Variable&)>> view_fn) :
base_(std::move(base)),
view_fn_(std::move(view_fn)) {
TORCH_CHECK(base_.defined(), "base is undefined");
}
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -274,6 +341,27 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
///
/// Differentiable Views
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// This class allows to track both forward and backward AD differentiable views.
/// These views can have different base as non-differentiable view for forward
/// and backward mode AD are not the same.
///
/// Most function are either both forward and backward differentiable views (for
/// example: view, select, narrow, transpose, etc) or both not forward and not
/// backward differentiable views (for example: indices, values, eq, lt, etc).
/// But there are also functions that are forward but not backward differentiable
/// views (only detach for now) or functions that are backward but not forward
/// differentiable view (only make_dual and unpack dual for now).
///
/// A concrete example of two views with different bases is as follow:
///
/// # Have:
/// # dual is a dual Tensor that is neither a forward or backward view
/// detached_dual = dual.detach()
/// view = detached_dual.view_as(dual)
/// # The forward base of view is dual
/// # The backward base of view is detached_dual
///
/// - Backward Mode View
/// Differentiable views are the view variables where you want gradients to flow
/// back to the base variables. Out-of-place operations on views are quite
/// straightforward, but in-place ones are very tricky. Even if the base
@ -300,6 +388,34 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
/// var[1] filled with all ones and
/// zeros everywhere else
///
/// - Forward Mode View
/// Forward differentiable views follow the same semantic as backward ones but
/// show up differently as they are computed along with the forward evaluation.
/// The hard examples above are thus very similar
///
/// (1) in-place operation on view, e.g.,
///
/// # Have:
/// # base is a regular Tensor
/// # var is a dual Tensor whose tangent is all ones
/// base[1] = var # i.e., base[1].copy_(var)
/// # Now, base is a dual Tensor
/// _, fw_grad = fwAD.unpack_dual(base) <- fw_grad should be a tensor with
/// fw_grad[1] filled with all ones and
/// zeros everywhere else
///
/// (2) in-place operation on base after view is created, e.g.,
///
/// # Have:
/// # base is a regular Tensor
/// # var is a dual Tensor whose tangent is all ones
/// view = base[1]
/// base.copy_(var)
/// _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones tensor
///
/// See Note [Forward Grad View/inplace] for more details on how we handle these hard cases.
///
///
/// DifferentiableViewMeta is created to support gradient tracking of
/// such **in-place** operations. In particular,
/// + if an in-place op is done on base, the grad_fn field of the view may
@ -392,37 +508,66 @@ enum class CreationMeta: uint8_t { DEFAULT, IN_CUSTOM_FUNCTION, MULTI_OUTPUT_NOD
TORCH_API void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect=false);
struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
/// The base `Variable` (never a view).
Variable base_;
private:
/// Informations about the views
c10::optional<ViewInfo> backward_info_;
c10::optional<ViewInfo> forward_info_;
/// The two following fields are extra information that we track to ensure that
/// any operation on this backward view is valid.
/// The value of the version_counter at the time grad_fn was created. The
/// grad_fn field is stale if attr_version !=
/// version_counter.current_version().
/// grad_fn field is stale if attr_version != version_counter.current_version().
uint32_t attr_version;
/// By default we use as_strided to recover views which is more efficient.
/// view_fn is only saved when as_strided is not supported.
/// If view_fn has value, we use it to recover views in backward.
c10::optional<std::function<at::Tensor(const at::Tensor&)>> view_fn_;
CreationMeta creation_meta;
public:
/// requires_grad is a backward AD field so we only use the view specific logic
/// for backward differentiable views
bool requires_grad() const override {
return requires_grad_ || grad_fn_ || (is_view_ && base_.requires_grad());
return requires_grad_ || grad_fn_ || (has_bw_view() && get_backward_view().base_.requires_grad());
}
bool has_view_fn() const {
return view_fn_.has_value();
bool has_bw_view() const {
return backward_info_.has_value();
}
std::function<at::Tensor(const at::Tensor&)> view_fn() const {
TORCH_CHECK(has_view_fn(), "view_fn is not set.");
return view_fn_.value();
const ViewInfo& get_backward_view() const {
TORCH_CHECK(has_bw_view(), "backward view info can only exist for backward views.");
return backward_info_.value();
}
DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base, c10::optional<std::function<at::Tensor(const at::Tensor&)>> view_fn,
CreationMeta creation_meta=CreationMeta::DEFAULT);
~DifferentiableViewMeta();
uint32_t get_attr_version() const {
TORCH_CHECK(has_bw_view(), "attr_version can only exist for backward views.");
return attr_version;
}
void set_attr_version(uint32_t new_attr_version) {
TORCH_CHECK(has_bw_view(), "attr_version can only exist for backward views.");
attr_version = new_attr_version;
}
CreationMeta get_creation_meta() const {
TORCH_CHECK(has_bw_view(), "creation_meta can only exist for backward views.");
return creation_meta;
}
void set_creation_meta(CreationMeta new_creation_meta) {
TORCH_CHECK(has_bw_view(), "creation_meta can only exist for backward views.");
creation_meta = new_creation_meta;
}
bool has_fw_view() const {
return forward_info_.has_value();
}
const ViewInfo& get_forward_view() const {
TORCH_CHECK(has_fw_view(), "forward view info can only exist for forward views.");
return forward_info_.value();
}
DifferentiableViewMeta(at::TensorImpl* self_impl, c10::optional<ViewInfo> backward_info,
c10::optional<ViewInfo> forward_info, CreationMeta creation_meta=CreationMeta::DEFAULT);
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -448,10 +593,11 @@ struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
// See NOTE [ Autograd View Variables ] for details.
// Differentiable view. Track history with DifferentiableViewMeta.
inline Variable make_variable_differentiable_view(
Variable base,
const at::Tensor& data,
c10::optional<ViewInfo> backward_info,
c10::optional<ViewInfo> forward_info,
CreationMeta creation_meta,
c10::optional<std::function<at::Tensor(const at::Tensor&)>> view_func = c10::nullopt) {
bool allow_tensor_metadata_change = true) {
if (data.defined()) {
// If we already did a TensorImpl allocation for data, just reuse it.
// Otherwise(e.g tensor.swapdim(0, 0) when we return the same tensor as input),
@ -461,14 +607,16 @@ inline Variable make_variable_differentiable_view(
if (data.getIntrusivePtr().unique() && data.getIntrusivePtr()->unique_version()) {
at::TensorImpl* data_impl = data.unsafeGetTensorImpl();
data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
data_impl, std::move(base), std::move(view_func), creation_meta));
data_impl, std::move(backward_info), std::move(forward_info),
creation_meta));
return data;
} else {
c10::intrusive_ptr<at::TensorImpl> data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/true);
data_impl_copy->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
data_impl_copy.get(), std::move(base), std::move(view_func), creation_meta));
data_impl_copy.get(), std::move(backward_info), std::move(forward_info),
creation_meta));
return Variable(data_impl_copy);
}
}

View File

@ -181,6 +181,8 @@ def get_ignored_functions() -> Set[Callable]:
torch.is_deterministic,
torch.set_deterministic,
torch.unify_type_list,
torch.make_dual,
torch.unpack_dual,
Tensor.__delitem__,
Tensor.__dir__,
Tensor.__getattribute__,