Files
pytorch/test/conftest.py
Catherine Lee d21577f28c Run more tests through pytest (#95844)
Run more tests through pytest.

Use a block list for tests that shouldn't run through pytest.  As far as I can tell, the number of tests run, skipped, and xfailed for those not on the blocklist are the same.

Regarding the main module:

Usually tests are run in CI, we call `python <test file>`, which causes the file to be imported under the module name `__main__`.  However, pytest searches for the module to be imported under the file name, so the file will be reimported.  This can cause issues for tests that run module level code and change global state, like test_nn, which modifies lists imported from another file, or tests in test/lazy, which initialize a backend that cannot coexist with a second copy of itself.

My workaround for this is to run tests from the `__main__` module.  However, this results in pytest being unable to rewrite assertions (and possibly other things but I don't know what other things pytest does right now).  A better solution might be to call `pytest <test file>` directly and move all the code in run_tests(argv) to be module level code or put it in a hook in conftest.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95844
Approved by: https://github.com/huydhn
2023-03-03 17:32:26 +00:00

182 lines
6.4 KiB
Python

from _pytest.junitxml import LogXML, _NodeReporter, bin_xml_escape
from _pytest.terminal import _get_raw_skip_reason
from _pytest.stash import StashKey
from _pytest.reports import TestReport
from _pytest.config.argparsing import Parser
from _pytest.config import filename_arg
from _pytest.config import Config
from _pytest._code.code import ReprFileLocation
from _pytest.python import Module
from typing import Union
from typing import Optional
from types import MethodType
import xml.etree.ElementTree as ET
import functools
import pytest
import sys
# a lot of this file is copied from _pytest.junitxml and modified to get rerun info
xml_key = StashKey["LogXMLReruns"]()
def pytest_addoption(parser: Parser) -> None:
parser.addoption("--use-main-module", action='store_true')
group = parser.getgroup("terminal reporting")
group.addoption(
"--junit-xml-reruns",
action="store",
dest="xmlpath_reruns",
metavar="path",
type=functools.partial(filename_arg, optname="--junit-xml-reruns"),
default=None,
help="create junit-xml style report file at given path.",
)
group.addoption(
"--junit-prefix-reruns",
action="store",
metavar="str",
default=None,
help="prepend prefix to classnames in junit-xml output",
)
parser.addini(
"junit_suite_name_reruns", "Test suite name for JUnit report", default="pytest"
)
parser.addini(
"junit_logging_reruns",
"Write captured log messages to JUnit report: "
"one of no|log|system-out|system-err|out-err|all",
default="no",
)
parser.addini(
"junit_log_passing_tests_reruns",
"Capture log information for passing tests to JUnit report: ",
type="bool",
default=True,
)
parser.addini(
"junit_duration_report_reruns",
"Duration time to report: one of total|call",
default="total",
)
parser.addini(
"junit_family_reruns",
"Emit XML for schema: one of legacy|xunit1|xunit2",
default="xunit2",
)
def pytest_configure(config: Config) -> None:
xmlpath = config.option.xmlpath_reruns
# Prevent opening xmllog on worker nodes (xdist).
if xmlpath and not hasattr(config, "workerinput"):
junit_family = config.getini("junit_family_reruns")
config.stash[xml_key] = LogXMLReruns(
xmlpath,
config.option.junitprefix,
config.getini("junit_suite_name_reruns"),
config.getini("junit_logging_reruns"),
config.getini("junit_duration_report_reruns"),
junit_family,
config.getini("junit_log_passing_tests_reruns"),
)
config.pluginmanager.register(config.stash[xml_key])
def pytest_unconfigure(config: Config) -> None:
xml = config.stash.get(xml_key, None)
if xml:
del config.stash[xml_key]
config.pluginmanager.unregister(xml)
class _NodeReporterReruns(_NodeReporter):
def _prepare_content(self, content: str, header: str) -> str:
return content
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
if content == "":
return
tag = ET.Element(jheader)
tag.text = bin_xml_escape(content)
self.append(tag)
class LogXMLReruns(LogXML):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def append_rerun(self, reporter: _NodeReporter, report: TestReport) -> None:
if hasattr(report, "wasxfail"):
reporter._add_simple("skipped", "xfail-marked test passes unexpectedly")
else:
assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr(
report.longrepr, "reprcrash", None
)
if reprcrash is not None:
message = reprcrash.message
else:
message = str(report.longrepr)
message = bin_xml_escape(message)
reporter._add_simple("rerun", message, str(report.longrepr))
def pytest_runtest_logreport(self, report: TestReport) -> None:
super().pytest_runtest_logreport(report)
if report.outcome == "rerun":
reporter = self._opentestcase(report)
self.append_rerun(reporter, report)
if report.outcome == "skipped":
if isinstance(report.longrepr, tuple):
fspath, lineno, reason = report.longrepr
reason = f"{report.nodeid}: {_get_raw_skip_reason(report)}"
report.longrepr = (fspath, lineno, reason)
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporterReruns:
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
# Local hack to handle xdist report order.
workernode = getattr(report, "node", None)
key = nodeid, workernode
if key in self.node_reporters:
# TODO: breaks for --dist=each
return self.node_reporters[key]
reporter = _NodeReporterReruns(nodeid, self)
self.node_reporters[key] = reporter
self.node_reporters_ordered.append(reporter)
return reporter
# imitating summary_failures in pytest's terminal.py
# both hookwrapper and tryfirst to make sure this runs before pytest's
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_terminal_summary(terminalreporter, exitstatus, config):
# prints stack traces for reruns
if terminalreporter.config.option.tbstyle != "no":
reports = terminalreporter.getreports("rerun")
if reports:
terminalreporter.write_sep("=", "RERUNS")
if terminalreporter.config.option.tbstyle == "line":
for rep in reports:
line = terminalreporter._getcrashline(rep)
terminalreporter.write_line(line)
else:
for rep in reports:
msg = terminalreporter._getfailureheadline(rep)
terminalreporter.write_sep("_", msg, red=True, bold=True)
terminalreporter._outrep_summary(rep)
terminalreporter._handle_teardown_sections(rep.nodeid)
yield
@pytest.hookimpl(tryfirst=True)
def pytest_pycollect_makemodule(module_path, path, parent) -> Module:
if parent.config.getoption("--use-main-module"):
mod = Module.from_parent(parent, path=module_path)
mod._getobj = MethodType(lambda x: sys.modules['__main__'], mod)
return mod