mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
RFC: https://github.com/pytorch/rfcs/pull/54 First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/ We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core. In the next commits, I do a number of things in this order - Fix a few small issues - Make the tests that this PR adds pass - Bend backwards until lintrunner passes - Remove the optional dependency on `torch_np` and simply rely on the upstreamed code - Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now. Missing from this PR (but not blocking): - Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate. - https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge. All the tests in `tests/torch_np` take about 75s to run. This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211 Approved by: https://github.com/ezyang
80 lines
2.1 KiB
Python
80 lines
2.1 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import sys
|
|
|
|
import pytest
|
|
|
|
import torch._numpy as tnp
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.addinivalue_line("markers", "slow: very slow tests")
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
parser.addoption("--runslow", action="store_true", help="run slow tests")
|
|
parser.addoption("--nonp", action="store_true", help="error when NumPy is accessed")
|
|
|
|
|
|
class Inaccessible:
|
|
def __getattribute__(self, attr):
|
|
raise RuntimeError(f"Using --nonp but accessed np.{attr}")
|
|
|
|
|
|
def pytest_sessionstart(session):
|
|
if session.config.getoption("--nonp"):
|
|
sys.modules["numpy"] = Inaccessible()
|
|
|
|
|
|
def pytest_generate_tests(metafunc):
|
|
"""
|
|
Hook to parametrize test cases
|
|
See https://docs.pytest.org/en/6.2.x/parametrize.html#pytest-generate-tests
|
|
|
|
The logic here allows us to test with both NumPy-proper and torch._numpy.
|
|
Normally we'd just test torch._numpy, e.g.
|
|
|
|
import torch._numpy as np
|
|
...
|
|
def test_foo():
|
|
np.array([42])
|
|
...
|
|
|
|
but this hook allows us to test NumPy-proper as well, e.g.
|
|
|
|
def test_foo(np):
|
|
np.array([42])
|
|
...
|
|
|
|
np is a pytest parameter, which is either NumPy-proper or torch._numpy. This
|
|
allows us to sanity check our own tests, so that tested behaviour is
|
|
consistent with NumPy-proper.
|
|
|
|
pytest will have test names respective to the library being tested, e.g.
|
|
|
|
$ pytest --collect-only
|
|
test_foo[torch._numpy]
|
|
test_foo[numpy]
|
|
|
|
"""
|
|
np_params = [tnp]
|
|
|
|
try:
|
|
import numpy as np
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
if not isinstance(np, Inaccessible): # i.e. --nonp was used
|
|
np_params.append(np)
|
|
|
|
if "np" in metafunc.fixturenames:
|
|
metafunc.parametrize("np", np_params)
|
|
|
|
|
|
def pytest_collection_modifyitems(config, items):
|
|
if not config.getoption("--runslow"):
|
|
skip_slow = pytest.mark.skip(reason="slow test, use --runslow to run")
|
|
for item in items:
|
|
if "slow" in item.keywords:
|
|
item.add_marker(skip_slow)
|