mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 23:45:05 +08:00
Compare commits
132 Commits
export-D79
...
test-old
| Author | SHA1 | Date | |
|---|---|---|---|
| c0d981fb92 | |||
| f5f3fd155f | |||
| 161048c20a | |||
| d8bc95af0b | |||
| 4e4e1f81b3 | |||
| 24580f98c2 | |||
| de4c6b3f3f | |||
| ef6573629e | |||
| 36e115eca5 | |||
| dad3e16379 | |||
| 95490a1ad7 | |||
| 19f1f9960d | |||
| fd6655a0f5 | |||
| a7f3bdf550 | |||
| 510e8b4ae0 | |||
| 83ba3f1101 | |||
| 1fad16aacb | |||
| 444e2381d0 | |||
| 6085bf7565 | |||
| 8201dbf4bc | |||
| 26d045bb60 | |||
| 356ac3103a | |||
| d4109a0f99 | |||
| 7ea789ccfb | |||
| 7e8197e34d | |||
| 50eac811a6 | |||
| 4e0f179d0b | |||
| 36e59d9b12 | |||
| fc340d0ca3 | |||
| 53e47af0f7 | |||
| 66ad881fc7 | |||
| 1d3eef27ac | |||
| dd95900cec | |||
| 1cdd665526 | |||
| 7cb2dcd2dd | |||
| e5a81aa7ba | |||
| 3e2aa4b0e3 | |||
| 6646461764 | |||
| f74da2a136 | |||
| d35b27dde5 | |||
| a9dc1566d4 | |||
| 33a1996714 | |||
| ee62177c19 | |||
| 64cbaa876c | |||
| 4516c59f5f | |||
| 8bc843a9ec | |||
| e39a62c70d | |||
| 978e3a9142 | |||
| e2a5c42e7e | |||
| 5116c49b52 | |||
| fecdebe385 | |||
| e136a9175b | |||
| 9a680e14b7 | |||
| 805a102beb | |||
| 6e8d705a22 | |||
| 9c18901bfd | |||
| a29ed5e1ac | |||
| d2792f51b2 | |||
| be71000ff5 | |||
| 3f86076775 | |||
| 1616777cd2 | |||
| 38895c0ac2 | |||
| 310f901a71 | |||
| e11b1cd97e | |||
| b599d91738 | |||
| fd6a6658c3 | |||
| 04973496a8 | |||
| 1548b011ea | |||
| e57a92734d | |||
| 79ff3b320b | |||
| 426f249f20 | |||
| d33a484763 | |||
| a81ffbc5f5 | |||
| 465fe4d9f7 | |||
| 9477af1063 | |||
| dcc36e38bb | |||
| efd78584a8 | |||
| 135762ea20 | |||
| e2ee9cfaa2 | |||
| 06d28de17a | |||
| df9720b8b5 | |||
| 85e74d5ace | |||
| 0450f05658 | |||
| 595a65f5c2 | |||
| 8c6c2e40eb | |||
| 32840d19f9 | |||
| 2040f00112 | |||
| c137f9da0b | |||
| 5e8b95605f | |||
| 8ea86a6e31 | |||
| acad808545 | |||
| c687446374 | |||
| dd22ba09b4 | |||
| c0e0126399 | |||
| e4b123b5e4 | |||
| 5711a8f069 | |||
| b4b71d011e | |||
| 52376b9b6f | |||
| 1371a98b0e | |||
| 2a286cbdf4 | |||
| 7c37b8e1e0 | |||
| ee2649219c | |||
| b0b3e6e48b | |||
| 3967dbedf4 | |||
| 4396b15aa7 | |||
| bb6766053b | |||
| a4fc051c9a | |||
| 5cc6a0abc1 | |||
| 90f13f3b2a | |||
| cb9b74872b | |||
| c964204829 | |||
| 2ac45c2752 | |||
| 83e2ea8135 | |||
| d994027a41 | |||
| cb4f41e125 | |||
| 690fc9cf88 | |||
| eb853e222b | |||
| 06395276e4 | |||
| 8becf646ef | |||
| fa68216ca1 | |||
| 25ef3d315d | |||
| 7e00f2ec9d | |||
| 490cb3f1a4 | |||
| b95cf5c91d | |||
| 5e2ef2a465 | |||
| 9f753f8c0d | |||
| db437690d1 | |||
| 669009bcd1 | |||
| e4e2701429 | |||
| 64cc649275 | |||
| b1fb552974 | |||
| bb62e1f769 |
@ -1 +1 @@
|
||||
11ec6354315768a85da41032535e3b7b99c5f706
|
||||
f7888497a1eb9e98d4c07537f0d0bcfe180d1363
|
||||
|
||||
@ -103,5 +103,5 @@ fi
|
||||
# It depends on torch and triton. We don't want to install
|
||||
# triton and torch from production on Docker CI images
|
||||
if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then
|
||||
pip_install helion==0.0.10 --no-deps
|
||||
pip_install helion --no-deps
|
||||
fi
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
sphinx==5.3.0
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 5.3.0
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@722b7e6f9ca512fcc526ad07d62b3d28c50bb6cd#egg=pytorch_sphinx_theme2
|
||||
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
@ -50,7 +50,7 @@ IPython==8.12.0
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 0.17.2
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
|
||||
31
.ci/lumen_cli/README.md
Normal file
31
.ci/lumen_cli/README.md
Normal file
@ -0,0 +1,31 @@
|
||||
# 🔧 Lumen_cli
|
||||
A Python CLI tool for building and testing PyTorch-based components, using a YAML configuration file for structured, repeatable workflows.
|
||||
|
||||
|
||||
## Features
|
||||
- **Build**
|
||||
- external projects (e.g. vLLM)
|
||||
|
||||
## 📦 Installation
|
||||
at the root of the pytorch repo
|
||||
```bash
|
||||
pip install -e .ci/lumen_cli
|
||||
```
|
||||
|
||||
## Run the cli tool
|
||||
The cli tool must be used at root of pytorch repo, as example to run build external vllm:
|
||||
```bash
|
||||
python -m cli.run build external vllm
|
||||
```
|
||||
this will run the build steps with default behaviour for vllm project.
|
||||
|
||||
to see help messages, run
|
||||
```bash
|
||||
python3 -m cli.run --help
|
||||
```
|
||||
|
||||
## Add customized external build logics
|
||||
To add a new external build, for instance, add a new external build logics:
|
||||
1. create the build function in cli/lib folder
|
||||
2. register your target and the main build function at EXTERNAL_BUILD_TARGET_DISPATCH in `cli/build_cli/register_build.py`
|
||||
3. [optional] create your ci config file in .github/ci_configs/${EXTERNAL_PACKAGE_NAME}.yaml
|
||||
37
.ci/lumen_cli/cli/build_cli/register_build.py
Normal file
37
.ci/lumen_cli/cli/build_cli/register_build.py
Normal file
@ -0,0 +1,37 @@
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from cli.lib.common.cli_helper import register_targets, RichHelp, TargetSpec
|
||||
from cli.lib.core.vllm import VllmBuildRunner
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps targets to their argparse configuration and runner
|
||||
# it adds new target to path python -m cli.run build external {target} with buildrunner
|
||||
_TARGETS: dict[str, TargetSpec] = {
|
||||
"vllm": {
|
||||
"runner": VllmBuildRunner,
|
||||
"help": "Build vLLM using docker buildx.",
|
||||
}
|
||||
# add yours ...
|
||||
}
|
||||
|
||||
|
||||
def register_build_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
build_parser = subparsers.add_parser(
|
||||
"build",
|
||||
help="Build related commands",
|
||||
formatter_class=RichHelp,
|
||||
)
|
||||
build_subparsers = build_parser.add_subparsers(dest="build_command", required=True)
|
||||
overview = "\n".join(
|
||||
f" {name:12} {spec.get('help', '')}" for name, spec in _TARGETS.items()
|
||||
)
|
||||
external_parser = build_subparsers.add_parser(
|
||||
"external",
|
||||
help="Build external targets",
|
||||
description="Build third-party targets.\n\nAvailable targets:\n" + overview,
|
||||
formatter_class=RichHelp,
|
||||
)
|
||||
register_targets(external_parser, _TARGETS)
|
||||
71
.ci/lumen_cli/cli/lib/common/cli_helper.py
Normal file
71
.ci/lumen_cli/cli/lib/common/cli_helper.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Cli Argparser Utility helpers for CLI tasks.
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
try:
|
||||
from typing import Any, Callable, Required, TypedDict # Python 3.11+
|
||||
except ImportError:
|
||||
from typing import Any, Callable, TypedDict
|
||||
|
||||
from typing_extensions import Required # Fallback for Python <3.11
|
||||
|
||||
|
||||
class BaseRunner(ABC):
|
||||
def __init__(self, args: Any) -> None:
|
||||
self.args = args
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> None:
|
||||
"""runs main logics, required"""
|
||||
|
||||
|
||||
# Pretty help: keep newlines + show defaults
|
||||
class RichHelp(
|
||||
argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class TargetSpec(TypedDict, total=False):
|
||||
"""CLI subcommand specification with bA."""
|
||||
|
||||
runner: Required[type[BaseRunner]]
|
||||
help: str
|
||||
description: str
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]
|
||||
|
||||
|
||||
def register_targets(
|
||||
parser: argparse.ArgumentParser,
|
||||
target_specs: dict[str, TargetSpec],
|
||||
common_args: Callable[[argparse.ArgumentParser], None] = lambda _: None,
|
||||
) -> None:
|
||||
"""Register target subcommands."""
|
||||
targets = parser.add_subparsers(
|
||||
dest="target",
|
||||
required=True,
|
||||
metavar="{" + ",".join(target_specs.keys()) + "}",
|
||||
)
|
||||
|
||||
for name, spec in target_specs.items():
|
||||
desc = spec.get("description") or spec["runner"].__doc__ or ""
|
||||
|
||||
p = targets.add_parser(
|
||||
name,
|
||||
help=spec.get("help", ""),
|
||||
description=desc.strip(),
|
||||
formatter_class=RichHelp,
|
||||
)
|
||||
p.set_defaults(
|
||||
func=lambda args, cls=spec["runner"]: cls(args).run(),
|
||||
_runner_class=spec["runner"],
|
||||
)
|
||||
if "add_arguments" in spec and callable(spec["add_arguments"]):
|
||||
spec["add_arguments"](p)
|
||||
if common_args:
|
||||
common_args(p)
|
||||
42
.ci/lumen_cli/cli/lib/common/docker_helper.py
Normal file
42
.ci/lumen_cli/cli/lib/common/docker_helper.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""
|
||||
Docker Utility helpers for CLI tasks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import docker
|
||||
from docker.errors import APIError, NotFound
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# lazy singleton so we don't reconnect every call
|
||||
_docker_client: Optional[docker.DockerClient] = None
|
||||
|
||||
|
||||
def _get_client() -> docker.DockerClient:
|
||||
global _docker_client
|
||||
if _docker_client is None:
|
||||
_docker_client = docker.from_env()
|
||||
return _docker_client
|
||||
|
||||
|
||||
def local_image_exists(
|
||||
image_name: str, client: Optional[docker.DockerClient] = None
|
||||
) -> bool:
|
||||
"""Return True if a local Docker image exists."""
|
||||
if not image_name:
|
||||
return False
|
||||
|
||||
client = client or _get_client()
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
return True
|
||||
except (NotFound, APIError) as e:
|
||||
logger.error(
|
||||
"Error when checking Docker image '%s': %s",
|
||||
image_name,
|
||||
e.explanation if hasattr(e, "explanation") else str(e),
|
||||
)
|
||||
return False
|
||||
110
.ci/lumen_cli/cli/lib/common/envs_helper.py
Normal file
110
.ci/lumen_cli/cli/lib/common/envs_helper.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""
|
||||
Environment Variables and Dataclasses Utility helpers for CLI tasks.
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import field, fields, is_dataclass, MISSING
|
||||
from pathlib import Path
|
||||
from textwrap import indent
|
||||
from typing import Optional, Union
|
||||
|
||||
from cli.lib.common.utils import str2bool
|
||||
|
||||
|
||||
def get_env(name: str, default: str = "") -> str:
|
||||
"""Get environment variable with default fallback."""
|
||||
return os.environ.get(name) or default
|
||||
|
||||
|
||||
def env_path_optional(
|
||||
name: str,
|
||||
default: Optional[Union[str, Path]] = None,
|
||||
resolve: bool = True,
|
||||
) -> Optional[Path]:
|
||||
"""Get environment variable as optional Path."""
|
||||
val = get_env(name) or default
|
||||
if not val:
|
||||
return None
|
||||
|
||||
path = Path(val)
|
||||
return path.resolve() if resolve else path
|
||||
|
||||
|
||||
def env_path(
|
||||
name: str,
|
||||
default: Optional[Union[str, Path]] = None,
|
||||
resolve: bool = True,
|
||||
) -> Path:
|
||||
"""Get environment variable as Path, raise if missing."""
|
||||
path = env_path_optional(name, default, resolve)
|
||||
if not path:
|
||||
raise ValueError(f"Missing path value for {name}")
|
||||
return path
|
||||
|
||||
|
||||
def env_bool(
|
||||
name: str,
|
||||
default: bool = False,
|
||||
) -> bool:
|
||||
val = get_env(name)
|
||||
if not val:
|
||||
return default
|
||||
return str2bool(val)
|
||||
|
||||
|
||||
def env_bool_field(
|
||||
name: str,
|
||||
default: bool = False,
|
||||
):
|
||||
return field(default_factory=lambda: env_bool(name, default))
|
||||
|
||||
|
||||
def env_path_field(
|
||||
name: str,
|
||||
default: Union[str, Path] = "",
|
||||
*,
|
||||
resolve: bool = True,
|
||||
) -> Path:
|
||||
return field(default_factory=lambda: env_path(name, default, resolve=resolve))
|
||||
|
||||
|
||||
def env_str_field(
|
||||
name: str,
|
||||
default: str = "",
|
||||
) -> str:
|
||||
return field(default_factory=lambda: get_env(name, default))
|
||||
|
||||
|
||||
def generate_dataclass_help(cls) -> str:
|
||||
"""Auto-generate help text for dataclass fields."""
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError(f"{cls} is not a dataclass")
|
||||
|
||||
def get_value(f):
|
||||
if f.default is not MISSING:
|
||||
return f.default
|
||||
if f.default_factory is not MISSING:
|
||||
try:
|
||||
return f.default_factory()
|
||||
except Exception as e:
|
||||
return f"<error: {e}>"
|
||||
return "<required>"
|
||||
|
||||
lines = [f"{f.name:<22} = {repr(get_value(f))}" for f in fields(cls)]
|
||||
return indent("\n".join(lines), " ")
|
||||
|
||||
|
||||
def with_params_help(params_cls: type, title: str = "Parameter defaults"):
|
||||
"""
|
||||
Class decorator that appends a help table generated from another dataclass
|
||||
(e.g., VllmParameters) to the decorated class's docstring.
|
||||
"""
|
||||
if not is_dataclass(params_cls):
|
||||
raise TypeError(f"{params_cls} must be a dataclass")
|
||||
|
||||
def _decorator(cls: type) -> type:
|
||||
block = generate_dataclass_help(params_cls)
|
||||
cls.__doc__ = (cls.__doc__ or "") + f"\n\n{title}:\n{block}"
|
||||
return cls
|
||||
|
||||
return _decorator
|
||||
84
.ci/lumen_cli/cli/lib/common/git_helper.py
Normal file
84
.ci/lumen_cli/cli/lib/common/git_helper.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""
|
||||
Git Utility helpers for CLI tasks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from cli.lib.common.path_helper import remove_dir
|
||||
from cli.lib.common.utils import run_command
|
||||
from git import GitCommandError, RemoteProgress, Repo
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PrintProgress(RemoteProgress):
|
||||
"""Simple progress logger for git operations."""
|
||||
|
||||
def __init__(self, interval: int = 5):
|
||||
super().__init__()
|
||||
self._last_percent = -1
|
||||
self._interval = interval
|
||||
|
||||
def update(self, op_code, cur, max=None, message=""):
|
||||
msg = self._cur_line or message
|
||||
if max and cur:
|
||||
percent = int(cur / max * 100)
|
||||
if percent != self._last_percent and percent % self._interval == 0:
|
||||
self._last_percent = percent
|
||||
logger.info("Progress: %d%% - %s", percent, msg)
|
||||
elif msg:
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules=False):
|
||||
"""Clone repository with pinned commit and optional submodules."""
|
||||
dst = dst or target
|
||||
|
||||
try:
|
||||
logger.info("Cloning %s to %s", target, dst)
|
||||
|
||||
# Clone and fetch
|
||||
remove_dir(dst)
|
||||
r = Repo.clone_from(repo, dst, progress=PrintProgress())
|
||||
r.git.fetch("--all", "--tags")
|
||||
|
||||
# Checkout pinned commit
|
||||
commit = get_post_build_pinned_commit(target)
|
||||
logger.info("Checking out pinned commit %s", commit)
|
||||
r.git.checkout(commit)
|
||||
|
||||
# Update submodules if requested
|
||||
if update_submodules and r.submodules:
|
||||
logger.info("Updating %d submodule(s)", len(r.submodules))
|
||||
for sm in r.submodules:
|
||||
sm.update(init=True, recursive=True, progress=PrintProgress())
|
||||
|
||||
logger.info("Successfully cloned %s", target)
|
||||
return r
|
||||
|
||||
except GitCommandError as e:
|
||||
logger.error("Git operation failed: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def clone_vllm_pure(commit: str):
|
||||
"""
|
||||
cloning vllm and checkout pinned commit
|
||||
"""
|
||||
print("clonening vllm....", flush=True)
|
||||
cwd = "vllm"
|
||||
# delete the directory if it exists
|
||||
remove_dir(cwd)
|
||||
# Clone the repo & checkout commit
|
||||
run_command("git clone https://github.com/vllm-project/vllm.git")
|
||||
run_command(f"git checkout {commit}", cwd=cwd)
|
||||
run_command("git submodule update --init --recursive", cwd=cwd)
|
||||
|
||||
|
||||
def get_post_build_pinned_commit(name: str, prefix=".github/ci_commit_pins") -> str:
|
||||
path = Path(prefix) / f"{name}.txt"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Pin file not found: {path}")
|
||||
return path.read_text(encoding="utf-8").strip()
|
||||
14
.ci/lumen_cli/cli/lib/common/logger.py
Normal file
14
.ci/lumen_cli/cli/lib/common/logger.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Logger Utility helpers for CLI tasks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
def setup_logging(level: int = logging.INFO):
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
stream=sys.stdout,
|
||||
)
|
||||
62
.ci/lumen_cli/cli/lib/common/path_helper.py
Normal file
62
.ci/lumen_cli/cli/lib/common/path_helper.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""Path utility helpers for CLI tasks."""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_path(path: Union[str, Path], resolve: bool = False) -> Path:
|
||||
"""Convert to Path object, optionally resolving to absolute path."""
|
||||
if not path:
|
||||
raise ValueError("Path cannot be None or empty")
|
||||
result = Path(path)
|
||||
return result.resolve() if resolve else result
|
||||
|
||||
|
||||
def ensure_dir_exists(path: Union[str, Path]) -> Path:
|
||||
"""Create directory if it doesn't exist."""
|
||||
path_obj = get_path(path)
|
||||
path_obj.mkdir(parents=True, exist_ok=True)
|
||||
return path_obj
|
||||
|
||||
|
||||
def remove_dir(path: Union[str, Path, None]) -> None:
|
||||
"""Remove directory if it exists."""
|
||||
if not path:
|
||||
return
|
||||
path_obj = get_path(path)
|
||||
if path_obj.exists():
|
||||
shutil.rmtree(path_obj)
|
||||
|
||||
|
||||
def force_create_dir(path: Union[str, Path]) -> Path:
|
||||
"""Remove directory if exists, then create fresh empty directory."""
|
||||
remove_dir(path)
|
||||
return ensure_dir_exists(path)
|
||||
|
||||
|
||||
def copy(src: Union[str, Path], dst: Union[str, Path]) -> None:
|
||||
"""Copy file or directory from src to dst."""
|
||||
src_path = get_path(src, resolve=True)
|
||||
dst_path = get_path(dst, resolve=True)
|
||||
|
||||
if not src_path.exists():
|
||||
raise FileNotFoundError(f"Source does not exist: {src_path}")
|
||||
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if src_path.is_file():
|
||||
shutil.copy2(src_path, dst_path)
|
||||
elif src_path.is_dir():
|
||||
shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported path type: {src_path}")
|
||||
|
||||
|
||||
def is_path_exist(path: Union[str, Path, None]) -> bool:
|
||||
"""Check if path exists."""
|
||||
return bool(path and get_path(path).exists())
|
||||
69
.ci/lumen_cli/cli/lib/common/pip_helper.py
Normal file
69
.ci/lumen_cli/cli/lib/common/pip_helper.py
Normal file
@ -0,0 +1,69 @@
|
||||
import glob
|
||||
import logging
|
||||
import shlex
|
||||
import shutil
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
from cli.lib.common.utils import run_command
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pip_install_packages(
|
||||
packages: Iterable[str] = (),
|
||||
env=None,
|
||||
*,
|
||||
requirements: Optional[str] = None,
|
||||
constraints: Optional[str] = None,
|
||||
prefer_uv: bool = False,
|
||||
) -> None:
|
||||
use_uv = prefer_uv and shutil.which("uv") is not None
|
||||
base = (
|
||||
[sys.executable, "-m", "uv", "pip", "install"]
|
||||
if use_uv
|
||||
else [sys.executable, "-m", "pip", "install"]
|
||||
)
|
||||
|
||||
if use_uv:
|
||||
logger.info("Installing packages using uv pip")
|
||||
|
||||
cmd = base[:]
|
||||
if requirements:
|
||||
cmd += ["-r", requirements]
|
||||
if constraints:
|
||||
cmd += ["-c", constraints]
|
||||
cmd += list(packages)
|
||||
|
||||
logger.info("pip installing packages: %s", " ".join(map(shlex.quote, cmd)))
|
||||
run_command(" ".join(map(shlex.quote, cmd)), env=env)
|
||||
logger.info("Done installing packages")
|
||||
|
||||
|
||||
def pip_install_first_match(pattern: str, extras: Optional[str] = None, pref_uv=False):
|
||||
"""
|
||||
Install the first local whl that matches the given glob pattern.
|
||||
|
||||
Args:
|
||||
pattern (str): Glob pattern for the wheel file(s).
|
||||
extras (str | None): Optional extras (e.g., "opt_einsum") to install with the wheel.
|
||||
"""
|
||||
matches = sorted(glob.glob(pattern))
|
||||
if not matches:
|
||||
raise FileNotFoundError(f"No files match: {pattern}")
|
||||
wheel = matches[0]
|
||||
target = f"{wheel}[{extras}]" if extras else wheel
|
||||
logger.info("Installing wheel: %s", target)
|
||||
pip_install_packages([target], prefer_uv=pref_uv)
|
||||
|
||||
|
||||
def run_python(args: Union[str, list[str]], env=None):
|
||||
"""
|
||||
Run the python in the current environment.
|
||||
"""
|
||||
if isinstance(args, str):
|
||||
args = shlex.split(args)
|
||||
cmd = [sys.executable] + args
|
||||
run_command(" ".join(map(shlex.quote, cmd)), env=env)
|
||||
117
.ci/lumen_cli/cli/lib/common/utils.py
Normal file
117
.ci/lumen_cli/cli/lib/common/utils.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""
|
||||
General Utility helpers for CLI tasks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_command(
|
||||
cmd: str,
|
||||
use_shell: bool = False,
|
||||
log_cmd: bool = True,
|
||||
cwd: Optional[str] = None,
|
||||
env: Optional[dict] = None,
|
||||
check: bool = True,
|
||||
) -> int:
|
||||
"""Run a command with optional shell execution."""
|
||||
if use_shell:
|
||||
args = cmd
|
||||
log_prefix = "[shell]"
|
||||
executable = "/bin/bash"
|
||||
else:
|
||||
args = shlex.split(cmd)
|
||||
log_prefix = "[cmd]"
|
||||
executable = None
|
||||
|
||||
if log_cmd:
|
||||
display_cmd = cmd if use_shell else " ".join(args)
|
||||
logger.info("%s %s", log_prefix, display_cmd)
|
||||
|
||||
run_env = {**os.environ, **(env or {})}
|
||||
|
||||
proc = subprocess.run(
|
||||
args,
|
||||
shell=use_shell,
|
||||
executable=executable,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
cwd=cwd,
|
||||
env=run_env,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if check and proc.returncode != 0:
|
||||
logger.error(
|
||||
"%s Command failed (exit %s): %s", log_prefix, proc.returncode, cmd
|
||||
)
|
||||
raise subprocess.CalledProcessError(
|
||||
proc.returncode, args if not use_shell else cmd
|
||||
)
|
||||
|
||||
return proc.returncode
|
||||
|
||||
|
||||
def str2bool(value: Optional[str]) -> bool:
|
||||
"""Convert environment variables to boolean values."""
|
||||
if not value:
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Expected a string value for boolean conversion, got {type(value)}"
|
||||
)
|
||||
value = value.strip().lower()
|
||||
|
||||
true_value_set = {"1", "true", "t", "yes", "y", "on", "enable", "enabled", "found"}
|
||||
false_value_set = {"0", "false", "f", "no", "n", "off", "disable"}
|
||||
|
||||
if value in true_value_set:
|
||||
return True
|
||||
if value in false_value_set:
|
||||
return False
|
||||
raise ValueError(f"Invalid string value for boolean conversion: {value}")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_environ(updates: dict[str, str]):
|
||||
"""
|
||||
Temporarily set environment variables and restore them after the block.
|
||||
Args:
|
||||
updates: Dict of environment variables to set.
|
||||
"""
|
||||
missing = object()
|
||||
old: dict[str, str | object] = {k: os.environ.get(k, missing) for k in updates}
|
||||
try:
|
||||
os.environ.update(updates)
|
||||
yield
|
||||
finally:
|
||||
for k, v in old.items():
|
||||
if v is missing:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v # type: ignore[arg-type]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def working_directory(path: str):
|
||||
"""
|
||||
Temporarily change the working directory inside a context.
|
||||
"""
|
||||
if not path:
|
||||
# No-op context
|
||||
yield
|
||||
return
|
||||
prev_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(path)
|
||||
yield
|
||||
finally:
|
||||
os.chdir(prev_cwd)
|
||||
650
.ci/lumen_cli/cli/lib/core/vllm.py
Normal file
650
.ci/lumen_cli/cli/lib/core/vllm.py
Normal file
@ -0,0 +1,650 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from cli.lib.common.cli_helper import BaseRunner
|
||||
from cli.lib.common.docker_helper import local_image_exists
|
||||
from cli.lib.common.envs_helper import (
|
||||
env_bool_field,
|
||||
env_path_field,
|
||||
env_str_field,
|
||||
get_env,
|
||||
with_params_help,
|
||||
)
|
||||
from cli.lib.common.git_helper import clone_vllm_pure, get_post_build_pinned_commit
|
||||
from cli.lib.common.path_helper import (
|
||||
copy,
|
||||
ensure_dir_exists,
|
||||
force_create_dir,
|
||||
get_path,
|
||||
is_path_exist,
|
||||
remove_dir,
|
||||
)
|
||||
from cli.lib.common.pip_helper import (
|
||||
pip_install_first_match,
|
||||
pip_install_packages,
|
||||
run_python,
|
||||
)
|
||||
from cli.lib.common.utils import run_command, temp_environ, working_directory
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Default path for docker build artifacts
|
||||
_DEFAULT_RESULT_PATH = "./shared"
|
||||
|
||||
# Temp folder in vllm work place to cp torch whls in vllm work directory for docker build
|
||||
_VLLM_TEMP_FOLDER = "tmp"
|
||||
|
||||
|
||||
@dataclass
|
||||
class VllmBuildParameters:
|
||||
"""
|
||||
Parameters defining the vllm external input configurations.
|
||||
Combine with VllmDockerBuildArgs to define the vllm build environment
|
||||
"""
|
||||
|
||||
# USE_TORCH_WHEEL: when true, use local Torch wheels; requires TORCH_WHEELS_PATH.
|
||||
# Otherwise docker build pull torch nightly during build
|
||||
# TORCH_WHEELS_PATH: directory containing local torch wheels when use_torch_whl is True
|
||||
use_torch_whl: bool = env_bool_field("USE_TORCH_WHEEL", True)
|
||||
torch_whls_path: Path = env_path_field("TORCH_WHEELS_PATH", "./dist")
|
||||
|
||||
# USE_LOCAL_BASE_IMAGE: when true, use an existing local Docker base image; requires BASE_IMAGE
|
||||
# Otherwise, pull dockerfile's default image remotely
|
||||
# BASE_IMAGE: name:tag (only needed when use_local_base_image is True)
|
||||
use_local_base_image: bool = env_bool_field("USE_LOCAL_BASE_IMAGE", True)
|
||||
base_image: str = env_str_field("BASE_IMAGE")
|
||||
|
||||
# USE_LOCAL_DOCKERFILE: when true("1"), use a local Dockerfile; requires DOCKERFILE_PATH.
|
||||
# otherwise, use vllm's default dockerfile.torch_nightly for build
|
||||
# DOCKERFILE_PATH: path to Dockerfile used when use_local_dockerfile is True"
|
||||
use_local_dockerfile: bool = env_bool_field("USE_LOCAL_DOCKERFILE", True)
|
||||
dockerfile_path: Path = env_path_field(
|
||||
"DOCKERFILE_PATH", ".github/ci_configs/vllm/Dockerfile.tmp_vllm"
|
||||
)
|
||||
|
||||
# OUTPUT_DIR: where docker buildx (local exporter) will write artifacts
|
||||
output_dir: Path = env_path_field("OUTPUT_DIR", "shared")
|
||||
|
||||
# --- Build args ----------------------------------------------------------
|
||||
target_stage: str = env_str_field("TARGET_STAGE", "export-wheels")
|
||||
|
||||
tag_name: str = env_str_field("TAG", "vllm-wheels")
|
||||
|
||||
cuda_version: str = env_str_field("CUDA_VERSION", "12.8.1")
|
||||
|
||||
python_version: str = env_str_field("PYTHON_VERSION", "3.12")
|
||||
|
||||
max_jobs: str = env_str_field("MAX_JOBS", "64")
|
||||
|
||||
sccache_bucket: str = env_str_field("SCCACHE_BUCKET")
|
||||
|
||||
sccache_region: str = env_str_field("SCCACHE_REGION")
|
||||
|
||||
torch_cuda_arch_list: str = env_str_field("TORCH_CUDA_ARCH_LIST", "8.9")
|
||||
|
||||
def __post_init__(self):
|
||||
checks = [
|
||||
(
|
||||
self.use_torch_whl, # flag
|
||||
True, # trigger_value
|
||||
"torch_whls_path", # resource
|
||||
is_path_exist, # check_func
|
||||
"TORCH_WHEELS_PATH is not provided, but USE_TORCH_WHEEL is set to 1",
|
||||
),
|
||||
(
|
||||
self.use_local_base_image,
|
||||
True,
|
||||
"base_image",
|
||||
local_image_exists,
|
||||
f"BASE_IMAGE {self.base_image} does not found, but USE_LOCAL_BASE_IMAGE is set to 1",
|
||||
),
|
||||
(
|
||||
self.use_local_dockerfile,
|
||||
True,
|
||||
"dockerfile_path",
|
||||
is_path_exist,
|
||||
" DOCKERFILE_PATH path does not found, but USE_LOCAL_DOCKERFILE is set to 1",
|
||||
),
|
||||
]
|
||||
for flag, trigger_value, attr_name, check_func, error_msg in checks:
|
||||
value = getattr(self, attr_name)
|
||||
if flag == trigger_value:
|
||||
if not value or not check_func(value):
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
logger.info("flag %s is not set", flag)
|
||||
if not self.output_dir:
|
||||
raise ValueError("missing required output_dir")
|
||||
|
||||
|
||||
@with_params_help(VllmBuildParameters)
|
||||
class VllmBuildRunner(BaseRunner):
|
||||
"""
|
||||
Build vLLM using docker buildx.
|
||||
|
||||
Environment variable options:
|
||||
"USE_TORCH_WHEEL": "1: use local wheels; 0: pull nightly from pypi",
|
||||
"TORCH_WHEELS_PATH": "Path to local wheels (when USE_TORCH_WHEEL=1)",
|
||||
|
||||
"USE_LOCAL_BASE_IMAGE": "1: use local base image; 0: default image",
|
||||
"BASE_IMAGE": "name:tag to indicate base image the dockerfile depends on (when USE_LOCAL_BASE_IMAGE=1)",
|
||||
|
||||
"USE_LOCAL_DOCKERFILE": "1: use local Dockerfile; 0: vllm repo default dockerfile.torch_nightly",
|
||||
"DOCKERFILE_PATH": "Path to Dockerfile (when USE_LOCAL_DOCKERFILE=1)",
|
||||
|
||||
"OUTPUT_DIR": "e.g. './shared'",
|
||||
|
||||
"TORCH_CUDA_ARCH_LIST": "e.g. '8.0' or '8.0;9.0'",
|
||||
"CUDA_VERSION": "e.g. '12.8.1'",
|
||||
"PYTHON_VERSION": "e.g. '3.12'",
|
||||
"MAX_JOBS": "e.g. '64'",
|
||||
"SCCACHE_BUCKET": "e.g. 'my-bucket'",
|
||||
"SCCACHE_REGION": "e.g. 'us-west-2'",
|
||||
"""
|
||||
|
||||
def __init__(self, args=None):
|
||||
self.work_directory = "vllm"
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
main function to run vllm build
|
||||
1. prepare vllm build environment
|
||||
2. prepare the docker build command args
|
||||
3. run docker build
|
||||
"""
|
||||
inputs = VllmBuildParameters()
|
||||
logger.info("Running vllm build with inputs: %s", inputs)
|
||||
clone_vllm()
|
||||
|
||||
self.cp_dockerfile_if_exist(inputs)
|
||||
|
||||
# cp torch wheels from root direct to vllm workspace if exist
|
||||
self.cp_torch_whls_if_exist(inputs)
|
||||
|
||||
ensure_dir_exists(inputs.output_dir)
|
||||
|
||||
cmd = self._generate_docker_build_cmd(inputs)
|
||||
logger.info("Running docker build: \n %s", cmd)
|
||||
run_command(cmd, cwd="vllm", env=os.environ.copy())
|
||||
|
||||
def cp_torch_whls_if_exist(self, inputs: VllmBuildParameters) -> str:
|
||||
if not inputs.use_torch_whl:
|
||||
return ""
|
||||
tmp_dir = f"./{self.work_directory}/{_VLLM_TEMP_FOLDER}"
|
||||
tmp_path = Path(tmp_dir)
|
||||
force_create_dir(tmp_path)
|
||||
copy(inputs.torch_whls_path, tmp_dir)
|
||||
return tmp_dir
|
||||
|
||||
def cp_dockerfile_if_exist(self, inputs: VllmBuildParameters):
|
||||
if not inputs.use_local_dockerfile:
|
||||
logger.info("using vllm default dockerfile.torch_nightly for build")
|
||||
return
|
||||
dockerfile_path = get_path(inputs.dockerfile_path, resolve=True)
|
||||
vllm_torch_dockerfile = Path(
|
||||
f"./{self.work_directory}/docker/Dockerfile.nightly_torch"
|
||||
)
|
||||
copy(dockerfile_path, vllm_torch_dockerfile)
|
||||
|
||||
def get_result_path(self, path):
|
||||
"""
|
||||
Get the absolute path of the result path
|
||||
"""
|
||||
if not path:
|
||||
path = _DEFAULT_RESULT_PATH
|
||||
abs_path = get_path(path, resolve=True)
|
||||
return abs_path
|
||||
|
||||
def _get_torch_wheel_path_arg(self, torch_whl_dir: Optional[Path]) -> str:
|
||||
if not torch_whl_dir:
|
||||
return ""
|
||||
return f"--build-arg TORCH_WHEELS_PATH={_VLLM_TEMP_FOLDER}"
|
||||
|
||||
def _get_base_image_args(self, inputs: VllmBuildParameters) -> tuple[str, str, str]:
|
||||
"""
|
||||
Returns:
|
||||
- base_image_arg: docker buildx arg string for base image
|
||||
- final_base_image_arg: docker buildx arg string for vllm-base stage
|
||||
- pull_flag: --pull=true or --pull=false depending on whether the image exists locally
|
||||
"""
|
||||
if not inputs.use_local_base_image:
|
||||
return "", "", ""
|
||||
|
||||
base_image = inputs.base_image
|
||||
|
||||
# set both base image and final base image to the same local image
|
||||
base_image_arg = f"--build-arg BUILD_BASE_IMAGE={base_image}"
|
||||
final_base_image_arg = f"--build-arg FINAL_BASE_IMAGE={base_image}"
|
||||
|
||||
if local_image_exists(base_image):
|
||||
pull_flag = "--pull=false"
|
||||
return base_image_arg, final_base_image_arg, pull_flag
|
||||
logger.info(
|
||||
"[INFO] Local image not found:%s will try to pull from remote", {base_image}
|
||||
)
|
||||
return base_image_arg, final_base_image_arg, ""
|
||||
|
||||
def _generate_docker_build_cmd(
|
||||
self,
|
||||
inputs: VllmBuildParameters,
|
||||
) -> str:
|
||||
base_image_arg, final_base_image_arg, pull_flag = self._get_base_image_args(
|
||||
inputs
|
||||
)
|
||||
torch_arg = self._get_torch_wheel_path_arg(inputs.torch_whls_path)
|
||||
|
||||
return textwrap.dedent(
|
||||
f"""
|
||||
docker buildx build \
|
||||
--output type=local,dest={inputs.output_dir} \
|
||||
-f docker/Dockerfile.nightly_torch \
|
||||
{pull_flag} \
|
||||
{torch_arg} \
|
||||
{base_image_arg} \
|
||||
{final_base_image_arg} \
|
||||
--build-arg max_jobs={inputs.max_jobs} \
|
||||
--build-arg CUDA_VERSION={inputs.cuda_version} \
|
||||
--build-arg PYTHON_VERSION={inputs.python_version} \
|
||||
--build-arg USE_SCCACHE={int(bool(inputs.sccache_bucket and inputs.sccache_region))} \
|
||||
--build-arg SCCACHE_BUCKET_NAME={inputs.sccache_bucket} \
|
||||
--build-arg SCCACHE_REGION_NAME={inputs.sccache_region} \
|
||||
--build-arg torch_cuda_arch_list='{inputs.torch_cuda_arch_list}' \
|
||||
--target {inputs.target_stage} \
|
||||
-t {inputs.tag_name} \
|
||||
--progress=plain .
|
||||
"""
|
||||
).strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VllmTestParameters:
|
||||
"""
|
||||
Parameters defining the vllm external test input
|
||||
|
||||
!!!DO NOT ADD SECRETS IN THIS CLASS!!!
|
||||
you can put environment variable name in VllmTestParameters if it's not the same as the secret one
|
||||
fetch secrests directly from env variables during runtime
|
||||
"""
|
||||
|
||||
torch_whls_path: Path = env_path_field("TORCH_WHEELS_PATH", "./dist")
|
||||
|
||||
vllm_whls_path: Path = env_path_field("VLLM_WHEELS_PATH", "./shared")
|
||||
|
||||
torch_cuda_arch_list: str = env_str_field("TORCH_CUDA_ARCH_LIST", "8.9")
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.torch_whls_path.exists():
|
||||
raise ValueError("missing torch_whls_path")
|
||||
if not self.vllm_whls_path.exists():
|
||||
raise ValueError("missing vllm_whls_path")
|
||||
|
||||
|
||||
class TestInpuType(Enum):
|
||||
TEST_PLAN = "test_plan"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class VllmTestRunner(BaseRunner):
|
||||
def __init__(self, args: Any):
|
||||
self.work_directory = "vllm"
|
||||
|
||||
self.test_plan = ""
|
||||
self.test_type = TestInpuType.UNKNOWN
|
||||
|
||||
if args.test_plan:
|
||||
self.test_plan = args.test_plan
|
||||
self.test_type = TestInpuType.TEST_PLAN
|
||||
|
||||
# Matches the structeur in the artifacts.zip from torcb build
|
||||
self.TORCH_WHL_PATH_REGEX = "torch*.whl"
|
||||
self.TORCH_WHL_EXTRA = "opt-einsum"
|
||||
self.TORCH_ADDITIONAL_WHLS_REGEX = [
|
||||
"vision/torchvision*.whl",
|
||||
"audio/torchaudio*.whl",
|
||||
]
|
||||
|
||||
# Match the structure of the artifacts.zip from vllm external build
|
||||
self.VLLM_TEST_WHLS_REGEX = [
|
||||
"wheels/xformers/xformers*.whl",
|
||||
"wheels/vllm/vllm*.whl",
|
||||
"wheels/flashinfer-python/flashinfer*.whl",
|
||||
]
|
||||
|
||||
def prepare(self):
|
||||
"""
|
||||
prepare test environment for vllm. This includes clone vllm repo, install all wheels, test dependencies and set env
|
||||
"""
|
||||
params = VllmTestParameters()
|
||||
logger.info("Display VllmTestParameters %s", params)
|
||||
self._set_envs(params)
|
||||
|
||||
clone_vllm(dst=self.work_directory)
|
||||
with working_directory(self.work_directory):
|
||||
remove_dir(Path("vllm"))
|
||||
self._install_wheels(params)
|
||||
self._install_dependencies()
|
||||
# verify the torches are not overridden by test dependencies
|
||||
self.check_versions()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
main function to run vllm test
|
||||
"""
|
||||
self.prepare()
|
||||
with working_directory(self.work_directory):
|
||||
if self.test_type == TestInpuType.TEST_PLAN:
|
||||
self.run_test_plan(self.test_plan)
|
||||
else:
|
||||
raise ValueError(f"Unknown test type {self.test_type}")
|
||||
|
||||
def _install_wheels(self, params: VllmTestParameters):
|
||||
logger.info("Running vllm test with inputs: %s", params)
|
||||
logger.info("Installing torch wheel")
|
||||
# torch_p = f"{str(params.torch_whls_path)}/{self.TORCH_WHL_PATH_REGEX}"
|
||||
# pip_install_first_match(torch_p, self.TORCH_WHL_EXTRA)
|
||||
|
||||
logger.info("Installing other torch-related wheels")
|
||||
torch_whls_path = [
|
||||
f"{str(params.torch_whls_path)}/{whl_path}"
|
||||
for whl_path in self.TORCH_ADDITIONAL_WHLS_REGEX
|
||||
]
|
||||
for torch_whl in torch_whls_path:
|
||||
pip_install_first_match(torch_whl)
|
||||
logger.info("Done. Installed torch and other torch-related wheels ")
|
||||
|
||||
logger.info("Installing vllm wheels")
|
||||
vllm_whls_path = [
|
||||
f"{str(params.vllm_whls_path)}/{whl_path}"
|
||||
for whl_path in self.VLLM_TEST_WHLS_REGEX
|
||||
]
|
||||
for vllm_whl in vllm_whls_path:
|
||||
pip_install_first_match(vllm_whl)
|
||||
logger.info("Done. Installed vllm wheels")
|
||||
|
||||
def _install_test_dependencies(self):
|
||||
"""
|
||||
Install test dependencies for vllm test
|
||||
This method replaces torch dependencies with local torch wheel info in
|
||||
requirements/test.in file from vllm repo.
|
||||
|
||||
Then generates the test.txt file using uv pip compile, along with requirements/test.txt,
|
||||
which is generated by the test.in with torch stable as soft constraint to match
|
||||
packages' version
|
||||
|
||||
"""
|
||||
# TODO(elainewy): move this as part of vllm build, to generate the test.txt file
|
||||
logger.info("generate test.txt from requirements/test.in with local torch whls")
|
||||
preprocess_test_in()
|
||||
|
||||
copy(
|
||||
Path("requirements/test.txt"),
|
||||
Path("snapshot_constraint.txt"),
|
||||
)
|
||||
run_command(
|
||||
f"{sys.executable} -m uv pip compile requirements/test.in "
|
||||
"-o test.txt "
|
||||
"--index-strategy unsafe-best-match "
|
||||
"--constraint snapshot_constraint.txt "
|
||||
"--torch-backend cu128"
|
||||
)
|
||||
logger.info("install requirements from test.txt")
|
||||
pip_install_packages(requirements="test.txt", prefer_uv=True)
|
||||
logger.info("Done. install requirements from test.txt")
|
||||
|
||||
# install mambda from source since it does not work now with pip
|
||||
# TODO(elainewy): move this as part of vllm build
|
||||
pip_install_packages(
|
||||
packages=[
|
||||
"--no-build-isolation",
|
||||
"git+https://github.com/state-spaces/mamba@v2.2.4",
|
||||
],
|
||||
prefer_uv=True,
|
||||
)
|
||||
logger.info("Done. installed requirements from test.txt")
|
||||
|
||||
def _install_dependencies(self):
|
||||
pip_install_packages(packages=["-e", "tests/vllm_test_utils"], prefer_uv=True)
|
||||
pip_install_packages(packages=["hf_transfer"], prefer_uv=True)
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
# using script from vllm repo to remove all torch packages from requirements txt
|
||||
run_python("use_existing_torch.py")
|
||||
|
||||
# install common packages
|
||||
for requirements in ["requirements/common.txt", "requirements/build.txt"]:
|
||||
pip_install_packages(
|
||||
requirements=requirements,
|
||||
prefer_uv=True,
|
||||
)
|
||||
# install test packages
|
||||
self._install_test_dependencies()
|
||||
|
||||
def check_versions(self):
|
||||
"""
|
||||
check installed packages version
|
||||
"""
|
||||
logger.info("double check installed packages")
|
||||
patterns = ["torch", "xformers", "torchvision", "torchaudio", "vllm"]
|
||||
for pkg in patterns:
|
||||
try:
|
||||
module = __import__(pkg)
|
||||
version = getattr(module, "__version__", None)
|
||||
version = version if version else "Unknown version"
|
||||
logger.info("%s: %s", pkg, version)
|
||||
except ImportError:
|
||||
logger.info(" %s: Not installed", pkg)
|
||||
logger.info("Done. checked installed packages")
|
||||
|
||||
def _set_envs(self, inputs: VllmTestParameters):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = inputs.torch_cuda_arch_list
|
||||
if not self.validate_cuda(get_env("TORCH_CUDA_ARCH_LIST")):
|
||||
logger.warning(
|
||||
"Missing supported TORCH_CUDA_ARCH_LIST. "
|
||||
"Currently support TORCH_CUDA_ARCH_LIST env var "
|
||||
"with supported arch [8.0, 8.9, 9.0]"
|
||||
)
|
||||
self.validate_cuda(get_env("TORCH_CUDA_ARCH_LIST"))
|
||||
|
||||
os.environ["HF_TOKEN"] = os.getenv("VLLM_TEST_HUGGING_FACE_TOKEN", "")
|
||||
if not get_env("HF_TOKEN"):
|
||||
raise ValueError(
|
||||
"missing required HF_TOKEN, please set VLLM_TEST_HUGGING_FACE_TOKEN env var"
|
||||
)
|
||||
if not get_env("TORCH_CUDA_ARCH_LIST"):
|
||||
raise ValueError(
|
||||
"missing required TORCH_CUDA_ARCH_LIST, please set TORCH_CUDA_ARCH_LIST env var"
|
||||
)
|
||||
|
||||
def run_test_plan(self, test_plan: str):
|
||||
"""
|
||||
a method to run list of tests based on the test plan. currently this only
|
||||
used to run vllm tests.
|
||||
"""
|
||||
logger.info("run vllm tests.....")
|
||||
tests_map = sample_test_plans()
|
||||
if test_plan not in tests_map:
|
||||
raise RuntimeError(
|
||||
f"test {test_plan} not found, please add it to test plan pool"
|
||||
)
|
||||
tests = tests_map[test_plan]
|
||||
logger.info("Running tests: %s", tests["title"])
|
||||
|
||||
pkgs = tests.get("package_install", [])
|
||||
if pkgs:
|
||||
logger.info("Installing packages: %s", pkgs)
|
||||
pip_install_packages(packages=pkgs, prefer_uv=True)
|
||||
with (
|
||||
temp_environ(tests.get("env_var", {})),
|
||||
working_directory(tests.get("working_directory", "tests")),
|
||||
):
|
||||
failures = []
|
||||
for step in tests["steps"]:
|
||||
with temp_environ(step.get("env_var", {})):
|
||||
code = run_command(cmd=step["command"], check=False)
|
||||
if code != 0:
|
||||
failures.append(step)
|
||||
if failures:
|
||||
logger.error("Failed tests: %s", failures)
|
||||
raise RuntimeError(f"{len(failures)} pytest runs failed: {failures}")
|
||||
logger.info("Done. All tests passed")
|
||||
|
||||
def validate_cuda(self, value: str) -> bool:
|
||||
VALID_VALUES = {"8.0", "8.9", "9.0"}
|
||||
return all(v in VALID_VALUES for v in value.split())
|
||||
|
||||
|
||||
def clone_vllm(dst: str = "vllm"):
|
||||
clone_vllm_pure(get_post_build_pinned_commit(dst))
|
||||
"""
|
||||
clone_external_repo(
|
||||
target="vllm",
|
||||
repo="https://github.com/vllm-project/vllm.git",
|
||||
dst=dst,
|
||||
update_submodules=True,
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def preprocess_test_in(
|
||||
target_file: str = "requirements/test.in", additional_packages: Iterable[str] = ()
|
||||
):
|
||||
"""
|
||||
This modifies the target_file file in place. by default, it points to vllm's requirements/test.in
|
||||
It removes torch packages in target_file and replace with local torch whls
|
||||
"""
|
||||
additional_package_to_move = list(additional_packages or ())
|
||||
pkgs_to_remove = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"xformers",
|
||||
"mamba_ssm",
|
||||
] + additional_package_to_move
|
||||
# Read current requirements
|
||||
target_path = Path(target_file)
|
||||
lines = target_path.read_text().splitlines()
|
||||
|
||||
# Remove lines starting with the package names (==, @, >=) — case-insensitive
|
||||
pattern = re.compile(rf"^({'|'.join(pkgs_to_remove)})\s*(==|@|>=)", re.IGNORECASE)
|
||||
kept_lines = [line for line in lines if not pattern.match(line)]
|
||||
|
||||
# Get local torch/vision/audio installs from pip freeze
|
||||
# this is hacky, but it works
|
||||
pip_freeze = subprocess.check_output(["pip", "freeze"], text=True)
|
||||
header_lines = [
|
||||
line
|
||||
for line in pip_freeze.splitlines()
|
||||
if re.match(
|
||||
r"^(torch|torchvision|torchaudio)\s*@\s*file://", line, re.IGNORECASE
|
||||
)
|
||||
]
|
||||
|
||||
# Write back: header_lines + blank + kept_lines
|
||||
out = "\n".join(header_lines + [""] + kept_lines) + "\n"
|
||||
target_path.write_text(out)
|
||||
logger.info("[INFO] Updated %s", target_file)
|
||||
|
||||
|
||||
def sample_test_plans():
|
||||
"""
|
||||
Simple sample to unblock the vllm ci development, which is mimic to
|
||||
https://github.com/vllm-project/vllm/blob/main/.buildkite/test-pipeline.yaml
|
||||
"""
|
||||
# TODO(elainewy): Read from yaml file to handle the env and tests for vllm
|
||||
# TODO(elainewy): implement logics to handle package_install
|
||||
return {
|
||||
# test plan:
|
||||
# required id, title, and steps
|
||||
# optional: env_var, package_install, working_directory
|
||||
# by default the working_drectory is "tests/", but it can be changed based on tests, for instance,
|
||||
# vllm sample test happens in samples/
|
||||
"vllm_basic_correctness_test": {
|
||||
"title": "Basic Correctness Test",
|
||||
"id": "vllm_basic_correctness_test",
|
||||
"env_var": {
|
||||
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
|
||||
},
|
||||
# test step:
|
||||
# required: command
|
||||
# available fields: env_var (env_var only set within the scope of the test step), package_install(pip package)
|
||||
"steps": [
|
||||
{
|
||||
"command": "pytest -v -s basic_correctness/test_cumem.py",
|
||||
},
|
||||
{
|
||||
"command": "pytest -v -s basic_correctness/test_basic_correctness.py",
|
||||
},
|
||||
{
|
||||
"command": "pytest -v -s basic_correctness/test_cpu_offload.py",
|
||||
},
|
||||
{
|
||||
"command": "pytest -v -s basic_correctness/test_preemption.py",
|
||||
"env_var": {
|
||||
"VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT": "1",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"vllm_basic_models_test": {
|
||||
"title": "Basic models test",
|
||||
"id": "vllm_basic_models_test",
|
||||
"steps": [
|
||||
{"command": "pytest -v -s models/test_transformers.py"},
|
||||
{"command": "pytest -v -s models/test_registry.py"},
|
||||
{"command": "pytest -v -s models/test_utils.py"},
|
||||
{"command": "pytest -v -s models/test_vision.py"},
|
||||
{"command": "pytest -v -s models/test_initialization.py"},
|
||||
],
|
||||
},
|
||||
"vllm_entrypoints_test": {
|
||||
"title": "Entrypoints Test ",
|
||||
"id": "vllm_entrypoints_test",
|
||||
"env_var": {
|
||||
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
|
||||
},
|
||||
"steps": [
|
||||
{
|
||||
"command": " ".join(
|
||||
[
|
||||
"pytest",
|
||||
"-v",
|
||||
"-s",
|
||||
"entrypoints/llm",
|
||||
"--ignore=entrypoints/llm/test_lazy_outlines.py",
|
||||
"--ignore=entrypoints/llm/test_generate.py",
|
||||
"--ignore=entrypoints/llm/test_generate_multiple_loras.py",
|
||||
"--ignore=entrypoints/llm/test_collective_rpc.py",
|
||||
]
|
||||
)
|
||||
},
|
||||
{"command": "pytest -v -s entrypoints/llm/test_lazy_outlines.py"},
|
||||
{"command": "pytest -v -s entrypoints/llm/test_generate.py "},
|
||||
{
|
||||
"command": "pytest -v -s entrypoints/llm/test_generate_multiple_loras.py"
|
||||
},
|
||||
{
|
||||
"env_var": {"VLLM_USE_V1": "0"},
|
||||
"command": "pytest -v -s entrypoints/offline_mode",
|
||||
},
|
||||
],
|
||||
},
|
||||
"vllm_regression_test": {
|
||||
"title": "Regression Test",
|
||||
"id": "vllm_regression_test",
|
||||
"package_install": ["modelscope"],
|
||||
"steps": [
|
||||
{"command": "pytest -v -s test_regression.py"},
|
||||
],
|
||||
},
|
||||
}
|
||||
40
.ci/lumen_cli/cli/run.py
Normal file
40
.ci/lumen_cli/cli/run.py
Normal file
@ -0,0 +1,40 @@
|
||||
# main.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from cli.build_cli.register_build import register_build_commands
|
||||
from cli.lib.common.logger import setup_logging
|
||||
from cli.test_cli.register_test import register_test_commands
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
# Define top-level parser
|
||||
parser = argparse.ArgumentParser(description="Lumos CLI")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
parser.add_argument(
|
||||
"--log-level", default="INFO", help="Log level (DEBUG, INFO, WARNING, ERROR)"
|
||||
)
|
||||
|
||||
# registers second-level subcommands
|
||||
register_build_commands(subparsers)
|
||||
register_test_commands(subparsers)
|
||||
|
||||
# parse args after all options are registered
|
||||
args = parser.parse_args()
|
||||
|
||||
# setup global logging
|
||||
setup_logging(getattr(logging, args.log_level.upper(), logging.INFO))
|
||||
logger.debug("Parsed args: %s", args)
|
||||
|
||||
if hasattr(args, "func"):
|
||||
args.func(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
51
.ci/lumen_cli/cli/test_cli/register_test.py
Normal file
51
.ci/lumen_cli/cli/test_cli/register_test.py
Normal file
@ -0,0 +1,51 @@
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from cli.lib.common.cli_helper import register_targets, RichHelp, TargetSpec
|
||||
from cli.lib.core.vllm import VllmTestRunner
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps targets to their argparse configuration and runner
|
||||
# it adds new target to path python -m cli.run build external {target} with buildrunner
|
||||
_TARGETS: dict[str, TargetSpec] = {
|
||||
"vllm": {
|
||||
"runner": VllmTestRunner,
|
||||
"help": "test vLLM unittests",
|
||||
}
|
||||
# add yours ...
|
||||
}
|
||||
|
||||
|
||||
def common_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""
|
||||
Add common CLI arguments to the given parser.
|
||||
"""
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"-tp",
|
||||
"--test-plan",
|
||||
type=str,
|
||||
help="a pre-defined test plan to run, e.g. 'basic_correctness_test'",
|
||||
)
|
||||
# TODO(elainewy):add another common option that user can trigger a specific test with test config
|
||||
|
||||
|
||||
def register_test_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
build_parser = subparsers.add_parser(
|
||||
"test",
|
||||
help="test related commands",
|
||||
formatter_class=RichHelp,
|
||||
)
|
||||
build_subparsers = build_parser.add_subparsers(dest="test_command", required=True)
|
||||
overview = "\n".join(
|
||||
f" {name:12} {spec.get('help', '')}" for name, spec in _TARGETS.items()
|
||||
)
|
||||
external_parser = build_subparsers.add_parser(
|
||||
"external",
|
||||
help="Test external targets",
|
||||
description="Test third-party targets.\n\nAvailable targets:\n" + overview,
|
||||
formatter_class=RichHelp,
|
||||
)
|
||||
register_targets(external_parser, _TARGETS, common_args=common_args)
|
||||
23
.ci/lumen_cli/pyproject.toml
Normal file
23
.ci/lumen_cli/pyproject.toml
Normal file
@ -0,0 +1,23 @@
|
||||
[project]
|
||||
name = "lumen-ci"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"pyyaml==6.0.2",
|
||||
"GitPython==3.1.45",
|
||||
"docker==7.1.0",
|
||||
"pytest==7.3.2",
|
||||
"uv==0.8.4"
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["cli"]
|
||||
|
||||
[tool.setuptools.package-dir]
|
||||
cli = "cli"
|
||||
|
||||
[tool.ruff.lint]
|
||||
# Enable preview mode for linting
|
||||
preview = true
|
||||
|
||||
# Now you can select your preview rules, like RUF048
|
||||
extend-select = ["RUF048"]
|
||||
47
.ci/lumen_cli/tests/test_app.py
Normal file
47
.ci/lumen_cli/tests/test_app.py
Normal file
@ -0,0 +1,47 @@
|
||||
# tests/test_cli.py
|
||||
import io
|
||||
import sys
|
||||
import unittest
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from unittest.mock import patch
|
||||
|
||||
from cli.run import main
|
||||
|
||||
|
||||
class TestArgparseCLI(unittest.TestCase):
|
||||
@patch("cli.build_cli.register_build.VllmBuildRunner.run", return_value=None)
|
||||
@patch("cli.build_cli.register_build.VllmBuildRunner.__init__", return_value=None)
|
||||
def test_cli_run_build_external(self, mock_init, mock_run):
|
||||
from cli.run import main # import after patches if needed
|
||||
|
||||
test_args = ["cli.run", "build", "external", "vllm"]
|
||||
with patch.object(sys, "argv", test_args):
|
||||
# argparse may call sys.exit on error; capture to avoid test aborts
|
||||
try:
|
||||
main()
|
||||
except SystemExit:
|
||||
pass
|
||||
mock_init.assert_called_once() # got constructed
|
||||
mock_run.assert_called_once_with() # run() called
|
||||
|
||||
def test_build_help(self):
|
||||
test_args = ["cli.run", "build", "--help"]
|
||||
|
||||
with patch.object(sys, "argv", test_args):
|
||||
stdout = io.StringIO()
|
||||
stderr = io.StringIO()
|
||||
|
||||
# --help always raises SystemExit(0)
|
||||
with self.assertRaises(SystemExit) as cm:
|
||||
with redirect_stdout(stdout), redirect_stderr(stderr):
|
||||
main()
|
||||
|
||||
self.assertEqual(cm.exception.code, 0)
|
||||
|
||||
output = stdout.getvalue()
|
||||
self.assertIn("usage", output)
|
||||
self.assertIn("external", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
115
.ci/lumen_cli/tests/test_cli_helper.py
Normal file
115
.ci/lumen_cli/tests/test_cli_helper.py
Normal file
@ -0,0 +1,115 @@
|
||||
import argparse
|
||||
import io
|
||||
import unittest
|
||||
from contextlib import redirect_stderr
|
||||
from unittest.mock import patch
|
||||
|
||||
from cli.lib.common.cli_helper import BaseRunner, register_targets, RichHelp, TargetSpec
|
||||
|
||||
|
||||
# ---- Dummy runners for unittests----
|
||||
class FooRunner(BaseRunner):
|
||||
"""Foo description from docstring."""
|
||||
|
||||
def run(self) -> None: # replaced by mock
|
||||
pass
|
||||
|
||||
|
||||
class BarRunner(BaseRunner):
|
||||
def run(self) -> None: # replaced by mock
|
||||
pass
|
||||
|
||||
|
||||
def add_foo_args(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument("--x", type=int, required=True, help="x value")
|
||||
|
||||
|
||||
def common_args(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument("--verbose", action="store_true", help="verbose flag")
|
||||
|
||||
|
||||
def build_parser(specs: dict[str, TargetSpec]) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="app", formatter_class=RichHelp)
|
||||
register_targets(
|
||||
parser=parser,
|
||||
target_specs=specs,
|
||||
common_args=common_args,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_subparser(
|
||||
parser: argparse.ArgumentParser, name: str
|
||||
) -> argparse.ArgumentParser:
|
||||
subparsers_action = next(
|
||||
a
|
||||
for a in parser._subparsers._group_actions # type: ignore[attr-defined]
|
||||
if isinstance(a, argparse._SubParsersAction)
|
||||
)
|
||||
return subparsers_action.choices[name]
|
||||
|
||||
|
||||
class TestRegisterTargets(unittest.TestCase):
|
||||
def test_metavar_lists_targets(self):
|
||||
specs: dict[str, TargetSpec] = {
|
||||
"foo": {"runner": FooRunner, "add_arguments": add_foo_args},
|
||||
"bar": {"runner": BarRunner},
|
||||
}
|
||||
parser = build_parser(specs)
|
||||
subparsers_action = next(
|
||||
a
|
||||
for a in parser._subparsers._group_actions # type: ignore[attr-defined]
|
||||
if isinstance(a, argparse._SubParsersAction)
|
||||
)
|
||||
self.assertEqual(subparsers_action.metavar, "{foo,bar}")
|
||||
|
||||
def test_add_arguments_and_common_args_present(self):
|
||||
specs: dict[str, TargetSpec] = {
|
||||
"foo": {"runner": FooRunner, "add_arguments": add_foo_args},
|
||||
}
|
||||
parser = build_parser(specs)
|
||||
foo = get_subparser(parser, "foo")
|
||||
help_text = foo.format_help()
|
||||
self.assertIn("--x", help_text)
|
||||
self.assertIn("--verbose", help_text)
|
||||
|
||||
def test_runner_constructed_with_ns_and_run_called(self):
|
||||
specs: dict[str, TargetSpec] = {
|
||||
"foo": {"runner": FooRunner, "add_arguments": add_foo_args},
|
||||
}
|
||||
parser = build_parser(specs)
|
||||
|
||||
with (
|
||||
patch.object(FooRunner, "__init__", return_value=None) as mock_init,
|
||||
patch.object(FooRunner, "run", return_value=None) as mock_run,
|
||||
):
|
||||
ns = parser.parse_args(["foo", "--x", "3", "--verbose"])
|
||||
ns.func(ns) # set by register_targets
|
||||
# __init__ received the Namespace
|
||||
self.assertEqual(mock_init.call_count, 1)
|
||||
(called_ns,), _ = mock_init.call_args
|
||||
self.assertIsInstance(called_ns, argparse.Namespace)
|
||||
# run() called with no args
|
||||
mock_run.assert_called_once_with()
|
||||
|
||||
def test_runner_docstring_used_as_description_when_missing(self):
|
||||
specs: dict[str, TargetSpec] = {
|
||||
"foo": {"runner": FooRunner, "add_arguments": add_foo_args},
|
||||
}
|
||||
parser = build_parser(specs)
|
||||
foo = get_subparser(parser, "foo")
|
||||
help_text = foo.format_help()
|
||||
self.assertIn("Foo description from docstring.", help_text)
|
||||
|
||||
def test_missing_target_raises_systemexit_with_usage(self):
|
||||
specs: dict[str, TargetSpec] = {"foo": {"runner": FooRunner}}
|
||||
parser = build_parser(specs)
|
||||
buf = io.StringIO()
|
||||
with self.assertRaises(SystemExit), redirect_stderr(buf):
|
||||
parser.parse_args([])
|
||||
err = buf.getvalue()
|
||||
self.assertIn("usage:", err)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
75
.ci/lumen_cli/tests/test_docker_helper.py
Normal file
75
.ci/lumen_cli/tests/test_docker_helper.py
Normal file
@ -0,0 +1,75 @@
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import docker.errors as derr
|
||||
from cli.lib.common.docker_helper import _get_client, local_image_exists
|
||||
|
||||
|
||||
class TestDockerImageHelpers(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Reset the singleton in the target module
|
||||
patcher = mock.patch("cli.lib.common.docker_helper._docker_client", None)
|
||||
self.addCleanup(patcher.stop)
|
||||
patcher.start()
|
||||
|
||||
def test_local_image_exists_true(self):
|
||||
# Mock a docker client whose images.get returns an object (no exception)
|
||||
mock_client = MagicMock()
|
||||
mock_client.images.get.return_value = object()
|
||||
ok = local_image_exists("repo:tag", client=mock_client)
|
||||
self.assertTrue(ok)
|
||||
|
||||
def test_local_image_exists_not_found_false(self):
|
||||
mock_client = MagicMock()
|
||||
# Raise docker.errors.NotFound
|
||||
mock_client.images.get.side_effect = derr.NotFound("nope")
|
||||
ok = local_image_exists("missing:latest", client=mock_client)
|
||||
self.assertFalse(ok)
|
||||
|
||||
def test_local_image_exists_api_error_false(self):
|
||||
mock_client = MagicMock()
|
||||
mock_client.images.get.side_effect = derr.APIError("boom", None)
|
||||
|
||||
ok = local_image_exists("broken:tag", client=mock_client)
|
||||
self.assertFalse(ok)
|
||||
|
||||
def test_local_image_exists_uses_lazy_singleton(self):
|
||||
# Patch docker.from_env used by _get_client()
|
||||
with mock.patch(
|
||||
"cli.lib.common.docker_helper.docker.from_env"
|
||||
) as mock_from_env:
|
||||
mock_docker_client = MagicMock()
|
||||
mock_from_env.return_value = mock_docker_client
|
||||
|
||||
# First call should create and cache the client
|
||||
c1 = _get_client()
|
||||
self.assertIs(c1, mock_docker_client)
|
||||
mock_from_env.assert_called_once()
|
||||
|
||||
# Second call should reuse cached client (no extra from_env calls)
|
||||
c2 = _get_client()
|
||||
self.assertIs(c2, mock_docker_client)
|
||||
mock_from_env.assert_called_once() # still once
|
||||
|
||||
def test_local_image_exists_without_client_param_calls_get_client_once(self):
|
||||
# Ensure _get_client is called and cached; local_image_exists should reuse it
|
||||
with mock.patch("cli.lib.common.docker_helper._get_client") as mock_get_client:
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# 1st call
|
||||
local_image_exists("repo:tag")
|
||||
# 2nd call
|
||||
local_image_exists("repo:tag2")
|
||||
|
||||
# local_image_exists should call _get_client each time,
|
||||
# but your _get_client itself caches docker.from_env.
|
||||
self.assertEqual(mock_get_client.call_count, 2)
|
||||
self.assertEqual(mock_client.images.get.call_count, 2)
|
||||
mock_client.images.get.assert_any_call("repo:tag")
|
||||
mock_client.images.get.assert_any_call("repo:tag2")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
149
.ci/lumen_cli/tests/test_envs_helper.py
Normal file
149
.ci/lumen_cli/tests/test_envs_helper.py
Normal file
@ -0,0 +1,149 @@
|
||||
import os
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import cli.lib.common.envs_helper as m
|
||||
|
||||
|
||||
class TestEnvHelpers(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Keep a copy of the original environment to restore later
|
||||
self._env_backup = dict(os.environ)
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment to original state
|
||||
os.environ.clear()
|
||||
os.environ.update(self._env_backup)
|
||||
|
||||
# -------- get_env --------
|
||||
def test_get_env_unset_returns_default(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
self.assertEqual(m.get_env("FOO", "default"), "default")
|
||||
|
||||
def test_get_env_empty_returns_default(self):
|
||||
with patch.dict(os.environ, {"FOO": ""}, clear=True):
|
||||
self.assertEqual(m.get_env("FOO", "default"), "default")
|
||||
|
||||
def test_get_env_set_returns_value(self):
|
||||
with patch.dict(os.environ, {"FOO": "bar"}, clear=True):
|
||||
self.assertEqual(m.get_env("FOO", "default"), "bar")
|
||||
|
||||
def test_get_env_not_exist_returns_default(self):
|
||||
with patch.dict(os.environ, {"FOO": "bar"}, clear=True):
|
||||
self.assertEqual(m.get_env("TEST_NOT_EXIST", "default"), "default")
|
||||
|
||||
def test_get_env_not_exist_without_default(self):
|
||||
with patch.dict(os.environ, {"FOO": "bar"}, clear=True):
|
||||
self.assertEqual(m.get_env("TEST_NOT_EXIST"), "")
|
||||
|
||||
# -------- env_bool --------
|
||||
def test_env_bool_uses_default_when_unset(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
self.assertTrue(m.env_bool("FLAG", default=True))
|
||||
self.assertFalse(m.env_bool("FLAG", default=False))
|
||||
|
||||
def test_env_bool_uses_str2bool_when_set(self):
|
||||
# Patch str2bool used by env_bool so we don't depend on its exact behavior
|
||||
def fake_str2bool(s: str) -> bool:
|
||||
return s.lower() in {"1", "true", "yes", "on", "y"}
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"FLAG": "yEs"}, clear=True),
|
||||
patch.object(m, "str2bool", fake_str2bool),
|
||||
):
|
||||
self.assertTrue(m.env_bool("FLAG", default=False))
|
||||
|
||||
# -------- env_path_optional / env_path --------
|
||||
def test_env_path_optional_unset_returns_none_by_default(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
self.assertIsNone(m.env_path_optional("P"))
|
||||
|
||||
def test_env_path_optional_unset_returns_none_when_env_var_is_empty(self):
|
||||
with patch.dict(os.environ, {"P": ""}, clear=True):
|
||||
self.assertIsNone(m.env_path_optional("P"))
|
||||
|
||||
def test_env_path_optional_unset_returns_default_str(self):
|
||||
# default as string; resolve=True by default -> absolute path
|
||||
default_str = "x/y"
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
p = m.env_path_optional("P", default=default_str)
|
||||
self.assertIsInstance(p, Path)
|
||||
self.assertIsNotNone(p)
|
||||
if p:
|
||||
self.assertTrue(p.is_absolute())
|
||||
self.assertEqual(p.parts[-2:], ("x", "y"))
|
||||
|
||||
def test_env_path_optional_unset_returns_default_path_no_resolve(self):
|
||||
d = Path("z")
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
p = m.env_path_optional("P", default=d, resolve=False)
|
||||
self.assertEqual(p, d)
|
||||
|
||||
def test_env_path_optional_respects_resolve_true(self):
|
||||
with patch.dict(os.environ, {"P": "a/b"}, clear=True):
|
||||
p = m.env_path_optional("P", resolve=True)
|
||||
self.assertIsInstance(p, Path)
|
||||
if p:
|
||||
self.assertTrue(p.is_absolute())
|
||||
|
||||
def test_env_path_optional_respects_resolve_false(self):
|
||||
with patch.dict(os.environ, {"P": "rel/dir"}, clear=True):
|
||||
p = m.env_path_optional("P", resolve=False)
|
||||
self.assertEqual(p, Path("rel/dir"))
|
||||
if p:
|
||||
self.assertFalse(p.is_absolute())
|
||||
|
||||
def test_env_path_raises_when_missing_and_default_none(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with self.assertRaises(ValueError):
|
||||
m.env_path("P", None, resolve=True)
|
||||
|
||||
def test_env_path_returns_path_when_present(self):
|
||||
tmp = Path("./b").resolve()
|
||||
with patch.dict(os.environ, {"P": str(tmp)}, clear=True):
|
||||
p = m.env_path("P", None, resolve=True)
|
||||
self.assertEqual(p, tmp)
|
||||
|
||||
# -------- dataclass field helpers --------
|
||||
def test_dataclass_fields_read_env_at_instantiation(self):
|
||||
@dataclass
|
||||
class Cfg:
|
||||
flag: bool = m.env_bool_field("FLAG", default=False)
|
||||
out: Path = m.env_path_field("OUT", default="ab", resolve=True)
|
||||
name: str = m.env_str_field("NAME", default="anon")
|
||||
|
||||
# First instantiation
|
||||
with patch.dict(
|
||||
os.environ, {"FLAG": "true", "OUT": "outdir", "NAME": "alice"}, clear=True
|
||||
):
|
||||
cfg1 = Cfg()
|
||||
self.assertTrue(cfg1.flag)
|
||||
self.assertIsInstance(cfg1.out, Path)
|
||||
self.assertTrue(cfg1.out.is_absolute())
|
||||
self.assertEqual(cfg1.name, "alice")
|
||||
cfg1.name = "bob" # change instance value
|
||||
self.assertEqual(cfg1.name, "bob") # change is reflected
|
||||
|
||||
# Change env; new instance should reflect new values
|
||||
with patch.dict(os.environ, {"FLAG": "false", "NAME": ""}, clear=True):
|
||||
cfg2 = Cfg()
|
||||
self.assertFalse(cfg2.flag) # str2bool("false") -> False
|
||||
self.assertTrue("ab" in str(cfg2.out))
|
||||
self.assertIsInstance(cfg2.out, Path)
|
||||
self.assertTrue(cfg2.out.is_absolute())
|
||||
self.assertEqual(cfg2.name, "anon") # empty -> fallback to default
|
||||
|
||||
def test_dataclass_path_field_with_default_value(self):
|
||||
@dataclass
|
||||
class C2:
|
||||
out: Path = m.env_path_field("OUT", default="some/dir", resolve=False)
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
c = C2()
|
||||
self.assertEqual(c.out, Path("some/dir"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
122
.ci/lumen_cli/tests/test_path_helper.py
Normal file
122
.ci/lumen_cli/tests/test_path_helper.py
Normal file
@ -0,0 +1,122 @@
|
||||
# test_path_utils.py
|
||||
# Run: pytest -q
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from cli.lib.common.path_helper import (
|
||||
copy,
|
||||
ensure_dir_exists,
|
||||
force_create_dir,
|
||||
get_path,
|
||||
is_path_exist,
|
||||
remove_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestPathHelper(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdir = TemporaryDirectory()
|
||||
self.tmp_path = Path(self.tmpdir.name)
|
||||
|
||||
def tearDown(self):
|
||||
self.tmpdir.cleanup()
|
||||
|
||||
# -------- get_path --------
|
||||
def test_get_path_returns_path_for_str(self):
|
||||
# Use relative path to avoid absolute-ness
|
||||
rel_str = "sub/f.txt"
|
||||
os.chdir(self.tmp_path)
|
||||
p = get_path(rel_str, resolve=False)
|
||||
self.assertIsInstance(p, Path)
|
||||
self.assertFalse(p.is_absolute())
|
||||
self.assertEqual(str(p), rel_str)
|
||||
|
||||
def test_get_path_resolves(self):
|
||||
rel_str = "sub/f.txt"
|
||||
p = get_path(str(self.tmp_path / rel_str), resolve=True)
|
||||
self.assertTrue(p.is_absolute())
|
||||
self.assertTrue(str(p).endswith(rel_str))
|
||||
|
||||
def test_get_path_with_path_input(self):
|
||||
p_in = self.tmp_path / "sub/f.txt"
|
||||
p_out = get_path(p_in, resolve=False)
|
||||
self.assertTrue(str(p_out) == str(p_in))
|
||||
|
||||
def test_get_path_with_none_raises(self):
|
||||
with self.assertRaises(ValueError):
|
||||
get_path(None) # type: ignore[arg-type]
|
||||
|
||||
def test_get_path_invalid_type_raises(self):
|
||||
with self.assertRaises(TypeError):
|
||||
get_path(123) # type: ignore[arg-type]
|
||||
|
||||
# -------- ensure_dir_exists / force_create_dir / remove_dir --------
|
||||
def test_ensure_dir_exists_creates_and_is_idempotent(self):
|
||||
d = self.tmp_path / "made"
|
||||
ensure_dir_exists(d)
|
||||
self.assertTrue(d.exists() and d.is_dir())
|
||||
ensure_dir_exists(d)
|
||||
|
||||
def test_force_create_dir_clears_existing(self):
|
||||
d = self.tmp_path / "fresh"
|
||||
(d / "inner").mkdir(parents=True)
|
||||
(d / "inner" / "f.txt").write_text("x")
|
||||
force_create_dir(d)
|
||||
self.assertTrue(d.exists())
|
||||
self.assertEqual(list(d.iterdir()), [])
|
||||
|
||||
def test_remove_dir_none_is_noop(self):
|
||||
remove_dir(None) # type: ignore[arg-type]
|
||||
|
||||
def test_remove_dir_nonexistent_is_noop(self):
|
||||
ghost = self.tmp_path / "ghost"
|
||||
remove_dir(ghost)
|
||||
|
||||
def test_remove_dir_accepts_str(self):
|
||||
d = self.tmp_path / "to_rm"
|
||||
d.mkdir()
|
||||
remove_dir(str(d))
|
||||
self.assertFalse(d.exists())
|
||||
|
||||
# -------- copy --------
|
||||
def test_copy_file_to_file(self):
|
||||
src = self.tmp_path / "src.txt"
|
||||
dst = self.tmp_path / "out" / "dst.txt"
|
||||
src.write_text("hello")
|
||||
copy(src, dst)
|
||||
self.assertEqual(dst.read_text(), "hello")
|
||||
|
||||
def test_copy_dir_to_new_dir(self):
|
||||
src = self.tmp_path / "srcdir"
|
||||
(src / "a").mkdir(parents=True)
|
||||
(src / "a" / "f.txt").write_text("content")
|
||||
dst = self.tmp_path / "destdir"
|
||||
copy(src, dst)
|
||||
self.assertEqual((dst / "a" / "f.txt").read_text(), "content")
|
||||
|
||||
def test_copy_dir_into_existing_dir_overwrite_true_merges(self):
|
||||
src = self.tmp_path / "srcdir"
|
||||
dst = self.tmp_path / "destdir"
|
||||
(src / "x").mkdir(parents=True)
|
||||
(src / "x" / "new.txt").write_text("new")
|
||||
dst.mkdir()
|
||||
(dst / "existing.txt").write_text("old")
|
||||
copy(src, dst)
|
||||
self.assertEqual((dst / "existing.txt").read_text(), "old")
|
||||
self.assertEqual((dst / "x" / "new.txt").read_text(), "new")
|
||||
|
||||
def test_is_str_path_exist(self):
|
||||
p = self.tmp_path / "x.txt"
|
||||
p.write_text("1")
|
||||
self.assertTrue(is_path_exist(str(p)))
|
||||
self.assertTrue(is_path_exist(p))
|
||||
self.assertFalse(is_path_exist(str(self.tmp_path / "missing")))
|
||||
self.assertFalse(is_path_exist(self.tmp_path / "missing"))
|
||||
self.assertFalse(is_path_exist(""))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
181
.ci/lumen_cli/tests/test_vllm.py
Normal file
181
.ci/lumen_cli/tests/test_vllm.py
Normal file
@ -0,0 +1,181 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import cli.lib.core.vllm as vllm
|
||||
|
||||
|
||||
class TestVllmBuildParameters(unittest.TestCase):
|
||||
@patch("cli.lib.core.vllm.local_image_exists", return_value=True)
|
||||
@patch("cli.lib.core.vllm.is_path_exist", return_value=True)
|
||||
@patch(
|
||||
"cli.lib.common.envs_helper.env_path_optional",
|
||||
side_effect=lambda name, default=None, resolve=True: {
|
||||
"DOCKERFILE_PATH": Path("/abs/vllm/Dockerfile"),
|
||||
"TORCH_WHEELS_PATH": Path("/abs/dist"),
|
||||
"OUTPUT_DIR": Path("/abs/shared"),
|
||||
}.get(name, Path(default) if default is not None else None),
|
||||
)
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"USE_TORCH_WHEEL": "1",
|
||||
"USE_LOCAL_BASE_IMAGE": "1",
|
||||
"USE_LOCAL_DOCKERFILE": "1",
|
||||
"BASE_IMAGE": "my/image:tag",
|
||||
"DOCKERFILE_PATH": "vllm/Dockerfile",
|
||||
"TORCH_WHEELS_PATH": "dist",
|
||||
"OUTPUT_DIR": "shared",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_params_success_normalizes_and_validates(
|
||||
self, mock_env_path, mock_is_path, mock_local_img
|
||||
):
|
||||
params = vllm.VllmBuildParameters()
|
||||
self.assertEqual(params.torch_whls_path, Path("/abs/dist"))
|
||||
self.assertEqual(params.dockerfile_path, Path("/abs/vllm/Dockerfile"))
|
||||
self.assertEqual(params.output_dir, Path("/abs/shared"))
|
||||
self.assertEqual(params.base_image, "my/image:tag")
|
||||
|
||||
@patch("cli.lib.core.vllm.is_path_exist", return_value=False)
|
||||
@patch.dict(
|
||||
os.environ, {"USE_TORCH_WHEEL": "1", "TORCH_WHEELS_PATH": "dist"}, clear=True
|
||||
)
|
||||
def test_params_missing_torch_whls_raises(self, _is_path):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
os.chdir(td)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
vllm.VllmBuildParameters(
|
||||
use_local_base_image=False,
|
||||
use_local_dockerfile=False,
|
||||
)
|
||||
err = cm.exception
|
||||
self.assertIn("TORCH_WHEELS_PATH", str(err))
|
||||
|
||||
@patch("cli.lib.core.vllm.local_image_exists", return_value=False)
|
||||
@patch.dict(
|
||||
os.environ, {"USE_LOCAL_BASE_IMAGE": "1", "BASE_IMAGE": "img:tag"}, clear=True
|
||||
)
|
||||
def test_params_missing_local_base_image_raises(self, _local_img):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
os.chdir(td)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
vllm.VllmBuildParameters(
|
||||
use_torch_whl=False,
|
||||
use_local_dockerfile=False,
|
||||
)
|
||||
err = cm.exception
|
||||
self.assertIn("BASE_IMAGE", str(err))
|
||||
|
||||
@patch("cli.lib.core.vllm.is_path_exist", return_value=False)
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"USE_LOCAL_DOCKERFILE": "1", "DOCKERFILE_PATH": "Dockerfile"},
|
||||
clear=True,
|
||||
)
|
||||
def test_params_missing_dockerfile_raises(self, _is_path):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
os.chdir(td)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
vllm.VllmBuildParameters(
|
||||
use_torch_whl=False,
|
||||
use_local_base_image=False,
|
||||
)
|
||||
err = cm.exception
|
||||
self.assertIn("DOCKERFILE_PATH", str(err))
|
||||
|
||||
@patch("cli.lib.core.vllm.is_path_exist", return_value=False)
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"OUTPUT_DIR": ""},
|
||||
clear=True,
|
||||
)
|
||||
def test_params_missing_output_dir(self, _is_path):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
vllm.VllmBuildParameters()
|
||||
|
||||
|
||||
class TestBuildCmdAndRun(unittest.TestCase):
|
||||
@patch("cli.lib.core.vllm.local_image_exists", return_value=True)
|
||||
def test_generate_docker_build_cmd_includes_bits(self, _exists):
|
||||
runner = vllm.VllmBuildRunner()
|
||||
# Craft inputs that simulate a prepared build
|
||||
inputs = MagicMock()
|
||||
inputs.output_dir = Path("/abs/out")
|
||||
inputs.use_local_base_image = True
|
||||
inputs.base_image = "img:tag"
|
||||
inputs.torch_whls_path = Path("./vllm/tmp")
|
||||
inputs.max_jobs = 64
|
||||
inputs.cuda_version = "12.8.1"
|
||||
inputs.python_version = "3.12"
|
||||
inputs.sccache_bucket = "my-bucket"
|
||||
inputs.sccache_region = "us-west-2"
|
||||
inputs.torch_cuda_arch_list = "8.0;9.0"
|
||||
inputs.target_stage = "export-wheels"
|
||||
inputs.tag_name = "vllm-wheels"
|
||||
|
||||
cmd = runner._generate_docker_build_cmd(inputs)
|
||||
squashed = " ".join(cmd.split()) # normalize whitespace for matching
|
||||
|
||||
self.assertIn("--output type=local,dest=/abs/out", squashed)
|
||||
self.assertIn("-f docker/Dockerfile.nightly_torch", squashed)
|
||||
self.assertIn("--pull=false", squashed)
|
||||
self.assertIn("--build-arg TORCH_WHEELS_PATH=tmp", squashed)
|
||||
self.assertIn("--build-arg BUILD_BASE_IMAGE=img:tag", squashed)
|
||||
self.assertIn("--build-arg FINAL_BASE_IMAGE=img:tag", squashed)
|
||||
self.assertIn("--build-arg max_jobs=64", squashed)
|
||||
self.assertIn("--build-arg CUDA_VERSION=12.8.1", squashed)
|
||||
self.assertIn("--build-arg PYTHON_VERSION=3.12", squashed)
|
||||
self.assertIn("--build-arg USE_SCCACHE=1", squashed)
|
||||
self.assertIn("--build-arg SCCACHE_BUCKET_NAME=my-bucket", squashed)
|
||||
self.assertIn("--build-arg SCCACHE_REGION_NAME=us-west-2", squashed)
|
||||
self.assertIn("--build-arg torch_cuda_arch_list='8.0;9.0'", squashed)
|
||||
self.assertIn("--target export-wheels", squashed)
|
||||
self.assertIn("-t vllm-wheels", squashed)
|
||||
|
||||
@patch("cli.lib.core.vllm.run_command")
|
||||
@patch("cli.lib.core.vllm.ensure_dir_exists")
|
||||
@patch("cli.lib.core.vllm.clone_vllm")
|
||||
@patch.object(
|
||||
vllm.VllmBuildRunner,
|
||||
"_generate_docker_build_cmd",
|
||||
return_value="docker buildx ...",
|
||||
)
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
# Make __post_init__ validations pass cheaply
|
||||
"USE_TORCH_WHEEL": "0",
|
||||
"USE_LOCAL_BASE_IMAGE": "0",
|
||||
"USE_LOCAL_DOCKERFILE": "0",
|
||||
"OUTPUT_DIR": "shared",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
def test_run_calls_clone_prepare_and_build(
|
||||
self, mock_gen, mock_clone, mock_ensure, mock_run
|
||||
):
|
||||
# Stub parameters instance so we avoid FS/Docker accesses in run()
|
||||
params = MagicMock()
|
||||
params.output_dir = Path("shared")
|
||||
params.use_local_dockerfile = False
|
||||
params.use_torch_whl = False
|
||||
|
||||
with patch("cli.lib.core.vllm.VllmBuildParameters", return_value=params):
|
||||
runner = vllm.VllmBuildRunner()
|
||||
runner.run()
|
||||
|
||||
mock_clone.assert_called_once()
|
||||
mock_ensure.assert_called_once_with(Path("shared"))
|
||||
mock_gen.assert_called_once_with(params)
|
||||
mock_run.assert_called_once()
|
||||
# ensure we run in vllm workdir
|
||||
_, kwargs = mock_run.call_args
|
||||
assert kwargs.get("cwd") == "vllm"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -194,7 +194,7 @@ ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library
|
||||
ROCBLAS_LIB_DST=lib/rocblas/library
|
||||
ROCBLAS_ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH)
|
||||
ROCBLAS_OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx)
|
||||
ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $OTHER_FILES)
|
||||
ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $ROCBLAS_OTHER_FILES)
|
||||
|
||||
# hipblaslt library files
|
||||
HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library
|
||||
|
||||
@ -627,6 +627,8 @@ test_perf_for_dashboard() {
|
||||
device=cuda_a10g
|
||||
elif [[ "${TEST_CONFIG}" == *h100* ]]; then
|
||||
device=cuda_h100
|
||||
elif [[ "${TEST_CONFIG}" == *b200* ]]; then
|
||||
device=cuda_b200
|
||||
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
|
||||
device=rocm
|
||||
fi
|
||||
@ -801,6 +803,16 @@ test_dynamo_benchmark() {
|
||||
if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then
|
||||
test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@"
|
||||
elif [[ "${TEST_CONFIG}" == *perf* ]]; then
|
||||
# TODO (huydhn): Just smoke test some sample models
|
||||
if [[ "${TEST_CONFIG}" == *b200* ]]; then
|
||||
if [[ "${suite}" == "huggingface" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="DistillGPT2"
|
||||
elif [[ "${suite}" == "timm_models" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="inception_v3"
|
||||
elif [[ "${suite}" == "torchbench" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="hf_Bert"
|
||||
fi
|
||||
fi
|
||||
test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@"
|
||||
else
|
||||
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
|
||||
@ -1627,6 +1639,16 @@ elif [[ "${TEST_CONFIG}" == *xla* ]]; then
|
||||
install_torchvision
|
||||
build_xla
|
||||
test_xla
|
||||
elif [[ "$TEST_CONFIG" == *vllm* ]]; then
|
||||
(cd .ci/lumen_cli && python -m pip install -e .)
|
||||
if [[ "$BUILD_ENVIRONMENT" == *sm80* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0"
|
||||
elif [[ "$BUILD_ENVIRONMENT" == *sm90* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="9.0"
|
||||
else
|
||||
export TORCH_CUDA_ARCH_LIST="8.9"
|
||||
fi
|
||||
python -m cli.run test external vllm --test-plan "$TEST_CONFIG"
|
||||
elif [[ "${TEST_CONFIG}" == *executorch* ]]; then
|
||||
test_executorch
|
||||
elif [[ "$TEST_CONFIG" == 'jit_legacy' ]]; then
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
bf305f538005f2e900f8850ed57146024a8bc559
|
||||
9b57c7bd5ad4db093c5bb31c802df9f04d933ac9
|
||||
|
||||
2
.github/ci_commit_pins/vllm.txt
vendored
2
.github/ci_commit_pins/vllm.txt
vendored
@ -1 +1 @@
|
||||
ca9e2be3ed6320b51f52f536595cd24e254f8bb2
|
||||
53d7c39271aeb0568afcae337396a972e1848586
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
29ae4c76c026185f417a25e841d2cd5e65f087a3
|
||||
b6a5b82b9948b610fa4c304d0d869c82b8f17db1
|
||||
|
||||
414
.github/ci_configs/vllm/Dockerfile.tmp_vllm
vendored
Normal file
414
.github/ci_configs/vllm/Dockerfile.tmp_vllm
vendored
Normal file
@ -0,0 +1,414 @@
|
||||
# TODO(elainwy): remove this file after the torch nightly dockerfile is in sync in vllm repo
|
||||
# The vLLM Dockerfile is used to construct vLLM image against torch nightly and torch main that can be directly used for testing
|
||||
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
# BUILD_BASE_IMAGE: used to setup python build xformers, and vllm wheels, It can be replaced with a different base image from local machine,
|
||||
# by default, it uses the torch-nightly-base stage from this docker image
|
||||
ARG BUILD_BASE_IMAGE=torch-nightly-base
|
||||
|
||||
# FINAL_BASE_IMAGE: used to set up vllm-instaled environment and build flashinfer,
|
||||
# by default, it uses devel-ubuntu22.04 official image.
|
||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
|
||||
|
||||
|
||||
#################### TORCH NIGHTLY BASE IMAGE ####################
|
||||
# A base image for building vLLM with devel ubuntu 22.04, this is mainly used to build vllm in vllm builtkite ci
|
||||
From nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 as torch-nightly-base
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG TARGETPLATFORM
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
|
||||
|
||||
# Install Python and other dependencies if it does not existed
|
||||
RUN if ! command -v python3 >/dev/null || ! python3 --version | grep -q "${PYTHON_VERSION}"; then \
|
||||
echo "Installing Python ${PYTHON_VERSION}..." && \
|
||||
echo 'tzdata tzdata/Areas select America' | debconf-set-selections && \
|
||||
echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections && \
|
||||
apt-get update -y && \
|
||||
apt-get install -y ccache software-properties-common git curl sudo && \
|
||||
for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done && \
|
||||
apt-get update -y && \
|
||||
apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv && \
|
||||
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 && \
|
||||
update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} && \
|
||||
ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config && \
|
||||
curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}; \
|
||||
else \
|
||||
echo "Python ${PYTHON_VERSION} already present, skipping setup."; \
|
||||
fi \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
|
||||
# as it was causing spam when compiling the CUTLASS kernels
|
||||
# Ensure gcc >= 10 to avoid CUTLASS issues (bug 92519)
|
||||
RUN current_gcc_version=$(gcc -dumpversion | cut -f1 -d.) && \
|
||||
if [ "$current_gcc_version" -lt 10 ]; then \
|
||||
echo "GCC version is $current_gcc_version, installing gcc-10..."; \
|
||||
apt-get update && \
|
||||
apt-get install -y gcc-10 g++-10 && \
|
||||
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 && \
|
||||
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 100; \
|
||||
else \
|
||||
echo "GCC version is $current_gcc_version, no need to install gcc-10."; \
|
||||
fi && \
|
||||
gcc --version && g++ --version
|
||||
|
||||
# install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv==0.8.4
|
||||
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
#################### TORCH NIGHTLY BASE IMAGE ####################
|
||||
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# A base image for building vLLM with torch nightly or torch wheels
|
||||
# prepare basic build environment
|
||||
FROM ${BUILD_BASE_IMAGE} AS base
|
||||
USER root
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
# or future versions of triton.
|
||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
|
||||
# Install uv for faster pip installs if not existed
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if ! python3 -m uv --version >/dev/null 2>&1; then \
|
||||
python3 -m pip install uv==0.8.4; \
|
||||
fi
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY use_existing_torch.py use_existing_torch.py
|
||||
COPY pyproject.toml pyproject.toml
|
||||
|
||||
# install build and runtime dependencies without stable torch version
|
||||
RUN python3 use_existing_torch.py
|
||||
|
||||
# default mount file as placeholder, this just avoid the mount error
|
||||
# change to a different vllm folder if this does not exist anymore
|
||||
ARG TORCH_WHEELS_PATH="./requirements"
|
||||
ARG PINNED_TORCH_VERSION
|
||||
|
||||
# Install torch, torchaudio and torchvision based on the input
|
||||
# if TORCH_WHEELS_PATH is default "./requirements", it will pull thethe nightly versions using pip
|
||||
# otherwise, it will use the whls from TORCH_WHEELS_PATH from the host machine
|
||||
RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
if [ -n "$TORCH_WHEELS_PATH" ] && [ "$TORCH_WHEELS_PATH" != "./requirements" ] && [ -d "/dist" ] && ls /dist/torch*.whl >/dev/null 2>&1; then \
|
||||
torch_whl=$(find /dist -maxdepth 1 -name 'torch-*.whl' -print -quit); \
|
||||
vision_whl=$(find /dist/vision -name 'torchvision*.whl' | head -n1 | xargs); \
|
||||
audio_whl=$(find /dist/audio -name 'torchaudio*.whl' | head -n1 | xargs); \
|
||||
uv pip install --system "${torch_whl}[opt-einsum]"; \
|
||||
uv pip install --system "${vision_whl}"; \
|
||||
uv pip install --system "${audio_whl}"; \
|
||||
elif [ -n "$PINNED_TORCH_VERSION" ]; then \
|
||||
echo "[INFO] Installing pinned torch nightly version: $PINNED_TORCH_VERSION"; \
|
||||
uv pip install --system "$PINNED_TORCH_VERSION" --index-url https://download.pytorch.org/whl/nightly/cu128; \
|
||||
else \
|
||||
echo "[INFO] Installing torch nightly with latest one"; \
|
||||
uv pip install --system torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128; \
|
||||
fi
|
||||
|
||||
# Install numba 0.61.2 for cuda environment
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system numba==0.61.2
|
||||
|
||||
# Install common dependencies from vllm common.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/common.txt
|
||||
|
||||
|
||||
# Must put before installing xformers, so it can install the correct version of xfomrers.
|
||||
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
ARG max_jobs=16
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
|
||||
# Build xformers with cuda and torch nightly/wheel
|
||||
# following official xformers guidance: https://github.com/facebookresearch/xformers#build
|
||||
ARG XFORMERS_COMMIT=f2de641ef670510cadab099ce6954031f52f191c
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
echo 'git clone xformers...' \
|
||||
&& git clone https://github.com/facebookresearch/xformers.git --recursive \
|
||||
&& cd xformers \
|
||||
&& git checkout ${XFORMERS_COMMIT} \
|
||||
&& git submodule update --init --recursive \
|
||||
&& echo 'finish git clone xformers...' \
|
||||
&& rm -rf build \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
|
||||
&& cd .. \
|
||||
&& rm -rf xformers
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system xformers-dist/*.whl --verbose
|
||||
|
||||
# Build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
|
||||
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
|
||||
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
|
||||
RUN cat torch_build_versions.txt
|
||||
|
||||
RUN pip freeze | grep -E 'torch|xformers|torchvision|torchaudio'
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
|
||||
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
# Image used to build vllm wheel
|
||||
FROM base AS build
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN python3 use_existing_torch.py
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt
|
||||
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
# Max jobs used by Ninja to build extensions
|
||||
ARG max_jobs=16
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
ARG nvcc_threads=2
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
ARG USE_SCCACHE
|
||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||
echo "Installing sccache..." \
|
||||
&& curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \
|
||||
&& tar -xzf sccache.tar.gz \
|
||||
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
|
||||
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
|
||||
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
|
||||
&& export SCCACHE_IDLE_TIMEOUT=0 \
|
||||
&& export CMAKE_BUILD_TYPE=Release \
|
||||
&& sccache --show-stats \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=vllm-dist --py-limited-api=cp38 \
|
||||
&& sccache --show-stats; \
|
||||
fi
|
||||
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" != "1" ]; then \
|
||||
# Clean any existing CMake artifacts
|
||||
rm -rf .deps && \
|
||||
mkdir -p .deps && \
|
||||
python3 setup.py bdist_wheel --dist-dir=vllm-dist --py-limited-api=cp38; \
|
||||
fi
|
||||
|
||||
RUN echo "[DEBUG] Listing current directory:" && \
|
||||
ls -al && \
|
||||
echo "[DEBUG] Showing torch_build_versions.txt content:" && \
|
||||
cat torch_build_versions.txt
|
||||
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
|
||||
|
||||
################### VLLM INSTALLED IMAGE ####################
|
||||
# Setup clean environment for vLLM for test and api server using ubuntu22.04 with AOT flashinfer
|
||||
FROM ${FINAL_BASE_IMAGE} AS vllm-base
|
||||
USER root
|
||||
# prepare for environment starts
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
|
||||
|
||||
# Install Python and other dependencies if it does not existed
|
||||
RUN if ! command -v python3 >/dev/null || ! python3 --version | grep -q "${PYTHON_VERSION}"; then \
|
||||
echo "Installing Python ${PYTHON_VERSION}..." && \
|
||||
echo 'tzdata tzdata/Areas select America' | debconf-set-selections && \
|
||||
echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections && \
|
||||
apt-get update -y && \
|
||||
apt-get install -y ccache software-properties-common git curl sudo && \
|
||||
for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done && \
|
||||
apt-get update -y && \
|
||||
apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv && \
|
||||
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 && \
|
||||
update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} && \
|
||||
ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config && \
|
||||
curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}; \
|
||||
else \
|
||||
echo "Python ${PYTHON_VERSION} already present, skipping setup."; \
|
||||
fi \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
|
||||
# Get the torch versions, and whls used in previous stagtes for consistency
|
||||
COPY --from=base /workspace/torch_build_versions.txt ./torch_build_versions.txt
|
||||
COPY --from=base /workspace/xformers-dist /wheels/xformers
|
||||
COPY --from=build /workspace/vllm-dist /wheels/vllm
|
||||
RUN echo "[DEBUG] Listing current directory before torch install step:" && \
|
||||
ls -al && \
|
||||
echo "[DEBUG] Showing torch_build_versions.txt content:" && \
|
||||
cat torch_build_versions.txt
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
# or future versions of triton.
|
||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
|
||||
|
||||
# Install uv for faster pip installs if not existed
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if ! python3 -m uv --version > /dev/null 2>&1; then \
|
||||
python3 -m pip install uv==0.8.4; \
|
||||
fi
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
# Default mount file as placeholder, this just avoid the mount error
|
||||
ARG TORCH_WHEELS_PATH="./requirements"
|
||||
# Install torch, torchaudio and torchvision
|
||||
# if TORCH_WHEELS_PATH is default "./requirements", it will pull the nightly versions using pip using torch_build_versions.txt
|
||||
# otherwise, it will use the whls from TORCH_WHEELS_PATH from the host machine
|
||||
RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
if [ -n "$TORCH_WHEELS_PATH" ] && [ "$TORCH_WHEELS_PATH" != "./requirements" ] && [ -d "/dist" ] && ls /dist/torch*.whl >/dev/null 2>&1; then \
|
||||
torch_whl=$(find /dist -maxdepth 1 -name 'torch-*.whl' -print -quit); \
|
||||
vision_whl=$(find /dist/vision -name 'torchvision*.whl' | head -n1 | xargs); \
|
||||
audio_whl=$(find /dist/audio -name 'torchaudio*.whl' | head -n1 | xargs); \
|
||||
echo "Found: '${torch_whl}' '${audio_whl}' '${vision_whl}'"; \
|
||||
uv pip install --system "${torch_whl}[opt-einsum]"; \
|
||||
uv pip install --system "${vision_whl}"; \
|
||||
uv pip install --system "${audio_whl}"; \
|
||||
else \
|
||||
echo "[INFO] Installing torch versions from torch_build_versions.txt"; \
|
||||
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu128; \
|
||||
fi
|
||||
|
||||
# Install the vllm wheel from previous stage
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system /wheels/vllm/*.whl --verbose
|
||||
|
||||
# Install xformers wheel from previous stage
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system /wheels/xformers/*.whl --verbose
|
||||
|
||||
|
||||
# Build flashinfer from source.
|
||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a'
|
||||
# install package for build flashinfer
|
||||
# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738
|
||||
|
||||
RUN pip install build==1.3.0
|
||||
RUN pip freeze | grep -E 'setuptools|packaging|build'
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
# Build flashinfer for torch nightly from source around 10 mins
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt
|
||||
ARG FLASHINFER_GIT_REF="v0.2.9rc2"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer \
|
||||
&& echo "Building FlashInfer with AOT for arches: ${torch_cuda_arch_list}" \
|
||||
&& cd flashinfer \
|
||||
&& python3 -m flashinfer.aot \
|
||||
&& python3 -m build --no-isolation --wheel --outdir ../wheels/flashinfer \
|
||||
&& cd .. \
|
||||
&& rm -rf flashinfer
|
||||
|
||||
# install flashinfer python
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system wheels/flashinfer/*.whl --verbose
|
||||
|
||||
# Logging to confirm the torch versions
|
||||
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
|
||||
################### VLLM INSTALLED IMAGE ####################
|
||||
|
||||
|
||||
#################### UNITTEST IMAGE #############################
|
||||
FROM vllm-base as test
|
||||
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
COPY tests/ tests/
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY use_existing_torch.py use_existing_torch.py
|
||||
COPY pyproject.toml pyproject.toml
|
||||
# Install build and runtime dependencies without stable torch version
|
||||
COPY requirements/nightly_torch_test.txt requirements/nightly_torch_test.txt
|
||||
|
||||
RUN python3 use_existing_torch.py
|
||||
|
||||
# install packages
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/common.txt
|
||||
# enable fast downloads from hf (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system hf_transfer
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -e tests/vllm_test_utils
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/nightly_torch_test.txt
|
||||
|
||||
# Workaround for #17068
|
||||
# pinned commit for v2.2.4
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@95d8aba8a8c75aedcaa6143713b11e745e7cd0d9#egg=mamba-ssm"
|
||||
|
||||
# Logging to confirm the torch versions
|
||||
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
|
||||
|
||||
# Logging to confirm all the packages are installed
|
||||
RUN pip freeze
|
||||
|
||||
#################### UNITTEST IMAGE #############################
|
||||
|
||||
#################### EXPORT STAGE ####################
|
||||
FROM scratch as export-wheels
|
||||
|
||||
# Just copy the wheels we prepared in previous stages
|
||||
COPY --from=base /workspace/xformers-dist /wheels/xformers
|
||||
COPY --from=build /workspace/vllm-dist /wheels/vllm
|
||||
COPY --from=vllm-base /workspace/wheels/flashinfer /wheels/flashinfer-python
|
||||
4
.github/merge_rules.yaml
vendored
4
.github/merge_rules.yaml
vendored
@ -488,6 +488,10 @@
|
||||
- torch/_dynamo/**
|
||||
- torch/csrc/dynamo/**
|
||||
- test/dynamo/**
|
||||
- test/dynamo_expected_failures/**
|
||||
- test/dynamo_skips/**
|
||||
- test/inductor_expected_failures/**
|
||||
- test/inductor_skips/**
|
||||
approved_by:
|
||||
- guilhermeleobas
|
||||
mandatory_checks_name:
|
||||
|
||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -26,6 +26,7 @@ ciflow_push_tags:
|
||||
- ciflow/trunk
|
||||
- ciflow/unstable
|
||||
- ciflow/xpu
|
||||
- ciflow/vllm
|
||||
- ciflow/torchbench
|
||||
- ciflow/op-benchmark
|
||||
- ciflow/pull
|
||||
|
||||
@ -193,7 +193,7 @@ LIBTORCH_CONTAINER_IMAGES: dict[str, str] = {
|
||||
"cpu": "libtorch-cxx11-builder:cpu",
|
||||
}
|
||||
|
||||
FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t"]
|
||||
FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"]
|
||||
|
||||
|
||||
def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
|
||||
@ -315,6 +315,11 @@ def generate_wheels_matrix(
|
||||
# TODO: Enable python 3.13t on cpu-s390x
|
||||
if gpu_arch_type == "cpu-s390x" and python_version == "3.13t":
|
||||
continue
|
||||
# TODO: Enable python 3.14 on non linux OSes
|
||||
if os != "linux" and (
|
||||
python_version == "3.14" or python_version == "3.14t"
|
||||
):
|
||||
continue
|
||||
|
||||
if use_split_build and (
|
||||
arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux"
|
||||
|
||||
292
.github/workflows/_linux-external-build-main.yml
vendored
Normal file
292
.github/workflows/_linux-external-build-main.yml
vendored
Normal file
@ -0,0 +1,292 @@
|
||||
name: linux-external-build
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
build-environment:
|
||||
required: true
|
||||
type: string
|
||||
description: Top-level label for what's being built/tested.
|
||||
build-target:
|
||||
required: true
|
||||
type: string
|
||||
description: target library to build
|
||||
build-generates-artifacts:
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
description: If set, upload generated build artifacts.
|
||||
artifacts-folder-name:
|
||||
required: false
|
||||
type: string
|
||||
description: must be different from build-environment
|
||||
default: ""
|
||||
docker-image:
|
||||
required: true
|
||||
type: string
|
||||
description: Docker image to run in or replace the external base image.
|
||||
cuda-arch-list:
|
||||
required: false
|
||||
type: string
|
||||
default: "8.9"
|
||||
description: |
|
||||
List of CUDA architectures CI build should target.
|
||||
runner_prefix:
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
description: Prefix for runner label
|
||||
runner:
|
||||
required: false
|
||||
type: string
|
||||
default: "linux.2xlarge"
|
||||
description: |
|
||||
Label of the runner this job should run on.
|
||||
s3-bucket:
|
||||
description: S3 bucket to download artifact
|
||||
required: false
|
||||
type: string
|
||||
default: "gha-artifacts"
|
||||
use-gha:
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
description: f set to any value, use GHA to download the artifact. Otherwise use s3.
|
||||
aws-role-to-assume:
|
||||
description: Role to assume for downloading artifacts
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
disable-monitor:
|
||||
description: |
|
||||
Disable utilization monitoring for build job
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
monitor-log-interval:
|
||||
description: |
|
||||
Set the interval for the monitor script to log utilization.
|
||||
required: false
|
||||
type: number
|
||||
default: 5
|
||||
monitor-data-collect-interval:
|
||||
description: |
|
||||
Set the interval for the monitor script to collect data.
|
||||
required: false
|
||||
type: number
|
||||
default: 1
|
||||
secrets:
|
||||
HUGGING_FACE_HUB_TOKEN:
|
||||
required: false
|
||||
description: |
|
||||
HF Auth token to avoid rate limits when downloading models or datasets from hub
|
||||
SCRIBE_GRAPHQL_ACCESS_TOKEN:
|
||||
required: false
|
||||
description: |
|
||||
FB app token to write to scribe endpoint
|
||||
|
||||
jobs:
|
||||
build-external-lib:
|
||||
environment: ${{ github.ref == 'refs/heads/main' && 'scribe-protected' || startsWith(github.ref, 'refs/heads/release/') && 'scribe-protected' || contains(github.event.pull_request.labels.*.name, 'ci-scribe') && 'scribe-pr' || '' }}
|
||||
# Don't run on forked repos
|
||||
if: github.repository_owner == 'pytorch'
|
||||
runs-on: ${{ inputs.runner_prefix}}${{ inputs.runner }}
|
||||
timeout-minutes: 240
|
||||
steps:
|
||||
- name: Setup SSH (Click me for login details)
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
instructions: |
|
||||
Build is done inside the container, to start an interactive session run:
|
||||
docker exec -it $(docker container ps --format '{{.ID}}') bash
|
||||
|
||||
# [pytorch repo ref]
|
||||
# Use a pytorch/pytorch reference instead of a reference to the local
|
||||
# checkout because when we run this action we don't *have* a local
|
||||
# checkout. In other cases you should prefer a local checkout.
|
||||
- name: Checkout PyTorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
no-sudo: true
|
||||
|
||||
- name: Get workflow job id
|
||||
id: get-job-id
|
||||
uses: ./.github/actions/get-workflow-job-id
|
||||
if: always()
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: configure aws credentials
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
if: ${{ inputs.aws-role-to-assume != ''}}
|
||||
with:
|
||||
role-to-assume: ${{ inputs.aws-role-to-assume }}
|
||||
role-session-name: gha-linux-build
|
||||
aws-region: us-east-1
|
||||
|
||||
- name: Setup Linux
|
||||
uses: ./.github/actions/setup-linux
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
if: ${{ inputs.aws-role-to-assume != ''}}
|
||||
id: login-ecr
|
||||
continue-on-error: true
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Parse ref
|
||||
id: parse-ref
|
||||
run: .github/scripts/parse_ref.py
|
||||
|
||||
- name: Start monitoring script
|
||||
id: monitor-script
|
||||
if: ${{ !inputs.disable-monitor }}
|
||||
shell: bash
|
||||
continue-on-error: true
|
||||
env:
|
||||
JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
|
||||
JOB_NAME: ${{ steps.get-job-id.outputs.job-name }}
|
||||
WORKFLOW_NAME: ${{ github.workflow }}
|
||||
WORKFLOW_RUN_ID: ${{github.run_id}}
|
||||
MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }}
|
||||
MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }}
|
||||
run: |
|
||||
mkdir -p ../../usage_logs
|
||||
python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7
|
||||
python3 -m tools.stats.monitor \
|
||||
--log-interval "$MONITOR_LOG_INTERVAL" \
|
||||
--data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" \
|
||||
> "../../usage_logs/usage_log_build_${JOB_ID}.txt" 2>&1 &
|
||||
echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-image-name: ${{ inputs.docker-image }}
|
||||
|
||||
- name: Use following to pull public copy of the image
|
||||
id: print-ghcr-mirror
|
||||
env:
|
||||
ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
shell: bash
|
||||
run: |
|
||||
tag=${ECR_DOCKER_IMAGE##*:}
|
||||
echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}"
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Download pytorch build artifacts
|
||||
uses: ./.github/actions/download-build-artifacts
|
||||
with:
|
||||
name: ${{ inputs.build-environment }}
|
||||
s3-bucket: ${{ inputs.s3-bucket }}
|
||||
use-gha: ${{ inputs.use-gha }}
|
||||
|
||||
- name: Download TD artifacts
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/download-td-artifacts
|
||||
|
||||
- name: Build external project
|
||||
id: build
|
||||
env:
|
||||
BUILD_ENVIRONMENT: ${{ inputs.build-environment }}
|
||||
BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
# Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs
|
||||
SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2
|
||||
SCCACHE_REGION: us-east-1
|
||||
PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }}
|
||||
TORCH_CUDA_ARCH_LIST: ${{ inputs.cuda-arch-list }}
|
||||
OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
BASE_IMAGE: ${{ inputs.docker-image }}
|
||||
BUILD_TARGET: ${{ inputs.build-target }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
python3 --version
|
||||
docker images
|
||||
START_TIME=$(date +%s)
|
||||
(
|
||||
cd .ci/lumen_cli
|
||||
python3 -m pip install -e .
|
||||
)
|
||||
|
||||
MAX_JOBS="$(nproc --ignore=6)"
|
||||
export MAX_JOBS
|
||||
|
||||
python3 -m cli.run build external "$BUILD_TARGET"
|
||||
END_TIME=$(date +%s)
|
||||
echo "build_time=$((END_TIME - START_TIME))" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Archive artifacts into zip
|
||||
if: ${{ inputs.build-generates-artifacts && steps.build.outcome && steps.build.outcome != 'skipped'}}
|
||||
run: |
|
||||
zip -1 -r artifacts.zip shared/
|
||||
|
||||
# By default it will upload the artifacts to <github_org>/<github_repo>/<workflow_id>/<name>-<target>-additional-build/
|
||||
# to avoid override the pytorch build artifacts
|
||||
- name: Store External Build Artifacts on S3
|
||||
if: ${{ inputs.build-generates-artifacts }}
|
||||
uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5.1.0
|
||||
with:
|
||||
name: ${{ inputs.artifacts-folder-name || format('{0}-{1}-additional-build', inputs.build-environment, inputs.build-target) }}
|
||||
retention-days: 14
|
||||
if-no-files-found: warn
|
||||
path: artifacts.zip
|
||||
s3-bucket: ${{ inputs.s3-bucket }}
|
||||
|
||||
- name: Stop monitoring script
|
||||
if: ${{ always() && steps.monitor-script.outputs.monitor-script-pid }}
|
||||
shell: bash
|
||||
continue-on-error: true
|
||||
env:
|
||||
MONITOR_SCRIPT_PID: ${{ steps.monitor-script.outputs.monitor-script-pid }}
|
||||
run: |
|
||||
kill "$MONITOR_SCRIPT_PID"
|
||||
|
||||
- name: Copy logs
|
||||
shell: bash
|
||||
if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor}}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
rm -f ./usage_logs
|
||||
mkdir -p ./usage_logs
|
||||
cp ../../usage_logs/usage_log_build_*.txt ./usage_logs/
|
||||
|
||||
- name: Upload raw usage log to s3
|
||||
if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor}}
|
||||
uses: seemethere/upload-artifact-s3@v5
|
||||
with:
|
||||
s3-prefix: |
|
||||
${{ github.repository }}/${{ github.run_id }}/${{ github.run_attempt }}/artifact
|
||||
retention-days: 14
|
||||
if-no-files-found: warn
|
||||
path: usage_logs/usage_log_build_*.txt
|
||||
|
||||
- name: Upload utilization stats
|
||||
if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor }}
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/upload-utilization-stats
|
||||
with:
|
||||
job_id: ${{ steps.get-job-id.outputs.job-id }}
|
||||
job_name: ${{ steps.get-job-id.outputs.job-name }}
|
||||
workflow_name: ${{ github.workflow }}
|
||||
workflow_run_id: ${{github.run_id}}
|
||||
workflow_attempt: ${{github.run_attempt}}
|
||||
artifact_prefix: usage_log_build_${{ steps.get-job-id.outputs.job-id }}
|
||||
|
||||
- name: Teardown Linux
|
||||
uses: pytorch/test-infra/.github/actions/teardown-linux@main
|
||||
if: always()
|
||||
|
||||
- name: Cleanup docker
|
||||
if: always()
|
||||
shell: bash
|
||||
run: |
|
||||
docker stop -a || true
|
||||
docker kill -a || true
|
||||
40
.github/workflows/_linux-test.yml
vendored
40
.github/workflows/_linux-test.yml
vendored
@ -47,6 +47,12 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
additional-artifact-name:
|
||||
description: |
|
||||
additional artifacts needed to be downloaded for testing
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
disable-monitor:
|
||||
description: |
|
||||
[Experimental] Disable utilization monitoring for tests.
|
||||
@ -72,6 +78,10 @@ on:
|
||||
required: false
|
||||
description: |
|
||||
HF Auth token to avoid rate limits when downloading models or datasets from hub
|
||||
VLLM_TEST_HUGGING_FACE_TOKEN:
|
||||
required: false
|
||||
description: |
|
||||
HF Auth token to test vllm
|
||||
SCRIBE_GRAPHQL_ACCESS_TOKEN:
|
||||
required: false
|
||||
description: |
|
||||
@ -96,7 +106,7 @@ jobs:
|
||||
steps:
|
||||
- name: Setup SSH (Click me for login details)
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
if: ${{ matrix.runner != 'B200' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
if: ${{ !contains(matrix.runner, 'b200') && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
instructions: |
|
||||
@ -109,7 +119,7 @@ jobs:
|
||||
no-sudo: true
|
||||
|
||||
- name: Setup Python
|
||||
if: matrix.runner == 'B200'
|
||||
if: contains(matrix.runner, 'b200')
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
@ -117,7 +127,7 @@ jobs:
|
||||
|
||||
- name: Setup Linux
|
||||
uses: ./.github/actions/setup-linux
|
||||
if: inputs.build-environment != 'linux-s390x-binary-manywheel' && matrix.runner != 'B200'
|
||||
if: inputs.build-environment != 'linux-s390x-binary-manywheel' && !contains(matrix.runner, 'b200')
|
||||
|
||||
- name: configure aws credentials
|
||||
if: ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
@ -128,7 +138,7 @@ jobs:
|
||||
aws-region: us-east-1
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
if: ${{ inputs.aws-role-to-assume != '' && matrix.runner == 'B200' }}
|
||||
if: ${{ inputs.aws-role-to-assume != '' && contains(matrix.runner, 'b200') }}
|
||||
id: login-ecr
|
||||
continue-on-error: true
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
@ -166,17 +176,17 @@ jobs:
|
||||
uses: pytorch/test-infra/.github/actions/setup-nvidia@main
|
||||
with:
|
||||
driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }}
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && matrix.runner != 'B200' }}
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }}
|
||||
|
||||
- name: Setup GPU_FLAG for docker run
|
||||
id: setup-gpu-flag
|
||||
run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}"
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || matrix.runner == 'B200') }}
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }}
|
||||
|
||||
- name: Setup SCCACHE_SERVER_PORT environment for docker run when on container
|
||||
id: setup-sscache-port-flag
|
||||
run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}"
|
||||
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && matrix.runner != 'B200' }}
|
||||
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && !contains(matrix.runner, 'b200') }}
|
||||
|
||||
- name: Lock NVIDIA A100 40GB Frequency
|
||||
run: |
|
||||
@ -216,6 +226,14 @@ jobs:
|
||||
s3-bucket: ${{ inputs.s3-bucket }}
|
||||
use-gha: ${{ inputs.use-gha }}
|
||||
|
||||
- name: Download additional build artifacts
|
||||
if: ${{ inputs.additional-artifact-name != ''}}
|
||||
uses: ./.github/actions/download-build-artifacts
|
||||
with:
|
||||
name: ${{ inputs.additional-artifact-name }}
|
||||
s3-bucket: ${{ inputs.s3-bucket }}
|
||||
use-gha: ${{ inputs.use-gha }}
|
||||
|
||||
- name: Download TD artifacts
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/download-td-artifacts
|
||||
@ -277,8 +295,8 @@ jobs:
|
||||
NO_TD: ${{ steps.keep-going.outputs.ci-no-td }}
|
||||
TD_DISTRIBUTED: ${{ steps.keep-going.outputs.ci-td-distributed }}
|
||||
# Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs
|
||||
SCCACHE_BUCKET: ${{ matrix.runner != 'B200' && 'ossci-compiler-cache-circleci-v2' || '' }}
|
||||
SCCACHE_REGION: ${{ matrix.runner != 'B200' && 'us-east-1' || '' }}
|
||||
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
|
||||
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
|
||||
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
|
||||
DOCKER_IMAGE: ${{ inputs.docker-image }}
|
||||
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
|
||||
@ -286,6 +304,7 @@ jobs:
|
||||
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
|
||||
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
|
||||
DASHBOARD_TAG: ${{ inputs.dashboard-tag }}
|
||||
VLLM_TEST_HUGGING_FACE_TOKEN: ${{ secrets.VLLM_TEST_HUGGING_FACE_TOKEN }}
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }}
|
||||
ARTIFACTS_FILE_SUFFIX: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }}
|
||||
@ -362,6 +381,7 @@ jobs:
|
||||
-e PYTORCH_TEST_RERUN_DISABLED_TESTS \
|
||||
-e SKIP_SCCACHE_INITIALIZATION=1 \
|
||||
-e HUGGING_FACE_HUB_TOKEN \
|
||||
-e VLLM_TEST_HUGGING_FACE_TOKEN \
|
||||
-e SCRIBE_GRAPHQL_ACCESS_TOKEN \
|
||||
-e DASHBOARD_TAG \
|
||||
-e ARTIFACTS_FILE_SUFFIX \
|
||||
@ -403,7 +423,7 @@ jobs:
|
||||
job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }}
|
||||
|
||||
- name: Authenticate with AWS
|
||||
if: ${{ matrix.runner == 'B200' }}
|
||||
if: ${{ contains(matrix.runner, 'b200') }}
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results
|
||||
|
||||
3
.github/workflows/docker-builds.yml
vendored
3
.github/workflows/docker-builds.yml
vendored
@ -76,7 +76,8 @@ jobs:
|
||||
pytorch-linux-jammy-py3-clang12-onnx,
|
||||
pytorch-linux-jammy-linter,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter,
|
||||
pytorch-linux-jammy-py3-clang12-executorch,
|
||||
# Executorch pin needs update
|
||||
# pytorch-linux-jammy-py3-clang12-executorch,
|
||||
pytorch-linux-jammy-py3.12-triton-cpu
|
||||
]
|
||||
include:
|
||||
|
||||
1226
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
1226
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
File diff suppressed because it is too large
Load Diff
154
.github/workflows/inductor-perf-test-b200.yml
vendored
Normal file
154
.github/workflows/inductor-perf-test-b200.yml
vendored
Normal file
@ -0,0 +1,154 @@
|
||||
name: inductor-perf-b200
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 7 * * 1-6
|
||||
- cron: 0 7 * * 0
|
||||
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
|
||||
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
training:
|
||||
description: Run training (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
inference:
|
||||
description: Run inference (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
default:
|
||||
description: Run inductor_default?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
dynamic:
|
||||
description: Run inductor_dynamic_shapes?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cppwrapper:
|
||||
description: Run inductor_cpp_wrapper?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cudagraphs:
|
||||
description: Run inductor_cudagraphs?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
freezing_cudagraphs:
|
||||
description: Run inductor_cudagraphs with freezing for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
aotinductor:
|
||||
description: Run aot_inductor for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
maxautotune:
|
||||
description: Run inductor_max_autotune?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
benchmark_configs:
|
||||
description: The list of configs used the benchmark
|
||||
required: false
|
||||
type: string
|
||||
default: inductor_huggingface_perf_cuda_b200,inductor_timm_perf_cuda_b200,inductor_torchbench_perf_cuda_b200
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
build:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
# Use a bigger runner here because CUDA_ARCH 9.0 is only built for H100
|
||||
# or newer GPUs, so it doesn't benefit much from existing compiler cache
|
||||
# from trunk. Also use a memory-intensive runner here because memory is
|
||||
# usually the bottleneck
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
{ config: "inductor_timm_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
{ config: "inductor_torchbench_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
]}
|
||||
selected-test-configs: ${{ inputs.benchmark_configs }}
|
||||
build-additional-packages: "vision audio fbgemm torchao"
|
||||
secrets: inherit
|
||||
|
||||
test-periodically:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: build
|
||||
if: github.event.schedule == '0 7 * * 1-6'
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
|
||||
docker-image: ${{ needs.build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
timeout-minutes: 720
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
test-weekly:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: build
|
||||
if: github.event.schedule == '0 7 * * 0'
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
|
||||
docker-image: ${{ needs.build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
timeout-minutes: 1440
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
test:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: build
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||
docker-image: ${{ needs.build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
timeout-minutes: 720
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
30
.github/workflows/inductor-periodic.yml
vendored
30
.github/workflows/inductor-periodic.yml
vendored
@ -81,21 +81,21 @@ jobs:
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
9
.github/workflows/nightly.yml
vendored
9
.github/workflows/nightly.yml
vendored
@ -75,10 +75,11 @@ jobs:
|
||||
repo-owner: pytorch
|
||||
branch: main
|
||||
pin-folder: .github/ci_commit_pins
|
||||
- repo-name: executorch
|
||||
repo-owner: pytorch
|
||||
branch: main
|
||||
pin-folder: .ci/docker/ci_commit_pins
|
||||
# executorch jobs are disabled since it needs some manual work for the hash update
|
||||
# - repo-name: executorch
|
||||
# repo-owner: pytorch
|
||||
# branch: main
|
||||
# pin-folder: .ci/docker/ci_commit_pins
|
||||
- repo-name: triton
|
||||
repo-owner: triton-lang
|
||||
branch: main
|
||||
|
||||
1
.github/workflows/pull.yml
vendored
1
.github/workflows/pull.yml
vendored
@ -434,6 +434,7 @@ jobs:
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-py3-clang12-executorch-build:
|
||||
if: false # Docker build needs pin update
|
||||
name: linux-jammy-py3-clang12-executorch
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
|
||||
70
.github/workflows/tools-unit-tests.yml
vendored
Normal file
70
.github/workflows/tools-unit-tests.yml
vendored
Normal file
@ -0,0 +1,70 @@
|
||||
name: test-scripts-and-ci-tools
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- scripts/lumen_cli/**
|
||||
- .github/workflows/tools-unit-tests.yml
|
||||
pull_request:
|
||||
paths:
|
||||
- scripts/lumen_cli/**
|
||||
- .github/workflows/tools-unit-tests.yml
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
lumen-cli-unit-tests-python312:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout pytorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 0
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: pip
|
||||
|
||||
- name: Run tests
|
||||
continue-on-error: true
|
||||
run: |
|
||||
set -ex
|
||||
python3 -m venv /tmp/venv
|
||||
source /tmp/venv/bin/activate
|
||||
pip install -e .ci/lumen_cli/
|
||||
pytest -v -s .ci/lumen_cli/tests/*
|
||||
|
||||
lumen-cli-compatible-python39:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout pytorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 0
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||
with:
|
||||
python-version: '3.9'
|
||||
cache: 'pip'
|
||||
- name: Run tests
|
||||
continue-on-error: true
|
||||
run: |
|
||||
set -ex
|
||||
python3 -m venv /tmp/venv
|
||||
source /tmp/venv/bin/activate
|
||||
pip install -e .ci/lumen_cli/
|
||||
2
.github/workflows/update-viablestrict.yml
vendored
2
.github/workflows/update-viablestrict.yml
vendored
@ -23,7 +23,7 @@ jobs:
|
||||
with:
|
||||
repository: pytorch/pytorch
|
||||
stable-branch: viable/strict
|
||||
requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\"]'
|
||||
requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\", \"linux-aarch64\"]'
|
||||
secret-bot-token: ${{ secrets.MERGEBOT_TOKEN }}
|
||||
clickhouse-url: ${{ secrets.CLICKHOUSE_URL }}
|
||||
clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }}
|
||||
|
||||
123
.github/workflows/vllm.yml
vendored
Normal file
123
.github/workflows/vllm.yml
vendored
Normal file
@ -0,0 +1,123 @@
|
||||
name: vllm-test
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/vllm/*
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
torch-build-sm89:
|
||||
name: ci-vllm-test-sm89
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-additional-packages: "vision audio torchao"
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11-sm89
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm
|
||||
cuda-arch-list: '8.9'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "vllm_basic_correctness_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_basic_models_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_regression_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "vllm_entrypoints_test", shard: 1, num_shards: 1,runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
torch-build-sm80:
|
||||
name: ci-vllm-test-sm80
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-additional-packages: "vision audio torchao"
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11-sm80
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm
|
||||
cuda-arch-list: '8.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "vllm_basic_correctness_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
|
||||
{ config: "vllm_basic_models_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
|
||||
{ config: "vllm_regression_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
|
||||
{ config: "vllm_entrypoints_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
vllm-build-sm89:
|
||||
name: ci-vllm-test-sm89
|
||||
uses: ./.github/workflows/_linux-external-build-main.yml
|
||||
needs: [
|
||||
get-label-type,
|
||||
torch-build-sm89
|
||||
]
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11-sm89
|
||||
build-target: vllm
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
docker-image: ${{ needs.torch-build-sm89.outputs.docker-image }}
|
||||
cuda-arch-list: '8.9'
|
||||
runner: linux.24xlarge.memory
|
||||
secrets: inherit
|
||||
|
||||
vllm-test-sm89:
|
||||
name: ci-vllm-test-sm89
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: [
|
||||
torch-build-sm89,
|
||||
vllm-build-sm89
|
||||
]
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11-sm89
|
||||
docker-image: ${{ needs.torch-build-sm89.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.torch-build-sm89.outputs.test-matrix }}
|
||||
additional-artifact-name: linux-jammy-cuda12.8-py3.12-gcc11-sm89-vllm-additional-build
|
||||
secrets: inherit
|
||||
|
||||
vllm-build-sm80:
|
||||
name: ci-vllm-test-sm80
|
||||
uses: ./.github/workflows/_linux-external-build-main.yml
|
||||
needs: [
|
||||
get-label-type,
|
||||
torch-build-sm80
|
||||
]
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11-sm80
|
||||
build-target: vllm
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
docker-image: ${{ needs.torch-build-sm80.outputs.docker-image }}
|
||||
runner: linux.24xlarge.memory
|
||||
cuda-arch-list: '8.0'
|
||||
secrets: inherit
|
||||
|
||||
vllm-test-sm80:
|
||||
name: ci-vllm-test-sm80
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: [
|
||||
torch-build-sm80,
|
||||
vllm-build-sm80
|
||||
]
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11-sm80
|
||||
docker-image: ${{ needs.torch-build-sm80.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.torch-build-sm80.outputs.test-matrix }}
|
||||
additional-artifact-name: linux-jammy-cuda12.8-py3.12-gcc11-sm80-vllm-additional-build
|
||||
secrets: inherit
|
||||
@ -14,7 +14,6 @@
|
||||
/torch/csrc/autograd/ @albanD @soulitzer
|
||||
/torch/autograd/ @albanD @soulitzer
|
||||
/tools/autograd/ @albanD @soulitzer
|
||||
/torch/header_only_apis.txt @janeyx99
|
||||
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki
|
||||
/torch/optim/ @albanD @janeyx99
|
||||
/test/test_public_bindings.py @albanD
|
||||
@ -196,3 +195,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed
|
||||
/torch/utils/_cxx_pytree.py @XuehaiPan
|
||||
/torch/utils/pytree/ @XuehaiPan
|
||||
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
|
||||
|
||||
# Relating to libtorch ABI
|
||||
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
|
||||
/torch/headeronly/ @janeyx99
|
||||
/torch/header_only_apis.txt @janeyx99
|
||||
|
||||
@ -439,6 +439,7 @@ if(USE_ROCM)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
|
||||
_pytorch_rocm_generate_ck_conf()
|
||||
@ -703,21 +704,17 @@ if(USE_MPS)
|
||||
if(CAN_COMPILE_METAL)
|
||||
foreach(SHADER ${native_mps_metal})
|
||||
cmake_path(GET SHADER STEM TGT_STEM)
|
||||
string(CONCAT TGT_BASIC ${TGT_STEM} "_30.air")
|
||||
string(CONCAT TGT_BFLOAT ${TGT_STEM} "_31.air")
|
||||
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
|
||||
list(APPEND AIR_BASIC ${TGT_BASIC})
|
||||
list(APPEND AIR_BFLOAT ${TGT_BFLOAT})
|
||||
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.0")
|
||||
metal_to_air(${SHADER} ${TGT_BFLOAT} "-std=metal3.1")
|
||||
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.1")
|
||||
endforeach()
|
||||
air_to_metallib(kernels_basic.metallib ${AIR_BASIC})
|
||||
air_to_metallib(kernels_bfloat.metallib ${AIR_BFLOAT})
|
||||
add_custom_command(
|
||||
COMMAND echo "// $$(date)" > metallib_dummy.cpp
|
||||
DEPENDS kernels_basic.metallib kernels_bfloat.metallib
|
||||
DEPENDS kernels_basic.metallib
|
||||
OUTPUT metallib_dummy.cpp
|
||||
COMMENT "Updating metallibs timestamp")
|
||||
add_custom_target(metallibs DEPENDS kernels_basic.metallib kernels_bfloat.metallib metallib_dummy.cpp)
|
||||
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
|
||||
else()
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
|
||||
foreach(SHADER ${native_mps_metal})
|
||||
|
||||
@ -162,7 +162,7 @@ struct CUDACachingHostAllocatorImpl
|
||||
}
|
||||
|
||||
bool pinned_use_background_threads() override {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
|
||||
pinned_use_background_threads();
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void _assert_match<c10::Device, std::optional<c10::Device>>(
|
||||
const c10::Device& original,
|
||||
const std::optional<c10::Device>& compared,
|
||||
const std::string& name) {
|
||||
if (compared) {
|
||||
const c10::Device& expected = compared.value();
|
||||
if (original.type() != expected.type()) {
|
||||
std::stringstream msg;
|
||||
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// If the expected device doesn't have an index (e.g., just "cuda"),
|
||||
// or if both devices have the same index, consider them equal
|
||||
if (expected.has_index() && original.has_index() && expected.index() != original.index()) {
|
||||
std::stringstream msg;
|
||||
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
|
||||
_assert_match(tensor.sym_sizes(), sizes, "sizes");
|
||||
_assert_match(tensor.sym_strides(), strides, "strides");
|
||||
|
||||
@ -367,27 +367,27 @@ void int8pack_mm_kernel_(
|
||||
auto* C_data = C.data_ptr<T>();
|
||||
const auto* S_data = scales.const_data_ptr<T>();
|
||||
|
||||
int M = A.size(0);
|
||||
int N = B.size(0);
|
||||
int K = A.size(1);
|
||||
int lda = A.stride(0);
|
||||
constexpr int BLOCK_M = 4;
|
||||
constexpr int BLOCK_N = 4;
|
||||
int64_t M = A.size(0);
|
||||
int64_t N = B.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 4;
|
||||
|
||||
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
const int64_t MB = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
const int64_t NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
|
||||
int mb{0}, nb{0};
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
(void)i;
|
||||
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
const auto* A_ptr = A_data + mb_start * lda;
|
||||
const auto* B_ptr = B_data + nb_start * K;
|
||||
|
||||
@ -526,7 +526,7 @@ namespace {
|
||||
|
||||
|
||||
// we are dealing with packed tensor here. max index is the same as numel.
|
||||
// TODO: to really support input tensor large enought to go beyond int32,
|
||||
// TODO: to really support input tensor large enough to go beyond int32,
|
||||
// we will need to restrict out shared memory usage and adjust the launch
|
||||
// config;
|
||||
AT_ASSERT(input_.numel() < std::numeric_limits<int32_t>::max());
|
||||
@ -681,7 +681,7 @@ namespace {
|
||||
const dim3 grid(grid_x, grid_y, grid_z);
|
||||
|
||||
// we are dealing with packed tensor here. max index is the same as numel.
|
||||
// TODO: to really support input tensor large enought to go beyond int32,
|
||||
// TODO: to really support input tensor large enough to go beyond int32,
|
||||
// we will need to restrict out shared memory usage and adjust the launch
|
||||
// config;
|
||||
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());
|
||||
|
||||
@ -1634,6 +1634,9 @@ bool use_fast_accum) {
|
||||
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
@ -1716,6 +1719,9 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
|
||||
// check that the strides are valid, the fn will throw an error if not
|
||||
check_valid_strides_and_return_transposed(mat_a);
|
||||
|
||||
@ -223,7 +223,7 @@ inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bo
|
||||
class CuFFTConfig {
|
||||
public:
|
||||
|
||||
// Only move semantics is enought for this class. Although we already use
|
||||
// Only move semantics is enough for this class. Although we already use
|
||||
// unique_ptr for the plan, still remove copy constructor and assignment op so
|
||||
// we don't accidentally copy and take perf hit.
|
||||
CuFFTConfig(const CuFFTConfig&) = delete;
|
||||
|
||||
@ -241,6 +241,8 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100(
|
||||
Strides tensor_StrideA = make_strides(mat_a.strides());
|
||||
Strides tensor_StrideB = make_strides(mat_b.strides());
|
||||
Strides tensor_StrideOutput = make_strides(out.strides());
|
||||
Strides tensor_ShapeA = make_strides(mat_a.sizes());
|
||||
Strides tensor_ShapeB = make_strides(mat_b.sizes());
|
||||
|
||||
at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>(
|
||||
reinterpret_cast<DtypeA*>(mat_a.data_ptr()),
|
||||
@ -264,6 +266,8 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100(
|
||||
tensor_StrideA,
|
||||
tensor_StrideB,
|
||||
tensor_StrideOutput,
|
||||
tensor_ShapeA,
|
||||
tensor_ShapeB,
|
||||
0,
|
||||
0,
|
||||
a_row_major,
|
||||
|
||||
@ -38,18 +38,20 @@ __global__ void prepare_grouped_gemm_data(
|
||||
Strides tensor_StrideA,
|
||||
Strides tensor_StrideB,
|
||||
Strides tensor_StrideOutput,
|
||||
Strides tensor_ShapeA,
|
||||
Strides tensor_ShapeB,
|
||||
int64_t a_scale_stride,
|
||||
int64_t b_scale_stride,
|
||||
bool a_row_major = true,
|
||||
bool b_row_major = false) {
|
||||
int32_t tid = threadIdx.x;
|
||||
int32_t delta = 0;
|
||||
int32_t offset = 0;
|
||||
if (offs != nullptr) {
|
||||
int32_t start = tid == 0 ? 0 : offs[tid - 1];
|
||||
delta = offs[tid] - start;
|
||||
if (K < 0) {
|
||||
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
|
||||
}
|
||||
offset = offs[tid];
|
||||
delta = offset - start;
|
||||
CUDA_KERNEL_ASSERT(delta >=0 && "expected gemm dimension to be greater or equal 0\n");
|
||||
|
||||
// TMA transfers require global memory tensor addresses to be
|
||||
// aligned to 16 bytes.
|
||||
@ -84,6 +86,7 @@ __global__ void prepare_grouped_gemm_data(
|
||||
int64_t lda, ldb, ldoutput;
|
||||
if (M < 0) {
|
||||
// A and output is 2d
|
||||
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[0] && "expected offset to be less than tensor size\n");
|
||||
M = delta;
|
||||
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
|
||||
ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2];
|
||||
@ -96,6 +99,7 @@ __global__ void prepare_grouped_gemm_data(
|
||||
output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput;
|
||||
B_ptrs[tid] = B + tid * tensor_StrideB[0];
|
||||
} else if (N < 0) {
|
||||
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeB[1] && "expected offset to be less than tensor size\n");
|
||||
N = delta;
|
||||
lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2];
|
||||
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed
|
||||
@ -108,6 +112,7 @@ __global__ void prepare_grouped_gemm_data(
|
||||
inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1];
|
||||
}
|
||||
} else if (K < 0) {
|
||||
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[1] && offset <= tensor_ShapeB[0] && "expected offset to be less than tensor size\n");
|
||||
// A, B is 2d, output is 3d
|
||||
K = delta;
|
||||
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
|
||||
|
||||
@ -298,6 +298,9 @@ void f8f8bf16_grouped_gemm_impl_sm90(
|
||||
Strides tensor_StrideA = make_strides(mat_a.strides());
|
||||
Strides tensor_StrideB = make_strides(mat_b.strides());
|
||||
Strides tensor_StrideOutput = make_strides(out.strides());
|
||||
Strides tensor_ShapeA = make_strides(mat_a.sizes());
|
||||
Strides tensor_ShapeB = make_strides(mat_b.sizes());
|
||||
|
||||
// scale stride will be used inside the kernel only if needed,
|
||||
// so for 1d scales the "1" assigned here won't be used
|
||||
int64_t a_scale_stride = scale_a.stride(0);
|
||||
@ -325,6 +328,8 @@ void f8f8bf16_grouped_gemm_impl_sm90(
|
||||
tensor_StrideA,
|
||||
tensor_StrideB,
|
||||
tensor_StrideOutput,
|
||||
tensor_ShapeA,
|
||||
tensor_ShapeB,
|
||||
a_scale_stride,
|
||||
b_scale_stride);
|
||||
|
||||
|
||||
@ -28,6 +28,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
TORCH_CHECK(false, "cudnn_batch_norm: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<Tensor>& running_mean,
|
||||
const std::optional<Tensor>& running_var,
|
||||
bool training,
|
||||
double exponential_average_factor,
|
||||
double epsilon,
|
||||
Tensor& out,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var,
|
||||
Tensor& reserve) {
|
||||
AT_ERROR("cudnn_batch_norm_out: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
|
||||
const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
@ -120,7 +136,12 @@ size_t _get_cudnn_batch_norm_reserve_space_size(
|
||||
return reserve_size;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
// Param `reserve` is a placeholder, just passing an empty tensor.
|
||||
// usage:
|
||||
// auto reserve = torch::empty({0}, torch::device(torch::kCUDA));
|
||||
// at::native::cudnn_batch_norm_out(..., epsilon, output, save_mean, save_var,
|
||||
// reserve);
|
||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_t_opt,
|
||||
@ -128,7 +149,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
const std::optional<Tensor>& running_var_t_opt,
|
||||
bool training,
|
||||
double exponential_average_factor,
|
||||
double epsilon) {
|
||||
double epsilon,
|
||||
Tensor& output_t,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var,
|
||||
Tensor& reserve) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_t_maybe_owned =
|
||||
at::borrow_from_optional_tensor(bias_t_opt);
|
||||
@ -168,9 +193,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
|
||||
training, input->suggest_memory_format(), input->dim());
|
||||
|
||||
auto output_t =
|
||||
at::empty_like(*input, input->options(), input->suggest_memory_format());
|
||||
|
||||
TensorArg output{output_t, "output", 0};
|
||||
|
||||
auto handle = getCudnnHandle();
|
||||
@ -182,15 +204,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
|
||||
Constant one(dataType, 1);
|
||||
Constant zero(dataType, 0);
|
||||
Tensor save_mean, save_var;
|
||||
|
||||
Tensor reserve;
|
||||
|
||||
if (training) {
|
||||
int64_t num_features = input_t.size(1);
|
||||
save_mean = at::empty({num_features}, weight_t.options());
|
||||
save_var = at::empty({num_features}, weight_t.options());
|
||||
|
||||
auto op = CUDNN_BATCHNORM_OPS_BN;
|
||||
size_t workspace_size;
|
||||
AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
|
||||
@ -238,9 +253,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
reserve_size));
|
||||
} else {
|
||||
reserve = at::empty({0}, input->options().dtype(kByte));
|
||||
// This keeps a consistent output with native_batch_norm
|
||||
save_mean = at::empty({0}, weight_t.options());
|
||||
save_var = at::empty({0}, weight_t.options());
|
||||
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
|
||||
handle,
|
||||
mode,
|
||||
@ -261,10 +273,48 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
// save_mean and save_var can be undefined
|
||||
// If this causes problems, we can initialize them to empty tensors
|
||||
// of the correct type
|
||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>{
|
||||
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>{
|
||||
output_t, save_mean, save_var, reserve};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_t_opt,
|
||||
const std::optional<Tensor>& running_mean_t_opt,
|
||||
const std::optional<Tensor>& running_var_t_opt,
|
||||
bool training,
|
||||
double exponential_average_factor,
|
||||
double epsilon) {
|
||||
auto output_t = at::empty_like(
|
||||
input_t, input_t.options(), input_t.suggest_memory_format());
|
||||
Tensor save_mean, save_var, reserve;
|
||||
|
||||
if (training) {
|
||||
int64_t num_features = input_t.size(1);
|
||||
save_mean = at::empty({num_features}, weight_t.options());
|
||||
save_var = at::empty({num_features}, weight_t.options());
|
||||
} else {
|
||||
// This keeps a consistent output with native_batch_norm
|
||||
save_mean = at::empty({0}, weight_t.options());
|
||||
save_var = at::empty({0}, weight_t.options());
|
||||
}
|
||||
|
||||
return cudnn_batch_norm_out(
|
||||
input_t,
|
||||
weight_t,
|
||||
bias_t_opt,
|
||||
running_mean_t_opt,
|
||||
running_var_t_opt,
|
||||
training,
|
||||
exponential_average_factor,
|
||||
epsilon,
|
||||
output_t,
|
||||
save_mean,
|
||||
save_var,
|
||||
reserve);
|
||||
}
|
||||
|
||||
// NB: CuDNN only implements the backward algorithm for batchnorm
|
||||
// in training mode (evaluation mode batchnorm has a different algorithm),
|
||||
// which is why this doesn't accept a 'training' parameter.
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/mkldnn/Matmul.h>
|
||||
|
||||
@ -428,56 +427,74 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool use_mkldnn_typed_matmul(
|
||||
bool use_mkldnn_bf16_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
bool dtype_check = false;
|
||||
if constexpr (std::is_same_v<T, c10::BFloat16>) {
|
||||
#if defined(__aarch64__)
|
||||
if (mkldnn_bf16_device_check_arm()) {
|
||||
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
|
||||
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
|
||||
// inputs, allow it for float as well
|
||||
dtype_check = use_mkldnn_bf16_matmul() &&
|
||||
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16));
|
||||
}
|
||||
#else
|
||||
dtype_check = dtype_check && use_mkldnn_bf16_matmul() &&
|
||||
(mat1.scalar_type() == kBFloat16);
|
||||
if (mkldnn_bf16_device_check_arm()) {
|
||||
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
|
||||
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
|
||||
// inputs, allow it for float as well
|
||||
return (
|
||||
use_mkldnn_bf16_matmul() &&
|
||||
(mat1.scalar_type() == mat2.scalar_type()) &&
|
||||
(!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
|
||||
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
} else
|
||||
#endif
|
||||
} else if constexpr (std::is_same_v<T, c10::Half>) {
|
||||
dtype_check = dtype_check && use_mkldnn_fp16_matmul() &&
|
||||
(mat1.scalar_type() == kHalf);
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
dtype_check = dtype_check &&
|
||||
(use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) &&
|
||||
(mat1.scalar_type() == kFloat);
|
||||
{
|
||||
return (
|
||||
use_mkldnn_bf16_matmul() && mat1.scalar_type() == kBFloat16 &&
|
||||
mat2.scalar_type() == kBFloat16 &&
|
||||
(!result.defined() || result.scalar_type() == kBFloat16) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
if (!dtype_check) {
|
||||
return false;
|
||||
}
|
||||
bool size_check =
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2);
|
||||
dtype_check = (mat1.scalar_type() == mat2.scalar_type()) &&
|
||||
(!result.defined() || result.scalar_type() == mat1.scalar_type());
|
||||
return dtype_check && size_check;
|
||||
}
|
||||
|
||||
bool use_mkldnn_fp16_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
return (
|
||||
use_mkldnn_fp16_matmul() && mat1.scalar_type() == kHalf &&
|
||||
mat2.scalar_type() == kHalf &&
|
||||
(!result.defined() || result.scalar_type() == kHalf) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
|
||||
bool use_mkldnn_bf32_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
return (
|
||||
use_mkldnn_bf32_matmul() && mat1.scalar_type() == kFloat &&
|
||||
mat2.scalar_type() == kFloat &&
|
||||
(!result.defined() || result.scalar_type() == kFloat) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
|
||||
bool use_mkldnn_tf32_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
return (
|
||||
use_mkldnn_tf32_matmul() && mat1.scalar_type() == kFloat &&
|
||||
mat2.scalar_type() == kFloat &&
|
||||
(!result.defined() || result.scalar_type() == kFloat) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
|
||||
bool use_mkldnn_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
auto mat1_type = mat1.scalar_type();
|
||||
if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) {
|
||||
return false;
|
||||
}
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] {
|
||||
return use_mkldnn_typed_matmul<scalar_t>(mat1, mat2, result);
|
||||
});
|
||||
return false;
|
||||
return (
|
||||
use_mkldnn_bf16_matmul(mat1, mat2, result) ||
|
||||
use_mkldnn_fp16_matmul(mat1, mat2, result) ||
|
||||
use_mkldnn_bf32_matmul(mat1, mat2, result) ||
|
||||
use_mkldnn_tf32_matmul(mat1, mat2, result));
|
||||
}
|
||||
|
||||
static void _mkldnn_matmul_i8i8i32_with_primitive(
|
||||
|
||||
@ -469,4 +469,94 @@ Tensor _weight_int4pack_mm_xpu(
|
||||
|
||||
return C;
|
||||
}
|
||||
|
||||
Tensor& _int_mm_out_xpu(
|
||||
const Tensor& self,
|
||||
const Tensor& mat2,
|
||||
Tensor& result) {
|
||||
TORCH_CHECK(
|
||||
self.dim() == 2,
|
||||
"Expected self to be of dimension 2 but got ",
|
||||
self.dim());
|
||||
TORCH_CHECK(
|
||||
mat2.dim() == 2,
|
||||
"Expected mat2 to be of dimension 2 but got ",
|
||||
mat2.dim());
|
||||
TORCH_CHECK(
|
||||
self.size(1) == mat2.size(0),
|
||||
"self.size(1) needs to match mat2.size(0) but got ",
|
||||
self.size(1),
|
||||
" and ",
|
||||
mat2.size(0));
|
||||
|
||||
TORCH_CHECK(
|
||||
self.dtype() == at::kChar,
|
||||
"Expected self dtype to be of type int8 but got ",
|
||||
self.dtype());
|
||||
TORCH_CHECK(
|
||||
mat2.dtype() == at::kChar,
|
||||
"Expected mat2 dtype to be of type int8 but got ",
|
||||
mat2.dtype());
|
||||
TORCH_CHECK(
|
||||
result.dtype() == at::kInt,
|
||||
"Expected result dtype to be of type kInt but got ",
|
||||
result.dtype());
|
||||
TORCH_CHECK(
|
||||
result.size(0) == self.size(0),
|
||||
"Expected result.size(0) to be ",
|
||||
self.size(0),
|
||||
" but got ",
|
||||
result.size(0));
|
||||
TORCH_CHECK(
|
||||
result.size(1) == mat2.size(1),
|
||||
"Expected result.size(1) to be ",
|
||||
mat2.size(1),
|
||||
" but got ",
|
||||
result.size(1));
|
||||
|
||||
TORCH_CHECK(
|
||||
result.dim() == 2,
|
||||
"Expected result to be of dimension 2 but got ",
|
||||
result.dim());
|
||||
|
||||
TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
|
||||
|
||||
if (result.numel() == 0 || self.size(1) == 0) {
|
||||
return result.zero_();
|
||||
}
|
||||
|
||||
Tensor bias = at::Tensor();
|
||||
Tensor mat2_scales = at::ones({1}, mat2.options().dtype(at::kFloat));
|
||||
Tensor mat2_zero_points = at::Tensor();
|
||||
auto post_op_args = torch::List<std::optional<at::Scalar>>();
|
||||
|
||||
at::native::onednn::quantized_matmul(
|
||||
self.contiguous(),
|
||||
1.0,
|
||||
0,
|
||||
mat2.contiguous(),
|
||||
mat2_scales,
|
||||
mat2_zero_points,
|
||||
bias,
|
||||
result,
|
||||
1.0,
|
||||
0,
|
||||
result.scalar_type(),
|
||||
/*other*/ std::nullopt,
|
||||
/*other scale*/ 1.0,
|
||||
/*other zp*/ 0,
|
||||
/*binary post op*/ "none",
|
||||
/*binary alpha*/ 1.0,
|
||||
/*post_op_name*/ "none",
|
||||
post_op_args,
|
||||
/*post_op_algorithm*/ "none",
|
||||
/*m2_trans*/ true);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) {
|
||||
Tensor result =
|
||||
at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
|
||||
return _int_mm_out_xpu(self, mat2, result);
|
||||
}
|
||||
} // namespace at::native
|
||||
|
||||
@ -953,8 +953,7 @@ class BundledShaderLibary : public MetalShaderLibrary {
|
||||
if (C10_UNLIKELY(!library)) {
|
||||
auto device = MPSDevice::getInstance()->device();
|
||||
NSError* error = nil;
|
||||
auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? "metal_bfloat" : "metal_basic";
|
||||
library = [device newLibraryWithData:getSectionData(section_name) error:&error];
|
||||
library = [device newLibraryWithData:getSectionData("metal_basic") error:&error];
|
||||
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
|
||||
}
|
||||
return library;
|
||||
|
||||
@ -33,21 +33,15 @@ struct shrink_backward_functor {
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct hardsigmoid_functor {
|
||||
template <typename T>
|
||||
@ -67,15 +61,11 @@ struct hardsigmoid_backward_functor {
|
||||
|
||||
REGISTER_UNARY_OP(hardsigmoid, float, float);
|
||||
REGISTER_UNARY_OP(hardsigmoid, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, float, float);
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct hardswish_functor {
|
||||
template <typename T>
|
||||
@ -103,15 +93,11 @@ struct hardswish_backward_functor {
|
||||
|
||||
REGISTER_UNARY_OP(hardswish, float, float);
|
||||
REGISTER_UNARY_OP(hardswish, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_OP(hardswish, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_OP(hardswish_backward, float, float);
|
||||
REGISTER_BINARY_OP(hardswish_backward, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct leaky_relu_functor {
|
||||
template <typename T>
|
||||
@ -135,12 +121,8 @@ struct leaky_relu_backward_functor {
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
@ -113,18 +113,12 @@ kernel void ampUpdateScale(
|
||||
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float);
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(float);
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(bfloat);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float);
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat);
|
||||
#endif
|
||||
|
||||
@ -590,9 +590,7 @@ kernel void attention(
|
||||
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(float);
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_ATTN(DTYPE, bq, bk, bd, wm, wn) \
|
||||
template [[host_name("attention_" #DTYPE "_bq" #bq "_bk" #bk "_bd" #bd \
|
||||
@ -621,6 +619,4 @@ INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
|
||||
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(float);
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(bfloat);
|
||||
#endif
|
||||
|
||||
@ -209,38 +209,9 @@ struct hermite_polynomial_he_functor {
|
||||
};
|
||||
|
||||
struct nextafter_functor {
|
||||
#if __METAL_VERSION__ < 310
|
||||
template <typename U>
|
||||
struct bit_type {};
|
||||
template <>
|
||||
struct bit_type<float> {
|
||||
using type = int;
|
||||
};
|
||||
template <>
|
||||
struct bit_type<half> {
|
||||
using type = short;
|
||||
};
|
||||
#endif
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
#if __METAL_VERSION__ >= 310
|
||||
return static_cast<T>(::metal::nextafter(a, b));
|
||||
#else
|
||||
using U = typename bit_type<T>::type;
|
||||
if (a == b) {
|
||||
return a;
|
||||
}
|
||||
if (::metal::isunordered(a, b)) {
|
||||
return NAN;
|
||||
}
|
||||
if (a == 0) {
|
||||
constexpr auto eps = as_type<T>(static_cast<U>(1));
|
||||
return b > 0 ? eps : -eps;
|
||||
}
|
||||
auto bits = as_type<U>(a);
|
||||
(a > 0) ^ (a > b) ? bits++ : bits--;
|
||||
return as_type<T>(bits);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -344,13 +315,6 @@ struct fmod_functor {
|
||||
}
|
||||
};
|
||||
|
||||
// Some helper defines
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#define _METAL_310_PLUS(x) x
|
||||
#else
|
||||
#define _METAL_310_PLUS(x)
|
||||
#endif
|
||||
|
||||
#define REGISTER_INTEGER_BINARY_OP(NAME) \
|
||||
REGISTER_BINARY_OP(NAME, long, long); \
|
||||
REGISTER_BINARY_OP(NAME, int, int); \
|
||||
@ -370,12 +334,12 @@ struct fmod_functor {
|
||||
#define REGISTER_FLOAT_BINARY_OP(NAME) \
|
||||
REGISTER_BINARY_OP(NAME, float, float); \
|
||||
REGISTER_BINARY_OP(NAME, half, half); \
|
||||
_METAL_310_PLUS(REGISTER_BINARY_OP(NAME, bfloat, bfloat))
|
||||
REGISTER_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
#define REGISTER_OPMATH_FLOAT_BINARY_OP(NAME) \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, float, float); \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
|
||||
_METAL_310_PLUS(REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat))
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
REGISTER_FLOAT_BINARY_OP(copysign);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(copysign);
|
||||
@ -447,11 +411,9 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
// Complex binary functions
|
||||
REGISTER_BINARY_OP(polar, float, float2);
|
||||
|
||||
@ -180,10 +180,8 @@ REGISTER_SEARCHSORTED_OP(float, int);
|
||||
REGISTER_SEARCHSORTED_OP(float, long);
|
||||
REGISTER_SEARCHSORTED_OP(half, int);
|
||||
REGISTER_SEARCHSORTED_OP(half, long);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_SEARCHSORTED_OP(bfloat, int);
|
||||
REGISTER_SEARCHSORTED_OP(bfloat, long);
|
||||
#endif
|
||||
REGISTER_SEARCHSORTED_OP(char, int);
|
||||
REGISTER_SEARCHSORTED_OP(char, long);
|
||||
REGISTER_SEARCHSORTED_OP(uchar, int);
|
||||
|
||||
@ -96,6 +96,4 @@ kernel void col2im_kernel(
|
||||
INSTANTIATE_COL2IM(bool);
|
||||
INSTANTIATE_COL2IM(float);
|
||||
INSTANTIATE_COL2IM(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_COL2IM(bfloat);
|
||||
#endif
|
||||
|
||||
@ -20,9 +20,7 @@ REGISTER_CROSS_FUNC(short);
|
||||
REGISTER_CROSS_FUNC(char);
|
||||
REGISTER_CROSS_FUNC(uchar);
|
||||
REGISTER_CROSS_FUNC(bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_CROSS_FUNC(bfloat);
|
||||
#endif
|
||||
|
||||
template <typename T, typename U>
|
||||
kernel void cross(
|
||||
@ -68,6 +66,4 @@ REGISTER_CROSS_OP(short);
|
||||
REGISTER_CROSS_OP(char);
|
||||
REGISTER_CROSS_OP(uchar);
|
||||
REGISTER_CROSS_OP(bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_CROSS_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
using metal::max;
|
||||
#if __METAL_VERSION__ >= 310
|
||||
bfloat max(bfloat a, bfloat b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
#define kmaxThreadGroups 32
|
||||
#define kmaxTensors 32
|
||||
@ -306,11 +304,9 @@ REGISTER_ADAM_OPS_QUART(float, float);
|
||||
REGISTER_ADAM_OPS_QUART(float, half);
|
||||
REGISTER_ADAM_OPS_QUART(half, float);
|
||||
REGISTER_ADAM_OPS_QUART(half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_ADAM_OPS_QUART(float, bfloat);
|
||||
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
|
||||
REGISTER_ADAM_OPS_QUART(bfloat, float);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
inline void sgd_momentum_math(
|
||||
@ -460,7 +456,5 @@ REGISTER_FUSED_SGD_OP(float);
|
||||
REGISTER_FUSED_SGD_OP(half);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_FUSED_SGD_OP(bfloat);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -106,9 +106,7 @@ kernel void polygamma(
|
||||
constant int64_t& order [[buffer(2)]], \
|
||||
uint id [[thread_position_in_grid]]);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_GAMMA_KERNELS(bfloat, bfloat);
|
||||
#endif
|
||||
INSTANTIATE_GAMMA_KERNELS(half, half);
|
||||
INSTANTIATE_GAMMA_KERNELS(float, float);
|
||||
INSTANTIATE_GAMMA_KERNELS(bool, float);
|
||||
|
||||
@ -76,6 +76,4 @@ INSTANTIATE_IM2COL(float);
|
||||
INSTANTIATE_IM2COL(float2);
|
||||
INSTANTIATE_IM2COL(half);
|
||||
INSTANTIATE_IM2COL(half2);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_IM2COL(bfloat);
|
||||
#endif
|
||||
|
||||
@ -240,9 +240,7 @@ REGISTER_INDEX_OP(put_accumulate, short, short);
|
||||
REGISTER_INDEX_OP(put_accumulate, char, char);
|
||||
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
|
||||
REGISTER_INDEX_OP(put_accumulate, bool, bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
template <typename StridesT, typename DataT>
|
||||
kernel void kernel_index_offsets(
|
||||
@ -477,10 +475,8 @@ INSTANTIATE_INDEX_COPY(char, long);
|
||||
INSTANTIATE_INDEX_COPY(uchar, int);
|
||||
INSTANTIATE_INDEX_COPY(uchar, long);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_INDEX_COPY(bfloat, int);
|
||||
INSTANTIATE_INDEX_COPY(bfloat, long);
|
||||
#endif
|
||||
INSTANTIATE_INDEX_COPY(float2, int);
|
||||
INSTANTIATE_INDEX_COPY(float2, long);
|
||||
INSTANTIATE_INDEX_COPY(half2, int);
|
||||
|
||||
@ -288,7 +288,6 @@ kernel void layer_norm_looped(
|
||||
#define instantiate_layer_norm(DTYPE) \
|
||||
instantiate_layer_norm_single_row(DTYPE) instantiate_layer_norm_looped(DTYPE)
|
||||
|
||||
instantiate_layer_norm(float) instantiate_layer_norm(half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
instantiate_layer_norm(bfloat)
|
||||
#endif
|
||||
instantiate_layer_norm(float);
|
||||
instantiate_layer_norm(half);
|
||||
instantiate_layer_norm(bfloat);
|
||||
|
||||
@ -635,9 +635,7 @@ kernel void applyPivots(
|
||||
|
||||
INSTANTIATE_NAIVE_MM(float);
|
||||
INSTANTIATE_NAIVE_MM(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_NAIVE_MM(bfloat);
|
||||
#endif
|
||||
|
||||
// Integral MM
|
||||
INSTANTIATE_NAIVE_MM(short);
|
||||
|
||||
@ -48,3 +48,14 @@ struct PoolingBackwardParams {
|
||||
::c10::metal::array<idx_type_t, N> grad_output_strides;
|
||||
::c10::metal::array<idx_type_t, N> indices_strides;
|
||||
};
|
||||
|
||||
template <unsigned N = 5, typename idx_type_t = int32_t>
|
||||
struct MaxUnpoolingParams {
|
||||
int32_t dims;
|
||||
int32_t pooling_dims;
|
||||
::c10::metal::array<idx_type_t, N> input_sizes;
|
||||
::c10::metal::array<idx_type_t, N> input_strides;
|
||||
::c10::metal::array<idx_type_t, N> output_sizes;
|
||||
::c10::metal::array<idx_type_t, N> output_strides;
|
||||
::c10::metal::array<idx_type_t, N> indices_strides;
|
||||
};
|
||||
|
||||
@ -168,6 +168,16 @@ PoolOffsets find_pool_offsets(
|
||||
leading_dims,
|
||||
return_indices,
|
||||
tid);
|
||||
case 3:
|
||||
return find_pool_offsets_dim_specific<3>(
|
||||
output_sizes,
|
||||
output_strides,
|
||||
indices_strides,
|
||||
input_strides,
|
||||
pooling_dim_indices,
|
||||
leading_dims,
|
||||
return_indices,
|
||||
tid);
|
||||
}
|
||||
return PoolOffsets();
|
||||
}
|
||||
@ -292,6 +302,68 @@ kernel void max_pool_backward(
|
||||
pooling_dims);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void max_unpool_impl(
|
||||
device T* output,
|
||||
T input_element,
|
||||
int32_t input_index,
|
||||
constant int32_t* output_sizes,
|
||||
constant int32_t* output_strides,
|
||||
int32_t pooling_dims) {
|
||||
int32_t size_prod = 1;
|
||||
int32_t pool_offset = 0;
|
||||
|
||||
for (auto dim = pooling_dims - 1; dim >= 0; dim--) {
|
||||
auto next_size_prod = output_sizes[dim] * size_prod;
|
||||
pool_offset +=
|
||||
output_strides[dim] * ((input_index % next_size_prod) / size_prod);
|
||||
size_prod *= output_sizes[dim];
|
||||
}
|
||||
|
||||
output[pool_offset] = input_element;
|
||||
}
|
||||
|
||||
// Kernel computes one element of the grad input per kernel call.
|
||||
template <typename T>
|
||||
kernel void max_unpool(
|
||||
device T* output [[buffer(0)]],
|
||||
constant T* input [[buffer(1)]],
|
||||
constant int64_t* indices [[buffer(2)]],
|
||||
constant MaxUnpoolingParams<5>& params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto pooling_dims = params.pooling_dims;
|
||||
auto dims = params.dims;
|
||||
auto input_sizes = params.input_sizes.data();
|
||||
auto input_strides = params.input_strides.data();
|
||||
auto output_sizes = params.output_sizes.data();
|
||||
auto output_strides = params.output_strides.data();
|
||||
auto indices_strides = params.indices_strides.data();
|
||||
|
||||
auto leading_dims = dims - pooling_dims;
|
||||
|
||||
// NOTE: Since we're doing unpooling, the variable names "input" and "output"
|
||||
// are reversed compared to the pooling operations. So in `find_pool_offsets`,
|
||||
// we need to map "input" -> "output" and "output" -> "input".
|
||||
PoolOffsets offsets = find_pool_offsets(
|
||||
/*output_sizes=*/input_sizes,
|
||||
/*output_strides=*/input_strides,
|
||||
indices_strides,
|
||||
/*input_strides=*/output_strides,
|
||||
/*pooling_dim_indices=*/nullptr,
|
||||
dims,
|
||||
leading_dims,
|
||||
/*return_indices=*/true,
|
||||
tid);
|
||||
|
||||
max_unpool_impl<T>(
|
||||
output + offsets.input_leading,
|
||||
input[offsets.output],
|
||||
indices[offsets.indices],
|
||||
output_sizes + leading_dims,
|
||||
output_strides + leading_dims,
|
||||
pooling_dims);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct AvgPoolIterBounds {
|
||||
T start;
|
||||
@ -428,18 +500,25 @@ kernel void avg_pool(
|
||||
params.divisor_override);
|
||||
}
|
||||
|
||||
#define REGISTER_POOL_OP(DTYPE) \
|
||||
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant PoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant AvgPoolingParams<5> & params [[buffer(2)]], \
|
||||
#define REGISTER_POOL_OP(DTYPE) \
|
||||
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant PoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("max_unpool_" #DTYPE)]] kernel void max_unpool<DTYPE>( \
|
||||
device DTYPE * output [[buffer(0)]], \
|
||||
constant DTYPE * input [[buffer(1)]], \
|
||||
constant int64_t* indices [[buffer(2)]], \
|
||||
constant MaxUnpoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant AvgPoolingParams<5> & params [[buffer(2)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
|
||||
@ -453,6 +532,7 @@ kernel void avg_pool(
|
||||
|
||||
REGISTER_POOL_OP(float);
|
||||
REGISTER_POOL_OP(half);
|
||||
REGISTER_POOL_OP(bfloat);
|
||||
REGISTER_POOL_OP(int);
|
||||
REGISTER_POOL_OP(long);
|
||||
REGISTER_POOL_OP(short);
|
||||
@ -462,8 +542,4 @@ REGISTER_POOL_OP(bool);
|
||||
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(float);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(half);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_POOL_OP(bfloat);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -197,12 +197,10 @@ INSTANTIATE_INT4MV(float, 128);
|
||||
INSTANTIATE_INT4MV(half, 128);
|
||||
INSTANTIATE_INT4MV(float, 256);
|
||||
INSTANTIATE_INT4MV(half, 256);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_INT4MV(bfloat, 32);
|
||||
INSTANTIATE_INT4MV(bfloat, 64);
|
||||
INSTANTIATE_INT4MV(bfloat, 128);
|
||||
INSTANTIATE_INT4MV(bfloat, 256);
|
||||
#endif
|
||||
|
||||
// ------------------------------ int8 MM For M >= 12 ------------------------------------
|
||||
/**
|
||||
@ -234,12 +232,10 @@ template <> struct BlockType<half> {
|
||||
using simdgroup_type8x8 = simdgroup_half8x8;
|
||||
using type4 = half4;
|
||||
};
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <> struct BlockType<bfloat> {
|
||||
using simdgroup_type8x8 = simdgroup_bfloat8x8;
|
||||
using type4 = bfloat4;
|
||||
};
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
float2 get_scale_zero_q8(constant T * scalesAndZeros, uint2 index) {
|
||||
@ -490,9 +486,7 @@ kernel void kernel_mul_mm<DTYPE, WDTYPE, DEQUANT_FUNC>( \
|
||||
|
||||
INSTANTIATE_MM(float, char, get_scale_zero_q8);
|
||||
INSTANTIATE_MM(half, char, get_scale_zero_q8);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_MM(bfloat, char, get_scale_zero_q8);
|
||||
#endif
|
||||
// ------------------------------ int8 MM For M < 12 ------------------------------------
|
||||
/* Matrix vector multiplication, used for small M size for matrix multiplication as well.
|
||||
|
||||
@ -646,6 +640,4 @@ kernel void kernel_mul_mv<DTYPE>(
|
||||
|
||||
INSTANTIATE_MV(float);
|
||||
INSTANTIATE_MV(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_MV(bfloat);
|
||||
#endif
|
||||
|
||||
@ -192,6 +192,4 @@ template <typename T>
|
||||
|
||||
instantiate_rms(float)
|
||||
instantiate_rms(half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
instantiate_rms(bfloat)
|
||||
#endif // clang-format on
|
||||
|
||||
@ -23,6 +23,4 @@ kernel void renorm(
|
||||
|
||||
REGISTER_RENORM_OP(float);
|
||||
REGISTER_RENORM_OP(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_RENORM_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -25,379 +25,6 @@ struct LogAddExp {
|
||||
};
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ < 310
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct CumMinOp {
|
||||
static acc_t apply(acc_t a, acc_t b) {
|
||||
return metal::min(a, b);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return static_cast<acc_t>(
|
||||
metal::is_floating_point_v<T> ? metal::numeric_limits<T>::infinity()
|
||||
: metal::numeric_limits<T>::max());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct CumMaxOp {
|
||||
static acc_t apply(acc_t a, acc_t b) {
|
||||
return metal::max(a, b);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return static_cast<acc_t>(
|
||||
metal::is_floating_point_v<T> ? -metal::numeric_limits<T>::infinity()
|
||||
: metal::numeric_limits<T>::lowest());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct LogCumSumExpOp {
|
||||
static acc_t apply(acc_t x, acc_t y) {
|
||||
return LogAddExp{}(x, y);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return -metal::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
};
|
||||
|
||||
// Inclusive scan along innermost dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_contiguous_innermost_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant uint& num_rows [[buffer(2)]],
|
||||
constant uint& row_size [[buffer(3)]],
|
||||
uint row [[thread_position_in_grid]]) {
|
||||
if (row >= num_rows)
|
||||
return;
|
||||
|
||||
const uint offset = row * row_size;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
|
||||
for (uint col = 0; col < row_size; col++) {
|
||||
T val = input[offset + col];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[offset + col] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan along outer dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_contiguous_outer_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant uint& num_orows [[buffer(2)]],
|
||||
constant uint& num_irows [[buffer(3)]],
|
||||
constant uint& row_size [[buffer(4)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const uint orow = thread_index / num_irows;
|
||||
const uint irow = thread_index % num_irows;
|
||||
|
||||
if (orow >= num_orows)
|
||||
return;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
|
||||
const uint idx_base = orow * row_size * num_irows + irow;
|
||||
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
|
||||
T val = input[idx];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[idx] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan with indices along innermost dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_contiguous_innermost_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant uint& num_rows [[buffer(3)]],
|
||||
constant uint& row_size [[buffer(4)]],
|
||||
uint row [[thread_position_in_grid]]) {
|
||||
if (row >= num_rows)
|
||||
return;
|
||||
|
||||
const uint offset = row * row_size;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
|
||||
for (uint col = 0; col < row_size; col++) {
|
||||
T val = input[offset + col];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = col;
|
||||
}
|
||||
values[offset + col] = static_cast<T>(accumulator);
|
||||
indices[offset + col] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan with indices along outer dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_contiguous_outer_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant uint& num_orows [[buffer(3)]],
|
||||
constant uint& num_irows [[buffer(4)]],
|
||||
constant uint& row_size [[buffer(5)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const uint orow = thread_index / num_irows;
|
||||
const uint irow = thread_index % num_irows;
|
||||
|
||||
if (orow >= num_orows)
|
||||
return;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
|
||||
const uint idx_base = orow * row_size * num_irows + irow;
|
||||
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
|
||||
T val = input[idx];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = col;
|
||||
}
|
||||
values[idx] = static_cast<T>(accumulator);
|
||||
indices[idx] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Shared utility functions for strided kernels
|
||||
inline long calculate_non_scan_elements(
|
||||
constant long* sizes,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long total = 1;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
total *= sizes[i];
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
inline void thread_index_to_coordinates(
|
||||
uint index,
|
||||
int pos[c10::metal::max_ndim],
|
||||
constant long* sizes,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long remaining_index = index;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
pos[i] = remaining_index % sizes[i];
|
||||
remaining_index /= sizes[i];
|
||||
} else {
|
||||
pos[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline long calculate_base_offset(
|
||||
int pos[c10::metal::max_ndim],
|
||||
constant long* strides,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long offset = 0;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
offset += pos[i] * strides[i];
|
||||
}
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
// Generic strided scan kernel
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_strided(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant long* sizes [[buffer(2)]],
|
||||
constant long* input_strides [[buffer(3)]],
|
||||
constant long* output_strides [[buffer(4)]],
|
||||
constant uint& ndim [[buffer(5)]],
|
||||
constant uint& scan_dim [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const long total_non_scan_elements =
|
||||
calculate_non_scan_elements(sizes, ndim, scan_dim);
|
||||
if (thread_index >= total_non_scan_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos[c10::metal::max_ndim];
|
||||
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
|
||||
|
||||
const long input_base_offset =
|
||||
calculate_base_offset(pos, input_strides, ndim, scan_dim);
|
||||
const long output_base_offset =
|
||||
calculate_base_offset(pos, output_strides, ndim, scan_dim);
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
const long scan_size = sizes[scan_dim];
|
||||
const long input_scan_stride = input_strides[scan_dim];
|
||||
const long output_scan_stride = output_strides[scan_dim];
|
||||
|
||||
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
|
||||
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
|
||||
const long output_offset =
|
||||
output_base_offset + scan_idx * output_scan_stride;
|
||||
|
||||
T val = input[input_offset];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[output_offset] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Generic strided scan with indices kernel
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_strided(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant long* sizes [[buffer(3)]],
|
||||
constant long* input_strides [[buffer(4)]],
|
||||
constant long* values_strides [[buffer(5)]],
|
||||
constant long* indices_strides [[buffer(6)]],
|
||||
constant uint& ndim [[buffer(7)]],
|
||||
constant uint& scan_dim [[buffer(8)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const long total_non_scan_elements =
|
||||
calculate_non_scan_elements(sizes, ndim, scan_dim);
|
||||
if (thread_index >= total_non_scan_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos[c10::metal::max_ndim];
|
||||
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
|
||||
|
||||
const long input_base_offset =
|
||||
calculate_base_offset(pos, input_strides, ndim, scan_dim);
|
||||
const long values_base_offset =
|
||||
calculate_base_offset(pos, values_strides, ndim, scan_dim);
|
||||
const long indices_base_offset =
|
||||
calculate_base_offset(pos, indices_strides, ndim, scan_dim);
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
const long scan_size = sizes[scan_dim];
|
||||
const long input_scan_stride = input_strides[scan_dim];
|
||||
const long values_scan_stride = values_strides[scan_dim];
|
||||
const long indices_scan_stride = indices_strides[scan_dim];
|
||||
|
||||
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
|
||||
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
|
||||
const long values_offset =
|
||||
values_base_offset + scan_idx * values_scan_stride;
|
||||
const long indices_offset =
|
||||
indices_base_offset + scan_idx * indices_scan_stride;
|
||||
|
||||
T val = input[input_offset];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (scan_idx == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = scan_idx;
|
||||
}
|
||||
values[values_offset] = static_cast<T>(accumulator);
|
||||
indices[indices_offset] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE) \
|
||||
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
|
||||
scan_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant uint & num_rows [[buffer(2)]], \
|
||||
constant uint & row_size [[buffer(3)]], \
|
||||
uint row [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
|
||||
scan_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant uint & num_orows [[buffer(2)]], \
|
||||
constant uint & num_irows [[buffer(3)]], \
|
||||
constant uint & row_size [[buffer(4)]], \
|
||||
uint thread_index [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
|
||||
scan_strided<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant long* sizes [[buffer(2)]], \
|
||||
constant long* input_strides [[buffer(3)]], \
|
||||
constant long* output_strides [[buffer(4)]], \
|
||||
constant uint& ndim [[buffer(5)]], \
|
||||
constant uint& scan_dim [[buffer(6)]], \
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE) \
|
||||
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant uint& num_rows [[buffer(3)]], \
|
||||
constant uint& row_size [[buffer(4)]], \
|
||||
uint row [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant uint& num_orows [[buffer(3)]], \
|
||||
constant uint& num_irows [[buffer(4)]], \
|
||||
constant uint& row_size [[buffer(5)]], \
|
||||
uint thread_index [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_strided<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant long* sizes [[buffer(3)]], \
|
||||
constant long* input_strides [[buffer(4)]], \
|
||||
constant long* values_strides [[buffer(5)]], \
|
||||
constant long* indices_strides [[buffer(6)]], \
|
||||
constant uint& ndim [[buffer(7)]], \
|
||||
constant uint& scan_dim [[buffer(8)]], \
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
// Simple scan operations
|
||||
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float);
|
||||
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half);
|
||||
|
||||
// Scan operations with indices
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool);
|
||||
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
|
||||
|
||||
#else // __METAL_VERSION__ >= 310
|
||||
|
||||
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
|
||||
|
||||
// The reminder of this file contains cummin and cummax implementations adapted
|
||||
@ -1159,5 +786,3 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool, 4);
|
||||
|
||||
#endif
|
||||
|
||||
@ -89,6 +89,4 @@ REGISTER_SPECIAL(short, float);
|
||||
REGISTER_SPECIAL(int, float);
|
||||
REGISTER_SPECIAL(long, float);
|
||||
REGISTER_SPECIAL(half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_SPECIAL(bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
@ -100,9 +100,7 @@ kernel void triul(
|
||||
|
||||
INSTANTIATE_TRIUL_KERNELS(float, int);
|
||||
INSTANTIATE_TRIUL_KERNELS(half, int);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_TRIUL_KERNELS(bfloat, int);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_TRIUL_KERNELS(float2, int);
|
||||
INSTANTIATE_TRIUL_KERNELS(half2, int);
|
||||
|
||||
@ -556,11 +556,9 @@ REGISTER_UNARY_OP(abs, half, half);
|
||||
REGISTER_UNARY_OP(acos, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(atan, DTYPE1, DTYPE0)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
|
||||
REGISTER_UNARY_OP(neg, bfloat, bfloat);
|
||||
REGISTER_UNARY_OP(abs, bfloat, bfloat);
|
||||
#endif
|
||||
INSTANTIATE_UNARY_KERNELS2(half, half);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, float);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, bool);
|
||||
@ -600,6 +598,4 @@ INSTANTIATE_UNARY_KERNELS_VEC2(float);
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, float, long, float);
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, half, long, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat);
|
||||
#endif
|
||||
|
||||
@ -70,6 +70,4 @@ kernel void unfold_backward(
|
||||
|
||||
INSTANTIATE_UNFOLD_BACKWARD(float);
|
||||
INSTANTIATE_UNFOLD_BACKWARD(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UNFOLD_BACKWARD(bfloat);
|
||||
#endif
|
||||
|
||||
@ -852,6 +852,4 @@ INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar);
|
||||
INSTANTIATE_UPSAMPLE_3D(uchar);
|
||||
INSTANTIATE_UPSAMPLE_ALL(float);
|
||||
INSTANTIATE_UPSAMPLE_ALL(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UPSAMPLE_ALL(bfloat);
|
||||
#endif
|
||||
|
||||
@ -21,6 +21,8 @@
|
||||
#include <ATen/ops/max_pool2d_with_indices_native.h>
|
||||
#include <ATen/ops/max_pool3d_with_indices_backward_native.h>
|
||||
#include <ATen/ops/max_pool3d_with_indices_native.h>
|
||||
#include <ATen/ops/max_unpool2d_native.h>
|
||||
#include <ATen/ops/max_unpool3d_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
@ -492,6 +494,60 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
|
||||
});
|
||||
}
|
||||
|
||||
static void max_unpool_out_mps_template(const Tensor& input,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size_,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
Tensor& output,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
auto dims = input.dim();
|
||||
auto leading_dims = input.dim() - pooling_dims;
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
std::vector<int64_t> output_size(dims);
|
||||
for (int dim : c10::irange(leading_dims)) {
|
||||
output_size[dim] = input.sizes()[dim];
|
||||
}
|
||||
for (int dim : c10::irange(pooling_dims)) {
|
||||
output_size[leading_dims + dim] = output_size_[dim];
|
||||
}
|
||||
|
||||
output.resize_(output_size, memory_format);
|
||||
output.fill_(0);
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const auto numThreads = input.numel();
|
||||
MaxUnpoolingParams<5> params;
|
||||
|
||||
params.dims = dims;
|
||||
params.pooling_dims = pooling_dims;
|
||||
|
||||
for (const auto dim : c10::irange(dims)) {
|
||||
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(output.size(dim));
|
||||
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(output.stride(dim));
|
||||
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(input.size(dim));
|
||||
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(input.stride(dim));
|
||||
params.indices_strides[dim] = safe_downcast<int32_t, int64_t>(indices.stride(dim));
|
||||
}
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
auto PSO = lib.getPipelineStateForFunc("max_unpool_" + scalarToMetalTypeString(input));
|
||||
|
||||
getMPSProfiler().beginProfileKernel(PSO, op_name, {input});
|
||||
[computeEncoder setComputePipelineState:PSO];
|
||||
mtl_setArgs(computeEncoder, output, input, indices, params);
|
||||
|
||||
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(PSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void avg_pool2d_template(const Tensor& input,
|
||||
const Tensor& output,
|
||||
const std::optional<Tensor>& grad_output_opt,
|
||||
@ -896,6 +952,68 @@ Tensor max_pool3d_with_indices_backward_mps(const Tensor& grad_output,
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
Tensor& max_unpooling2d_forward_out_mps(const Tensor& self,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size,
|
||||
Tensor& output) {
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
/*stride=*/{},
|
||||
/*padding=*/{},
|
||||
output,
|
||||
/*pooling_dims=*/2,
|
||||
"max_unpool2d");
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor max_unpooling2d_forward_mps(const Tensor& self, const Tensor& indices, IntArrayRef output_size) {
|
||||
auto output = at::empty({0}, self.options());
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
/*stride=*/{},
|
||||
/*padding=*/{},
|
||||
output,
|
||||
/*pooling_dims=*/2,
|
||||
"max_unpool2d");
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& max_unpooling3d_forward_out_mps(const Tensor& self,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
Tensor& output) {
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
stride,
|
||||
padding,
|
||||
output,
|
||||
/*pooling_dims=*/3,
|
||||
"max_unpool3d");
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor max_unpooling3d_forward_mps(const Tensor& self,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding) {
|
||||
auto output = at::empty({0}, self.options());
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
stride,
|
||||
padding,
|
||||
output,
|
||||
/*pooling_dims=*/3,
|
||||
"max_unpool3d");
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool2d_out_mps)
|
||||
(const Tensor& input,
|
||||
int64_t kH,
|
||||
|
||||
@ -719,6 +719,7 @@
|
||||
dispatch:
|
||||
CPU, CUDA: all_out
|
||||
MPS: all_out_mps
|
||||
MTIA: all_out_mtia
|
||||
|
||||
- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -808,6 +809,7 @@
|
||||
CPU, Meta: arange_out
|
||||
CUDA: arange_cuda_out
|
||||
MPS: arange_mps_out
|
||||
MTIA: arange_mtia_out
|
||||
cpp_no_default_args: ['step']
|
||||
|
||||
# This function is a temporary hack to allow tracing of arange like constructs with dynamic
|
||||
@ -1889,7 +1891,10 @@
|
||||
- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: cudnn_batch_norm
|
||||
autogen: cudnn_batch_norm.out
|
||||
|
||||
- func: cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
|
||||
dispatch:
|
||||
CUDA: cudnn_batch_norm_out
|
||||
|
||||
# NB: You can only use this if you used cudnn_batch_norm training=True
|
||||
- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)
|
||||
@ -4182,11 +4187,13 @@
|
||||
dispatch:
|
||||
CPU: _int_mm_cpu
|
||||
CUDA: _int_mm_cuda
|
||||
XPU: _int_mm_xpu
|
||||
|
||||
- func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: _int_mm_out_cpu
|
||||
CUDA: _int_mm_out_cuda
|
||||
XPU: _int_mm_out_xpu
|
||||
|
||||
- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
|
||||
dispatch:
|
||||
@ -7124,18 +7131,21 @@
|
||||
dispatch:
|
||||
CPU: _scaled_mm_cpu
|
||||
CUDA: _scaled_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _scaled_mm_out_cpu
|
||||
CUDA: _scaled_mm_out_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
|
||||
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_grouped_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
||||
variants: function
|
||||
@ -10487,6 +10497,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_add_scalar_kernel_cuda_
|
||||
MTIA: foreach_tensor_add_scalar_kernel_mtia_
|
||||
autogen: _foreach_add.Scalar_out
|
||||
|
||||
- func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
|
||||
@ -10495,6 +10506,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow
|
||||
CUDA: foreach_tensor_add_list_kernel_cuda
|
||||
MTIA: foreach_tensor_add_list_kernel_mtia
|
||||
|
||||
- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10502,6 +10514,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_
|
||||
CUDA: foreach_tensor_add_list_kernel_cuda_
|
||||
MTIA: foreach_tensor_add_list_kernel_mtia_
|
||||
autogen: _foreach_add.List_out
|
||||
|
||||
- func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||
@ -10532,6 +10545,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_
|
||||
CUDA: foreach_tensor_add_tensor_kernel_cuda_
|
||||
MTIA: foreach_tensor_add_tensor_kernel_mtia_
|
||||
autogen: _foreach_add.Tensor_out
|
||||
|
||||
- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||
@ -10592,6 +10606,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_scalar_kernel_cuda_
|
||||
MTIA: foreach_tensor_mul_scalar_kernel_mtia_
|
||||
autogen: _foreach_mul.Scalar_out
|
||||
|
||||
- func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]
|
||||
@ -10600,6 +10615,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow
|
||||
CUDA: foreach_tensor_mul_list_kernel_cuda
|
||||
MTIA: foreach_tensor_mul_list_kernel_mtia
|
||||
|
||||
- func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10607,6 +10623,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_list_kernel_cuda_
|
||||
MTIA: foreach_tensor_mul_list_kernel_mtia_
|
||||
autogen: _foreach_mul.List_out
|
||||
|
||||
- func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||
@ -10630,6 +10647,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow
|
||||
CUDA: foreach_tensor_mul_tensor_kernel_cuda
|
||||
MTIA: foreach_tensor_mul_tensor_kernel_mtia
|
||||
|
||||
- func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10637,6 +10655,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_tensor_kernel_cuda_
|
||||
MTIA: foreach_tensor_mul_tensor_kernel_mtia_
|
||||
autogen: _foreach_mul.Tensor_out
|
||||
|
||||
- func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||
@ -10933,6 +10952,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow
|
||||
CUDA: foreach_tensor_addcmul_scalar_cuda
|
||||
MTIA: foreach_tensor_addcmul_scalar_mtia
|
||||
|
||||
- func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10954,6 +10974,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_
|
||||
CUDA: foreach_tensor_addcmul_scalar_cuda_
|
||||
MTIA: foreach_tensor_addcmul_scalar_mtia_
|
||||
autogen: _foreach_addcmul.Scalar_out
|
||||
|
||||
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
|
||||
@ -10978,6 +10999,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_abs_slow
|
||||
CUDA: foreach_tensor_abs_cuda
|
||||
MTIA: foreach_tensor_abs_mtia
|
||||
|
||||
- func: _foreach_abs_(Tensor(a!)[] self) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10985,6 +11007,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_abs_slow_
|
||||
CUDA: foreach_tensor_abs_cuda_
|
||||
MTIA: foreach_tensor_abs_mtia_
|
||||
autogen: _foreach_abs.out
|
||||
|
||||
- func: _foreach_acos(Tensor[] self) -> Tensor[]
|
||||
@ -11319,6 +11342,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_norm_slow
|
||||
CUDA: foreach_tensor_norm_cuda
|
||||
MTIA: foreach_tensor_norm_mtia
|
||||
autogen: _foreach_norm.Scalar_out
|
||||
|
||||
- func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
|
||||
@ -11491,6 +11515,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_sqrt_slow_
|
||||
CUDA: foreach_tensor_sqrt_cuda_
|
||||
MTIA: foreach_tensor_sqrt_mtia_
|
||||
autogen: _foreach_sqrt.out
|
||||
|
||||
- func: _foreach_tan(Tensor[] self) -> Tensor[]
|
||||
@ -11552,6 +11577,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_
|
||||
CUDA: foreach_tensor_copy_list_kernel_cuda_
|
||||
MTIA: foreach_tensor_copy_list_kernel_mtia_
|
||||
autogen: _foreach_copy.out
|
||||
|
||||
- func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out
|
||||
@ -11559,6 +11585,7 @@
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _foreach_copy
|
||||
MTIA: foreach_tensor_copy_list_kernel_mtia
|
||||
|
||||
- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor
|
||||
dispatch:
|
||||
@ -12476,24 +12503,28 @@
|
||||
dispatch:
|
||||
CPU: max_unpooling2d_forward_out_cpu
|
||||
CUDA: max_unpooling2d_forward_out_cuda
|
||||
MPS: max_unpooling2d_forward_out_mps
|
||||
|
||||
- func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: max_unpooling2d_forward_cpu
|
||||
CUDA: max_unpooling2d_forward_cuda
|
||||
MPS: max_unpooling2d_forward_mps
|
||||
|
||||
- func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: max_unpooling3d_forward_out_cpu
|
||||
CUDA: max_unpooling3d_forward_out_cuda
|
||||
MPS: max_unpooling3d_forward_out_mps
|
||||
|
||||
- func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: max_unpooling3d_forward_cpu
|
||||
CUDA: max_unpooling3d_forward_cuda
|
||||
MPS: max_unpooling3d_forward_mps
|
||||
|
||||
- func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: nn
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api fwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
|
||||
--api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -11,7 +11,27 @@ endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api bwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
|
||||
--api fwd_splitkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_SPLITKV kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api fwd_appendkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_APPENDKV kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -19,15 +39,29 @@ if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
# Generate the files for both fwd and bwd
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
# Generate the files for both fwd, fwd_splitkv, fwd_appendkv, and bwd
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_splitkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_SPLITKV kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_appendkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_APPENDKV kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -44,6 +78,22 @@ if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd pass")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt"
|
||||
RESULT_VARIABLE ret)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd_splitkv pass")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt"
|
||||
RESULT_VARIABLE ret)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd appendkv pass")
|
||||
endif()
|
||||
|
||||
# Change make_kernel to make_kernel_pt for bwd
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt"
|
||||
|
||||
@ -21,6 +21,8 @@ while IFS= read -r file; do
|
||||
if [ -f "$file" ]; then
|
||||
# Use sed to replace "make_kernel" with "make_kernel_pt" in place
|
||||
sed -i 's/make_kernel/make_kernel_pt/g' "$file"
|
||||
sed -i 's/\#include \"fmha_fwd.hpp\"/\#include \"fmha_fwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file"
|
||||
sed -i 's/\#include \"fmha_bwd.hpp\"/\#include \"fmha_bwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file"
|
||||
echo "Updated: $file"
|
||||
else
|
||||
echo "Skipping: $file (not found)"
|
||||
|
||||
@ -1,100 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
// keep sync with BlockAttentionBiasEnum
|
||||
enum class bias_enum
|
||||
{
|
||||
no_bias = 0,
|
||||
elementwise_bias = 1,
|
||||
alibi = 2,
|
||||
};
|
||||
|
||||
struct bias_info
|
||||
{
|
||||
bias_enum type;
|
||||
/*
|
||||
* simple dispatch logic
|
||||
*
|
||||
* if type == elementwise_bias:
|
||||
* if rank_info == 0:
|
||||
* bias is 1*1*s*s
|
||||
* elif rank_info == 1:
|
||||
* bias is 1*h*s*s
|
||||
* elif rank_info == 2:
|
||||
* bias is b*h*s*s
|
||||
*
|
||||
* elif type == alibi:
|
||||
* if rank_info == 0:
|
||||
* alibi in 1*h
|
||||
* elif rank_info == 1:
|
||||
* alibi in b*h
|
||||
*/
|
||||
int rank_info;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == bias_enum::no_bias)
|
||||
os << "n";
|
||||
else if(type == bias_enum::elementwise_bias)
|
||||
{
|
||||
os << "e";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
else if(type == bias_enum::alibi)
|
||||
{
|
||||
os << "alibi";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bias_info decode(std::string str)
|
||||
{
|
||||
bias_info info{bias_enum::no_bias, 0};
|
||||
if(str == "0" || str == "n")
|
||||
{
|
||||
info.type = bias_enum::no_bias;
|
||||
}
|
||||
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
|
||||
str.compare(0, 11, "elementwise") == 0)
|
||||
{
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string e = str.substr(found_0 + 1);
|
||||
info.rank_info = atoi(e.c_str());
|
||||
}
|
||||
}
|
||||
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
|
||||
str.compare(0, 5, "alibi") == 0)
|
||||
{
|
||||
info.type = bias_enum::alibi;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string e = str.substr(found_0 + 1);
|
||||
info.rank_info = atoi(e.c_str());
|
||||
}
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
|
||||
{
|
||||
bi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
@ -1,457 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <mask.hpp>
|
||||
#include <bias.hpp>
|
||||
#include <launch_kernel_pt.hpp>
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
struct FmhaBwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaBwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaBwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using GemmDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = ck_tile::half_t;
|
||||
using OGradDataType = ck_tile::half_t;
|
||||
using QGradDataType = ck_tile::half_t;
|
||||
using KGradDataType = ck_tile::half_t;
|
||||
using VGradDataType = ck_tile::half_t;
|
||||
using BiasGradDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using GemmDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using OGradDataType = ck_tile::bf16_t;
|
||||
using QGradDataType = ck_tile::bf16_t;
|
||||
using KGradDataType = ck_tile::bf16_t;
|
||||
using VGradDataType = ck_tile::bf16_t;
|
||||
using BiasGradDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
struct fmha_bwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
const void* o_ptr;
|
||||
const void* lse_ptr;
|
||||
const void* do_ptr;
|
||||
void* d_ptr;
|
||||
void* rand_val_ptr;
|
||||
void* dq_ptr;
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* dq_acc_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t max_seqlen_k;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
float scale;
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_dq_acc;
|
||||
ck_tile::index_t stride_dq;
|
||||
ck_tile::index_t stride_dk;
|
||||
ck_tile::index_t stride_dv;
|
||||
ck_tile::index_t stride_dbias;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
ck_tile::index_t nhead_stride_dbias;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
ck_tile::index_t batch_stride_dbias;
|
||||
ck_tile::index_t split_stride_dq_acc;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
float p_undrop;
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
template <typename FmhaBwdDQDKDVKernel>
|
||||
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
args.batch_stride_dq_acc,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdOGradDotOKernel>
|
||||
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_lsed);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdConvertQGradKernel>
|
||||
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_dot_do_o_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
bool kPadS_,
|
||||
bool kPadD_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_convert_dq_get_name_();
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_bwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_dbias;
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
template <int Version = 2>
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
@ -1,824 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
#include <bias.hpp>
|
||||
#include <mask.hpp>
|
||||
#include <rotary.hpp>
|
||||
#include <launch_kernel_pt.hpp>
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
struct FmhaFwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdBf8
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8Fp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp8>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::fp8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdBf8>
|
||||
{
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using KDataType = ck_tile::bf8_t;
|
||||
using VDataType = ck_tile::bf8_t;
|
||||
using BiasDataType = ck_tile::bf8_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf8_t;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
struct fmha_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* rand_val_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void*
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
struct fmha_fwd_splitkv_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* lse_acc_ptr;
|
||||
void* o_acc_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
|
||||
// nullptr.
|
||||
|
||||
const void* cache_batch_idx;
|
||||
|
||||
// the real seqlen_q & seqlen_k are decided by following:
|
||||
// batch mode: seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k
|
||||
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
//
|
||||
// batch mode (kvcache):
|
||||
// seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// group mode (kvcache):
|
||||
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
//
|
||||
// when is_gappy=true:
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// seqstart_k_ptr[b] now store local offset of each batch
|
||||
//
|
||||
// when is_gappy=false:
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
ck_tile::index_t num_splits;
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_o_acc;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_lse_acc;
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
};
|
||||
|
||||
struct fmha_fwd_appendkv_args
|
||||
{
|
||||
void* q_ptr;
|
||||
void* k_ptr;
|
||||
const void* knew_ptr;
|
||||
void* v_ptr;
|
||||
const void* vnew_ptr;
|
||||
|
||||
const void* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_knew;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0
|
||||
const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0
|
||||
ck_tile::index_t rotary_dim;
|
||||
bool has_mask;
|
||||
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_knew;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_vnew;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_knew;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_vnew;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_knew;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_vnew;
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
dim3 grids = FmhaKernel::GridSize(
|
||||
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 grids =
|
||||
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(Kernel::kIsGroupMode)
|
||||
{
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.batch,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.is_gappy,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_o_acc,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_k, // only used for paged-kvcache
|
||||
args.batch_stride_v, // only used for paged-kvcache
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.batch,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.cache_batch_idx,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_o_acc,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = Kernel::GridSize(
|
||||
args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel argumentszs
|
||||
if constexpr(Kernel::kIsGroupMode)
|
||||
{
|
||||
return Kernel::MakeKargs(args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
args.scale_o,
|
||||
args.stride_o_acc,
|
||||
args.stride_o,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return Kernel::MakeKargs(args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
args.scale_o,
|
||||
args.stride_o_acc,
|
||||
args.stride_o,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.knew_ptr,
|
||||
args.v_ptr,
|
||||
args.vnew_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k_ptr,
|
||||
args.seqlen_knew,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.rotary_cos_ptr,
|
||||
args.rotary_sin_ptr,
|
||||
args.rotary_dim,
|
||||
args.has_mask,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.cache_batch_idx,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
args.stride_v,
|
||||
args.stride_vnew,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_knew,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_vnew,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_knew,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_vnew);
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kIsPagedKV_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_splitkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kN1_,
|
||||
bool kStoreLse_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kPadS_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_splitkv_combine_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_();
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
ck_tile::index_t kTileSizeS_,
|
||||
ck_tile::index_t kTileSizeSk_,
|
||||
ck_tile::index_t kTileSizeD_,
|
||||
ck_tile::index_t kTileSizeDv_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
bool kPadS_,
|
||||
bool kPadSk_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
ck_tile::RotaryEmbeddingEnum RotaryEnum_,
|
||||
bool kIsPagedKV_>
|
||||
struct fmha_fwd_appendkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
|
||||
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
|
||||
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
|
||||
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSk = kPadSk_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr auto RotaryEnum = RotaryEnum_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_fwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool has_dropout;
|
||||
bool do_fp8_static_quant;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_fwd_splitkv_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool do_fp8_static_quant;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
|
||||
fmha_fwd_splitkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_fwd_appendkv_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_v_rowmajor;
|
||||
rope_enum rope_type;
|
||||
};
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
|
||||
fmha_fwd_appendkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
@ -1,157 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
||||
enum class mask_enum
|
||||
{
|
||||
no_mask = 0,
|
||||
mask_top_left,
|
||||
mask_bottom_right,
|
||||
window_generic,
|
||||
};
|
||||
|
||||
struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
os << "n";
|
||||
else if(type == mask_enum::mask_top_left)
|
||||
os << "t(" << left << ":" << right << ")";
|
||||
else if(type == mask_enum::mask_bottom_right)
|
||||
os << "b(" << left << ":" << right << ")";
|
||||
else
|
||||
{
|
||||
os << "g(" << y << ":" << x << ")";
|
||||
}
|
||||
}
|
||||
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
||||
{
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
ck_tile::index_t y_total = seqlen_q;
|
||||
mask_info tmp;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string t = str.substr(0, found_0);
|
||||
std::string v = str.substr(found_0 + 1);
|
||||
if(t == "xt" || t == "xb")
|
||||
{
|
||||
// xformer style sliding window attn from top-left
|
||||
ck_tile::index_t window_size = atoi(v.c_str());
|
||||
ck_tile::index_t left_size = -1;
|
||||
ck_tile::index_t right_size = 0;
|
||||
if(window_size > 0)
|
||||
{
|
||||
left_size = window_size / 2;
|
||||
right_size = window_size - 1 - left_size;
|
||||
}
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, t == "xt");
|
||||
|
||||
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = left_size;
|
||||
tmp.right = right_size;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto found_1 = v.find(",");
|
||||
if(found_1 == std::string::npos)
|
||||
{
|
||||
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
|
||||
assert(0);
|
||||
}
|
||||
tmp.type = mask_enum::window_generic;
|
||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
||||
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
// TODO: some validation
|
||||
if(t == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, true);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, false);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
tmp.y = v0;
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("not supported type %s, %s\n", t.c_str(), str.c_str());
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto set_causal_top_left = [&]() {
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
};
|
||||
auto set_causal_bottom_right = [&]() {
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
};
|
||||
if(str == "t")
|
||||
set_causal_top_left();
|
||||
else if(str == "b")
|
||||
set_causal_bottom_right();
|
||||
else
|
||||
{
|
||||
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
|
||||
if(tmp.type == mask_enum::mask_top_left)
|
||||
{
|
||||
set_causal_top_left();
|
||||
}
|
||||
else if(tmp.type == mask_enum::mask_bottom_right)
|
||||
{
|
||||
set_causal_bottom_right();
|
||||
}
|
||||
}
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
@ -22,6 +22,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
|
||||
dtype,
|
||||
false, // is_group_mode
|
||||
true, // is_v_rowmajor
|
||||
false, // has_logits_soft_cap
|
||||
mask.type,
|
||||
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
has_lse,
|
||||
@ -85,6 +86,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
ck_tile::index_t stride_attn_bias = 0;
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
ck_tile::index_t nhead_stride_bias = 0;
|
||||
|
||||
if (attn_bias_.has_value()) {
|
||||
auto a_b = attn_bias_.value();
|
||||
CHECK_DEVICE(a_b);
|
||||
@ -94,7 +96,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
nhead_stride_bias = a_b.stride(1);
|
||||
batch_stride_bias = a_b.stride(0);
|
||||
}
|
||||
|
||||
return fmha_fwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
@ -116,6 +117,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
softmax_scale, // scale_s
|
||||
1, // scale_p
|
||||
1, // scale_o
|
||||
0.0f, // logits_soft_cap
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@ -139,6 +141,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
-1, // min_seqlen_q
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
drop_seed_offset};
|
||||
|
||||
@ -20,6 +20,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
|
||||
dtype,
|
||||
true, // is_group_mode
|
||||
true, // is_v_rowmajor
|
||||
false, // has_logits_soft_cap
|
||||
mask.type,
|
||||
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
has_lse,
|
||||
@ -117,6 +118,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
||||
softmax_scale, // scale_s
|
||||
1, // scale_p
|
||||
1, // scale_o
|
||||
0.0f, // logits_soft_cap
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@ -140,6 +142,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
-1, // min_seqlen_q
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
drop_seed_offset};
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/host_tensor.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <tuple>
|
||||
|
||||
// keep sync with RotaryEmbeddingEnum
|
||||
enum class rope_enum
|
||||
{
|
||||
none = 0,
|
||||
interleaved = 1,
|
||||
half_rotated = 2,
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
generate_rotary_cos_sin(ck_tile::index_t seqlen,
|
||||
ck_tile::index_t rotary_dim,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
// return dummy tensors if we won't apply RoPE at all
|
||||
if(rotary_dim <= 0)
|
||||
{
|
||||
ck_tile::HostTensor<DataType> dummy({1, 1});
|
||||
return std::make_tuple(dummy, dummy);
|
||||
}
|
||||
|
||||
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
|
||||
|
||||
const ck_tile::index_t num_rows = seqlen * 2;
|
||||
const ck_tile::index_t num_cols = rotary_dim / 2;
|
||||
|
||||
using std::begin, std::end;
|
||||
|
||||
ck_tile::HostTensor<float> angle({num_rows, num_cols});
|
||||
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
|
||||
|
||||
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
|
||||
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
|
||||
return ck_tile::type_convert<DataType>(std::cos(origin_value));
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
|
||||
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
|
||||
return ck_tile::type_convert<DataType>(std::sin(origin_value));
|
||||
});
|
||||
|
||||
return std::make_tuple(cos, sin);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
|
||||
const ck_tile::HostTensor<DataType>& sin,
|
||||
ck_tile::index_t seqlen_offset,
|
||||
ck_tile::index_t seqlen)
|
||||
{
|
||||
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
|
||||
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
|
||||
|
||||
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
|
||||
|
||||
const ck_tile::index_t num_rows = seqlen;
|
||||
const ck_tile::index_t num_cols = cos.get_length(1);
|
||||
|
||||
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
|
||||
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
|
||||
|
||||
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
|
||||
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
|
||||
|
||||
return std::make_tuple(cos_pt, sin_pt);
|
||||
}
|
||||
@ -5,6 +5,12 @@ import os
|
||||
import sys
|
||||
|
||||
|
||||
# Run only this selected group of models, leave this empty to run everything
|
||||
TORCHBENCH_ONLY_MODELS = [
|
||||
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
|
||||
]
|
||||
|
||||
|
||||
# Note - hf and timm have their own version of this, torchbench does not
|
||||
# TODO(voz): Someday, consolidate all the files into one runner instead of a shim like this...
|
||||
def model_names(filename: str) -> set[str]:
|
||||
@ -17,6 +23,8 @@ def model_names(filename: str) -> set[str]:
|
||||
if len(line_parts) == 1:
|
||||
line_parts = line.split(",")
|
||||
model_name = line_parts[0]
|
||||
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
|
||||
continue
|
||||
names.add(model_name)
|
||||
return names
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import copy
|
||||
import csv
|
||||
import dataclasses
|
||||
import functools
|
||||
import gc
|
||||
import importlib
|
||||
import itertools
|
||||
import json
|
||||
@ -2387,6 +2388,7 @@ class BenchmarkRunner:
|
||||
)
|
||||
|
||||
def warmup(fn, model, example_inputs, mode, niters=10):
|
||||
gc.collect()
|
||||
peak_mem = 0
|
||||
start_stats = get_dynamo_stats()
|
||||
try:
|
||||
@ -2548,6 +2550,7 @@ class BenchmarkRunner:
|
||||
return experiment(*self.maybe_cast(model, example_inputs))
|
||||
|
||||
def warmup(fn, model, example_inputs, mode, niters=5):
|
||||
gc.collect()
|
||||
peak_mem = 0
|
||||
start_stats = get_dynamo_stats()
|
||||
try:
|
||||
|
||||
@ -106,6 +106,11 @@ finally:
|
||||
# on A100 GPUs - 40 GB.
|
||||
BATCH_SIZE_KNOWN_MODELS = {}
|
||||
|
||||
# Run only this selected group of models, leave this empty to run everything
|
||||
TORCHBENCH_ONLY_MODELS = [
|
||||
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
|
||||
]
|
||||
|
||||
|
||||
# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py
|
||||
# Get the list of models and their batch sizes
|
||||
@ -116,6 +121,8 @@ with open(MODELS_FILENAME) as fh:
|
||||
lines = [line.rstrip() for line in lines]
|
||||
for line in lines:
|
||||
model_name, batch_size = line.split(",")
|
||||
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
|
||||
continue
|
||||
batch_size = int(batch_size)
|
||||
BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
|
||||
assert len(BATCH_SIZE_KNOWN_MODELS)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user