Files
pytorch/torch/_numpy
James Wu 41dfdde9f5 Handle some numpy functions with out arguments correctly in dynamo (#118248)
Dynamo creates Tensors when tracing through numpy ufuncs like np.sin, np.minimum etc. When running, np functions generally return Tensors when run with `torch.compile`. However, we currently require when normalizing `out` arguments that the input is an ndarray.  This creates assertion errors when running torch.compile on any numpy function with an out argument:
```
    def test_numpy_ufunc_out(self):
        @torch.compile(backend="eager")
        def foo():
            x = np.arange(5)
            out = np.empty((x.shape[0], x.shape[0]))
            res_out = np.sin(x, out=out)
            assert res_out is out
        foo()
```
Failure with stack trace: https://gist.github.com/jamesjwu/68e217638d735678b3de968584dba23f

Instead, we can wrap tensors in an ndarray in normalize_outarray to handle the case correctly. Fixing this resolves ~220 tests under dynamo_test_failures, but also exposes a followup bug.

In the presence of a graph break, ndarrays don't preserve their id, which can affect assertions and `is` checks between numpy arrays:
```
     def test_x_and_out_broadcast(self, ufunc):
        x = self.get_x(ufunc)
        out = np.empty((x.shape[0], x.shape[0]))

        x_b = np.broadcast_to(x, out.shape)
        # ufunc is just np.sin here
        res_out = ufunc(x, out=out)
        res_bcast = ufunc(x_b)
        # passes
        assert res_out is out
        graph_break()
        # fails
        assert res_out is out
```
Regular tensors preserve their id because Dynamo caches their example tensor values across a graph break. However, with ndarrays, we only store their converted tensor values, and construct new ndarrays around those values:
eebe7e1d37/torch/_dynamo/variables/builder.py (L1083)
Added a test with expected failure to showcase this — we can then fix that issue separately.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118248
Approved by: https://github.com/lezcano
2024-01-29 09:09:21 +00:00
..

NumPy <> PyTorch Compat Layer

This folder contains an implementation of (most of) the NumPy public API using PyTorch tensors. Note that this folder does not depend on NumPy in any way. This is a standalone implementation.

This implementation is used by Dynamo to through NumPy code and lower it into PyTorch code.

To see design decisions that went into this implementation, please see the rfc.

Structure of the code

This folder exports a drop-in replacement for the NumPy namespace and its modules linalg, fft and random via its __init__.py.

The implementation is split into files that work with PyTorch objects (PyTorch Tensors, dtypes, etc) and files that use these PyTorch-only files and convert them into functions/objects that can process all the types that the NumPy functions accept. In particular, they accept torch._numpy.dtypes or torch._numpy.ndarrays.

The PyTorch-only files are the *_impl.py files, while the wrapper files are those that do not have an *_impl.py. This creates a hierarchy, wherein, for example, _dtypes.py will import _dtypes_impl.py, but not the other way around. In particular, *_impl.py will only depend on other *_impl.py files.

As discussed in the rfc, we use types as tags in our PyTorch implementations. We then use a decorator called normalizer that will inspect these types and preprocess the inputs before sending them to the function. This preprocessing is the one in charge of mapping array-like objects into Tensors, dtype-like objects into PyTorch dtypes, implement the out= behaviour and so on.

In the files _funcs.py and _ufuncs.py we use register the normalizer decorator to all the *_impl.py functions.

In the file _ndarray.py we define the ndarray class, which is just a thin wrapper around a PyTorch tensor. We use the free functions and a bit of metaprogramming to implement many of the methods.

Adding a new function

You just need to add a function in the relevant *_impl.py file. You will need to tag the inputs with the relevant Types. After that, you can assume that the inputs are all PyTorch objects. Your function should return PyTorch tensors. The normalizer will make sure that you always get PyTorch objects. If in doubt, you can see the implementation of the normalization attached to each type annotation in the file _normalizations.py.

Debugging

It may be useful to figure out whether a given bug is caused by dynamo or the compatibility layer. You may use the compat layer in eager mode simply by changing import numpy as np by import torch._numpy as np in your program, without having to call torch.compile at all. Note that torch._numpy will be quite slow when used in eager mode, and it is in no way a replacement or an alternative to the regular PyTorch API. This should only be used as a debugging tool.