Files
pytorch/torch/_numpy/_funcs.py
lezcano a9dca53438 NumPy support in torch.compile (#106211)
RFC: https://github.com/pytorch/rfcs/pull/54
First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/

We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core.

In the next commits, I do a number of things in this order
- Fix a few small issues
- Make the tests that this PR adds pass
- Bend backwards until lintrunner passes
- Remove the optional dependency on `torch_np` and simply rely on the upstreamed code
- Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now.

Missing from this PR (but not blocking):
- Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate.
- https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge.

All the tests in `tests/torch_np` take about 75s to run.

This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211
Approved by: https://github.com/ezyang
2023-08-11 00:39:32 +00:00

74 lines
2.0 KiB
Python

import inspect
import itertools
from . import _funcs_impl, _reductions_impl
from ._normalizations import normalizer
# _funcs_impl.py contains functions which mimic NumPy's eponymous equivalents,
# and consume/return PyTorch tensors/dtypes.
# They are also type annotated.
# Pull these functions from _funcs_impl and decorate them with @normalizer, which
# - Converts any input `np.ndarray`, `torch._numpy.ndarray`, list of lists, Python scalars, etc into a `torch.Tensor`.
# - Maps NumPy dtypes to PyTorch dtypes
# - If the input to the `axis` kwarg is an ndarray, it maps it into a tuple
# - Implements the semantics for the `out=` arg
# - Wraps back the outputs into `torch._numpy.ndarrays`
def _public_functions(mod):
def is_public_function(f):
return inspect.isfunction(f) and not f.__name__.startswith("_")
return inspect.getmembers(mod, is_public_function)
# We fill in __all__ in the loop below
__all__ = []
# decorate implementer functions with argument normalizers and export to the top namespace
for name, func in itertools.chain(
_public_functions(_funcs_impl), _public_functions(_reductions_impl)
):
if name in ["percentile", "quantile", "median"]:
decorated = normalizer(func, promote_scalar_result=True)
elif name == "einsum":
# normalized manually
decorated = func
else:
decorated = normalizer(func)
decorated.__qualname__ = name
decorated.__name__ = name
vars()[name] = decorated
__all__.append(name)
"""
Vendored objects from numpy.lib.index_tricks
"""
class IndexExpression:
"""
Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
last revision: 1999-7-23
Cosmetic changes by T. Oliphant 2001
"""
def __init__(self, maketuple):
self.maketuple = maketuple
def __getitem__(self, item):
if self.maketuple and not isinstance(item, tuple):
return (item,)
else:
return item
index_exp = IndexExpression(maketuple=True)
s_ = IndexExpression(maketuple=False)
__all__ += ["index_exp", "s_"]