Summary:
As part of the Variable/Tensor merge work: https://github.com/pytorch/pytorch/issues/13638, we make the following changes in this PR:
1. Remove the `Variable::Impl` class and the `DifferentiableViewImpl` class
2. Change all `Variable.data()` call sites to either use `Variable` directly, or use `Variable.tensor_data()`
3. Remove `Variable.data()` API
3. Add `Variable.variable_data()` that matches `tensor.data` in Python API, which creates a new `Variable` that shares the same storage and tensor metadata with the original `Variable`, but with a completely new autograd history.
After this PR, Variable doesn't wrap a Tensor internally anymore, and both Variable and Tensor use the same TensorImpl class as its `impl_`. The only difference is that Variable always has AutogradMeta in its TensorImpl, but Tensor doesn't.
**Note that this PR is BC-breaking in the following use cases:**
**Use Case 1:**
Previously, `x.data = y` works even if `x` and `y` are of different TensorImpl type (e.g. `x` is a CPU dense tensor whose impl is of type TensorImpl, while `y` is a CPU sparse tensor whose impl is of type SparseTensorImpl). However, after this PR, `x.data = y` doesn't work anymore if `x` and `y` are of different TensorImpl type, because the underlying implementation `variable.set_data(tensor)` no longer works if `variable` and `tensor` have different TensorImpl type.
**Use Case 2:**
If a tensor `x`'s `grad` is sparse, accumulating dense gradients to `x` will change the tensor that `x.grad` is pointing to. This is better illustrated with the following example:
```python
params = torch.tensor([1.5, 1.5]).requires_grad_()
with torch.no_grad():
# Change gradient to a sparse tensor
params.grad = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]))
grad_saved = params.grad
params.backward(torch.tensor([1.5, 1.5]))
assert id(grad_saved) == id(params.grad) # This will fail after this PR
```
The assertion in the last line will fail after this PR, because adding dense gradients to sparse gradients will change the `params.grad` tensor reference.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17072
Differential Revision: D14075257
Pulled By: yf225
fbshipit-source-id: 0e681df641270dea586042dd26db59f2e76b5957
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
Summary:
There's a bunch of legacy code where people are explicitly instantiating Variable, and these call-sites have thus far been untraceable (appearing as prim::Constant nodes with the tensor value at the time of tracing). This makes it so that the new variable inherits the traced Value* from the tensor it's being constructed from
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11463
Differential Revision: D9756529
Pulled By: jamesr66a
fbshipit-source-id: da99c6a7621957a305f2699ec9cb9def69b1b2d7
Summary:
This PR extends the existing type and shape metadata tracing and verification done in autograd with device information. This expansion of tracing is required for #8354, is likely useful in other scenarios, and is a healthy sanity check, just like type and shape tracing.
The precise changes are:
- TypeAndShape -> InputMetadata, now includes device()
- Creating InputMetadata is simplified to just require a tensor, and callers were updated to use this simpler invocation wherever possible
- The gradient accumulator of a variable is now reset when set_data() is called if either the type or device changes, and this reset now locks to avoid contention with acquiring the gradient accumulator
- Mismatched devices during backward() will throw a runtime error, just like mismatched type and shape
- (Bonus!) Two uninitialized pointers in THCReduce are now initialized (to nullptr) to prevent build warnings
fyi colesbury
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9796
Reviewed By: goldsborough
Differential Revision: D9119325
Pulled By: ezyang
fbshipit-source-id: 76d1861b8d4f74db0575ff1f3bd965e18f9463de
* Bag of fixes
* Rename tensor_range.h to tensor_list_view.h
* Post rebase fixes
* Rename torch::tensor namespace to torch::tensors due to name conflict
* Avoid recursion in Module::to
This changes type(tensor) to return `torch.Tensor` instead of
`torch.autograd.Variable`.
This requires a few implementation changes:
- torch.Tensor is now a regular Python class instead of a
pseudo-factory like torch.FloatTensor/torch.DoubleTensor
- torch.autograd.Variable is just a shell with a __new__ function.
Since no instanes are constructed it doesn't have any methods.
- Adds torch.get_default_dtype() since torch.Tensor.dtype returns
<attribute 'dtype' of 'torch._C._TensorBase' objects>