The previous PRs built up to this. We change compiled autograd's initial
trace to stop baking in metadata.
While tracing, we allocate some weirdly shaped tensors that we can put
proxies on. The initial trace should not be accessing any metadata of
these tensors (it will likely error out if it does because of how weird
the shapes are).
This involved fixing some various sites where we do specialize on the
metadata, like:
- we change CopySlices's apply_with_saved to proxy some calls
into the graph (this change is fairly hard to split out by itself).
- we stop calling InputBuffer::add
- we delete the weird metadata from the graph so that no graph passes
can make use of it.
Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143417
Approved by: https://github.com/jansel, https://github.com/xmfan
ghstack dependencies: #143296, #143304, #143387, #143405
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.
```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
virtual ~ViewFunc() {}
/// Returns any SymInts in the saved state.
virtual std::vector<c10::SymInt> get_symints() const { return {}; }
/// Returns the number of SymInts in the saved state.
virtual size_t num_symints() const { return 0; }
/// Returns any tensors in the saved state.
virtual std::vector<at::Tensor> get_tensors() const { return {}; }
/// Returns the number of tensors in the saved state.
virtual size_t num_tensors() const { return 0; }
/// Reapplies the view on the given base using the saved state.
virtual at::Tensor operator()(const at::Tensor&) const = 0;
/// Returns a clone of this ViewFunc, optionally with the specified saved state.
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;
protected:
/// Sets the values of any SymInts in the saved state. The input vector size must
/// match the number of SymInts in the saved state (i.e. the size of the list
/// returned by get_symints()).
virtual void set_symints(std::vector<c10::SymInt>) {}
/// Sets the values of any Tensors in the saved state. The input vector size must
/// match the number of Tensors in the saved state (i.e. the size of the list
/// returned by get_tensors()).
virtual void set_tensors(std::vector<at::Tensor>) {}
};
```
New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`
The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.
Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
{};
virtual ~SliceTensorViewFunc() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override;
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override;
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
protected:
virtual void set_symints(std::vector<c10::SymInt>) override;
virtual void set_tensors(std::vector<at::Tensor>) override;
private:
int64_t dim;
c10::optional<c10::SymInt> start;
c10::optional<c10::SymInt> end;
c10::SymInt step;
};
...
// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
::std::vector<c10::SymInt> symints;
symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
if(start.has_value()) symints.insert(symints.end(), *(start));
if(end.has_value()) symints.insert(symints.end(), *(end));
symints.push_back(step);
return symints;
}
size_t SliceTensorViewFunc::num_symints() const {
return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}
void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
auto i = 0;
if(start.has_value()) start = symints[i];
i += (start.has_value() ? 1 : 0);
if(end.has_value()) end = symints[i];
i += (end.has_value() ? 1 : 0);
step = symints[i];
}
std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
::std::vector<at::Tensor> tensors;
return tensors;
}
size_t SliceTensorViewFunc::num_tensors() const {
return static_cast<size_t>(0);
}
void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());
}
at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}
std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
std::optional<std::vector<c10::SymInt>> symints,
std::optional<std::vector<at::Tensor>> tensors) const {
auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
if (symints.has_value()) {
output->set_symints(std::move(*(symints)));
}
if (tensors.has_value()) {
output->set_tensors(std::move(*(tensors)));
}
return output;
}
```
The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.
For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.
```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
virtual ~ViewFunc() {}
/// Returns any SymInts in the saved state.
virtual std::vector<c10::SymInt> get_symints() const { return {}; }
/// Returns the number of SymInts in the saved state.
virtual size_t num_symints() const { return 0; }
/// Returns any tensors in the saved state.
virtual std::vector<at::Tensor> get_tensors() const { return {}; }
/// Returns the number of tensors in the saved state.
virtual size_t num_tensors() const { return 0; }
/// Reapplies the view on the given base using the saved state.
virtual at::Tensor operator()(const at::Tensor&) const = 0;
/// Returns a clone of this ViewFunc, optionally with the specified saved state.
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;
protected:
/// Sets the values of any SymInts in the saved state. The input vector size must
/// match the number of SymInts in the saved state (i.e. the size of the list
/// returned by get_symints()).
virtual void set_symints(std::vector<c10::SymInt>) {}
/// Sets the values of any Tensors in the saved state. The input vector size must
/// match the number of Tensors in the saved state (i.e. the size of the list
/// returned by get_tensors()).
virtual void set_tensors(std::vector<at::Tensor>) {}
};
```
New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`
The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.
Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
{};
virtual ~SliceTensorViewFunc() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override;
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override;
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
protected:
virtual void set_symints(std::vector<c10::SymInt>) override;
virtual void set_tensors(std::vector<at::Tensor>) override;
private:
int64_t dim;
c10::optional<c10::SymInt> start;
c10::optional<c10::SymInt> end;
c10::SymInt step;
};
...
// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
::std::vector<c10::SymInt> symints;
symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
if(start.has_value()) symints.insert(symints.end(), *(start));
if(end.has_value()) symints.insert(symints.end(), *(end));
symints.push_back(step);
return symints;
}
size_t SliceTensorViewFunc::num_symints() const {
return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}
void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
auto i = 0;
if(start.has_value()) start = symints[i];
i += (start.has_value() ? 1 : 0);
if(end.has_value()) end = symints[i];
i += (end.has_value() ? 1 : 0);
step = symints[i];
}
std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
::std::vector<at::Tensor> tensors;
return tensors;
}
size_t SliceTensorViewFunc::num_tensors() const {
return static_cast<size_t>(0);
}
void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());
}
at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}
std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
std::optional<std::vector<c10::SymInt>> symints,
std::optional<std::vector<at::Tensor>> tensors) const {
auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
if (symints.has_value()) {
output->set_symints(std::move(*(symints)));
}
if (tensors.has_value()) {
output->set_tensors(std::move(*(tensors)));
}
return output;
}
```
The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.
For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
- TensorGeometry supports symint
- check_size supports symint
- functorch batch rule improved symint
- Some operator support for symint in LTC
- More supported operations on SymInt and SymFloat
- More symint support in backwards formulas
This merge includes code contributions from bdhirsh and anjali411.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86160
Approved by: https://github.com/Chillee
### Introduction
<!-- What did you change and why was it needed? -->
Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.
For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).
The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with ❌ is not needed when calculating derivatives.
<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">
### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: https://github.com/pytorch/pytorch/issues/56500
### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.
Following https://github.com/pytorch/pytorch/issues/56500#issuecomment-825694656, this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.
### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99
Benchmark result:
6 hidden layers, batch size 10000, on A100
FP32 result
| hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) |
| Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) |
TF32 result
| hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) |
| Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) |
For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, https://github.com/pytorch/functorch/issues/989, so we are using v0.1.1 for benchmark)
@zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?
### Testing
<!-- How did you test your change? -->
- [x] we need to figure out a way for unittest
### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)
cc @zasdfgbnm @albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82544
Approved by: https://github.com/soulitzer
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60025
`to` already copies unconditionally if `src.device() != options.device()` so
specifying the copy argument is unnecessary.
`src.device()` is also completely equivalent to `src.options().device()` so
storing both is redundant.
Test Plan: Imported from OSS
Reviewed By: zou3519
Differential Revision: D29698627
Pulled By: albanD
fbshipit-source-id: eb091d39b71db688e6bcbb33a227c01b94b432bb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60021
Dropping the imaginary component is expected and gives the correct gradient
formula, so silencing the warning is appropriate.
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D29589371
Pulled By: mruberry
fbshipit-source-id: 73e1511cae69207dc9abe576e2769ee1d03f1bbd
Summary:
This PR suppresses clang-tidy warnings in the codebase (for now) so that we can re-enable clang-tidy checks on master.
I ran this script to add the `NOLINTNEXTLINE` comments (on a devserver):
```bash
python3 setup.py develop
# Uses same script that's run on CI and adds the -j (parallel), -s (add comments), -k (continue if diagnostic errors are found) options
python3 tools/clang_tidy.py \
-j \
-s \
-k \
-v \
--paths torch/csrc/ \
-g"-torch/csrc/jit/passes/onnx/helper.cpp" \
-g"-torch/csrc/jit/passes/onnx/shape_type_inference.cpp" \
-g"-torch/csrc/jit/serialization/onnx.cpp" \
-g"-torch/csrc/jit/serialization/export.cpp" \
-g"-torch/csrc/jit/serialization/import.cpp" \
-g"-torch/csrc/jit/serialization/import_legacy.cpp" \
-g"-torch/csrc/onnx/init.cpp" \
-g"-torch/csrc/cuda/nccl.*" \
-g"-torch/csrc/cuda/python_nccl.cpp" \
-g"-torch/csrc/autograd/FunctionsManual.cpp" \
-g"-torch/csrc/generic/*.cpp" \
-g"-torch/csrc/jit/codegen/cuda/runtime/*" \
-g"-torch/csrc/deploy/interpreter/interpreter.cpp" \
-g"-torch/csrc/deploy/interpreter/interpreter.h" \
-g"-torch/csrc/deploy/interpreter/interpreter_impl.h" \
-g"-torch/csrc/deploy/interpreter/test_main.cpp"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60649
Test Plan: Verified changes by re-running the script (without the `-s` option) and seeing no warnings/errors.
Reviewed By: walterddr, janeyx99
Differential Revision: D29504258
Pulled By: 1ntEgr8
fbshipit-source-id: 78310b30ee8213b73ddb4771ad874665323e7a4e
Summary:
Switches most of the simple for loops outside of `jit` directories to use `c10::irange`.
Generated with D28874212.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59481
Test Plan: Sandcastle
Reviewed By: ngimel
Differential Revision: D28909681
fbshipit-source-id: ec9ab1bd602933238d9d0f73d4d8d027b75d9d85
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47227
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.
Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,
```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```
base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.
To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.
Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.
Test Plan: Imported from OSS
Reviewed By: kwanmacher, glaringlee
Differential Revision: D24741687
Pulled By: zou3519
fbshipit-source-id: 8210064f782a0a7a193752029a4340e505ffb5d8
Summary:
Adds the ability for all backward functions to accept undefined output gradient arguments. An undefined gradient is a Tensor that was created by the argumentless constructor `at::Tensor()`, where `tensor.defined() == false`.
Also adds new autograd nodes, UndefinedGrad and UndefinedGradBackward, that can be used from within Python code to inject undefined gradients into a backward function. A new test case is added to the backward function unit tests to use the UndefinedGrad node to ensure that undefined gradients do not break any backward functions.
Closes https://github.com/pytorch/pytorch/issues/33138
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39400
Differential Revision: D21936588
Pulled By: albanD
fbshipit-source-id: eccc5f55c77babe6dadcea4249d0c68a3c64e85d
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33157
This PR enables graph level thread parallelism on CPU for the Autograd
Engine. It replace https://github.com/pytorch/pytorch/pull/29574 for the
reason of task level parallelism drawbacks with the existing autograd
system.
Fixes https://github.com/pytorch/pytorch/issues/18333
The graph level parallelism on CPU design:
1. Remove the single CPU thread that init in the Engine itself and allow
the owning thread (which calls Engine::execute) to drive the Engine
execution so that we could let outer threading to enable thread
parallelism.
2. Maintain a separate ReadyQueue per CPU thread, and stash the
ReadyQueue for different devices/threads into the thread local
shared_ptr, the Engine itself will memorize the shared_ptr of the
ReadyQueue to different devices (other than CPU)
3. The CPU thread local ReadyQueue is initialized per CPU thread
Engine::execute call (or `backward()`, `grad()` call), and memorized
the shared_ptr into the GraphTask since every `backward()` call have
its own GraphTask
4. Cross device NodeTask push is accomplished by 2 and 3. we can refer
to device's ReadyQueue from Engine, and CPU's ReadyQueue from
GraphTask, which means if we can push to a different ReadyQueue
according to the device
5. Termination of the CPU thread: if we mark the graph_task as
completed, we will exit the while loop and terminate the current
backward execution, because it's guranteed that all other NodeTasks
is finished before we mark a GraphTask as complete
6. re-entrant thread logic keeps the same, reentrant thread detection is
similar as before, we set the worker_device to NO_DEVICE initially
and set to CPU afterward to detect if this is a reentrant call or not.
7. we still have the reentrant thread pool that create new threads if it's
a deep reentrant case, and reuse the ReadyQueue with the parent thread
for performance.
Since we introduce the thread parallelism on CPU, we have to ensure the
thread safety of the GraphTask. This is not a problem if we execute all
forward in different threads since we will build separate GraphTask in
different threads, and each GraphTask is a separate instance that share
nothing, i.e. Hogwild training on CPU should be fine on this case.
But there might be case that user would like to do some part of the task in
a single thread, and do the rest of work in several threads
concurrently, so thread safety is crucial in those cases. The thread
safety strategy for the multithread autograd is as follows:
1. Add a mutex to protect thread safety in Autograd Node/Function, and
hold the lock for different data racing cases
2. Lock the mutex during Node::apply(), this is to ensure Node that
writing to the shared variable are not racing across threads (i.e.
AccumulateGrad and custom C++ Autograd Node if writing to shared
variables )
3. Lock the mutex during Node::release_variables(), this serve the
purpose that when we release saved_variables from one thread, no
other threads can call the Node::apply(), this ensures the variable
references from other threads aren't dangling.
4. If we don't release any variables and no shared data read/write in
the Node i.e. purely functional, we don't lock the mutex
This way we could protect the thread safety on Autograd Node, but we
could still not protect the thread safety on Node pre/post C++ hooks
(python hooks are automatically thread safe), we rely on the user to
write thread safe C++ hooks if they want the hook to be correctly
applied in multithreading environment.
**User visiable changes**:
There're not too much user visiable changes, since we use the owning
thread to drive the autograd execution, user could write their own
threading code and does not block on the Autograd engine, some behaviors
that user should be aware of:
**Non-determinism**:
if we are calling backward() on multiple thread concurrently but with
shared inputs (i.e. Hogwild CPU training). Since parameters are automatically shared across threads, gradient accumulation might become non-deterministic on backward calls across threads, because two backward calls might access and try to accumulate the same .grad attribute. This is technically not safe, and it might result in racing condition and the result might be invalid to use.
But this is expected pattern if user are using the multithreading
approach to drive the whole training process but using shared
parameters, user who use multithreading should have the threading model
in mind and should expect this to happen. User should use the functional
interface `torch.autograd.grad()` to calculate the gradients instead of
`backward()` on loss.
**Graph retaining**:
If part of the autograd graph is shared between threads, i.e. run first
part of forward single thread, then run second part in multiple threads,
then the first part of graph is shared. In this case different threads execute grad() or backward() on the same graph might
have issue of destroying the graph on the fly of one thread, and the
other thread will crash in this case. We will error out to the user
similar to what call `backward()` twice with out `retain_graph=True`, and let the user know they should use `retain_graph=True`.
**TODOs**:
[ ] benchmark the PR with example models and datasets to demonstrate
the performance gain in CPU training
[ ] ensure that we don't regress the single thread autograd performance
**Follow ups**:
[ ] a correct and tight integration with distributed autograd
[ ] try to unify the thread pool between JIT and Autograd, and see if
there's unifying pattern that we could apply universally
Test Plan: Imported from OSS
Differential Revision: D20236771
Pulled By: wanchaol
fbshipit-source-id: 1e0bd4eec14ffebeffdb60b763b8d6f0e427eb64
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29665
Our intention is to merge the static distinction between Tensor and
Variable. Ordinarily, this would entail merging the methods of Tensor
and Variable. But there are a lot of "private"-ish methods on Variable
that we don't actually want to dump onto the Tensor class. So, as prep
work, we move all of those methods off of Variable and into
the torch::autograd::impl namespace (impl as in, please don't use this
end users). This ends up being a fairly large patch because all of
the call sites have to play ball too.
While I was on the topic, I also moved any of the touched functions into
the C++ file, so that modifying them would not trigger a recompilation of
all of torch.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D18496169
Pulled By: ezyang
fbshipit-source-id: afb203252620ec274be596b3e7b1d84d321bad3a
Summary:
This is the first of a series of changes to reduce build size by cutting
autograd functions from mobile build.
When INTERN_DISABLE_AUTOGRAD is set:
* On CMake side we exclude Functions.h/cpp, VariableType*.h/cpp,
VariableTypeManual.cpp from the build process. Still keep variable_factories.h
as we rely on it to create variables instead of tensors.
* In source code we gate a couple autograd references (in autograd/variable.cpp)
with C10_MOBILE (technically we should use a dedicated c macro but its
maintenance cost is higher than cmake macro as we have several build systems
to change).
* Pass --disable-autograd flag to codegen script, which will stop generating
Functions/VariableType code. And for variable_factories.h it will stop
generating tracing code.
Edit: in this diff we will keep Functions.h/cpp to avoid changing source code.
Why we need this change if it's already not calling VariableType and autograd
stuff with USE_STATIC_DISPATCH=ON for mobile?
It's trying to reduce static library size for iOS build, for which it's
relatively harder to strip size with linker approach.
Why we need make involved change into codegen script?
There isn't a global config system in codegen - autograd/env.py provides similar
functionality but it says not adding anything there.
Test Plan:
- will check CI;
- test mobile build in sample app;
Differential Revision: D17202733
Pulled By: ljk53
fbshipit-source-id: 5701c6639b39ce58aba9bf5489a08d30d1dcd299
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25332
This method makes reference to a deprecated class, we now delete it.
This deletion was somewhat involved. Pre-existing use sites of
toType:
- Tensor::cpu()/cuda()/hip()
- native::type_as
- SummaryOps: toType(CPU(kDouble)) translated into to(kDouble) as weights
is an input argument and therefore assumed to be on CPU already. Similar
for CUDA.
- TensorTransformations: toType(CUDA(kLong)) translated into cuda(), as
the inputs are actually already the correct dtype, and this translation is just to move them to CUDA
- Adjusted native_test to take TensorOptions instead of
DeprecatedTypeProperties, killing toType along the way in favor of to
- Some tests for toType with UndefinedType which I just deleted
- CopyBackwards stores TensorOptions now instead of
DeprecatedTypeProperties
ghstack-source-id: 89177526
Test Plan: sandcastle and ossci
Differential Revision: D17096824
fbshipit-source-id: 964e5a073b9d37594e911d8bca98c9eab5766826
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13232
DeviceGuard should be device agnostic, which means that it shouldn't
assume that int64_t means select the CUDA device.
Reviewed By: gchanan
Differential Revision: D10858024
fbshipit-source-id: b40e8337e4046906fd8f83a95e6206367fb29dbe
Summary:
This PR:
1. Makes clang-tidy diff against `master` instead of `HEAD~1` in CI, which makes much more sense
2. Enables all checks in the `bugprone-*` category (see https://clang.llvm.org/extra/clang-tidy/checks/list.html) except one about parantheses in macros, because it doesn't always apply too well for us.
Fixed some nice code smells.
ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12378
Differential Revision: D10247972
Pulled By: goldsborough
fbshipit-source-id: 97dc9e262effa6874d2854584bf41a86684eb8bd