Refactor nightly pull tool to use venv and pip (#141281)

Resolves #141238

- #141238

Example output:

```console
$ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10
log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log
Creating virtual environment
Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env
Installing packages
Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel
Installing packages took 5.576 [s]
Creating virtual environment took 9.505 [s]
Downloading packages
Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch
Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4:
  - mpmath-1.3.0-py3-none-any.whl
  - torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl
  - jinja2-3.1.4-py3-none-any.whl
  - sympy-1.13.1-py3-none-any.whl
  - MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl
  - networkx-3.4.2-py3-none-any.whl
  - fsspec-2024.10.0-py3-none-any.whl
  - filelock-3.16.1-py3-none-any.whl
  - typing_extensions-4.12.2-py3-none-any.whl
Downloading packages took 7.628 [s]
Installing dependencies
Installing packages
Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl
Installing packages took 42.514 [s]
Installing dependencies took 42.515 [s]
Unpacking wheel file
Unpacking wheel file took 3.223 [s]
Checking out nightly PyTorch
Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee
Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b
Switched to a new branch 'my-nightly-branch'
Checking out nightly PyTorch took 0.198 [s]
Moving nightly files into repo
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib
...
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h
Moving nightly files into repo took 11.426 [s]
Writing pytorch-nightly.pth
Writing pytorch-nightly.pth took 0.036 [s]
-------
PyTorch Development Environment set up!
Please activate to enable this environment:

  $ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281
Approved by: https://github.com/seemethere
This commit is contained in:
Xuehai Pan
2024-11-23 01:46:42 +08:00
committed by PyTorch MergeBot
parent 75cecba164
commit 2a6eaa2e6f
3 changed files with 514 additions and 311 deletions

View File

@ -15,28 +15,21 @@ import os
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
# NOTE: Also update the CUDA sources in tools/nightly.py when changing this list
CUDA_ARCHES = ["11.8", "12.4", "12.6"] CUDA_ARCHES = ["11.8", "12.4", "12.6"]
CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.4": "12.4.1", "12.6": "12.6.2"} CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.4": "12.4.1", "12.6": "12.6.2"}
CUDA_ARCHES_CUDNN_VERSION = {"11.8": "9", "12.4": "9", "12.6": "9"} CUDA_ARCHES_CUDNN_VERSION = {"11.8": "9", "12.4": "9", "12.6": "9"}
ROCM_ARCHES = ["6.1", "6.2"] ROCM_ARCHES = ["6.1", "6.2"]
XPU_ARCHES = ["xpu"] XPU_ARCHES = ["xpu"]
CPU_CXX11_ABI_ARCH = ["cpu-cxx11-abi"] CPU_CXX11_ABI_ARCH = ["cpu-cxx11-abi"]
CPU_AARCH64_ARCH = ["cpu-aarch64"] CPU_AARCH64_ARCH = ["cpu-aarch64"]
CPU_S390X_ARCH = ["cpu-s390x"] CPU_S390X_ARCH = ["cpu-s390x"]
CUDA_AARCH64_ARCH = ["cuda-aarch64"] CUDA_AARCH64_ARCH = ["cuda-aarch64"]

View File

@ -79,7 +79,7 @@ cd pytorch
git remote add upstream git@github.com:pytorch/pytorch.git git remote add upstream git@github.com:pytorch/pytorch.git
make setup-env # or make setup-env-cuda for pre-built CUDA binaries make setup-env # or make setup-env-cuda for pre-built CUDA binaries
conda activate pytorch-deps source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
``` ```
### Tips and Debugging ### Tips and Debugging
@ -166,7 +166,7 @@ conda activate pytorch-deps
## Nightly Checkout & Pull ## Nightly Checkout & Pull
The `tools/nightly.py` script is provided to ease pure Python development of The `tools/nightly.py` script is provided to ease pure Python development of
PyTorch. This uses `conda` and `git` to check out the nightly development PyTorch. This uses `venv` and `git` to check out the nightly development
version of PyTorch and installs pre-built binaries into the current repository. version of PyTorch and installs pre-built binaries into the current repository.
This is like a development or editable install, but without needing the ability This is like a development or editable install, but without needing the ability
to compile any C++ code. to compile any C++ code.
@ -175,33 +175,33 @@ You can use this script to check out a new nightly branch with the following:
```bash ```bash
./tools/nightly.py checkout -b my-nightly-branch ./tools/nightly.py checkout -b my-nightly-branch
conda activate pytorch-deps source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
``` ```
Or if you would like to re-use an existing conda environment, you can pass in Or if you would like to re-use an existing conda environment, you can pass in
the regular environment parameters (`--name` or `--prefix`): the prefix argument (`--prefix`):
```bash ```bash
./tools/nightly.py checkout -b my-nightly-branch -n my-env ./tools/nightly.py checkout -b my-nightly-branch -p my-env
conda activate my-env source my-env/bin/activate # or `& .\my-env\Scripts\Activate.ps1` on Windows
``` ```
To install the nightly binaries built with CUDA, you can pass in the flag `--cuda`: To install the nightly binaries built with CUDA, you can pass in the flag `--cuda`:
```bash ```bash
./tools/nightly.py checkout -b my-nightly-branch --cuda ./tools/nightly.py checkout -b my-nightly-branch --cuda
conda activate pytorch-deps source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
``` ```
You can also use this tool to pull the nightly commits into the current branch: You can also use this tool to pull the nightly commits into the current branch:
```bash ```bash
./tools/nightly.py pull -n my-env ./tools/nightly.py pull -p my-env
conda activate my-env source my-env/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
``` ```
Pulling will reinstall the PyTorch dependencies as well as the nightly binaries Pulling will recreate a fresh virtual environment and reinstall the development
into the repo directory. dependencies as well as the nightly binaries into the repo directory.
## Codebase structure ## Codebase structure

View File

@ -1,43 +1,42 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Much of the logging code here was forked from https://github.com/ezyang/ghstack # Much of the logging code here was forked from https://github.com/ezyang/ghstack
# Copyright (c) Edward Z. Yang <ezyang@mit.edu> # Copyright (c) Edward Z. Yang <ezyang@mit.edu>
"""Checks out the nightly development version of PyTorch and installs pre-built r"""Checks out the nightly development version of PyTorch and installs pre-built
binaries into the repo. binaries into the repo.
You can use this script to check out a new nightly branch with the following:: You can use this script to check out a new nightly branch with the following::
$ ./tools/nightly.py checkout -b my-nightly-branch $ ./tools/nightly.py checkout -b my-nightly-branch
$ conda activate pytorch-deps $ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
Or if you would like to re-use an existing conda environment, you can pass in Or if you would like to re-use an existing virtual environment, you can pass in
the regular environment parameters (--name or --prefix):: the prefix argument (--prefix)::
$ ./tools/nightly.py checkout -b my-nightly-branch -n my-env $ ./tools/nightly.py checkout -b my-nightly-branch -p my-env
$ conda activate my-env $ source my-env/bin/activate # or `& .\my-env\Scripts\Activate.ps1` on Windows
To install the nightly binaries built with CUDA, you can pass in the flag --cuda:: To install the nightly binaries built with CUDA, you can pass in the flag --cuda::
$ ./tools/nightly.py checkout -b my-nightly-branch --cuda $ ./tools/nightly.py checkout -b my-nightly-branch --cuda
$ conda activate pytorch-deps $ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
You can also use this tool to pull the nightly commits into the current branch as You can also use this tool to pull the nightly commits into the current branch as
well. This can be done with:: well. This can be done with::
$ ./tools/nightly.py pull -n my-env $ ./tools/nightly.py pull
$ conda activate my-env $ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
Pulling will reinstall the conda dependencies as well as the nightly binaries into Pulling will recreate a fresh virtual environment and reinstall the development
the repo directory. dependencies as well as the nightly binaries into the repo directory.
""" """
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import atexit
import contextlib import contextlib
import functools import functools
import glob
import itertools import itertools
import json
import logging import logging
import os import os
import re import re
@ -51,16 +50,46 @@ from ast import literal_eval
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from platform import system as platform_system from platform import system as platform_system
from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar from typing import (
Any,
Callable,
cast,
Generator,
Iterable,
Iterator,
NamedTuple,
TypeVar,
)
try:
from packaging.version import Version
except ImportError:
Version = None # type: ignore[assignment,misc]
REPO_ROOT = Path(__file__).absolute().parent.parent REPO_ROOT = Path(__file__).absolute().parent.parent
GITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git" GITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git"
SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx") PACKAGES_TO_INSTALL = (
DEFAULT_ENV_NAME = "pytorch-deps" "torch",
"numpy",
"cmake",
"ninja",
"packaging",
"ruff",
"mypy",
"pytest",
"hypothesis",
"ipython",
"rich",
"clang-format",
"clang-tidy",
"sphinx",
)
DEFAULT_VENV_DIR = REPO_ROOT / "venv"
LOGGER: logging.Logger | None = None LOGGER: logging.Logger | None = None
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss" DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
SHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})") SHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})")
USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@") USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
@ -70,6 +99,49 @@ LOG_DIRNAME_RE = re.compile(
) )
PLATFORM = platform_system().replace("Darwin", "macOS")
LINUX = PLATFORM == "Linux"
MACOS = PLATFORM == "macOS"
WINDOWS = PLATFORM == "Windows"
POSIX = LINUX or MACOS
class PipSource(NamedTuple):
name: str
index_url: str
supported_platforms: set[str]
accelerator: str
PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly"
PIP_SOURCES = {
"cpu": PipSource(
name="cpu",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu",
supported_platforms={"Linux", "macOS", "Windows"},
accelerator="cpu",
),
"cuda-11.8": PipSource(
name="cuda-11.8",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu118",
supported_platforms={"Linux", "Windows"},
accelerator="cuda",
),
"cuda-12.1": PipSource(
name="cuda-12.1",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu121",
supported_platforms={"Linux", "Windows"},
accelerator="cuda",
),
"cuda-12.4": PipSource(
name="cuda-12.4",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu124",
supported_platforms={"Linux", "Windows"},
accelerator="cuda",
),
}
class Formatter(logging.Formatter): class Formatter(logging.Formatter):
redactions: dict[str, str] redactions: dict[str, str]
@ -108,6 +180,290 @@ class Formatter(logging.Formatter):
self.redactions[needle] = replace self.redactions[needle] = replace
@contextlib.contextmanager
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
"""Timed context manager"""
start_time = time.perf_counter()
yield
logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time)
F = TypeVar("F", bound=Callable[..., Any])
def timed(prefix: str) -> Callable[[F], F]:
"""Decorator for timing functions"""
def decorator(f: F) -> F:
@functools.wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
logger = cast(logging.Logger, LOGGER)
logger.info(prefix)
with timer(logger, prefix):
return f(*args, **kwargs)
return cast(F, wrapper)
return decorator
class Venv:
"""Virtual environment manager"""
AGGRESSIVE_UPDATE_PACKAGES = ("pip", "setuptools", "packaging", "wheel")
def __init__(
self,
prefix: Path | str,
pip_source: PipSource,
*,
base_executable: Path | str | None = None,
) -> None:
self.prefix = Path(prefix).absolute()
self.pip_source = pip_source
self.base_executable = Path(base_executable or sys.executable).absolute()
self._executable: Path | None = None
self._env = {"PIP_EXTRA_INDEX_URL": self.pip_source.index_url}
def is_venv(self) -> bool:
"""Check if the prefix is a virtual environment."""
return self.prefix.is_dir() and (self.prefix / "pyvenv.cfg").is_file()
@property
def executable(self) -> Path:
"""Get the Python executable for the virtual environment."""
assert self.is_venv()
if self._executable is None:
if WINDOWS:
executable = self.prefix / "Scripts" / "python.exe"
else:
executable = self.prefix / "bin" / "python"
assert executable.is_file() or executable.is_symlink()
assert os.access(executable, os.X_OK), f"{executable} is not executable"
self._executable = executable
return self._executable
def site_packages(self, python: Path | str | None = None) -> Path:
"""Get the site-packages directory for the virtual environment."""
output = self.python(
"-c",
"import site; [print(p) for p in site.getsitepackages()]",
python=python,
capture_output=True,
).stdout
candidates = list(map(Path, filter(None, map(str.strip, output.splitlines()))))
candidates = [p for p in candidates if p.is_dir() and p.name == "site-packages"]
if not candidates:
raise RuntimeError(
f"No site-packages directory found for excecutable {python}"
)
return candidates[0]
@property
def activate_command(self) -> str:
"""Get the command to activate the virtual environment."""
if WINDOWS:
# Assume PowerShell
return f"& {self.prefix / 'Scripts' / 'Activate.ps1'}"
return f"source {self.prefix}/bin/activate"
@timed("Creating virtual environment")
def create(self, *, remove_if_exists: bool = False) -> Path:
"""Create a virtual environment."""
if self.prefix.exists():
if remove_if_exists:
# If the venv directory already exists, remove it first
if not self.is_venv():
raise RuntimeError(
f"The path {self.prefix} already exists and is not a virtual environment. "
"Please remove it manually or choose a different prefix."
)
if self.prefix in [
Path(p).absolute()
for p in [
sys.prefix,
sys.exec_prefix,
sys.base_prefix,
sys.base_exec_prefix,
]
]:
raise RuntimeError(
f"The path {self.prefix} trying to remove is the same as the interpreter "
"to run this script. Please choose a different prefix or deactivate the "
"current virtual environment."
)
if self.prefix in [
Path(
self.base_python(
"-c",
f"import os, sys; print(os.path.abspath({p}))",
capture_output=True,
).stdout.strip()
).absolute()
for p in [
"sys.prefix",
"sys.exec_prefix",
"sys.base_prefix",
"sys.base_exec_prefix",
]
]:
raise RuntimeError(
f"The Python executable {self.base_executable} trying to remove is the "
"same as the interpreter to create the virtual environment. Please choose "
"a different prefix or a different Python interpreter."
)
print(f"Removing existing venv: {self.prefix}")
_remove_existing(self.prefix)
else:
raise RuntimeError(f"Path {self.prefix} already exists.")
print(f"Creating venv (Python {self.base_python_version()}): {self.prefix}")
self.base_python("-m", "venv", str(self.prefix))
assert self.is_venv(), "Failed to create virtual environment."
(self.prefix / ".gitignore").write_text("*\n", encoding="utf-8")
return self.ensure()
def ensure(self) -> Path:
"""Ensure the virtual environment exists."""
if not self.is_venv():
return self.create(remove_if_exists=True)
self.pip_install(*self.AGGRESSIVE_UPDATE_PACKAGES, upgrade=True)
return self.prefix
def python(
self,
*args: str,
python: Path | str | None = None,
**popen_kwargs: Any,
) -> subprocess.CompletedProcess[str]:
"""Run a Python command in the virtual environment."""
if python is None:
python = self.executable
cmd = [str(python), *args]
env = popen_kwargs.pop("env", None) or {}
return subprocess.run(
cmd,
check=True,
text=True,
encoding="utf-8",
env={**self._env, **env},
**popen_kwargs,
)
def base_python(
self,
*args: str,
**popen_kwargs: Any,
) -> subprocess.CompletedProcess[str]:
"""Run a Python command in the base environment."""
return self.python(*args, python=self.base_executable, **popen_kwargs)
def python_version(self, *, python: Path | str | None = None) -> str:
"""Get the Python version for the virtual environment."""
return self.python(
"-c",
(
"import sys; print('{0.major}.{0.minor}.{0.micro}{1}'."
"format(sys.version_info, getattr(sys, 'abiflags', '')))"
),
python=python,
capture_output=True,
).stdout.strip()
def base_python_version(self) -> str:
"""Get the Python version for the base environment."""
return self.python_version(python=self.base_executable)
def pip(self, *args: str, **popen_kwargs: Any) -> subprocess.CompletedProcess[str]:
"""Run a pip command in the virtual environment."""
return self.python("-m", "pip", *args, **popen_kwargs)
@timed("Installing packages")
def pip_install(
self,
*packages: str,
prerelease: bool = False,
upgrade: bool = False,
**popen_kwargs: Any,
) -> subprocess.CompletedProcess[str]:
"""Run a pip install command in the virtual environment."""
if upgrade:
args = ["--upgrade", *packages]
verb = "Upgrading"
else:
args = list(packages)
verb = "Installing"
if prerelease:
args = ["--pre", *args]
print(
f"{verb} package(s) ({self.pip_source.index_url}): "
f"{', '.join(map(os.path.basename, packages))}"
)
return self.pip("install", *args, **popen_kwargs)
@timed("Downloading packages")
def pip_download(
self,
*packages: str,
prerelease: bool = False,
**popen_kwargs: Any,
) -> list[Path]:
"""Download a package in the virtual environment."""
tmpdir = tempfile.TemporaryDirectory(prefix="pip-download-")
atexit.register(tmpdir.cleanup)
tempdir = Path(tmpdir.name).absolute()
print(
f"Downloading package(s) ({self.pip_source.index_url}): "
f"{', '.join(packages)}"
)
if prerelease:
args = ["--pre", *packages]
else:
args = list(packages)
self.pip("download", "--dest", str(tempdir), *args, **popen_kwargs)
files = list(tempdir.iterdir())
print(f"Downloaded {len(files)} file(s) to {tempdir}:")
for file in files:
print(f" - {file.name}")
return files
def wheel(
self,
*args: str,
**popen_kwargs: Any,
) -> subprocess.CompletedProcess[str]:
"""Run a wheel command in the virtual environment."""
return self.python("-m", "wheel", *args, **popen_kwargs)
@timed("Unpacking wheel file")
def wheel_unpack(
self,
wheel: Path | str,
dest: Path | str,
**popen_kwargs: Any,
) -> subprocess.CompletedProcess[str]:
"""Unpack a wheel into a directory."""
wheel = Path(wheel).absolute()
dest = Path(dest).absolute()
assert wheel.is_file() and wheel.suffix.lower() == ".whl"
return self.wheel("unpack", "--dest", str(dest), str(wheel), **popen_kwargs)
@contextlib.contextmanager
def extracted_wheel(self, wheel: Path | str) -> Generator[Path]:
"""Download and extract a wheel into a temporary directory."""
with tempfile.TemporaryDirectory(prefix="wheel-") as tempdir:
self.wheel_unpack(wheel, tempdir)
subdirs = [p for p in Path(tempdir).absolute().iterdir() if p.is_dir()]
if len(subdirs) != 1:
raise RuntimeError(
f"Expected exactly one directory in {tempdir}, "
f"got {[str(d) for d in subdirs]}."
)
yield subdirs[0]
def git(*args: str) -> list[str]: def git(*args: str) -> list[str]:
return ["git", "-C", str(REPO_ROOT), *args] return ["git", "-C", str(REPO_ROOT), *args]
@ -134,7 +490,10 @@ def logging_record_argv() -> None:
def logging_record_exception(e: BaseException) -> None: def logging_record_exception(e: BaseException) -> None:
(logging_run_dir() / "exception").write_text(type(e).__name__, encoding="utf-8") text = f"{type(e).__name__}: {e}"
if isinstance(e, subprocess.CalledProcessError):
text += f"\n\nstdout: {e.stdout}\n\nstderr: {e.stderr}"
(logging_run_dir() / "exception").write_text(text, encoding="utf-8")
def logging_rotate() -> None: def logging_rotate() -> None:
@ -156,7 +515,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
stderr (INFO) and file handler (DEBUG). stderr (INFO) and file handler (DEBUG).
""" """
formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="") formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="")
root_logger = logging.getLogger("conda-pytorch") root_logger = logging.getLogger("pytorch-nightly")
root_logger.setLevel(logging.DEBUG) root_logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
@ -213,161 +572,13 @@ def check_branch(subcommand: str, branch: str | None) -> str | None:
return None return None
def check_conda_env_exists(name: str | None = None, prefix: str | None = None) -> bool:
"""Checks that the conda environment exists."""
if name is not None and prefix is not None:
raise ValueError("Cannot specify both --name and --prefix")
if name is None and prefix is None:
raise ValueError("Must specify either --name or --prefix")
try:
cmd = ["conda", "info", "--envs"]
output = subprocess.check_output(cmd, text=True, encoding="utf-8")
except subprocess.CalledProcessError:
logger = cast(logging.Logger, LOGGER)
logger.warning("Failed to list conda environments", exc_info=True)
return False
if name is not None:
return len(re.findall(rf"^{name}\s+", output, flags=re.MULTILINE)) > 0
assert prefix is not None
prefix = Path(prefix).absolute()
return len(re.findall(rf"\s+{prefix}$", output, flags=re.MULTILINE)) > 0
@contextlib.contextmanager
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
"""Timed context manager"""
start_time = time.perf_counter()
yield
logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time)
F = TypeVar("F", bound=Callable[..., Any])
def timed(prefix: str) -> Callable[[F], F]:
"""Decorator for timing functions"""
def dec(f: F) -> F:
@functools.wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
logger = cast(logging.Logger, LOGGER)
logger.info(prefix)
with timer(logger, prefix):
return f(*args, **kwargs)
return cast(F, wrapper)
return dec
def _make_channel_args(
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> list[str]:
args = []
for channel in channels:
args.extend(["--channel", channel])
if override_channels:
args.append("--override-channels")
return args
@timed("Solving conda environment")
def conda_solve(
specs: Iterable[str],
*,
name: str | None = None,
prefix: str | None = None,
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> tuple[list[str], str, str, bool, list[str]]:
"""Performs the conda solve and splits the deps from the package."""
# compute what environment to use
if prefix is not None:
existing_env = True
env_opts = ["--prefix", prefix]
elif name is not None:
existing_env = True
env_opts = ["--name", name]
else:
# create new environment
existing_env = False
env_opts = ["--name", DEFAULT_ENV_NAME]
# run solve
if existing_env:
cmd = [
"conda",
"install",
"--yes",
"--dry-run",
"--json",
*env_opts,
]
else:
cmd = [
"conda",
"create",
"--yes",
"--dry-run",
"--json",
"--name",
"__pytorch__",
]
channel_args = _make_channel_args(
channels=channels,
override_channels=override_channels,
)
cmd.extend(channel_args)
cmd.extend(specs)
stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
# parse solution
solve = json.loads(stdout)
link = solve["actions"]["LINK"]
deps = []
pytorch, platform = "", ""
for pkg in link:
url = URL_FORMAT.format(**pkg)
if pkg["name"] == "pytorch":
pytorch = url
platform = pkg["platform"]
else:
deps.append(url)
assert pytorch, "PyTorch package not found in solve"
assert platform, "Platform not found in solve"
return deps, pytorch, platform, existing_env, env_opts
@timed("Installing dependencies") @timed("Installing dependencies")
def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None: def install_packages(venv: Venv, packages: Iterable[str]) -> None:
"""Install dependencies to deps environment""" """Install dependencies to deps environment"""
if not existing_env: # install packages
# first remove previous pytorch-deps env packages = list(dict.fromkeys(packages))
if check_conda_env_exists(name=DEFAULT_ENV_NAME): if packages:
cmd = ["conda", "env", "remove", "--yes", *env_opts] venv.pip_install(*packages)
subprocess.check_output(cmd)
# install new deps
install_command = "install" if existing_env else "create"
cmd = ["conda", install_command, "--yes", "--no-deps", *env_opts, *deps]
subprocess.check_call(cmd)
@timed("Installing pytorch nightly binaries")
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
"""Install pytorch into a temporary directory"""
pytorch_dir = tempfile.TemporaryDirectory(prefix="conda-pytorch-")
cmd = ["conda", "create", "--yes", "--no-deps", f"--prefix={pytorch_dir.name}", url]
subprocess.check_call(cmd)
return pytorch_dir
def _site_packages(dirname: str, platform: str) -> Path:
if platform.startswith("win"):
template = os.path.join(dirname, "Lib", "site-packages")
else:
template = os.path.join(dirname, "lib", "python*.*", "site-packages")
return Path(next(glob.iglob(template))).absolute()
def _ensure_commit(git_sha1: str) -> None: def _ensure_commit(git_sha1: str) -> None:
@ -437,7 +648,7 @@ def _get_listing_linux(source_dir: Path) -> list[Path]:
) )
def _get_listing_osx(source_dir: Path) -> list[Path]: def _get_listing_macos(source_dir: Path) -> list[Path]:
# oddly, these are .so files even on Mac # oddly, these are .so files even on Mac
return list( return list(
itertools.chain( itertools.chain(
@ -447,7 +658,7 @@ def _get_listing_osx(source_dir: Path) -> list[Path]:
) )
def _get_listing_win(source_dir: Path) -> list[Path]: def _get_listing_windows(source_dir: Path) -> list[Path]:
return list( return list(
itertools.chain( itertools.chain(
source_dir.glob("*.pyd"), source_dir.glob("*.pyd"),
@ -468,15 +679,15 @@ def _find_missing_pyi(source_dir: Path, target_dir: Path) -> list[Path]:
return missing_pyis return missing_pyis
def _get_listing(source_dir: Path, target_dir: Path, platform: str) -> list[Path]: def _get_listing(source_dir: Path, target_dir: Path) -> list[Path]:
if platform.startswith("linux"): if LINUX:
listing = _get_listing_linux(source_dir) listing = _get_listing_linux(source_dir)
elif platform.startswith("osx"): elif MACOS:
listing = _get_listing_osx(source_dir) listing = _get_listing_macos(source_dir)
elif platform.startswith("win"): elif WINDOWS:
listing = _get_listing_win(source_dir) listing = _get_listing_windows(source_dir)
else: else:
raise RuntimeError(f"Platform {platform!r} not recognized") raise RuntimeError(f"Platform {platform_system()!r} not recognized")
listing.extend(_find_missing_pyi(source_dir, target_dir)) listing.extend(_find_missing_pyi(source_dir, target_dir))
listing.append(source_dir / "version.py") listing.append(source_dir / "version.py")
listing.append(source_dir / "testing" / "_internal" / "generated") listing.append(source_dir / "testing" / "_internal" / "generated")
@ -532,14 +743,14 @@ def _link_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None
@timed("Moving nightly files into repo") @timed("Moving nightly files into repo")
def move_nightly_files(site_dir: Path, platform: str) -> None: def move_nightly_files(site_dir: Path) -> None:
"""Moves PyTorch files from temporary installed location to repo.""" """Moves PyTorch files from temporary installed location to repo."""
# get file listing # get file listing
source_dir = site_dir / "torch" source_dir = site_dir / "torch"
target_dir = REPO_ROOT / "torch" target_dir = REPO_ROOT / "torch"
listing = _get_listing(source_dir, target_dir, platform) listing = _get_listing(source_dir, target_dir)
# copy / link files # copy / link files
if platform.startswith("win"): if WINDOWS:
_copy_files(listing, source_dir, target_dir) _copy_files(listing, source_dir, target_dir)
else: else:
try: try:
@ -548,31 +759,10 @@ def move_nightly_files(site_dir: Path, platform: str) -> None:
_copy_files(listing, source_dir, target_dir) _copy_files(listing, source_dir, target_dir)
def _available_envs() -> dict[str, str]:
cmd = ["conda", "env", "list"]
stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
envs = {}
for line in map(str.strip, stdout.splitlines()):
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) == 1:
# unnamed env
continue
envs[parts[0]] = parts[-1]
return envs
@timed("Writing pytorch-nightly.pth") @timed("Writing pytorch-nightly.pth")
def write_pth(env_opts: list[str], platform: str) -> None: def write_pth(venv: Venv) -> None:
"""Writes Python path file for this dir.""" """Writes Python path file for this dir."""
env_type, env_dir = env_opts (venv.site_packages() / "pytorch-nightly.pth").write_text(
if env_type == "--name":
# have to find directory
envs = _available_envs()
env_dir = envs[env_dir]
site_dir = _site_packages(env_dir, platform)
(site_dir / "pytorch-nightly.pth").write_text(
"# This file was autogenerated by PyTorch's tools/nightly.py\n" "# This file was autogenerated by PyTorch's tools/nightly.py\n"
"# Please delete this file if you no longer need the following development\n" "# Please delete this file if you no longer need the following development\n"
"# version of PyTorch to be importable\n" "# version of PyTorch to be importable\n"
@ -582,50 +772,66 @@ def write_pth(env_opts: list[str], platform: str) -> None:
def install( def install(
specs: Iterable[str],
*, *,
logger: logging.Logger, venv: Venv,
packages: Iterable[str],
subcommand: str = "checkout", subcommand: str = "checkout",
branch: str | None = None, branch: str | None = None,
name: str | None = None, logger: logging.Logger,
prefix: str | None = None,
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> None: ) -> None:
"""Development install of PyTorch""" """Development install of PyTorch"""
specs = list(specs) use_existing = subcommand == "checkout"
deps, pytorch, platform, existing_env, env_opts = conda_solve( if use_existing:
specs=specs, venv.ensure()
name=name, else:
prefix=prefix, venv.create(remove_if_exists=True)
channels=channels,
override_channels=override_channels,
)
if deps:
deps_install(deps, existing_env, env_opts)
with pytorch_install(pytorch) as pytorch_dir: packages = [p for p in packages if p != "torch"]
site_dir = _site_packages(pytorch_dir, platform)
dependencies = venv.pip_download("torch", prerelease=True)
torch_wheel = [
dep
for dep in dependencies
if dep.name.startswith("torch-") and dep.name.endswith(".whl")
]
if len(torch_wheel) != 1:
raise RuntimeError(f"Expected exactly one torch wheel, got {torch_wheel}")
torch_wheel = torch_wheel[0]
dependencies = [deps for deps in dependencies if deps != torch_wheel]
install_packages(venv, [*packages, *map(str, dependencies)])
with venv.extracted_wheel(torch_wheel) as wheel_site_dir:
if subcommand == "checkout": if subcommand == "checkout":
checkout_nightly_version(cast(str, branch), site_dir) checkout_nightly_version(cast(str, branch), wheel_site_dir)
elif subcommand == "pull": elif subcommand == "pull":
pull_nightly_version(site_dir) pull_nightly_version(wheel_site_dir)
else: else:
raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.") raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.")
move_nightly_files(site_dir, platform) move_nightly_files(wheel_site_dir)
write_pth(env_opts, platform) write_pth(venv)
logger.info( logger.info(
"-------\nPyTorch Development Environment set up!\nPlease activate to " "-------\n"
"enable this environment:\n $ conda activate %s", "PyTorch Development Environment set up!\n"
env_opts[1], "Please activate to enable this environment:\n\n"
" $ %s",
venv.activate_command,
) )
def make_parser() -> argparse.ArgumentParser: def make_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser() def find_executable(name: str) -> Path:
executable = shutil.which(name)
if executable is None:
raise argparse.ArgumentTypeError(
f"Could not find executable {name} in PATH."
)
return Path(executable).absolute()
parser = argparse.ArgumentParser()
# subcommands # subcommands
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute") subcmd = parser.add_subparsers(dest="subcmd", help="subcommand to execute")
checkout = subcmd.add_parser("checkout", help="checkout a new branch") checkout = subcmd.add_parser("checkout", help="checkout a new branch")
checkout.add_argument( checkout.add_argument(
"-b", "-b",
@ -642,19 +848,24 @@ def make_parser() -> argparse.ArgumentParser:
subparsers = [checkout, pull] subparsers = [checkout, pull]
for subparser in subparsers: for subparser in subparsers:
subparser.add_argument( subparser.add_argument(
"-n", "--python",
"--name", "--base-executable",
help="Name of environment", type=find_executable,
dest="name", help=(
"Path to Python interpreter to use for creating the virtual environment. "
"Defaults to the interpreter running this script."
),
dest="base_executable",
default=None, default=None,
metavar="ENVIRONMENT", metavar="PYTHON",
) )
subparser.add_argument( subparser.add_argument(
"-p", "-p",
"--prefix", "--prefix",
help="Full path to environment location (i.e. prefix)", type=lambda p: Path(p).absolute(),
help='Path to virtual environment directory (e.g. "./venv")',
dest="prefix", dest="prefix",
default=None, default=str(DEFAULT_VENV_DIR),
metavar="PATH", metavar="PATH",
) )
subparser.add_argument( subparser.add_argument(
@ -665,25 +876,6 @@ def make_parser() -> argparse.ArgumentParser:
default=False, default=False,
action="store_true", action="store_true",
) )
subparser.add_argument(
"--override-channels",
help="Do not search default or .condarc channels.",
dest="override_channels",
default=False,
action="store_true",
)
subparser.add_argument(
"-c",
"--channel",
help=(
"Additional channel to search for packages. "
"'pytorch-nightly' will always be prepended to this list."
),
dest="channels",
action="append",
metavar="CHANNEL",
)
if platform_system() in {"Linux", "Windows"}:
subparser.add_argument( subparser.add_argument(
"--cuda", "--cuda",
help=( help=(
@ -695,42 +887,60 @@ def make_parser() -> argparse.ArgumentParser:
default=argparse.SUPPRESS, default=argparse.SUPPRESS,
metavar="VERSION", metavar="VERSION",
) )
return p return parser
def main(args: Sequence[str] | None = None) -> None: def parse_arguments() -> argparse.Namespace:
parser = make_parser()
args = parser.parse_args()
args.branch = getattr(args, "branch", None)
return args
def main() -> None:
"""Main entry point""" """Main entry point"""
global LOGGER global LOGGER
p = make_parser() args = parse_arguments()
ns = p.parse_args(args) status = check_branch(args.subcmd, args.branch)
ns.branch = getattr(ns, "branch", None)
status = check_branch(ns.subcmd, ns.branch)
if status: if status:
sys.exit(status) sys.exit(status)
specs = list(SPECS_TO_INSTALL)
channels = ["pytorch-nightly"] pip_source = None
if hasattr(ns, "cuda"): if hasattr(args, "cuda"):
if ns.cuda is not None: available_sources = {
specs.append(f"pytorch-cuda={ns.cuda}") src.name[len("cuda-") :]: src
for src in PIP_SOURCES.values()
if src.name.startswith("cuda-") and PLATFORM in src.supported_platforms
}
if not available_sources:
print(f"No CUDA versions available on platform {PLATFORM}.")
sys.exit(1)
if args.cuda is not None:
pip_source = available_sources.get(args.cuda)
if pip_source is None:
print(
f"CUDA {args.cuda} is not available on platform {PLATFORM}. "
f"Available version(s): {', '.join(sorted(available_sources, key=Version))}"
)
sys.exit(1)
else: else:
specs.append("pytorch-cuda") pip_source = available_sources[max(available_sources, key=Version)]
specs.append("pytorch-mutex=*=*cuda*")
channels.append("nvidia")
else: else:
specs.append("pytorch-mutex=*=*cpu*") pip_source = PIP_SOURCES["cpu"] # always available
if ns.channels:
channels.extend(ns.channels) with logging_manager(debug=args.verbose) as logger:
with logging_manager(debug=ns.verbose) as logger:
LOGGER = logger LOGGER = logger
venv = Venv(
prefix=args.prefix,
pip_source=pip_source,
base_executable=args.base_executable,
)
install( install(
specs=specs, venv=venv,
subcommand=ns.subcmd, packages=PACKAGES_TO_INSTALL,
branch=ns.branch, subcommand=args.subcmd,
name=ns.name, branch=args.branch,
prefix=ns.prefix,
logger=logger, logger=logger,
channels=channels,
override_channels=ns.override_channels,
) )