mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576 Approved by: https://github.com/ezyang, https://github.com/Skylion007 ghstack dependencies: #132574
		
			
				
	
	
		
			69 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			69 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Custom pytest shard plugin
 | |
| https://github.com/AdamGleave/pytest-shard/blob/64610a08dac6b0511b6d51cf895d0e1040d162ad/pytest_shard/pytest_shard.py#L1
 | |
| Modifications:
 | |
| * shards are now 1 indexed instead of 0 indexed
 | |
| * option for printing items in shard
 | |
| """
 | |
| 
 | |
| import hashlib
 | |
| 
 | |
| from _pytest.config.argparsing import Parser
 | |
| 
 | |
| 
 | |
| def pytest_addoptions(parser: Parser):
 | |
|     """Add options to control sharding."""
 | |
|     group = parser.getgroup("shard")
 | |
|     group.addoption(
 | |
|         "--shard-id", dest="shard_id", type=int, default=1, help="Number of this shard."
 | |
|     )
 | |
|     group.addoption(
 | |
|         "--num-shards",
 | |
|         dest="num_shards",
 | |
|         type=int,
 | |
|         default=1,
 | |
|         help="Total number of shards.",
 | |
|     )
 | |
|     group.addoption(
 | |
|         "--print-items",
 | |
|         dest="print_items",
 | |
|         action="store_true",
 | |
|         default=False,
 | |
|         help="Print out the items being tested in this shard.",
 | |
|     )
 | |
| 
 | |
| 
 | |
| class PytestShardPlugin:
 | |
|     def __init__(self, config):
 | |
|         self.config = config
 | |
| 
 | |
|     def pytest_report_collectionfinish(self, config, items) -> str:
 | |
|         """Log how many and which items are tested in this shard."""
 | |
|         msg = f"Running {len(items)} items in this shard"
 | |
|         if config.getoption("print_items"):
 | |
|             msg += ": " + ", ".join([item.nodeid for item in items])
 | |
|         return msg
 | |
| 
 | |
|     def sha256hash(self, x: str) -> int:
 | |
|         return int.from_bytes(hashlib.sha256(x.encode()).digest(), "little")
 | |
| 
 | |
|     def filter_items_by_shard(self, items, shard_id: int, num_shards: int):
 | |
|         """Computes `items` that should be tested in `shard_id` out of `num_shards` total shards."""
 | |
|         new_items = [
 | |
|             item
 | |
|             for item in items
 | |
|             if self.sha256hash(item.nodeid) % num_shards == shard_id - 1
 | |
|         ]
 | |
|         return new_items
 | |
| 
 | |
|     def pytest_collection_modifyitems(self, config, items):
 | |
|         """Mutate the collection to consist of just items to be tested in this shard."""
 | |
|         shard_id = config.getoption("shard_id")
 | |
|         shard_total = config.getoption("num_shards")
 | |
|         if shard_id < 1 or shard_id > shard_total:
 | |
|             raise ValueError(
 | |
|                 f"{shard_id} is not a valid shard ID out of {shard_total} total shards"
 | |
|             )
 | |
| 
 | |
|         items[:] = self.filter_items_by_shard(items, shard_id, shard_total)
 |