Dynamo creates Tensors when tracing through numpy ufuncs like np.sin, np.minimum etc. When running, np functions generally return Tensors when run with `torch.compile`. However, we currently require when normalizing `out` arguments that the input is an ndarray. This creates assertion errors when running torch.compile on any numpy function with an out argument:
```
def test_numpy_ufunc_out(self):
@torch.compile(backend="eager")
def foo():
x = np.arange(5)
out = np.empty((x.shape[0], x.shape[0]))
res_out = np.sin(x, out=out)
assert res_out is out
foo()
```
Failure with stack trace: https://gist.github.com/jamesjwu/68e217638d735678b3de968584dba23f
Instead, we can wrap tensors in an ndarray in normalize_outarray to handle the case correctly. Fixing this resolves ~220 tests under dynamo_test_failures, but also exposes a followup bug.
In the presence of a graph break, ndarrays don't preserve their id, which can affect assertions and `is` checks between numpy arrays:
```
def test_x_and_out_broadcast(self, ufunc):
x = self.get_x(ufunc)
out = np.empty((x.shape[0], x.shape[0]))
x_b = np.broadcast_to(x, out.shape)
# ufunc is just np.sin here
res_out = ufunc(x, out=out)
res_bcast = ufunc(x_b)
# passes
assert res_out is out
graph_break()
# fails
assert res_out is out
```
Regular tensors preserve their id because Dynamo caches their example tensor values across a graph break. However, with ndarrays, we only store their converted tensor values, and construct new ndarrays around those values:
eebe7e1d37/torch/_dynamo/variables/builder.py (L1083)
Added a test with expected failure to showcase this — we can then fix that issue separately.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118248
Approved by: https://github.com/lezcano
This is a lot of files changed! Don't panic! Here's how it works:
* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.
In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.
The codemod was done with this script authored by GPT-4:
```
import glob
exclude_patterns = [
...
]
for pattern in exclude_patterns:
for filepath in glob.glob(pattern, recursive=True):
if filepath.endswith('.py'):
with open(filepath, 'r+') as f:
content = f.read()
f.seek(0, 0)
f.write('# mypy: ignore-errors\n\n' + content)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
Enable pylint rules `PLC0131` and `PLC0132`. There was a violation of the `PLC0132` so this commit also fixes it and enables the rules so the violation do not occur again. `PLC0205` checks accidentally setting your `__slots__` to a string which is almost always a bug.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115015
Approved by: https://github.com/jansel, https://github.com/malfet
Use conditional imports: when running under dynamo, import the original NumPy not torch._numpy. This is what we want to trace, not our implementation.
With this, the test suite passes with and without `PYTORCH_TEST_WITH_DYNAMO=1` (modulo a couple of test modules which are not meant to be compiled, e.g. `test_nep50_examples`). There are two new decorators, `x{fail,pass}ifTorchDynamo`, the `xpass` in most cases indicates a graph break and a fallback to eager for things we do not implement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110401
Approved by: https://github.com/lezcano
RFC: https://github.com/pytorch/rfcs/pull/54
First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/
We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core.
In the next commits, I do a number of things in this order
- Fix a few small issues
- Make the tests that this PR adds pass
- Bend backwards until lintrunner passes
- Remove the optional dependency on `torch_np` and simply rely on the upstreamed code
- Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now.
Missing from this PR (but not blocking):
- Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate.
- https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge.
All the tests in `tests/torch_np` take about 75s to run.
This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211
Approved by: https://github.com/ezyang