mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Refactor nightly pull tool to use venv
and pip
(#141281)
Resolves #141238 - #141238 Example output: ```console $ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10 log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log Creating virtual environment Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env Installing packages Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel Installing packages took 5.576 [s] Creating virtual environment took 9.505 [s] Downloading packages Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4: - mpmath-1.3.0-py3-none-any.whl - torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl - jinja2-3.1.4-py3-none-any.whl - sympy-1.13.1-py3-none-any.whl - MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl - networkx-3.4.2-py3-none-any.whl - fsspec-2024.10.0-py3-none-any.whl - filelock-3.16.1-py3-none-any.whl - typing_extensions-4.12.2-py3-none-any.whl Downloading packages took 7.628 [s] Installing dependencies Installing packages Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl Installing packages took 42.514 [s] Installing dependencies took 42.515 [s] Unpacking wheel file Unpacking wheel file took 3.223 [s] Checking out nightly PyTorch Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b Switched to a new branch 'my-nightly-branch' Checking out nightly PyTorch took 0.198 [s] Moving nightly files into repo Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib ... Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h Moving nightly files into repo took 11.426 [s] Writing pytorch-nightly.pth Writing pytorch-nightly.pth took 0.036 [s] ------- PyTorch Development Environment set up! Please activate to enable this environment: $ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281 Approved by: https://github.com/seemethere
This commit is contained in:
committed by
PyTorch MergeBot
parent
75cecba164
commit
2a6eaa2e6f
@ -15,28 +15,21 @@ import os
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: Also update the CUDA sources in tools/nightly.py when changing this list
|
||||||
CUDA_ARCHES = ["11.8", "12.4", "12.6"]
|
CUDA_ARCHES = ["11.8", "12.4", "12.6"]
|
||||||
|
|
||||||
|
|
||||||
CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.4": "12.4.1", "12.6": "12.6.2"}
|
CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.4": "12.4.1", "12.6": "12.6.2"}
|
||||||
|
|
||||||
|
|
||||||
CUDA_ARCHES_CUDNN_VERSION = {"11.8": "9", "12.4": "9", "12.6": "9"}
|
CUDA_ARCHES_CUDNN_VERSION = {"11.8": "9", "12.4": "9", "12.6": "9"}
|
||||||
|
|
||||||
|
|
||||||
ROCM_ARCHES = ["6.1", "6.2"]
|
ROCM_ARCHES = ["6.1", "6.2"]
|
||||||
|
|
||||||
XPU_ARCHES = ["xpu"]
|
XPU_ARCHES = ["xpu"]
|
||||||
|
|
||||||
CPU_CXX11_ABI_ARCH = ["cpu-cxx11-abi"]
|
CPU_CXX11_ABI_ARCH = ["cpu-cxx11-abi"]
|
||||||
|
|
||||||
|
|
||||||
CPU_AARCH64_ARCH = ["cpu-aarch64"]
|
CPU_AARCH64_ARCH = ["cpu-aarch64"]
|
||||||
|
|
||||||
|
|
||||||
CPU_S390X_ARCH = ["cpu-s390x"]
|
CPU_S390X_ARCH = ["cpu-s390x"]
|
||||||
|
|
||||||
|
|
||||||
CUDA_AARCH64_ARCH = ["cuda-aarch64"]
|
CUDA_AARCH64_ARCH = ["cuda-aarch64"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ cd pytorch
|
|||||||
git remote add upstream git@github.com:pytorch/pytorch.git
|
git remote add upstream git@github.com:pytorch/pytorch.git
|
||||||
|
|
||||||
make setup-env # or make setup-env-cuda for pre-built CUDA binaries
|
make setup-env # or make setup-env-cuda for pre-built CUDA binaries
|
||||||
conda activate pytorch-deps
|
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||||
```
|
```
|
||||||
|
|
||||||
### Tips and Debugging
|
### Tips and Debugging
|
||||||
@ -166,7 +166,7 @@ conda activate pytorch-deps
|
|||||||
## Nightly Checkout & Pull
|
## Nightly Checkout & Pull
|
||||||
|
|
||||||
The `tools/nightly.py` script is provided to ease pure Python development of
|
The `tools/nightly.py` script is provided to ease pure Python development of
|
||||||
PyTorch. This uses `conda` and `git` to check out the nightly development
|
PyTorch. This uses `venv` and `git` to check out the nightly development
|
||||||
version of PyTorch and installs pre-built binaries into the current repository.
|
version of PyTorch and installs pre-built binaries into the current repository.
|
||||||
This is like a development or editable install, but without needing the ability
|
This is like a development or editable install, but without needing the ability
|
||||||
to compile any C++ code.
|
to compile any C++ code.
|
||||||
@ -175,33 +175,33 @@ You can use this script to check out a new nightly branch with the following:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
./tools/nightly.py checkout -b my-nightly-branch
|
./tools/nightly.py checkout -b my-nightly-branch
|
||||||
conda activate pytorch-deps
|
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||||
```
|
```
|
||||||
|
|
||||||
Or if you would like to re-use an existing conda environment, you can pass in
|
Or if you would like to re-use an existing conda environment, you can pass in
|
||||||
the regular environment parameters (`--name` or `--prefix`):
|
the prefix argument (`--prefix`):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./tools/nightly.py checkout -b my-nightly-branch -n my-env
|
./tools/nightly.py checkout -b my-nightly-branch -p my-env
|
||||||
conda activate my-env
|
source my-env/bin/activate # or `& .\my-env\Scripts\Activate.ps1` on Windows
|
||||||
```
|
```
|
||||||
|
|
||||||
To install the nightly binaries built with CUDA, you can pass in the flag `--cuda`:
|
To install the nightly binaries built with CUDA, you can pass in the flag `--cuda`:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./tools/nightly.py checkout -b my-nightly-branch --cuda
|
./tools/nightly.py checkout -b my-nightly-branch --cuda
|
||||||
conda activate pytorch-deps
|
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also use this tool to pull the nightly commits into the current branch:
|
You can also use this tool to pull the nightly commits into the current branch:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./tools/nightly.py pull -n my-env
|
./tools/nightly.py pull -p my-env
|
||||||
conda activate my-env
|
source my-env/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||||
```
|
```
|
||||||
|
|
||||||
Pulling will reinstall the PyTorch dependencies as well as the nightly binaries
|
Pulling will recreate a fresh virtual environment and reinstall the development
|
||||||
into the repo directory.
|
dependencies as well as the nightly binaries into the repo directory.
|
||||||
|
|
||||||
## Codebase structure
|
## Codebase structure
|
||||||
|
|
||||||
|
780
tools/nightly.py
780
tools/nightly.py
@ -1,43 +1,42 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Much of the logging code here was forked from https://github.com/ezyang/ghstack
|
# Much of the logging code here was forked from https://github.com/ezyang/ghstack
|
||||||
# Copyright (c) Edward Z. Yang <ezyang@mit.edu>
|
# Copyright (c) Edward Z. Yang <ezyang@mit.edu>
|
||||||
"""Checks out the nightly development version of PyTorch and installs pre-built
|
r"""Checks out the nightly development version of PyTorch and installs pre-built
|
||||||
binaries into the repo.
|
binaries into the repo.
|
||||||
|
|
||||||
You can use this script to check out a new nightly branch with the following::
|
You can use this script to check out a new nightly branch with the following::
|
||||||
|
|
||||||
$ ./tools/nightly.py checkout -b my-nightly-branch
|
$ ./tools/nightly.py checkout -b my-nightly-branch
|
||||||
$ conda activate pytorch-deps
|
$ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||||
|
|
||||||
Or if you would like to re-use an existing conda environment, you can pass in
|
Or if you would like to re-use an existing virtual environment, you can pass in
|
||||||
the regular environment parameters (--name or --prefix)::
|
the prefix argument (--prefix)::
|
||||||
|
|
||||||
$ ./tools/nightly.py checkout -b my-nightly-branch -n my-env
|
$ ./tools/nightly.py checkout -b my-nightly-branch -p my-env
|
||||||
$ conda activate my-env
|
$ source my-env/bin/activate # or `& .\my-env\Scripts\Activate.ps1` on Windows
|
||||||
|
|
||||||
To install the nightly binaries built with CUDA, you can pass in the flag --cuda::
|
To install the nightly binaries built with CUDA, you can pass in the flag --cuda::
|
||||||
|
|
||||||
$ ./tools/nightly.py checkout -b my-nightly-branch --cuda
|
$ ./tools/nightly.py checkout -b my-nightly-branch --cuda
|
||||||
$ conda activate pytorch-deps
|
$ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||||
|
|
||||||
You can also use this tool to pull the nightly commits into the current branch as
|
You can also use this tool to pull the nightly commits into the current branch as
|
||||||
well. This can be done with::
|
well. This can be done with::
|
||||||
|
|
||||||
$ ./tools/nightly.py pull -n my-env
|
$ ./tools/nightly.py pull
|
||||||
$ conda activate my-env
|
$ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||||
|
|
||||||
Pulling will reinstall the conda dependencies as well as the nightly binaries into
|
Pulling will recreate a fresh virtual environment and reinstall the development
|
||||||
the repo directory.
|
dependencies as well as the nightly binaries into the repo directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import atexit
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import glob
|
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -51,16 +50,46 @@ from ast import literal_eval
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from platform import system as platform_system
|
from platform import system as platform_system
|
||||||
from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
cast,
|
||||||
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
NamedTuple,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from packaging.version import Version
|
||||||
|
except ImportError:
|
||||||
|
Version = None # type: ignore[assignment,misc]
|
||||||
|
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).absolute().parent.parent
|
REPO_ROOT = Path(__file__).absolute().parent.parent
|
||||||
GITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git"
|
GITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git"
|
||||||
SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx")
|
PACKAGES_TO_INSTALL = (
|
||||||
DEFAULT_ENV_NAME = "pytorch-deps"
|
"torch",
|
||||||
|
"numpy",
|
||||||
|
"cmake",
|
||||||
|
"ninja",
|
||||||
|
"packaging",
|
||||||
|
"ruff",
|
||||||
|
"mypy",
|
||||||
|
"pytest",
|
||||||
|
"hypothesis",
|
||||||
|
"ipython",
|
||||||
|
"rich",
|
||||||
|
"clang-format",
|
||||||
|
"clang-tidy",
|
||||||
|
"sphinx",
|
||||||
|
)
|
||||||
|
DEFAULT_VENV_DIR = REPO_ROOT / "venv"
|
||||||
|
|
||||||
|
|
||||||
LOGGER: logging.Logger | None = None
|
LOGGER: logging.Logger | None = None
|
||||||
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
|
|
||||||
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
|
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
|
||||||
SHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})")
|
SHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})")
|
||||||
USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
|
USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
|
||||||
@ -70,6 +99,49 @@ LOG_DIRNAME_RE = re.compile(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PLATFORM = platform_system().replace("Darwin", "macOS")
|
||||||
|
LINUX = PLATFORM == "Linux"
|
||||||
|
MACOS = PLATFORM == "macOS"
|
||||||
|
WINDOWS = PLATFORM == "Windows"
|
||||||
|
POSIX = LINUX or MACOS
|
||||||
|
|
||||||
|
|
||||||
|
class PipSource(NamedTuple):
|
||||||
|
name: str
|
||||||
|
index_url: str
|
||||||
|
supported_platforms: set[str]
|
||||||
|
accelerator: str
|
||||||
|
|
||||||
|
|
||||||
|
PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly"
|
||||||
|
PIP_SOURCES = {
|
||||||
|
"cpu": PipSource(
|
||||||
|
name="cpu",
|
||||||
|
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu",
|
||||||
|
supported_platforms={"Linux", "macOS", "Windows"},
|
||||||
|
accelerator="cpu",
|
||||||
|
),
|
||||||
|
"cuda-11.8": PipSource(
|
||||||
|
name="cuda-11.8",
|
||||||
|
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu118",
|
||||||
|
supported_platforms={"Linux", "Windows"},
|
||||||
|
accelerator="cuda",
|
||||||
|
),
|
||||||
|
"cuda-12.1": PipSource(
|
||||||
|
name="cuda-12.1",
|
||||||
|
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu121",
|
||||||
|
supported_platforms={"Linux", "Windows"},
|
||||||
|
accelerator="cuda",
|
||||||
|
),
|
||||||
|
"cuda-12.4": PipSource(
|
||||||
|
name="cuda-12.4",
|
||||||
|
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu124",
|
||||||
|
supported_platforms={"Linux", "Windows"},
|
||||||
|
accelerator="cuda",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Formatter(logging.Formatter):
|
class Formatter(logging.Formatter):
|
||||||
redactions: dict[str, str]
|
redactions: dict[str, str]
|
||||||
|
|
||||||
@ -108,6 +180,290 @@ class Formatter(logging.Formatter):
|
|||||||
self.redactions[needle] = replace
|
self.redactions[needle] = replace
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
|
||||||
|
"""Timed context manager"""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
yield
|
||||||
|
logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time)
|
||||||
|
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
def timed(prefix: str) -> Callable[[F], F]:
|
||||||
|
"""Decorator for timing functions"""
|
||||||
|
|
||||||
|
def decorator(f: F) -> F:
|
||||||
|
@functools.wraps(f)
|
||||||
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
logger = cast(logging.Logger, LOGGER)
|
||||||
|
logger.info(prefix)
|
||||||
|
with timer(logger, prefix):
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return cast(F, wrapper)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class Venv:
|
||||||
|
"""Virtual environment manager"""
|
||||||
|
|
||||||
|
AGGRESSIVE_UPDATE_PACKAGES = ("pip", "setuptools", "packaging", "wheel")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: Path | str,
|
||||||
|
pip_source: PipSource,
|
||||||
|
*,
|
||||||
|
base_executable: Path | str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.prefix = Path(prefix).absolute()
|
||||||
|
self.pip_source = pip_source
|
||||||
|
self.base_executable = Path(base_executable or sys.executable).absolute()
|
||||||
|
self._executable: Path | None = None
|
||||||
|
self._env = {"PIP_EXTRA_INDEX_URL": self.pip_source.index_url}
|
||||||
|
|
||||||
|
def is_venv(self) -> bool:
|
||||||
|
"""Check if the prefix is a virtual environment."""
|
||||||
|
return self.prefix.is_dir() and (self.prefix / "pyvenv.cfg").is_file()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def executable(self) -> Path:
|
||||||
|
"""Get the Python executable for the virtual environment."""
|
||||||
|
assert self.is_venv()
|
||||||
|
if self._executable is None:
|
||||||
|
if WINDOWS:
|
||||||
|
executable = self.prefix / "Scripts" / "python.exe"
|
||||||
|
else:
|
||||||
|
executable = self.prefix / "bin" / "python"
|
||||||
|
assert executable.is_file() or executable.is_symlink()
|
||||||
|
assert os.access(executable, os.X_OK), f"{executable} is not executable"
|
||||||
|
self._executable = executable
|
||||||
|
return self._executable
|
||||||
|
|
||||||
|
def site_packages(self, python: Path | str | None = None) -> Path:
|
||||||
|
"""Get the site-packages directory for the virtual environment."""
|
||||||
|
output = self.python(
|
||||||
|
"-c",
|
||||||
|
"import site; [print(p) for p in site.getsitepackages()]",
|
||||||
|
python=python,
|
||||||
|
capture_output=True,
|
||||||
|
).stdout
|
||||||
|
candidates = list(map(Path, filter(None, map(str.strip, output.splitlines()))))
|
||||||
|
candidates = [p for p in candidates if p.is_dir() and p.name == "site-packages"]
|
||||||
|
if not candidates:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No site-packages directory found for excecutable {python}"
|
||||||
|
)
|
||||||
|
return candidates[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activate_command(self) -> str:
|
||||||
|
"""Get the command to activate the virtual environment."""
|
||||||
|
if WINDOWS:
|
||||||
|
# Assume PowerShell
|
||||||
|
return f"& {self.prefix / 'Scripts' / 'Activate.ps1'}"
|
||||||
|
return f"source {self.prefix}/bin/activate"
|
||||||
|
|
||||||
|
@timed("Creating virtual environment")
|
||||||
|
def create(self, *, remove_if_exists: bool = False) -> Path:
|
||||||
|
"""Create a virtual environment."""
|
||||||
|
if self.prefix.exists():
|
||||||
|
if remove_if_exists:
|
||||||
|
# If the venv directory already exists, remove it first
|
||||||
|
if not self.is_venv():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The path {self.prefix} already exists and is not a virtual environment. "
|
||||||
|
"Please remove it manually or choose a different prefix."
|
||||||
|
)
|
||||||
|
if self.prefix in [
|
||||||
|
Path(p).absolute()
|
||||||
|
for p in [
|
||||||
|
sys.prefix,
|
||||||
|
sys.exec_prefix,
|
||||||
|
sys.base_prefix,
|
||||||
|
sys.base_exec_prefix,
|
||||||
|
]
|
||||||
|
]:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The path {self.prefix} trying to remove is the same as the interpreter "
|
||||||
|
"to run this script. Please choose a different prefix or deactivate the "
|
||||||
|
"current virtual environment."
|
||||||
|
)
|
||||||
|
if self.prefix in [
|
||||||
|
Path(
|
||||||
|
self.base_python(
|
||||||
|
"-c",
|
||||||
|
f"import os, sys; print(os.path.abspath({p}))",
|
||||||
|
capture_output=True,
|
||||||
|
).stdout.strip()
|
||||||
|
).absolute()
|
||||||
|
for p in [
|
||||||
|
"sys.prefix",
|
||||||
|
"sys.exec_prefix",
|
||||||
|
"sys.base_prefix",
|
||||||
|
"sys.base_exec_prefix",
|
||||||
|
]
|
||||||
|
]:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The Python executable {self.base_executable} trying to remove is the "
|
||||||
|
"same as the interpreter to create the virtual environment. Please choose "
|
||||||
|
"a different prefix or a different Python interpreter."
|
||||||
|
)
|
||||||
|
print(f"Removing existing venv: {self.prefix}")
|
||||||
|
_remove_existing(self.prefix)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Path {self.prefix} already exists.")
|
||||||
|
|
||||||
|
print(f"Creating venv (Python {self.base_python_version()}): {self.prefix}")
|
||||||
|
self.base_python("-m", "venv", str(self.prefix))
|
||||||
|
assert self.is_venv(), "Failed to create virtual environment."
|
||||||
|
(self.prefix / ".gitignore").write_text("*\n", encoding="utf-8")
|
||||||
|
return self.ensure()
|
||||||
|
|
||||||
|
def ensure(self) -> Path:
|
||||||
|
"""Ensure the virtual environment exists."""
|
||||||
|
if not self.is_venv():
|
||||||
|
return self.create(remove_if_exists=True)
|
||||||
|
|
||||||
|
self.pip_install(*self.AGGRESSIVE_UPDATE_PACKAGES, upgrade=True)
|
||||||
|
return self.prefix
|
||||||
|
|
||||||
|
def python(
|
||||||
|
self,
|
||||||
|
*args: str,
|
||||||
|
python: Path | str | None = None,
|
||||||
|
**popen_kwargs: Any,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
"""Run a Python command in the virtual environment."""
|
||||||
|
if python is None:
|
||||||
|
python = self.executable
|
||||||
|
cmd = [str(python), *args]
|
||||||
|
env = popen_kwargs.pop("env", None) or {}
|
||||||
|
return subprocess.run(
|
||||||
|
cmd,
|
||||||
|
check=True,
|
||||||
|
text=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
env={**self._env, **env},
|
||||||
|
**popen_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def base_python(
|
||||||
|
self,
|
||||||
|
*args: str,
|
||||||
|
**popen_kwargs: Any,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
"""Run a Python command in the base environment."""
|
||||||
|
return self.python(*args, python=self.base_executable, **popen_kwargs)
|
||||||
|
|
||||||
|
def python_version(self, *, python: Path | str | None = None) -> str:
|
||||||
|
"""Get the Python version for the virtual environment."""
|
||||||
|
return self.python(
|
||||||
|
"-c",
|
||||||
|
(
|
||||||
|
"import sys; print('{0.major}.{0.minor}.{0.micro}{1}'."
|
||||||
|
"format(sys.version_info, getattr(sys, 'abiflags', '')))"
|
||||||
|
),
|
||||||
|
python=python,
|
||||||
|
capture_output=True,
|
||||||
|
).stdout.strip()
|
||||||
|
|
||||||
|
def base_python_version(self) -> str:
|
||||||
|
"""Get the Python version for the base environment."""
|
||||||
|
return self.python_version(python=self.base_executable)
|
||||||
|
|
||||||
|
def pip(self, *args: str, **popen_kwargs: Any) -> subprocess.CompletedProcess[str]:
|
||||||
|
"""Run a pip command in the virtual environment."""
|
||||||
|
return self.python("-m", "pip", *args, **popen_kwargs)
|
||||||
|
|
||||||
|
@timed("Installing packages")
|
||||||
|
def pip_install(
|
||||||
|
self,
|
||||||
|
*packages: str,
|
||||||
|
prerelease: bool = False,
|
||||||
|
upgrade: bool = False,
|
||||||
|
**popen_kwargs: Any,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
"""Run a pip install command in the virtual environment."""
|
||||||
|
if upgrade:
|
||||||
|
args = ["--upgrade", *packages]
|
||||||
|
verb = "Upgrading"
|
||||||
|
else:
|
||||||
|
args = list(packages)
|
||||||
|
verb = "Installing"
|
||||||
|
if prerelease:
|
||||||
|
args = ["--pre", *args]
|
||||||
|
print(
|
||||||
|
f"{verb} package(s) ({self.pip_source.index_url}): "
|
||||||
|
f"{', '.join(map(os.path.basename, packages))}"
|
||||||
|
)
|
||||||
|
return self.pip("install", *args, **popen_kwargs)
|
||||||
|
|
||||||
|
@timed("Downloading packages")
|
||||||
|
def pip_download(
|
||||||
|
self,
|
||||||
|
*packages: str,
|
||||||
|
prerelease: bool = False,
|
||||||
|
**popen_kwargs: Any,
|
||||||
|
) -> list[Path]:
|
||||||
|
"""Download a package in the virtual environment."""
|
||||||
|
tmpdir = tempfile.TemporaryDirectory(prefix="pip-download-")
|
||||||
|
atexit.register(tmpdir.cleanup)
|
||||||
|
tempdir = Path(tmpdir.name).absolute()
|
||||||
|
print(
|
||||||
|
f"Downloading package(s) ({self.pip_source.index_url}): "
|
||||||
|
f"{', '.join(packages)}"
|
||||||
|
)
|
||||||
|
if prerelease:
|
||||||
|
args = ["--pre", *packages]
|
||||||
|
else:
|
||||||
|
args = list(packages)
|
||||||
|
self.pip("download", "--dest", str(tempdir), *args, **popen_kwargs)
|
||||||
|
files = list(tempdir.iterdir())
|
||||||
|
print(f"Downloaded {len(files)} file(s) to {tempdir}:")
|
||||||
|
for file in files:
|
||||||
|
print(f" - {file.name}")
|
||||||
|
return files
|
||||||
|
|
||||||
|
def wheel(
|
||||||
|
self,
|
||||||
|
*args: str,
|
||||||
|
**popen_kwargs: Any,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
"""Run a wheel command in the virtual environment."""
|
||||||
|
return self.python("-m", "wheel", *args, **popen_kwargs)
|
||||||
|
|
||||||
|
@timed("Unpacking wheel file")
|
||||||
|
def wheel_unpack(
|
||||||
|
self,
|
||||||
|
wheel: Path | str,
|
||||||
|
dest: Path | str,
|
||||||
|
**popen_kwargs: Any,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
"""Unpack a wheel into a directory."""
|
||||||
|
wheel = Path(wheel).absolute()
|
||||||
|
dest = Path(dest).absolute()
|
||||||
|
assert wheel.is_file() and wheel.suffix.lower() == ".whl"
|
||||||
|
return self.wheel("unpack", "--dest", str(dest), str(wheel), **popen_kwargs)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def extracted_wheel(self, wheel: Path | str) -> Generator[Path]:
|
||||||
|
"""Download and extract a wheel into a temporary directory."""
|
||||||
|
with tempfile.TemporaryDirectory(prefix="wheel-") as tempdir:
|
||||||
|
self.wheel_unpack(wheel, tempdir)
|
||||||
|
subdirs = [p for p in Path(tempdir).absolute().iterdir() if p.is_dir()]
|
||||||
|
if len(subdirs) != 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Expected exactly one directory in {tempdir}, "
|
||||||
|
f"got {[str(d) for d in subdirs]}."
|
||||||
|
)
|
||||||
|
yield subdirs[0]
|
||||||
|
|
||||||
|
|
||||||
def git(*args: str) -> list[str]:
|
def git(*args: str) -> list[str]:
|
||||||
return ["git", "-C", str(REPO_ROOT), *args]
|
return ["git", "-C", str(REPO_ROOT), *args]
|
||||||
|
|
||||||
@ -134,7 +490,10 @@ def logging_record_argv() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def logging_record_exception(e: BaseException) -> None:
|
def logging_record_exception(e: BaseException) -> None:
|
||||||
(logging_run_dir() / "exception").write_text(type(e).__name__, encoding="utf-8")
|
text = f"{type(e).__name__}: {e}"
|
||||||
|
if isinstance(e, subprocess.CalledProcessError):
|
||||||
|
text += f"\n\nstdout: {e.stdout}\n\nstderr: {e.stderr}"
|
||||||
|
(logging_run_dir() / "exception").write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
def logging_rotate() -> None:
|
def logging_rotate() -> None:
|
||||||
@ -156,7 +515,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
|
|||||||
stderr (INFO) and file handler (DEBUG).
|
stderr (INFO) and file handler (DEBUG).
|
||||||
"""
|
"""
|
||||||
formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="")
|
formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="")
|
||||||
root_logger = logging.getLogger("conda-pytorch")
|
root_logger = logging.getLogger("pytorch-nightly")
|
||||||
root_logger.setLevel(logging.DEBUG)
|
root_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
@ -213,161 +572,13 @@ def check_branch(subcommand: str, branch: str | None) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_conda_env_exists(name: str | None = None, prefix: str | None = None) -> bool:
|
|
||||||
"""Checks that the conda environment exists."""
|
|
||||||
if name is not None and prefix is not None:
|
|
||||||
raise ValueError("Cannot specify both --name and --prefix")
|
|
||||||
if name is None and prefix is None:
|
|
||||||
raise ValueError("Must specify either --name or --prefix")
|
|
||||||
|
|
||||||
try:
|
|
||||||
cmd = ["conda", "info", "--envs"]
|
|
||||||
output = subprocess.check_output(cmd, text=True, encoding="utf-8")
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
logger = cast(logging.Logger, LOGGER)
|
|
||||||
logger.warning("Failed to list conda environments", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if name is not None:
|
|
||||||
return len(re.findall(rf"^{name}\s+", output, flags=re.MULTILINE)) > 0
|
|
||||||
assert prefix is not None
|
|
||||||
prefix = Path(prefix).absolute()
|
|
||||||
return len(re.findall(rf"\s+{prefix}$", output, flags=re.MULTILINE)) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
|
|
||||||
"""Timed context manager"""
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
yield
|
|
||||||
logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time)
|
|
||||||
|
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable[..., Any])
|
|
||||||
|
|
||||||
|
|
||||||
def timed(prefix: str) -> Callable[[F], F]:
|
|
||||||
"""Decorator for timing functions"""
|
|
||||||
|
|
||||||
def dec(f: F) -> F:
|
|
||||||
@functools.wraps(f)
|
|
||||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
logger = cast(logging.Logger, LOGGER)
|
|
||||||
logger.info(prefix)
|
|
||||||
with timer(logger, prefix):
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
|
|
||||||
return cast(F, wrapper)
|
|
||||||
|
|
||||||
return dec
|
|
||||||
|
|
||||||
|
|
||||||
def _make_channel_args(
|
|
||||||
channels: Iterable[str] = ("pytorch-nightly",),
|
|
||||||
override_channels: bool = False,
|
|
||||||
) -> list[str]:
|
|
||||||
args = []
|
|
||||||
for channel in channels:
|
|
||||||
args.extend(["--channel", channel])
|
|
||||||
if override_channels:
|
|
||||||
args.append("--override-channels")
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
@timed("Solving conda environment")
|
|
||||||
def conda_solve(
|
|
||||||
specs: Iterable[str],
|
|
||||||
*,
|
|
||||||
name: str | None = None,
|
|
||||||
prefix: str | None = None,
|
|
||||||
channels: Iterable[str] = ("pytorch-nightly",),
|
|
||||||
override_channels: bool = False,
|
|
||||||
) -> tuple[list[str], str, str, bool, list[str]]:
|
|
||||||
"""Performs the conda solve and splits the deps from the package."""
|
|
||||||
# compute what environment to use
|
|
||||||
if prefix is not None:
|
|
||||||
existing_env = True
|
|
||||||
env_opts = ["--prefix", prefix]
|
|
||||||
elif name is not None:
|
|
||||||
existing_env = True
|
|
||||||
env_opts = ["--name", name]
|
|
||||||
else:
|
|
||||||
# create new environment
|
|
||||||
existing_env = False
|
|
||||||
env_opts = ["--name", DEFAULT_ENV_NAME]
|
|
||||||
# run solve
|
|
||||||
if existing_env:
|
|
||||||
cmd = [
|
|
||||||
"conda",
|
|
||||||
"install",
|
|
||||||
"--yes",
|
|
||||||
"--dry-run",
|
|
||||||
"--json",
|
|
||||||
*env_opts,
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
cmd = [
|
|
||||||
"conda",
|
|
||||||
"create",
|
|
||||||
"--yes",
|
|
||||||
"--dry-run",
|
|
||||||
"--json",
|
|
||||||
"--name",
|
|
||||||
"__pytorch__",
|
|
||||||
]
|
|
||||||
channel_args = _make_channel_args(
|
|
||||||
channels=channels,
|
|
||||||
override_channels=override_channels,
|
|
||||||
)
|
|
||||||
cmd.extend(channel_args)
|
|
||||||
cmd.extend(specs)
|
|
||||||
stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
|
|
||||||
# parse solution
|
|
||||||
solve = json.loads(stdout)
|
|
||||||
link = solve["actions"]["LINK"]
|
|
||||||
deps = []
|
|
||||||
pytorch, platform = "", ""
|
|
||||||
for pkg in link:
|
|
||||||
url = URL_FORMAT.format(**pkg)
|
|
||||||
if pkg["name"] == "pytorch":
|
|
||||||
pytorch = url
|
|
||||||
platform = pkg["platform"]
|
|
||||||
else:
|
|
||||||
deps.append(url)
|
|
||||||
assert pytorch, "PyTorch package not found in solve"
|
|
||||||
assert platform, "Platform not found in solve"
|
|
||||||
return deps, pytorch, platform, existing_env, env_opts
|
|
||||||
|
|
||||||
|
|
||||||
@timed("Installing dependencies")
|
@timed("Installing dependencies")
|
||||||
def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None:
|
def install_packages(venv: Venv, packages: Iterable[str]) -> None:
|
||||||
"""Install dependencies to deps environment"""
|
"""Install dependencies to deps environment"""
|
||||||
if not existing_env:
|
# install packages
|
||||||
# first remove previous pytorch-deps env
|
packages = list(dict.fromkeys(packages))
|
||||||
if check_conda_env_exists(name=DEFAULT_ENV_NAME):
|
if packages:
|
||||||
cmd = ["conda", "env", "remove", "--yes", *env_opts]
|
venv.pip_install(*packages)
|
||||||
subprocess.check_output(cmd)
|
|
||||||
# install new deps
|
|
||||||
install_command = "install" if existing_env else "create"
|
|
||||||
cmd = ["conda", install_command, "--yes", "--no-deps", *env_opts, *deps]
|
|
||||||
subprocess.check_call(cmd)
|
|
||||||
|
|
||||||
|
|
||||||
@timed("Installing pytorch nightly binaries")
|
|
||||||
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
|
|
||||||
"""Install pytorch into a temporary directory"""
|
|
||||||
pytorch_dir = tempfile.TemporaryDirectory(prefix="conda-pytorch-")
|
|
||||||
cmd = ["conda", "create", "--yes", "--no-deps", f"--prefix={pytorch_dir.name}", url]
|
|
||||||
subprocess.check_call(cmd)
|
|
||||||
return pytorch_dir
|
|
||||||
|
|
||||||
|
|
||||||
def _site_packages(dirname: str, platform: str) -> Path:
|
|
||||||
if platform.startswith("win"):
|
|
||||||
template = os.path.join(dirname, "Lib", "site-packages")
|
|
||||||
else:
|
|
||||||
template = os.path.join(dirname, "lib", "python*.*", "site-packages")
|
|
||||||
return Path(next(glob.iglob(template))).absolute()
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_commit(git_sha1: str) -> None:
|
def _ensure_commit(git_sha1: str) -> None:
|
||||||
@ -437,7 +648,7 @@ def _get_listing_linux(source_dir: Path) -> list[Path]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_listing_osx(source_dir: Path) -> list[Path]:
|
def _get_listing_macos(source_dir: Path) -> list[Path]:
|
||||||
# oddly, these are .so files even on Mac
|
# oddly, these are .so files even on Mac
|
||||||
return list(
|
return list(
|
||||||
itertools.chain(
|
itertools.chain(
|
||||||
@ -447,7 +658,7 @@ def _get_listing_osx(source_dir: Path) -> list[Path]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_listing_win(source_dir: Path) -> list[Path]:
|
def _get_listing_windows(source_dir: Path) -> list[Path]:
|
||||||
return list(
|
return list(
|
||||||
itertools.chain(
|
itertools.chain(
|
||||||
source_dir.glob("*.pyd"),
|
source_dir.glob("*.pyd"),
|
||||||
@ -468,15 +679,15 @@ def _find_missing_pyi(source_dir: Path, target_dir: Path) -> list[Path]:
|
|||||||
return missing_pyis
|
return missing_pyis
|
||||||
|
|
||||||
|
|
||||||
def _get_listing(source_dir: Path, target_dir: Path, platform: str) -> list[Path]:
|
def _get_listing(source_dir: Path, target_dir: Path) -> list[Path]:
|
||||||
if platform.startswith("linux"):
|
if LINUX:
|
||||||
listing = _get_listing_linux(source_dir)
|
listing = _get_listing_linux(source_dir)
|
||||||
elif platform.startswith("osx"):
|
elif MACOS:
|
||||||
listing = _get_listing_osx(source_dir)
|
listing = _get_listing_macos(source_dir)
|
||||||
elif platform.startswith("win"):
|
elif WINDOWS:
|
||||||
listing = _get_listing_win(source_dir)
|
listing = _get_listing_windows(source_dir)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Platform {platform!r} not recognized")
|
raise RuntimeError(f"Platform {platform_system()!r} not recognized")
|
||||||
listing.extend(_find_missing_pyi(source_dir, target_dir))
|
listing.extend(_find_missing_pyi(source_dir, target_dir))
|
||||||
listing.append(source_dir / "version.py")
|
listing.append(source_dir / "version.py")
|
||||||
listing.append(source_dir / "testing" / "_internal" / "generated")
|
listing.append(source_dir / "testing" / "_internal" / "generated")
|
||||||
@ -532,14 +743,14 @@ def _link_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None
|
|||||||
|
|
||||||
|
|
||||||
@timed("Moving nightly files into repo")
|
@timed("Moving nightly files into repo")
|
||||||
def move_nightly_files(site_dir: Path, platform: str) -> None:
|
def move_nightly_files(site_dir: Path) -> None:
|
||||||
"""Moves PyTorch files from temporary installed location to repo."""
|
"""Moves PyTorch files from temporary installed location to repo."""
|
||||||
# get file listing
|
# get file listing
|
||||||
source_dir = site_dir / "torch"
|
source_dir = site_dir / "torch"
|
||||||
target_dir = REPO_ROOT / "torch"
|
target_dir = REPO_ROOT / "torch"
|
||||||
listing = _get_listing(source_dir, target_dir, platform)
|
listing = _get_listing(source_dir, target_dir)
|
||||||
# copy / link files
|
# copy / link files
|
||||||
if platform.startswith("win"):
|
if WINDOWS:
|
||||||
_copy_files(listing, source_dir, target_dir)
|
_copy_files(listing, source_dir, target_dir)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@ -548,31 +759,10 @@ def move_nightly_files(site_dir: Path, platform: str) -> None:
|
|||||||
_copy_files(listing, source_dir, target_dir)
|
_copy_files(listing, source_dir, target_dir)
|
||||||
|
|
||||||
|
|
||||||
def _available_envs() -> dict[str, str]:
|
|
||||||
cmd = ["conda", "env", "list"]
|
|
||||||
stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
|
|
||||||
envs = {}
|
|
||||||
for line in map(str.strip, stdout.splitlines()):
|
|
||||||
if not line or line.startswith("#"):
|
|
||||||
continue
|
|
||||||
parts = line.split()
|
|
||||||
if len(parts) == 1:
|
|
||||||
# unnamed env
|
|
||||||
continue
|
|
||||||
envs[parts[0]] = parts[-1]
|
|
||||||
return envs
|
|
||||||
|
|
||||||
|
|
||||||
@timed("Writing pytorch-nightly.pth")
|
@timed("Writing pytorch-nightly.pth")
|
||||||
def write_pth(env_opts: list[str], platform: str) -> None:
|
def write_pth(venv: Venv) -> None:
|
||||||
"""Writes Python path file for this dir."""
|
"""Writes Python path file for this dir."""
|
||||||
env_type, env_dir = env_opts
|
(venv.site_packages() / "pytorch-nightly.pth").write_text(
|
||||||
if env_type == "--name":
|
|
||||||
# have to find directory
|
|
||||||
envs = _available_envs()
|
|
||||||
env_dir = envs[env_dir]
|
|
||||||
site_dir = _site_packages(env_dir, platform)
|
|
||||||
(site_dir / "pytorch-nightly.pth").write_text(
|
|
||||||
"# This file was autogenerated by PyTorch's tools/nightly.py\n"
|
"# This file was autogenerated by PyTorch's tools/nightly.py\n"
|
||||||
"# Please delete this file if you no longer need the following development\n"
|
"# Please delete this file if you no longer need the following development\n"
|
||||||
"# version of PyTorch to be importable\n"
|
"# version of PyTorch to be importable\n"
|
||||||
@ -582,50 +772,66 @@ def write_pth(env_opts: list[str], platform: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def install(
|
def install(
|
||||||
specs: Iterable[str],
|
|
||||||
*,
|
*,
|
||||||
logger: logging.Logger,
|
venv: Venv,
|
||||||
|
packages: Iterable[str],
|
||||||
subcommand: str = "checkout",
|
subcommand: str = "checkout",
|
||||||
branch: str | None = None,
|
branch: str | None = None,
|
||||||
name: str | None = None,
|
logger: logging.Logger,
|
||||||
prefix: str | None = None,
|
|
||||||
channels: Iterable[str] = ("pytorch-nightly",),
|
|
||||||
override_channels: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Development install of PyTorch"""
|
"""Development install of PyTorch"""
|
||||||
specs = list(specs)
|
use_existing = subcommand == "checkout"
|
||||||
deps, pytorch, platform, existing_env, env_opts = conda_solve(
|
if use_existing:
|
||||||
specs=specs,
|
venv.ensure()
|
||||||
name=name,
|
else:
|
||||||
prefix=prefix,
|
venv.create(remove_if_exists=True)
|
||||||
channels=channels,
|
|
||||||
override_channels=override_channels,
|
|
||||||
)
|
|
||||||
if deps:
|
|
||||||
deps_install(deps, existing_env, env_opts)
|
|
||||||
|
|
||||||
with pytorch_install(pytorch) as pytorch_dir:
|
packages = [p for p in packages if p != "torch"]
|
||||||
site_dir = _site_packages(pytorch_dir, platform)
|
|
||||||
|
dependencies = venv.pip_download("torch", prerelease=True)
|
||||||
|
torch_wheel = [
|
||||||
|
dep
|
||||||
|
for dep in dependencies
|
||||||
|
if dep.name.startswith("torch-") and dep.name.endswith(".whl")
|
||||||
|
]
|
||||||
|
if len(torch_wheel) != 1:
|
||||||
|
raise RuntimeError(f"Expected exactly one torch wheel, got {torch_wheel}")
|
||||||
|
torch_wheel = torch_wheel[0]
|
||||||
|
dependencies = [deps for deps in dependencies if deps != torch_wheel]
|
||||||
|
|
||||||
|
install_packages(venv, [*packages, *map(str, dependencies)])
|
||||||
|
|
||||||
|
with venv.extracted_wheel(torch_wheel) as wheel_site_dir:
|
||||||
if subcommand == "checkout":
|
if subcommand == "checkout":
|
||||||
checkout_nightly_version(cast(str, branch), site_dir)
|
checkout_nightly_version(cast(str, branch), wheel_site_dir)
|
||||||
elif subcommand == "pull":
|
elif subcommand == "pull":
|
||||||
pull_nightly_version(site_dir)
|
pull_nightly_version(wheel_site_dir)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.")
|
raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.")
|
||||||
move_nightly_files(site_dir, platform)
|
move_nightly_files(wheel_site_dir)
|
||||||
|
|
||||||
write_pth(env_opts, platform)
|
write_pth(venv)
|
||||||
logger.info(
|
logger.info(
|
||||||
"-------\nPyTorch Development Environment set up!\nPlease activate to "
|
"-------\n"
|
||||||
"enable this environment:\n $ conda activate %s",
|
"PyTorch Development Environment set up!\n"
|
||||||
env_opts[1],
|
"Please activate to enable this environment:\n\n"
|
||||||
|
" $ %s",
|
||||||
|
venv.activate_command,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_parser() -> argparse.ArgumentParser:
|
def make_parser() -> argparse.ArgumentParser:
|
||||||
p = argparse.ArgumentParser()
|
def find_executable(name: str) -> Path:
|
||||||
|
executable = shutil.which(name)
|
||||||
|
if executable is None:
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
f"Could not find executable {name} in PATH."
|
||||||
|
)
|
||||||
|
return Path(executable).absolute()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
# subcommands
|
# subcommands
|
||||||
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
|
subcmd = parser.add_subparsers(dest="subcmd", help="subcommand to execute")
|
||||||
checkout = subcmd.add_parser("checkout", help="checkout a new branch")
|
checkout = subcmd.add_parser("checkout", help="checkout a new branch")
|
||||||
checkout.add_argument(
|
checkout.add_argument(
|
||||||
"-b",
|
"-b",
|
||||||
@ -642,19 +848,24 @@ def make_parser() -> argparse.ArgumentParser:
|
|||||||
subparsers = [checkout, pull]
|
subparsers = [checkout, pull]
|
||||||
for subparser in subparsers:
|
for subparser in subparsers:
|
||||||
subparser.add_argument(
|
subparser.add_argument(
|
||||||
"-n",
|
"--python",
|
||||||
"--name",
|
"--base-executable",
|
||||||
help="Name of environment",
|
type=find_executable,
|
||||||
dest="name",
|
help=(
|
||||||
|
"Path to Python interpreter to use for creating the virtual environment. "
|
||||||
|
"Defaults to the interpreter running this script."
|
||||||
|
),
|
||||||
|
dest="base_executable",
|
||||||
default=None,
|
default=None,
|
||||||
metavar="ENVIRONMENT",
|
metavar="PYTHON",
|
||||||
)
|
)
|
||||||
subparser.add_argument(
|
subparser.add_argument(
|
||||||
"-p",
|
"-p",
|
||||||
"--prefix",
|
"--prefix",
|
||||||
help="Full path to environment location (i.e. prefix)",
|
type=lambda p: Path(p).absolute(),
|
||||||
|
help='Path to virtual environment directory (e.g. "./venv")',
|
||||||
dest="prefix",
|
dest="prefix",
|
||||||
default=None,
|
default=str(DEFAULT_VENV_DIR),
|
||||||
metavar="PATH",
|
metavar="PATH",
|
||||||
)
|
)
|
||||||
subparser.add_argument(
|
subparser.add_argument(
|
||||||
@ -665,25 +876,6 @@ def make_parser() -> argparse.ArgumentParser:
|
|||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
)
|
)
|
||||||
subparser.add_argument(
|
|
||||||
"--override-channels",
|
|
||||||
help="Do not search default or .condarc channels.",
|
|
||||||
dest="override_channels",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
subparser.add_argument(
|
|
||||||
"-c",
|
|
||||||
"--channel",
|
|
||||||
help=(
|
|
||||||
"Additional channel to search for packages. "
|
|
||||||
"'pytorch-nightly' will always be prepended to this list."
|
|
||||||
),
|
|
||||||
dest="channels",
|
|
||||||
action="append",
|
|
||||||
metavar="CHANNEL",
|
|
||||||
)
|
|
||||||
if platform_system() in {"Linux", "Windows"}:
|
|
||||||
subparser.add_argument(
|
subparser.add_argument(
|
||||||
"--cuda",
|
"--cuda",
|
||||||
help=(
|
help=(
|
||||||
@ -695,42 +887,60 @@ def make_parser() -> argparse.ArgumentParser:
|
|||||||
default=argparse.SUPPRESS,
|
default=argparse.SUPPRESS,
|
||||||
metavar="VERSION",
|
metavar="VERSION",
|
||||||
)
|
)
|
||||||
return p
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main(args: Sequence[str] | None = None) -> None:
|
def parse_arguments() -> argparse.Namespace:
|
||||||
|
parser = make_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.branch = getattr(args, "branch", None)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
"""Main entry point"""
|
"""Main entry point"""
|
||||||
global LOGGER
|
global LOGGER
|
||||||
p = make_parser()
|
args = parse_arguments()
|
||||||
ns = p.parse_args(args)
|
status = check_branch(args.subcmd, args.branch)
|
||||||
ns.branch = getattr(ns, "branch", None)
|
|
||||||
status = check_branch(ns.subcmd, ns.branch)
|
|
||||||
if status:
|
if status:
|
||||||
sys.exit(status)
|
sys.exit(status)
|
||||||
specs = list(SPECS_TO_INSTALL)
|
|
||||||
channels = ["pytorch-nightly"]
|
pip_source = None
|
||||||
if hasattr(ns, "cuda"):
|
if hasattr(args, "cuda"):
|
||||||
if ns.cuda is not None:
|
available_sources = {
|
||||||
specs.append(f"pytorch-cuda={ns.cuda}")
|
src.name[len("cuda-") :]: src
|
||||||
|
for src in PIP_SOURCES.values()
|
||||||
|
if src.name.startswith("cuda-") and PLATFORM in src.supported_platforms
|
||||||
|
}
|
||||||
|
if not available_sources:
|
||||||
|
print(f"No CUDA versions available on platform {PLATFORM}.")
|
||||||
|
sys.exit(1)
|
||||||
|
if args.cuda is not None:
|
||||||
|
pip_source = available_sources.get(args.cuda)
|
||||||
|
if pip_source is None:
|
||||||
|
print(
|
||||||
|
f"CUDA {args.cuda} is not available on platform {PLATFORM}. "
|
||||||
|
f"Available version(s): {', '.join(sorted(available_sources, key=Version))}"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
specs.append("pytorch-cuda")
|
pip_source = available_sources[max(available_sources, key=Version)]
|
||||||
specs.append("pytorch-mutex=*=*cuda*")
|
|
||||||
channels.append("nvidia")
|
|
||||||
else:
|
else:
|
||||||
specs.append("pytorch-mutex=*=*cpu*")
|
pip_source = PIP_SOURCES["cpu"] # always available
|
||||||
if ns.channels:
|
|
||||||
channels.extend(ns.channels)
|
with logging_manager(debug=args.verbose) as logger:
|
||||||
with logging_manager(debug=ns.verbose) as logger:
|
|
||||||
LOGGER = logger
|
LOGGER = logger
|
||||||
|
venv = Venv(
|
||||||
|
prefix=args.prefix,
|
||||||
|
pip_source=pip_source,
|
||||||
|
base_executable=args.base_executable,
|
||||||
|
)
|
||||||
install(
|
install(
|
||||||
specs=specs,
|
venv=venv,
|
||||||
subcommand=ns.subcmd,
|
packages=PACKAGES_TO_INSTALL,
|
||||||
branch=ns.branch,
|
subcommand=args.subcmd,
|
||||||
name=ns.name,
|
branch=args.branch,
|
||||||
prefix=ns.prefix,
|
|
||||||
logger=logger,
|
logger=logger,
|
||||||
channels=channels,
|
|
||||||
override_channels=ns.override_channels,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user