Addresses: https://github.com/pytorch/pytorch/issues/35802
Design doc: https://docs.google.com/document/d/19xSib7FFknRQ5f3ptGFUmiOt3BrgXSUlTQH2xMcZJYg/edit#
### Changes in this PR
#### Implementation
- We have now have 3 fields: pre_hooks, retains_grad_hooks, and tensor_pre_hooks so that we can more precisely define their ordering and when they are executed.
- Since retains grad uses an entirely new field, we cannot reuse the old retains grad, logic. We refactor retains grad to call directly into the variable.cpp logic. Other logic in variable.cpp that handle cpp hooks must also be updated.
#### Hooks ordering and execution:
- Defines pre-hooks registered on tensor to run before pre-hooks registered on grad_fn
- Updates pre-hooks registered on tensor to always run, even if they are the inputs= to .grad()
- Post hooks (and pre hooks) can now observe the modifications to gradient by the tensor pre hook
#### Retains grad hooks
- retains grad hooks always execute last, even if there are other tensor pre-hooks registered
#### Unchanged:
- pre_hooks registered to grad_fn aren't expected to execute if they are the inputs= to .grad()
Follow ups:
- simplify retains_grad field to not be a vector, since it always holds a single hook
- potentially merge capture hooks with tensor pre hooks, this would involve some additional refactoring since
- python hooks registered to tensor behavior on in-place is still wrong
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85849
Approved by: https://github.com/albanD
This PR removes the autograd.Function extension feature flag. This was
previously used for development of the functorch <> autograd.Function
interaction.
It's been in master for long enough with the feature flag defaulting to
True, so it's time to remove it.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92026
Approved by: https://github.com/soulitzer
This allows to know at any point during the backward pass what is running and where the Node currently running was created at:
```python
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.autograd import detect_anomaly
class MyMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
node = torch._C._current_autograd_node()
print(f"Running {func} from within {node}")
if node is not None:
print("The Node was created at:")
print("\n ".join(node.metadata["traceback_"]))
return func(*args, **kwargs or {})
with MyMode(), detect_anomaly():
print("FW")
a = torch.rand(10, requires_grad=True)
b = a.mul(2)
b = b.div(3)
b = b.sum()
print("BW")
b.backward()
```
Gives
```
$ python foo.py
foo.py:15: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
with MyMode(), detect_anomaly():
FW
Running aten.rand.default from within None
Running aten.mul.Tensor from within None
Running aten.div.Tensor from within None
Running aten.sum.default from within None
BW
Running aten.ones_like.default from within None
Running aten.expand.default from within <SumBackward0 object at 0x7fa40c0c6dc0>
The Node was created at:
File "foo.py", line 20, in <module>
b = b.sum()
Running aten.isnan.default from within <SumBackward0 object at 0x7fa40c0c6500>
The Node was created at:
File "foo.py", line 20, in <module>
b = b.sum()
Running aten.any.default from within <SumBackward0 object at 0x7fa32b23a780>
The Node was created at:
File "foo.py", line 20, in <module>
b = b.sum()
Running aten._local_scalar_dense.default from within <SumBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 20, in <module>
b = b.sum()
Running aten.div.Tensor from within <DivBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 19, in <module>
b = b.div(3)
Running aten.isnan.default from within <DivBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 19, in <module>
b = b.div(3)
Running aten.any.default from within <DivBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 19, in <module>
b = b.div(3)
Running aten._local_scalar_dense.default from within <DivBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 19, in <module>
b = b.div(3)
Running aten.mul.Tensor from within <MulBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
Running aten.isnan.default from within <MulBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
Running aten.any.default from within <MulBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
Running aten._local_scalar_dense.default from within <MulBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
Running aten.detach.default from within <AccumulateGrad object at 0x7fa40c0c9730>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
Running aten.detach.default from within <AccumulateGrad object at 0x7fa40c0c94b0>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90867
Approved by: https://github.com/soulitzer
This PR adds a private runtime feature flag for the feature work we're going
to do with extending autograd.Function. The motivation of the feature flag
is:
- to guard the feature against unsuspecting users
- control the release of the feature to when we are ready to release it
We might not even need the feature flag (because we hope to have the
work done in the next month), but it is good practice and it does touch
currently public API (autograd.Function).
Concretely, "autograd.Function extension" refers to:
- adding an optional `setup_context` staticmethod to autograd.Function
- adding an optional `vmap` staticmethod to autograd.Function
- autograd.Function support for functorch
Test Plan:
- new test that the feature flag works
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89858
Approved by: https://github.com/soulitzer
In this PR:
- graph_task stores graph roots on construction so that we can later traverse through the graph
- before the nodes are returned, they needed to be converted from raw_ptr to shared_ptr, and this should be OK because the graph is guaranteed to be alive
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87507
Approved by: https://github.com/albanD
### Introduction
<!-- What did you change and why was it needed? -->
Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.
For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).
The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with ❌ is not needed when calculating derivatives.
<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">
### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: https://github.com/pytorch/pytorch/issues/56500
### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.
Following https://github.com/pytorch/pytorch/issues/56500#issuecomment-825694656, this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.
### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99
Benchmark result:
6 hidden layers, batch size 10000, on A100
FP32 result
| hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) |
| Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) |
TF32 result
| hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) |
| Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) |
For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, https://github.com/pytorch/functorch/issues/989, so we are using v0.1.1 for benchmark)
@zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?
### Testing
<!-- How did you test your change? -->
- [x] we need to figure out a way for unittest
### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)
cc @zasdfgbnm @albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82544
Approved by: https://github.com/soulitzer
Summary:
When we were pre-sampling this was a pretty important optimizaton. However now when we make a record function we can be sure that it will be called.
For the RECORD_FUNCTION macros I preserved the old behavior by making a `c10::optional<RecordFunction>` since we can't force callers to have separate paths the way Dispatcher does.
Maybe it makes sense to have a guard that handles the optional logic? If we can move enough out of the internals (e.g. replace `std::string`s with `char*`s) we might not even need the optional to get good perf.
Test Plan: The no-op observer overhead benchmark got a bit better, but even with lots of replicates it's hard to tell if that's just noise. This is primarily a change to simplify the semantics of RecordFunction.
Reviewed By: chaekit
Differential Revision: D35276157
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76016
Approved by: https://github.com/chaekit
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75807
There is a tension in RecordFunction between two use cases:
1) In the normal eager path we don't run any callbacks, so we need to bail out of the profiling path as soon as possible to minimize eager overhead.
2) When profiling we want to determine which callbacks to run as efficiently as possible to minimize instrumentation overhead.
The confounding factor in all of this is sampling callbacks because they change which callbacks will run on each call, even in steady state operation. This has traditionally been handled with a two stage procedure: first we flip a coin to determine if a sampled callback *might* run. If false (which it usually is), do nothing. This solves (1). If true, check to see if we need to build the full callback set or if it was a false positive. This procedure has two negative effects:
* It forces us to rebuild the set of callbacks to run on every step when profiling
* It leaks the sampling abstraction, requiring other parts of the code to bump certain values and forces RecordFunction to lazily initialize.
This change introduces a multi-level cache which can (in the common case) quickly determine which callbacks *will* run, rather than if callbacks *might* run. This means that rather than call `shouldRunRecordFunction`, we can simply get the callbacks for an invocation and check if they are empty. (And completely removes the pre-sampling heuristic.) Another major benefit of the new cache structure is that it allows thread-safe registration and unregistration of global callbacks.
It's worth briefly discussing how this maintains eager performance. In the standard eager case (only sampling callbacks registered) the cache first checks that the global callbacks haven't changed (atomic read), decrements a counter to see if a sampling callback fired, and then returns the active callbacks which is simply a SmallVector of pointer pairs and a couple POD values (scope, needs inputs/outputs/ids). The biggest cost according to perf is the SmallVector logic; we could consider adopting a hard limit on active callbacks; more than half a dozen callbacks *running* in a single step would be quite a lot. But the total cost relative to `PYTORCH_DISABLE_PER_OP_PROFILING` is only ~10ns, so debatable if it's worth it to switch to `std::array`.
The primary change is in `record_function.cpp`, which has a more detailed description of the new cache structure. `record_function.h` has some minor changes to align with the new calling convention and the remaining files are simply changes to the call sites.
Future work:
* RecordFunction no longer needs to be lazily initialized.
* We can deprecate the disable/reenable APIs, since we can not safely add and remove global callbacks.
Test Plan:
I tested eager mode performance using the overhead benchmark and found that the non-profiled path was unaffected. However the no-op observer dropped from 0.41us to 0.37us (0.25us if no observers are active) which is about 1/3rd reduction in the cost of the callback selection machinery.
I also added several C++ unit tests, as the core RecordFunction machinery (especially sampling) was largely untested.
Reviewed By: swolchok, davidberard98
Differential Revision: D35276158
fbshipit-source-id: 35135f444724fba4eb97c0ae7f3f710f0f9016fd
(cherry picked from commit 9e359b87422c18f2a195185f32e7e85c82f956fd)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70326
See D24145988 for context: it allows loops such as for(int i=0;i<10;i++) to be expressed as for(const auto i : c10::irange(10)). This is nice because it auto-types the loops and adds const-safety to the iteration variable.
Test Plan: buck run //caffe2/torch/fb/sparsenn:test
Reviewed By: r-barnes
Differential Revision: D33243400
fbshipit-source-id: b1f1b4163f4bf662031baea9e5268459b40c69a3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68691
TraceType is a sharded file, so by only including specific operator
headers, we ensure that changing one (non-method) operator only needs
one shard to be re-compiled.
This also changes all the included autograd and jit headers from
including `ATen/ATen.h` to just including `ATen/core/Tensor.h`.
Test Plan: Imported from OSS
Reviewed By: gchanan
Differential Revision: D33336948
Pulled By: albanD
fbshipit-source-id: 4e40371592b9a5a7e7fcd1d8cecae11ffb873113
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68691
TraceType is a sharded file, so by only including specific operator
headers, we ensure that changing one (non-method) operator only needs
one shard to be re-compiled.
This also changes all the included autograd and jit headers from
including `ATen/ATen.h` to just including `ATen/core/Tensor.h`.
Test Plan: Imported from OSS
Reviewed By: jbschlosser, malfet
Differential Revision: D32596264
Pulled By: albanD
fbshipit-source-id: 2f28b62d7b9932f30fad7daacd8ac5bb7f63c621
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69421
I've hit a lot of build issues in D32671972, and I've come to realize that a lot of it boils down to header hygene. `function.h` includes `profiler.h` *solely* to transitively include `record_function.h` which winds up leaking the profiler symbols. Moreover several files are relying on transitive includes to get access to `getTime`. As long as I have to touch all the places that use `getTime`, I may as well also move them to the new namespace.
Test Plan: Unit tests and CI.
Reviewed By: aaronenyeshi, albanD
Differential Revision: D32865907
fbshipit-source-id: f87d6fd5afb784dca2146436e72c69e34623020e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65235
1. Updated the legacy type checks in `torch/csrc/autograd/engine.cpp` to individually validate the dtype, device, and layout equality for grad and tensor.
2. Removed device field from `InputMetadata` since it's already stored via storing options. Also, added `dtype()` and `layout()` methods to `InputMetadata`. To make this change, some calls had to be updated due to the change in constructor.
3. To fix https://github.com/pytorch/pytorch/issues/65016:
a. Added a `is_tensor_subclass` field in `InputMetadata` to skip device checks for grad and tensor when the tensor has
python key set on it (tensor subclass).
Test Plan: Imported from OSS
Reviewed By: jbschlosser
Differential Revision: D31117318
Pulled By: anjali411
fbshipit-source-id: 825401df98695c48bf9b320be54585f6aff500bd
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65235
1. Updated the legacy type checks in `torch/csrc/autograd/engine.cpp` to individually validate the dtype, device, and layout equality for grad and tensor.
2. Removed device field from `InputMetadata` since it's already stored via storing options. Also, added `dtype()` and `layout()` methods to `InputMetadata`. To make this change, some calls had to be updated due to the change in constructor.
3. To fix https://github.com/pytorch/pytorch/issues/65016:
a. Added a `is_tensor_subclass` field in `InputMetadata` to skip device checks for grad and tensor when the tensor has
python key set on it (tensor subclass).
Test Plan: Imported from OSS
Reviewed By: pbelevich
Differential Revision: D31082693
Pulled By: anjali411
fbshipit-source-id: cb551cd438c6ca40b0f18a4d0009e0861cf0fd4e
Summary:
This PR suppresses clang-tidy warnings in the codebase (for now) so that we can re-enable clang-tidy checks on master.
I ran this script to add the `NOLINTNEXTLINE` comments (on a devserver):
```bash
python3 setup.py develop
# Uses same script that's run on CI and adds the -j (parallel), -s (add comments), -k (continue if diagnostic errors are found) options
python3 tools/clang_tidy.py \
-j \
-s \
-k \
-v \
--paths torch/csrc/ \
-g"-torch/csrc/jit/passes/onnx/helper.cpp" \
-g"-torch/csrc/jit/passes/onnx/shape_type_inference.cpp" \
-g"-torch/csrc/jit/serialization/onnx.cpp" \
-g"-torch/csrc/jit/serialization/export.cpp" \
-g"-torch/csrc/jit/serialization/import.cpp" \
-g"-torch/csrc/jit/serialization/import_legacy.cpp" \
-g"-torch/csrc/onnx/init.cpp" \
-g"-torch/csrc/cuda/nccl.*" \
-g"-torch/csrc/cuda/python_nccl.cpp" \
-g"-torch/csrc/autograd/FunctionsManual.cpp" \
-g"-torch/csrc/generic/*.cpp" \
-g"-torch/csrc/jit/codegen/cuda/runtime/*" \
-g"-torch/csrc/deploy/interpreter/interpreter.cpp" \
-g"-torch/csrc/deploy/interpreter/interpreter.h" \
-g"-torch/csrc/deploy/interpreter/interpreter_impl.h" \
-g"-torch/csrc/deploy/interpreter/test_main.cpp"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60649
Test Plan: Verified changes by re-running the script (without the `-s` option) and seeing no warnings/errors.
Reviewed By: walterddr, janeyx99
Differential Revision: D29504258
Pulled By: 1ntEgr8
fbshipit-source-id: 78310b30ee8213b73ddb4771ad874665323e7a4e
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55957
This diff adds an execution graph observer that tracks all operators (dispatcher autograd, jit, user defined, etc.) and their inputs and outputs. The results are written to a temp JSON file which can be used for further analysis. This support various use cases, such as dependency analysis, performance optimizations, etc.
Some minor refactoring of existing code for clarity and completeness.
Test Plan:
Example output:
{F603167736}
```
=> buck build caffe2/torch/fb/observers:execution_graph_observer_runner --show-output
=> buck-out/gen/caffe2/torch/fb/observers/execution_graph_observer_runner --pytorch_enable_execution_graph_observer=true --pytorch_execution_graph_observer_iter_label="## START ##" --pytorch_execution_graph_observer_iter_target=3
I0414 01:26:55.834039 1038798 ExecutionGraphObserver.cpp:408] Enabled PyTorch execution graph observer
I0414 01:26:55.834717 1038798 ExecutionGraphObserver.cpp:411] Matching iteration start label: "## START ##"
I0414 01:26:55.834940 1038798 ExecutionGraphObserver.cpp:423] Target iteration: 3
I0414 01:26:55.835962 1038798 ExecutionGraphObserverRunner.cpp:50] Running test execution graph observer runner.
I0414 01:26:55.836180 1038798 ExecutionGraphObserverRunner.cpp:51] iterations: 10
I0414 01:26:55.836419 1038798 ExecutionGraphObserverRunner.cpp:52] output file name: /tmp/pytorch_execution_graph_1618388815_1038798_3.json
I0414 01:26:56.246432 1038798 ExecutionGraphObserver.cpp:137] Writing PyTorch execution graph to: /tmp/pytorch_execution_graph_1618388815_1038798_3.json
I0414 01:26:56.278715 1038798 ExecutionGraphObserver.cpp:314] PyTorch execution graph is written to file: /tmp/pytorch_execution_graph_1618388815_1038798_3.json
```
see `/tmp/pytorch_execution_graph_[timestamp]_[process_id]_[iter_target].json`
Reviewed By: albanD
Differential Revision: D27238906
fbshipit-source-id: 3eb717d7d512e2d51d3162e9995b1ccd18e5a725
Summary:
Fixes https://github.com/pytorch/pytorch/issues/12635
This change will help us speed up autograd's discovery algorithm in cases where we use `.grad` and we try to "unroll" the training loop. For example the example in the issue and also https://github.com/pytorch/pytorch/pull/52180#issuecomment-783400832 observe an unbounded multiple of speed-up.
We do this by adding a new sequence_nr-type numbering: for each node, we maintain the length of the longest path from it to any leaf node. How does this help us speed up discovery (dfs)? Previously the bottleneck was that the dfs that computes which nodes need to be executed always explored every node. With this change, before we run dfs, we first compute the mininum seq_nr among all the nodes passed as the `inputs`. If let this be some number N, intuitively this means that dfs should stay at least N units away from any leaf node. So, if we find ourselves too close to any leaf node, we should stop our search early.
Edit:
After some discussion offline, the plan is:
- make old sequence_nr a construct of the profiler. This means we can avoid accessing thread local state in cases where the profiler is disabled. Note that we cannot replace sequence_nr as-is because profiler's use-case requires that thread-id + sequence_nr can uniquely identify a given node in order for downstream users/programs to correlate nodes from backward and forward passes. This means we must maintain two sequence_nr's and that we have an extra field in Node.
- In a future PR, we can potentially remove sequence_nr entirely from the profiler as well, but we avoid doing it now because we haven't measured, and its a larger effort because we'd have to mess around with the dispatcher and profiler
Testing with this [code](https://gist.github.com/kyunghyuncho/5fb9991ce1233f909051854a84b7148e), we see that runtime no longer increases as we iterate.
Before:
```
100: Time taken: 0.47s, loss: 1.1e+06
200: Time taken: 0.064s, loss: 6.5e+05
300: Time taken: 0.088s, loss: 4.4e+05
400: Time taken: 0.1s, loss: 3.2e+05
500: Time taken: 0.12s, loss: 2.5e+05
600: Time taken: 0.15s, loss: 2e+05
700: Time taken: 0.18s, loss: 1.7e+05
800: Time taken: 0.2s, loss: 1.4e+05
900: Time taken: 0.22s, loss: 1.2e+05
1000: Time taken: 0.24s, loss: 1.1e+05
1100: Time taken: 0.27s, loss: 9.3e+04
1200: Time taken: 0.3s, loss: 8.3e+04
1300: Time taken: 0.34s, loss: 7.4e+04
1400: Time taken: 0.36s, loss: 6.7e+04
1500: Time taken: 0.38s, loss: 6.1e+04
1600: Time taken: 0.4s, loss: 5.6e+04
1700: Time taken: 0.42s, loss: 5.1e+04
1800: Time taken: 0.44s, loss: 4.7e+04
1900: Time taken: 0.47s, loss: 4.4e+04
2000: Time taken: 0.5s, loss: 4.1e+04
```
After:
```
100: Time taken: 0.49s, loss: 1.2e+06
200: Time taken: 0.031s, loss: 6.9e+05
300: Time taken: 0.031s, loss: 4.6e+05
400: Time taken: 0.031s, loss: 3.3e+05
500: Time taken: 0.031s, loss: 2.6e+05
600: Time taken: 0.031s, loss: 2.1e+05
700: Time taken: 0.031s, loss: 1.7e+05
800: Time taken: 0.031s, loss: 1.4e+05
900: Time taken: 0.031s, loss: 1.2e+05
1000: Time taken: 0.031s, loss: 1.1e+05
1100: Time taken: 0.031s, loss: 9.6e+04
1200: Time taken: 0.031s, loss: 8.6e+04
1300: Time taken: 0.031s, loss: 7.7e+04
1400: Time taken: 0.031s, loss: 7e+04
1500: Time taken: 0.031s, loss: 6.3e+04
1600: Time taken: 0.031s, loss: 5.8e+04
1700: Time taken: 0.031s, loss: 5.3e+04
1800: Time taken: 0.031s, loss: 4.9e+04
1900: Time taken: 0.031s, loss: 4.5e+04
2000: Time taken: 0.032s, loss: 4.2e+04
```
Testing w/ small graph to check for regression:
```
import torch
from torch.utils.benchmark import Timer
setup="""
a = torch.rand((2, 2), requires_grad=True)
b = torch.rand((2, 2), requires_grad=True)
gradient = torch.ones(2, 2)
"""
stmt="""
torch.autograd.grad(a*b, [a, b], gradient)
"""
timer = Timer(stmt, setup)
print(timer.timeit(10000))
print(timer.collect_callgrind(100))
```
Result: there doesn't seem to be any significant regression
```
Time before: 12.74 us
Time after: 13.12 us
Instruction count before:
All Noisy symbols removed
Instructions: 8078960 8000882
Baseline: 4226 3838
Instruction count after:
All Noisy symbols removed
Instructions: 8091846 8017940
Baseline: 4336 3838
100 runs per measurement, 1 thread
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52180
Reviewed By: gchanan, zhangguanheng66
Differential Revision: D26794387
Pulled By: soulitzer
fbshipit-source-id: c00d387a29f151109c33dc6f1b56a8f275cdec58
Summary:
Fixes https://github.com/pytorch/pytorch/issues/49756
## Background
Fix applied here is to remove the grad enabled check from `collect_next_edges`, unconditionally returning the actual collected edges. This pushes the responsibility for determining whether the function should be called without grad mode to its call-sites. With this update, `collect_next_edges` will no longer incorrectly return an empty list, which caused the problem described in the issue. Three call-sites depended on this behavior and have been updated.
Beyond bad printing side effects, this fix addresses the more general issue of accessing `grad_fn` with grad mode disabled after an in-place operation on a view. The included test verifies this without the use of print.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51364
Test Plan:
```
python test/test_autograd.py TestAutogradDeviceTypeCPU.test_inplace_view_then_no_grad_cpu
```
Reviewed By: zou3519
Differential Revision: D26190451
Pulled By: jbschlosser
fbshipit-source-id: 9b004a393463f8bd4ac0690e5e53c07a609f87f0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47550
I saw over 5% time spent in RecordFunction's ctor during one
of our framework overhead benchmarks in `perf`. Inspecting assembly,
it looks like we just create a lot of RecordFunctions and the
constructor has to initialize a relatively large number of member
variables.
This diff takes advantage of the observation that RecordFunction does
nothing most of the time by moving its state onto the heap and only
allocating it if needed. It does add the requirement that profiling is
actually active to use RecordFunction accessors, which I hope won't be
a problem.
ghstack-source-id: 117498489
Test Plan: Run framework overhead benchmarks. Savings ranging from 3% (InPlace_ndim_1) to 7.5% (empty_ndim_3) wall time.
Reviewed By: ilia-cher
Differential Revision: D24812213
fbshipit-source-id: 823a1e2ca573d9a8d7c5b7bb3972987faaacd11a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47549
In preparation for moving state onto the heap.
ghstack-source-id: 117027862
Test Plan: CI
Reviewed By: ilia-cher
Differential Revision: D24812214
fbshipit-source-id: 1455c2782b66f6a59c4d45ba58e1c4c92402a323
Summary:
Fixes https://github.com/pytorch/pytorch/issues/43405.
This pull request adds a feature of printing all tracebacks if a `detect_anomaly` mode detects `nan` in nested backward operations.
The way I did it is by assigning a node as a parent to all nodes it produces during its backward calculation. Then if one of the children produces `nan`, it will print the traceback from the parent and grand parents (if any).
The parent is assigned in `parent_node_` member in `Node` class which is accessible in C++ by function `node->parent()` and in Python by `node.parent_function`.
A node has a parent iff:
1. it is created from a backward operation, and
2. created when anomaly mode and grad mode are both enabled.
An example of this feature:
import torch
def example():
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(1e-8, requires_grad=True) # small to induce nan in n-th backward
a = x * y
b = x * y
z1 = a / b # can produce nan in n-th backward as long as https://github.com/pytorch/pytorch/issues/43414 is unsolved
z = z1 * z1
gy , = torch.autograd.grad( z , (y,), create_graph=True)
gy2, = torch.autograd.grad(gy , (y,), create_graph=True)
gy3, = torch.autograd.grad(gy2, (y,), create_graph=True)
gy4, = torch.autograd.grad(gy3, (y,), create_graph=True)
return gy4
with torch.autograd.detect_anomaly():
gy4 = example()
with output:
example.py:16: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
with torch.autograd.detect_anomaly():
/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning: Error detected in DivBackward0. Traceback of forward call that caused the error:
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 12, in example
gy3, = torch.autograd.grad(gy2, (y,), create_graph=True)
File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad
return Variable._execution_engine.run_backward(
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:61.)
return Variable._execution_engine.run_backward(
/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning:
Traceback of forward call that induces the previous calculation:
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 11, in example
gy2, = torch.autograd.grad(gy , (y,), create_graph=True)
File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad
return Variable._execution_engine.run_backward(
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:65.)
return Variable._execution_engine.run_backward(
/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning:
Traceback of forward call that induces the previous calculation:
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 8, in example
z1 = a / b # can produce nan in n-th backward as long as https://github.com/pytorch/pytorch/issues/43414 is unsolved
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:65.)
return Variable._execution_engine.run_backward(
Traceback (most recent call last):
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 13, in example
gy4, = torch.autograd.grad(gy3, (y,), create_graph=True)
File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad
return Variable._execution_engine.run_backward(
RuntimeError: Function 'DivBackward0' returned nan values in its 1th output.
cc & thanks to albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43626
Reviewed By: malfet
Differential Revision: D23397499
Pulled By: albanD
fbshipit-source-id: aa7435ec2a7f0d23a7a02ab7db751c198faf3b7d
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41947
Previously, if an op took an optional `Tensor?` argument, the C++ frontend (i.e. `at::op()` and `Tensor::op()`)
were generated to take `Tensor`. A previous PR (https://github.com/pytorch/pytorch/pull/41610) changed the kernels
to be written with `c10::optional<Tensor>` instead of `Tensor`, but that did not touch the C++ frontend yet.
This PR changes the C++ frontend API to take `c10::optional<Tensor>` instead of `Tensor` as well.
This should be mostly bc conserving. Since `Tensor` implicitly converts to `c10::optional<Tensor>`, any old code
calling an op with a `Tensor` would still work. There are likely corner cases that get broken though.
For example, C++ only ever does *one* implicit conversion. So if you call an op with a non-tensor object
that gets implicitly converted to a `Tensor`, then that previously worked since the API took a `Tensor` and
C++ allows one implicit conversion. Now it wouldn't work anymore because it would require two implicit conversions
(to `Tensor` and then to `c10::optional<Tensor>`) and C++ doesn't do that.
The main reasons for doing this are
- Make the C++ API more sane. Those arguments are optional and that should be visible from the signature.
- Allow easier integration for XLA and Autocast. Those backends generate code to wrap operators and forward
operator arguments to calls to at::op(). After https://github.com/pytorch/pytorch/pull/41610, there was
a mismatch because they had to implement operators with `optional<Tensor>` but call `at::op()` with `Tensor`,
so they had to manually convert between those. After this PR, they can just forward the `optional<Tensor>`
in their call to `at::op()`.
ghstack-source-id: 108873705
Test Plan: unit tests
Reviewed By: bhosmer
Differential Revision: D22704832
fbshipit-source-id: f4c00d457b178fbc124be9e884a538a3653aae1f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37587
Lifting RecordFunction up into the dispatcher code
Test Plan: Imported from OSS
Differential Revision: D21374246
fbshipit-source-id: 19f9c1719e6fd3990e451c5bbd771121e91128f7
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37548
Moving RecordFunction from torch::autograd::profiler into at namespace
Test Plan:
CI
Imported from OSS
Differential Revision: D21315852
fbshipit-source-id: 4a4dbabf116c162f9aef0da8606590ec3f3847aa
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33157
This PR enables graph level thread parallelism on CPU for the Autograd
Engine. It replace https://github.com/pytorch/pytorch/pull/29574 for the
reason of task level parallelism drawbacks with the existing autograd
system.
Fixes https://github.com/pytorch/pytorch/issues/18333
The graph level parallelism on CPU design:
1. Remove the single CPU thread that init in the Engine itself and allow
the owning thread (which calls Engine::execute) to drive the Engine
execution so that we could let outer threading to enable thread
parallelism.
2. Maintain a separate ReadyQueue per CPU thread, and stash the
ReadyQueue for different devices/threads into the thread local
shared_ptr, the Engine itself will memorize the shared_ptr of the
ReadyQueue to different devices (other than CPU)
3. The CPU thread local ReadyQueue is initialized per CPU thread
Engine::execute call (or `backward()`, `grad()` call), and memorized
the shared_ptr into the GraphTask since every `backward()` call have
its own GraphTask
4. Cross device NodeTask push is accomplished by 2 and 3. we can refer
to device's ReadyQueue from Engine, and CPU's ReadyQueue from
GraphTask, which means if we can push to a different ReadyQueue
according to the device
5. Termination of the CPU thread: if we mark the graph_task as
completed, we will exit the while loop and terminate the current
backward execution, because it's guranteed that all other NodeTasks
is finished before we mark a GraphTask as complete
6. re-entrant thread logic keeps the same, reentrant thread detection is
similar as before, we set the worker_device to NO_DEVICE initially
and set to CPU afterward to detect if this is a reentrant call or not.
7. we still have the reentrant thread pool that create new threads if it's
a deep reentrant case, and reuse the ReadyQueue with the parent thread
for performance.
Since we introduce the thread parallelism on CPU, we have to ensure the
thread safety of the GraphTask. This is not a problem if we execute all
forward in different threads since we will build separate GraphTask in
different threads, and each GraphTask is a separate instance that share
nothing, i.e. Hogwild training on CPU should be fine on this case.
But there might be case that user would like to do some part of the task in
a single thread, and do the rest of work in several threads
concurrently, so thread safety is crucial in those cases. The thread
safety strategy for the multithread autograd is as follows:
1. Add a mutex to protect thread safety in Autograd Node/Function, and
hold the lock for different data racing cases
2. Lock the mutex during Node::apply(), this is to ensure Node that
writing to the shared variable are not racing across threads (i.e.
AccumulateGrad and custom C++ Autograd Node if writing to shared
variables )
3. Lock the mutex during Node::release_variables(), this serve the
purpose that when we release saved_variables from one thread, no
other threads can call the Node::apply(), this ensures the variable
references from other threads aren't dangling.
4. If we don't release any variables and no shared data read/write in
the Node i.e. purely functional, we don't lock the mutex
This way we could protect the thread safety on Autograd Node, but we
could still not protect the thread safety on Node pre/post C++ hooks
(python hooks are automatically thread safe), we rely on the user to
write thread safe C++ hooks if they want the hook to be correctly
applied in multithreading environment.
**User visiable changes**:
There're not too much user visiable changes, since we use the owning
thread to drive the autograd execution, user could write their own
threading code and does not block on the Autograd engine, some behaviors
that user should be aware of:
**Non-determinism**:
if we are calling backward() on multiple thread concurrently but with
shared inputs (i.e. Hogwild CPU training). Since parameters are automatically shared across threads, gradient accumulation might become non-deterministic on backward calls across threads, because two backward calls might access and try to accumulate the same .grad attribute. This is technically not safe, and it might result in racing condition and the result might be invalid to use.
But this is expected pattern if user are using the multithreading
approach to drive the whole training process but using shared
parameters, user who use multithreading should have the threading model
in mind and should expect this to happen. User should use the functional
interface `torch.autograd.grad()` to calculate the gradients instead of
`backward()` on loss.
**Graph retaining**:
If part of the autograd graph is shared between threads, i.e. run first
part of forward single thread, then run second part in multiple threads,
then the first part of graph is shared. In this case different threads execute grad() or backward() on the same graph might
have issue of destroying the graph on the fly of one thread, and the
other thread will crash in this case. We will error out to the user
similar to what call `backward()` twice with out `retain_graph=True`, and let the user know they should use `retain_graph=True`.
**TODOs**:
[ ] benchmark the PR with example models and datasets to demonstrate
the performance gain in CPU training
[ ] ensure that we don't regress the single thread autograd performance
**Follow ups**:
[ ] a correct and tight integration with distributed autograd
[ ] try to unify the thread pool between JIT and Autograd, and see if
there's unifying pattern that we could apply universally
Test Plan: Imported from OSS
Differential Revision: D20236771
Pulled By: wanchaol
fbshipit-source-id: 1e0bd4eec14ffebeffdb60b763b8d6f0e427eb64
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31117
After this diff, we will have completely removed the named tensor
feature flagging. This means that named tensors are always on and that
there is no mechanism to turn them off. There should be no more follow-up
diffs.
I performed the deletion of the header with
```
find . -type f -print0 | xargs -0 sed -i '/#include
<ATen\/core\/EnableNamedTensor.h>/d'
```
Test Plan: - wait for CI
Differential Revision: D18934952
Pulled By: zou3519
fbshipit-source-id: 253d059074b910fef15bdf885ebf71e0edf5bea5
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30894
This PR begins the process of removing BUILD_NAMEDTENSOR macros. There
will be followups.
Reasons for removing the macros:
- BUILD_NAMEDTENSOR is always on and has been on since pytorch 1.3.0.
- Since we don't test building without it, it is useless to keep around.
- Code becomes nicer to read without the macros
Reasons for not removing the macros:
- potential for feature flagging
Now, I argue against needing to feature flag. The main reason why we
might want to feature flag is if we need to disable the feature.
We'd need a fast switch to disable the feature if someone discovers
in the future that named tensors caused some regression in some existing workflows.
In https://github.com/pytorch/pytorch/pull/25798, I did a variety of
macro- and micro- benchmarks to determine the performance impact of named
tensors on regular tensors.
[The
microbenchmarks](https://github.com/pytorch/pytorch/pull/25798#issuecomment-529014810)
were not very stable, and running the
microbenchmarks for more iterations doesn't actually help because the
noise is not distributed in a nice way. Instead of microbenchmarks I ran
a [profiler
(perf)](https://github.com/pytorch/pytorch/pull/25798#issuecomment-555707645)
to estimate how much overhead named tensors add to unnamed code. I
estimated the overhead to be less than 100ns for `add` and even smaller
for `mm`; there are ways to optimize even futher if we find this to be a
problem.
[Initial
macrobenchmarks](https://github.com/pytorch/pytorch/pull/25798#issuecomment-530539104)
were also not very stable. I ran imagenet for some number of epochs. To
make them more stable, I got rid of the data loading (which seemed to
vary between runs). [In some benchmarkers without data
loading](https://github.com/pytorch/pytorch/pull/25798#issuecomment-562214053),
we can see that the results are less noisy now. These results support
no noticeable regressions in speed.
Test Plan: - wait for CI
Differential Revision: D18858543
Pulled By: zou3519
fbshipit-source-id: 08bf3853a9f506c6b084808dc9ddd1e835f48c13
Summary:
Fixes https://github.com/pytorch/pytorch/issues/29161.
I looked a bit at the code changes related to this and think I have all of the use cases of `DeprecatedTypeProperties` covered in the message, but suggestions from someone with more context on this would be very much appreciated :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30281
Differential Revision: D18830818
Pulled By: ezyang
fbshipit-source-id: 1a7fcee15354ae09e6644577e7fa33bd26acfe20
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29665
Our intention is to merge the static distinction between Tensor and
Variable. Ordinarily, this would entail merging the methods of Tensor
and Variable. But there are a lot of "private"-ish methods on Variable
that we don't actually want to dump onto the Tensor class. So, as prep
work, we move all of those methods off of Variable and into
the torch::autograd::impl namespace (impl as in, please don't use this
end users). This ends up being a fairly large patch because all of
the call sites have to play ball too.
While I was on the topic, I also moved any of the touched functions into
the C++ file, so that modifying them would not trigger a recompilation of
all of torch.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D18496169
Pulled By: ezyang
fbshipit-source-id: afb203252620ec274be596b3e7b1d84d321bad3a