mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Remove Variable::Impl and DifferentiableViewImpl (#17072)
Summary: As part of the Variable/Tensor merge work: https://github.com/pytorch/pytorch/issues/13638, we make the following changes in this PR: 1. Remove the `Variable::Impl` class and the `DifferentiableViewImpl` class 2. Change all `Variable.data()` call sites to either use `Variable` directly, or use `Variable.tensor_data()` 3. Remove `Variable.data()` API 3. Add `Variable.variable_data()` that matches `tensor.data` in Python API, which creates a new `Variable` that shares the same storage and tensor metadata with the original `Variable`, but with a completely new autograd history. After this PR, Variable doesn't wrap a Tensor internally anymore, and both Variable and Tensor use the same TensorImpl class as its `impl_`. The only difference is that Variable always has AutogradMeta in its TensorImpl, but Tensor doesn't. **Note that this PR is BC-breaking in the following use cases:** **Use Case 1:** Previously, `x.data = y` works even if `x` and `y` are of different TensorImpl type (e.g. `x` is a CPU dense tensor whose impl is of type TensorImpl, while `y` is a CPU sparse tensor whose impl is of type SparseTensorImpl). However, after this PR, `x.data = y` doesn't work anymore if `x` and `y` are of different TensorImpl type, because the underlying implementation `variable.set_data(tensor)` no longer works if `variable` and `tensor` have different TensorImpl type. **Use Case 2:** If a tensor `x`'s `grad` is sparse, accumulating dense gradients to `x` will change the tensor that `x.grad` is pointing to. This is better illustrated with the following example: ```python params = torch.tensor([1.5, 1.5]).requires_grad_() with torch.no_grad(): # Change gradient to a sparse tensor params.grad = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.])) grad_saved = params.grad params.backward(torch.tensor([1.5, 1.5])) assert id(grad_saved) == id(params.grad) # This will fail after this PR ``` The assertion in the last line will fail after this PR, because adding dense gradients to sparse gradients will change the `params.grad` tensor reference. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17072 Differential Revision: D14075257 Pulled By: yf225 fbshipit-source-id: 0e681df641270dea586042dd26db59f2e76b5957
This commit is contained in:
committed by
Facebook Github Bot
parent
f93e0619f3
commit
8cde4c4d22
@ -112,9 +112,12 @@ bool Context::setFlushDenormal(bool on) {
|
||||
return at::cpu::set_flush_denormal(on);
|
||||
}
|
||||
|
||||
// NOTE: We also check `at::NonVariableTypeMode`, and if it's enabled we always
|
||||
// return non-Variable type in this function.
|
||||
// See NOTE [ Treating Variables as non-Variables in type dispatch ]
|
||||
TypeExtendedInterface& getType(TensorOptions options) {
|
||||
return globalContext().getType(
|
||||
options.backend(), typeMetaToScalarType(options.dtype()), options.is_variable());
|
||||
options.backend(), typeMetaToScalarType(options.dtype()), options.is_variable() && !at::NonVariableTypeMode::is_enabled());
|
||||
}
|
||||
|
||||
// NOTE: We also check `at::NonVariableTypeMode`, and if it's enabled we always
|
||||
|
@ -77,41 +77,66 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl {
|
||||
AT_ERROR("opaque tensors do not have storage");
|
||||
}
|
||||
|
||||
// NOTE: `shallow_copy_and_detach()` does not copy the following TensorImpl fields:
|
||||
// 1. the AutogradMeta pointer, because it is unique for each Variable.
|
||||
// 2. the version counter, because it is set to the passed in `version_counter`.
|
||||
// See NOTE [ Version Counter Sharing ] for details.
|
||||
//
|
||||
// NOTE: `allow_tensor_metadata_change` determines whether the TensorImpl shallow-copy
|
||||
// allows changes to its metadata (e.g. sizes / strides / storage / storage_offset).
|
||||
// See NOTE [ Metadata Change for a Detached Tensor ] for details.
|
||||
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const override {
|
||||
//AT_ASSERT(false);
|
||||
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
|
||||
type_id(), dtype(), device(), opaque_handle_, sizes_);
|
||||
// TensorImpl general fields
|
||||
// Note that some of these fields are not used in opaque tensor code,
|
||||
// and we copy them here only for completeness.
|
||||
impl->sizes_ = sizes_;
|
||||
impl->strides_ = strides_;
|
||||
impl->storage_offset_ = storage_offset_;
|
||||
impl->is_contiguous_ = is_contiguous_;
|
||||
impl->is_wrapped_number_ = is_wrapped_number_;
|
||||
impl->reserved_ = reserved_;
|
||||
impl->set_version_counter(version_counter);
|
||||
impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
/**
|
||||
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
||||
*
|
||||
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
||||
* see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const override {
|
||||
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
|
||||
type_id(), dtype(), device(), opaque_handle_, sizes_);
|
||||
copy_tensor_data(
|
||||
/*src_impl=*/this,
|
||||
/*dest_impl=*/impl.get(),
|
||||
/*version_counter=*/version_counter,
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
impl->refresh_numel();
|
||||
return impl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shallow-copies data from another TensorImpl into this TensorImpl.
|
||||
*
|
||||
* For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`,
|
||||
* see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
||||
AT_ASSERT(typeid(*(impl.get())) == typeid(OpaqueTensorImpl<OpaqueHandle>));
|
||||
auto opaque_impl = static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
|
||||
copy_tensor_data(
|
||||
/*src_impl=*/opaque_impl,
|
||||
/*dest_impl=*/this,
|
||||
/*version_counter=*/version_counter(),
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
||||
refresh_numel();
|
||||
}
|
||||
|
||||
// OpaqueTensorImpl-specific fields (none currently).
|
||||
return impl;
|
||||
}
|
||||
OpaqueHandle& unsafe_opaque_handle() {
|
||||
return opaque_handle_;
|
||||
}
|
||||
|
||||
private:
|
||||
OpaqueHandle opaque_handle_;
|
||||
|
||||
/**
|
||||
* Copy the storage pointer and the tensor metadata fields (e.g. sizes / strides / storage_offset)
|
||||
* from one TensorImpl to another TensorImpl.
|
||||
*
|
||||
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
static void copy_tensor_data(
|
||||
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
|
||||
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) {
|
||||
TensorImpl::copy_tensor_data(src_opaque_impl, dest_opaque_impl, version_counter, allow_tensor_metadata_change);
|
||||
|
||||
// OpaqueTensorImpl-specific fields.
|
||||
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at
|
||||
|
@ -183,41 +183,64 @@ public:
|
||||
// make it happen
|
||||
void set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values);
|
||||
|
||||
// NOTE: `shallow_copy_and_detach()` does not copy the following TensorImpl fields:
|
||||
// 1. the AutogradMeta pointer, because it is unique for each Variable.
|
||||
// 2. the version counter, because it is set to the passed in `version_counter`.
|
||||
// See NOTE [ Version Counter Sharing ] for details.
|
||||
//
|
||||
// NOTE: `allow_tensor_metadata_change` determines whether the TensorImpl shallow-copy
|
||||
// allows changes to its metadata (e.g. sizes / strides / storage / storage_offset).
|
||||
// See NOTE [ Metadata Change for a Detached Tensor ] for details.
|
||||
/**
|
||||
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
||||
*
|
||||
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
||||
* see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const override {
|
||||
auto impl = c10::make_intrusive<SparseTensorImpl>(type_id(), dtype());
|
||||
// TensorImpl general fields
|
||||
// Note that these fields are not used in sparse tensor code, and we copy them here only for completeness.
|
||||
impl->sizes_ = sizes_;
|
||||
impl->strides_ = strides_;
|
||||
impl->storage_offset_ = storage_offset_;
|
||||
impl->is_contiguous_ = is_contiguous_;
|
||||
impl->is_wrapped_number_ = is_wrapped_number_;
|
||||
impl->reserved_ = reserved_;
|
||||
impl->set_version_counter(version_counter);
|
||||
impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
|
||||
// Sparse-specific fields
|
||||
impl->sparse_dim_ = sparse_dim();
|
||||
impl->dense_dim_ = dense_dim();
|
||||
impl->indices_ = indices();
|
||||
impl->values_ = values();
|
||||
impl->device_opt_ = device();
|
||||
impl->coalesced_ = coalesced();
|
||||
copy_tensor_data(
|
||||
/*src_impl=*/this,
|
||||
/*dest_impl=*/impl.get(),
|
||||
/*version_counter=*/version_counter,
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
impl->refresh_numel();
|
||||
return impl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shallow-copies data from another TensorImpl into this TensorImpl.
|
||||
*
|
||||
* For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`,
|
||||
* see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
||||
AT_ASSERT(typeid(*(impl.get())) == typeid(SparseTensorImpl));
|
||||
auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get());
|
||||
copy_tensor_data(
|
||||
/*src_impl=*/sparse_impl,
|
||||
/*dest_impl=*/this,
|
||||
/*version_counter=*/version_counter(),
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
||||
refresh_numel();
|
||||
}
|
||||
private:
|
||||
explicit SparseTensorImpl(at::TensorTypeId, const caffe2::TypeMeta&, at::Tensor indices, at::Tensor values);
|
||||
|
||||
/**
|
||||
* Copy the storage pointer and the tensor metadata fields (e.g. sizes / strides / storage_offset)
|
||||
* from one TensorImpl to another TensorImpl.
|
||||
*
|
||||
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
static void copy_tensor_data(
|
||||
const SparseTensorImpl* src_sparse_impl,
|
||||
SparseTensorImpl* dest_sparse_impl,
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) {
|
||||
TensorImpl::copy_tensor_data(src_sparse_impl, dest_sparse_impl, version_counter, allow_tensor_metadata_change);
|
||||
|
||||
// Sparse-specific fields
|
||||
dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim();
|
||||
dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim();
|
||||
dest_sparse_impl->indices_ = src_sparse_impl->indices();
|
||||
dest_sparse_impl->values_ = src_sparse_impl->values();
|
||||
dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at
|
||||
|
@ -42,11 +42,11 @@ Tensor & celu_(Tensor & self, Scalar alpha) {
|
||||
}
|
||||
|
||||
Tensor rrelu(const Tensor & self, Scalar lower, Scalar upper, bool training, Generator* generator) {
|
||||
return at::rrelu_with_noise(self, at::empty({0}, self.options()), lower, upper, training, generator);
|
||||
return at::rrelu_with_noise(self, at::empty_like(self), lower, upper, training, generator);
|
||||
}
|
||||
|
||||
Tensor & rrelu_(Tensor & self, Scalar lower, Scalar upper, bool training, Generator* generator) {
|
||||
return at::rrelu_with_noise_(self, at::empty({0}, self.options()), lower, upper, training, generator);
|
||||
return at::rrelu_with_noise_(self, at::empty_like(self), lower, upper, training, generator);
|
||||
}
|
||||
|
||||
// computes `result = self <= threshold ? value : other`
|
||||
|
@ -9,7 +9,7 @@ namespace at { namespace native {
|
||||
// Methods
|
||||
|
||||
void* data_ptr(const Tensor & self) {
|
||||
return self.unsafeGetTensorImpl()->slow_data();
|
||||
return self.unsafeGetTensorImpl()->data();
|
||||
}
|
||||
|
||||
Tensor & set_(Tensor& self, Storage source) {
|
||||
|
@ -25,24 +25,64 @@ struct CAFFE2_API QTensorImpl : public c10::TensorImpl {
|
||||
return quantizer_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
||||
*
|
||||
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
||||
* see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const override {
|
||||
auto impl = c10::make_intrusive<QTensorImpl>(
|
||||
Storage(storage()), type_id(), quantizer_);
|
||||
impl->set_sizes_and_strides(sizes(), strides());
|
||||
impl->storage_offset_ = storage_offset_;
|
||||
impl->is_wrapped_number_ = is_wrapped_number_;
|
||||
impl->reserved_ = reserved_;
|
||||
copy_tensor_data(
|
||||
/*src_impl=*/this,
|
||||
/*dest_impl=*/impl.get(),
|
||||
/*version_counter=*/version_counter,
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
impl->refresh_numel();
|
||||
impl->refresh_contiguous();
|
||||
impl->set_version_counter(version_counter);
|
||||
impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
return impl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shallow-copies data from another TensorImpl into this TensorImpl.
|
||||
*
|
||||
* For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`,
|
||||
* see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
||||
AT_ASSERT(typeid(*(impl.get())) == typeid(QTensorImpl));
|
||||
auto q_impl = static_cast<const QTensorImpl*>(impl.get());
|
||||
copy_tensor_data(
|
||||
/*src_impl=*/q_impl,
|
||||
/*dest_impl=*/this,
|
||||
/*version_counter=*/version_counter(),
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
}
|
||||
|
||||
private:
|
||||
QuantizerPtr quantizer_;
|
||||
|
||||
/**
|
||||
* Copy the storage pointer and the tensor metadata fields (e.g. sizes / strides / storage_offset)
|
||||
* from one TensorImpl to another TensorImpl.
|
||||
*
|
||||
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
static void copy_tensor_data(
|
||||
const QTensorImpl* src_q_impl,
|
||||
QTensorImpl* dest_q_impl,
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) {
|
||||
TensorImpl::copy_tensor_data(src_q_impl, dest_q_impl, version_counter, allow_tensor_metadata_change);
|
||||
|
||||
// OpaqueTensorImpl-specific fields.
|
||||
dest_q_impl->quantizer_ = src_q_impl->quantizer_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at
|
||||
|
@ -164,8 +164,7 @@ void THTensor_stealAndSetStoragePtr(THTensor* tensor, THStorage* storage) {
|
||||
// Caffe2 also has uninitialized dtype states, which we disallow here
|
||||
AT_ASSERT(tensor->storage().dtype() == storage->dtype());
|
||||
|
||||
// We used to allow this, but this breaks device caching,
|
||||
// see Note [We regret making Variable hold a Tensor]
|
||||
// We used to allow this, but this breaks device caching.
|
||||
// Let's put an actual error message for this one.
|
||||
TORCH_CHECK(tensor->storage().device() == storage->device(),
|
||||
"Attempted to set the storage of a tensor on device \"", tensor->storage().device(),
|
||||
|
@ -82,6 +82,7 @@ bool TensorImpl::compute_contiguous() const {
|
||||
}
|
||||
|
||||
void TensorImpl::release_resources() {
|
||||
autograd_meta_.reset();
|
||||
if (storage_) {
|
||||
storage_ = {};
|
||||
}
|
||||
|
@ -493,12 +493,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* the amount of code we have to write for add, when actually
|
||||
* a Tensor-Scalar addition is really just a Tensor-Tensor
|
||||
* addition when the RHS is 0-dim (except for promotion behavior.)
|
||||
*
|
||||
* WARNING: It is NOT valid to call this method on a Variable.
|
||||
* See Note [We regret making Variable hold a Tensor]
|
||||
*/
|
||||
bool is_wrapped_number() const {
|
||||
TORCH_INTERNAL_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
return is_wrapped_number_;
|
||||
}
|
||||
|
||||
@ -506,12 +502,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* Set whether or not a tensor was auto-wrapped from a C++ or Python
|
||||
* number. You probably don't want to call this, unless you are
|
||||
* writing binding code.
|
||||
*
|
||||
* WARNING: It is NOT valid to call this method on a Variable.
|
||||
* See Note [We regret making Variable hold a Tensor]
|
||||
*/
|
||||
void set_wrapped_number(bool value) {
|
||||
TORCH_INTERNAL_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
TORCH_INTERNAL_ASSERT(dim() == 0);
|
||||
is_wrapped_number_ = value;
|
||||
}
|
||||
@ -522,36 +514,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
//
|
||||
// Note [Tensor versus Variable in C++]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Autograd methods are only valid for the Variable::Impl subclass
|
||||
// of Tensor. This is due to some questionable life choices, where
|
||||
// a Variable has a Tensor (so they are not the same thing), but
|
||||
// a Variable is a Tensor (they are subclassed, so that you can write
|
||||
// code on Tensor that works both with Variables and Tensors. Poor
|
||||
// man's polymorphism). Variable does NOT satisfy the Liskov Substitution
|
||||
// Principle for Tensor; generally you want to work with all Variables,
|
||||
// or all Tensors, but not a mix of both. We intend to fix this in
|
||||
// the future.
|
||||
//
|
||||
// Note [We regret making Variable hold a Tensor]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Tensor has a bunch of fields in it. Are those fields always valid?
|
||||
// Not necessarily: the Variable::Impl subclass of a tensor doesn't use these
|
||||
// fields; instead, it *forwards* them to a contained, inner tensor
|
||||
// (the 'data' tensor). It doesn't even bother keeping the fields on the
|
||||
// outer tensor up-to-date, because an end user could grab the inner
|
||||
// tensor and directly, e.g., resize it (making any outer fields we track
|
||||
// stale).
|
||||
//
|
||||
// As you might imagine, this is a TERRIBLE state of affairs to be in.
|
||||
// It makes implementing everything on TensorImpl complicated: if
|
||||
// you directly access a field on TensorImpl, you must *virtualize*
|
||||
// the function, if you want it to work correctly when called from
|
||||
// Variable (because we need to override the method to avoid looking
|
||||
// in our fields, and look in the data tensor's fields.) Anything that
|
||||
// isn't virtualized, won't work if called on a variable.
|
||||
//
|
||||
// The way to fix this is to make Variable::Impl stop holding a tensor;
|
||||
// instead, it should just *be* a tensor.
|
||||
// Autograd methods are only valid for Variables (i.e. Tensors that contain
|
||||
// autograd metadata).
|
||||
|
||||
/**
|
||||
* Set whether or not a tensor requires gradient.
|
||||
@ -609,13 +573,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* performing index calculations to determine the location of elements in
|
||||
* the tensor. We recommend using 'TensorAccessor' to handle this computation
|
||||
* for you; this class is available from 'Tensor'.
|
||||
*
|
||||
* WARNING: It is NOT valid to call this method on a Variable.
|
||||
* See Note [We regret making Variable hold a Tensor]
|
||||
*/
|
||||
template <typename T>
|
||||
inline T * data() const {
|
||||
TORCH_INTERNAL_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
TORCH_CHECK(has_storage(),
|
||||
"Cannot access data pointer of Tensor that doesn't have storage");
|
||||
TORCH_CHECK(
|
||||
@ -643,12 +603,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* WARNING: The data pointed to by this tensor may not contiguous; do NOT
|
||||
* assume that itemsize() * numel() is sufficient to compute the bytes that
|
||||
* can be validly read from this tensor.
|
||||
*
|
||||
* WARNING: It is NOT valid to call this method on a Variable.
|
||||
* See Note [We regret making Variable hold a Tensor]
|
||||
*/
|
||||
inline void* data() const {
|
||||
TORCH_INTERNAL_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
TORCH_CHECK(has_storage(),
|
||||
"Cannot access data pointer of Tensor that doesn't have storage");
|
||||
TORCH_CHECK(dtype_initialized(),
|
||||
@ -659,21 +615,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
data_type_.itemsize() * storage_offset_);
|
||||
}
|
||||
|
||||
/**
|
||||
* This is just like data(), except it works with Variables.
|
||||
* This function will go away once Variable and Tensor are merged.
|
||||
* See Note [We regret making Variable hold a Tensor]
|
||||
*/
|
||||
virtual void* slow_data() const {
|
||||
return data();
|
||||
}
|
||||
|
||||
/**
|
||||
* Like data<T>(), but performs no checks. You are responsible for ensuring
|
||||
* that all invariants required by data() are upheld here.
|
||||
*
|
||||
* WARNING: It is NOT valid to call this method on a Variable.
|
||||
* See Note [We regret making Variable hold a Tensor]
|
||||
*/
|
||||
template <typename T>
|
||||
inline T * unsafe_data() const {
|
||||
@ -783,13 +727,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* WARNING: This function does not check if the requested
|
||||
* sizes/strides are in bounds for the storage that is allocated;
|
||||
* this is the responsibility of the caller
|
||||
*
|
||||
* WARNING: It is NOT valid to call this method on a Variable.
|
||||
* See Note [We regret making Variable hold a Tensor]
|
||||
*/
|
||||
void set_sizes_contiguous(IntArrayRef new_size) {
|
||||
TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous is not allowed on Tensor created from .data or .detach()");
|
||||
TORCH_INTERNAL_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
auto old_dim = sizes_.size();
|
||||
auto new_dim = new_size.size();
|
||||
|
||||
@ -808,12 +748,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* WARNING: This function does not check if the requested
|
||||
* sizes/strides are in bounds for the storage that is allocated;
|
||||
* this is the responsibility of the caller
|
||||
*
|
||||
* WARNING: It is NOT valid to call this method on a Variable.
|
||||
* See Note [We regret making Variable hold a Tensor]
|
||||
*/
|
||||
void set_sizes_and_strides(IntArrayRef new_size, IntArrayRef new_stride) {
|
||||
TORCH_INTERNAL_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_and_strides is not allowed on Tensor created from .data or .detach()");
|
||||
TORCH_CHECK(
|
||||
new_size.size() == new_stride.size(),
|
||||
@ -907,30 +843,69 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return std::move(autograd_meta_);
|
||||
}
|
||||
|
||||
// NOTE: `shallow_copy_and_detach()` does not copy the following TensorImpl fields:
|
||||
// 1. the AutogradMeta pointer, because it is unique for each Variable.
|
||||
// 2. the version counter, because it is set to the passed in `version_counter`.
|
||||
// See NOTE [ Version Counter Sharing ] for details.
|
||||
// NOTE [ TensorImpl Shallow-Copying ]
|
||||
//
|
||||
// NOTE: `allow_tensor_metadata_change` determines whether the TensorImpl shallow-copy
|
||||
// allows changes to its metadata (e.g. sizes / strides / storage / storage_offset).
|
||||
// See NOTE [ Metadata Change for a Detached Tensor ] for details.
|
||||
// TensorImpl shallow-copying is used when we want to have two Variables share the same storage pointer
|
||||
// and tensor metadata, but each with a different autograd history. Example call sites:
|
||||
//
|
||||
// 1. `var_detached = var.detach()` uses `shallow_copy_and_detach()` to create `var_detached` that shares
|
||||
// the same storage pointer and tensor metadata with `var`, but with a completely new autograd history.
|
||||
// 2. `var.set_data(tensor)` uses `shallow_copy_from()` to copy storage pointer and tensor metadata from
|
||||
// `tensor` into `var`, while keeping `var`'s original AutogradMeta.
|
||||
//
|
||||
// Functions that shallow-copy a TensorImpl (such as `shallow_copy_and_detach()` / `shallow_copy_from()` /
|
||||
// `copy_tensor_data()`) copy the storage pointer and the tensor metadata fields (e.g. sizes / strides /
|
||||
// storage_offset) by value. However, the following fields are not copied:
|
||||
//
|
||||
// 1. the AutogradMeta pointer, because it is unique for each Variable.
|
||||
// 2. the version counter, because the destination TensorImpl's version counter is either set to the
|
||||
// passed-in `version_counter` (in `shallow_copy_and_detach()` and `copy_tensor_data()`), or it is kept
|
||||
// intact (in `shallow_copy_from()`). See NOTE [ Version Counter Sharing ] for details.
|
||||
//
|
||||
// In `shallow_copy_and_detach()` and `copy_tensor_data()`, the passed-in `allow_tensor_metadata_change`
|
||||
// determines whether the TensorImpl shallow-copy allows changes to its metadata (e.g. sizes / strides /
|
||||
// storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor ] for details.
|
||||
//
|
||||
// In `shallow_copy_from()`, we don't check the destination TensorImpl's `allow_tensor_metadata_change_`,
|
||||
// because `shallow_copy_from()` is used for implementing functions such as `var.set_data(tensor)`, which
|
||||
// changes `var`'s tensor metadata and expects its `allow_tensor_metadata_change_` to be ignored.
|
||||
|
||||
/**
|
||||
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
||||
*
|
||||
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
||||
* see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
virtual c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
TORCH_INTERNAL_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
auto impl = c10::make_intrusive<TensorImpl>(Storage(storage()), type_id());
|
||||
impl->set_sizes_and_strides(sizes(), strides());
|
||||
impl->storage_offset_ = storage_offset_;
|
||||
impl->is_wrapped_number_ = is_wrapped_number_;
|
||||
impl->reserved_ = reserved_;
|
||||
copy_tensor_data(
|
||||
/*src_impl=*/this,
|
||||
/*dest_impl=*/impl.get(),
|
||||
/*version_counter=*/version_counter,
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
impl->refresh_numel();
|
||||
impl->refresh_contiguous();
|
||||
impl->set_version_counter(version_counter);
|
||||
impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
return impl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shallow-copies data from another TensorImpl into this TensorImpl.
|
||||
*
|
||||
* For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`,
|
||||
* see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
virtual void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
|
||||
copy_tensor_data(
|
||||
/*src_impl=*/impl.get(),
|
||||
/*dest_impl=*/this,
|
||||
/*version_counter=*/version_counter(),
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
}
|
||||
|
||||
void set_version_counter(
|
||||
const c10::VariableVersion& version_counter) noexcept {
|
||||
version_counter_ = version_counter;
|
||||
@ -966,7 +941,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* The device type of a Tensor, e.g., DeviceType::CPU or DeviceType::CUDA.
|
||||
*/
|
||||
DeviceType device_type() const {
|
||||
TORCH_INTERNAL_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
// TODO: A useful internal assert would be to show that device_opt_ is null
|
||||
// only if you are an undefined tensor
|
||||
TORCH_CHECK(device_opt_.has_value(), "device_type cannot be run on undefined Tensor");
|
||||
@ -1438,7 +1412,6 @@ protected:
|
||||
* Recompute the cached numel of a tensor. Call this if you modify sizes.
|
||||
*/
|
||||
void refresh_numel() {
|
||||
TORCH_INTERNAL_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
numel_ = compute_numel();
|
||||
}
|
||||
|
||||
@ -1447,10 +1420,34 @@ protected:
|
||||
* or strides.
|
||||
*/
|
||||
void refresh_contiguous() {
|
||||
TORCH_INTERNAL_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
is_contiguous_ = compute_contiguous();
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy the storage pointer and the tensor metadata fields (e.g. sizes / strides / storage_offset)
|
||||
* from one TensorImpl to another TensorImpl.
|
||||
*
|
||||
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ].
|
||||
*/
|
||||
static void copy_tensor_data(
|
||||
const TensorImpl* src_impl,
|
||||
TensorImpl* dest_impl,
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) {
|
||||
dest_impl->storage_ = src_impl->storage_;
|
||||
dest_impl->sizes_ = src_impl->sizes_;
|
||||
dest_impl->strides_ = src_impl->strides_;
|
||||
dest_impl->storage_offset_ = src_impl->storage_offset_;
|
||||
dest_impl->data_type_ = src_impl->data_type_;
|
||||
dest_impl->device_opt_ = src_impl->device_opt_;
|
||||
dest_impl->type_id_ = src_impl->type_id_;
|
||||
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
|
||||
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
|
||||
dest_impl->reserved_ = src_impl->reserved_;
|
||||
dest_impl->set_version_counter(version_counter);
|
||||
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
}
|
||||
|
||||
protected:
|
||||
Storage storage_;
|
||||
// This pointer points to an AutogradMeta struct that stores autograd-specific fields
|
||||
|
@ -83,7 +83,7 @@ class ScriptModuleOp final : public Operator<Context> {
|
||||
}
|
||||
|
||||
static caffe2::Tensor castIValueToTensor(const IValue& v) {
|
||||
return caffe2::Tensor(torch::autograd::Variable(v.toTensor()).data());
|
||||
return caffe2::Tensor(torch::autograd::Variable(v.toTensor()).tensor_data());
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
@ -422,7 +422,15 @@ void addObjectMethods(py::module& m) {
|
||||
.def("_wrap_tensor_impl", [](Blob* blob, void* ptr) {
|
||||
auto p = c10::intrusive_ptr<c10::TensorImpl, at::UndefinedTensorImpl>::
|
||||
unsafe_reclaim_from_nonowning(static_cast<c10::TensorImpl*>(ptr));
|
||||
// TODO: In the near future, a PyTorch tensor without AutogradMeta will be
|
||||
// a valid tensor. At that point, we will only accept non-requires-grad
|
||||
// tensor into Caffe2 workspace, and don't need to perform shallow-copying
|
||||
// here anymore.
|
||||
p = p->shallow_copy_and_detach(
|
||||
/*version_counter=*/p->version_counter(),
|
||||
/*allow_tensor_metadata_change=*/p->allow_tensor_metadata_change());
|
||||
TORCH_CHECK(p.defined(), "Can't wrap undefined tensor");
|
||||
TORCH_CHECK(!p->is_variable(), "Can wrap only non-variable tensor");
|
||||
auto at_tensor = at::Tensor::wrap_tensor_impl(std::move(p));
|
||||
BlobSetTensor(blob, Tensor(std::move(at_tensor)));
|
||||
});
|
||||
|
@ -308,6 +308,18 @@ class TestWorkspace(unittest.TestCase):
|
||||
workspace.FeedBlob('bar', z)
|
||||
workspace.RunOperatorOnce(
|
||||
core.CreateOperator("Reshape", ['bar'], ['bar', '_'], shape=(2,2)))
|
||||
# NOTE: `workspace.FeedBlob('bar', z)` above creates a shallow-copy of `z`
|
||||
# and assign it to `bar` in the Caffe2 workspace. Since it's a shallow-copy,
|
||||
# any sizes or strides change to `bar` will not be propagated back to `z`,
|
||||
# and we need to call `z = workspace.FetchTorch("bar")` to manually put
|
||||
# the value of `bar` back into `z`.
|
||||
#
|
||||
# In the near future, we won't need to perform the shallow-copying of `z` and
|
||||
# can directly pass it into the Caffe2 workspace, as long as `z` doesn't require
|
||||
# grad. At that point we won't need to use `z = workspace.FetchTorch("bar")`
|
||||
# to fetch `z` from the Caffe2 workspace, since it will exactly be the same as
|
||||
# the original tensor `z`.
|
||||
z = workspace.FetchTorch("bar")
|
||||
z[0,1] = 123
|
||||
np.testing.assert_array_equal(
|
||||
workspace.FetchBlob("bar"), np.array([[1,123],[1,1]]))
|
||||
@ -398,6 +410,18 @@ class TestWorkspaceGPU(test_util.TestCase):
|
||||
workspace.RunOperatorOnce(
|
||||
core.CreateOperator("Reshape", ['bar'], ['bar', '_'], shape=(2,2),
|
||||
device_option=core.DeviceOption(workspace.GpuDeviceType)))
|
||||
# NOTE: `workspace.FeedBlob('bar', z)` above creates a shallow-copy of `z`
|
||||
# and assign it to `bar` in the Caffe2 workspace. Since it's a shallow-copy,
|
||||
# any sizes or strides change to `bar` will not be propagated back to `z`,
|
||||
# and we need to call `z = workspace.FetchTorch("bar")` to manually put
|
||||
# the value of `bar` back into `z`.
|
||||
#
|
||||
# In the near future, we won't need to perform the shallow-copying of `z` and
|
||||
# can directly pass it into the Caffe2 workspace, as long as `z` doesn't require
|
||||
# grad. At that point we won't need to use `z = workspace.FetchTorch("bar")`
|
||||
# to fetch `z` from the Caffe2 workspace, since it will exactly be the same as
|
||||
# the original tensor `z`.
|
||||
z = workspace.FetchTorch("bar")
|
||||
z[0,1] = 123
|
||||
np.testing.assert_array_equal(
|
||||
workspace.FetchBlob("bar"), np.array([[1,123],[1,1]]))
|
||||
|
@ -93,7 +93,7 @@ variable_list grad(
|
||||
}
|
||||
|
||||
void testADFormulas() {
|
||||
const auto unwrap = [](const Variable& v) { return v.data(); };
|
||||
const auto cast = [](const Variable& v) { return static_cast<at::Tensor>(v); };
|
||||
|
||||
using VL = variable_list;
|
||||
const var_meta_list binary_pointwise = {{2, 3, 4, 5}, {2, 3, 4, 5}};
|
||||
@ -155,15 +155,15 @@ void testADFormulas() {
|
||||
auto grad_spec = differentiate(graph);
|
||||
LowerGradOf(*grad_spec.df);
|
||||
// Get outputs from the interpreter
|
||||
auto tensors_in = fmap(vars_in, unwrap);
|
||||
auto tensor_grads_in = fmap(var_grads_in, unwrap);
|
||||
auto tensors_in = fmap(vars_in, cast);
|
||||
auto tensor_grads_in = fmap(var_grads_in, cast);
|
||||
tensor_list tensors_out, tensor_grads_out;
|
||||
std::tie(tensors_out, tensor_grads_out) =
|
||||
runGradient(grad_spec, tensors_in, tensor_grads_in);
|
||||
|
||||
// Compare results
|
||||
auto expected_tensors_out = fmap(vars_out, unwrap);
|
||||
auto expected_tensor_grads_out = fmap(var_grads_out, unwrap);
|
||||
auto expected_tensors_out = fmap(vars_out, cast);
|
||||
auto expected_tensor_grads_out = fmap(var_grads_out, cast);
|
||||
assertAllClose(tensors_out, expected_tensors_out);
|
||||
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
|
||||
}
|
||||
|
@ -29,8 +29,8 @@ void testGraphExecutor() {
|
||||
ASSERT_EQ(stack.size(), 2);
|
||||
at::Tensor r0, r1;
|
||||
std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh);
|
||||
ASSERT_TRUE(almostEqual(Variable(stack[0].toTensor()).data(), r0));
|
||||
ASSERT_TRUE(almostEqual(Variable(stack[1].toTensor()).data(), r1));
|
||||
ASSERT_TRUE(almostEqual(stack[0].toTensor(), v(r0)));
|
||||
ASSERT_TRUE(almostEqual(stack[1].toTensor(), v(r1)));
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
@ -805,8 +805,8 @@ void testModuleConversion() {
|
||||
|
||||
m->to(at::kCUDA);
|
||||
m->to(at::kCPU);
|
||||
AT_ASSERT(m->get_parameter("foo").data().device().is_cpu());
|
||||
AT_ASSERT(m->get_buffer("bar").data().device().is_cpu());
|
||||
AT_ASSERT(m->get_parameter("foo").device().is_cpu());
|
||||
AT_ASSERT(m->get_buffer("bar").device().is_cpu());
|
||||
}
|
||||
{
|
||||
// test cpu to cuda for params and buffers
|
||||
@ -814,8 +814,8 @@ void testModuleConversion() {
|
||||
m->register_buffer("bar", torch::ones({}));
|
||||
|
||||
m->to(at::kCUDA);
|
||||
AT_ASSERT(m->get_parameter("foo").data().device().is_cuda());
|
||||
AT_ASSERT(m->get_buffer("bar").data().device().is_cuda());
|
||||
AT_ASSERT(m->get_parameter("foo").device().is_cuda());
|
||||
AT_ASSERT(m->get_buffer("bar").device().is_cuda());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,11 +50,7 @@ Tensor kl_div_backward_override(
|
||||
return get_dtype_tensor(self.dtype());
|
||||
}
|
||||
|
||||
// numel and ones_like are needed for autograd backwards
|
||||
int64_t numel_override(const Tensor & self) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// ones_like is needed for autograd backwards
|
||||
Tensor ones_like_override(const Tensor & self, const TensorOptions & options) {
|
||||
return get_dtype_tensor(options.dtype());
|
||||
}
|
||||
@ -81,9 +77,6 @@ void init_msnpu_extension() {
|
||||
Backend::MSNPU,
|
||||
"kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction) -> Tensor",
|
||||
&kl_div_backward_override);
|
||||
register_extension_backend_op(
|
||||
Backend::MSNPU,
|
||||
"numel(Tensor self) -> int64_t", &numel_override);
|
||||
register_extension_backend_op(
|
||||
Backend::MSNPU,
|
||||
"ones_like(Tensor self, TensorOptions options) -> Tensor",
|
||||
|
@ -203,6 +203,32 @@ class TestAutograd(TestCase):
|
||||
x_grad, x_grad_clone = compute_grad(create_graph=True)
|
||||
self.assertEqual(x_grad, x_grad_clone)
|
||||
|
||||
def test_accumulate_grad_tensor_reference(self):
|
||||
def _test_grad_tensor(params_grad_tensor, backward_grad_tensor, should_preserve_reference):
|
||||
params = torch.tensor([1.5, 1.5]).requires_grad_()
|
||||
params.grad = params_grad_tensor
|
||||
grad_saved = params.grad
|
||||
params.backward(backward_grad_tensor)
|
||||
self.assertEqual(id(grad_saved) == id(params.grad), should_preserve_reference)
|
||||
|
||||
# Accumulate dense gradient to sparse gradient will change the `params.grad` reference
|
||||
_test_grad_tensor(
|
||||
torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.])),
|
||||
torch.tensor([1.5, 1.5]),
|
||||
False)
|
||||
|
||||
# Accumulate dense gradient to dense gradient will preserve the `params.grad` reference
|
||||
_test_grad_tensor(
|
||||
torch.tensor([1.5, 1.5]),
|
||||
torch.tensor([1.5, 1.5]),
|
||||
True)
|
||||
|
||||
# Accumulate sparse gradient to sparse gradient will preserve the `params.grad` reference
|
||||
_test_grad_tensor(
|
||||
torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.])),
|
||||
torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.])),
|
||||
True)
|
||||
|
||||
def test_slogdet_sign(self):
|
||||
a = torch.randn(3, 3, requires_grad=True)
|
||||
s, logdet = a.slogdet()
|
||||
@ -3052,6 +3078,21 @@ class TestAutograd(TestCase):
|
||||
xz.add_(1)
|
||||
self.assertTrue(x._version == xz._version)
|
||||
|
||||
def test_set_data_tensorimpl_type(self):
|
||||
# Dense tensor has impl of type `TensorImpl`, while sparse tensor has impl
|
||||
# of type `SparseTensorImpl`.
|
||||
x = torch.randn(1, 2)
|
||||
x_s = torch.sparse_coo_tensor(torch.zeros([1, 1]), torch.ones([1]))
|
||||
with self.assertRaisesRegex(RuntimeError, 'different types of TensorImpl'):
|
||||
x.data = x_s
|
||||
|
||||
def test_set_data_preserve_pyobj(self):
|
||||
a = torch.randn(1, 2)
|
||||
b = torch.randn(1, 2)
|
||||
b_id_saved = id(b)
|
||||
b.data = a
|
||||
self.assertTrue(b_id_saved == id(b))
|
||||
|
||||
|
||||
def index_variable(shape, max_indices):
|
||||
if not isinstance(shape, tuple):
|
||||
|
@ -670,7 +670,7 @@ class TestMSNPUTensor(common.TestCase):
|
||||
d = c.sum()
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 2)
|
||||
|
||||
d.backward()
|
||||
d.backward(torch.zeros(0, device='msnpu'))
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 4)
|
||||
|
||||
|
||||
|
@ -274,6 +274,14 @@ class TestMkldnn(TestCase):
|
||||
module(*inputs).to_dense(),
|
||||
traced(*inputs).to_dense())
|
||||
|
||||
def test_set_data_tensorimpl_type(self):
|
||||
# Dense tensor has impl of type `TensorImpl`, while MKL-DNN tensor has impl
|
||||
# of type `OpaqueTensorImpl<IDeepTensorWrapperPtr>`.
|
||||
x = torch.randn((1, 2), dtype=torch.float, device=torch.device('cpu'))
|
||||
x_mkldnn = x.to_mkldnn()
|
||||
with self.assertRaisesRegex(RuntimeError, 'different types of TensorImpl'):
|
||||
x.data = x_mkldnn
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -68,10 +68,11 @@ class TestOptim(TestCase):
|
||||
y = grad[1]
|
||||
v = torch.DoubleTensor([y - y / 4., y / 4.])
|
||||
x = sparse.DoubleTensor(i, v, torch.Size([2]))
|
||||
if sparse_grad:
|
||||
params.grad.data = x
|
||||
else:
|
||||
params.grad.data = x.to_dense()
|
||||
with torch.no_grad():
|
||||
if sparse_grad:
|
||||
params.grad = x
|
||||
else:
|
||||
params.grad = x.to_dense()
|
||||
return loss
|
||||
|
||||
for i in range(2000):
|
||||
|
@ -120,6 +120,14 @@ graph(%x : (Tensor, float, int)):
|
||||
""",
|
||||
)
|
||||
|
||||
def test_set_data_tensorimpl_type(self):
|
||||
# Dense tensor has impl of type `TensorImpl`, while quantized tensor has impl
|
||||
# of type `QTensorImpl`.
|
||||
x = torch.randn(1, 2)
|
||||
x_q = torch.ops.c10.quantize(torch.randn(1, 2))
|
||||
with self.assertRaisesRegex(RuntimeError, 'different types of TensorImpl'):
|
||||
x.data = x_q
|
||||
|
||||
|
||||
class TestQuantizedOps(unittest.TestCase):
|
||||
"""Tests the correctness of the quantized::relu op."""
|
||||
|
@ -395,7 +395,7 @@ static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg)
|
||||
"Can't call numpy() on Variable that requires grad. "
|
||||
"Use var.detach().numpy() instead.");
|
||||
}
|
||||
return torch::utils::tensor_to_numpy(self_.data());
|
||||
return torch::utils::tensor_to_numpy(self_.tensor_data());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -614,7 +614,7 @@ static PyObject * THPVariable_tolist(PyObject* self, PyObject* args)
|
||||
HANDLE_TH_ERRORS
|
||||
jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW);
|
||||
auto self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
return torch::utils::tensor_to_list(self_.data());
|
||||
return torch::utils::tensor_to_list(self_);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
@ -69,7 +69,7 @@ static PyObject * THPGenerator_getState(THPGenerator *self)
|
||||
HANDLE_TH_ERRORS
|
||||
THGenerator *generator = THPGenerator_TH_CData(self);
|
||||
Variable var = torch::empty({0}, at::device(at::kCPU).dtype(at::kByte));
|
||||
THByteTensor_getRNGState(generator, (THByteTensor*)(var.data().unsafeGetTensorImpl()));
|
||||
THByteTensor_getRNGState(generator, (THByteTensor*)(var.unsafeGetTensorImpl()));
|
||||
return THPVariable_Wrap(std::move(var));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -81,7 +81,7 @@ static PyObject * THPGenerator_setState(THPGenerator *self, PyObject *_new_state
|
||||
if (!THPVariable_Check(_new_state)) {
|
||||
throw TypeError("expected a torch.ByteTensor, but got %s", Py_TYPE(_new_state)->tp_name);
|
||||
}
|
||||
auto& tensor = ((THPVariable*)_new_state)->cdata.data();
|
||||
auto& tensor = ((THPVariable*)_new_state)->cdata;
|
||||
if (tensor.layout() != kStrided || tensor.device().type() != kCPU || tensor.scalar_type() != kByte) {
|
||||
auto type_name = torch::utils::type_to_string(tensor.dispatch_type(), tensor.scalar_type());
|
||||
throw TypeError("expected a torch.ByteTensor, but got %s", type_name.c_str());
|
||||
|
@ -352,7 +352,7 @@ PyObject *THPModule_toDLPack(PyObject *_unused, PyObject *data)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assert(THPVariable_Check(data), "data must be a Tensor");
|
||||
DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_UnpackData(data));
|
||||
DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data));
|
||||
return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ class Cloneable : public virtual Module {
|
||||
"Are you sure you called register_parameter() inside reset() "
|
||||
"and not the constructor?");
|
||||
for (const auto& parameter : parameters_) {
|
||||
auto data = autograd::Variable(*parameter).data().clone();
|
||||
auto data = autograd::Variable(*parameter).clone();
|
||||
copy->parameters_[parameter.key()].set_data(
|
||||
device ? data.to(*device) : data);
|
||||
}
|
||||
@ -59,7 +59,7 @@ class Cloneable : public virtual Module {
|
||||
"Are you sure you called register_buffer() inside reset() "
|
||||
"and not the constructor?");
|
||||
for (const auto& buffer : buffers_) {
|
||||
auto data = autograd::Variable(*buffer).data().clone();
|
||||
auto data = autograd::Variable(*buffer).clone();
|
||||
copy->buffers_[buffer.key()].set_data(device ? data.to(*device) : data);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
|
@ -591,11 +591,11 @@ void Module::to_impl(Ts&&... ts) {
|
||||
}
|
||||
// Then move every parameter to the new dtype/device.
|
||||
for (auto& parameter : parameters_) {
|
||||
parameter->set_data(autograd::Variable(*parameter).data().to(ts...));
|
||||
parameter->set_data(autograd::Variable(*parameter).to(ts...));
|
||||
}
|
||||
// Then move every buffer to the new dtype/device.
|
||||
for (auto& buffer : buffers_) {
|
||||
buffer->set_data(autograd::Variable(*buffer).data().to(ts...));
|
||||
buffer->set_data(autograd::Variable(*buffer).to(ts...));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -38,7 +38,7 @@ bool InputArchive::try_read(
|
||||
if (tensor.defined()) {
|
||||
torch::NoGradGuard guard;
|
||||
if (tensor.device() != read_tensor.device()) {
|
||||
tensor.set_data(autograd::Variable(read_tensor).data());
|
||||
tensor.set_data(read_tensor);
|
||||
} else {
|
||||
tensor.set_(read_tensor);
|
||||
}
|
||||
|
@ -176,15 +176,15 @@ Variable & VariableType::checked_cast_variable(Tensor & t, const char * name, in
|
||||
}
|
||||
|
||||
const Tensor & VariableType::unpack(const Tensor & t, const char * name, int pos) {
|
||||
return checked_cast_variable(t, name, pos).data();
|
||||
return checked_cast_variable(t, name, pos);
|
||||
}
|
||||
|
||||
Tensor & VariableType::unpack(Tensor & t, const char * name, int pos) {
|
||||
return checked_cast_variable(t, name, pos).data();
|
||||
return checked_cast_variable(t, name, pos);
|
||||
}
|
||||
|
||||
SparseTensorRef VariableType::unpack(SparseTensorRef t, const char * name, int pos) {
|
||||
return SparseTensorRef(checked_cast_variable(t.tref, name, pos).data());
|
||||
return SparseTensorRef(checked_cast_variable(t.tref, name, pos));
|
||||
}
|
||||
|
||||
Tensor VariableType::unpack_opt(const Tensor & t, const char * name, int pos) {
|
||||
@ -205,7 +205,7 @@ std::vector<at::Tensor> VariableType::unpack(at::TensorList tl, const char *name
|
||||
AT_ERROR("Expected object of type Variable but found type ", t.dispatch_type().toString(), " at position #", i, " "
|
||||
"for iterable argument #", pos, " '", name, "'");
|
||||
}
|
||||
ret[i] = static_cast<const Variable&>(t).data();
|
||||
ret[i] = static_cast<const Variable&>(t);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -147,13 +147,24 @@ inline std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
|
||||
return SavedVariable{tensor, false /* is output */}; });
|
||||
}
|
||||
|
||||
// NOTE: For now, there is no guarantee that the tensors returned from
|
||||
// out-of-place ATen ops are not Variables. For example, the following operators:
|
||||
//
|
||||
// 1. `coalesce()` (called from `VariableType::coalesce()`)
|
||||
// 2. `_embedding_bag_cpu()` (called from `VariableType::_embedding_bag()`)
|
||||
//
|
||||
// can return its input or tensors created using the input's options, which can
|
||||
// potentially be Variables because inputs to ATen ops can be Variables.
|
||||
//
|
||||
// In the near future, once we make every tensor a Variable, these two
|
||||
// `as_variable()` functions are no-op and we can remove them.
|
||||
inline Tensor as_variable(Tensor tensor) {
|
||||
return make_variable(std::move(tensor), /*requires_grad=*/false);
|
||||
return tensor.is_variable() ? tensor : make_variable(std::move(tensor), /*requires_grad=*/false);
|
||||
}
|
||||
|
||||
inline std::vector<Tensor> as_variable(TensorList tl) {
|
||||
return fmap(tl, [](const Tensor& t) -> Tensor {
|
||||
return make_variable(t, /*requires_grad=*/false);
|
||||
return t.is_variable() ? t : make_variable(t, /*requires_grad=*/false);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -57,16 +57,28 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
|
||||
variable.grad() = new_grad.clone();
|
||||
}
|
||||
} else if (!GradMode::is_enabled()) {
|
||||
Variable& grad_variable = as_variable_ref(grad);
|
||||
// This case is not strictly necessary, but it makes the first-order only case
|
||||
// slightly more efficient and, what's more important, more predictable for
|
||||
// the users. Thanks to this case we can avoid changing the grad tensor,
|
||||
// a thing never promised and documented, but used in some hacks seen
|
||||
// on the internet.
|
||||
// slightly more efficient.
|
||||
Variable& grad_variable = as_variable_ref(grad);
|
||||
if (grad_variable.is_sparse() && !new_grad.is_sparse()) {
|
||||
grad_variable.set_data(new_grad.data() + grad_variable.data());
|
||||
// If `grad_variable` is sparse and `new_grad` is not sparse, their sum is not
|
||||
// sparse, and we must change the TensorImpl type of `grad_variable` for it to
|
||||
// store the result. However, changing the TensorImpl type of a tensor requires
|
||||
// changing the tensor itself, and thus in this case we have to change the grad
|
||||
// tensor.
|
||||
grad_variable = new_grad + grad_variable;
|
||||
} else {
|
||||
grad_variable.data() += new_grad.data();
|
||||
// In this case we can avoid changing the grad tensor. There are three scenarios
|
||||
// when we'll hit this case:
|
||||
//
|
||||
// 1. `grad_variable` is sparse, and `new_grad` is sparse.
|
||||
// 2. `grad_variable` is dense, and `new_grad` is sparse.
|
||||
// 3. `grad_variable` is dense, and `new_grad` is dense.
|
||||
//
|
||||
// In all of these three cases, `grad_variable += new_grad` is a valid operation
|
||||
// which adds `new_grad` to `grad_variable` in place. `grad_variable` is thus
|
||||
// still referring to the same tensor after the operation.
|
||||
grad_variable += new_grad;
|
||||
}
|
||||
} else {
|
||||
variable.grad() = grad + new_grad;
|
||||
|
@ -20,7 +20,7 @@ auto DelayedError::apply(variable_list&& inputs) -> variable_list {
|
||||
outputs.reserve(inputs.size());
|
||||
for (auto& var : inputs) {
|
||||
// FIXME: share version counters
|
||||
outputs.emplace_back(var.defined() ? var.data() : at::Tensor());
|
||||
outputs.emplace_back(var.defined() ? var.tensor_data() : at::Tensor());
|
||||
}
|
||||
return wrap_outputs(inputs, std::move(outputs), [&](edge_list&& next_edges) {
|
||||
return std::make_shared<Error>(msg, std::move(next_edges));
|
||||
|
@ -82,7 +82,7 @@ auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
|
||||
msg += "')'";
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
tensor_results[i] = ((THPVariable*)obj)->cdata.data();
|
||||
tensor_results[i] = ((THPVariable*)obj)->cdata.tensor_data();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -157,8 +157,8 @@ static void check_single_result(PyObject* _original, PyObject* _result, PyObject
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
auto& original = ((THPVariable*)_original)->cdata.data();
|
||||
auto& result = ((THPVariable*)_result)->cdata.data();
|
||||
auto& original = ((THPVariable*)_original)->cdata;
|
||||
auto& result = ((THPVariable*)_result)->cdata;
|
||||
|
||||
if (original.type() != result.type()) {
|
||||
std::stringstream ss;
|
||||
|
@ -48,9 +48,9 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject
|
||||
// by nn.Parameter() with no arguments.
|
||||
auto scalar_type = torch::tensors::get_default_scalar_type();
|
||||
auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options(scalar_type));
|
||||
tensor = static_cast<Variable&>(var).data();
|
||||
tensor = static_cast<Variable&>(var).tensor_data();
|
||||
} else if (THPVariable_Check(data)) {
|
||||
tensor = ((THPVariable*)data)->cdata.data();
|
||||
tensor = ((THPVariable*)data)->cdata.tensor_data();
|
||||
} else {
|
||||
throw torch::TypeError("Variable data has to be a tensor, but got %s",
|
||||
Py_TYPE(data)->tp_name);
|
||||
|
@ -150,7 +150,7 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P
|
||||
if (!PyType_Check(cls)) {
|
||||
throw TypeError("cls must be a type (got %s)", Py_TYPE(cls)->tp_name);
|
||||
}
|
||||
auto& data = as_variable_ref(r.tensor(1)).data();
|
||||
auto data = as_variable_ref(r.tensor(1)).tensor_data();
|
||||
auto var = make_variable(data, r.toBool(2));
|
||||
return THPVariable_NewWithVar((PyTypeObject*)cls, std::move(var));
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -163,7 +163,7 @@ PyObject *THPVariable_get_cdata(THPVariable *self)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto& var = self->cdata;
|
||||
return PyLong_FromVoidPtr(var.data().unsafeGetTensorImpl());
|
||||
return PyLong_FromVoidPtr(var.unsafeGetTensorImpl());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -206,15 +206,7 @@ static PyObject *THPVariable_is_leaf(THPVariable *self)
|
||||
static PyObject * THPVariable_get_data(THPVariable *self)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
/// NOTE: Previously, if we change the tensor metadata (e.g. sizes / strides /
|
||||
/// storage / storage_offset) of a tensor created from `.data`, those metadata
|
||||
/// in the original tensor will also be updated. However, the new behavior is that
|
||||
/// those metadata changes to the `.data` tensor will not update the original tensor
|
||||
/// anymore, and here we need to set `allow_tensor_metadata_change_` to false to
|
||||
/// make such changes explicitly illegal, in order to prevent users from changing
|
||||
/// metadata of the `.data` tensor and expecting the original tensor to also
|
||||
/// be updated.
|
||||
auto var = make_variable(self->cdata.data(), /*requires_grad=*/false, /*allow_tensor_metadata_change=*/false);
|
||||
auto var = self->cdata.variable_data();
|
||||
return THPVariable_Wrap(var);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -227,7 +219,7 @@ int THPVariable_set_data(THPVariable *self, PyObject *data)
|
||||
throw torch::TypeError("Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name);
|
||||
}
|
||||
|
||||
self->cdata.set_data(THPVariable_UnpackData(data));
|
||||
self->cdata.set_data(THPVariable_Unpack(data));
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
@ -524,10 +516,9 @@ void initTensorImplConversion(PyObject* module) {
|
||||
});
|
||||
// set on the module level to avoid mixing pybind and plain CPython extensions
|
||||
m.def("_tensor_impl_raw_handle", [](torch::autograd::Variable* t) -> void* {
|
||||
auto p = t->data().getIntrusivePtr();
|
||||
// We return a raw non-owning pointer here, we rely on surrounding
|
||||
// code to keep the original tensor alive
|
||||
return p.get();
|
||||
return t->getIntrusivePtr().get();
|
||||
});
|
||||
}
|
||||
}}
|
||||
|
@ -32,8 +32,3 @@ inline torch::autograd::Variable& THPVariable_Unpack(PyObject* obj) {
|
||||
auto var = (THPVariable*)obj;
|
||||
return var->cdata;
|
||||
}
|
||||
|
||||
inline at::Tensor& THPVariable_UnpackData(PyObject* obj) {
|
||||
auto var = (THPVariable*)obj;
|
||||
return var->cdata.data();
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ SavedVariable::SavedVariable(const Variable& variable, bool is_output) {
|
||||
has_grad_fn_ = !variable.is_leaf();
|
||||
// These copies are all shared_ptr copies, so slightly more expensive.
|
||||
// Do them here instead of in the init list in case data is undefined.
|
||||
data_ = variable.data();
|
||||
data_ = variable.tensor_data();
|
||||
if (variable.is_leaf()) {
|
||||
grad_accumulator_ = variable.grad_accumulator();
|
||||
} else if (!is_output) {
|
||||
|
@ -17,89 +17,21 @@
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <typeinfo>
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
Variable::Impl::Impl(at::Tensor data, std::unique_ptr<Variable::AutogradMeta> autograd_meta, bool requires_grad, Edge gradient_edge)
|
||||
: TensorImpl(data.type_id(), data.dtype(), data.device()),
|
||||
data_(std::move(data)) {
|
||||
autograd_meta->grad_fn_ = std::move(gradient_edge.function);
|
||||
autograd_meta->requires_grad_ = false;
|
||||
autograd_meta->is_view_ = false;
|
||||
autograd_meta->output_nr_ = gradient_edge.input_nr;
|
||||
Variable::AutogradMeta::AutogradMeta(at::TensorImpl* self_impl, bool requires_grad, Edge gradient_edge) {
|
||||
grad_fn_ = std::move(gradient_edge.function);
|
||||
requires_grad_ = false;
|
||||
is_view_ = false;
|
||||
output_nr_ = gradient_edge.input_nr;
|
||||
|
||||
// set_requires_grad also checks error conditions.
|
||||
autograd_meta->set_requires_grad(requires_grad, this);
|
||||
set_requires_grad(requires_grad, self_impl);
|
||||
TORCH_CHECK(
|
||||
!autograd_meta->grad_fn_ || !autograd_meta->requires_grad_,
|
||||
!grad_fn_ || !requires_grad_,
|
||||
"requires_grad should be false if grad_fn is set");
|
||||
if (!data_.defined()) {
|
||||
throw std::runtime_error("data is undefined");
|
||||
}
|
||||
|
||||
set_autograd_meta(std::move(autograd_meta));
|
||||
}
|
||||
|
||||
Variable::Impl::~Impl() = default;
|
||||
|
||||
int64_t Variable::Impl::numel() const {
|
||||
return data_.numel();
|
||||
}
|
||||
|
||||
IntArrayRef Variable::Impl::sizes() const {
|
||||
return data_.sizes();
|
||||
}
|
||||
|
||||
IntArrayRef Variable::Impl::strides() const {
|
||||
return data_.strides();
|
||||
}
|
||||
|
||||
bool Variable::Impl::is_contiguous(MemoryFormat memory_format) const {
|
||||
return data_.is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
int64_t Variable::Impl::dim() const {
|
||||
return data_.dim();
|
||||
}
|
||||
|
||||
int64_t Variable::Impl::size(int64_t d) const {
|
||||
return data_.size(d);
|
||||
}
|
||||
|
||||
int64_t Variable::Impl::stride(int64_t d) const {
|
||||
return data_.stride(d);
|
||||
}
|
||||
|
||||
void Variable::Impl::resize_dim(int64_t ndim) {
|
||||
AT_ERROR("variable impl does not have resize_dim");
|
||||
}
|
||||
|
||||
void Variable::Impl::set_size(int64_t dim, int64_t new_size) {
|
||||
AT_ERROR("variable impl does not have set_size");
|
||||
}
|
||||
|
||||
void Variable::Impl::set_stride(int64_t dim, int64_t new_stride) {
|
||||
AT_ERROR("variable impl does not have set_stride");
|
||||
}
|
||||
|
||||
void Variable::Impl::set_storage_offset(int64_t storage_offset) {
|
||||
AT_ERROR("variable impl does not have set_storage_offset");
|
||||
}
|
||||
|
||||
void* Variable::Impl::slow_data() const {
|
||||
return data_.unsafeGetTensorImpl()->slow_data();
|
||||
}
|
||||
|
||||
bool Variable::Impl::has_storage() const {
|
||||
return data_.has_storage();
|
||||
}
|
||||
|
||||
const at::Storage& Variable::Impl::storage() const {
|
||||
return data_.storage();
|
||||
}
|
||||
|
||||
int64_t Variable::Impl::storage_offset() const {
|
||||
return data_.storage_offset();
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> Variable::grad_accumulator() const {
|
||||
@ -119,7 +51,7 @@ std::shared_ptr<Function> Variable::grad_accumulator() const {
|
||||
return result;
|
||||
|
||||
c10::raw::intrusive_ptr::incref(unsafeGetTensorImpl());
|
||||
auto intrusive_from_this = c10::intrusive_ptr<Variable::Impl>::reclaim(static_cast<Variable::Impl*>(unsafeGetTensorImpl()));
|
||||
auto intrusive_from_this = c10::intrusive_ptr<at::TensorImpl>::reclaim(unsafeGetTensorImpl());
|
||||
result = std::make_shared<AccumulateGrad>(Variable(std::move(intrusive_from_this)));
|
||||
autograd_meta->grad_accumulator_ = result;
|
||||
return result;
|
||||
@ -145,56 +77,58 @@ void Variable::backward(
|
||||
|
||||
std::vector<Variable> inputs;
|
||||
if (!gradient.has_value()) {
|
||||
gradient = make_variable(at::ones_like(data()), /*requires_grad=*/false);
|
||||
gradient = at::ones_like(*this);
|
||||
}
|
||||
inputs.push_back(std::move(as_variable_ref(*gradient)));
|
||||
Engine::get_default_engine().execute(edges, inputs, keep_graph, create_graph);
|
||||
}
|
||||
|
||||
void Variable::Impl::set_data(const at::Tensor &new_data) {
|
||||
void Variable::set_data(const at::Tensor &new_data) {
|
||||
// `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields
|
||||
// from `new_data` to `var`. It requires that `new_data` has the same derived
|
||||
// type of TensorImpl as `var`.
|
||||
TORCH_CHECK(
|
||||
typeid(*(this->unsafeGetTensorImpl())) == typeid(*(new_data.unsafeGetTensorImpl())),
|
||||
"Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have different types of TensorImpl.");
|
||||
|
||||
// Resets gradient accumulator if metadata is out of date
|
||||
auto autograd_meta = get_autograd_meta();
|
||||
Variable::AutogradMeta* autograd_meta = get_autograd_meta();
|
||||
std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
|
||||
auto prior_accumulator = autograd_meta->grad_accumulator_.lock();
|
||||
if (prior_accumulator) {
|
||||
const auto prior_device = prior_accumulator->input_metadata(0).device();
|
||||
const auto new_device = new_data.device();
|
||||
|
||||
if (new_data.type() != data_.type() || prior_device != new_device) {
|
||||
if (new_data.type() != type() || prior_device != new_device) {
|
||||
autograd_meta->grad_accumulator_.reset();
|
||||
}
|
||||
}
|
||||
|
||||
// Updates metadata
|
||||
data_type_ = new_data.type().typeMeta();
|
||||
device_opt_ = new_data.device();
|
||||
type_id_ = new_data.dispatch_type().type_id();
|
||||
|
||||
// Version counter is not shared when we replace a `Variable`'s underlying `Tensor`
|
||||
// Version counter is not shared when we replace a `Variable`'s tensor data
|
||||
// by calling `set_data(...)`. The original version of the `Variable` is always preserved.
|
||||
// See NOTE [ Version Counter Sharing ] for details.
|
||||
auto new_data_impl_copy = new_data.getIntrusivePtr()->shallow_copy_and_detach(
|
||||
/*version_counter=*/data_.unsafeGetTensorImpl()->version_counter(),
|
||||
/*allow_tensor_metadata_change=*/true);
|
||||
data_ = std::move(at::Tensor(new_data_impl_copy));
|
||||
//
|
||||
// `var.set_data(new_data)` always ignores `var`'s `allow_tensor_metadata_change_`, because
|
||||
// users need this API as an escape hatch for changing a tensor's metadata regardless of its
|
||||
// `allow_tensor_metadata_change_` value, and the users are responsible for ensuring this is
|
||||
// the behavior they want.
|
||||
get()->shallow_copy_from(new_data.getIntrusivePtr());
|
||||
}
|
||||
|
||||
void Variable::Impl::release_resources() {
|
||||
autograd_meta_.reset();
|
||||
data_.reset();
|
||||
}
|
||||
|
||||
Variable::DifferentiableViewImpl::DifferentiableViewImpl(Variable base, at::Tensor data, Edge gradient_edge, std::unique_ptr<Variable::DifferentiableViewMeta> autograd_meta)
|
||||
: Variable::Impl(std::move(data), std::move(autograd_meta), false, std::move(gradient_edge)) {
|
||||
auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
|
||||
diff_view_meta->base_ = std::move(base);
|
||||
TORCH_CHECK(diff_view_meta->base_.defined(), "base is undefined");
|
||||
if (diff_view_meta->base_.is_view()) {
|
||||
diff_view_meta->base_ = diff_view_meta->base_.base();
|
||||
Variable::DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base, Edge gradient_edge)
|
||||
: Variable::AutogradMeta(self_impl, false, std::move(gradient_edge)) {
|
||||
base_ = std::move(base);
|
||||
TORCH_CHECK(base_.defined(), "base is undefined");
|
||||
if (base_.is_view()) {
|
||||
base_ = base_.base();
|
||||
}
|
||||
diff_view_meta->is_view_ = true;
|
||||
data_.unsafeGetTensorImpl()->set_version_counter(diff_view_meta->base_.version_counter());
|
||||
diff_view_meta->attr_version = data_.unsafeGetTensorImpl()->version_counter().current_version();
|
||||
is_view_ = true;
|
||||
self_impl->set_version_counter(base_.version_counter());
|
||||
attr_version = self_impl->version_counter().current_version();
|
||||
}
|
||||
|
||||
Variable::DifferentiableViewMeta::~DifferentiableViewMeta() {
|
||||
base_.reset();
|
||||
}
|
||||
|
||||
const std::shared_ptr<Function>& Variable::grad_fn() const {
|
||||
@ -211,7 +145,7 @@ const std::shared_ptr<Function>& Variable::grad_fn() const {
|
||||
fn->self_geometry = at::TensorGeometry(diff_view_meta->base_);
|
||||
fn->size = sizes().vec();
|
||||
fn->stride = strides().vec();
|
||||
fn->storage_offset = data().storage_offset();
|
||||
fn->storage_offset = storage_offset();
|
||||
fn->set_next_edges(collect_next_edges(diff_view_meta->base_));
|
||||
fn->add_input_metadata(
|
||||
diff_view_meta->base_.type()
|
||||
@ -226,12 +160,6 @@ const std::shared_ptr<Function>& Variable::grad_fn() const {
|
||||
}
|
||||
}
|
||||
|
||||
void Variable::DifferentiableViewImpl::release_resources() {
|
||||
auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
|
||||
diff_view_meta->base_.reset();
|
||||
Variable::Impl::release_resources();
|
||||
}
|
||||
|
||||
void Variable::rebase_history(Edge gradient_edge) {
|
||||
AT_ASSERT(gradient_edge.function != nullptr);
|
||||
if (is_view()) {
|
||||
@ -243,7 +171,7 @@ void Variable::rebase_history(Edge gradient_edge) {
|
||||
"Functions which modify views in-place must return a single Variable");
|
||||
diff_view_meta->output_nr_ = gradient_edge.input_nr;
|
||||
auto copy_slices = std::make_shared<CopySlices>(
|
||||
diff_view_meta->base_, at::TensorGeometry(data()), std::move(gradient_edge.function));
|
||||
diff_view_meta->base_, at::TensorGeometry(*this), std::move(gradient_edge.function));
|
||||
diff_view_meta->base_.set_gradient_edge({std::move(copy_slices), 0});
|
||||
grad_fn(); // trigger an update to the view's grad_fn
|
||||
} else {
|
||||
|
@ -74,10 +74,10 @@ struct Function;
|
||||
/// can thus call functions defined on `Tensor`s also with `Variable`s. For
|
||||
/// this, the `Variable` class allows implicit construction from `Tensor`. It is
|
||||
/// the responsibility of calling code to ensure that this constructor is
|
||||
/// invoked only when the `Tensor`'s dynamic type is actually `Variable`. Most
|
||||
/// notably, it is *not* correct to construct a brand new `Variable` from a
|
||||
/// `Tensor` using this constructor. To do so, you must use the `make_variable`
|
||||
/// free function instead. To create a view variable, use `make_variable_view`.
|
||||
/// invoked only when the `Tensor` contains autograd metadata. Most notably, it
|
||||
/// is *not* correct to construct a brand new `Variable` from a `Tensor` using
|
||||
/// this constructor. To do so, you must use the `make_variable` free function
|
||||
/// instead. To create a view variable, use `make_variable_view`.
|
||||
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
struct TORCH_API Variable : public at::Tensor {
|
||||
@ -87,8 +87,8 @@ struct TORCH_API Variable : public at::Tensor {
|
||||
// Factory Functions
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
// NOTE: These factory functions have to be friends to access the
|
||||
// `Variable::Impl`. As a side effect, it allows us to keep them in the class.
|
||||
// TODO: These factory functions don't need to be friends anymore. Move them out of
|
||||
// the Variable class.
|
||||
|
||||
/// Creates a `Variable` that is a *view* of another (*base*) variable.
|
||||
/// The `gradient_edge` is an optional (gradient_function, input_number) pair.
|
||||
@ -122,8 +122,8 @@ struct TORCH_API Variable : public at::Tensor {
|
||||
bool requires_grad,
|
||||
bool allow_tensor_metadata_change);
|
||||
|
||||
/// Creates a `Variable` from the given `Tensor` and specify a
|
||||
/// `gradient_edge`, i.e. a (function, input_nr) pair specifying the function
|
||||
/// Creates a `Variable` from the given `Tensor`, copying its underlying `TensorImpl`.
|
||||
/// `gradient_edge` should be a (function, input_nr) pair specifying the function
|
||||
/// in the autograd graph, and what particular input of that function, this
|
||||
/// variable is connected to.
|
||||
friend Variable make_variable(
|
||||
@ -151,8 +151,28 @@ struct TORCH_API Variable : public at::Tensor {
|
||||
|
||||
// NOTE: Assignment operators to Tensor come for free from the constructors.
|
||||
|
||||
const at::Tensor& data() const noexcept;
|
||||
at::Tensor& data() noexcept;
|
||||
/// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
|
||||
/// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
|
||||
/// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
|
||||
///
|
||||
/// One notable difference with the legacy `.data()` function is that changes to the
|
||||
/// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
|
||||
/// will not update the original `Variable`, due to the fact that this function
|
||||
/// shallow-copies the `Variable`'s underlying TensorImpl.
|
||||
at::Tensor tensor_data() const noexcept;
|
||||
|
||||
/// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
|
||||
/// in Python, which create a new `Variable` that shares the same storage and
|
||||
/// tensor metadata with the original `Variable`, but with a completely new
|
||||
/// autograd history.
|
||||
///
|
||||
/// NOTE: If we change the tensor metadata (e.g. sizes / strides /
|
||||
/// storage / storage_offset) of a variable created from `var.variable_data()`, those
|
||||
/// changes will not update the original variable `var`. In `.variable_data()`, we set
|
||||
/// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
|
||||
/// in order to prevent users from changing metadata of `var.variable_data()`
|
||||
/// and expecting the original variable `var` to also be updated.
|
||||
at::Tensor variable_data() const noexcept;
|
||||
|
||||
// Gradient Function and Edges
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -305,18 +325,13 @@ struct TORCH_API Variable : public at::Tensor {
|
||||
Variable::AutogradMeta* get_autograd_meta() const noexcept;
|
||||
|
||||
private:
|
||||
/// Private implementation struct of the `Variable`. This struct declaration
|
||||
/// and the `get()` method which exposes it shall forever remain private and
|
||||
/// never be exposed to the public interface of this class.
|
||||
struct Impl;
|
||||
struct DifferentiableViewImpl;
|
||||
struct DifferentiableViewMeta;
|
||||
|
||||
// Private Methods
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Variable(c10::intrusive_ptr<Variable::Impl> self);
|
||||
Impl* get() const;
|
||||
Variable(c10::intrusive_ptr<at::TensorImpl> self);
|
||||
at::TensorImpl* get() const;
|
||||
};
|
||||
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -373,73 +388,15 @@ struct TORCH_API Variable::AutogradMeta : public c10::AutogradMetaInterface {
|
||||
const Variable& grad() const override {
|
||||
return grad_;
|
||||
}
|
||||
|
||||
AutogradMeta(
|
||||
at::TensorImpl* self_impl,
|
||||
bool requires_grad = false,
|
||||
Edge gradient_edge = Edge());
|
||||
};
|
||||
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Variable::DifferentiableViewMeta
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
struct TORCH_API Variable::DifferentiableViewMeta : public Variable::AutogradMeta {
|
||||
/// The base `Variable` (never a view).
|
||||
Variable base_;
|
||||
|
||||
/// The value of the version_counter at the time grad_fn was created. The
|
||||
/// grad_fn field is stale if attr_version !=
|
||||
/// version_counter.current_version().
|
||||
uint32_t attr_version;
|
||||
|
||||
bool requires_grad() const override {
|
||||
return requires_grad_ || grad_fn_ || (is_view_ && base_.requires_grad());
|
||||
}
|
||||
};
|
||||
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Variable::Impl
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
struct TORCH_API Variable::Impl : public at::TensorImpl {
|
||||
explicit Impl(
|
||||
at::Tensor data,
|
||||
std::unique_ptr<Variable::AutogradMeta> autograd_meta,
|
||||
bool requires_grad = false,
|
||||
Edge gradient_edge = Edge());
|
||||
|
||||
~Impl() override;
|
||||
|
||||
int64_t numel() const override;
|
||||
at::IntArrayRef sizes() const override;
|
||||
at::IntArrayRef strides() const override;
|
||||
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Any) const override;
|
||||
int64_t size(int64_t d) const override;
|
||||
int64_t stride(int64_t d) const override;
|
||||
void resize_dim(int64_t ndim) override;
|
||||
void set_size(int64_t dim, int64_t new_size) override;
|
||||
void set_stride(int64_t dim, int64_t new_stride) override;
|
||||
void set_storage_offset(int64_t storage_offset) override;
|
||||
|
||||
int64_t dim() const override;
|
||||
bool has_storage() const override;
|
||||
const at::Storage& storage() const override;
|
||||
void* slow_data() const override;
|
||||
|
||||
void set_data(const at::Tensor &new_data);
|
||||
|
||||
/// Reset all expensive fields to free up resources
|
||||
void release_resources() override;
|
||||
|
||||
Variable::AutogradMeta* get_autograd_meta() const {
|
||||
return static_cast<Variable::AutogradMeta*>(autograd_meta());
|
||||
}
|
||||
|
||||
int64_t storage_offset() const override;
|
||||
|
||||
/// The underlying data tensor for this Variable.
|
||||
/// This field will be removed once VariableImpl and TensorImpl are merged.
|
||||
at::Tensor data_;
|
||||
};
|
||||
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Variable::DifferentiableViewImpl
|
||||
// Variable::DifferentiableViewMeta
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// NOTE [ Autograd View Variables ]
|
||||
@ -482,7 +439,7 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
|
||||
/// var[1] filled with all ones and
|
||||
/// zeros everywhere else
|
||||
///
|
||||
/// Variable::DifferentiableViewImpl is created to support gradient tracking of
|
||||
/// Variable::DifferentiableViewMeta is created to support gradient tracking of
|
||||
/// such **in-place** operations. In particular,
|
||||
/// + if an in-place op is done on base, the grad_fn field of the view may
|
||||
/// become stale. So accesses should always go through grad_fn(), which
|
||||
@ -497,8 +454,8 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
|
||||
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
/// In certain cases, although function outputs share storage with inputs, they
|
||||
/// will **never** require gradient history tracking. Instead of registering the
|
||||
/// view relation via DifferentiableViewImpl in autograd, the views will be
|
||||
/// using usual Variable::Impl and just share the version counters with the base
|
||||
/// view relation via DifferentiableViewMeta in autograd, the views will be
|
||||
/// using usual AutogradMeta and just share the version counters with the base
|
||||
/// Variables.
|
||||
/// Such views include:
|
||||
/// 1. Views created from .detach()
|
||||
@ -511,15 +468,21 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
|
||||
/// through the view relation.
|
||||
/// Relevant logic for non-differentiable views is implemented in
|
||||
/// make_variable_view below, and wrap_output of gen_variable_type.py.
|
||||
struct TORCH_API Variable::DifferentiableViewImpl : public Variable::Impl {
|
||||
DifferentiableViewImpl(
|
||||
Variable base,
|
||||
at::Tensor data,
|
||||
Edge gradient_edge,
|
||||
std::unique_ptr<Variable::DifferentiableViewMeta> autograd_meta);
|
||||
struct TORCH_API Variable::DifferentiableViewMeta : public Variable::AutogradMeta {
|
||||
/// The base `Variable` (never a view).
|
||||
Variable base_;
|
||||
|
||||
/// Reset all expensive fields to free up resources
|
||||
void release_resources() override;
|
||||
/// The value of the version_counter at the time grad_fn was created. The
|
||||
/// grad_fn field is stale if attr_version !=
|
||||
/// version_counter.current_version().
|
||||
uint32_t attr_version;
|
||||
|
||||
bool requires_grad() const override {
|
||||
return requires_grad_ || grad_fn_ || (is_view_ && base_.requires_grad());
|
||||
}
|
||||
|
||||
DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base, Edge gradient_edge);
|
||||
~DifferentiableViewMeta();
|
||||
};
|
||||
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -545,24 +508,21 @@ inline Variable make_variable_view(
|
||||
Edge gradient_edge = Edge()) {
|
||||
if (data.defined()) {
|
||||
if (is_differentiable) {
|
||||
/// Differentiable view. Track history with DifferentiableViewImpl.
|
||||
/// Differentiable view. Track history with DifferentiableViewMeta.
|
||||
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
|
||||
/*version_counter=*/0,
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
auto data_copy = at::Tensor(data_impl_copy);
|
||||
auto diff_view_meta = c10::guts::make_unique<Variable::DifferentiableViewMeta>();
|
||||
return Variable(c10::make_intrusive<Variable::DifferentiableViewImpl>(
|
||||
std::move(base), std::move(data_copy), std::move(gradient_edge), std::move(diff_view_meta)));
|
||||
data_impl_copy->set_autograd_meta(c10::guts::make_unique<Variable::DifferentiableViewMeta>(
|
||||
data_impl_copy.get(), std::move(base), std::move(gradient_edge)));
|
||||
return Variable(data_impl_copy);
|
||||
} else {
|
||||
/// Non-differentiable view. Just share version counter.
|
||||
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
|
||||
/*version_counter=*/base.version_counter(),
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
auto data_copy = at::Tensor(data_impl_copy);
|
||||
auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
|
||||
auto var = Variable(c10::make_intrusive<Variable::Impl>(
|
||||
std::move(data_copy), std::move(autograd_meta), false, std::move(gradient_edge)));
|
||||
return var;
|
||||
data_impl_copy->set_autograd_meta(c10::guts::make_unique<Variable::AutogradMeta>(
|
||||
data_impl_copy.get(), false, std::move(gradient_edge)));
|
||||
return Variable(data_impl_copy);
|
||||
}
|
||||
}
|
||||
return Variable();
|
||||
@ -574,14 +534,14 @@ inline Variable make_variable(
|
||||
bool allow_tensor_metadata_change = true) {
|
||||
TORCH_CHECK(
|
||||
!data.is_variable(),
|
||||
"Must not create a new variable from a variable, use its .data()");
|
||||
"Must not create a new variable from a variable, use its .tensor_data()");
|
||||
if (data.defined()) {
|
||||
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
|
||||
/*version_counter=*/0,
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
auto data_copy = at::Tensor(data_impl_copy);
|
||||
auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
|
||||
return Variable(c10::make_intrusive<Variable::Impl>(data_copy, std::move(autograd_meta), requires_grad));
|
||||
data_impl_copy->set_autograd_meta(c10::guts::make_unique<Variable::AutogradMeta>(
|
||||
data_impl_copy.get(), requires_grad));
|
||||
return Variable(data_impl_copy);
|
||||
}
|
||||
return Variable();
|
||||
}
|
||||
@ -592,12 +552,13 @@ inline Variable make_variable_consuming(
|
||||
bool allow_tensor_metadata_change = true) {
|
||||
TORCH_CHECK(
|
||||
!data.is_variable(),
|
||||
"Must not create a new variable from a variable, use its .data()");
|
||||
"Must not create a new variable from a variable, use its .tensor_data()");
|
||||
if (data.defined()) {
|
||||
AT_ASSERT(data.getIntrusivePtr().use_count() == 1);
|
||||
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
|
||||
return Variable(c10::make_intrusive<Variable::Impl>(std::move(data), std::move(autograd_meta), requires_grad));
|
||||
auto data_impl = data.getIntrusivePtr();
|
||||
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
data_impl->set_autograd_meta(c10::guts::make_unique<Variable::AutogradMeta>(data_impl.get(), requires_grad));
|
||||
return Variable(std::move(data_impl));
|
||||
}
|
||||
return Variable();
|
||||
}
|
||||
@ -608,14 +569,14 @@ inline Variable make_variable(
|
||||
bool allow_tensor_metadata_change = true) {
|
||||
TORCH_CHECK(
|
||||
!data.is_variable(),
|
||||
"Must not create a new variable from a variable, use its .data()");
|
||||
"Must not create a new variable from a variable, use its .tensor_data()");
|
||||
if (data.defined()) {
|
||||
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
|
||||
/*version_counter=*/0,
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
auto data_copy = at::Tensor(data_impl_copy);
|
||||
auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
|
||||
return Variable(c10::make_intrusive<Variable::Impl>(data_copy, std::move(autograd_meta), false, std::move(gradient_edge)));
|
||||
data_impl_copy->set_autograd_meta(c10::guts::make_unique<Variable::AutogradMeta>(
|
||||
data_impl_copy.get(), false, std::move(gradient_edge)));
|
||||
return Variable(data_impl_copy);
|
||||
}
|
||||
return Variable();
|
||||
}
|
||||
@ -642,12 +603,19 @@ inline const Variable& as_variable_ref(const at::Tensor& tensor) {
|
||||
return static_cast<const Variable&>(tensor);
|
||||
}
|
||||
|
||||
inline const at::Tensor& Variable::data() const noexcept {
|
||||
return get()->data_;
|
||||
inline at::Tensor Variable::tensor_data() const noexcept {
|
||||
auto self_impl_copy = get()->shallow_copy_and_detach(
|
||||
/*version_counter=*/get()->version_counter(),
|
||||
/*allow_tensor_metadata_change=*/get()->allow_tensor_metadata_change());
|
||||
return at::Tensor(self_impl_copy);
|
||||
}
|
||||
|
||||
inline at::Tensor& Variable::data() noexcept {
|
||||
return get()->data_;
|
||||
inline at::Tensor Variable::variable_data() const noexcept {
|
||||
auto self_impl_copy = get()->shallow_copy_and_detach(
|
||||
/*version_counter=*/0,
|
||||
/*allow_tensor_metadata_change=*/false);
|
||||
self_impl_copy->set_autograd_meta(c10::guts::make_unique<Variable::AutogradMeta>(self_impl_copy.get(), false));
|
||||
return at::Tensor(self_impl_copy);
|
||||
}
|
||||
|
||||
// Gradient Function and Edges
|
||||
@ -667,14 +635,10 @@ inline std::shared_ptr<Function> Variable::try_get_grad_accumulator() const {
|
||||
}
|
||||
|
||||
inline Variable Variable::detach() const {
|
||||
auto var = make_variable_view(*this, get()->data_, /*is_differentiable=*/false, /*allow_tensor_metadata_change=*/false, Edge());
|
||||
auto var = make_variable_view(*this, *this, /*is_differentiable=*/false, /*allow_tensor_metadata_change=*/false, Edge());
|
||||
return var;
|
||||
}
|
||||
|
||||
inline void Variable::set_data(const at::Tensor &new_data) {
|
||||
get()->set_data(new_data);
|
||||
}
|
||||
|
||||
inline void Variable::set_gradient_edge(Edge edge) noexcept {
|
||||
get_autograd_meta()->grad_fn_ = std::move(edge.function);
|
||||
get_autograd_meta()->output_nr_ = edge.input_nr;
|
||||
@ -693,19 +657,19 @@ inline bool Variable::is_leaf() const noexcept {
|
||||
|
||||
inline void Variable::set_version_counter(
|
||||
const c10::VariableVersion& version_counter) noexcept {
|
||||
data().unsafeGetTensorImpl()->set_version_counter(version_counter);
|
||||
unsafeGetTensorImpl()->set_version_counter(version_counter);
|
||||
}
|
||||
|
||||
inline void Variable::bump_version() noexcept {
|
||||
data().unsafeGetTensorImpl()->bump_version();
|
||||
unsafeGetTensorImpl()->bump_version();
|
||||
}
|
||||
|
||||
inline uint32_t Variable::current_version() const noexcept {
|
||||
return data().unsafeGetTensorImpl()->version_counter().current_version();
|
||||
return unsafeGetTensorImpl()->version_counter().current_version();
|
||||
}
|
||||
|
||||
inline const c10::VariableVersion& Variable::version_counter() const noexcept {
|
||||
return data().unsafeGetTensorImpl()->version_counter();
|
||||
return unsafeGetTensorImpl()->version_counter();
|
||||
}
|
||||
|
||||
// Hooks
|
||||
@ -760,17 +724,17 @@ inline PyObject* Variable::pyobj() const noexcept {
|
||||
}
|
||||
|
||||
inline Variable::AutogradMeta* Variable::get_autograd_meta() const noexcept {
|
||||
return get()->get_autograd_meta();
|
||||
return static_cast<Variable::AutogradMeta*>(get()->autograd_meta());
|
||||
}
|
||||
|
||||
// Private Methods
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
inline Variable::Variable(c10::intrusive_ptr<Variable::Impl> self)
|
||||
inline Variable::Variable(c10::intrusive_ptr<at::TensorImpl> self)
|
||||
: at::Tensor(std::move(self)) {}
|
||||
|
||||
inline Variable::Impl* Variable::get() const {
|
||||
inline at::TensorImpl* Variable::get() const {
|
||||
TORCH_CHECK(defined(), "Called Variable::get() on an undefined Variable");
|
||||
return static_cast<Variable::Impl*>(impl_.get());
|
||||
return unsafeGetTensorImpl();
|
||||
}
|
||||
}} // namespace torch::autograd
|
||||
|
@ -136,7 +136,7 @@ PyObject * THCPModule_getRNGState(PyObject *_unused)
|
||||
using namespace torch::autograd;
|
||||
HANDLE_TH_ERRORS
|
||||
Variable var = torch::empty(0, at::device(at::kCPU).dtype(at::kByte));
|
||||
THCRandom_getRNGState(state, (THByteTensor*)(var.data().unsafeGetTensorImpl()));
|
||||
THCRandom_getRNGState(state, (THByteTensor*)(var.unsafeGetTensorImpl()));
|
||||
return THPVariable_Wrap(var);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -150,7 +150,7 @@ PyObject * THCPModule_setRNGState(PyObject *_unused, PyObject *obj)
|
||||
throw TypeError("set_rng_state expects a torch.ByteTensor, but got %s",
|
||||
Py_TYPE(obj)->tp_name);
|
||||
}
|
||||
auto& tensor = THPVariable_UnpackData(obj);
|
||||
auto& tensor = THPVariable_Unpack(obj);
|
||||
THCRandom_setRNGState(state, (THByteTensor*)tensor.unsafeGetTensorImpl());
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
@ -144,7 +144,7 @@ tensor_list2d broadcast_coalesced(TensorList tensors, IntArrayRef devices, size_
|
||||
// See NOTE [ Version Counter in comm.*_coalesced ]
|
||||
AT_ASSERT(t.is_variable());
|
||||
Variable var = t;
|
||||
device_outputs.push_back(make_variable(var.data(), false));
|
||||
device_outputs.push_back(make_variable(var.tensor_data(), false));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -157,7 +157,7 @@ tensor_list2d broadcast_coalesced(TensorList tensors, IntArrayRef devices, size_
|
||||
// See NOTE [ Version Counter in comm.*_coalesced ]
|
||||
AT_ASSERT(t.is_variable());
|
||||
Variable var = t;
|
||||
device_outputs.push_back(make_variable(var.data(), false));
|
||||
device_outputs.push_back(make_variable(var.tensor_data(), false));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -374,7 +374,7 @@ static std::vector<at::Tensor> extract_tensors(PyObject* obj) {
|
||||
"expected Tensor at %d (got %s)", (int)i, Py_TYPE(item)->tp_name);
|
||||
}
|
||||
auto var = (THPVariable*)item;
|
||||
list.emplace_back(var->cdata.data());
|
||||
list.emplace_back(var->cdata);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ extern PyObject* THCPByteTensorClass;
|
||||
|
||||
THDTensorDescriptor THDPModule_makeDescriptor(PyObject *obj) {
|
||||
auto var = (THPVariable*)obj;
|
||||
return var->cdata.data();
|
||||
return var->cdata.tensor_data();
|
||||
}
|
||||
|
||||
static THDRequest* _unpackRequest(PyObject *obj)
|
||||
|
@ -40,7 +40,7 @@ std::vector<IValue> runNode(Node* n) {
|
||||
// error gets caught within propagateNode()
|
||||
throw c10::Error("Can't insert requires grad as constant", "");
|
||||
}
|
||||
return IValue(autograd::as_variable_ref(t).data());
|
||||
return IValue(t);
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ at::Tensor unwrap_tensor(at::Tensor&& tensor) {
|
||||
throw std::runtime_error("Autograd not yet supported for c10 ops.");
|
||||
}
|
||||
if (tensor.is_variable()) {
|
||||
return torch::autograd::Variable(std::move(tensor)).data();
|
||||
return torch::autograd::Variable(std::move(tensor)).tensor_data();
|
||||
} else {
|
||||
return std::move(tensor);
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
|
||||
"Cannot input a tensor of dimension other than 0 as a scalar argument");
|
||||
}
|
||||
if (toInt &&
|
||||
!isIntegralType(autograd::as_variable_ref(t).data().scalar_type())) {
|
||||
!isIntegralType(t.scalar_type())) {
|
||||
std::stringstream ss;
|
||||
ss << "Cannot input a tensor of type " << t.scalar_type()
|
||||
<< " as an integral argument";
|
||||
|
@ -28,7 +28,7 @@ using caffe2::int8::Int8TensorCPU;
|
||||
namespace {
|
||||
|
||||
caffe2::Tensor from_at_tensor(const c10::IValue& v) {
|
||||
return caffe2::Tensor(autograd::Variable(std::move(v).toTensor()).data());
|
||||
return caffe2::Tensor(autograd::Variable(std::move(v).toTensor()).tensor_data());
|
||||
}
|
||||
|
||||
Int8TensorCPU from_proxy(const c10::IValue& proxy) {
|
||||
|
@ -108,11 +108,10 @@ void module_state_to(
|
||||
bool non_blocking) {
|
||||
// Need to access the `at::Tensor` as a `Variable` here.
|
||||
autograd::Variable variable = s.value().toTensor();
|
||||
at::Tensor data = variable.data();
|
||||
// Use the data's original device or dtype if not supplied here.
|
||||
auto new_data = data.to(
|
||||
device.value_or(data.device()),
|
||||
dtype.value_or(data.scalar_type()),
|
||||
auto new_data = variable.to(
|
||||
device.value_or(variable.device()),
|
||||
dtype.value_or(variable.scalar_type()),
|
||||
non_blocking);
|
||||
variable.set_data(new_data);
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ inline bool check_type(PyObject* obj, at::TensorTypeId id, at::ScalarType dtype)
|
||||
|
||||
template<typename T>
|
||||
inline T* unpack(PyObject* obj) {
|
||||
return (T*) ((THPVariable*)obj)->cdata.data().unsafeGetTensorImpl();
|
||||
return (T*) ((THPVariable*)obj)->cdata.unsafeGetTensorImpl();
|
||||
}
|
||||
|
||||
}} // namespace torch::nn
|
||||
@ -28,7 +28,7 @@ static inline int get_device(PyObject* args) {
|
||||
for (int i = 0, n = PyTuple_GET_SIZE(args); i != n; i++) {
|
||||
PyObject* arg = PyTuple_GET_ITEM(args, i);
|
||||
if (THPVariable_Check(arg)) {
|
||||
auto& tensor = THPVariable_UnpackData(arg);
|
||||
auto& tensor = THPVariable_Unpack(arg);
|
||||
if (tensor.is_cuda()) {
|
||||
return tensor.get_device();
|
||||
}
|
||||
|
Reference in New Issue
Block a user