* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
The regression from https://github.com/pytorch/pytorch/issues/132281 pinpoints e4ace1a396 as the cause. The main delta that commit introduces is that we now manually check `is_inference()` and call `increment_version()` (a pybind call) on every mutated input tensor to the graph.
This PR attempts to reduce overhead a bit by bundling up all of those checks into a single pybind call, by:
(1) updating `torch.autograd.graph.increment_version()` to accept a `Union[Tensor, List[Tensor]]`
(2) updating its semantics to no-op if you pass in a tensor with no version counter, instead of erroring
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132652
Approved by: https://github.com/albanD
This is useful for splitting grad to run in two parts while preserving intermediates:
<details>
<summary>
Click to see code
</summary>
```python
import collections
import weakref
from torch.autograd.graph import GradientEdge
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
def reverse_closure(roots, target_nodes):
# Recurse until we reach a target node
closure = set()
actual_target_nodes = set()
q: Deque = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
reverse_edges = node.metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is not None:
raise RuntimeError("Reverse graph is no longer alive")
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
actual_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, actual_target_nodes
# Enable weak pointer
class Holder():
def __init__(self, node):
self.node = node
# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
q: Deque = collections.deque()
root_seen = set()
reverse_graph_refs = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
reverse_edges = fn.metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
fn.metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs, params):
inputs_closure, _ = reverse_closure(inputs, set())
param_groups = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group = {
"params": set([param]),
"intermediates": set(intersected),
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params = set()
seen_ids = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
assert union_params == set(params)
return unique_param_groups
def compute_grads_only_inputs2(roots, inps, weights):
root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))
reverse_graph_refs = construct_reverse_graph(root_grad_fns)
param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(param_group["intermediates"])
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
return dinputs, param_groups
def compute_grads_only_weights2(user_weights, param_groups):
all_dweights = dict()
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
for w, dw in zip(param_group["params"], dweights):
all_dweights[w] = dw
# return grads in the original order weights were provided in
out = []
for w in user_weights:
grad_acc = _get_grad_fn_or_grad_acc(w)
out.append(all_dweights[grad_acc])
return tuple(out)
```
</details>
```python
import torch.nn as nn
# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)
a = torch.rand(10, requires_grad=True)
weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)
out = mod2(mod1(a))
class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
print(f"{func.__module__}.{func.__name__}")
return rs
print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
print("PART 1")
dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
print("PART 2")
dweights = compute_grads_only_weights2(weights, state)
out = mod2(mod1(a))
print(" -- REF -- ")
# Compare with reference
with LoggingTensorMode():
ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))
for actual, ref in zip(dinputs + dweights, ref_all_gradients):
print(torch.allclose(actual, ref))
```
<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">
```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
For some reason, if we construct `class Handle(RemovableHandle` inside `register_multi_grad_hook`, then over time, the call to `RemovableHandle.__init__` slows down more and more (when we have GC disabled). Perhaps, this is related to the class attribute `next_id: int = 0`. Python experts: please let me know if you have thoughts 😅
I am open to any suggestions on if how we should deal with this `Handle` class. For now, I changed it to a private `_MultiHandle`.
<details>
<summary> Experiment Script </summary>
```
import gc
import time
import torch
NUM_TENSORS = int(5e4)
ts = [torch.empty(1, requires_grad=True) for _ in range(NUM_TENSORS)]
def hook(grad) -> None:
return
gc.disable()
times = []
for i, t in enumerate(ts):
start_time = time.time()
torch.autograd.graph.register_multi_grad_hook([t], hook)
end_time = time.time()
times.append(end_time - start_time)
print([f"{t * 1e6:.3f} us" for t in times[1:6]]) # print first few times
print([f"{t * 1e6:.3f} us" for t in times[-5:]]) # print last few times
times = []
for i, t in enumerate(ts):
start_time = time.time()
t.register_hook(hook)
end_time = time.time()
times.append(end_time - start_time)
print([f"{t * 1e6:.3f} us" for t in times[1:6]]) # print first few times
print([f"{t * 1e6:.3f} us" for t in times[-5:]]) # print last few times
```
</details>
<details>
<summary> Results </summary>
Before fix:
```
['23.603 us', '19.550 us', '15.497 us', '12.875 us', '13.828 us']
['327.110 us', '341.177 us', '329.733 us', '332.832 us', '341.177 us']
['318.050 us', '315.189 us', '319.719 us', '311.613 us', '308.990 us']
['374.317 us', '394.821 us', '350.714 us', '337.362 us', '331.402 us']
```
Calling `register_multi_grad_hook` makes calling itself and `register_hook` slower (actually, any call to `RemovableHandle.__init__`).
After fix:
```
['13.590 us', '9.060 us', '12.875 us', '7.153 us', '8.583 us']
['4.530 us', '5.245 us', '6.437 us', '4.768 us', '5.007 us']
['2.623 us', '1.907 us', '1.431 us', '1.669 us', '1.192 us']
['1.431 us', '1.431 us', '1.192 us', '1.192 us', '1.431 us']
```
</details>
Update: from @soulitzer
> Your suspicion about next_id is right. I think what is happening is that whenever a class attribute is set, it needs to invalidate some cached data for the subclasses one-by-one. eefff682f0/Objects/typeobject.c (L845)
And this PR fixes the issue by avoiding creating many subclasses dynamically. Changing next_id to something like List[int] or incrementing a global instead also fixes this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122847
Approved by: https://github.com/soulitzer
ghstack dependencies: #122726
This PR:
- disallows FakeTensor.data_ptr when it is called inside PT2 or fx tracing.
- disallows FunctionalTensor.data_ptr (python FunctionalTensor is only used in
PT2)
The motivation behind this is that the leading cause of segfaults when
using custom ops with PT2 is calling .data_ptr on FunctionalTensor or
FakeTensor.
This change is BC-breaking. If your code broke as a result of this, it's
because there was a bug in it (these .data_ptr should never be
accessed!). You can either fix the bug (recommended) or get the previous
behavior back with:
```
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
data_ptr = 0 if isinstance(tensor, (FakeTensor, FunctionalTensor)) else tensor.data_ptr()
```
Test Plan:
- existing tests
Differential Revision: [D55366199](https://our.internmc.facebook.com/intern/diff/D55366199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122514
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/yifuwang, https://github.com/kurtamohler
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Instead of printing the tensor's data print the dtype and shape metadata of the tensor.
```
Executing: <VarMeanBackward0 object at 0x1352d0e20> with grad_outputs: [None,f32[]]
```
This is important in order to avoid doing a cuda sync and also useful to reduce verbosity.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116523
Approved by: https://github.com/albanD
Fixes#112595
- `torch/autograd/profiler.py` </br>
**Before: 37**
```
torch/autograd/profiler.py:1 at module level:
D100: Missing docstring in public module
torch/autograd/profiler.py:91 in public class `profile`:
D205: 1 blank line required between summary line and description (found 0)
torch/autograd/profiler.py:175 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/profiler.py:261 in public method `config`:
D102: Missing docstring in public method
torch/autograd/profiler.py:272 in public method `__enter__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:290 in public method `__exit__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:308 in public method `__repr__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:313 in public method `__str__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:322 in public method `table`:
D102: Missing docstring in public method
torch/autograd/profiler.py:346 in public method `export_chrome_trace`:
D102: Missing docstring in public method
torch/autograd/profiler.py:355 in public method `export_stacks`:
D102: Missing docstring in public method
torch/autograd/profiler.py:361 in public method `key_averages`:
D102: Missing docstring in public method
torch/autograd/profiler.py:368 in public method `total_average`:
D102: Missing docstring in public method
torch/autograd/profiler.py:377 in public method `self_cpu_time_total`:
D205: 1 blank line required between summary line and description (found 0)
torch/autograd/profiler.py:377 in public method `self_cpu_time_total`:
D400: First line should end with a period (not 'f')
torch/autograd/profiler.py:555 in public class `record_function`:
D205: 1 blank line required between summary line and description (found 0)
torch/autograd/profiler.py:555 in public class `record_function`:
D400: First line should end with a period (not 'f')
torch/autograd/profiler.py:591 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/profiler.py:602 in public method `__enter__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:608 in public method `__exit__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:625 in private method `_call_end_callbacks_on_future`:
D205: 1 blank line required between summary line and description (found 0)
torch/autograd/profiler.py:625 in private method `_call_end_callbacks_on_future`:
D400: First line should end with a period (not 'c')
torch/autograd/profiler.py:707 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/profiler.py:712 in public method `__enter__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:733 in public method `__exit__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:826 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/profiler.py:831 in public method `__enter__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:853 in public method `__exit__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:863 in public function `load_nvprof`:
D401: First line should be in imperative mood (perhaps 'Open', not 'Opens')
torch/autograd/profiler.py:874 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/profiler.py:877 in public method `see`:
D102: Missing docstring in public method
torch/autograd/profiler.py:883 in public function `parse_nvprof_trace`:
D103: Missing docstring in public function
torch/autograd/profiler.py:951 in public class `KinetoStepTracker`:
D205: 1 blank line required between summary line and description (found 0)
torch/autograd/profiler.py:991 in public method `init_step_count`:
D102: Missing docstring in public method
torch/autograd/profiler.py:995 in public method `erase_step_count`:
D102: Missing docstring in public method
torch/autograd/profiler.py:1000 in public method `increment_step`:
D205: 1 blank line required between summary line and description (found 0)
torch/autograd/profiler.py:1023 in public method `current_step`:
D102: Missing docstring in public method
37
```
**After: 27**
```
torch/autograd/profiler.py:1 at module level:
D100: Missing docstring in public module
torch/autograd/profiler.py:176 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/profiler.py:262 in public method `config`:
D102: Missing docstring in public method
torch/autograd/profiler.py:273 in public method `__enter__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:291 in public method `__exit__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:309 in public method `__repr__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:314 in public method `__str__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:323 in public method `table`:
D102: Missing docstring in public method
torch/autograd/profiler.py:347 in public method `export_chrome_trace`:
D102: Missing docstring in public method
torch/autograd/profiler.py:356 in public method `export_stacks`:
D102: Missing docstring in public method
torch/autograd/profiler.py:362 in public method `key_averages`:
D102: Missing docstring in public method
torch/autograd/profiler.py:369 in public method `total_average`:
D102: Missing docstring in public method
torch/autograd/profiler.py:593 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/profiler.py:604 in public method `__enter__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:610 in public method `__exit__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:708 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/profiler.py:713 in public method `__enter__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:734 in public method `__exit__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:827 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/profiler.py:832 in public method `__enter__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:854 in public method `__exit__`:
D105: Missing docstring in magic method
torch/autograd/profiler.py:875 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/profiler.py:878 in public method `see`:
D102: Missing docstring in public method
torch/autograd/profiler.py:884 in public function `parse_nvprof_trace`:
D103: Missing docstring in public function
torch/autograd/profiler.py:993 in public method `init_step_count`:
D102: Missing docstring in public method
torch/autograd/profiler.py:997 in public method `erase_step_count`:
D102: Missing docstring in public method
torch/autograd/profiler.py:1025 in public method `current_step`:
D102: Missing docstring in public method
27
```
- `torch/autograd/graph.py` </br>
**Before: 22**
```
torch/autograd/graph.py:1 at module level:
D100: Missing docstring in public module
torch/autograd/graph.py:24 in public class `Node`:
D101: Missing docstring in public class
torch/autograd/graph.py:27 in public method `name`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
torch/autograd/graph.py:42 in public method `next_functions`:
D102: Missing docstring in public method
torch/autograd/graph.py:47 in public method `metadata`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
torch/autograd/graph.py:56 in public method `register_hook`:
D401: First line should be in imperative mood (perhaps 'Register', not 'Registers')
torch/autograd/graph.py:94 in public method `register_prehook`:
D401: First line should be in imperative mood (perhaps 'Register', not 'Registers')
torch/autograd/graph.py:129 in public method `__subclasshook__`:
D105: Missing docstring in magic method
torch/autograd/graph.py:147 in public function `get_gradient_edge`:
D205: 1 blank line required between summary line and description (found 0)
torch/autograd/graph.py:147 in public function `get_gradient_edge`:
D400: First line should end with a period (not 'f')
torch/autograd/graph.py:147 in public function `get_gradient_edge`:
D401: First line should be in imperative mood; try rephrasing (found 'This')
torch/autograd/graph.py:166 in public function `increment_version`:
D205: 1 blank line required between summary line and description (found 0)
torch/autograd/graph.py:166 in public function `increment_version`:
D400: First line should end with a period (not 'd')
torch/autograd/graph.py:166 in public function `increment_version`:
D401: First line should be in imperative mood; try rephrasing (found 'This')
torch/autograd/graph.py:243 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/graph.py:251 in public method `__enter__`:
D105: Missing docstring in magic method
torch/autograd/graph.py:256 in public method `__exit__`:
D105: Missing docstring in magic method
torch/autograd/graph.py:261 in public class `save_on_cpu`:
D205: 1 blank line required between summary line and description (found 0)
torch/autograd/graph.py:261 in public class `save_on_cpu`:
D400: First line should end with a period (not 'e')
torch/autograd/graph.py:303 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/graph.py:365 in public function `register_multi_grad_hook`:
D401: First line should be in imperative mood (perhaps 'Register', not 'Registers')
torch/autograd/graph.py:588 in public function `allow_mutation_on_saved_tensors`:
D400: First line should end with a period (not 'd')
22
```
**After: 8**
```
torch/autograd/graph.py:1 at module level:
D100: Missing docstring in public module
torch/autograd/graph.py:24 in public class `Node`:
D101: Missing docstring in public class
torch/autograd/graph.py:42 in public method `next_functions`:
D102: Missing docstring in public method
torch/autograd/graph.py:129 in public method `__subclasshook__`:
D105: Missing docstring in magic method
torch/autograd/graph.py:244 in public method `__init__`:
D107: Missing docstring in __init__
torch/autograd/graph.py:252 in public method `__enter__`:
D105: Missing docstring in magic method
torch/autograd/graph.py:257 in public method `__exit__`:
D105: Missing docstring in magic method
torch/autograd/graph.py:303 in public method `__init__`:
D107: Missing docstring in __init__
8
```
- `torch/multiprocessing/pool.py` </br>
**Before: 6**
```
torch/multiprocessing/pool.py:1 at module level:
D100: Missing docstring in public module
torch/multiprocessing/pool.py:7 in public function `clean_worker`:
D103: Missing docstring in public function
torch/multiprocessing/pool.py:18 in public class `Pool`:
D205: 1 blank line required between summary line and description (found 0)
torch/multiprocessing/pool.py:18 in public class `Pool`:
D209: Multi-line docstring closing quotes should be on a separate line
torch/multiprocessing/pool.py:29 in private method `_repopulate_pool`:
D205: 1 blank line required between summary line and description (found 0)
torch/multiprocessing/pool.py:29 in private method `_repopulate_pool`:
D400: First line should end with a period (not ',')
6
```
**After: 2**
```
torch/multiprocessing/pool.py:1 at module level:
D100: Missing docstring in public module
torch/multiprocessing/pool.py:7 in public function `clean_worker`:
D103: Missing docstring in public function
2
```
- `torch/multiprocessing/queue.py` </br>
**Before: 11**
```
torch/multiprocessing/queue.py:1 at module level:
D100: Missing docstring in public module
torch/multiprocessing/queue.py:8 in public class `ConnectionWrapper`:
D205: 1 blank line required between summary line and description (found 0)
torch/multiprocessing/queue.py:8 in public class `ConnectionWrapper`:
D209: Multi-line docstring closing quotes should be on a separate line
torch/multiprocessing/queue.py:8 in public class `ConnectionWrapper`:
D400: First line should end with a period (not 'o')
torch/multiprocessing/queue.py:11 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/queue.py:14 in public method `send`:
D102: Missing docstring in public method
torch/multiprocessing/queue.py:19 in public method `recv`:
D102: Missing docstring in public method
torch/multiprocessing/queue.py:23 in public method `__getattr__`:
D105: Missing docstring in magic method
torch/multiprocessing/queue.py:29 in public class `Queue`:
D101: Missing docstring in public class
torch/multiprocessing/queue.py:30 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/queue.py:38 in public class `SimpleQueue`:
D101: Missing docstring in public class
11
```
**After: 8**
```
torch/multiprocessing/queue.py:1 at module level:
D100: Missing docstring in public module
torch/multiprocessing/queue.py:10 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/queue.py:13 in public method `send`:
D102: Missing docstring in public method
torch/multiprocessing/queue.py:18 in public method `recv`:
D102: Missing docstring in public method
torch/multiprocessing/queue.py:22 in public method `__getattr__`:
D105: Missing docstring in magic method
torch/multiprocessing/queue.py:28 in public class `Queue`:
D101: Missing docstring in public class
torch/multiprocessing/queue.py:29 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/queue.py:37 in public class `SimpleQueue`:
D101: Missing docstring in public class
8
```
- `torch/multiprocessing/reductions.py` </br>
**Before: 31**
```
torch/multiprocessing/reductions.py:1 at module level:
D100: Missing docstring in public module
torch/multiprocessing/reductions.py:24 in public class `StorageWeakRef`:
D209: Multi-line docstring closing quotes should be on a separate line
torch/multiprocessing/reductions.py:31 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/reductions.py:38 in public method `from_weakref`:
D102: Missing docstring in public method
torch/multiprocessing/reductions.py:44 in public method `expired`:
D102: Missing docstring in public method
torch/multiprocessing/reductions.py:47 in public method `__del__`:
D105: Missing docstring in magic method
torch/multiprocessing/reductions.py:50 in public method `__hash__`:
D105: Missing docstring in magic method
torch/multiprocessing/reductions.py:53 in public method `__eq__`:
D105: Missing docstring in magic method
torch/multiprocessing/reductions.py:60 in public class `SharedCache`:
D400: First line should end with a period (not 'f')
torch/multiprocessing/reductions.py:62 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/reductions.py:75 in public method `get`:
D102: Missing docstring in public method
torch/multiprocessing/reductions.py:79 in public method `__setitem__`:
D105: Missing docstring in magic method
torch/multiprocessing/reductions.py:85 in public method `free_dead_references`:
D102: Missing docstring in public method
torch/multiprocessing/reductions.py:99 in public function `rebuild_event`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:103 in public function `reduce_event`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:108 in public function `rebuild_tensor`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:121 in public function `rebuild_cuda_tensor`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:189 in public function `reduce_tensor`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:347 in public function `rebuild_nested_tensor`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:364 in public function `reduce_nested_tensor`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:389 in public function `fd_id`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:397 in public function `storage_from_cache`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:404 in public function `rebuild_storage_fd`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:417 in public function `rebuild_storage_filename`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:437 in public function `rebuild_storage_empty`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:441 in public function `rebuild_typed_storage`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:446 in public function `reduce_typed_storage`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:450 in public function `rebuild_typed_storage_child`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:455 in public function `reduce_typed_storage_child`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:459 in public function `reduce_storage`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:488 in public function `init_reductions`:
D103: Missing docstring in public function
31
```
**After: 29**
```
torch/multiprocessing/reductions.py:1 at module level:
D100: Missing docstring in public module
torch/multiprocessing/reductions.py:32 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/reductions.py:39 in public method `from_weakref`:
D102: Missing docstring in public method
torch/multiprocessing/reductions.py:45 in public method `expired`:
D102: Missing docstring in public method
torch/multiprocessing/reductions.py:48 in public method `__del__`:
D105: Missing docstring in magic method
torch/multiprocessing/reductions.py:51 in public method `__hash__`:
D105: Missing docstring in magic method
torch/multiprocessing/reductions.py:54 in public method `__eq__`:
D105: Missing docstring in magic method
torch/multiprocessing/reductions.py:63 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/reductions.py:76 in public method `get`:
D102: Missing docstring in public method
torch/multiprocessing/reductions.py:80 in public method `__setitem__`:
D105: Missing docstring in magic method
torch/multiprocessing/reductions.py:86 in public method `free_dead_references`:
D102: Missing docstring in public method
torch/multiprocessing/reductions.py:100 in public function `rebuild_event`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:104 in public function `reduce_event`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:109 in public function `rebuild_tensor`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:122 in public function `rebuild_cuda_tensor`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:190 in public function `reduce_tensor`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:348 in public function `rebuild_nested_tensor`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:365 in public function `reduce_nested_tensor`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:390 in public function `fd_id`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:398 in public function `storage_from_cache`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:405 in public function `rebuild_storage_fd`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:418 in public function `rebuild_storage_filename`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:438 in public function `rebuild_storage_empty`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:442 in public function `rebuild_typed_storage`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:447 in public function `reduce_typed_storage`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:451 in public function `rebuild_typed_storage_child`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:456 in public function `reduce_typed_storage_child`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:460 in public function `reduce_storage`:
D103: Missing docstring in public function
torch/multiprocessing/reductions.py:489 in public function `init_reductions`:
D103: Missing docstring in public function
29
```
- `torch/multiprocessing/spawn.py` </br>
**Before: 19**
```
torch/multiprocessing/spawn.py:1 at module level:
D100: Missing docstring in public module
torch/multiprocessing/spawn.py:11 in public class `ProcessException`:
D101: Missing docstring in public class
torch/multiprocessing/spawn.py:14 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/spawn.py:20 in public method `__reduce__`:
D105: Missing docstring in magic method
torch/multiprocessing/spawn.py:25 in public class `ProcessRaisedException`:
D205: 1 blank line required between summary line and description (found 0)
torch/multiprocessing/spawn.py:25 in public class `ProcessRaisedException`:
D400: First line should end with a period (not 'n')
torch/multiprocessing/spawn.py:30 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/spawn.py:40 in public class `ProcessExitedException`:
D205: 1 blank line required between summary line and description (found 0)
torch/multiprocessing/spawn.py:40 in public class `ProcessExitedException`:
D400: First line should end with a period (not 'l')
torch/multiprocessing/spawn.py:47 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/spawn.py:59 in public method `__reduce__`:
D105: Missing docstring in magic method
torch/multiprocessing/spawn.py:85 in public class `ProcessContext`:
D101: Missing docstring in public class
torch/multiprocessing/spawn.py:86 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/spawn.py:93 in public method `pids`:
D102: Missing docstring in public method
torch/multiprocessing/spawn.py:97 in public method `join`:
D205: 1 blank line required between summary line and description (found 0)
torch/multiprocessing/spawn.py:97 in public method `join`:
D401: First line should be in imperative mood (perhaps 'Try', not 'Tries')
torch/multiprocessing/spawn.py:166 in public class `SpawnContext`:
D101: Missing docstring in public class
torch/multiprocessing/spawn.py:167 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/spawn.py:180 in public function `start_processes`:
D103: Missing docstring in public function
19
```
**After: 13**
```
torch/multiprocessing/spawn.py:1 at module level:
D100: Missing docstring in public module
torch/multiprocessing/spawn.py:11 in public class `ProcessException`:
D101: Missing docstring in public class
torch/multiprocessing/spawn.py:14 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/spawn.py:20 in public method `__reduce__`:
D105: Missing docstring in magic method
torch/multiprocessing/spawn.py:27 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/spawn.py:41 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/spawn.py:53 in public method `__reduce__`:
D105: Missing docstring in magic method
torch/multiprocessing/spawn.py:79 in public class `ProcessContext`:
D101: Missing docstring in public class
torch/multiprocessing/spawn.py:80 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/spawn.py:87 in public method `pids`:
D102: Missing docstring in public method
torch/multiprocessing/spawn.py:161 in public class `SpawnContext`:
D101: Missing docstring in public class
torch/multiprocessing/spawn.py:162 in public method `__init__`:
D107: Missing docstring in __init__
torch/multiprocessing/spawn.py:175 in public function `start_processes`:
D103: Missing docstring in public function
13
```
- `torch/multiprocessing/__init__.py` </br>
**Before: 0**
```
torch/multiprocessing/__init__.py:1 at module level:
D205: 1 blank line required between summary line and description (found 0)
torch/multiprocessing/__init__.py:1 at module level:
D400: First line should end with a period (not '`')
torch/multiprocessing/__init__.py:57 in public function `set_sharing_strategy`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
torch/multiprocessing/__init__.py:69 in public function `get_sharing_strategy`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
torch/multiprocessing/__init__.py:74 in public function `get_all_sharing_strategies`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
5
```
**After: 0**
- `torch/nn/__init__.py` </br>
**Before: 3**
```
torch/nn/__init__.py:1 at module level:
D104: Missing docstring in public package
torch/nn/__init__.py:14 in public function `factory_kwargs`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/__init__.py:14 in public function `factory_kwargs`:
D400: First line should end with a period (not 'd')
3
```
**After: 1**
```
torch/nn/__init__.py:1 at module level:
D104: Missing docstring in public package
1
```
- `torch/nn/cpp.py` </br>
**Before: 16**
```
torch/nn/cpp.py:7 in public class `OrderedDictWrapper`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/cpp.py:7 in public class `OrderedDictWrapper`:
D400: First line should end with a period (not 'e')
torch/nn/cpp.py:16 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/cpp.py:21 in public method `cpp_dict`:
D102: Missing docstring in public method
torch/nn/cpp.py:27 in public method `items`:
D102: Missing docstring in public method
torch/nn/cpp.py:30 in public method `keys`:
D102: Missing docstring in public method
torch/nn/cpp.py:33 in public method `values`:
D102: Missing docstring in public method
torch/nn/cpp.py:36 in public method `__iter__`:
D105: Missing docstring in magic method
torch/nn/cpp.py:39 in public method `__len__`:
D105: Missing docstring in magic method
torch/nn/cpp.py:42 in public method `__contains__`:
D105: Missing docstring in magic method
torch/nn/cpp.py:45 in public method `__getitem__`:
D105: Missing docstring in magic method
torch/nn/cpp.py:50 in public class `ModuleWrapper`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/cpp.py:50 in public class `ModuleWrapper`:
D400: First line should end with a period (not 'd')
torch/nn/cpp.py:55 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/cpp.py:83 in public method `training`:
D102: Missing docstring in public method
torch/nn/cpp.py:90 in public method `__repr__`:
D105: Missing docstring in magic method
16
```
**After: 12**
```
torch/nn/cpp.py:16 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/cpp.py:21 in public method `cpp_dict`:
D102: Missing docstring in public method
torch/nn/cpp.py:27 in public method `items`:
D102: Missing docstring in public method
torch/nn/cpp.py:30 in public method `keys`:
D102: Missing docstring in public method
torch/nn/cpp.py:33 in public method `values`:
D102: Missing docstring in public method
torch/nn/cpp.py:36 in public method `__iter__`:
D105: Missing docstring in magic method
torch/nn/cpp.py:39 in public method `__len__`:
D105: Missing docstring in magic method
torch/nn/cpp.py:42 in public method `__contains__`:
D105: Missing docstring in magic method
torch/nn/cpp.py:45 in public method `__getitem__`:
D105: Missing docstring in magic method
torch/nn/cpp.py:52 in public method `__init__`:
D107: Missing docstring in __init__
torch/nn/cpp.py:80 in public method `training`:
D102: Missing docstring in public method
torch/nn/cpp.py:87 in public method `__repr__`:
D105: Missing docstring in magic method
12
```
- `torch/nn/grad.py` </br>
**Before: 10**
```
torch/nn/grad.py:1 at module level:
D400: First line should end with a period (not 'e')
torch/nn/grad.py:8 in public function `conv1d_input`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/grad.py:8 in public function `conv1d_input`:
D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch/nn/grad.py:40 in public function `conv1d_weight`:
D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch/nn/grad.py:71 in public function `conv2d_input`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/grad.py:71 in public function `conv2d_input`:
D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch/nn/grad.py:103 in public function `conv2d_weight`:
D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch/nn/grad.py:134 in public function `conv3d_input`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/grad.py:134 in public function `conv3d_input`:
D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch/nn/grad.py:166 in public function `conv3d_weight`:
D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
10
```
**After: 0**
- `torch/nn/parameter.py` </br>
**Before: 17**
```
torch/nn/parameter.py:1 at module level:
D100: Missing docstring in public module
torch/nn/parameter.py:14 in public class `Parameter`:
D204: 1 blank line required after class docstring (found 0)
torch/nn/parameter.py:33 in public method `__new__`:
D102: Missing docstring in public method
torch/nn/parameter.py:54 in public method `__deepcopy__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:62 in public method `__repr__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:65 in public method `__reduce_ex__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:84 in public class `UninitializedTensorMixin`:
D101: Missing docstring in public class
torch/nn/parameter.py:105 in public method `materialize`:
D205: 1 blank line required between summary line and description (found 0)
torch/nn/parameter.py:125 in public method `shape`:
D102: Missing docstring in public method
torch/nn/parameter.py:132 in public method `share_memory_`:
D102: Missing docstring in public method
torch/nn/parameter.py:138 in public method `__repr__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:141 in public method `__reduce_ex__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:149 in public method `__torch_function__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:164 in public function `is_lazy`:
D103: Missing docstring in public function
torch/nn/parameter.py:186 in public method `__new__`:
D102: Missing docstring in public method
torch/nn/parameter.py:191 in public method `__deepcopy__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:217 in public method `__new__`:
D102: Missing docstring in public method
17
```
**After: 15**
```
torch/nn/parameter.py:1 at module level:
D100: Missing docstring in public module
torch/nn/parameter.py:34 in public method `__new__`:
D102: Missing docstring in public method
torch/nn/parameter.py:55 in public method `__deepcopy__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:63 in public method `__repr__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:66 in public method `__reduce_ex__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:85 in public class `UninitializedTensorMixin`:
D101: Missing docstring in public class
torch/nn/parameter.py:127 in public method `shape`:
D102: Missing docstring in public method
torch/nn/parameter.py:134 in public method `share_memory_`:
D102: Missing docstring in public method
torch/nn/parameter.py:140 in public method `__repr__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:143 in public method `__reduce_ex__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:151 in public method `__torch_function__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:166 in public function `is_lazy`:
D103: Missing docstring in public function
torch/nn/parameter.py:188 in public method `__new__`:
D102: Missing docstring in public method
torch/nn/parameter.py:193 in public method `__deepcopy__`:
D105: Missing docstring in magic method
torch/nn/parameter.py:219 in public method `__new__`:
D102: Missing docstring in public method
15
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113052
Approved by: https://github.com/mikaylagawarecki, https://github.com/soulitzer
This can be useful for advanced users (like AOTAutograd) who don't want to keep the corresponding Tensor alive (for memory reasons for example) or when inplace op will change the Tensor's grad_fn (but gradients wrt to the original value is needed).
I went minimal API change but open to suggestions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110867
Approved by: https://github.com/soulitzer
Fixes #ISSUE_NUMBER
1、the class named "Type" has not been used anymore in anywhere, so I add warning message to remove it in the future.
2、add a arg(default is "cuda") for save_on_cpu so that it can support more device type (like privateuse1)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103245
Approved by: https://github.com/soulitzer
This PR:
- registers all of the codegened Nodes to the torch._C._functions module, this is where special nodes like AccumulateGrad are already registered.
- creates a autograd.graph.Node abstract base class that all of the newly registered nodes subclass from. We make the subclassing happen by implementing the ``__subclasshook__`` method
- enables static type checking to work and also enables Sphinx to generate documentation for the Node and its methods
- handles both the custom Function and codegened cases
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91475
Approved by: https://github.com/albanD
The rationale for this is that functorch doesn't work with saved
variable hooks at the moment or checkpointing and we need some way to
disable it.
Concretely:
- there's a context manager that does the disabling
- this feature is disabled on a thread-local basis
- one can set an error message or use the default error message that
says the feature has been disabled
Since it is thread local I needed to update ATen/ThreadLocalState. To
make things nicer, this PR refactors all the "saved tensors hooks"
related TLS things into a single struct.
Test Plan:
- new test
Differential Revision: [D39970936](https://our.internmc.facebook.com/intern/diff/D39970936)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85971
Approved by: https://github.com/albanD, https://github.com/soulitzer