mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 06:48:48 +08:00
Update build_triton_whl.py for new devices
This commit is contained in:
9
.github/scripts/build_triton_wheel.py
vendored
9
.github/scripts/build_triton_wheel.py
vendored
@ -136,7 +136,7 @@ def main() -> None:
|
||||
parser = ArgumentParser("Build Triton binaries")
|
||||
parser.add_argument("--release", action="store_true")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", choices=["cuda", "rocm", "xpu", "aarch64"]
|
||||
"--device", type=str, default="cuda", choices=["cuda", "rocm-n", "rocm-n-1", "xpu", "aarch64"]
|
||||
)
|
||||
parser.add_argument("--py-version", type=str)
|
||||
parser.add_argument("--commit-hash", type=str)
|
||||
@ -148,8 +148,13 @@ def main() -> None:
|
||||
if args.triton_version:
|
||||
triton_version = args.triton_version
|
||||
|
||||
# Normalize device name for rocm-n rocm-n-1 builds
|
||||
device = args.device
|
||||
if args.device.startswith("rocm"):
|
||||
device = "rocm"
|
||||
|
||||
build_triton(
|
||||
device=args.device,
|
||||
device=device,
|
||||
commit_hash=(
|
||||
args.commit_hash if args.commit_hash else read_triton_pin(args.device)
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user