This is useful for splitting grad to run in two parts while preserving intermediates:
<details>
<summary>
Click to see code
</summary>
```python
import collections
import weakref
from torch.autograd.graph import GradientEdge
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
def reverse_closure(roots, target_nodes):
# Recurse until we reach a target node
closure = set()
actual_target_nodes = set()
q: Deque = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
reverse_edges = node.metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is not None:
raise RuntimeError("Reverse graph is no longer alive")
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
actual_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, actual_target_nodes
# Enable weak pointer
class Holder():
def __init__(self, node):
self.node = node
# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
q: Deque = collections.deque()
root_seen = set()
reverse_graph_refs = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
reverse_edges = fn.metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
fn.metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs, params):
inputs_closure, _ = reverse_closure(inputs, set())
param_groups = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group = {
"params": set([param]),
"intermediates": set(intersected),
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params = set()
seen_ids = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
assert union_params == set(params)
return unique_param_groups
def compute_grads_only_inputs2(roots, inps, weights):
root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))
reverse_graph_refs = construct_reverse_graph(root_grad_fns)
param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(param_group["intermediates"])
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
return dinputs, param_groups
def compute_grads_only_weights2(user_weights, param_groups):
all_dweights = dict()
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
for w, dw in zip(param_group["params"], dweights):
all_dweights[w] = dw
# return grads in the original order weights were provided in
out = []
for w in user_weights:
grad_acc = _get_grad_fn_or_grad_acc(w)
out.append(all_dweights[grad_acc])
return tuple(out)
```
</details>
```python
import torch.nn as nn
# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)
a = torch.rand(10, requires_grad=True)
weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)
out = mod2(mod1(a))
class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
print(f"{func.__module__}.{func.__name__}")
return rs
print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
print("PART 1")
dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
print("PART 2")
dweights = compute_grads_only_weights2(weights, state)
out = mod2(mod1(a))
print(" -- REF -- ")
# Compare with reference
with LoggingTensorMode():
ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))
for actual, ref in zip(dinputs + dweights, ref_all_gradients):
print(torch.allclose(actual, ref))
```
<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">
```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
The existing try-catch doesn't work because it doesn't call err.persist(). This is in contrast to the try-catch for evaluate_function which does work because it calls into python_engine's thread_on_exception which calls persist.
Calling persist on a python_error stashes the PyErr state from the thread-local PyThreadState onto the python_error object, so that when this error object is stored onto the future and passed back to the calling cpu thread, python_engine's execute try-catch can then err.restore() the error state. Finally, the python_engine's execute would re-raise so that this is re-caught by the HANDLE_TH_ERRORS macro.
Fixes https://github.com/pytorch/pytorch/issues/75750
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113702
Approved by: https://github.com/albanD
This can be useful for advanced users (like AOTAutograd) who don't want to keep the corresponding Tensor alive (for memory reasons for example) or when inplace op will change the Tensor's grad_fn (but gradients wrt to the original value is needed).
I went minimal API change but open to suggestions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110867
Approved by: https://github.com/soulitzer
This PR enables the misc-XX checks in clang-tidy. Meanwhile, I excluded some of them that require a lot of code changes and have no immediate benefits. Some additional fixes and suppression were also given.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110283
Approved by: https://github.com/albanD
We have an older torch.vmap implementation. It is no longer supported.
It still needs to exist somewhere for the sake of BC with
torch.autograd.functional.
This PR makes it clear what files are meant for implementing the old
vmap implementation. I've seen a couple of PRs recently adding support
for the old vmap implementation, so this will lessen the confusion.
Test Plan:
- CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90324
Approved by: https://github.com/samdow
We define specializations for pybind11 defined templates
(in particular, PYBIND11_DECLARE_HOLDER_TYPE) and consequently
it is important that these specializations *always* be #include'd
when making use of pybind11 templates whose behavior depends on
these specializations, otherwise we can cause an ODR violation.
The easiest way to ensure that all the specializations are always
loaded is to designate a header (in this case, torch/csrc/util/pybind.h)
that ensures the specializations are defined, and then add a lint
to ensure this header is included whenever pybind11 headers are
included.
The existing grep linter didn't have enough knobs to do this
conveniently, so I added some features. I'm open to suggestions
for how to structure the features better. The main changes:
- Added an --allowlist-pattern flag, which turns off the grep lint
if some other line exists. This is used to stop the grep
lint from complaining about pybind11 includes if the util
include already exists.
- Added --match-first-only flag, which lets grep only match against
the first matching line. This is because, even if there are multiple
includes that are problematic, I only need to fix one of them.
We don't /really/ need this, but when I was running lintrunner -a
to fixup the preexisting codebase it was annoying without this,
as the lintrunner overall driver fails if there are multiple edits
on the same file.
I excluded any files that didn't otherwise have a dependency on
torch/ATen, this was mostly caffe2 and the valgrind wrapper compat
bindings.
Note the grep replacement is kind of crappy, but clang-tidy lint
cleaned it up in most cases.
See also https://github.com/pybind/pybind11/issues/4099
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82552
Approved by: https://github.com/albanD
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64620
`autograd` extension module's shutdown logic destructs `PyThreadState` by `pybind11::gil_scoped_acquire` using the RAII pattern.
The problem is that torch.deploy also destructs `PyThreadState` as part of its shutdown process (https://www.internalfb.com/phabricator/paste/view/P456363738), causing double destruction, use-after-free.
This change adds `defined(USE_DEPLOY)` as a special case to avoid destruction of `PyThreadState` to the existing special treatment for `IS_PYTHON_3_9_PLUS`.
Test Plan: Added `TorchpyTest.Autograd` unittest to ensure that torch.deploy can create multiple instances that use autograd without causing a crash.
Reviewed By: albanD
Differential Revision: D30779080
fbshipit-source-id: 4de3283cc2d394acc9b8141c17cacbfab5eea052
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62563
Expose a pair of functions to Python users: torch.autograd.graph.set_saved_tensors_default_hooks(pack, unpack) and torch.autograd.graph.reset_saved_tensors_default_hooks().
These functions control the hooks applied to saved tensors: all tensors saved in that context will be packed using the pack function, then unpacked accordingly when needed.
Currently, this works by simply calling register_hooks (cf #60975) directly at the end of the constructor of a SavedVariable. This could be optimized further by not performing the copy before registering default hooks, but this would require a small refactor. Edit: the refactor is done in #61927.
A current limitation is that if users create tensors in this context, they will not be able to register additional hooks on the saved tensor.
For instance, to perform something like #28997, one could define a pack function that saves to disk whenever the tensor size is too big and returns a filename, then unpack simply reads the content of the file and outputs a tensor, e.g.:
```
def pack(x):
name = os.path.join(tmp_dir, str(uuid.uuid4()))
torch.save(x, name)
return name
def unpack(name):
return torch.load(name)
```
Relanding previous PR: https://github.com/pytorch/pytorch/pull/61834
Original PR led to timeout error in: https://www.internalfb.com/mast/job/yuguo-release_canary_offline_training-inlinecvrp_a-canary_offline_train_28a7ecfc
Now passing: https://www.internalfb.com/mast/job/quach-release_canary_offline_training-inlinecvrp_a-canary_offline_train_9bb57e98
The difference with the new version is we don't need to acquire the GIL when calling `PyDefaultSavedVariableHooks::get_hooks`.
Test Plan: Imported from OSS
Reviewed By: iramazanli
Differential Revision: D30045405
Pulled By: Varal7
fbshipit-source-id: 7f6c07af3a56fe8835d5edcc815c15ea4fb4e332
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61834
Expose a pair of functions to Python users: torch.autograd.graph.set_saved_tensors_default_hooks(pack, unpack) and torch.autograd.graph.reset_saved_tensors_default_hooks().
These functions control the hooks applied to saved tensors: all tensors saved in that context will be packed using the pack function, then unpacked accordingly when needed.
Currently, this works by simply calling register_hooks (cf #60975) directly at the end of the constructor of a SavedVariable. This could be optimized further by not performing the copy before registering default hooks, but this would require a small refactor. Edit: the refactor is done in #61927.
A current limitation is that if users create tensors in this context, they will not be able to register additional hooks on the saved tensor.
For instance, to perform something like #28997, one could define a pack function that saves to disk whenever the tensor size is too big and returns a filename, then unpack simply reads the content of the file and outputs a tensor, e.g.:
```
def pack(x):
name = os.path.join(tmp_dir, str(uuid.uuid4()))
torch.save(x, name)
return name
def unpack(name):
return torch.load(name)
```
Test Plan: Imported from OSS
Reviewed By: zou3519
Differential Revision: D29792193
Pulled By: Varal7
fbshipit-source-id: 33e931230ef59faa3ec8b5d11ef7c05539bce77c
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`
All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008
Reviewed By: driazati, r-barnes
Differential Revision: D29838584
Pulled By: malfet
fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
Summary:
Switches most of the simple for loops outside of `jit` directories to use `c10::irange`.
Generated with D28874212.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59481
Test Plan: Sandcastle
Reviewed By: ngimel
Differential Revision: D28909681
fbshipit-source-id: ec9ab1bd602933238d9d0f73d4d8d027b75d9d85
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58420
In https://github.com/pytorch/pytorch/pull/57636 I migrated most uses of Future to an intrusive_ptr. I thought I had all of them but I missed a couple. These are the remaining ones. (The next PR will make it impossible to add new usages of shared_ptr).
ghstack-source-id: 129567071
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28477285
fbshipit-source-id: 75008276baa59e26b450e942c009ec7e78f89b13
Summary:
Such a deadlock was found for PyFunctionPreHook after adding https://github.com/pytorch/pytorch/pull/57057
This is fixing all occurrences in torch/csrc/autograd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57488
Reviewed By: malfet
Differential Revision: D28163321
Pulled By: albanD
fbshipit-source-id: 4daf1db69674e73967fc7c5ca2a240c61340e7ca
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56808
For information about data-race-on-vptr in general, see https://www.internalfb.com/intern/wiki/TSAN/Common_Concurrency_Mistakes/Stopping_a_Thread_in_Destructor/
Engine::~Engine() was previously tasked with stopping the threads. This causes a data race on the object's vptr when PythonEngine is being destructed. This fixes the data race by making ~PythonEngine trigger the thread stopping before going down to the base class's destructor.
Test Plan:
Many tests are affected, but here's one example:
buck test mode/dev-tsan -c fbcode.tsan_strict_mode=true //oculus/research/orcoptics/deep_learning/srg_nn/tests:test_grating_net -- 'test_train (oculus.research.orcoptics.deep_learning.srg_nn.tests.test_grating_net.TestGratingNet)' --run-disabled
Reviewed By: walterddr, albanD
Differential Revision: D27972384
fbshipit-source-id: 8b70fec8d9326497c591a2777b355ea590a85082
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55799
I'm going to change the implementation of cdata soon so I need to
abstract over cdata access with a function. Additionally, many
users are casting manually casting to THPVariable to access
the member so I can remove these unsafe casts in the client code
(the implementation, of course, is still doing an unsafe cast.)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D27712130
Pulled By: ezyang
fbshipit-source-id: 95fcc013bf3913d67f2c634068eb5b3aab144cb3
Summary:
Fixes https://github.com/pytorch/pytorch/issues/39784
At the time the issue was filed, there was only issue (1) below.
There are actually now two issues here:
1. We always set all inputs passed in through `inputs` arg as `needed = True` in exec_info. So if we pass in an input that has a grad_fn that is not materialized, we create an entry of exec_info with nullptr as key with `needed = True`. Coincidentally, when we perform simple arithmetic operations, such as "2 * x", one of the next edges of mul is an invalid edge, meaning that its grad_fn is also nullptr. This causes the discovery algorithm to set all grad_fns that have a path to this invalid_edge as `needed = True`.
2. Before the commit that enabled the engine skipped the dummy node, we knew that root node is always needed, i.e., we hardcode `exec_info[&graph_root]=true`. The issue was that this logic wasn't updated after the code was updated to skip the graph root.
To address (1), instead of passing in an invalid edge if an input in `inputs` has no grad_fn, we create a dummy grad_fn. This is done in both python and cpp entry points. The alternative is to add logic for both backward() and grad() cases to check whether the grad_fn is nullptr and set needed=false in that case (the .grad() case would be slightly more complicated than the .backward() case here).
For (2), we perform one final iteration of the discovery algorithm so that we really know whether we need to execute the graph root.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51940
Reviewed By: VitalyFedyunin
Differential Revision: D26369529
Pulled By: soulitzer
fbshipit-source-id: 14a01ae7988a8de621b967a31564ce1d7a00084e
Summary:
Remove `THPWrapper` from PyTorch C code since it is not used anymore and because we have dropped Python 2 compatibility, its usage can be replaced by capsule objects (`PyCapsule_New`, `PyCapsule_CheckExact`, `PyCapsule_GetPointer` and `PyCapsule_GetDestructor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49871
Reviewed By: mruberry
Differential Revision: D25715038
Pulled By: albanD
fbshipit-source-id: cc3b6f967bbe0dc42c692adf76dff4e4b667fdd5
Summary:
Fixes https://github.com/pytorch/pytorch/issues/46373
As noted in https://github.com/pytorch/pytorch/issues/46373, there needs to be a flag passed into the engine that indicates whether it was executed through the backward api or grad api. Tentatively named the flag `accumulate_grad` since functionally, backward api accumulates grad into .grad while grad api captures the grad and returns it.
Moving changes not necessary to the python api (cpp, torchscript) to a new PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46855
Reviewed By: ngimel
Differential Revision: D24649054
Pulled By: soulitzer
fbshipit-source-id: 6925d5a67d583eeb781fc7cfaec807c410e1fc65
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46227
Follow up from https://github.com/pytorch/pytorch/issues/45419, in
this PR I've removed as many PyCFunction casts as I could from the codebase.
The only ones I didn't remove were the ones with `METH_VARARGS | METH_KEYWORDS`
which have 3 parameters instead of 2 and had to be casted. Example: `
{"copy_", (PyCFunction)(void(*)(void))THPStorage_(copy_), METH_VARARGS |
METH_KEYWORDS, nullptr},`
ghstack-source-id: 114632704
Test Plan: waitforbuildbot
Reviewed By: albanD
Differential Revision: D24269435
fbshipit-source-id: 025cfd43a9a2a3e59f6b2951c1a78749193d77cf
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45461
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable:
```
>>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble)
>>> x.pow(3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: pow does not support automatic differentiation for outputs with complex dtype.
```
The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions:
`torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`.
Test Plan: Imported from OSS
Reviewed By: malfet
Differential Revision: D23998156
Pulled By: anjali411
fbshipit-source-id: 370eb07fe56ac84dd8e2233ef7bf3a3eb8aeb179
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43676
This is one part of https://github.com/pytorch/pytorch/issues/41574 to
ensure we consolidate everything around ivalue::Future.
I've removed the use of torch/csrc/utils/future.h from the autograd engines and
used ivalue::Future instead.
ghstack-source-id: 110895545
Test Plan: waitforbuildbot.
Reviewed By: albanD
Differential Revision: D23362415
fbshipit-source-id: aa109b3f8acf0814d59fc5264a85a8c27ef4bdb6
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42876
Previously, the error messages were pretty bad. This PR adds nice
error messages for the following cases:
- user attempts to call .backward() inside vmap for any reason
whatsoever
- user attempts to call autograd.grad(outputs, inputs, grad_outputs),
where outputs or inputs is being vmapped over (so they are
BatchedTensors).
The case we do support is calling autograd.grad(outputs, inputs,
grad_outputs) where `grad_outputs` is being vmapped over. This is the
case for batched gradient support (e.g., user passes in a batched
grad_output).
Test Plan: - new tests: `pytest test/test_vmap.py -v`
Reviewed By: ezyang
Differential Revision: D23059836
Pulled By: zou3519
fbshipit-source-id: 2fd4e3fd93f558e67e2f0941b18f0d00d8ab439f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40312
As part of https://github.com/pytorch/pytorch/issues/40255, we
realized that GPU support for distributed autograd was broken as part of our
multithreaded autograd change.
To fix this in the short term for 1.6, this PR includes the following changes:
1) Long lived CPU thread in DistEngine to execute GPU->CPU continuations in the
autograd graph.
2) The long lived CPU thread has its own ready_queue and this queue is used for
all GraphTasks created by DistEngine.
3) In thread_main(), the CPU thread cannot exit once the GraphTask is done
processing because of the new CPU thread added in 1).
4) To resolve this, thread_main() now has a parameter `device_thread` instead
of `reentrant_thread`. When device_thread is True, we expect this to be a long
lived device thread that does not exit.
5) When device_thread is False, thread_main is expected to run a GraphTask and
return once done.
ghstack-source-id: 106391329
Test Plan: waitforbuildbot
Differential Revision: D22146183
fbshipit-source-id: dd146b7a95f55db75f6767889b7255e9d62d5825
Summary:
If Engine is created shortly before application exits, then non-reentrant thread might not have a chance to spawn which would result in an infinite wait in `Engine::~Engine()`
Prevent this by actually waiting for threads to spawn before returning from `Engine::start_device_threads()`
Make sure that thread count is incremented before GIL is acquired in PythonThread
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39194
Differential Revision: D21789219
Pulled By: malfet
fbshipit-source-id: d9b5e74d5ddeb2474b575af2e4f33d022efcfe53
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36606
This PR refactor the continuation logic of the async mode on autograd
engine, to avoid launch spinning works. To achieve that:
1. remove the continuation logic in
execute_graph_task_with_continuiation
2. separate the usage of execute_graph_task between dist_engine and
local engine, now dist_engine universally use
`execute_graph_task_until_ready_queue_empty` (a better name appreciated
here).
3. remove enqueue_blocked_task_on_cpu
4. remove the async mode in `execute_with_graph_task` as we don't need
to use it in dist_engine
Test Plan: Imported from OSS
Differential Revision: D21032731
Pulled By: wanchaol
fbshipit-source-id: 708ea3bc14815bdc151b56afa15eb85b4ac0f4b1
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33157
This PR enables graph level thread parallelism on CPU for the Autograd
Engine. It replace https://github.com/pytorch/pytorch/pull/29574 for the
reason of task level parallelism drawbacks with the existing autograd
system.
Fixes https://github.com/pytorch/pytorch/issues/18333
The graph level parallelism on CPU design:
1. Remove the single CPU thread that init in the Engine itself and allow
the owning thread (which calls Engine::execute) to drive the Engine
execution so that we could let outer threading to enable thread
parallelism.
2. Maintain a separate ReadyQueue per CPU thread, and stash the
ReadyQueue for different devices/threads into the thread local
shared_ptr, the Engine itself will memorize the shared_ptr of the
ReadyQueue to different devices (other than CPU)
3. The CPU thread local ReadyQueue is initialized per CPU thread
Engine::execute call (or `backward()`, `grad()` call), and memorized
the shared_ptr into the GraphTask since every `backward()` call have
its own GraphTask
4. Cross device NodeTask push is accomplished by 2 and 3. we can refer
to device's ReadyQueue from Engine, and CPU's ReadyQueue from
GraphTask, which means if we can push to a different ReadyQueue
according to the device
5. Termination of the CPU thread: if we mark the graph_task as
completed, we will exit the while loop and terminate the current
backward execution, because it's guranteed that all other NodeTasks
is finished before we mark a GraphTask as complete
6. re-entrant thread logic keeps the same, reentrant thread detection is
similar as before, we set the worker_device to NO_DEVICE initially
and set to CPU afterward to detect if this is a reentrant call or not.
7. we still have the reentrant thread pool that create new threads if it's
a deep reentrant case, and reuse the ReadyQueue with the parent thread
for performance.
Since we introduce the thread parallelism on CPU, we have to ensure the
thread safety of the GraphTask. This is not a problem if we execute all
forward in different threads since we will build separate GraphTask in
different threads, and each GraphTask is a separate instance that share
nothing, i.e. Hogwild training on CPU should be fine on this case.
But there might be case that user would like to do some part of the task in
a single thread, and do the rest of work in several threads
concurrently, so thread safety is crucial in those cases. The thread
safety strategy for the multithread autograd is as follows:
1. Add a mutex to protect thread safety in Autograd Node/Function, and
hold the lock for different data racing cases
2. Lock the mutex during Node::apply(), this is to ensure Node that
writing to the shared variable are not racing across threads (i.e.
AccumulateGrad and custom C++ Autograd Node if writing to shared
variables )
3. Lock the mutex during Node::release_variables(), this serve the
purpose that when we release saved_variables from one thread, no
other threads can call the Node::apply(), this ensures the variable
references from other threads aren't dangling.
4. If we don't release any variables and no shared data read/write in
the Node i.e. purely functional, we don't lock the mutex
This way we could protect the thread safety on Autograd Node, but we
could still not protect the thread safety on Node pre/post C++ hooks
(python hooks are automatically thread safe), we rely on the user to
write thread safe C++ hooks if they want the hook to be correctly
applied in multithreading environment.
**User visiable changes**:
There're not too much user visiable changes, since we use the owning
thread to drive the autograd execution, user could write their own
threading code and does not block on the Autograd engine, some behaviors
that user should be aware of:
**Non-determinism**:
if we are calling backward() on multiple thread concurrently but with
shared inputs (i.e. Hogwild CPU training). Since parameters are automatically shared across threads, gradient accumulation might become non-deterministic on backward calls across threads, because two backward calls might access and try to accumulate the same .grad attribute. This is technically not safe, and it might result in racing condition and the result might be invalid to use.
But this is expected pattern if user are using the multithreading
approach to drive the whole training process but using shared
parameters, user who use multithreading should have the threading model
in mind and should expect this to happen. User should use the functional
interface `torch.autograd.grad()` to calculate the gradients instead of
`backward()` on loss.
**Graph retaining**:
If part of the autograd graph is shared between threads, i.e. run first
part of forward single thread, then run second part in multiple threads,
then the first part of graph is shared. In this case different threads execute grad() or backward() on the same graph might
have issue of destroying the graph on the fly of one thread, and the
other thread will crash in this case. We will error out to the user
similar to what call `backward()` twice with out `retain_graph=True`, and let the user know they should use `retain_graph=True`.
**TODOs**:
[ ] benchmark the PR with example models and datasets to demonstrate
the performance gain in CPU training
[ ] ensure that we don't regress the single thread autograd performance
**Follow ups**:
[ ] a correct and tight integration with distributed autograd
[ ] try to unify the thread pool between JIT and Autograd, and see if
there's unifying pattern that we could apply universally
Test Plan: Imported from OSS
Differential Revision: D20236771
Pulled By: wanchaol
fbshipit-source-id: 1e0bd4eec14ffebeffdb60b763b8d6f0e427eb64