mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576 Approved by: https://github.com/ezyang, https://github.com/Skylion007 ghstack dependencies: #132574
69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
"""
|
|
Custom pytest shard plugin
|
|
https://github.com/AdamGleave/pytest-shard/blob/64610a08dac6b0511b6d51cf895d0e1040d162ad/pytest_shard/pytest_shard.py#L1
|
|
Modifications:
|
|
* shards are now 1 indexed instead of 0 indexed
|
|
* option for printing items in shard
|
|
"""
|
|
|
|
import hashlib
|
|
|
|
from _pytest.config.argparsing import Parser
|
|
|
|
|
|
def pytest_addoptions(parser: Parser):
|
|
"""Add options to control sharding."""
|
|
group = parser.getgroup("shard")
|
|
group.addoption(
|
|
"--shard-id", dest="shard_id", type=int, default=1, help="Number of this shard."
|
|
)
|
|
group.addoption(
|
|
"--num-shards",
|
|
dest="num_shards",
|
|
type=int,
|
|
default=1,
|
|
help="Total number of shards.",
|
|
)
|
|
group.addoption(
|
|
"--print-items",
|
|
dest="print_items",
|
|
action="store_true",
|
|
default=False,
|
|
help="Print out the items being tested in this shard.",
|
|
)
|
|
|
|
|
|
class PytestShardPlugin:
|
|
def __init__(self, config):
|
|
self.config = config
|
|
|
|
def pytest_report_collectionfinish(self, config, items) -> str:
|
|
"""Log how many and which items are tested in this shard."""
|
|
msg = f"Running {len(items)} items in this shard"
|
|
if config.getoption("print_items"):
|
|
msg += ": " + ", ".join([item.nodeid for item in items])
|
|
return msg
|
|
|
|
def sha256hash(self, x: str) -> int:
|
|
return int.from_bytes(hashlib.sha256(x.encode()).digest(), "little")
|
|
|
|
def filter_items_by_shard(self, items, shard_id: int, num_shards: int):
|
|
"""Computes `items` that should be tested in `shard_id` out of `num_shards` total shards."""
|
|
new_items = [
|
|
item
|
|
for item in items
|
|
if self.sha256hash(item.nodeid) % num_shards == shard_id - 1
|
|
]
|
|
return new_items
|
|
|
|
def pytest_collection_modifyitems(self, config, items):
|
|
"""Mutate the collection to consist of just items to be tested in this shard."""
|
|
shard_id = config.getoption("shard_id")
|
|
shard_total = config.getoption("num_shards")
|
|
if shard_id < 1 or shard_id > shard_total:
|
|
raise ValueError(
|
|
f"{shard_id} is not a valid shard ID out of {shard_total} total shards"
|
|
)
|
|
|
|
items[:] = self.filter_items_by_shard(items, shard_id, shard_total)
|