[ROCm] Pyt 2.0 rocm staging (#94660)

Add triton support for ROCm builds of PyTorch.

* Enables inductor and dynamo when rocm is detected
* Adds support for pytorch-triton-mlir backend
* Adds check_rocm support for verify_dynamo.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94660
Approved by: https://github.com/malfet
This commit is contained in:
Douglas Lehr
2023-02-15 06:15:18 +00:00
committed by PyTorch MergeBot
parent 71ec2617d2
commit 77d1135566
3 changed files with 68 additions and 1 deletions

View File

@ -0,0 +1 @@
pytorch-triton-rocm>=2.0.0.dev

View File

@ -939,6 +939,13 @@ def configure_extension_build():
# These extensions are built by cmake and copied manually in build_extensions() # These extensions are built by cmake and copied manually in build_extensions()
# inside the build_ext implementation # inside the build_ext implementation
if cmake_cache_vars['USE_ROCM']:
triton_req_file = os.path.join(cwd, ".github", "requirements", "triton-requirements-rocm.txt")
if os.path.exists(triton_req_file):
with open(triton_req_file) as f:
triton_req = f.read().strip()
extra_install_requires.append(triton_req)
if cmake_cache_vars['BUILD_CAFFE2']: if cmake_cache_vars['BUILD_CAFFE2']:
extensions.append( extensions.append(
Extension( Extension(

View File

@ -8,6 +8,7 @@ import warnings
from pkg_resources import packaging from pkg_resources import packaging
MIN_CUDA_VERSION = packaging.version.parse("11.6") MIN_CUDA_VERSION = packaging.version.parse("11.6")
MIN_ROCM_VERSION = packaging.version.parse("5.4")
MIN_PYTHON_VERSION = (3, 8) MIN_PYTHON_VERSION = (3, 8)
@ -52,6 +53,31 @@ def get_cuda_version():
return packaging.version.parse(cuda_str_version) return packaging.version.parse(cuda_str_version)
def get_rocm_version():
from torch.utils import cpp_extension
ROCM_HOME = cpp_extension._find_rocm_home()
if not ROCM_HOME:
raise VerifyDynamoError(
"ROCM was not found on the system, please set ROCM_HOME environment variable"
)
hipcc = os.path.join(ROCM_HOME, "bin", "hipcc")
hip_version_str = (
subprocess.check_output([hipcc, "--version"])
.strip()
.decode(*cpp_extension.SUBPROCESS_DECODE_ARGS)
)
hip_version = re.search(r"HIP version: (\d+[.]\d+)", hip_version_str)
if hip_version is None:
raise VerifyDynamoError("HIP version not found in `hipcc --version` output")
hip_str_version = hip_version.group(1)
return packaging.version.parse(hip_str_version)
def check_cuda(): def check_cuda():
import torch import torch
@ -81,7 +107,38 @@ def check_cuda():
f"- minimum requirement: {MIN_CUDA_VERSION}" f"- minimum requirement: {MIN_CUDA_VERSION}"
) )
return cuda_ver return cuda_ver if torch.version.hip is None else "None"
def check_rocm():
import torch
if not torch.cuda.is_available() or torch.version.hip is None:
return None
# Extracts main ROCm version from full string
torch_rocm_ver = packaging.version.parse(
".".join(list(torch.version.hip.split(".")[0:2]))
)
# check if torch rocm version matches system rocm version
rocm_ver = get_rocm_version()
if rocm_ver != torch_rocm_ver:
warnings.warn(
f"ROCm version mismatch, `torch` version: {torch_rocm_ver}, env version: {rocm_ver}"
)
if torch_rocm_ver < MIN_ROCM_VERSION:
warnings.warn(
f"(`torch`) ROCm version not supported: {torch_rocm_ver} "
f"- minimum requirement: {MIN_ROCM_VERSION}"
)
if rocm_ver < MIN_ROCM_VERSION:
warnings.warn(
f"(env) ROCm version not supported: {rocm_ver} "
f"- minimum requirement: {MIN_ROCM_VERSION}"
)
return rocm_ver if torch.version.hip else "None"
def check_dynamo(backend, device, err_msg): def check_dynamo(backend, device, err_msg):
@ -150,10 +207,12 @@ def main():
python_ver = check_python() python_ver = check_python()
torch_ver = check_torch() torch_ver = check_torch()
cuda_ver = check_cuda() cuda_ver = check_cuda()
rocm_ver = check_rocm()
print( print(
f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n" f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n"
f"`torch` version: {torch_ver}\n" f"`torch` version: {torch_ver}\n"
f"CUDA version: {cuda_ver}\n" f"CUDA version: {cuda_ver}\n"
f"ROCM version: {rocm_ver}\n"
) )
for args in _SANITY_CHECK_ARGS: for args in _SANITY_CHECK_ARGS:
check_dynamo(*args) check_dynamo(*args)