This PR removes the autograd.Function extension feature flag. This was
previously used for development of the functorch <> autograd.Function
interaction.
It's been in master for long enough with the feature flag defaulting to
True, so it's time to remove it.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92026
Approved by: https://github.com/soulitzer
The autograd.Function <> functorch interaction is in a mostly completed
state now. There are some minor action items remaining
(https://github.com/pytorch/pytorch/issues/90224), but I want to enable
the feature by default so that PyTorch CI / other parties / etc can
begin testing to see if there is any impact on the original
autograd.Function API (there shouldn't be).
The longer-term plan for the feature flag is:
- keep it around until at least the next release (so that people can
turn off the feature if it breaks something in existing code)
- delete the flag then (either before or after the release, I haven't
decided yet)
Test Plan:
- new test
- wait for CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91441
Approved by: https://github.com/albanD, https://github.com/soulitzer
This allows to know at any point during the backward pass what is running and where the Node currently running was created at:
```python
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.autograd import detect_anomaly
class MyMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
node = torch._C._current_autograd_node()
print(f"Running {func} from within {node}")
if node is not None:
print("The Node was created at:")
print("\n ".join(node.metadata["traceback_"]))
return func(*args, **kwargs or {})
with MyMode(), detect_anomaly():
print("FW")
a = torch.rand(10, requires_grad=True)
b = a.mul(2)
b = b.div(3)
b = b.sum()
print("BW")
b.backward()
```
Gives
```
$ python foo.py
foo.py:15: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
with MyMode(), detect_anomaly():
FW
Running aten.rand.default from within None
Running aten.mul.Tensor from within None
Running aten.div.Tensor from within None
Running aten.sum.default from within None
BW
Running aten.ones_like.default from within None
Running aten.expand.default from within <SumBackward0 object at 0x7fa40c0c6dc0>
The Node was created at:
File "foo.py", line 20, in <module>
b = b.sum()
Running aten.isnan.default from within <SumBackward0 object at 0x7fa40c0c6500>
The Node was created at:
File "foo.py", line 20, in <module>
b = b.sum()
Running aten.any.default from within <SumBackward0 object at 0x7fa32b23a780>
The Node was created at:
File "foo.py", line 20, in <module>
b = b.sum()
Running aten._local_scalar_dense.default from within <SumBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 20, in <module>
b = b.sum()
Running aten.div.Tensor from within <DivBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 19, in <module>
b = b.div(3)
Running aten.isnan.default from within <DivBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 19, in <module>
b = b.div(3)
Running aten.any.default from within <DivBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 19, in <module>
b = b.div(3)
Running aten._local_scalar_dense.default from within <DivBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 19, in <module>
b = b.div(3)
Running aten.mul.Tensor from within <MulBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
Running aten.isnan.default from within <MulBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
Running aten.any.default from within <MulBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
Running aten._local_scalar_dense.default from within <MulBackward0 object at 0x7fa40c0c9190>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
Running aten.detach.default from within <AccumulateGrad object at 0x7fa40c0c9730>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
Running aten.detach.default from within <AccumulateGrad object at 0x7fa40c0c94b0>
The Node was created at:
File "foo.py", line 18, in <module>
b = a.mul(2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90867
Approved by: https://github.com/soulitzer
This PR adds a private runtime feature flag for the feature work we're going
to do with extending autograd.Function. The motivation of the feature flag
is:
- to guard the feature against unsuspecting users
- control the release of the feature to when we are ready to release it
We might not even need the feature flag (because we hope to have the
work done in the next month), but it is good practice and it does touch
currently public API (autograd.Function).
Concretely, "autograd.Function extension" refers to:
- adding an optional `setup_context` staticmethod to autograd.Function
- adding an optional `vmap` staticmethod to autograd.Function
- autograd.Function support for functorch
Test Plan:
- new test that the feature flag works
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89858
Approved by: https://github.com/soulitzer
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`
All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008
Reviewed By: driazati, r-barnes
Differential Revision: D29838584
Pulled By: malfet
fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59017
See the comment in ThreadLocal.h for context.
I used a slightly dirty preprocessor hack to minimize the number of changes.
The hope is that we'll be able to revert all of these soon.
Test Plan:
CI.
Built FB4A with gnustl and saw no references to cxa_thread_atexit
in the PyTorch libraries.
Reviewed By: ilia-cher
Differential Revision: D28720762
fbshipit-source-id: 0f13c7ac5a108b95f8fde6dbc63c6b8bdb8599de
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
Summary:
Fixes https://github.com/pytorch/pytorch/issues/43405.
This pull request adds a feature of printing all tracebacks if a `detect_anomaly` mode detects `nan` in nested backward operations.
The way I did it is by assigning a node as a parent to all nodes it produces during its backward calculation. Then if one of the children produces `nan`, it will print the traceback from the parent and grand parents (if any).
The parent is assigned in `parent_node_` member in `Node` class which is accessible in C++ by function `node->parent()` and in Python by `node.parent_function`.
A node has a parent iff:
1. it is created from a backward operation, and
2. created when anomaly mode and grad mode are both enabled.
An example of this feature:
import torch
def example():
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(1e-8, requires_grad=True) # small to induce nan in n-th backward
a = x * y
b = x * y
z1 = a / b # can produce nan in n-th backward as long as https://github.com/pytorch/pytorch/issues/43414 is unsolved
z = z1 * z1
gy , = torch.autograd.grad( z , (y,), create_graph=True)
gy2, = torch.autograd.grad(gy , (y,), create_graph=True)
gy3, = torch.autograd.grad(gy2, (y,), create_graph=True)
gy4, = torch.autograd.grad(gy3, (y,), create_graph=True)
return gy4
with torch.autograd.detect_anomaly():
gy4 = example()
with output:
example.py:16: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
with torch.autograd.detect_anomaly():
/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning: Error detected in DivBackward0. Traceback of forward call that caused the error:
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 12, in example
gy3, = torch.autograd.grad(gy2, (y,), create_graph=True)
File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad
return Variable._execution_engine.run_backward(
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:61.)
return Variable._execution_engine.run_backward(
/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning:
Traceback of forward call that induces the previous calculation:
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 11, in example
gy2, = torch.autograd.grad(gy , (y,), create_graph=True)
File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad
return Variable._execution_engine.run_backward(
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:65.)
return Variable._execution_engine.run_backward(
/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning:
Traceback of forward call that induces the previous calculation:
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 8, in example
z1 = a / b # can produce nan in n-th backward as long as https://github.com/pytorch/pytorch/issues/43414 is unsolved
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:65.)
return Variable._execution_engine.run_backward(
Traceback (most recent call last):
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 13, in example
gy4, = torch.autograd.grad(gy3, (y,), create_graph=True)
File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad
return Variable._execution_engine.run_backward(
RuntimeError: Function 'DivBackward0' returned nan values in its 1th output.
cc & thanks to albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43626
Reviewed By: malfet
Differential Revision: D23397499
Pulled By: albanD
fbshipit-source-id: aa7435ec2a7f0d23a7a02ab7db751c198faf3b7d
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37587
Lifting RecordFunction up into the dispatcher code
Test Plan: Imported from OSS
Differential Revision: D21374246
fbshipit-source-id: 19f9c1719e6fd3990e451c5bbd771121e91128f7
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
Summary:
There are still a few work to be done:
- Move logging and unify AT_WARN with LOG(ERROR).
- A few header files are still being plumbed through, need cleaning.
- caffe2::EnforceNotMet aliasing is not done yet.
- need to unify the macros. See c10/util/Exception.h
This is mainly a codemod and not causing functional changes. If you find your job failing and trace back to this diff, usually it can be fixed by the following approaches:
(1) add //caffe2/c10:c10 to your dependency (or transitive dependency).
(2) change objects such as at::Error, at::Optional to the c10 namespace.
(3) change functions to the c10 namespace. Especially, caffe2::MakeString is not overridden by the unified c10::str function. Nothing else changes.
Please kindly consider not reverting this diff - it involves multiple rounds of rebasing and the fix is usually simple. Contact jiayq@ or AI Platform Dev for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12354
Reviewed By: orionr
Differential Revision: D10238910
Pulled By: Yangqing
fbshipit-source-id: 7794d5bf2797ab0ca6ebaccaa2f7ebbd50ff8f32
Summary:
This eliminates the need for any heuristics regarding stack size limits.
This is a re-do #11534 with a fix to properly handle cases where
multiple edges exist between a pair of functions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11611
Differential Revision: D9991198
Pulled By: resistor
fbshipit-source-id: fecd2c5cac7e78f82a0f20cf33268bb1617bb4a0
Summary:
Often, we find ourselves looking at some long-running kernel or emit_nvtx range on an nvvp profile and trying to connect it to the offending line in a training script. If the op is in the forward pass that's easy: ops are enqueued explicitly from the Python side, so tracking it down with manual nvtx ranges supplemented by the built-in emit_nvtx ranges is straightforward. If the op is in the backward pass, it's much more difficult. From the Python side, all you can do is wrap loss.backward() in an nvtx range, and if you also use emit_nvtx, the automatic ranges provide only local information. Right now, the only consistent way to connect backward-pass kernels to their associated forward-pass lines of Python is to understand your script line by line, and know exactly where in the backward pass you are.
This PR augments the existing nvtx machinery to bridge the gap between forward and backward, allowing connection of backward-pass Function apply calls to the forward-pass operations that required/created those Functions.
The method is simple and surgical. During the forward pass, when running with emit_nvtx, the nvtx range for each function in VariableType is tagged with the current sequence number. During the backward pass, the nvtx range associated with each Function's operator() is tagged with that Function's stashed sequence number, which can be compared to "current sequence numbers" from the forward pass to locate the associated op.
Double-backward is not a problem. If a backward pass with create_graph = True is underway, the relationship between backward and double-backward is conceptually the same as the relationship between forward and backward: The functions in VariableType still spit out current-sequence-number-tagged ranges, the Function objects they create still stash those sequence numbers, and in the eventual double-backward execution, their operator() ranges are still tagged with the stashed numbers, which can be compared to "current sequence numbers" from the backward pass.
Minor caveats:
- The sequence number is thread-local, and many VariableType functions (specifically, those without a derivative explicitly defined in derivatives.yaml) don't create an associated function object (instead delegating that to sub-functions further down the call chain, perhaps called from within at::native functions that route back through VariableType by calling at::function_name). So the correspondence of stashed sequence numbers in Function operator() ranges with numbers in forward-pass ranges is not guaranteed to be 1 to 1. However, it's still a vast improvement over the current situation, and I don't think this issue should be a blocker.
- Feel free to litigate my use of stringstream in profiler.cpp. I did it because it was easy and clean. If that's too big a hammer, let's figure out something more lightweight.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10881
Differential Revision: D9833371
Pulled By: apaszke
fbshipit-source-id: 1844f2e697117880ef5e31394e36e801d1de6088
Summary:
This eliminates the need for any heuristics regarding stack size limits.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11534
Differential Revision: D9779866
Pulled By: resistor
fbshipit-source-id: 96753eead7904bbdc2869fb01f7bd42141032347
Summary:
Linting `torch/csrc/` (non-recursive) and `torch/csrc/autograd` (non-recursive).
Fixed things like:
- `typedef` vs `using`
- Use `.empty()` instead of comparing with empty string/using `.size() == 0`
- Use range for loops instead of old style loops (`modernize-`)
- Remove some `virtual` + `override`
- Replace `stdint.h` with `cstdint`
- Replace `return Type(x, y)` with `return {x, y}`
- Use boolean values (`true`/`false`) instead of numbers (1/0)
- More ...
ezyang apaszke cpuhrsch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11050
Differential Revision: D9597505
Pulled By: goldsborough
fbshipit-source-id: cb0fb4793ade885a8dbf4b10484487b84c64c7f2
Summary:
Prior to this diff, there have been two ways of compiling the bulk of the torch codebase. There was no interaction between them - you had to pick one or the other.
1) with setup.py. This method
- used the setuptools C extension functionality
- worked on all platforms
- did not build test_jit/test_api binaries
- did not include the C++ api
- always included python functionality
- produced _C.so
2) with cpp_build. This method
- used CMake
- did not support Windows or ROCM
- was capable of building the test binaries
- included the C++ api
- did not build the python functionality
- produced libtorch.so
This diff combines the two.
1) cpp_build/CMakeLists.txt has become torch/CMakeLists.txt. This build
- is CMake-based
- works on all platforms
- builds the test binaries
- includes the C++ api
- does not include the python functionality
- produces libtorch.so
2) the setup.py build
- compiles the python functionality
- calls into the CMake build to build libtorch.so
- produces _C.so, which has a dependency on libtorch.so
In terms of code changes, this mostly means extending the cmake build to support the full variety of environments and platforms. There are also a small number of changes related to the fact that there are now two shared objects - in particular, windows requires annotating some symbols with dllimport/dllexport, and doesn't allow exposing thread_local globals directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8792
Reviewed By: ezyang
Differential Revision: D8764181
Pulled By: anderspapitto
fbshipit-source-id: abec43834f739049da25f4583a0794b38eb0a94f
* Prevent stack overflow on deletion of deep graph
Fixes#5534.
Sometimes one can end up with a very big computation graph of Functions
and Edges. Each std::shared_ptr<Function> contains a list of Edge, and
each Edge contains a std::shared_ptr<Function>. Deleting a
std::shared_ptr<Function> can trigger the recursive deletion of other
std::shared_ptr<Function>'s: this can stack overflow if the graph
is deep enough. Here is an example of such a graph:
shared_ptr<Function> -> Edge -> shared_ptr<Function> -> Edge -> ... -> shared_ptr<Function>
The solution here is to use a custom deleter with each
std::shared_ptr<Function>. The custom deleter keeps track of how many
nested deleters it is in. When this number exceeds the maximum allowed
depth, the Function* to be deleted are accumulated in a per-thread
delete queue and handled by one of the deleters.
Example code that could trigger the overflow (set ``depth`` to something >
100000) is below. I also benchmarked the below code before/after the
changes to see if there are any significant performance differences.
```
import torch
def scope():
depth = 80000
x = torch.randn(9, requires_grad=True)
y = x.clone()
# build deeply nested computation graph
for i in range(depth):
y = y + y * 0.000001
%timeit -n 100 scope()
376 ms ± 3.94 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
Without changes:
352 ms ± 6.58 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
With the change, the above code is 6.8% slower.
UPDATE: I did some more benchmarking. It looks like it takes 25% more time to free the computation graph in the case of the straight chain graph: https://gist.github.com/zou3519/93cf84d96ae431356ae7f7c1923ef51a
* WIP
* Add custom deleter to PyFunctions created by THPFunction
* Address some comments; pick new value
* Address some more comments
* Add more complicated test; special case the windows depth constant
This PR adds the possibility to build the C++ parts of autograd and jit, with no dependency on Python.
The goal is to allow taking a PyTorch IR representation (a tree s-expr) and running it with provided inputs.
Prerequisite: build PyTorch so that codegen runs once.
Instructions:
cd tools/cpp_build
bash build_all.sh
This will build libtorchjit and torchjit_test in tools/cpp_build/build/torchjit-build. The latter basically runs the code in test_jit.cpp for now.
While writing the PR, it turned out that a few of Python.h includes were redundant. They were removed here (PyTorch tests still pass on my machine, we'll see CI).
* Introduce Python-free builds of autograd and jit
* Remove NO_PYTHON ifdef in functions/special
* Add source information to IR nodes
SourceRange information from the script is not propagated to IR nodes.
This information is only used in two places now: the interpreter
wraps errors that occur when an instruction executions and shape
propagation now reports errors on the line where it fails:
Traceback (most recent call last):
File "test/test_jit.py", line 1655, in test_script_error
bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True))
RuntimeError:
The size of tensor a (10) must match the size of tensor b (9) at non-singleton dimension 0:
@torch.jit.script
def bar(c, b):
return c / b
~~~~~ <--- HERE
In the future, shape propagation should really not report any size
errors and instead just not propagate shapes and let the actual
execution fail. However, this is hard to accomplish while we still
depend on running the op to do shape propagation.
Support shape propagation with control-flow
* This allows us to enable optimization in the GraphExecutor for most
script tests.
* Changes Type to always be present (non-null) on a Value, removing `hasType()`
and `typeOption()`. A new type kind 'DynamicType' now represents when
a specific type has not been determined.
* If/Loop nodes propagate shapes/types in the simple cases where types of
outputs do not change depending on where control flows. In other
cases, we propagate DynamicType to indicate we do not know what
the shape will be.
* Remove the `cond` input to the body of Loop to simplify handling in
interpreter and shape propagation.
* Bugfix for zero-dim contiguousStridesOf
* Improve Function interface
* Undo tracer changes
* Fix bug in VariableType.set_history
* Rename function_counter and sequence_number to sequence_nr
* Clarify Function documentation
* Replace swap_next_edges with next_edges() getter
* Bring back set_gradient_edge
* Simplify special.cpp
* add_gradient_edge -> create_gradient_edge
* Add mutable getters for pre/post hooks
* Use make_variable with Edge
* Remove remove_gradient_edge in favor of detach_
* Fix documentation and remove create_gradient_edge friend method
* Canonicalize some includes
Suppose you are given a list of arguments, each of which may be Tensor or
TensorList. How can you write a function that can treat these arguments
uniformly as a list of tensors? This patch solves the problem using
variadic templates.
Why variadic templates? Use of variadic templates means anyone working
with this code has to understand universal references, perfect
forwarding, parameter packs and some idioms of C++ template design.
However, I argue that variadic templates are the *right* tool for
supporting the implementation of functions which must take an
arbitrarily heterogenous set of inputs. We were able to limp by
in old code because, for the most part, tensor inputs were homogenous,
but this is no longer the case for some non-primitively differentiable
functions; and with the upcoming cuDNN RNN in ATen PR, will no longer be
the case for primitively differentiable functions too.
There are two parts to the PR.
First, we add torch/csrc/utils/variadic.h, which defines a mix-in
IterArgs that takes any class which supports operator(), and augments
with a new variadic function apply() which calls operator() on each
argument passed to it. In an original draft of the patch, I wrote the
recursion for each parameter pack from scratch for each function;
however, it turns out there are no fewer than seven instances where we
need this idiom, and the mix-in reduces the lines of code, and also
helps centralize the most important (and easy to forget) boilerplate
for perfect forwarding.
To verify that IterArgs is compiled away into an unrolled form per
call site, I inspected the assembly on some synthetic examples.
Next, we modify the following functions to make use of IterArgs:
- compute_requires_grad
- Function::flags (Variable and Tensor variants)
- flatten
- isTracing
- count_tensors / count_variables
Finally, the tuple packer is rewritten to be variadic, although we
cannot make use of IterArgs (since we are given a tuple). It might
make sense to refactor the code into a generic piece which invokes
a function with the arguments specified by a tuple, and then an
appropriate IterArgs, but we leave this for future work.
One thing to note: we cannot write a function with overloads for both
Tensor and Variable, because both ArrayRef<Variable> and Tensor have
implicit conversions from Variable, making such an overload ambiguous.
It may be interesting to remove the implicit conversion from ArrayRef.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
This removes volatile from Variable. The functionality is mostly
replaced by a global (thread-local) flag, which is controlled by
torch.set_grad_enabled() and the context manager torch.no_grad().
In C++, the flag is exposed through GradMode::is_enabled() and GradMode::set_enabled()
Fixes#3627
* Add interpreter support for Handles/PythonOp/CppOp
This treats Handles as a first-class type in the interpreter
since this turned out to be conceptually simpler than treating
them as a separate concept, which requires a second channel for
register allocating and moving data from one op to the next.
Notes:
* The refcounting nature of tensors is factored into its own base type
so that it can be shared with other refcounted types such as handle.
* Some methods redundant with TensorBase have been deleted from Tensor
* The interpreter uses raw refcounted handles. In addition to being
able to treat Tensors and Handles as the same base object, it removes
a lot of redundant refcounting as objects moved from tensors to input/
output lists.
* aten_dispatch has been updated to work directly on the raw refcounted
lists to avoid refcounting and duplicate lists.
* Removing jit_closure.cpp, The interpreter can now handle all pathways.
* Functions like `unsafeToTensorShare` describe how
ownership transfers in the interpreter. The `Steal` variants
take rvalue references as arguments, and invalidate those
arguments to prevent potential problems.
* Make TensorTemporary is not a subtype relationship because it is too easy to
do something horribly unsafe:
```
void foo(at::Tensor bar) {
// bar destructor call release on a temporary!
}
foo(TensorTemporary(retainable)); // structure slicing!
```
This commit adds a Value type similar to the one @ezyang suggested a while
ago for handling multi-return nodes.
Previously if we had a graph like:
a = op1(b)
c, d = op2(a)
Then its in-memory format would look like:
%0 = op1(b)
%1 = op2(%0)
%2 = select(%1, 0)
%2 = select(%1, 1)
Select nodes were used only to handle the multi-output case. In the
single-output case ops referred directly to their uses.
This required special handling for the single- and multi- output cases,
and was confusing when used with ONNX which distinguishes values (the
inputs/outputs of a node) from the nodes themselves (e.g. a Conv).
This commit adds the Node/Value distinction to the IR. In the example
above, `a`, `b`, `c`, and `d` are now Value objects, while `op1` and
`op2` are now Node objects. Inputs/Outputs to the graph are values.
* Nodes now always have multiple outputs, accessible through their `output()`
method.
* Methods exist for adding/removing outputs from a node.
* Nodes own their output Values, destroying a node destroys its outputs and it
is only valid to destroy a node when no uses of its outputs remain.
* Unlike select, Values do not appear in the nodes list.
* The method `node()` on `Value` retrieves its defining node. Calling it
is always valid. For inputs, its kind is "Param". Like "Return" there is a single Param
node representing all inputs.
* For single-output Nodes, the method `output()` retrieves the single
output Value, asserting that the node is in-fact single output.
* Functions are the same, but some functions like `type()` have moved to
Value.
* `replaceAllUsesWith` is now sanely defined for both Values and Nodes.
In the case of Nodes, it replaces all outputs of the node with the outputs
of the replacement node.
* stage is defined both on Node/Value. This is because Inputs require a stage.
* Apart from changing data types from Node->Value most passes remain the same.
Things that previously assumed single-output nodes now have to call output()
to get the node.
* This removes the uses = [...] field in the outputs because it was
getting confusing even before this commit when uses would refer to nodes,
but we print the names of Values. The lint pass validates the use list,
so printing it out seems less necessary.
* Generate torch.cat autograd via ATen.
Most of the change is around supporting generation of:
1) TensorList arguments
2) Arguments to "size", "sizes", i.e. "sizes(dim)"
To be honest, this was the whole point of this refactor set.
I noticed that in a lot of code, we were repeatedly copying lots of metadata
from old nodes to new nodes. This was quite concerning because I wanted to
add some more metadata (alias information) and I didn't want to have to
get it right in all cases. Plus, in a lot of cases we were forgetting
to set more optional properties like debug names when we "copied".
To solve this, I first made cloneFrom() copy all of this metadata. Then,
I searched for all occurrences of setType() (a proxy for "I'm cloning this
node), looked for cases where we really were morally doing a copy, and rewrote
the code to use cloneFrom() instead, allowing us to drop explicit setType()
(and getting more metadata preservation in the process.)
Finally, I refactored tryToMoveChunk. The code is modestly longer,
but the new version has the nice property that the initialization of
selects for input_chunk are next to the creation of the node (as opposed
to delayed for later.) I also added a lot more comments for invariants
I noticed when I was working on the code.
One minor extra change: TensorType grew a new constructor and a withSizesStride
"immutable setter" which returns a new copy of TensorType with different info.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Variable is now a subclass of at::Tensor backed by a VariableImpl* pImpl. The implementation of the ATen functions is defined in the auto-generated VariableType.h/cpp file.
Currently, only functions which fall through to the base type, such as sizes() and isCuda() are implemented. Differentiable ops like add() and mul() will be added in a subsequent PR.