mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reland: Add base forward grad logic (#49734)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49734 RFC: https://github.com/pytorch/rfcs/pull/11 This PR add the basic logic to handle forward grad as dual Tensors. It contains the following: - Mechanism to save dual state on a Tensor and clear it up when the dual level ends - C++ and python user facing API - Updated view system that is able to track both forward and backward views The current PR has the following limitations: - Extensive tests are in the next PR in the stack as formulas are needed to write full tests. - Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack) - Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR. - We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise. - We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise. Reading guide: - Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d). This introduces the new ViewInfo to hold view informations shared for forward and backward. It also updates the differentiable view meta to use this. And it updates the as_view function to handle both forward and backward view. - New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325). EDIT: These files also contain the new flag to globally disable forward AD that allows us to reduce performance issues while this is in development. - Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677) - API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243) - c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9) - python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d) - python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8) - c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3) - Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433) - Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030) Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D25678797 Pulled By: albanD fbshipit-source-id: 3d58550c11b5f58b9b73fd30596d042b857fb9dd
This commit is contained in:
committed by
Facebook GitHub Bot
parent
eabe05ab72
commit
c23808d8e8
@ -292,6 +292,11 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
|
||||
stream << ", axis: " << tensor_.q_per_channel_axis();
|
||||
}
|
||||
}
|
||||
|
||||
auto& fw_grad = tensor.fw_grad(/* level */ 0);
|
||||
if (fw_grad.defined()) {
|
||||
stream << ", tangent:" << std::endl << fw_grad;
|
||||
}
|
||||
stream << " ]";
|
||||
}
|
||||
return stream;
|
||||
|
@ -510,4 +510,5 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
|
||||
m.impl("_version", CppFunction::makeFallthrough());
|
||||
m.impl("requires_grad_", CppFunction::makeFallthrough());
|
||||
m.impl("retain_grad", CppFunction::makeFallthrough());
|
||||
m.impl("_fw_primal", CppFunction::makeFallthrough());
|
||||
}
|
||||
|
27
aten/src/ATen/native/AutogradComposite.cpp
Normal file
27
aten/src/ATen/native/AutogradComposite.cpp
Normal file
@ -0,0 +1,27 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
/// This function can be used to create a dual Tensor that holds a tangent to compute forward mode gradients.
|
||||
/// Note that the dual Tensor's primal is a view of the given primal and the given tangent is used as-is.
|
||||
/// This function is backward differentiable.
|
||||
at::Tensor make_dual(const at::Tensor& primal, const at::Tensor& tangent, int64_t level) {
|
||||
TORCH_CHECK(!primal.fw_grad(level).defined(), "Making a dual Tensor based on a Tensor that "
|
||||
"already has a forward gradient at the same level ", level, " is not supported.");
|
||||
|
||||
auto dual_tensor = primal.view(primal.sizes());
|
||||
dual_tensor.set_fw_grad(tangent, level, /* is_inplace_op */ false);
|
||||
return dual_tensor;
|
||||
}
|
||||
|
||||
/// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal
|
||||
/// is a view of the dual and the tangent is returned as is.
|
||||
/// This function is backward differentiable.
|
||||
std::tuple<at::Tensor, at::Tensor> unpack_dual(const at::Tensor& tensor, int64_t level) {
|
||||
return std::tuple<at::Tensor, at::Tensor>(tensor._fw_primal(level), tensor.fw_grad(level));
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
||||
} // namespace at
|
@ -40,5 +40,9 @@ void retain_grad(Tensor& self) {
|
||||
AT_ERROR("retain_grad is not implemented for Tensor");
|
||||
}
|
||||
|
||||
Tensor _fw_primal(const Tensor& self, int64_t level) {
|
||||
AT_ERROR("_fw_primal is not implemented for Tensor");
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -105,6 +105,20 @@
|
||||
manual_kernel_registration: True
|
||||
variants: method
|
||||
|
||||
- func: _fw_primal(Tensor(a) self, int level) -> Tensor(a)
|
||||
use_c10_dispatcher: full
|
||||
variants: method
|
||||
dispatch:
|
||||
DefaultBackend: _fw_primal
|
||||
|
||||
- func: make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a)
|
||||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
|
||||
- func: unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent)
|
||||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
|
||||
- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
|
||||
use_c10_dispatcher: full
|
||||
variants: method
|
||||
|
@ -599,6 +599,23 @@ class TORCH_API Tensor {
|
||||
return impl_->grad();
|
||||
}
|
||||
|
||||
// The Forward AD API functions below are low level and are not to be used by end
|
||||
// users who should use the API provided in torch/csrc/autograd.h
|
||||
|
||||
/// This function returns the forward gradient for this Tensor at the given level.
|
||||
const Tensor& fw_grad(uint64_t level) const {
|
||||
return impl_->fw_grad(level, *this);
|
||||
}
|
||||
|
||||
/// This function can be used to set the value of the forward grad.
|
||||
/// Note that the given new_grad might not be used directly if it has different
|
||||
/// metadata (size/stride/storage offset) compared to this Tensor. In that case,
|
||||
/// new_grad content will be copied into a new Tensor
|
||||
void set_fw_grad(const Tensor& new_grad, uint64_t level, bool is_inplace_op) {
|
||||
impl_->set_fw_grad(new_grad, *this, level, is_inplace_op);
|
||||
}
|
||||
|
||||
|
||||
// STOP. Thinking of adding a method here, which only makes use
|
||||
// of other ATen methods? Define it in native_functions.yaml.
|
||||
|
||||
|
@ -44,6 +44,17 @@ const at::Tensor& TensorImpl::grad() const {
|
||||
return autograd_meta_->grad();
|
||||
}
|
||||
|
||||
const at::Tensor& TensorImpl::fw_grad(uint64_t level, const at::Tensor& self) const {
|
||||
// See TensorImpl::grad() above for explanation about the line below
|
||||
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor();
|
||||
return autograd_meta_->fw_grad(level, self);
|
||||
}
|
||||
|
||||
void TensorImpl::set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) {
|
||||
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
|
||||
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
|
||||
}
|
||||
|
||||
TensorImpl::TensorImpl(
|
||||
Storage&& storage,
|
||||
DispatchKeySet key_set,
|
||||
|
@ -136,6 +136,8 @@ struct C10_API AutogradMetaInterface {
|
||||
virtual bool requires_grad() const = 0;
|
||||
virtual at::Tensor& mutable_grad() = 0;
|
||||
virtual const at::Tensor& grad() const = 0;
|
||||
virtual const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const = 0;
|
||||
virtual void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) = 0;
|
||||
virtual ~AutogradMetaInterface();
|
||||
};
|
||||
|
||||
@ -598,6 +600,42 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
*/
|
||||
const at::Tensor& grad() const;
|
||||
|
||||
/**
|
||||
* Return the accumulated gradient of a tensor. This gradient is computed
|
||||
* using forward mode AD.
|
||||
*
|
||||
* This is an internal API that should never be used by end users.
|
||||
*
|
||||
* The API is as follows:
|
||||
* - "level" allows to specify the level of forward AD nesting for which the
|
||||
* gradient should be returned. Note that since levels are not fully
|
||||
* supported yet, this argument should be 0. See documentation for
|
||||
* torch::autograd::enter_dual_level for more details about forward AD nesting.
|
||||
* - "self" should represent the Tensor whose forward grad is accessed. It is
|
||||
* required when dealing with view.
|
||||
*/
|
||||
const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const;
|
||||
|
||||
/**
|
||||
* Sets the forward gradient for this Tensor.
|
||||
* The given Tensor might not be used directly and its content will be copied.
|
||||
*
|
||||
* This is an internal API that should never be used by end users.
|
||||
*
|
||||
* The API is as follows:
|
||||
* - "new_grad" is a Tensor containing the new value of the gradient that should
|
||||
* be set
|
||||
* - "self" should reprensent the Tensor whose forward grad is accessed. It is
|
||||
* required when dealing with view.
|
||||
* - "level" allows to specify the level of forward AD nesting for which the
|
||||
* gradient should be set. Note that since levels are not fully supported
|
||||
* yet, this argument should be 0. See documentation for torch::autograd::enter_dual_level
|
||||
* for more details about forward AD nesting.
|
||||
* - "is_inplace_op" is a boolean flag that tells if this gradient was generated
|
||||
* by an inplace operation or an out of place one. This allows better error checking.
|
||||
*/
|
||||
void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op);
|
||||
|
||||
/**
|
||||
* Return a typed data pointer to the actual data which this tensor refers to.
|
||||
* This checks that the requested type (from the template parameter) matches
|
||||
|
@ -35,6 +35,7 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoL
|
||||
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck)
|
||||
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
|
||||
from torch.autograd.function import InplaceFunction
|
||||
import torch.autograd.forward_ad as fwAD
|
||||
from torch.testing import randn_like
|
||||
from torch.testing._internal.common_methods_invocations import (method_tests,
|
||||
create_input, unpack_variables,
|
||||
@ -5326,6 +5327,26 @@ class TestAutogradComplex(TestCase):
|
||||
|
||||
self.assertEqual(x.grad, y.grad)
|
||||
|
||||
def test_view_with_multi_output(self):
|
||||
x = torch.randn(2, 2, 2, dtype=torch.double)
|
||||
|
||||
x1 = torch.view_as_complex(x)
|
||||
# Taking an invalid view should always be allowed as long as it is not
|
||||
# modified inplace
|
||||
res = x1.unbind(0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
|
||||
res[0] += torch.rand(2, requires_grad=True)
|
||||
|
||||
x.requires_grad_(True)
|
||||
x1 = torch.view_as_complex(x)
|
||||
# Taking an invalid view should always be allowed as long as it is not
|
||||
# modified inplace
|
||||
res = x1.unbind(0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
|
||||
res[0] += torch.rand(2, requires_grad=True)
|
||||
|
||||
def as_identity(self):
|
||||
# view_as_real and view_as_complex behavior should be like an identity
|
||||
def func(z):
|
||||
@ -6324,6 +6345,66 @@ class TestAutogradFunctional(TestCase):
|
||||
self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1))
|
||||
self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0))
|
||||
|
||||
class TestAutogradForwardMode(TestCase):
|
||||
def test_forward_level_cleanup(self):
|
||||
import weakref
|
||||
|
||||
def get_tensor_and_weak_ref():
|
||||
# Helper function to get a Tensor and a weak ref that tells us
|
||||
# if the c++ version of this Tensor is still alive or not.
|
||||
#
|
||||
# Create the following reference chain to do so:
|
||||
# - python Tensor t
|
||||
# - c++ Tensor corresponding by t
|
||||
# - c++ Node corresponding to t.grad_fn
|
||||
# - python dict of metadata from this Node
|
||||
# - an object in this dict that we can take a weakref of
|
||||
|
||||
|
||||
# Create a new Tensor and Node
|
||||
t = torch.rand(2, requires_grad=True).clone()
|
||||
# Create the metadata dict
|
||||
meta_dict = t.grad_fn.metadata
|
||||
# Create the object in the dict
|
||||
|
||||
class Foo(object):
|
||||
pass
|
||||
my_obj = Foo()
|
||||
meta_dict[0] = my_obj
|
||||
|
||||
# After exiting this function, the python Tensor t is the only
|
||||
# thing keeping ref alive
|
||||
ref = weakref.ref(my_obj)
|
||||
return t, ref
|
||||
|
||||
# Sanity check that the helper function works as expected
|
||||
t, t_ref = get_tensor_and_weak_ref()
|
||||
self.assertIsNotNone(t_ref())
|
||||
|
||||
del t
|
||||
self.assertIsNone(t_ref())
|
||||
|
||||
# Main test code
|
||||
foo = torch.rand(2)
|
||||
|
||||
with fwAD.dual_level():
|
||||
tangent, tangent_ref = get_tensor_and_weak_ref()
|
||||
self.assertIsNotNone(tangent_ref())
|
||||
|
||||
dual = fwAD.make_dual(foo, tangent)
|
||||
self.assertIsNotNone(tangent_ref())
|
||||
|
||||
# Make sure that the tangent we provided has been re-used as is
|
||||
self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent)
|
||||
|
||||
# Make sure that dual is keeping the tangent alive
|
||||
del tangent
|
||||
self.assertIsNotNone(tangent_ref())
|
||||
|
||||
# Make sure that the dual level does not keep the c++
|
||||
# version of the tangent alive
|
||||
del dual
|
||||
self.assertIsNone(tangent_ref())
|
||||
|
||||
# Generic device type autograd tests.
|
||||
class TestAutogradDeviceType(TestCase):
|
||||
|
@ -12,7 +12,7 @@ aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.
|
||||
all_operators_with_namedtuple_return = {
|
||||
'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig',
|
||||
'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq',
|
||||
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh'
|
||||
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual"
|
||||
}
|
||||
|
||||
|
||||
@ -65,6 +65,7 @@ class TestNamedTupleAPI(unittest.TestCase):
|
||||
op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True),
|
||||
op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True),
|
||||
op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True),
|
||||
op(operators=['unpack_dual'], input=(a, 0), names=('primal', 'tangent'), hasout=False),
|
||||
]
|
||||
|
||||
for op in operators:
|
||||
@ -75,7 +76,9 @@ class TestNamedTupleAPI(unittest.TestCase):
|
||||
for i, name in enumerate(op.names):
|
||||
self.assertIs(getattr(ret, name), ret[i])
|
||||
else:
|
||||
ret = getattr(a, f)(*op.input)
|
||||
# Handle op that are not methods
|
||||
func = getattr(a, f) if hasattr(a, f) else getattr(torch, f)
|
||||
ret = func(*op.input)
|
||||
for i, name in enumerate(op.names):
|
||||
self.assertIs(getattr(ret, name), ret[i])
|
||||
if op.hasout:
|
||||
|
@ -80,7 +80,8 @@ SKIP_PYTHON_BINDINGS = [
|
||||
'nonzero(_(out|numpy))?',
|
||||
'set_data',
|
||||
'.*_overrideable', # overrideable functions for backend extension
|
||||
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_'
|
||||
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_',
|
||||
'_fw_primal'
|
||||
]
|
||||
|
||||
# These function signatures are not exposed to Python. Note that this signature
|
||||
|
@ -25,7 +25,7 @@ MANUAL_BACKEND = set([
|
||||
# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys.
|
||||
# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
|
||||
MANUAL_AUTOGRAD_AND_TRACER = set([
|
||||
'resize_', 'resize_as_', 'detach', 'detach_', 'copy_',
|
||||
'resize_', 'resize_as_', 'detach', 'detach_', 'copy_', '_fw_primal',
|
||||
])
|
||||
|
||||
# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops:
|
||||
|
@ -689,7 +689,7 @@ def emit_body(declaration):
|
||||
|
||||
if len(differentiable_output_vars) == 0:
|
||||
# no output is differentiable (.indices() for SparseTensors for example)
|
||||
rhs_value = 'as_view({}, {}, /* is_differentiable */ false)'.format(view_info, var)
|
||||
rhs_value = f'as_view({view_info}, {var}, /* is_bw_differentiable */ false, /* is_fw_differentiable */ false)'
|
||||
elif len(differentiable_output_vars) == 1:
|
||||
# Single differentiable output (Tensor or Tensor[])
|
||||
return_info = differentiable_outputs[0]
|
||||
@ -704,13 +704,15 @@ def emit_body(declaration):
|
||||
creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE"
|
||||
else:
|
||||
creation_meta = "CreationMeta::MULTI_OUTPUT_NODE"
|
||||
call += ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, "
|
||||
"/* creation_meta */ {});\n").format(view_info, var, creation_meta)
|
||||
call += ("as_view(/* base */ {}, /* output */ {}, /* is_bw_differentiable */ true, "
|
||||
"/* is_fw_differentiable */ true, "
|
||||
"/* creation_meta */ {});").format(view_info, var, creation_meta)
|
||||
rhs_value = 'std::move({})'.format(var)
|
||||
else:
|
||||
call += emit_view_lambda()
|
||||
creation_meta = "GradMode::is_enabled() ? CreationMeta::DEFAULT: CreationMeta::NO_GRAD_MODE"
|
||||
rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, "
|
||||
rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_bw_differentiable */ true, "
|
||||
"/* is_fw_differentiable */ true, "
|
||||
"/* view_func */ func, /* creation_meta */ {})").format(view_info, var, creation_meta)
|
||||
else:
|
||||
# This could be supported but we don't need it at the moment, so keeping things simple.
|
||||
|
@ -90,6 +90,8 @@ core_sources_common = [
|
||||
"torch/csrc/autograd/profiler_legacy.cpp",
|
||||
"torch/csrc/autograd/profiler_kineto.cpp",
|
||||
"torch/csrc/autograd/profiler_utils.cpp",
|
||||
"torch/csrc/autograd/autograd_meta.cpp",
|
||||
"torch/csrc/autograd/forward_grad.cpp",
|
||||
"torch/csrc/jit/frontend/edit_distance.cpp",
|
||||
"torch/csrc/jit/frontend/string_to_type.cpp",
|
||||
"torch/csrc/jit/mobile/type_parser.cpp",
|
||||
|
@ -522,6 +522,12 @@ def autocast_increment_nesting() -> _int: ...
|
||||
def autocast_decrement_nesting() -> _int: ...
|
||||
def set_anomaly_enabled(enabled: _bool) -> None: ...
|
||||
def is_anomaly_enabled() -> _bool: ...
|
||||
def _enter_dual_level() -> _int: ...
|
||||
def _exit_dual_level(level: _int) -> None: ...
|
||||
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...
|
||||
def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ...
|
||||
def __set_forward_AD_enabled(enabled: _bool) -> None: ...
|
||||
def __is_forward_AD_enabled() -> _bool: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class LoggerBase(object):
|
||||
|
@ -275,11 +275,16 @@ def get_summarized_data(self):
|
||||
else:
|
||||
return torch.stack([get_summarized_data(x) for x in self])
|
||||
|
||||
def _str_intern(self):
|
||||
def _str_intern(inp):
|
||||
prefix = 'tensor('
|
||||
indent = len(prefix)
|
||||
suffixes = []
|
||||
|
||||
# This is used to extract the primal value and thus disable the forward AD
|
||||
# within this function.
|
||||
# TODO(albanD) This needs to be updated when more than one level is supported
|
||||
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
|
||||
|
||||
# Note [Print tensor device]:
|
||||
# A general logic here is we only print device when it doesn't match
|
||||
# the device specified in default tensor type.
|
||||
@ -355,17 +360,22 @@ def _str_intern(self):
|
||||
if self.layout != torch.strided:
|
||||
suffixes.append('layout=' + str(self.layout))
|
||||
|
||||
if self.grad_fn is not None:
|
||||
name = type(self.grad_fn).__name__
|
||||
# Use inp here to get the original grad_fn and not the one generated by the forward grad
|
||||
# unpacking.
|
||||
if inp.grad_fn is not None:
|
||||
name = type(inp.grad_fn).__name__
|
||||
if name == 'CppFunction':
|
||||
name = self.grad_fn.name().rsplit('::', 1)[-1]
|
||||
name = inp.grad_fn.name().rsplit('::', 1)[-1]
|
||||
suffixes.append('grad_fn=<{}>'.format(name))
|
||||
elif self.requires_grad:
|
||||
elif inp.requires_grad:
|
||||
suffixes.append('requires_grad=True')
|
||||
|
||||
if self.has_names():
|
||||
suffixes.append('names={}'.format(self.names))
|
||||
|
||||
if tangent is not None:
|
||||
suffixes.append('tangent={}'.format(tangent))
|
||||
|
||||
return _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)
|
||||
|
||||
def _str(self):
|
||||
|
@ -19,6 +19,7 @@ from .grad_mode import no_grad, enable_grad, set_grad_enabled
|
||||
from .anomaly_mode import detect_anomaly, set_detect_anomaly
|
||||
from ..overrides import has_torch_function, handle_torch_function
|
||||
from . import functional
|
||||
from . import forward_ad
|
||||
|
||||
__all__ = ['Variable', 'Function', 'backward', 'grad_mode']
|
||||
|
||||
|
116
torch/autograd/forward_ad.py
Normal file
116
torch/autograd/forward_ad.py
Normal file
@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from .grad_mode import _DecoratorContextManager
|
||||
|
||||
from typing import Any
|
||||
|
||||
# TODO(alband): Once most of the formulas are implemented, these functions need to be added
|
||||
# to the main doc to make them fully "public".
|
||||
|
||||
# Global variable used to make the python API simpler to use
|
||||
_current_level = -1
|
||||
|
||||
def enter_dual_level():
|
||||
r"""Function that can be used to enter a new forward grad level.
|
||||
This level can be used to make and unpack dual Tensors to compute
|
||||
forward gradients.
|
||||
|
||||
This function also updates the current level that is used by default
|
||||
by the other functions in this API.
|
||||
"""
|
||||
global _current_level
|
||||
new_level = torch._C._enter_dual_level()
|
||||
if new_level != _current_level + 1:
|
||||
raise RuntimeError("Entering a new forward AD level but the current level "
|
||||
"is not valid. Make sure you did not modified it directly.")
|
||||
_current_level = new_level
|
||||
return new_level
|
||||
|
||||
def exit_dual_level(*, level=None):
|
||||
r"""Function that can be used to exit a forward grad level.
|
||||
This function deletes all the gradients associated with this
|
||||
level. Only deleting the latest entered level is allowed.
|
||||
|
||||
This function also updates the current level that is used by default
|
||||
by the other functions in this API.
|
||||
"""
|
||||
global _current_level
|
||||
if level is None:
|
||||
level = _current_level
|
||||
if level != _current_level:
|
||||
raise RuntimeError("Trying to exit a forward AD level that was not the last one "
|
||||
"that was created. This is not supported.")
|
||||
torch._C._exit_dual_level(level=level)
|
||||
_current_level = level - 1
|
||||
|
||||
def make_dual(tensor, tangent, *, level=None):
|
||||
r"""Function that creates a "dual object" that can be used to compute forward AD gradients
|
||||
based on the given Tensor and its tangent. It returns a new Tensor that shares memory with
|
||||
:attr:`tensor` and the :attr:`tangent` is used as-is.
|
||||
|
||||
This function is backward differentiable.
|
||||
|
||||
Given a function `f` whose jacobian is `J`, it allows to compute the jacobian vector product,
|
||||
named `jvp`, between `J` and a given vector `v` as follows.
|
||||
|
||||
Example::
|
||||
>>> inp = make_dual(x, v)
|
||||
>>> out = f(inp)
|
||||
>>> y, jvp = unpack_dual(out)
|
||||
|
||||
"""
|
||||
if level is None:
|
||||
level = _current_level
|
||||
|
||||
if level < 0:
|
||||
raise RuntimeError("Trying to create a dual Tensor for forward AD but no level "
|
||||
"exists, make sure to enter_dual_level() first.")
|
||||
|
||||
return torch.make_dual(tensor, tangent, level=level)
|
||||
|
||||
def unpack_dual(tensor, *, level=None):
|
||||
r"""Function that unpacks a "dual object" to recover two plain tensors, one representing
|
||||
the primal and the other the tangent (both are views of :attr:`tensor`. Neither of these
|
||||
tensors can be dual tensor of level :attr:`level`.
|
||||
|
||||
This function is backward differentiable.
|
||||
"""
|
||||
if level is None:
|
||||
level = _current_level
|
||||
|
||||
if level < 0:
|
||||
return tensor, None
|
||||
|
||||
return torch.unpack_dual(tensor, level=level)
|
||||
|
||||
class dual_level(_DecoratorContextManager):
|
||||
r"""Context-manager that controls the current forward ad level. It
|
||||
appropriately enters and exit the dual level.
|
||||
|
||||
This function also updates the current level that is used by default
|
||||
by the other functions in this API.
|
||||
|
||||
Example::
|
||||
|
||||
>>> x = torch.tensor([1])
|
||||
>>> x_t = torch.tensor([1])
|
||||
>>> with dual_level():
|
||||
... inp = make_dual(x, x_t)
|
||||
... # Do computations with inp
|
||||
... out = your_fn(inp)
|
||||
... _, grad = unpack_dual(out)
|
||||
>>> grad is None
|
||||
False
|
||||
>>> # After exiting the level, the grad is deleted
|
||||
>>> _, grad_after = unpack_dual(out)
|
||||
>>> grad is None
|
||||
True
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __enter__(self):
|
||||
return enter_dual_level()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
exit_dual_level()
|
@ -35,10 +35,22 @@ bool isDefined(const c10::optional<Tensor>& t) {
|
||||
return t.has_value() && t->defined();
|
||||
}
|
||||
|
||||
bool isFwGradDefined(const c10::optional<Tensor>& t) {
|
||||
return t.has_value() && t->defined() && t->fw_grad(/*level */ 0).defined();
|
||||
}
|
||||
|
||||
Tensor toLegacyTensor(const c10::optional<Tensor>& t) {
|
||||
return t.has_value() ? *t : Tensor();
|
||||
}
|
||||
|
||||
Tensor toLegacyFwGrad(const c10::optional<Tensor>& t) {
|
||||
return (t.has_value() && t->defined()) ? t->fw_grad(/*level */ 0) : Tensor();
|
||||
}
|
||||
|
||||
Tensor toLegacyPrimal(const c10::optional<Tensor>& t) {
|
||||
return (t.has_value() && t->defined()) ? t->_fw_primal(/*level */ 0) : Tensor();
|
||||
}
|
||||
|
||||
void copy_range(variable_list& out, IndexRange range, const Tensor & t) {
|
||||
AT_ASSERT(range.second <= out.size());
|
||||
AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output");
|
||||
|
@ -29,6 +29,10 @@ struct IndexRangeGenerator {
|
||||
size_t i = 0;
|
||||
};
|
||||
|
||||
bool isFwGradDefined(const c10::optional<Tensor>& t);
|
||||
Tensor toLegacyFwGrad(const c10::optional<Tensor>& t);
|
||||
Tensor toLegacyPrimal(const c10::optional<Tensor>& t);
|
||||
|
||||
bool any_variable_defined(variable_list& variables);
|
||||
void copy_range(variable_list& out, IndexRange range, const at::Tensor & t);
|
||||
void copy_range(variable_list& out, IndexRange range, at::ArrayRef<at::Tensor> t);
|
||||
|
@ -139,6 +139,7 @@ TORCH_LIBRARY_IMPL(aten, Tracer, m) {
|
||||
m.impl("_version", CppFunction::makeFallthrough());
|
||||
m.impl("requires_grad_", CppFunction::makeFallthrough());
|
||||
m.impl("retain_grad", CppFunction::makeFallthrough());
|
||||
m.impl("_fw_primal", CppFunction::makeFallthrough());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
||||
#include <torch/csrc/autograd/FunctionsManual.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
#include <torch/csrc/autograd/utils/error_messages.h>
|
||||
#include <torch/csrc/autograd/autograd.h>
|
||||
@ -194,6 +195,39 @@ void retain_grad(Tensor & self) {
|
||||
impl::get_autograd_meta(self)->retains_grad_ = true;
|
||||
}
|
||||
|
||||
// Taken from codegened version
|
||||
Tensor _fw_primal(const Tensor & self, int64_t level) {
|
||||
auto& self_ = unpack(self, "self", 0);
|
||||
std::shared_ptr<Identity> grad_fn;
|
||||
if (compute_requires_grad( self )) {
|
||||
grad_fn = std::make_shared<Identity>();
|
||||
grad_fn->set_next_edges(collect_next_edges( self ));
|
||||
}
|
||||
auto tmp = ([&]() {
|
||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||
return self_.alias();
|
||||
})();
|
||||
c10::optional<std::function<at::Tensor(const at::Tensor&)>> func=c10::nullopt;
|
||||
if (!self.unsafeGetTensorImpl()->support_as_strided()) {
|
||||
auto size_vec = self.sizes().vec();
|
||||
func = [=](const at::Tensor& input_base) {
|
||||
return input_base.view(size_vec);
|
||||
};
|
||||
}
|
||||
auto result = as_view(/* base */ self, /* output */ tmp, /* is_bw_differentiable */ true,
|
||||
/* is_fw_differentiable */ false, /* view_func */ func, /* creation_meta */ CreationMeta::DEFAULT);
|
||||
if (grad_fn) {
|
||||
set_history(flatten_tensor_args( result ), grad_fn);
|
||||
}
|
||||
if (generated::details::isFwGradDefined(self)) {
|
||||
// Modified from original codegen
|
||||
// We explicitly want to ignore the forward grad at the given level
|
||||
TORCH_CHECK(level == 0, "Invalid level given to _fw_primal");
|
||||
// End modified from original codegen
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// We don't have an outplace copy, so this can't be generated automatically
|
||||
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
|
||||
jit::Value* output = nullptr;
|
||||
@ -217,6 +251,24 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
|
||||
}
|
||||
increment_version(self);
|
||||
rebase_history(self , std::move(grad_fn));
|
||||
|
||||
if (isDifferentiableType(self.scalar_type()) &&
|
||||
(generated::details::isFwGradDefined(self) || generated::details::isFwGradDefined(src))) {
|
||||
auto self_fw_grad = generated::details::toLegacyFwGrad(self);
|
||||
auto src_fw_grad = generated::details::toLegacyFwGrad(src);
|
||||
Tensor new_fw_grad;
|
||||
if (self_fw_grad.defined()) {
|
||||
if (src_fw_grad.defined()) {
|
||||
new_fw_grad = self_fw_grad.copy_(src_fw_grad);
|
||||
} else {
|
||||
new_fw_grad = self_fw_grad.fill_(0);
|
||||
}
|
||||
} else {
|
||||
new_fw_grad = src_fw_grad;
|
||||
}
|
||||
self.set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ true);
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
@ -232,6 +284,11 @@ Tensor& resize_(
|
||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||
self_.resize_(size, std::move(optional_memory_format));
|
||||
}
|
||||
|
||||
if (self.fw_grad(/* level */ 0).defined()) {
|
||||
AT_ERROR("cannot resize variables that has a forward grad");
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
@ -248,13 +305,28 @@ Tensor& resize_as_(
|
||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||
at::resize_as_(self_, the_template_, std::move(optional_memory_format));
|
||||
}
|
||||
|
||||
// Handle fw grad
|
||||
if (self.fw_grad(/* level */ 0).defined()) {
|
||||
AT_ERROR("cannot resize variables that has a forward grad");
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor detach(const Tensor & self) {
|
||||
RECORD_FUNCTION("detach", std::vector<c10::IValue>({self}));
|
||||
auto result = make_variable_non_differentiable_view(self, self, /*allow_tensor_metadata_change=*/false);
|
||||
c10::optional<std::function<at::Tensor(const at::Tensor&)>> func=c10::nullopt;
|
||||
auto result = as_view(/* base */ self, /* output */ self, /* is_bw_differentiable */ false,
|
||||
/* is_fw_differentiable */ true, /* view_func */ func, /* creation_meta */ CreationMeta::DEFAULT,
|
||||
/*allow_tensor_metadata_change=*/false);
|
||||
namedinference::propagate_names(result, self);
|
||||
|
||||
// detach only backward gradients for both primal and tangent
|
||||
if (self.fw_grad(/* level */ 0).defined()) {
|
||||
auto new_fw_grad = self.fw_grad(/* level */ 0).detach();
|
||||
result.set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -264,7 +336,7 @@ Tensor & detach_(Tensor & self) {
|
||||
// NB: is_view() ==> get_autograd_meta()
|
||||
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(self));
|
||||
// See NOTE [ View + Inplace detection ]
|
||||
if (diff_view_meta->creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) {
|
||||
if (diff_view_meta->get_creation_meta() == CreationMeta::MULTI_OUTPUT_SAFE) {
|
||||
TORCH_WARN("This view is an output of a function that "
|
||||
"returns multiple views. Detaching such views inplace "
|
||||
"is being deprecated and will be forbidden "
|
||||
@ -272,7 +344,8 @@ Tensor & detach_(Tensor & self) {
|
||||
"of detach_(). Alternatively, create this view with an "
|
||||
"`unsafe_` version of the function that produced it.");
|
||||
} else {
|
||||
AT_ERROR("If you are using DistributedDataParallel (DDP) for training, "
|
||||
AT_ERROR("Can't detach views in-place. Use detach() instead. "
|
||||
"If you are using DistributedDataParallel (DDP) for training, "
|
||||
"and gradient_as_bucket_view is set as True, gradients are "
|
||||
"views of DDP buckets, and hence detach_() cannot be called "
|
||||
"on these gradients. To fix this error, please refer to the "
|
||||
@ -290,6 +363,12 @@ Tensor & detach_(Tensor & self) {
|
||||
autograd_meta->set_requires_grad(false, self.unsafeGetTensorImpl());
|
||||
autograd_meta->grad_fn_.reset();
|
||||
autograd_meta->output_nr_ = 0;
|
||||
|
||||
// detach only backward gradients for both primal and tangent
|
||||
if (self.fw_grad(/* level */ 0).defined()) {
|
||||
self.fw_grad(/* level */ 0).detach_();
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
@ -321,6 +400,7 @@ TORCH_LIBRARY_IMPL(aten, Autograd, m) {
|
||||
// and requires_grad_(), then remove the backend Autograd kernel here, only leaving the Math kernel.
|
||||
m.impl("_backward", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::_backward)));
|
||||
m.impl("requires_grad_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::requires_grad_)));
|
||||
m.impl("_fw_primal", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::_fw_primal)));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) {
|
||||
|
@ -134,88 +134,111 @@ template<typename... Args> inline variable_list flatten_tensor_args(Args&&... ar
|
||||
}
|
||||
|
||||
// See NOTE [ Autograd View Variables ] for details.
|
||||
inline Tensor as_view(const Tensor & base, const Tensor& tensor, bool is_differentiable,
|
||||
c10::optional<std::function<Tensor(const Tensor&)>> view_func=c10::nullopt,
|
||||
CreationMeta creation_meta=CreationMeta::DEFAULT) {
|
||||
auto base_var = Variable(base);
|
||||
if (base_var.is_view()) {
|
||||
// Set `view_func` using the root base as input.
|
||||
// `view_func` is used to recover views in backward when either as_strided is not supported
|
||||
// or the view function changes the metadata which is not recorded by as_strided
|
||||
// See Note [View + Inplace update on base tensor] and [View + Inplace update on view tensor]
|
||||
// for more details how we use this function in backward.
|
||||
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base_var));
|
||||
if (view_func.has_value()) {
|
||||
auto fn = view_func.value();
|
||||
// both current_view and it's parent have a view_func
|
||||
if (diff_view_meta->has_view_fn()) {
|
||||
auto prev_fn = diff_view_meta->view_fn();
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
auto temp = prev_fn(root_base);
|
||||
return fn(temp);
|
||||
};
|
||||
inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_differentiable,
|
||||
bool is_fw_differentiable, c10::optional<std::function<Tensor(const Tensor&)>> view_func=c10::nullopt,
|
||||
CreationMeta creation_meta=CreationMeta::DEFAULT, bool allow_tensor_metadata_change=true) {
|
||||
if (!isForwardADEnabled()) {
|
||||
// Fast codepath for backward only code
|
||||
// It is useful as it avoids the creation of the temporary c10<optional> which makes
|
||||
// a significant difference when measuring instruction count for a single "t.view(-1)" call from c++.
|
||||
if (is_bw_differentiable) {
|
||||
if (base.is_view()) {
|
||||
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
|
||||
const auto& base_bw_info = diff_view_meta->get_backward_view();
|
||||
return make_variable_differentiable_view(tensor, base_bw_info.chain(base, tensor, view_func),
|
||||
c10::nullopt, creation_meta, allow_tensor_metadata_change);
|
||||
} else {
|
||||
// current_view has a view_func and but it's parent doesn't have one
|
||||
if(base_var.unsafeGetTensorImpl()->support_as_strided()) {
|
||||
auto size = base.sizes().vec();
|
||||
auto stride = base.strides().vec();
|
||||
auto storage_offset = base.storage_offset();
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
auto temp = root_base.as_strided(size, stride, storage_offset);
|
||||
return fn(temp);
|
||||
};
|
||||
} else {
|
||||
// When base_var is a view but doesn't carry a view_fn in DifferentiableViewMeta, it's
|
||||
// a view that doesn't support inplace update, e.g. unbind.
|
||||
// In this case we should throw an error when inplace update happens in **forward**.
|
||||
// One would naturally think the following function will be first called in backward pass.
|
||||
// But the first call site is indeed in **forward** pass when we refresh `grad_fn`
|
||||
// triggered by inplace update.
|
||||
// Search Note [View + Inplace update for view tensor] to for the call site.
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
TORCH_CHECK(false, "This view is the output of a function that returns multiple views."
|
||||
"Such functions do not allow the output views to be modified inplace."
|
||||
"You should replace the inplace operation by an out-of-place one");
|
||||
return root_base;
|
||||
};
|
||||
}
|
||||
return make_variable_differentiable_view(tensor, ViewInfo(base, view_func),
|
||||
c10::nullopt, creation_meta, allow_tensor_metadata_change);
|
||||
}
|
||||
} else if(diff_view_meta->has_view_fn()) {
|
||||
// if current_view doesn't have a view_func but it's parent has one
|
||||
auto prev_view_fn = diff_view_meta->view_fn();
|
||||
auto size = tensor.sizes().vec();
|
||||
auto stride = tensor.strides().vec();
|
||||
auto storage_offset = tensor.storage_offset();
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
auto temp = prev_view_fn(root_base);
|
||||
return temp.as_strided(size, stride, storage_offset);
|
||||
};
|
||||
} else {
|
||||
TORCH_CHECK(creation_meta == CreationMeta::DEFAULT,
|
||||
"Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
|
||||
return make_variable_non_differentiable_view(base, std::move(tensor), allow_tensor_metadata_change);
|
||||
}
|
||||
base_var = base_var._base();
|
||||
}
|
||||
if (is_differentiable) {
|
||||
return make_variable_differentiable_view(std::move(base_var), tensor, creation_meta, std::move(view_func));
|
||||
// Create both the forward and backward info that are needed
|
||||
c10::optional<ViewInfo> new_bw_info;
|
||||
c10::optional<ViewInfo> new_fw_info;
|
||||
|
||||
if (is_bw_differentiable) {
|
||||
if (base.is_view()) {
|
||||
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
|
||||
const auto& base_bw_info = diff_view_meta->get_backward_view();
|
||||
new_bw_info = base_bw_info.chain(base, tensor, view_func);
|
||||
} else {
|
||||
new_bw_info = ViewInfo(base, view_func);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(creation_meta == CreationMeta::DEFAULT,
|
||||
"Non-differentiable views must have creation_meta=CreationMeta::DEFAULT");
|
||||
return make_variable_non_differentiable_view(std::move(base_var), tensor);
|
||||
"Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
|
||||
}
|
||||
|
||||
if (is_fw_differentiable) {
|
||||
// Check if base is a forward differentiable view
|
||||
auto base_meta = torch::autograd::impl::get_autograd_meta(base);
|
||||
auto is_view = base_meta && base_meta->is_view_;
|
||||
if (is_view && static_cast<DifferentiableViewMeta*>(base_meta)->has_fw_view()) {
|
||||
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(base_meta);
|
||||
const auto& base_fw_info = diff_view_meta->get_forward_view();
|
||||
new_fw_info = base_fw_info.chain(base, tensor, view_func);
|
||||
} else {
|
||||
new_fw_info = ViewInfo(base, view_func);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_fw_differentiable || is_bw_differentiable) {
|
||||
return make_variable_differentiable_view(tensor, std::move(new_bw_info), std::move(new_fw_info),
|
||||
creation_meta, allow_tensor_metadata_change);
|
||||
} else {
|
||||
return make_variable_non_differentiable_view(base, tensor, allow_tensor_metadata_change);
|
||||
}
|
||||
}
|
||||
|
||||
// See NOTE [ Autograd View Variables ] for details.
|
||||
inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& tensors, bool is_differentiable,
|
||||
CreationMeta creation_meta=CreationMeta::DEFAULT) {
|
||||
auto base_var = Variable(base);
|
||||
if (base_var.is_view()) {
|
||||
base_var = base_var._base();
|
||||
}
|
||||
for(Tensor &tensor : tensors) {
|
||||
if (is_differentiable) {
|
||||
tensor = make_variable_differentiable_view(base_var, tensor, creation_meta);
|
||||
inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& tensors, bool is_bw_differentiable,
|
||||
bool is_fw_differentiable, CreationMeta creation_meta=CreationMeta::DEFAULT) {
|
||||
c10::optional<ViewInfo> new_bw_info = c10::nullopt;
|
||||
c10::optional<ViewInfo> new_fw_info = c10::nullopt;
|
||||
|
||||
if (is_bw_differentiable) {
|
||||
if (base.is_view()) {
|
||||
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
|
||||
const auto& base_bw_info = diff_view_meta->get_backward_view();
|
||||
TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::MULTI_OUTPUT_NODE || creation_meta == CreationMeta::MULTI_OUTPUT_SAFE,
|
||||
"Functions that result multiple view must have a creation meta reflecting this behavior.");
|
||||
// It is ok to create a ViewInfo where only the base is correct in this case as inplace operations on such views are
|
||||
// not allowed
|
||||
new_bw_info = ViewInfo(base_bw_info.base_, /* view_func */ c10::nullopt);
|
||||
} else {
|
||||
TORCH_CHECK(creation_meta == CreationMeta::DEFAULT,
|
||||
"Non-differentiable views must have creation_meta=CreationMeta::DEFAULT");
|
||||
tensor = make_variable_non_differentiable_view(base_var, tensor);
|
||||
new_bw_info = ViewInfo(base, /* view_func */ c10::nullopt);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(creation_meta == CreationMeta::DEFAULT,
|
||||
"Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
|
||||
}
|
||||
if (isForwardADEnabled() && is_fw_differentiable) {
|
||||
// Check if base is a forward differentiabble view
|
||||
auto base_meta = torch::autograd::impl::get_autograd_meta(base);
|
||||
auto is_view = base_meta && base_meta->is_view_;
|
||||
if (is_view && static_cast<DifferentiableViewMeta*>(base_meta)->has_fw_view()) {
|
||||
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(base_meta);
|
||||
const auto& base_fw_info = diff_view_meta->get_forward_view();
|
||||
TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::MULTI_OUTPUT_NODE || creation_meta == CreationMeta::MULTI_OUTPUT_SAFE,
|
||||
"Functions that result multiple view must have a creation meta reflecting this behavior.");
|
||||
// It is ok to create a ViewInfo where only the base is correct in this case as inplace operations on such views are
|
||||
// not allowed
|
||||
new_fw_info = ViewInfo(base_fw_info.base_, /* view_func */ c10::nullopt);
|
||||
} else {
|
||||
new_fw_info = ViewInfo(base, /* view_func */ c10::nullopt);
|
||||
}
|
||||
}
|
||||
|
||||
for(Tensor &tensor : tensors) {
|
||||
if (is_fw_differentiable || is_bw_differentiable) {
|
||||
tensor = make_variable_differentiable_view(tensor, new_bw_info, new_fw_info, creation_meta);
|
||||
} else {
|
||||
tensor = make_variable_non_differentiable_view(base, tensor);
|
||||
}
|
||||
}
|
||||
return tensors;
|
||||
|
@ -155,5 +155,18 @@ variable_list grad(
|
||||
outputs, gradients, retain_graph.value(), create_graph, inputs, allow_unused, /*accumulate_grad=*/false);
|
||||
}
|
||||
|
||||
|
||||
namespace forward_ad {
|
||||
|
||||
uint64_t enter_dual_level() {
|
||||
return ForwardADLevel::get_next_idx();
|
||||
}
|
||||
|
||||
void exit_dual_level(uint64_t level) {
|
||||
ForwardADLevel::release_idx(level);
|
||||
}
|
||||
|
||||
} // namespace forward_ad
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
@ -75,5 +75,20 @@ TORCH_API variable_list grad(
|
||||
bool create_graph = false,
|
||||
bool allow_unused = false);
|
||||
|
||||
namespace forward_ad {
|
||||
|
||||
/// Creates a new dual level and returns its index. This level index should then be used to call
|
||||
/// into the other functions below.
|
||||
/// This API supports entering a new level before the previous one is exited. We call them nested
|
||||
/// forward AD levels. These can be used to compute higher order derivatives.
|
||||
TORCH_API uint64_t enter_dual_level();
|
||||
|
||||
/// Exits the given level. This will clear up all the gradients from this level and all dual Tensors
|
||||
/// that had gradients for this level will become regular Tensors again.
|
||||
/// This function can only be used to exit the innermost nesting level and so exiting must happen in
|
||||
/// reverse order compared to the entering that was done with the function above.
|
||||
TORCH_API void exit_dual_level(uint64_t level);
|
||||
|
||||
} // namespace forward_ad
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
218
torch/csrc/autograd/autograd_meta.cpp
Normal file
218
torch/csrc/autograd/autograd_meta.cpp
Normal file
@ -0,0 +1,218 @@
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
using at::Tensor;
|
||||
|
||||
// [Forward Grad View/inplace]
|
||||
// It is important to us to allow view and inplace to work with dual Tensors. These operations
|
||||
// should either compute the right gradient or raise a user-friendly error.
|
||||
|
||||
// The basic case where all Tensors are dual Tensors is as follows:
|
||||
// # Have:
|
||||
// # foo is a dual Tensor that is not a view
|
||||
// # bar is a dual Tensor of appropriate size (depending on cases) that is not a view
|
||||
//
|
||||
// # Case 1: no view
|
||||
// foo.copy_(bar)
|
||||
//
|
||||
// # Case 2: with view, propagate from view to base
|
||||
// view = foo[0]
|
||||
// view.copy_(bar)
|
||||
//
|
||||
// # Case 3: with view, propagate from base to view
|
||||
// view = foo[0]
|
||||
// foo.copy_(bar)
|
||||
//
|
||||
// # In both cases, the forward grad of foo must be properly updated.
|
||||
// # In the second and third cases, the forward grad of view must match
|
||||
// # the one of foo for the subset they have in common.
|
||||
//
|
||||
// All these cases can be handled by the following layout constraint on the forward grad:
|
||||
// - A Tensor and its forward grad (for all levels) must have the same metadata (size, stride
|
||||
// and storage offset). Storage offset must be in this metadata because of as_strided.
|
||||
// - View operations must create a forward grad that is a view of the base's forward grad.
|
||||
// - Inplace operations must modify the input's forward grad inplace.
|
||||
//
|
||||
// This layout constraint is ensured in the `set_fw_grad` function below
|
||||
|
||||
|
||||
// More complex cases arrise when non-dual Tensor interact with dual Tensors.
|
||||
// The two most important cases are:
|
||||
//
|
||||
// # Have:
|
||||
// # foo is a regular Tensor that is not a view
|
||||
// # bar is a dual Tensor of appropriate size (depending on cases) that is not a view
|
||||
//
|
||||
// # Case 4: Changes on the view must propagate to its base
|
||||
// view = foo[0]
|
||||
// # view is still a regular Tensor here
|
||||
// view.copy_(bar)
|
||||
// # Now both view and foo are dual Tensor with appropriate forward grad
|
||||
//
|
||||
// # Case 5: Changes on the base must propagate on all its views
|
||||
// view = foo[0]
|
||||
// # view is still a regular Tensor here
|
||||
// base.copy_(bar)
|
||||
// # Now both view and foo are dual Tensor with appropriate forward grad
|
||||
//
|
||||
// # NB there is a case 6 involving changes on a view propagating to other views
|
||||
// # but it is fully described by the two others and is skipped in this discussion.
|
||||
//
|
||||
// Case 4 is handled by set_fw_grad by properly setting the forward grad of the base if needed.
|
||||
// Case 5 is handled in fw_grad by reading the forward grad from the base if needed.
|
||||
|
||||
|
||||
namespace {
|
||||
// Check if two Tensor have the same storage offset, sizes and strides
|
||||
bool has_same_meta(const Variable& base, const Variable& other) {
|
||||
if (!base.defined() || !other.defined()) {
|
||||
return false;
|
||||
}
|
||||
if (base.storage_offset() != other.storage_offset()) {
|
||||
return false;
|
||||
}
|
||||
if (base.dim() != other.dim()) {
|
||||
return false;
|
||||
}
|
||||
for (int64_t i=0; i<base.dim(); ++i) {
|
||||
if (base.sizes()[i] != other.sizes()[i]) {
|
||||
return false;
|
||||
}
|
||||
if (base.strides()[i] != other.strides()[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Tensor new_with_same_meta(const Variable& base) {
|
||||
// We need to create a storage of the same size to be able to have the same
|
||||
// viewing behavior in all cases
|
||||
// Explicit type here to appease Windows build
|
||||
int64_t nelement_in_storage = base.storage().nbytes() / base.itemsize();
|
||||
auto new_tensor = at::zeros({nelement_in_storage}, base.options());
|
||||
auto res = new_tensor.as_strided(base.sizes(), base.strides(), base.storage_offset());
|
||||
return res;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
// This function is will ensure that the fw_grad_ is properly a view of the base for inplace ops on
|
||||
// Tensors that do not have forward grad originally.
|
||||
void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self, uint64_t level, bool is_inplace_op) {
|
||||
// Lazy initialization
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (!fw_grad_) {
|
||||
fw_grad_ = std::make_shared<ForwardGrad>();
|
||||
}
|
||||
}
|
||||
if (fw_grad_->contains(level)) {
|
||||
// Setting the forward grad again is only allowed if it is a no-op.
|
||||
// We do allow this case to simplify writing codegen for inplace ops.
|
||||
TORCH_INTERNAL_ASSERT(new_grad_.defined(), "Cannot set a forward grad that is an undefined Tensor. Use "
|
||||
"_fw_primal(level) to get a new Tensor with this forward grad unset.");
|
||||
|
||||
TORCH_INTERNAL_ASSERT(is_inplace_op, "Only inplace operations can re-set the forward grad of a Tensor that "
|
||||
"already has one.");
|
||||
|
||||
TORCH_INTERNAL_ASSERT(fw_grad_->value(level).is_same(new_grad_), "Cannot set a value of a forward grad if it "
|
||||
"already exists. Inplace operations should modify it inplace.");
|
||||
} else {
|
||||
// TODO(alband) remove this spurious version counter bump
|
||||
auto new_grad = new_grad_;
|
||||
|
||||
if (is_inplace_op && is_view_) {
|
||||
auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
|
||||
|
||||
// For inplace ops on a Tensor that does not already have a forward grad and is a view, we propagate
|
||||
// the tangent to the base and ensure that the new_grad is a view of that base's tangent.
|
||||
// This ensure that case 4 from [Forward Grad View/inplace] above works fine
|
||||
// What happens in this long if statement is:
|
||||
// - Check if the base already has a grad
|
||||
// - If not, set a new fw_grad for it full of zeros
|
||||
// - Take a view of the base's forward grad
|
||||
// - Copy the given new_grad into this view
|
||||
// - Use this view as the new new_grad
|
||||
if (this_view_meta->has_fw_view()) {
|
||||
auto view_info = this_view_meta->get_forward_view();
|
||||
auto& base = view_info.base_;
|
||||
|
||||
if (!base.fw_grad(level).defined()) {
|
||||
// Enforce same meta here to make sure that the view op below is always valid
|
||||
Tensor new_base_fw_grad;
|
||||
if (has_same_meta(new_grad, base)) {
|
||||
// TODO extend this special case to when the underlying storage of new_grad
|
||||
// can be re-used.
|
||||
new_base_fw_grad = new_grad;
|
||||
} else {
|
||||
new_base_fw_grad = new_with_same_meta(base);
|
||||
|
||||
// Update new_grad to be a view of the base
|
||||
Tensor new_fw_grad_value;
|
||||
if (view_info.has_view_fn()) {
|
||||
new_fw_grad_value = view_info.view_fn()(new_base_fw_grad);
|
||||
} else {
|
||||
new_fw_grad_value = new_base_fw_grad.as_strided(self.sizes(), self.strides(), self.storage_offset());
|
||||
}
|
||||
|
||||
new_fw_grad_value.copy_(new_grad);
|
||||
new_grad = new_fw_grad_value;
|
||||
}
|
||||
|
||||
base.set_fw_grad(new_base_fw_grad, level, /* is_inplace_op */ false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce the basic layout constraint
|
||||
if (!has_same_meta(new_grad, self)) {
|
||||
Tensor new_grad_with_meta = new_with_same_meta(self);
|
||||
new_grad_with_meta.copy_(new_grad);
|
||||
new_grad = new_grad_with_meta;
|
||||
}
|
||||
|
||||
fw_grad_->set_value(new_grad, level);
|
||||
}
|
||||
}
|
||||
|
||||
const Variable& AutogradMeta::fw_grad(uint64_t level, const Variable& self) const {
|
||||
// Ensure that concurent fw_grad() "reads" are thread safe
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
const auto& direct_fw_grad = fw_grad_ ? fw_grad_->value(level) : ForwardGrad::undef_grad();
|
||||
|
||||
if (!direct_fw_grad.defined() && is_view_) {
|
||||
// For view that don't have a forward grad, check if their base has one that
|
||||
// has been defined by an inplace operation.
|
||||
// This ensure that case 5 from [Forward Grad View/inplace] above works fine
|
||||
auto const_view_meta = static_cast<const torch::autograd::DifferentiableViewMeta*>(this);
|
||||
// This is ok to do as we ONLY modify fw_grad_ and this field is properly locked in all methods
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
auto this_view_meta = const_cast<torch::autograd::DifferentiableViewMeta*>(const_view_meta);
|
||||
if (this_view_meta->has_fw_view()) {
|
||||
const auto& view_info = this_view_meta->get_forward_view();
|
||||
const auto& base = view_info.base_;
|
||||
|
||||
const auto& base_val = base.fw_grad(level);
|
||||
if (base_val.defined()) {
|
||||
// Lazy initialization of fw_grad_
|
||||
this_view_meta->fw_grad_ = std::make_shared<ForwardGrad>();
|
||||
|
||||
Variable new_val;
|
||||
if (view_info.has_view_fn()) {
|
||||
new_val = view_info.view_fn()(base_val);
|
||||
} else {
|
||||
new_val = base_val.as_strided(self.sizes(), self.strides(), self.storage_offset());
|
||||
}
|
||||
|
||||
this_view_meta->fw_grad_->set_value(new_val, level);
|
||||
return this_view_meta->fw_grad_->value(level);
|
||||
}
|
||||
}
|
||||
}
|
||||
return direct_fw_grad;
|
||||
}
|
||||
|
||||
}} // torch::autograd
|
@ -124,7 +124,7 @@ variable_list _wrap_outputs(const variable_list &input_vars,
|
||||
if (!(is_input && is_modified) && var.is_view()) {
|
||||
// NB: is_view() ==> get_autograd_meta()
|
||||
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(impl::get_autograd_meta(var));
|
||||
diff_view_meta->creation_meta = CreationMeta::IN_CUSTOM_FUNCTION;
|
||||
diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION);
|
||||
}
|
||||
|
||||
if (is_differentiable) {
|
||||
@ -142,7 +142,7 @@ variable_list _wrap_outputs(const variable_list &input_vars,
|
||||
if (var.is_view()) {
|
||||
// NB: is_view() ==> get_autograd_meta()
|
||||
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(impl::get_autograd_meta(var));
|
||||
diff_view_meta->creation_meta = CreationMeta::MULTI_OUTPUT_NODE;
|
||||
diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
90
torch/csrc/autograd/forward_grad.cpp
Normal file
90
torch/csrc/autograd/forward_grad.cpp
Normal file
@ -0,0 +1,90 @@
|
||||
#include <torch/csrc/autograd/forward_grad.h>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
namespace {
|
||||
// See discussion in forward_grad.h for why these are global variables and not
|
||||
// thread local
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static std::mutex all_forward_levels_mutex_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static uint64_t next_forward_idx_ = 0;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static std::vector<std::shared_ptr<ForwardADLevel>> all_forward_levels_;
|
||||
|
||||
const static at::Tensor singleton_undefined_tensor;
|
||||
|
||||
// Temporary flag to disable forward mode
|
||||
// TODO(alband) remove these when perf issues are solved
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static bool is_forward_grad_enabled = false;
|
||||
}
|
||||
|
||||
uint64_t ForwardADLevel::get_next_idx() {
|
||||
std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
|
||||
TORCH_CHECK(next_forward_idx_ == 0, "Nested forward mode AD is not supported at the moment");
|
||||
auto new_index = next_forward_idx_++;
|
||||
TORCH_INTERNAL_ASSERT(new_index == all_forward_levels_.size());
|
||||
all_forward_levels_.push_back(std::make_shared<ForwardADLevel>(new_index));
|
||||
return new_index;
|
||||
}
|
||||
|
||||
void ForwardADLevel::release_idx(uint64_t idx) {
|
||||
std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
|
||||
TORCH_CHECK(idx == all_forward_levels_.size() - 1, "Exiting a forward AD level that is not the "
|
||||
"last that was created is not support. Ensure they are released in the reverse "
|
||||
"order they were created.");
|
||||
TORCH_CHECK(idx >= 0, "No forward AD level was created so you cannot exit it.");
|
||||
next_forward_idx_--;
|
||||
all_forward_levels_.pop_back();
|
||||
|
||||
}
|
||||
std::shared_ptr<ForwardADLevel> ForwardADLevel::get_by_idx(uint64_t idx) {
|
||||
std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
|
||||
TORCH_CHECK(idx < all_forward_levels_.size(), "Trying to access a forward AD level with an invalid index. "
|
||||
"This index was either not created or is already deleted.");
|
||||
return all_forward_levels_[idx];
|
||||
}
|
||||
|
||||
std::shared_ptr<ForwardADLevel> ForwardADLevel::try_get_by_idx(uint64_t idx) {
|
||||
std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
|
||||
if (idx < all_forward_levels_.size()) {
|
||||
return all_forward_levels_[idx];
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ForwardADLevel::~ForwardADLevel() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto it = grads_.begin();
|
||||
while (it != grads_.end()) {
|
||||
// Warning this will lock *it mutex
|
||||
// This is ok as this function is the *only* one to call back into another class's method.
|
||||
(*it)->reset(idx_, /* update_level */ false);
|
||||
it = grads_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
const at::Tensor& ForwardGrad::value(uint64_t level) const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
const auto& it = content_.find(level);
|
||||
return it == content_.end() ? singleton_undefined_tensor : (*it).second;
|
||||
}
|
||||
|
||||
const at::Tensor& ForwardGrad::undef_grad() {
|
||||
return singleton_undefined_tensor;
|
||||
}
|
||||
|
||||
// Temporary functions to disable forward AD
|
||||
// TODO(alband) remove these when perf issues are solved
|
||||
bool isForwardADEnabled() {
|
||||
return is_forward_grad_enabled;
|
||||
}
|
||||
|
||||
void setForwardADEnabled(bool value) {
|
||||
is_forward_grad_enabled = value;
|
||||
}
|
||||
|
||||
}} // namespace torch::autograd
|
193
torch/csrc/autograd/forward_grad.h
Normal file
193
torch/csrc/autograd/forward_grad.h
Normal file
@ -0,0 +1,193 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
// [ Using ForwardGrad ]
|
||||
// ForwardGrad needs to be a shared_ptr to satisfy constraints of its inner design. But
|
||||
// this shared_ptr must be uniquely associated with the object that stores it (as of
|
||||
// writing, either AutogradMeta or SavedVariable). This object is called the "owning object"
|
||||
// in the discussions below. This owning object must call `ForwardGrad::clear()` when it
|
||||
// is destroyed to ensure that the ForwardGrad is properly de-allocated.
|
||||
|
||||
struct ForwardGrad;
|
||||
|
||||
// This file contains two classes that are used to store forward AD gradients and
|
||||
// ensure that they are scoped properly.
|
||||
// Because forward AD runs concurrently with the evaluation of the function, we need
|
||||
// a mechanism to separate different forward AD invocations and be able to compute the
|
||||
// right gradients. We model such invocations as levels here.
|
||||
// The particular scoping issue mentioned above has two main drivers:
|
||||
// - Ensure that we can conveniently use forward AD within a high level API without
|
||||
// leaking the forward AD states outside.
|
||||
// - Ensure that we can keep the level that we expose to the user API simple (an integer
|
||||
// that represents the nesting depth) while avoiding confusions when the level index
|
||||
// is re-used.
|
||||
|
||||
// The important external APIs from this file are:
|
||||
// - ForwardADLevel::get_next_idx() that can be used to enter a new level and get its index
|
||||
// - ForwardADLevel::release_idx() that can be used to exit a given level.
|
||||
// - ForwardGrad() can be used to store a given forward gradient that will handle the level
|
||||
// tracking automatically.
|
||||
|
||||
// The basic implementation strategy is as follows:
|
||||
// Every tensor has a ForwardGrad, maintaining a map from levels to tangents.
|
||||
// ForwardGrad is responsible for registering itself to the appropriate ForwardADLevel when a new
|
||||
// tangent is added to it via ForwardGrad::set_value and to un-register itself from this same level
|
||||
// if that tangent is removed via ForwardGrad::reset.
|
||||
// The ForwardADLevel is created when a new level is entered via ForwardADLevel::get_next_idx.
|
||||
// A reference to the new ForwardADLevel is stored into a global (for the whole process) vector that
|
||||
// ensure it can be accessed via ForwardADLevel::get_by_idx. This reference is deleted when the index is
|
||||
// released by the user when calling ForwardADLevel::release_idx.
|
||||
// When it is destructed, the ForwardADLevel is responsible for clearing all the tangents for its
|
||||
// level stored in all the ForwardGrad that registered with it.
|
||||
//
|
||||
// This process-wide level design, compared to a thread local one, allows us to use very simple user facing
|
||||
// handle for the level (an int) while enabling cross-thread forward AD.
|
||||
// The only required synchronization for the user is when entering and exiting the levels.
|
||||
// Some discussion on alternative design is in https://github.com/pytorch/pytorch/pull/49097#discussion_r543716453
|
||||
// and can be refined in the future.
|
||||
|
||||
// Correctness of concurrency:
|
||||
// Each class uses its own lock when reading or modifying internal storages. This allows in particular
|
||||
// to safely remove tangents from ForwardGrad when the ForwardADLevel is being exited.
|
||||
// We ensure no deadlock by ensuring that a methods never calls into another class's method while
|
||||
// the local class's lock is held except in one single case: calling from ForwardADLevel's destructor
|
||||
// into ForwardGrad::reset with update_level=false.
|
||||
|
||||
// The lifetime of these objects is as follows:
|
||||
// The ForwardADLevel can be in three states:
|
||||
// - Initialized: where one of its reference is held by the global vector and there may be more
|
||||
// references held by temporary variables in ForwardGrad's methods.
|
||||
// - About to be destructed: where "release_idx" has been called and the only reason for the
|
||||
// ForwardADLevel not to be destructed right away is that some methods in ForwardGrad have
|
||||
// owning reference to it. This is done so that a ForwardADLevel can never be destructed when
|
||||
// a ForwardGrad is registered with it and in the process of adding something to its internal state.
|
||||
// - Being destructed: Here the ForwardADLevel is not referenced anymore and can be safely reset
|
||||
// all of the ForwardGrad. Note that we can have more than one reset being called here (which is ok)
|
||||
// but we are guaranteed that there is at least one.
|
||||
// The ForwardGrad is simpler as there is no intermediary state and no special destructor for. The logic to
|
||||
// unregister it from the different ForwardADLevel is done when the owning object (AutogradMeta or
|
||||
// SavedVariable) is being destroyed.
|
||||
|
||||
// Other considered design:
|
||||
// To avoid having the ForwardGrad::clear, we considered storing weak_ptr inside the ForwardADLevel. While this
|
||||
// would work, it would mean that the set inside the ForwardADLevel would only grow unless we do an
|
||||
// expensive linear scan to remove all the dangling weak pointers. Hence this approach was not used.
|
||||
|
||||
// Data structures in this file are optimized for this maximum number of levels.
|
||||
// The number of levels corresponds to the degree of the gradient being
|
||||
// computed using forward AD and we don't expect more than second order gradients
|
||||
// to be common.
|
||||
#define EXPECTED_MAX_LEVEL 2
|
||||
|
||||
struct TORCH_API ForwardADLevel {
|
||||
ForwardADLevel(uint64_t idx): idx_(idx) {}
|
||||
~ForwardADLevel();
|
||||
|
||||
static uint64_t get_next_idx();
|
||||
static void release_idx(uint64_t idx);
|
||||
static std::shared_ptr<ForwardADLevel> get_by_idx(uint64_t idx);
|
||||
static std::shared_ptr<ForwardADLevel> try_get_by_idx(uint64_t idx);
|
||||
|
||||
void erase(const std::shared_ptr<ForwardGrad>& grad) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
grads_.erase(grad);
|
||||
}
|
||||
|
||||
void insert(const std::shared_ptr<ForwardGrad>& grad) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
grads_.insert(grad);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_set<std::shared_ptr<ForwardGrad>> grads_;
|
||||
std::mutex mutex_;
|
||||
uint64_t idx_;
|
||||
|
||||
};
|
||||
|
||||
struct TORCH_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> {
|
||||
|
||||
ForwardGrad() {}
|
||||
|
||||
// This function must only be called when AutogradMeta or SavedVariable is being
|
||||
// destructed as it ensures that:
|
||||
// - The only (potential) other references to this ForwardGrad are the
|
||||
// different level it is registered to
|
||||
// - No other thread will try to call `set_value` or `value` ever from now on
|
||||
// - Any of the ForwardADLevel that this ForwardGrad is registered with might
|
||||
// call `reset` at any point during this function
|
||||
void clear() {
|
||||
c10::SmallVector<uint64_t, EXPECTED_MAX_LEVEL> levels_idx;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
for (auto& c: content_) {
|
||||
levels_idx.push_back(c.first);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto l_idx: levels_idx) {
|
||||
// Use "try" version here as another thread might have deleted this
|
||||
// level before we got here
|
||||
// This is an owning reference as we want to keep the level alive
|
||||
// until we successfully unregister ourselves
|
||||
auto level = ForwardADLevel::try_get_by_idx(l_idx);
|
||||
if (level) {
|
||||
level->erase(shared_from_this());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void set_value(const at::Tensor& value, uint64_t level) {
|
||||
// Owning reference to ensure the forward_level is not destroyed
|
||||
// while we are updating our internal state
|
||||
auto forward_level = ForwardADLevel::get_by_idx(level);
|
||||
forward_level->insert(shared_from_this());
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
content_.insert({level, value});
|
||||
}
|
||||
|
||||
// This function removes the tangent for a given level from this ForwardGrad
|
||||
// Use the update_level flag to disable notifying the level about this reset
|
||||
// This flag is most notably used by the ForwardADLevel destructor.
|
||||
void reset(uint64_t level, bool update_level=true) {
|
||||
if (update_level) {
|
||||
ForwardADLevel::get_by_idx(level)->erase(shared_from_this());
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
content_.erase(level);
|
||||
}
|
||||
|
||||
const at::Tensor& value(uint64_t level) const;
|
||||
|
||||
bool contains(uint64_t level) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return content_.count(level) > 0;
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return content_.empty();
|
||||
}
|
||||
|
||||
static const at::Tensor& undef_grad();
|
||||
|
||||
|
||||
private:
|
||||
// TODO(albanD): replace this with a SmallVector
|
||||
std::unordered_map<uint64_t, at::Tensor> content_;
|
||||
mutable std::mutex mutex_;
|
||||
|
||||
};
|
||||
|
||||
// Temporary functions to disable forward AD
|
||||
// TODO(alband) remove these when perf issues are solved
|
||||
bool TORCH_API isForwardADEnabled();
|
||||
void TORCH_API setForwardADEnabled(bool value);
|
||||
|
||||
}} // namespace torch::autograd
|
@ -47,4 +47,8 @@ auto UndefinedGradBackward::apply(variable_list&& output_grads) -> variable_list
|
||||
return input_grads;
|
||||
}
|
||||
|
||||
auto Identity::apply(variable_list&& grads) -> variable_list {
|
||||
return std::move(grads);
|
||||
}
|
||||
|
||||
}} // namespace torch::autograd
|
||||
|
@ -83,4 +83,8 @@ struct TORCH_API GraphRoot : public Node {
|
||||
variable_list outputs;
|
||||
};
|
||||
|
||||
struct TORCH_API Identity : public Node {
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
};
|
||||
|
||||
}}
|
||||
|
@ -3,11 +3,15 @@
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/autograd/autograd.h>
|
||||
#include <torch/csrc/autograd/grad_mode.h>
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <torch/csrc/autograd/profiler.h>
|
||||
#include <torch/csrc/autograd/python_function.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
|
||||
#include <torch/csrc/utils/pycfunction_helpers.h>
|
||||
|
||||
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
||||
using namespace torch::autograd::profiler;
|
||||
@ -230,6 +234,26 @@ static PyObject * autocast_decrement_nesting(PyObject* _unused, PyObject *arg) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * set_forward_AD_enabled(PyObject* _unused, PyObject *arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (!PyBool_Check(arg)) {
|
||||
throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
|
||||
}
|
||||
setForwardADEnabled(arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * is_forward_AD_enabled(PyObject* _unused, PyObject *arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (isForwardADEnabled()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (!PyBool_Check(arg)) {
|
||||
@ -270,10 +294,34 @@ static PyObject * is_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * python_enter_dual_level(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
// It is unlikely that the depth of forward nesting will overflow int64_t so we
|
||||
// just static cast here.
|
||||
return utils::wrap(static_cast<int64_t>(forward_ad::enter_dual_level()));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * python_exit_dual_level(PyObject* _unused, PyObject* args, PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"exit_dual_level(int64_t level)"
|
||||
});
|
||||
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto _r = parser.parse(args, kwargs, parsed_args);
|
||||
|
||||
forward_ad::exit_dual_level(_r.toInt64(0));
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// autograd methods on torch._C
|
||||
static PyMethodDef methods[] = { // NOLINT
|
||||
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
|
||||
{"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr},
|
||||
{"_set_forward_AD_enabled", set_forward_AD_enabled, METH_O, nullptr},
|
||||
{"_is_forward_AD_enabled", is_forward_AD_enabled, METH_NOARGS, nullptr},
|
||||
{"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr},
|
||||
{"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr},
|
||||
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
|
||||
@ -281,6 +329,8 @@ static PyMethodDef methods[] = { // NOLINT
|
||||
{"autocast_decrement_nesting", autocast_decrement_nesting, METH_NOARGS, nullptr},
|
||||
{"set_anomaly_enabled", set_anomaly_mode_enabled, METH_O, nullptr},
|
||||
{"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
|
||||
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
|
||||
{"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}
|
||||
};
|
||||
|
||||
|
@ -24,6 +24,12 @@ SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_i
|
||||
// These copies are all shared_ptr copies, so slightly more expensive.
|
||||
// Do them here instead of in the init list in case data is undefined.
|
||||
data_ = variable.tensor_data();
|
||||
// TODO(albanD) This needs to be updated when moving to multiple levels
|
||||
const auto& fw_grad = variable.fw_grad(/* level */ 0);
|
||||
if (fw_grad.defined()) {
|
||||
fw_grad_ = std::make_shared<ForwardGrad>();
|
||||
fw_grad_->set_value(fw_grad, /* level */ 0);
|
||||
}
|
||||
if (variable.is_leaf()) {
|
||||
grad_accumulator_ = impl::grad_accumulator(variable);
|
||||
} else if (!is_output) {
|
||||
@ -100,6 +106,16 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
|
||||
throw std::logic_error("No grad accumulator for a saved leaf!");
|
||||
impl::set_grad_accumulator(var, grad_accumulator_);
|
||||
|
||||
// NB: var here is never a view so there is no need to make anything special
|
||||
// for the case where the saved Tensor was a view. This whole argument relies
|
||||
// on the fact that the Tensor returned by this function is never
|
||||
// modified in-place.
|
||||
if (fw_grad_ && !fw_grad_->empty()) {
|
||||
// TODO(albanD) This needs to be updated when moving to multiple levels
|
||||
auto new_fw_grad = fw_grad_->value(/* level */ 0);
|
||||
var.set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
|
||||
}
|
||||
|
||||
return var;
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/autograd/forward_grad.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
@ -23,6 +24,12 @@ class TORCH_API SavedVariable {
|
||||
SavedVariable(const c10::optional<Variable>& variable, bool is_output, bool is_inplace_view=false);
|
||||
SavedVariable(SavedVariable&&) = default;
|
||||
SavedVariable& operator=(SavedVariable&&) = default;
|
||||
~SavedVariable() {
|
||||
if (fw_grad_) {
|
||||
// See note [ Using ForwardGrad ]
|
||||
fw_grad_->clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Reconstructs the saved variable. Pass `saved_for` as the gradient
|
||||
/// function if constructing the `SavedVariable` with it would have caused a
|
||||
@ -40,6 +47,11 @@ class TORCH_API SavedVariable {
|
||||
private:
|
||||
at::Tensor data_;
|
||||
|
||||
// This field is used to store the forward AD gradients associated with
|
||||
// the saved Tensor. Note that this shared_ptr must never be shared with
|
||||
// either the saved Tensor or the unpacked Tensor. See note [ Using ForwardGrad ]
|
||||
std::shared_ptr<ForwardGrad> fw_grad_;
|
||||
|
||||
// The gradient function associated with this node. If has_grad_fn
|
||||
// is false, then this is a leaf node. Note that the grad_fn is not saved if
|
||||
// it would create a circular reference. In that case, the grad_fn must be
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <ATen/core/VariableHooksInterface.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <list>
|
||||
@ -20,28 +21,83 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <typeinfo>
|
||||
#include <iostream>
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
|
||||
DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base,
|
||||
c10::optional<std::function<at::Tensor(const at::Tensor&)>> view_fn,
|
||||
DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl,
|
||||
c10::optional<ViewInfo> backward_info,
|
||||
c10::optional<ViewInfo> forward_info,
|
||||
CreationMeta creation_meta)
|
||||
: AutogradMeta(self_impl), creation_meta(creation_meta) {
|
||||
base_ = std::move(base);
|
||||
view_fn_ = std::move(view_fn);
|
||||
TORCH_CHECK(base_.defined(), "base is undefined");
|
||||
if (base_.is_view()) {
|
||||
base_ = base_._base();
|
||||
}
|
||||
: AutogradMeta(self_impl),
|
||||
backward_info_(std::move(backward_info)),
|
||||
forward_info_(std::move(forward_info)),
|
||||
creation_meta(creation_meta) {
|
||||
is_view_ = true;
|
||||
self_impl->set_version_counter(impl::version_counter(base_));
|
||||
attr_version = self_impl->version_counter().current_version();
|
||||
if (backward_info_.has_value()) {
|
||||
self_impl->set_version_counter(impl::version_counter(backward_info_.value().base_));
|
||||
attr_version = self_impl->version_counter().current_version();
|
||||
}
|
||||
}
|
||||
|
||||
DifferentiableViewMeta::~DifferentiableViewMeta() {
|
||||
base_.reset();
|
||||
// Chain this view info with the new view op between base and tensor
|
||||
ViewInfo ViewInfo::chain(const Variable & base, const Variable & tensor,
|
||||
c10::optional<std::function<Variable(const Variable&)>> view_func) const {
|
||||
// Set `view_func` using the root base as input.
|
||||
// `view_func` is used to recover views in backward when either as_strided is not supported
|
||||
// or the view function changes the metadata which is not recorded by as_strided
|
||||
// See Note [View + Inplace update on base tensor] and [View + Inplace update on view tensor]
|
||||
// for more details how we use this function in backward.
|
||||
if (view_func.has_value()) {
|
||||
auto fn = view_func.value();
|
||||
// both current_view and it's parent have a view_func
|
||||
if (view_fn_.has_value()) {
|
||||
auto prev_fn = view_fn_.value();
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
auto temp = prev_fn(root_base);
|
||||
return fn(temp);
|
||||
};
|
||||
} else {
|
||||
// current_view has a view_func and but it's parent doesn't have one
|
||||
if (base.unsafeGetTensorImpl()->support_as_strided()) {
|
||||
auto size = base.sizes().vec();
|
||||
auto stride = base.strides().vec();
|
||||
auto storage_offset = base.storage_offset();
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
auto temp = root_base.as_strided(size, stride, storage_offset);
|
||||
return fn(temp);
|
||||
};
|
||||
} else {
|
||||
// When base is a view but doesn't carry a view_fn in DifferentiableViewMeta, it's
|
||||
// a view that doesn't support inplace update, e.g. unbind.
|
||||
// In this case we should throw an error when inplace update happens in **forward**.
|
||||
// One would naturally think the following function will be first called in backward pass.
|
||||
// But the first call site is indeed in **forward** pass when we refresh `grad_fn`
|
||||
// triggered by inplace update.
|
||||
// Search Note [View + Inplace update for view tensor] to for the call site.
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
TORCH_CHECK(false, "This view is the output of a function that returns multiple views."
|
||||
"Such functions do not allow the output views to be modified inplace."
|
||||
"You should replace the inplace operation by an out-of-place one");
|
||||
return root_base;
|
||||
};
|
||||
}
|
||||
}
|
||||
} else if(view_fn_.has_value()) {
|
||||
// if current_view doesn't have a view_func but it's parent has one
|
||||
auto prev_view_fn = view_fn_.value();
|
||||
auto size = tensor.sizes().vec();
|
||||
auto stride = tensor.strides().vec();
|
||||
auto storage_offset = tensor.storage_offset();
|
||||
view_func = [=](const at::Tensor& root_base) {
|
||||
auto temp = prev_view_fn(root_base);
|
||||
return temp.as_strided(size, stride, storage_offset);
|
||||
};
|
||||
}
|
||||
|
||||
return ViewInfo(base_, view_func);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -81,21 +137,23 @@ namespace impl {
|
||||
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(get_autograd_meta(self));
|
||||
|
||||
// See NOTE [ View + Inplace detection ]
|
||||
if (diff_view_meta->creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) {
|
||||
auto creation_meta = diff_view_meta->get_creation_meta();
|
||||
if (creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) {
|
||||
// Do not use handle_view_on_rebase here as check_inplace should have been called before this
|
||||
// and either throw an error or clear the warning
|
||||
// Temporary error message as a full fix is too risky for now
|
||||
// Should be an internal assert again
|
||||
TORCH_INTERNAL_ASSERT(diff_view_meta->creation_meta == CreationMeta::DEFAULT);
|
||||
TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT);
|
||||
TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0);
|
||||
TORCH_INTERNAL_ASSERT(gradient_edge.function);
|
||||
TORCH_CHECK(
|
||||
gradient_edge.function->num_inputs() == 1,
|
||||
"Functions which modify views in-place must return a single Variable");
|
||||
auto view_info = diff_view_meta->get_backward_view();
|
||||
diff_view_meta->output_nr_ = gradient_edge.input_nr;
|
||||
auto copy_slices = std::make_shared<CopySlices>(
|
||||
diff_view_meta->base_, at::TensorGeometry(self), diff_view_meta->view_fn_, std::move(gradient_edge.function));
|
||||
set_gradient_edge(diff_view_meta->base_, {std::move(copy_slices), 0});
|
||||
view_info.base_, at::TensorGeometry(self), view_info.view_fn_, std::move(gradient_edge.function));
|
||||
set_gradient_edge(view_info.base_, {std::move(copy_slices), 0});
|
||||
self.grad_fn(); // trigger an update to the view's grad_fn
|
||||
return;
|
||||
}
|
||||
@ -181,7 +239,7 @@ namespace impl {
|
||||
if (self.is_view()) {
|
||||
// NB: is_view() ==> get_autograd_meta()
|
||||
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(meta);
|
||||
diff_view_meta->attr_version = self._version();
|
||||
diff_view_meta->set_attr_version(self._version());
|
||||
}
|
||||
}
|
||||
|
||||
@ -298,12 +356,14 @@ Tensor VariableHooks::tensor_data(const Tensor& self) const {
|
||||
return at::Tensor(self_impl_copy);
|
||||
}
|
||||
|
||||
// View Variables
|
||||
// Backward View Variables
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
bool VariableHooks::is_view(const Tensor& self) const {
|
||||
if (torch::autograd::impl::get_autograd_meta(self)) {
|
||||
return torch::autograd::impl::get_autograd_meta(self)->is_view_;
|
||||
auto meta = torch::autograd::impl::get_autograd_meta(self);
|
||||
if (meta && meta->is_view_) {
|
||||
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(meta);
|
||||
return diff_view_meta->has_bw_view();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
@ -313,9 +373,10 @@ const Tensor& VariableHooks::base(const Tensor& self) const {
|
||||
if (self.is_view()) {
|
||||
// is_view() implies get_autograd_meta()
|
||||
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(self));
|
||||
return diff_view_meta->base_;
|
||||
TORCH_CHECK(diff_view_meta->has_bw_view(), "Can't get base of non-backward view Tensor");
|
||||
return diff_view_meta->get_backward_view().base_;
|
||||
} else {
|
||||
throw std::runtime_error("Can't get base of non-view Variable");
|
||||
throw std::runtime_error("Can't get base of non-view Tensor");
|
||||
}
|
||||
}
|
||||
|
||||
@ -342,13 +403,14 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tenso
|
||||
auto diff_view_meta = static_cast<torch::autograd::DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(self));
|
||||
|
||||
// See NOTE [ View + Inplace detection ]
|
||||
if (diff_view_meta->creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) {
|
||||
if (diff_view_meta->get_creation_meta() != CreationMeta::MULTI_OUTPUT_SAFE) {
|
||||
std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
|
||||
if (!diff_view_meta->grad_fn_ && !diff_view_meta->base_.requires_grad()) {
|
||||
auto view_info = diff_view_meta->get_backward_view();
|
||||
if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) {
|
||||
return diff_view_meta->grad_fn_;
|
||||
}
|
||||
auto current_version = self._version();
|
||||
if (diff_view_meta->attr_version != current_version) {
|
||||
if (diff_view_meta->get_attr_version() != current_version) {
|
||||
// This is an indirect rebase_history due to another view or the base being modified inplace
|
||||
handle_view_on_rebase(diff_view_meta, /* indirect */ true);
|
||||
TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0);
|
||||
@ -377,24 +439,24 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tenso
|
||||
//
|
||||
// TODO: Potentially the following logic can be replaced by special logic in VariableType_x.cpp
|
||||
// that would provide a way to recreate the grad_fn chain.
|
||||
if (diff_view_meta->has_view_fn()) {
|
||||
auto view_fn = diff_view_meta->view_fn();
|
||||
auto diff_view = view_fn(diff_view_meta->base_);
|
||||
if (view_info.has_view_fn()) {
|
||||
auto view_fn = view_info.view_fn();
|
||||
auto diff_view = view_fn(view_info.base_);
|
||||
diff_view_meta->grad_fn_ = diff_view.grad_fn();
|
||||
} else {
|
||||
auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward>();
|
||||
fn->self_geometry = at::TensorGeometry(diff_view_meta->base_);
|
||||
fn->self_geometry = at::TensorGeometry(view_info.base_);
|
||||
fn->size = self.sizes().vec();
|
||||
fn->stride = self.strides().vec();
|
||||
fn->storage_offset = self.storage_offset();
|
||||
fn->set_next_edges(torch::autograd::collect_next_edges(diff_view_meta->base_));
|
||||
fn->set_next_edges(torch::autograd::collect_next_edges(view_info.base_));
|
||||
fn->add_input_metadata(
|
||||
diff_view_meta->base_.options(),
|
||||
view_info.base_.options(),
|
||||
self.sizes(), // Note: sizes(), not base_.sizes(), is intentional
|
||||
diff_view_meta->base_.device());
|
||||
view_info.base_.device());
|
||||
diff_view_meta->grad_fn_ = std::move(fn);
|
||||
}
|
||||
diff_view_meta->attr_version = current_version;
|
||||
diff_view_meta->set_attr_version(current_version);
|
||||
}
|
||||
return diff_view_meta->grad_fn_;
|
||||
}
|
||||
@ -429,7 +491,8 @@ unsigned VariableHooks::_register_hook(const Tensor& self, std::function<Tensor(
|
||||
|
||||
void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect) {
|
||||
/// See NOTE [ View + Inplace detection ] for justification of the logic below
|
||||
if (diff_view_meta->creation_meta != CreationMeta::DEFAULT) {
|
||||
auto creation_meta = diff_view_meta->get_creation_meta();
|
||||
if (creation_meta != CreationMeta::DEFAULT) {
|
||||
auto grad_fn = diff_view_meta->grad_fn_.get();
|
||||
std::string msg;
|
||||
std::string modified_obj;
|
||||
@ -446,24 +509,24 @@ void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect
|
||||
msg = c10::str("A view was created in no_grad mode and ", modified_obj, " modified inplace with grad mode enabled.");
|
||||
}
|
||||
|
||||
if (diff_view_meta->creation_meta == CreationMeta::MULTI_OUTPUT_NODE) {
|
||||
if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) {
|
||||
TORCH_CHECK(false, msg, " This view is the output of a function that returns multiple views. Such functions do not"
|
||||
" allow the output views to be modified inplace. You should replace the inplace operation by an"
|
||||
" out-of-place one.");
|
||||
} else {
|
||||
if (diff_view_meta->creation_meta == CreationMeta::NO_GRAD_MODE) {
|
||||
if (creation_meta == CreationMeta::NO_GRAD_MODE) {
|
||||
TORCH_INTERNAL_ASSERT(!grad_fn);
|
||||
msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is deprecated and will be forbidden"
|
||||
" starting 1.6 (see https://github.com/pytorch/pytorch/pull/32839 for more details about this). You"
|
||||
" can clarify your code and remove this warning by moving both the view and the inplace either both"
|
||||
" inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want"
|
||||
" the inplace to be tracked).");
|
||||
} else if (diff_view_meta->creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) {
|
||||
} else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) {
|
||||
msg = c10::str(msg, " This view was created inside a custom Function (or because an input was returned as-is) and the"
|
||||
" autograd logic to handle view+inplace would override the custom backward associated with the custom"
|
||||
" Function, leading to incorrect gradients. This behavior is deprecated and will be forbidden starting"
|
||||
" version 1.6. You can remove this warning by cloning the output of the custom Function.");
|
||||
} else if (diff_view_meta->creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) {
|
||||
} else if (creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) {
|
||||
msg = c10::str(msg, " This view is an output of a function that "
|
||||
"returns multiple views. Inplace operators on such "
|
||||
"views are being deprecated and will be forbidden "
|
||||
@ -487,8 +550,10 @@ void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect
|
||||
// We warn only once per view
|
||||
// Note that if a Tensor is modified inplace from two threads at the same time, this is not thread safe and can warn
|
||||
// multiple time. This is ok as it should be a rare event.
|
||||
diff_view_meta->creation_meta = CreationMeta::DEFAULT;
|
||||
diff_view_meta->set_creation_meta(CreationMeta::DEFAULT);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
}} // namespace torch::autograd
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <torch/csrc/autograd/edge.h>
|
||||
#include <torch/csrc/autograd/function_hook.h>
|
||||
#include <torch/csrc/autograd/cpp_hook.h>
|
||||
#include <torch/csrc/autograd/forward_grad.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
@ -193,6 +194,17 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
std::shared_ptr<Node> grad_fn_;
|
||||
std::weak_ptr<Node> grad_accumulator_;
|
||||
|
||||
// This field is used to store all the forward AD gradients
|
||||
// associated with this AutogradMeta (and the Tensor it corresponds to)
|
||||
// There is a semantic 1:1 correspondence between AutogradMeta and
|
||||
// ForwardGrad but:
|
||||
// - This field is lazily populated.
|
||||
// - This field is a shared_ptr but it must never be
|
||||
// shared by multiple Tensors. See Note [ Using ForwardGrad ]
|
||||
// Any transition from not_initialized to initialized
|
||||
// must be protected by mutex_
|
||||
std::shared_ptr<ForwardGrad> fw_grad_;
|
||||
|
||||
std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
|
||||
std::shared_ptr<hooks_list> cpp_hooks_list;
|
||||
|
||||
@ -211,9 +223,11 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
uint32_t output_nr_;
|
||||
|
||||
// Mutex to ensure that concurrent read operations that modify internal
|
||||
// state are still thread-safe. Used by grad_fn() and
|
||||
// grad_accumulator().
|
||||
std::mutex mutex_;
|
||||
// state are still thread-safe. Used by grad_fn(), grad_accumulator(),
|
||||
// fw_grad() and set_fw_grad()
|
||||
// This is mutable because we need to be able to acquire this from const
|
||||
// version of this class for the functions above
|
||||
mutable std::mutex mutex_;
|
||||
|
||||
/// Sets the `requires_grad` property of `Variable`. This should be true for
|
||||
/// leaf variables that want to accumulate gradients, and false for all other
|
||||
@ -238,6 +252,10 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
return grad_;
|
||||
}
|
||||
|
||||
const Variable& fw_grad(uint64_t level, const Variable& self) const override;
|
||||
|
||||
void set_fw_grad(const Variable& new_grad, const Variable& self, uint64_t level, bool is_inplace_op) override;
|
||||
|
||||
AutogradMeta(at::TensorImpl* self_impl = nullptr, bool requires_grad = false, Edge gradient_edge = Edge() ) {
|
||||
grad_fn_ = std::move(gradient_edge.function);
|
||||
requires_grad_ = false;
|
||||
@ -254,6 +272,55 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
!grad_fn_ || !requires_grad_,
|
||||
"requires_grad should be false if grad_fn is set");
|
||||
}
|
||||
|
||||
~AutogradMeta() override {
|
||||
// If AutogradMeta is being destroyed, it means that there is no other reference to its
|
||||
// corresponding Tensor. It implies that no other thread can be using this object and so there is
|
||||
// no need to lock mutex_ here to guard the check if fw_grad_ is populated.
|
||||
if (fw_grad_) {
|
||||
// See note [ Using ForwardGrad ]
|
||||
fw_grad_->clear();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API ViewInfo {
|
||||
/// The base `Variable`
|
||||
/// If this ViewInfo represents a forward (respectively backward) AD gradient,
|
||||
/// then this Tensor cannot be a forward (respectively backward) view.
|
||||
Variable base_;
|
||||
|
||||
/// By default we use as_strided to recover views which is more efficient.
|
||||
/// view_fn is only saved when as_strided is not supported.
|
||||
/// If view_fn has value, we use it to recover views in backward.
|
||||
c10::optional<std::function<Variable(const Variable&)>> view_fn_;
|
||||
|
||||
/// Accessors for the view function
|
||||
bool has_view_fn() const {
|
||||
return view_fn_.has_value();
|
||||
}
|
||||
|
||||
std::function<Variable(const Variable&)> view_fn() const {
|
||||
TORCH_CHECK(has_view_fn(), "Can only access the view function if it exists.");
|
||||
return view_fn_.value();
|
||||
}
|
||||
|
||||
/// The chain function can be used to build a new ViewInfo for a differentiable view
|
||||
/// function. It will return a new view info that accurately represents how "tensor" is
|
||||
/// a view of this instance's "base_".
|
||||
/// The "base" and "tensor" are respectively the input and output of the differentiable
|
||||
/// view function that happened. They are required to properly set the optional
|
||||
/// view_fn_ when it is not provided.
|
||||
/// The "view_func", if provided, should be a function that allows to re-do the view
|
||||
/// between "base" and "tensor".
|
||||
ViewInfo chain(const Variable & base, const Variable & tensor,
|
||||
c10::optional<std::function<Variable(const Variable&)>> view_func=c10::nullopt) const;
|
||||
|
||||
ViewInfo(Variable base, c10::optional<std::function<Variable(const Variable&)>> view_fn) :
|
||||
base_(std::move(base)),
|
||||
view_fn_(std::move(view_fn)) {
|
||||
TORCH_CHECK(base_.defined(), "base is undefined");
|
||||
}
|
||||
};
|
||||
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -274,6 +341,27 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
///
|
||||
/// Differentiable Views
|
||||
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
/// This class allows to track both forward and backward AD differentiable views.
|
||||
/// These views can have different base as non-differentiable view for forward
|
||||
/// and backward mode AD are not the same.
|
||||
///
|
||||
/// Most function are either both forward and backward differentiable views (for
|
||||
/// example: view, select, narrow, transpose, etc) or both not forward and not
|
||||
/// backward differentiable views (for example: indices, values, eq, lt, etc).
|
||||
/// But there are also functions that are forward but not backward differentiable
|
||||
/// views (only detach for now) or functions that are backward but not forward
|
||||
/// differentiable view (only make_dual and unpack dual for now).
|
||||
///
|
||||
/// A concrete example of two views with different bases is as follow:
|
||||
///
|
||||
/// # Have:
|
||||
/// # dual is a dual Tensor that is neither a forward or backward view
|
||||
/// detached_dual = dual.detach()
|
||||
/// view = detached_dual.view_as(dual)
|
||||
/// # The forward base of view is dual
|
||||
/// # The backward base of view is detached_dual
|
||||
///
|
||||
/// - Backward Mode View
|
||||
/// Differentiable views are the view variables where you want gradients to flow
|
||||
/// back to the base variables. Out-of-place operations on views are quite
|
||||
/// straightforward, but in-place ones are very tricky. Even if the base
|
||||
@ -300,6 +388,34 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
/// var[1] filled with all ones and
|
||||
/// zeros everywhere else
|
||||
///
|
||||
/// - Forward Mode View
|
||||
/// Forward differentiable views follow the same semantic as backward ones but
|
||||
/// show up differently as they are computed along with the forward evaluation.
|
||||
/// The hard examples above are thus very similar
|
||||
///
|
||||
/// (1) in-place operation on view, e.g.,
|
||||
///
|
||||
/// # Have:
|
||||
/// # base is a regular Tensor
|
||||
/// # var is a dual Tensor whose tangent is all ones
|
||||
/// base[1] = var # i.e., base[1].copy_(var)
|
||||
/// # Now, base is a dual Tensor
|
||||
/// _, fw_grad = fwAD.unpack_dual(base) <- fw_grad should be a tensor with
|
||||
/// fw_grad[1] filled with all ones and
|
||||
/// zeros everywhere else
|
||||
///
|
||||
/// (2) in-place operation on base after view is created, e.g.,
|
||||
///
|
||||
/// # Have:
|
||||
/// # base is a regular Tensor
|
||||
/// # var is a dual Tensor whose tangent is all ones
|
||||
/// view = base[1]
|
||||
/// base.copy_(var)
|
||||
/// _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones tensor
|
||||
///
|
||||
/// See Note [Forward Grad View/inplace] for more details on how we handle these hard cases.
|
||||
///
|
||||
///
|
||||
/// DifferentiableViewMeta is created to support gradient tracking of
|
||||
/// such **in-place** operations. In particular,
|
||||
/// + if an in-place op is done on base, the grad_fn field of the view may
|
||||
@ -392,37 +508,66 @@ enum class CreationMeta: uint8_t { DEFAULT, IN_CUSTOM_FUNCTION, MULTI_OUTPUT_NOD
|
||||
TORCH_API void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect=false);
|
||||
|
||||
struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
|
||||
/// The base `Variable` (never a view).
|
||||
Variable base_;
|
||||
private:
|
||||
/// Informations about the views
|
||||
c10::optional<ViewInfo> backward_info_;
|
||||
c10::optional<ViewInfo> forward_info_;
|
||||
|
||||
/// The two following fields are extra information that we track to ensure that
|
||||
/// any operation on this backward view is valid.
|
||||
|
||||
/// The value of the version_counter at the time grad_fn was created. The
|
||||
/// grad_fn field is stale if attr_version !=
|
||||
/// version_counter.current_version().
|
||||
/// grad_fn field is stale if attr_version != version_counter.current_version().
|
||||
uint32_t attr_version;
|
||||
|
||||
/// By default we use as_strided to recover views which is more efficient.
|
||||
/// view_fn is only saved when as_strided is not supported.
|
||||
/// If view_fn has value, we use it to recover views in backward.
|
||||
c10::optional<std::function<at::Tensor(const at::Tensor&)>> view_fn_;
|
||||
|
||||
CreationMeta creation_meta;
|
||||
|
||||
public:
|
||||
/// requires_grad is a backward AD field so we only use the view specific logic
|
||||
/// for backward differentiable views
|
||||
bool requires_grad() const override {
|
||||
return requires_grad_ || grad_fn_ || (is_view_ && base_.requires_grad());
|
||||
return requires_grad_ || grad_fn_ || (has_bw_view() && get_backward_view().base_.requires_grad());
|
||||
}
|
||||
|
||||
bool has_view_fn() const {
|
||||
return view_fn_.has_value();
|
||||
bool has_bw_view() const {
|
||||
return backward_info_.has_value();
|
||||
}
|
||||
|
||||
std::function<at::Tensor(const at::Tensor&)> view_fn() const {
|
||||
TORCH_CHECK(has_view_fn(), "view_fn is not set.");
|
||||
return view_fn_.value();
|
||||
const ViewInfo& get_backward_view() const {
|
||||
TORCH_CHECK(has_bw_view(), "backward view info can only exist for backward views.");
|
||||
return backward_info_.value();
|
||||
}
|
||||
|
||||
DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base, c10::optional<std::function<at::Tensor(const at::Tensor&)>> view_fn,
|
||||
CreationMeta creation_meta=CreationMeta::DEFAULT);
|
||||
~DifferentiableViewMeta();
|
||||
uint32_t get_attr_version() const {
|
||||
TORCH_CHECK(has_bw_view(), "attr_version can only exist for backward views.");
|
||||
return attr_version;
|
||||
}
|
||||
|
||||
void set_attr_version(uint32_t new_attr_version) {
|
||||
TORCH_CHECK(has_bw_view(), "attr_version can only exist for backward views.");
|
||||
attr_version = new_attr_version;
|
||||
}
|
||||
|
||||
CreationMeta get_creation_meta() const {
|
||||
TORCH_CHECK(has_bw_view(), "creation_meta can only exist for backward views.");
|
||||
return creation_meta;
|
||||
}
|
||||
|
||||
void set_creation_meta(CreationMeta new_creation_meta) {
|
||||
TORCH_CHECK(has_bw_view(), "creation_meta can only exist for backward views.");
|
||||
creation_meta = new_creation_meta;
|
||||
}
|
||||
|
||||
bool has_fw_view() const {
|
||||
return forward_info_.has_value();
|
||||
}
|
||||
|
||||
const ViewInfo& get_forward_view() const {
|
||||
TORCH_CHECK(has_fw_view(), "forward view info can only exist for forward views.");
|
||||
return forward_info_.value();
|
||||
}
|
||||
|
||||
DifferentiableViewMeta(at::TensorImpl* self_impl, c10::optional<ViewInfo> backward_info,
|
||||
c10::optional<ViewInfo> forward_info, CreationMeta creation_meta=CreationMeta::DEFAULT);
|
||||
};
|
||||
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -448,10 +593,11 @@ struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
|
||||
// See NOTE [ Autograd View Variables ] for details.
|
||||
// Differentiable view. Track history with DifferentiableViewMeta.
|
||||
inline Variable make_variable_differentiable_view(
|
||||
Variable base,
|
||||
const at::Tensor& data,
|
||||
c10::optional<ViewInfo> backward_info,
|
||||
c10::optional<ViewInfo> forward_info,
|
||||
CreationMeta creation_meta,
|
||||
c10::optional<std::function<at::Tensor(const at::Tensor&)>> view_func = c10::nullopt) {
|
||||
bool allow_tensor_metadata_change = true) {
|
||||
if (data.defined()) {
|
||||
// If we already did a TensorImpl allocation for data, just reuse it.
|
||||
// Otherwise(e.g tensor.swapdim(0, 0) when we return the same tensor as input),
|
||||
@ -461,14 +607,16 @@ inline Variable make_variable_differentiable_view(
|
||||
if (data.getIntrusivePtr().unique() && data.getIntrusivePtr()->unique_version()) {
|
||||
at::TensorImpl* data_impl = data.unsafeGetTensorImpl();
|
||||
data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
|
||||
data_impl, std::move(base), std::move(view_func), creation_meta));
|
||||
data_impl, std::move(backward_info), std::move(forward_info),
|
||||
creation_meta));
|
||||
return data;
|
||||
} else {
|
||||
c10::intrusive_ptr<at::TensorImpl> data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
|
||||
/*version_counter=*/0,
|
||||
/*allow_tensor_metadata_change=*/true);
|
||||
data_impl_copy->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
|
||||
data_impl_copy.get(), std::move(base), std::move(view_func), creation_meta));
|
||||
data_impl_copy.get(), std::move(backward_info), std::move(forward_info),
|
||||
creation_meta));
|
||||
return Variable(data_impl_copy);
|
||||
}
|
||||
}
|
||||
|
@ -181,6 +181,8 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
torch.is_deterministic,
|
||||
torch.set_deterministic,
|
||||
torch.unify_type_list,
|
||||
torch.make_dual,
|
||||
torch.unpack_dual,
|
||||
Tensor.__delitem__,
|
||||
Tensor.__dir__,
|
||||
Tensor.__getattribute__,
|
||||
|
Reference in New Issue
Block a user