mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Changes: - Add `-C REPO` in `git` commands to allow the tool can be run everywhere not only the repo dir - Use `pathlib.Path` as many as possible - Replace `subprocess.run(..., check=True)` with `subprocess.check_{call,output}(...)` - Add `encoding='utf-8'` for files Pull Request resolved: https://github.com/pytorch/pytorch/pull/131134 Approved by: https://github.com/ezyang
715 lines
23 KiB
Python
Executable File
715 lines
23 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Much of the logging code here was forked from https://github.com/ezyang/ghstack
|
|
# Copyright (c) Edward Z. Yang <ezyang@mit.edu>
|
|
"""Checks out the nightly development version of PyTorch and installs pre-built
|
|
binaries into the repo.
|
|
|
|
You can use this script to check out a new nightly branch with the following::
|
|
|
|
$ ./tools/nightly.py checkout -b my-nightly-branch
|
|
$ conda activate pytorch-deps
|
|
|
|
Or if you would like to re-use an existing conda environment, you can pass in
|
|
the regular environment parameters (--name or --prefix)::
|
|
|
|
$ ./tools/nightly.py checkout -b my-nightly-branch -n my-env
|
|
$ conda activate my-env
|
|
|
|
To install the nightly binaries built with CUDA, you can pass in the flag --cuda::
|
|
|
|
$ ./tools/nightly.py checkout -b my-nightly-branch --cuda
|
|
$ conda activate pytorch-deps
|
|
|
|
You can also use this tool to pull the nightly commits into the current branch as
|
|
well. This can be done with::
|
|
|
|
$ ./tools/nightly.py pull -n my-env
|
|
$ conda activate my-env
|
|
|
|
Pulling will reinstall the conda dependencies as well as the nightly binaries into
|
|
the repo directory.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import contextlib
|
|
import functools
|
|
import glob
|
|
import itertools
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import uuid
|
|
from ast import literal_eval
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from platform import system as platform_system
|
|
from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
|
|
|
|
|
|
REPO_ROOT = Path(__file__).absolute().parent.parent
|
|
GITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git"
|
|
SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx")
|
|
|
|
LOGGER: logging.Logger | None = None
|
|
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
|
|
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
|
|
SHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})")
|
|
USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
|
|
LOG_DIRNAME_RE = re.compile(
|
|
r"(?P<datetime>\d{4}-\d\d-\d\d_\d\dh\d\dm\d\ds)_"
|
|
r"(?P<uuid>[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12})",
|
|
)
|
|
|
|
|
|
class Formatter(logging.Formatter):
|
|
redactions: dict[str, str]
|
|
|
|
def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None:
|
|
super().__init__(fmt, datefmt)
|
|
self.redactions = {}
|
|
|
|
# Remove sensitive information from URLs
|
|
def _filter(self, s: str) -> str:
|
|
s = USERNAME_PASSWORD_RE.sub(r"://<USERNAME>:<PASSWORD>@", s)
|
|
for needle, replace in self.redactions.items():
|
|
s = s.replace(needle, replace)
|
|
return s
|
|
|
|
def formatMessage(self, record: logging.LogRecord) -> str:
|
|
if record.levelno == logging.INFO or record.levelno == logging.DEBUG:
|
|
# Log INFO/DEBUG without any adornment
|
|
return record.getMessage()
|
|
else:
|
|
# I'm not sure why, but formatMessage doesn't show up
|
|
# even though it's in the typeshed for Python >3
|
|
return super().formatMessage(record)
|
|
|
|
def format(self, record: logging.LogRecord) -> str:
|
|
return self._filter(super().format(record))
|
|
|
|
def redact(self, needle: str, replace: str = "<REDACTED>") -> None:
|
|
"""Redact specific strings; e.g., authorization tokens. This won't
|
|
retroactively redact stuff you've already leaked, so make sure
|
|
you redact things as soon as possible.
|
|
"""
|
|
# Don't redact empty strings; this will lead to something
|
|
# that looks like s<REDACTED>t<REDACTED>r<REDACTED>...
|
|
if needle == "":
|
|
return
|
|
self.redactions[needle] = replace
|
|
|
|
|
|
def git(*args: str) -> list[str]:
|
|
return ["git", "-C", str(REPO_ROOT), *args]
|
|
|
|
|
|
@functools.lru_cache
|
|
def logging_base_dir() -> Path:
|
|
base_dir = REPO_ROOT / "nightly" / "log"
|
|
base_dir.mkdir(parents=True, exist_ok=True)
|
|
return base_dir
|
|
|
|
|
|
@functools.lru_cache
|
|
def logging_run_dir() -> Path:
|
|
base_dir = logging_base_dir()
|
|
cur_dir = base_dir / f"{datetime.now().strftime(DATETIME_FORMAT)}_{uuid.uuid1()}"
|
|
cur_dir.mkdir(parents=True, exist_ok=True)
|
|
return cur_dir
|
|
|
|
|
|
@functools.lru_cache
|
|
def logging_record_argv() -> None:
|
|
s = subprocess.list2cmdline(sys.argv)
|
|
(logging_run_dir() / "argv").write_text(s, encoding="utf-8")
|
|
|
|
|
|
def logging_record_exception(e: BaseException) -> None:
|
|
(logging_run_dir() / "exception").write_text(type(e).__name__, encoding="utf-8")
|
|
|
|
|
|
def logging_rotate() -> None:
|
|
log_base = logging_base_dir()
|
|
old_logs = sorted(log_base.iterdir(), reverse=True)
|
|
for stale_log in old_logs[1000:]:
|
|
# Sanity check that it looks like a log
|
|
if LOG_DIRNAME_RE.fullmatch(stale_log.name) is not None:
|
|
shutil.rmtree(stale_log)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]:
|
|
"""Setup logging. If a failure starts here we won't
|
|
be able to save the user in a reasonable way.
|
|
|
|
Logging structure: there is one logger (the root logger)
|
|
and in processes all events. There are two handlers:
|
|
stderr (INFO) and file handler (DEBUG).
|
|
"""
|
|
formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="")
|
|
root_logger = logging.getLogger("conda-pytorch")
|
|
root_logger.setLevel(logging.DEBUG)
|
|
|
|
console_handler = logging.StreamHandler()
|
|
if debug:
|
|
console_handler.setLevel(logging.DEBUG)
|
|
else:
|
|
console_handler.setLevel(logging.INFO)
|
|
console_handler.setFormatter(formatter)
|
|
root_logger.addHandler(console_handler)
|
|
|
|
log_file = logging_run_dir() / "nightly.log"
|
|
|
|
file_handler = logging.FileHandler(log_file)
|
|
file_handler.setFormatter(formatter)
|
|
root_logger.addHandler(file_handler)
|
|
logging_record_argv()
|
|
|
|
try:
|
|
logging_rotate()
|
|
print(f"log file: {log_file}")
|
|
yield root_logger
|
|
except Exception as e:
|
|
logging.exception("Fatal exception")
|
|
logging_record_exception(e)
|
|
print(f"log file: {log_file}")
|
|
sys.exit(1)
|
|
except BaseException as e:
|
|
# You could logging.debug here to suppress the backtrace
|
|
# entirely, but there is no reason to hide it from technically
|
|
# savvy users.
|
|
logging.info("", exc_info=True)
|
|
logging_record_exception(e)
|
|
print(f"log file: {log_file}")
|
|
sys.exit(1)
|
|
|
|
|
|
def check_branch(subcommand: str, branch: str | None) -> str | None:
|
|
"""Checks that the branch name can be checked out."""
|
|
if subcommand != "checkout":
|
|
return None
|
|
# first make sure actual branch name was given
|
|
if branch is None:
|
|
return "Branch name to checkout must be supplied with '-b' option"
|
|
# next check that the local repo is clean
|
|
cmd = git("status", "--untracked-files=no", "--porcelain")
|
|
stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
|
|
if stdout.strip():
|
|
return "Need to have clean working tree to checkout!\n\n" + stdout
|
|
# next check that the branch name doesn't already exist
|
|
cmd = git("show-ref", "--verify", "--quiet", f"refs/heads/{branch}")
|
|
p = subprocess.run(cmd, capture_output=True, check=False) # type: ignore[assignment]
|
|
if not p.returncode:
|
|
return f"Branch {branch!r} already exists"
|
|
return None
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
|
|
"""Timed context manager"""
|
|
start_time = time.perf_counter()
|
|
yield
|
|
logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time)
|
|
|
|
|
|
F = TypeVar("F", bound=Callable[..., Any])
|
|
|
|
|
|
def timed(prefix: str) -> Callable[[F], F]:
|
|
"""Decorator for timing functions"""
|
|
|
|
def dec(f: F) -> F:
|
|
@functools.wraps(f)
|
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
logger = cast(logging.Logger, LOGGER)
|
|
logger.info(prefix)
|
|
with timer(logger, prefix):
|
|
return f(*args, **kwargs)
|
|
|
|
return cast(F, wrapper)
|
|
|
|
return dec
|
|
|
|
|
|
def _make_channel_args(
|
|
channels: Iterable[str] = ("pytorch-nightly",),
|
|
override_channels: bool = False,
|
|
) -> list[str]:
|
|
args = []
|
|
for channel in channels:
|
|
args.extend(["--channel", channel])
|
|
if override_channels:
|
|
args.append("--override-channels")
|
|
return args
|
|
|
|
|
|
@timed("Solving conda environment")
|
|
def conda_solve(
|
|
specs: Iterable[str],
|
|
*,
|
|
name: str | None = None,
|
|
prefix: str | None = None,
|
|
channels: Iterable[str] = ("pytorch-nightly",),
|
|
override_channels: bool = False,
|
|
) -> tuple[list[str], str, str, bool, list[str]]:
|
|
"""Performs the conda solve and splits the deps from the package."""
|
|
# compute what environment to use
|
|
if prefix is not None:
|
|
existing_env = True
|
|
env_opts = ["--prefix", prefix]
|
|
elif name is not None:
|
|
existing_env = True
|
|
env_opts = ["--name", name]
|
|
else:
|
|
# create new environment
|
|
existing_env = False
|
|
env_opts = ["--name", "pytorch-deps"]
|
|
# run solve
|
|
if existing_env:
|
|
cmd = [
|
|
"conda",
|
|
"install",
|
|
"--yes",
|
|
"--dry-run",
|
|
"--json",
|
|
]
|
|
cmd.extend(env_opts)
|
|
else:
|
|
cmd = [
|
|
"conda",
|
|
"create",
|
|
"--yes",
|
|
"--dry-run",
|
|
"--json",
|
|
"--name",
|
|
"__pytorch__",
|
|
]
|
|
channel_args = _make_channel_args(
|
|
channels=channels,
|
|
override_channels=override_channels,
|
|
)
|
|
cmd.extend(channel_args)
|
|
cmd.extend(specs)
|
|
stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
|
|
# parse solution
|
|
solve = json.loads(stdout)
|
|
link = solve["actions"]["LINK"]
|
|
deps = []
|
|
pytorch, platform = "", ""
|
|
for pkg in link:
|
|
url = URL_FORMAT.format(**pkg)
|
|
if pkg["name"] == "pytorch":
|
|
pytorch = url
|
|
platform = pkg["platform"]
|
|
else:
|
|
deps.append(url)
|
|
assert pytorch, "PyTorch package not found in solve"
|
|
assert platform, "Platform not found in solve"
|
|
return deps, pytorch, platform, existing_env, env_opts
|
|
|
|
|
|
@timed("Installing dependencies")
|
|
def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None:
|
|
"""Install dependencies to deps environment"""
|
|
if not existing_env:
|
|
# first remove previous pytorch-deps env
|
|
cmd = ["conda", "env", "remove", "--yes", *env_opts]
|
|
subprocess.check_call(cmd)
|
|
# install new deps
|
|
install_command = "install" if existing_env else "create"
|
|
cmd = ["conda", install_command, "--yes", "--no-deps", *env_opts, *deps]
|
|
subprocess.check_call(cmd)
|
|
|
|
|
|
@timed("Installing pytorch nightly binaries")
|
|
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
|
|
"""Install pytorch into a temporary directory"""
|
|
pytorch_dir = tempfile.TemporaryDirectory(prefix="conda-pytorch-")
|
|
cmd = ["conda", "create", "--yes", "--no-deps", f"--prefix={pytorch_dir.name}", url]
|
|
subprocess.check_call(cmd)
|
|
return pytorch_dir
|
|
|
|
|
|
def _site_packages(dirname: str, platform: str) -> Path:
|
|
if platform.startswith("win"):
|
|
template = os.path.join(dirname, "Lib", "site-packages")
|
|
else:
|
|
template = os.path.join(dirname, "lib", "python*.*", "site-packages")
|
|
return Path(next(glob.iglob(template))).absolute()
|
|
|
|
|
|
def _ensure_commit(git_sha1: str) -> None:
|
|
"""Make sure that we actually have the commit locally"""
|
|
cmd = git("cat-file", "-e", git_sha1 + r"^{commit}")
|
|
p = subprocess.run(cmd, capture_output=True, check=False)
|
|
if p.returncode == 0:
|
|
# we have the commit locally
|
|
return
|
|
# we don't have the commit, must fetch
|
|
cmd = git("fetch", GITHUB_REMOTE_URL, git_sha1)
|
|
subprocess.check_call(cmd)
|
|
|
|
|
|
def _nightly_version(site_dir: Path) -> str:
|
|
# first get the git version from the installed module
|
|
version_file = site_dir / "torch" / "version.py"
|
|
with version_file.open(encoding="utf-8") as f:
|
|
for line in f:
|
|
if not line.startswith("git_version"):
|
|
continue
|
|
git_version = literal_eval(line.partition("=")[2].strip())
|
|
break
|
|
else:
|
|
raise RuntimeError(f"Could not find git_version in {version_file}")
|
|
|
|
print(f"Found released git version {git_version}")
|
|
# now cross reference with nightly version
|
|
_ensure_commit(git_version)
|
|
cmd = git("show", "--no-patch", "--format=%s", git_version)
|
|
stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
|
|
m = SHA1_RE.search(stdout)
|
|
if m is None:
|
|
raise RuntimeError(
|
|
f"Could not find nightly release in git history:\n {stdout}"
|
|
)
|
|
nightly_version = m.group("sha1")
|
|
print(f"Found nightly release version {nightly_version}")
|
|
# now checkout nightly version
|
|
_ensure_commit(nightly_version)
|
|
return nightly_version
|
|
|
|
|
|
@timed("Checking out nightly PyTorch")
|
|
def checkout_nightly_version(branch: str, site_dir: Path) -> None:
|
|
"""Get's the nightly version and then checks it out."""
|
|
nightly_version = _nightly_version(site_dir)
|
|
cmd = git("checkout", "-b", branch, nightly_version)
|
|
subprocess.check_call(cmd)
|
|
|
|
|
|
@timed("Pulling nightly PyTorch")
|
|
def pull_nightly_version(site_dir: Path) -> None:
|
|
"""Fetches the nightly version and then merges it ."""
|
|
nightly_version = _nightly_version(site_dir)
|
|
cmd = git("merge", nightly_version)
|
|
subprocess.check_call(cmd)
|
|
|
|
|
|
def _get_listing_linux(source_dir: Path) -> list[Path]:
|
|
return list(
|
|
itertools.chain(
|
|
source_dir.glob("*.so"),
|
|
(source_dir / "lib").glob("*.so"),
|
|
(source_dir / "lib").glob("*.so.*"),
|
|
)
|
|
)
|
|
|
|
|
|
def _get_listing_osx(source_dir: Path) -> list[Path]:
|
|
# oddly, these are .so files even on Mac
|
|
return list(
|
|
itertools.chain(
|
|
source_dir.glob("*.so"),
|
|
(source_dir / "lib").glob("*.dylib"),
|
|
)
|
|
)
|
|
|
|
|
|
def _get_listing_win(source_dir: Path) -> list[Path]:
|
|
return list(
|
|
itertools.chain(
|
|
source_dir.glob("*.pyd"),
|
|
(source_dir / "lib").glob("*.lib"),
|
|
(source_dir / "lib").glob(".dll"),
|
|
)
|
|
)
|
|
|
|
|
|
def _glob_pyis(d: Path) -> set[str]:
|
|
return {p.relative_to(d).as_posix() for p in d.rglob("*.pyi")}
|
|
|
|
|
|
def _find_missing_pyi(source_dir: Path, target_dir: Path) -> list[Path]:
|
|
source_pyis = _glob_pyis(source_dir)
|
|
target_pyis = _glob_pyis(target_dir)
|
|
missing_pyis = sorted(source_dir / p for p in (source_pyis - target_pyis))
|
|
return missing_pyis
|
|
|
|
|
|
def _get_listing(source_dir: Path, target_dir: Path, platform: str) -> list[Path]:
|
|
if platform.startswith("linux"):
|
|
listing = _get_listing_linux(source_dir)
|
|
elif platform.startswith("osx"):
|
|
listing = _get_listing_osx(source_dir)
|
|
elif platform.startswith("win"):
|
|
listing = _get_listing_win(source_dir)
|
|
else:
|
|
raise RuntimeError(f"Platform {platform!r} not recognized")
|
|
listing.extend(_find_missing_pyi(source_dir, target_dir))
|
|
listing.append(source_dir / "version.py")
|
|
listing.append(source_dir / "testing" / "_internal" / "generated")
|
|
listing.append(source_dir / "bin")
|
|
listing.append(source_dir / "include")
|
|
return listing
|
|
|
|
|
|
def _remove_existing(path: Path) -> None:
|
|
if path.exists():
|
|
if path.is_dir():
|
|
shutil.rmtree(path)
|
|
else:
|
|
path.unlink()
|
|
|
|
|
|
def _move_single(
|
|
src: Path,
|
|
source_dir: Path,
|
|
target_dir: Path,
|
|
mover: Callable[[Path, Path], None],
|
|
verb: str,
|
|
) -> None:
|
|
relpath = src.relative_to(source_dir)
|
|
trg = target_dir / relpath
|
|
_remove_existing(trg)
|
|
# move over new files
|
|
if src.is_dir():
|
|
trg.mkdir(parents=True, exist_ok=True)
|
|
for root, dirs, files in os.walk(src):
|
|
relroot = Path(root).relative_to(src)
|
|
for name in files:
|
|
relname = relroot / name
|
|
s = src / relname
|
|
t = trg / relname
|
|
print(f"{verb} {s} -> {t}")
|
|
mover(s, t)
|
|
for name in dirs:
|
|
(trg / relroot / name).mkdir(parents=True, exist_ok=True)
|
|
else:
|
|
print(f"{verb} {src} -> {trg}")
|
|
mover(src, trg)
|
|
|
|
|
|
def _copy_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
|
|
for src in listing:
|
|
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
|
|
|
|
|
|
def _link_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
|
|
for src in listing:
|
|
_move_single(src, source_dir, target_dir, os.link, "Linking")
|
|
|
|
|
|
@timed("Moving nightly files into repo")
|
|
def move_nightly_files(site_dir: Path, platform: str) -> None:
|
|
"""Moves PyTorch files from temporary installed location to repo."""
|
|
# get file listing
|
|
source_dir = site_dir / "torch"
|
|
target_dir = REPO_ROOT / "torch"
|
|
listing = _get_listing(source_dir, target_dir, platform)
|
|
# copy / link files
|
|
if platform.startswith("win"):
|
|
_copy_files(listing, source_dir, target_dir)
|
|
else:
|
|
try:
|
|
_link_files(listing, source_dir, target_dir)
|
|
except Exception:
|
|
_copy_files(listing, source_dir, target_dir)
|
|
|
|
|
|
def _available_envs() -> dict[str, str]:
|
|
cmd = ["conda", "env", "list"]
|
|
stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
|
|
envs = {}
|
|
for line in map(str.strip, stdout.splitlines()):
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
parts = line.split()
|
|
if len(parts) == 1:
|
|
# unnamed env
|
|
continue
|
|
envs[parts[0]] = parts[-1]
|
|
return envs
|
|
|
|
|
|
@timed("Writing pytorch-nightly.pth")
|
|
def write_pth(env_opts: list[str], platform: str) -> None:
|
|
"""Writes Python path file for this dir."""
|
|
env_type, env_dir = env_opts
|
|
if env_type == "--name":
|
|
# have to find directory
|
|
envs = _available_envs()
|
|
env_dir = envs[env_dir]
|
|
site_dir = _site_packages(env_dir, platform)
|
|
(site_dir / "pytorch-nightly.pth").write_text(
|
|
"# This file was autogenerated by PyTorch's tools/nightly.py\n"
|
|
"# Please delete this file if you no longer need the following development\n"
|
|
"# version of PyTorch to be importable\n"
|
|
f"{REPO_ROOT}\n",
|
|
encoding="utf-8",
|
|
)
|
|
|
|
|
|
def install(
|
|
specs: Iterable[str],
|
|
*,
|
|
logger: logging.Logger,
|
|
subcommand: str = "checkout",
|
|
branch: str | None = None,
|
|
name: str | None = None,
|
|
prefix: str | None = None,
|
|
channels: Iterable[str] = ("pytorch-nightly",),
|
|
override_channels: bool = False,
|
|
) -> None:
|
|
"""Development install of PyTorch"""
|
|
specs = list(specs)
|
|
deps, pytorch, platform, existing_env, env_opts = conda_solve(
|
|
specs=specs,
|
|
name=name,
|
|
prefix=prefix,
|
|
channels=channels,
|
|
override_channels=override_channels,
|
|
)
|
|
if deps:
|
|
deps_install(deps, existing_env, env_opts)
|
|
|
|
with pytorch_install(pytorch) as pytorch_dir:
|
|
site_dir = _site_packages(pytorch_dir, platform)
|
|
if subcommand == "checkout":
|
|
checkout_nightly_version(cast(str, branch), site_dir)
|
|
elif subcommand == "pull":
|
|
pull_nightly_version(site_dir)
|
|
else:
|
|
raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.")
|
|
move_nightly_files(site_dir, platform)
|
|
|
|
write_pth(env_opts, platform)
|
|
logger.info(
|
|
"-------\nPyTorch Development Environment set up!\nPlease activate to "
|
|
"enable this environment:\n $ conda activate %s",
|
|
env_opts[1],
|
|
)
|
|
|
|
|
|
def make_parser() -> argparse.ArgumentParser:
|
|
p = argparse.ArgumentParser()
|
|
# subcommands
|
|
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
|
|
checkout = subcmd.add_parser("checkout", help="checkout a new branch")
|
|
checkout.add_argument(
|
|
"-b",
|
|
"--branch",
|
|
help="Branch name to checkout",
|
|
dest="branch",
|
|
default=None,
|
|
metavar="NAME",
|
|
)
|
|
pull = subcmd.add_parser(
|
|
"pull", help="pulls the nightly commits into the current branch"
|
|
)
|
|
# general arguments
|
|
subparsers = [checkout, pull]
|
|
for subparser in subparsers:
|
|
subparser.add_argument(
|
|
"-n",
|
|
"--name",
|
|
help="Name of environment",
|
|
dest="name",
|
|
default=None,
|
|
metavar="ENVIRONMENT",
|
|
)
|
|
subparser.add_argument(
|
|
"-p",
|
|
"--prefix",
|
|
help="Full path to environment location (i.e. prefix)",
|
|
dest="prefix",
|
|
default=None,
|
|
metavar="PATH",
|
|
)
|
|
subparser.add_argument(
|
|
"-v",
|
|
"--verbose",
|
|
help="Provide debugging info",
|
|
dest="verbose",
|
|
default=False,
|
|
action="store_true",
|
|
)
|
|
subparser.add_argument(
|
|
"--override-channels",
|
|
help="Do not search default or .condarc channels.",
|
|
dest="override_channels",
|
|
default=False,
|
|
action="store_true",
|
|
)
|
|
subparser.add_argument(
|
|
"-c",
|
|
"--channel",
|
|
help=(
|
|
"Additional channel to search for packages. "
|
|
"'pytorch-nightly' will always be prepended to this list."
|
|
),
|
|
dest="channels",
|
|
action="append",
|
|
metavar="CHANNEL",
|
|
)
|
|
if platform_system() in {"Linux", "Windows"}:
|
|
subparser.add_argument(
|
|
"--cuda",
|
|
help=(
|
|
"CUDA version to install "
|
|
"(defaults to the latest version available on the platform)"
|
|
),
|
|
dest="cuda",
|
|
nargs="?",
|
|
default=argparse.SUPPRESS,
|
|
metavar="VERSION",
|
|
)
|
|
return p
|
|
|
|
|
|
def main(args: Sequence[str] | None = None) -> None:
|
|
"""Main entry point"""
|
|
global LOGGER
|
|
p = make_parser()
|
|
ns = p.parse_args(args)
|
|
ns.branch = getattr(ns, "branch", None)
|
|
status = check_branch(ns.subcmd, ns.branch)
|
|
if status:
|
|
sys.exit(status)
|
|
specs = list(SPECS_TO_INSTALL)
|
|
channels = ["pytorch-nightly"]
|
|
if hasattr(ns, "cuda"):
|
|
if ns.cuda is not None:
|
|
specs.append(f"pytorch-cuda={ns.cuda}")
|
|
else:
|
|
specs.append("pytorch-cuda")
|
|
specs.append("pytorch-mutex=*=*cuda*")
|
|
channels.append("nvidia")
|
|
else:
|
|
specs.append("pytorch-mutex=*=*cpu*")
|
|
if ns.channels:
|
|
channels.extend(ns.channels)
|
|
with logging_manager(debug=ns.verbose) as logger:
|
|
LOGGER = logger
|
|
install(
|
|
specs=specs,
|
|
subcommand=ns.subcmd,
|
|
branch=ns.branch,
|
|
name=ns.name,
|
|
prefix=ns.prefix,
|
|
logger=logger,
|
|
channels=channels,
|
|
override_channels=ns.override_channels,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|