Refactor nightly pull tool to use venv and pip (#141281)

Resolves #141238

- #141238

Example output:

```console
$ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10
log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log
Creating virtual environment
Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env
Installing packages
Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel
Installing packages took 5.576 [s]
Creating virtual environment took 9.505 [s]
Downloading packages
Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch
Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4:
  - mpmath-1.3.0-py3-none-any.whl
  - torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl
  - jinja2-3.1.4-py3-none-any.whl
  - sympy-1.13.1-py3-none-any.whl
  - MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl
  - networkx-3.4.2-py3-none-any.whl
  - fsspec-2024.10.0-py3-none-any.whl
  - filelock-3.16.1-py3-none-any.whl
  - typing_extensions-4.12.2-py3-none-any.whl
Downloading packages took 7.628 [s]
Installing dependencies
Installing packages
Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl
Installing packages took 42.514 [s]
Installing dependencies took 42.515 [s]
Unpacking wheel file
Unpacking wheel file took 3.223 [s]
Checking out nightly PyTorch
Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee
Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b
Switched to a new branch 'my-nightly-branch'
Checking out nightly PyTorch took 0.198 [s]
Moving nightly files into repo
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib
...
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h
Moving nightly files into repo took 11.426 [s]
Writing pytorch-nightly.pth
Writing pytorch-nightly.pth took 0.036 [s]
-------
PyTorch Development Environment set up!
Please activate to enable this environment:

  $ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281
Approved by: https://github.com/seemethere
This commit is contained in:
Xuehai Pan
2024-11-23 01:46:42 +08:00
committed by PyTorch MergeBot
parent 75cecba164
commit 2a6eaa2e6f
3 changed files with 514 additions and 311 deletions

View File

@ -15,28 +15,21 @@ import os
from typing import Dict, List, Optional, Tuple
# NOTE: Also update the CUDA sources in tools/nightly.py when changing this list
CUDA_ARCHES = ["11.8", "12.4", "12.6"]
CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.4": "12.4.1", "12.6": "12.6.2"}
CUDA_ARCHES_CUDNN_VERSION = {"11.8": "9", "12.4": "9", "12.6": "9"}
ROCM_ARCHES = ["6.1", "6.2"]
XPU_ARCHES = ["xpu"]
CPU_CXX11_ABI_ARCH = ["cpu-cxx11-abi"]
CPU_AARCH64_ARCH = ["cpu-aarch64"]
CPU_S390X_ARCH = ["cpu-s390x"]
CUDA_AARCH64_ARCH = ["cuda-aarch64"]

View File

@ -79,7 +79,7 @@ cd pytorch
git remote add upstream git@github.com:pytorch/pytorch.git
make setup-env # or make setup-env-cuda for pre-built CUDA binaries
conda activate pytorch-deps
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
```
### Tips and Debugging
@ -166,7 +166,7 @@ conda activate pytorch-deps
## Nightly Checkout & Pull
The `tools/nightly.py` script is provided to ease pure Python development of
PyTorch. This uses `conda` and `git` to check out the nightly development
PyTorch. This uses `venv` and `git` to check out the nightly development
version of PyTorch and installs pre-built binaries into the current repository.
This is like a development or editable install, but without needing the ability
to compile any C++ code.
@ -175,33 +175,33 @@ You can use this script to check out a new nightly branch with the following:
```bash
./tools/nightly.py checkout -b my-nightly-branch
conda activate pytorch-deps
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
```
Or if you would like to re-use an existing conda environment, you can pass in
the regular environment parameters (`--name` or `--prefix`):
the prefix argument (`--prefix`):
```bash
./tools/nightly.py checkout -b my-nightly-branch -n my-env
conda activate my-env
./tools/nightly.py checkout -b my-nightly-branch -p my-env
source my-env/bin/activate # or `& .\my-env\Scripts\Activate.ps1` on Windows
```
To install the nightly binaries built with CUDA, you can pass in the flag `--cuda`:
```bash
./tools/nightly.py checkout -b my-nightly-branch --cuda
conda activate pytorch-deps
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
```
You can also use this tool to pull the nightly commits into the current branch:
```bash
./tools/nightly.py pull -n my-env
conda activate my-env
./tools/nightly.py pull -p my-env
source my-env/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
```
Pulling will reinstall the PyTorch dependencies as well as the nightly binaries
into the repo directory.
Pulling will recreate a fresh virtual environment and reinstall the development
dependencies as well as the nightly binaries into the repo directory.
## Codebase structure

View File

@ -1,43 +1,42 @@
#!/usr/bin/env python3
# Much of the logging code here was forked from https://github.com/ezyang/ghstack
# Copyright (c) Edward Z. Yang <ezyang@mit.edu>
"""Checks out the nightly development version of PyTorch and installs pre-built
r"""Checks out the nightly development version of PyTorch and installs pre-built
binaries into the repo.
You can use this script to check out a new nightly branch with the following::
$ ./tools/nightly.py checkout -b my-nightly-branch
$ conda activate pytorch-deps
$ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
Or if you would like to re-use an existing conda environment, you can pass in
the regular environment parameters (--name or --prefix)::
Or if you would like to re-use an existing virtual environment, you can pass in
the prefix argument (--prefix)::
$ ./tools/nightly.py checkout -b my-nightly-branch -n my-env
$ conda activate my-env
$ ./tools/nightly.py checkout -b my-nightly-branch -p my-env
$ source my-env/bin/activate # or `& .\my-env\Scripts\Activate.ps1` on Windows
To install the nightly binaries built with CUDA, you can pass in the flag --cuda::
$ ./tools/nightly.py checkout -b my-nightly-branch --cuda
$ conda activate pytorch-deps
$ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
You can also use this tool to pull the nightly commits into the current branch as
well. This can be done with::
$ ./tools/nightly.py pull -n my-env
$ conda activate my-env
$ ./tools/nightly.py pull
$ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
Pulling will reinstall the conda dependencies as well as the nightly binaries into
the repo directory.
Pulling will recreate a fresh virtual environment and reinstall the development
dependencies as well as the nightly binaries into the repo directory.
"""
from __future__ import annotations
import argparse
import atexit
import contextlib
import functools
import glob
import itertools
import json
import logging
import os
import re
@ -51,16 +50,46 @@ from ast import literal_eval
from datetime import datetime
from pathlib import Path
from platform import system as platform_system
from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
from typing import (
Any,
Callable,
cast,
Generator,
Iterable,
Iterator,
NamedTuple,
TypeVar,
)
try:
from packaging.version import Version
except ImportError:
Version = None # type: ignore[assignment,misc]
REPO_ROOT = Path(__file__).absolute().parent.parent
GITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git"
SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx")
DEFAULT_ENV_NAME = "pytorch-deps"
PACKAGES_TO_INSTALL = (
"torch",
"numpy",
"cmake",
"ninja",
"packaging",
"ruff",
"mypy",
"pytest",
"hypothesis",
"ipython",
"rich",
"clang-format",
"clang-tidy",
"sphinx",
)
DEFAULT_VENV_DIR = REPO_ROOT / "venv"
LOGGER: logging.Logger | None = None
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
SHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})")
USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
@ -70,6 +99,49 @@ LOG_DIRNAME_RE = re.compile(
)
PLATFORM = platform_system().replace("Darwin", "macOS")
LINUX = PLATFORM == "Linux"
MACOS = PLATFORM == "macOS"
WINDOWS = PLATFORM == "Windows"
POSIX = LINUX or MACOS
class PipSource(NamedTuple):
name: str
index_url: str
supported_platforms: set[str]
accelerator: str
PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly"
PIP_SOURCES = {
"cpu": PipSource(
name="cpu",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu",
supported_platforms={"Linux", "macOS", "Windows"},
accelerator="cpu",
),
"cuda-11.8": PipSource(
name="cuda-11.8",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu118",
supported_platforms={"Linux", "Windows"},
accelerator="cuda",
),
"cuda-12.1": PipSource(
name="cuda-12.1",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu121",
supported_platforms={"Linux", "Windows"},
accelerator="cuda",
),
"cuda-12.4": PipSource(
name="cuda-12.4",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu124",
supported_platforms={"Linux", "Windows"},
accelerator="cuda",
),
}
class Formatter(logging.Formatter):
redactions: dict[str, str]
@ -108,6 +180,290 @@ class Formatter(logging.Formatter):
self.redactions[needle] = replace
@contextlib.contextmanager
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
"""Timed context manager"""
start_time = time.perf_counter()
yield
logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time)
F = TypeVar("F", bound=Callable[..., Any])
def timed(prefix: str) -> Callable[[F], F]:
"""Decorator for timing functions"""
def decorator(f: F) -> F:
@functools.wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
logger = cast(logging.Logger, LOGGER)
logger.info(prefix)
with timer(logger, prefix):
return f(*args, **kwargs)
return cast(F, wrapper)
return decorator
class Venv:
"""Virtual environment manager"""
AGGRESSIVE_UPDATE_PACKAGES = ("pip", "setuptools", "packaging", "wheel")
def __init__(
self,
prefix: Path | str,
pip_source: PipSource,
*,
base_executable: Path | str | None = None,
) -> None:
self.prefix = Path(prefix).absolute()
self.pip_source = pip_source
self.base_executable = Path(base_executable or sys.executable).absolute()
self._executable: Path | None = None
self._env = {"PIP_EXTRA_INDEX_URL": self.pip_source.index_url}
def is_venv(self) -> bool:
"""Check if the prefix is a virtual environment."""
return self.prefix.is_dir() and (self.prefix / "pyvenv.cfg").is_file()
@property
def executable(self) -> Path:
"""Get the Python executable for the virtual environment."""
assert self.is_venv()
if self._executable is None:
if WINDOWS:
executable = self.prefix / "Scripts" / "python.exe"
else:
executable = self.prefix / "bin" / "python"
assert executable.is_file() or executable.is_symlink()
assert os.access(executable, os.X_OK), f"{executable} is not executable"
self._executable = executable
return self._executable
def site_packages(self, python: Path | str | None = None) -> Path:
"""Get the site-packages directory for the virtual environment."""
output = self.python(
"-c",
"import site; [print(p) for p in site.getsitepackages()]",
python=python,
capture_output=True,
).stdout
candidates = list(map(Path, filter(None, map(str.strip, output.splitlines()))))
candidates = [p for p in candidates if p.is_dir() and p.name == "site-packages"]
if not candidates:
raise RuntimeError(
f"No site-packages directory found for excecutable {python}"
)
return candidates[0]
@property
def activate_command(self) -> str:
"""Get the command to activate the virtual environment."""
if WINDOWS:
# Assume PowerShell
return f"& {self.prefix / 'Scripts' / 'Activate.ps1'}"
return f"source {self.prefix}/bin/activate"
@timed("Creating virtual environment")
def create(self, *, remove_if_exists: bool = False) -> Path:
"""Create a virtual environment."""
if self.prefix.exists():
if remove_if_exists:
# If the venv directory already exists, remove it first
if not self.is_venv():
raise RuntimeError(
f"The path {self.prefix} already exists and is not a virtual environment. "
"Please remove it manually or choose a different prefix."
)
if self.prefix in [
Path(p).absolute()
for p in [
sys.prefix,
sys.exec_prefix,
sys.base_prefix,
sys.base_exec_prefix,
]
]:
raise RuntimeError(
f"The path {self.prefix} trying to remove is the same as the interpreter "
"to run this script. Please choose a different prefix or deactivate the "
"current virtual environment."
)
if self.prefix in [
Path(
self.base_python(
"-c",
f"import os, sys; print(os.path.abspath({p}))",
capture_output=True,
).stdout.strip()
).absolute()
for p in [
"sys.prefix",
"sys.exec_prefix",
"sys.base_prefix",
"sys.base_exec_prefix",
]
]:
raise RuntimeError(
f"The Python executable {self.base_executable} trying to remove is the "
"same as the interpreter to create the virtual environment. Please choose "
"a different prefix or a different Python interpreter."
)
print(f"Removing existing venv: {self.prefix}")
_remove_existing(self.prefix)
else:
raise RuntimeError(f"Path {self.prefix} already exists.")
print(f"Creating venv (Python {self.base_python_version()}): {self.prefix}")
self.base_python("-m", "venv", str(self.prefix))
assert self.is_venv(), "Failed to create virtual environment."
(self.prefix / ".gitignore").write_text("*\n", encoding="utf-8")
return self.ensure()
def ensure(self) -> Path:
"""Ensure the virtual environment exists."""
if not self.is_venv():
return self.create(remove_if_exists=True)
self.pip_install(*self.AGGRESSIVE_UPDATE_PACKAGES, upgrade=True)
return self.prefix
def python(
self,
*args: str,
python: Path | str | None = None,
**popen_kwargs: Any,
) -> subprocess.CompletedProcess[str]:
"""Run a Python command in the virtual environment."""
if python is None:
python = self.executable
cmd = [str(python), *args]
env = popen_kwargs.pop("env", None) or {}
return subprocess.run(
cmd,
check=True,
text=True,
encoding="utf-8",
env={**self._env, **env},
**popen_kwargs,
)
def base_python(
self,
*args: str,
**popen_kwargs: Any,
) -> subprocess.CompletedProcess[str]:
"""Run a Python command in the base environment."""
return self.python(*args, python=self.base_executable, **popen_kwargs)
def python_version(self, *, python: Path | str | None = None) -> str:
"""Get the Python version for the virtual environment."""
return self.python(
"-c",
(
"import sys; print('{0.major}.{0.minor}.{0.micro}{1}'."
"format(sys.version_info, getattr(sys, 'abiflags', '')))"
),
python=python,
capture_output=True,
).stdout.strip()
def base_python_version(self) -> str:
"""Get the Python version for the base environment."""
return self.python_version(python=self.base_executable)
def pip(self, *args: str, **popen_kwargs: Any) -> subprocess.CompletedProcess[str]:
"""Run a pip command in the virtual environment."""
return self.python("-m", "pip", *args, **popen_kwargs)
@timed("Installing packages")
def pip_install(
self,
*packages: str,
prerelease: bool = False,
upgrade: bool = False,
**popen_kwargs: Any,
) -> subprocess.CompletedProcess[str]:
"""Run a pip install command in the virtual environment."""
if upgrade:
args = ["--upgrade", *packages]
verb = "Upgrading"
else:
args = list(packages)
verb = "Installing"
if prerelease:
args = ["--pre", *args]
print(
f"{verb} package(s) ({self.pip_source.index_url}): "
f"{', '.join(map(os.path.basename, packages))}"
)
return self.pip("install", *args, **popen_kwargs)
@timed("Downloading packages")
def pip_download(
self,
*packages: str,
prerelease: bool = False,
**popen_kwargs: Any,
) -> list[Path]:
"""Download a package in the virtual environment."""
tmpdir = tempfile.TemporaryDirectory(prefix="pip-download-")
atexit.register(tmpdir.cleanup)
tempdir = Path(tmpdir.name).absolute()
print(
f"Downloading package(s) ({self.pip_source.index_url}): "
f"{', '.join(packages)}"
)
if prerelease:
args = ["--pre", *packages]
else:
args = list(packages)
self.pip("download", "--dest", str(tempdir), *args, **popen_kwargs)
files = list(tempdir.iterdir())
print(f"Downloaded {len(files)} file(s) to {tempdir}:")
for file in files:
print(f" - {file.name}")
return files
def wheel(
self,
*args: str,
**popen_kwargs: Any,
) -> subprocess.CompletedProcess[str]:
"""Run a wheel command in the virtual environment."""
return self.python("-m", "wheel", *args, **popen_kwargs)
@timed("Unpacking wheel file")
def wheel_unpack(
self,
wheel: Path | str,
dest: Path | str,
**popen_kwargs: Any,
) -> subprocess.CompletedProcess[str]:
"""Unpack a wheel into a directory."""
wheel = Path(wheel).absolute()
dest = Path(dest).absolute()
assert wheel.is_file() and wheel.suffix.lower() == ".whl"
return self.wheel("unpack", "--dest", str(dest), str(wheel), **popen_kwargs)
@contextlib.contextmanager
def extracted_wheel(self, wheel: Path | str) -> Generator[Path]:
"""Download and extract a wheel into a temporary directory."""
with tempfile.TemporaryDirectory(prefix="wheel-") as tempdir:
self.wheel_unpack(wheel, tempdir)
subdirs = [p for p in Path(tempdir).absolute().iterdir() if p.is_dir()]
if len(subdirs) != 1:
raise RuntimeError(
f"Expected exactly one directory in {tempdir}, "
f"got {[str(d) for d in subdirs]}."
)
yield subdirs[0]
def git(*args: str) -> list[str]:
return ["git", "-C", str(REPO_ROOT), *args]
@ -134,7 +490,10 @@ def logging_record_argv() -> None:
def logging_record_exception(e: BaseException) -> None:
(logging_run_dir() / "exception").write_text(type(e).__name__, encoding="utf-8")
text = f"{type(e).__name__}: {e}"
if isinstance(e, subprocess.CalledProcessError):
text += f"\n\nstdout: {e.stdout}\n\nstderr: {e.stderr}"
(logging_run_dir() / "exception").write_text(text, encoding="utf-8")
def logging_rotate() -> None:
@ -156,7 +515,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
stderr (INFO) and file handler (DEBUG).
"""
formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="")
root_logger = logging.getLogger("conda-pytorch")
root_logger = logging.getLogger("pytorch-nightly")
root_logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
@ -213,161 +572,13 @@ def check_branch(subcommand: str, branch: str | None) -> str | None:
return None
def check_conda_env_exists(name: str | None = None, prefix: str | None = None) -> bool:
"""Checks that the conda environment exists."""
if name is not None and prefix is not None:
raise ValueError("Cannot specify both --name and --prefix")
if name is None and prefix is None:
raise ValueError("Must specify either --name or --prefix")
try:
cmd = ["conda", "info", "--envs"]
output = subprocess.check_output(cmd, text=True, encoding="utf-8")
except subprocess.CalledProcessError:
logger = cast(logging.Logger, LOGGER)
logger.warning("Failed to list conda environments", exc_info=True)
return False
if name is not None:
return len(re.findall(rf"^{name}\s+", output, flags=re.MULTILINE)) > 0
assert prefix is not None
prefix = Path(prefix).absolute()
return len(re.findall(rf"\s+{prefix}$", output, flags=re.MULTILINE)) > 0
@contextlib.contextmanager
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
"""Timed context manager"""
start_time = time.perf_counter()
yield
logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time)
F = TypeVar("F", bound=Callable[..., Any])
def timed(prefix: str) -> Callable[[F], F]:
"""Decorator for timing functions"""
def dec(f: F) -> F:
@functools.wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
logger = cast(logging.Logger, LOGGER)
logger.info(prefix)
with timer(logger, prefix):
return f(*args, **kwargs)
return cast(F, wrapper)
return dec
def _make_channel_args(
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> list[str]:
args = []
for channel in channels:
args.extend(["--channel", channel])
if override_channels:
args.append("--override-channels")
return args
@timed("Solving conda environment")
def conda_solve(
specs: Iterable[str],
*,
name: str | None = None,
prefix: str | None = None,
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> tuple[list[str], str, str, bool, list[str]]:
"""Performs the conda solve and splits the deps from the package."""
# compute what environment to use
if prefix is not None:
existing_env = True
env_opts = ["--prefix", prefix]
elif name is not None:
existing_env = True
env_opts = ["--name", name]
else:
# create new environment
existing_env = False
env_opts = ["--name", DEFAULT_ENV_NAME]
# run solve
if existing_env:
cmd = [
"conda",
"install",
"--yes",
"--dry-run",
"--json",
*env_opts,
]
else:
cmd = [
"conda",
"create",
"--yes",
"--dry-run",
"--json",
"--name",
"__pytorch__",
]
channel_args = _make_channel_args(
channels=channels,
override_channels=override_channels,
)
cmd.extend(channel_args)
cmd.extend(specs)
stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
# parse solution
solve = json.loads(stdout)
link = solve["actions"]["LINK"]
deps = []
pytorch, platform = "", ""
for pkg in link:
url = URL_FORMAT.format(**pkg)
if pkg["name"] == "pytorch":
pytorch = url
platform = pkg["platform"]
else:
deps.append(url)
assert pytorch, "PyTorch package not found in solve"
assert platform, "Platform not found in solve"
return deps, pytorch, platform, existing_env, env_opts
@timed("Installing dependencies")
def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None:
def install_packages(venv: Venv, packages: Iterable[str]) -> None:
"""Install dependencies to deps environment"""
if not existing_env:
# first remove previous pytorch-deps env
if check_conda_env_exists(name=DEFAULT_ENV_NAME):
cmd = ["conda", "env", "remove", "--yes", *env_opts]
subprocess.check_output(cmd)
# install new deps
install_command = "install" if existing_env else "create"
cmd = ["conda", install_command, "--yes", "--no-deps", *env_opts, *deps]
subprocess.check_call(cmd)
@timed("Installing pytorch nightly binaries")
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
"""Install pytorch into a temporary directory"""
pytorch_dir = tempfile.TemporaryDirectory(prefix="conda-pytorch-")
cmd = ["conda", "create", "--yes", "--no-deps", f"--prefix={pytorch_dir.name}", url]
subprocess.check_call(cmd)
return pytorch_dir
def _site_packages(dirname: str, platform: str) -> Path:
if platform.startswith("win"):
template = os.path.join(dirname, "Lib", "site-packages")
else:
template = os.path.join(dirname, "lib", "python*.*", "site-packages")
return Path(next(glob.iglob(template))).absolute()
# install packages
packages = list(dict.fromkeys(packages))
if packages:
venv.pip_install(*packages)
def _ensure_commit(git_sha1: str) -> None:
@ -437,7 +648,7 @@ def _get_listing_linux(source_dir: Path) -> list[Path]:
)
def _get_listing_osx(source_dir: Path) -> list[Path]:
def _get_listing_macos(source_dir: Path) -> list[Path]:
# oddly, these are .so files even on Mac
return list(
itertools.chain(
@ -447,7 +658,7 @@ def _get_listing_osx(source_dir: Path) -> list[Path]:
)
def _get_listing_win(source_dir: Path) -> list[Path]:
def _get_listing_windows(source_dir: Path) -> list[Path]:
return list(
itertools.chain(
source_dir.glob("*.pyd"),
@ -468,15 +679,15 @@ def _find_missing_pyi(source_dir: Path, target_dir: Path) -> list[Path]:
return missing_pyis
def _get_listing(source_dir: Path, target_dir: Path, platform: str) -> list[Path]:
if platform.startswith("linux"):
def _get_listing(source_dir: Path, target_dir: Path) -> list[Path]:
if LINUX:
listing = _get_listing_linux(source_dir)
elif platform.startswith("osx"):
listing = _get_listing_osx(source_dir)
elif platform.startswith("win"):
listing = _get_listing_win(source_dir)
elif MACOS:
listing = _get_listing_macos(source_dir)
elif WINDOWS:
listing = _get_listing_windows(source_dir)
else:
raise RuntimeError(f"Platform {platform!r} not recognized")
raise RuntimeError(f"Platform {platform_system()!r} not recognized")
listing.extend(_find_missing_pyi(source_dir, target_dir))
listing.append(source_dir / "version.py")
listing.append(source_dir / "testing" / "_internal" / "generated")
@ -532,14 +743,14 @@ def _link_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None
@timed("Moving nightly files into repo")
def move_nightly_files(site_dir: Path, platform: str) -> None:
def move_nightly_files(site_dir: Path) -> None:
"""Moves PyTorch files from temporary installed location to repo."""
# get file listing
source_dir = site_dir / "torch"
target_dir = REPO_ROOT / "torch"
listing = _get_listing(source_dir, target_dir, platform)
listing = _get_listing(source_dir, target_dir)
# copy / link files
if platform.startswith("win"):
if WINDOWS:
_copy_files(listing, source_dir, target_dir)
else:
try:
@ -548,31 +759,10 @@ def move_nightly_files(site_dir: Path, platform: str) -> None:
_copy_files(listing, source_dir, target_dir)
def _available_envs() -> dict[str, str]:
cmd = ["conda", "env", "list"]
stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
envs = {}
for line in map(str.strip, stdout.splitlines()):
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) == 1:
# unnamed env
continue
envs[parts[0]] = parts[-1]
return envs
@timed("Writing pytorch-nightly.pth")
def write_pth(env_opts: list[str], platform: str) -> None:
def write_pth(venv: Venv) -> None:
"""Writes Python path file for this dir."""
env_type, env_dir = env_opts
if env_type == "--name":
# have to find directory
envs = _available_envs()
env_dir = envs[env_dir]
site_dir = _site_packages(env_dir, platform)
(site_dir / "pytorch-nightly.pth").write_text(
(venv.site_packages() / "pytorch-nightly.pth").write_text(
"# This file was autogenerated by PyTorch's tools/nightly.py\n"
"# Please delete this file if you no longer need the following development\n"
"# version of PyTorch to be importable\n"
@ -582,50 +772,66 @@ def write_pth(env_opts: list[str], platform: str) -> None:
def install(
specs: Iterable[str],
*,
logger: logging.Logger,
venv: Venv,
packages: Iterable[str],
subcommand: str = "checkout",
branch: str | None = None,
name: str | None = None,
prefix: str | None = None,
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
logger: logging.Logger,
) -> None:
"""Development install of PyTorch"""
specs = list(specs)
deps, pytorch, platform, existing_env, env_opts = conda_solve(
specs=specs,
name=name,
prefix=prefix,
channels=channels,
override_channels=override_channels,
)
if deps:
deps_install(deps, existing_env, env_opts)
use_existing = subcommand == "checkout"
if use_existing:
venv.ensure()
else:
venv.create(remove_if_exists=True)
with pytorch_install(pytorch) as pytorch_dir:
site_dir = _site_packages(pytorch_dir, platform)
packages = [p for p in packages if p != "torch"]
dependencies = venv.pip_download("torch", prerelease=True)
torch_wheel = [
dep
for dep in dependencies
if dep.name.startswith("torch-") and dep.name.endswith(".whl")
]
if len(torch_wheel) != 1:
raise RuntimeError(f"Expected exactly one torch wheel, got {torch_wheel}")
torch_wheel = torch_wheel[0]
dependencies = [deps for deps in dependencies if deps != torch_wheel]
install_packages(venv, [*packages, *map(str, dependencies)])
with venv.extracted_wheel(torch_wheel) as wheel_site_dir:
if subcommand == "checkout":
checkout_nightly_version(cast(str, branch), site_dir)
checkout_nightly_version(cast(str, branch), wheel_site_dir)
elif subcommand == "pull":
pull_nightly_version(site_dir)
pull_nightly_version(wheel_site_dir)
else:
raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.")
move_nightly_files(site_dir, platform)
move_nightly_files(wheel_site_dir)
write_pth(env_opts, platform)
write_pth(venv)
logger.info(
"-------\nPyTorch Development Environment set up!\nPlease activate to "
"enable this environment:\n $ conda activate %s",
env_opts[1],
"-------\n"
"PyTorch Development Environment set up!\n"
"Please activate to enable this environment:\n\n"
" $ %s",
venv.activate_command,
)
def make_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser()
def find_executable(name: str) -> Path:
executable = shutil.which(name)
if executable is None:
raise argparse.ArgumentTypeError(
f"Could not find executable {name} in PATH."
)
return Path(executable).absolute()
parser = argparse.ArgumentParser()
# subcommands
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
subcmd = parser.add_subparsers(dest="subcmd", help="subcommand to execute")
checkout = subcmd.add_parser("checkout", help="checkout a new branch")
checkout.add_argument(
"-b",
@ -642,19 +848,24 @@ def make_parser() -> argparse.ArgumentParser:
subparsers = [checkout, pull]
for subparser in subparsers:
subparser.add_argument(
"-n",
"--name",
help="Name of environment",
dest="name",
"--python",
"--base-executable",
type=find_executable,
help=(
"Path to Python interpreter to use for creating the virtual environment. "
"Defaults to the interpreter running this script."
),
dest="base_executable",
default=None,
metavar="ENVIRONMENT",
metavar="PYTHON",
)
subparser.add_argument(
"-p",
"--prefix",
help="Full path to environment location (i.e. prefix)",
type=lambda p: Path(p).absolute(),
help='Path to virtual environment directory (e.g. "./venv")',
dest="prefix",
default=None,
default=str(DEFAULT_VENV_DIR),
metavar="PATH",
)
subparser.add_argument(
@ -666,71 +877,70 @@ def make_parser() -> argparse.ArgumentParser:
action="store_true",
)
subparser.add_argument(
"--override-channels",
help="Do not search default or .condarc channels.",
dest="override_channels",
default=False,
action="store_true",
)
subparser.add_argument(
"-c",
"--channel",
"--cuda",
help=(
"Additional channel to search for packages. "
"'pytorch-nightly' will always be prepended to this list."
"CUDA version to install "
"(defaults to the latest version available on the platform)"
),
dest="channels",
action="append",
metavar="CHANNEL",
dest="cuda",
nargs="?",
default=argparse.SUPPRESS,
metavar="VERSION",
)
if platform_system() in {"Linux", "Windows"}:
subparser.add_argument(
"--cuda",
help=(
"CUDA version to install "
"(defaults to the latest version available on the platform)"
),
dest="cuda",
nargs="?",
default=argparse.SUPPRESS,
metavar="VERSION",
)
return p
return parser
def main(args: Sequence[str] | None = None) -> None:
def parse_arguments() -> argparse.Namespace:
parser = make_parser()
args = parser.parse_args()
args.branch = getattr(args, "branch", None)
return args
def main() -> None:
"""Main entry point"""
global LOGGER
p = make_parser()
ns = p.parse_args(args)
ns.branch = getattr(ns, "branch", None)
status = check_branch(ns.subcmd, ns.branch)
args = parse_arguments()
status = check_branch(args.subcmd, args.branch)
if status:
sys.exit(status)
specs = list(SPECS_TO_INSTALL)
channels = ["pytorch-nightly"]
if hasattr(ns, "cuda"):
if ns.cuda is not None:
specs.append(f"pytorch-cuda={ns.cuda}")
pip_source = None
if hasattr(args, "cuda"):
available_sources = {
src.name[len("cuda-") :]: src
for src in PIP_SOURCES.values()
if src.name.startswith("cuda-") and PLATFORM in src.supported_platforms
}
if not available_sources:
print(f"No CUDA versions available on platform {PLATFORM}.")
sys.exit(1)
if args.cuda is not None:
pip_source = available_sources.get(args.cuda)
if pip_source is None:
print(
f"CUDA {args.cuda} is not available on platform {PLATFORM}. "
f"Available version(s): {', '.join(sorted(available_sources, key=Version))}"
)
sys.exit(1)
else:
specs.append("pytorch-cuda")
specs.append("pytorch-mutex=*=*cuda*")
channels.append("nvidia")
pip_source = available_sources[max(available_sources, key=Version)]
else:
specs.append("pytorch-mutex=*=*cpu*")
if ns.channels:
channels.extend(ns.channels)
with logging_manager(debug=ns.verbose) as logger:
pip_source = PIP_SOURCES["cpu"] # always available
with logging_manager(debug=args.verbose) as logger:
LOGGER = logger
venv = Venv(
prefix=args.prefix,
pip_source=pip_source,
base_executable=args.base_executable,
)
install(
specs=specs,
subcommand=ns.subcmd,
branch=ns.branch,
name=ns.name,
prefix=ns.prefix,
venv=venv,
packages=PACKAGES_TO_INSTALL,
subcommand=args.subcmd,
branch=args.branch,
logger=logger,
channels=channels,
override_channels=ns.override_channels,
)