mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
With ufmt in place https://github.com/pytorch/pytorch/pull/81157, we can now use it to gradually format all files. I'm breaking this down into multiple smaller batches to avoid too many merge conflicts later on. This batch (as copied from the current BLACK linter config): * `tools/**/*.py` Upcoming batchs: * `torchgen/**/*.py` * `torch/package/**/*.py` * `torch/onnx/**/*.py` * `torch/_refs/**/*.py` * `torch/_prims/**/*.py` * `torch/_meta_registrations.py` * `torch/_decomp/**/*.py` * `test/onnx/**/*.py` Once they are all formatted, BLACK linter will be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81285 Approved by: https://github.com/suo
104 lines
3.7 KiB
Python
104 lines
3.7 KiB
Python
import contextlib
|
|
import os
|
|
import typing
|
|
import unittest
|
|
import unittest.mock
|
|
from typing import Iterator, Optional, Sequence
|
|
|
|
import tools.setup_helpers.cmake
|
|
|
|
import tools.setup_helpers.env # noqa: F401 unused but resolves circular import
|
|
|
|
|
|
T = typing.TypeVar("T")
|
|
|
|
|
|
class TestCMake(unittest.TestCase):
|
|
@unittest.mock.patch("multiprocessing.cpu_count")
|
|
def test_build_jobs(self, mock_cpu_count: unittest.mock.MagicMock) -> None:
|
|
"""Tests that the number of build jobs comes out correctly."""
|
|
mock_cpu_count.return_value = 13
|
|
cases = [
|
|
# MAX_JOBS, USE_NINJA, IS_WINDOWS, want
|
|
(("8", True, False), ["-j", "8"]), # noqa: E201,E241
|
|
((None, True, False), None), # noqa: E201,E241
|
|
(("7", False, False), ["-j", "7"]), # noqa: E201,E241
|
|
((None, False, False), ["-j", "13"]), # noqa: E201,E241
|
|
(("6", True, True), ["-j", "6"]), # noqa: E201,E241
|
|
((None, True, True), None), # noqa: E201,E241
|
|
(("11", False, True), ["/p:CL_MPCount=11"]), # noqa: E201,E241
|
|
((None, False, True), ["/p:CL_MPCount=13"]), # noqa: E201,E241
|
|
]
|
|
for (max_jobs, use_ninja, is_windows), want in cases:
|
|
with self.subTest(
|
|
MAX_JOBS=max_jobs, USE_NINJA=use_ninja, IS_WINDOWS=is_windows
|
|
):
|
|
with contextlib.ExitStack() as stack:
|
|
stack.enter_context(env_var("MAX_JOBS", max_jobs))
|
|
stack.enter_context(
|
|
unittest.mock.patch.object(
|
|
tools.setup_helpers.cmake, "USE_NINJA", use_ninja
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
unittest.mock.patch.object(
|
|
tools.setup_helpers.cmake, "IS_WINDOWS", is_windows
|
|
)
|
|
)
|
|
|
|
cmake = tools.setup_helpers.cmake.CMake()
|
|
|
|
with unittest.mock.patch.object(cmake, "run") as cmake_run:
|
|
cmake.build({})
|
|
|
|
cmake_run.assert_called_once()
|
|
(call,) = cmake_run.mock_calls
|
|
build_args, _ = call.args
|
|
|
|
if want is None:
|
|
self.assertNotIn("-j", build_args)
|
|
else:
|
|
self.assert_contains_sequence(build_args, want)
|
|
|
|
@staticmethod
|
|
def assert_contains_sequence(
|
|
sequence: Sequence[T], subsequence: Sequence[T]
|
|
) -> None:
|
|
"""Raises an assertion if the subsequence is not contained in the sequence."""
|
|
if len(subsequence) == 0:
|
|
return # all sequences contain the empty subsequence
|
|
|
|
# Iterate over all windows of len(subsequence). Stop if the
|
|
# window matches.
|
|
for i in range(len(sequence) - len(subsequence) + 1):
|
|
candidate = sequence[i : i + len(subsequence)]
|
|
assert len(candidate) == len(subsequence) # sanity check
|
|
if candidate == subsequence:
|
|
return # found it
|
|
raise AssertionError(f"{subsequence} not found in {sequence}")
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def env_var(key: str, value: Optional[str]) -> Iterator[None]:
|
|
"""Sets/clears an environment variable within a Python context."""
|
|
# Get the previous value and then override it.
|
|
previous_value = os.environ.get(key)
|
|
set_env_var(key, value)
|
|
try:
|
|
yield
|
|
finally:
|
|
# Restore to previous value.
|
|
set_env_var(key, previous_value)
|
|
|
|
|
|
def set_env_var(key: str, value: Optional[str]) -> None:
|
|
"""Sets/clears an environment variable."""
|
|
if value is None:
|
|
os.environ.pop(key, None)
|
|
else:
|
|
os.environ[key] = value
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|