[ROCM] Support Multi-GPU offline tuning in TunableOp (#139673)

This PR enhances offline tuning to support multi-GPUs.

High-level description of algorithm:
- Duplicate GEMMs are first eliminated
- GEMMs are distributed to multi-GPUs for tuning
- Results are gathered into a file with `_full` in the filename

Also adding support for GemmAndBias and ScaledGemm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139673
Approved by: https://github.com/jeffdaily, https://github.com/hongxiayang
This commit is contained in:
Nichols A. Romero
2024-11-26 19:07:41 +00:00
committed by PyTorch MergeBot
parent 5b4c864672
commit a99332eb25
5 changed files with 381 additions and 56 deletions

View File

@ -952,6 +952,9 @@ test_distributed() {
python test/run_test.py --cpp --verbose -i cpp/HashStoreTest
python test/run_test.py --cpp --verbose -i cpp/TCPStoreTest
echo "Testing multi-GPU linalg tests"
python test/run_test.py -i test_linalg.py -k test_matmul_offline_mgpu_tunable --verbose
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
MPIEXEC=$(command -v mpiexec)
if [[ -n "$MPIEXEC" ]]; then

View File

@ -80,18 +80,23 @@ fastest available implementation across both rocblas and hipblaslt.
## Offline Tuning
### Motivation
Basically it is used for workload with high-memory utilization where one might run out of memory with regular tuning.
There are a couple of uses cases for offline tuning.
One use case, is a workload with a high-memory utilization where one might run out of memory with regular tuning.
Another use case would be a workload that is compute intensive to run and it would be more resource efficient to collect
the GEMMs for the workload once, and then tune repeatedly with different tuning parameters or libraries.
### Workflow
There are basically two steps:
1) Set the environment variables to collect the untuned GEMM and this will generate `tunableop_untuned?.csv` ("?" is placeholder for the GPU ID), like:
1) Set the environment variables to collect the untuned GEMM and this will generate `tunableop_untuned0.csv`
```
PYTORCH_TUNABLEOP_ENABLED=1
PYTORCH_TUNABLEOP_TUNING=0
PYTORCH_TUNABLEOP_RECORD_UNTUNED=1
...
```
2) Run a Python script that reads the `tunableop_untuned?.csv` and generates the `tunableop_results?.csv`, like:
2) Run a Python script that reads the `tunableop_untuned0.csv` and generates the `tunableop_results0.csv`, like:
```
import torch.cuda.tunable as tunable
import os
@ -99,9 +104,29 @@ import os
os.putenv('PYTORCH_TUNABLEOP_ENABLED', '1')
os.putenv('PYTORCH_TUNABLEOP_TUNING', '1')
os.putenv('PYTORCH_TUNABLEOP_RECORD_UNTUNED', '0')
tunable.tune_gemm_in_file("tunableop_results?.csv")
tunable.tune_gemm_in_file("tunableop_untuned0.csv")
```
It is also possible to take multiple untuned files and distribute the GEMMs for tuning to multiple GPUs
within a single node. In the first step, the GEMMs are first gathered and duplicate GEMMs are eliminated.
Next, the GEMMs are distributed to different GPUs for tuning. After all GEMMs are tuned, the results from
all the GPUs are then gathered into a single file whose base filename has `_full0` appended to it
(e.g. `tunableop_results_full0.csv`). Finally, this new file, containing the gathered results, will be
duplicated N times, once for each GPU as convenience to the user will run the workload with the tuned
configuration on N GPUs.
```
if __name__ == "__main__":
num_gpus = 8 # number of GPUs that will be used during the tuning process
tunable.mgpu_tune_gemm_in_file("tunableop_untuned?.csv", num_gpus)
```
Note that the usage of the `mgpu_tune_gemm_in_file` API is different from its single GPU counterpart
(`tune_gemm_in_file`). The body of the Python script that calls the API must be wrapped in `main()` as shown
due to the use of concurrent futures module. The argument to `mgpu_tune_gemm_in_file` must contain a wild card
expression (? or *) to generate the list of untuned files containing the GEMMs to be processed. The `num_gpus`
must between 1 and the total number of GPUs available.
## Tuning Context
The behavior of TunableOp is currently manipulated through environment variables, the C++ interface of
at::cuda::tunable::getTuningContext(), or the `torch.cuda.tunable` python interfaces. The environment variables take
@ -153,6 +178,7 @@ All python APIs exist in the `torch.cuda.tunable` module.
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |
### C++ Interface
Example:

View File

@ -33,3 +33,4 @@ API Reference
.. autofunction:: write_file
.. autofunction:: read_file
.. autofunction:: tune_gemm_in_file
.. autofunction:: mgpu_tune_gemm_in_file

View File

@ -32,7 +32,7 @@ from torch.testing._internal.common_dtype import (
floating_and_complex_types_and, floating_types_and, complex_types,
)
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
_get_torch_cuda_version, CDNA2OrLater
_get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.distributions.binomial import Binomial
@ -61,11 +61,9 @@ def set_tunableop_defaults():
return
# disable TunableOp and restore to default values
ordinal = torch.cuda.current_device()
filename = f"tunableop_results{ordinal}.csv"
torch.cuda.tunable.enable(False)
torch.cuda.tunable.record_untuned_enable(False)
torch.cuda.tunable.tuning_enable(True)
torch.cuda.tunable.set_filename(filename) # reset back to default filename for next unit test
torch.cuda.tunable.set_max_tuning_duration(30)
torch.cuda.tunable.set_max_tuning_iterations(100)
@ -4673,6 +4671,103 @@ class TestLinalg(TestCase):
assert torch.cuda.tunable.is_enabled() is False, "TunableOp should be off after resetting"
assert torch.cuda.tunable.get_max_tuning_iterations() == 100
@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.float)
def test_matmul_offline_mgpu_tunableop(self, device, dtype):
# Offline tuning with multiple GPUs.
# Case where you record GEMMs on one GPU, but then tune
# on multiple GPUs
import tempfile
import os
tmp_dir = tempfile.mkdtemp()
# Use all available GPUs for this test
total_gpus = torch.cuda.device_count()
# Test in try-finally block to avoid leaking state
# if test is interrupted.
try:
set_tunableop_defaults()
os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"] = "0"
# Pointing to temp files. The test cannot remove them on Windows because
# they are in use and locked
os.putenv("PYTORCH_TUNABLEOP_UNTUNED_FILENAME", os.path.join(tmp_dir, "tunableop_untuned.csv"))
os.putenv("PYTORCH_TUNABLEOP_FILENAME", os.path.join(tmp_dir, "tunableop_results.csv"))
# turn on untuned GEMM recording and turn off tuning
torch.cuda.tunable.enable(True)
torch.cuda.tunable.tuning_enable(False)
torch.cuda.tunable.record_untuned_enable(True)
# Choose matrix sizes that have not been used before
m = n = k = 23
# Create at least one GEMM per GPU, so when the GEMMs
# are distributed to the GPUs there is at least one
# GEMM per GPU.
for g in range(1, total_gpus + 1):
A = torch.rand(m * g, k * g, device=device, dtype=dtype)
B = torch.rand(k * g, n * g, device=device, dtype=dtype)
C = torch.matmul(A, B)
# check the untuned file was written
ordinal = torch.cuda.current_device()
untuned_filename = os.path.join(tmp_dir, f"tunableop_untuned{ordinal}.csv")
self.assertTrue(os.path.exists(untuned_filename))
# turn off untuned GEMM recording and turn on tuning
# We need to set the environment variables here instead of using
# the Python API, so that the child processes created will inherit
# these operations
os.environ["PYTORCH_TUNABLEOP_ENABLED"] = "1"
os.environ["PYTORCH_TUNABLEOP_TUNING"] = "1"
os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "1"
torch.cuda.tunable.mgpu_tune_gemm_in_file(untuned_filename, total_gpus)
assert torch.cuda.tunable.write_file()
# check the results files where written, one per gpu
for i in range(total_gpus):
result_filename = os.path.join(tmp_dir, f"tunableop_results{i}.csv")
self.assertTrue(os.path.exists(result_filename))
# Check the full results files was written, one per gpu
# for i in range(total_gpus):
# result_full_filename = os.path.join(tmp_dir, f"tunableop_results_full{i}.csv")
# self.assertTrue(os.path.exists(result_full_filename))
finally:
# disables TunableOp
torch.cuda.tunable.enable(False)
# undo all the environment variables set
try:
del os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"]
del os.environ["PYTORCH_TUNABLEOP_UNTUNED_FILENAME"]
del os.environ["PYTORCH_TUNABLEOP_FILENAME"]
del os.environ["PYTORCH_TUNABLEOP_ENABLED"]
del os.environ["PYTORCH_TUNABLEOP_TUNING"]
del os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"]
except KeyError:
pass
# # clean up, remove any files that were generated
try:
untuned_filename = os.path.join(tmp_dir, "tunableop_untuned0.csv")
os.remove(untuned_filename)
for i in range(total_gpus):
result_filename = os.path.join(tmp_dir, f"tunableop_results{i}.csv")
result_full_filename = os.path.join(tmp_dir, f"tunableop_results_full{i}.csv")
os.remove(result_filename)
os.remove(result_full_filename)
except FileNotFoundError:
pass
@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.float)

View File

@ -49,7 +49,7 @@ like so::
GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262
GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033
Note the "Validator" lines. If you change a library verison, or ROCm version, or
Note the "Validator" lines. If you change a library version, or ROCm version, or
PyTorch version, TunableOp will detect this and reject the tunings file because
the prior tunings are likely affected by other software changes.
@ -112,6 +112,11 @@ environment variables take precedence over any setting you manipulate using the
C++ or Python APIs.
"""
import concurrent.futures
import glob
import multiprocessing as mp
import os
import shutil
import warnings
from typing import Optional, Tuple
@ -137,6 +142,7 @@ __all__ = [
"write_file",
"read_file",
"tune_gemm_in_file",
"mgpu_tune_gemm_in_file",
]
@ -265,56 +271,250 @@ def tune_gemm_in_file(filename: str) -> None:
assert is_enabled()
assert tuning_is_enabled()
deviceid = torch.cuda.current_device()
with open(filename) as file:
for line in file:
if line.startswith("Gemm"):
untuned_gemm = line.strip().split(",")[:]
[op_sig, data_type, layout] = untuned_gemm[0].split("_")
if line.startswith(("Gemm", "ScaledGemm")):
_process_single_offline_gemm(line, deviceid)
transA = True if layout[0] == "T" else False
transB = True if layout[1] == "T" else False
dtype = {
"float": torch.float32,
"double": torch.float64,
"BFloat16": torch.bfloat16,
"Half": torch.half,
"c10::complex<double>": torch.complex128,
"c10::complex<float>": torch.complex64,
"Float8_e4m3fn": torch.float8_e4m3fn,
"Float8_e5m2": torch.float8_e5m2,
"Float8_e4m3fnuz": torch.float8_e4m3fnuz,
"Float8_e5m2fnuz": torch.float8_e5m2fnuz,
}.get(data_type, torch.half)
def _gather_unique_untuned_gemm_from_files(filename_pattern: str) -> set[str]:
r"""Process multiple untuned results file and return a set with duplicates removed."""
unique_gemm_entries = set() # set will avoid duplicates
if op_sig == "GemmTunableOp":
[n, m, k] = [int(g) for g in untuned_gemm[1].split("_")[1:]]
matA = (
torch.rand(k, m, dtype=dtype, device="cuda").t()
if transB
else torch.rand(m, k, dtype=dtype, device="cuda")
)
matB = (
torch.rand(n, k, dtype=dtype, device="cuda").t()
if transA
else torch.rand(k, n, dtype=dtype, device="cuda")
)
torch.mm(matA, matB)
elif op_sig == "GemmStridedBatchedTunableOp":
[n, m, k] = [int(g) for g in untuned_gemm[1].split("_")[1:4]]
[b] = [int(g) for g in untuned_gemm[1].split("_")[5:6]]
matA = (
torch.rand(b, k, m, dtype=dtype, device="cuda")
if transB
else torch.rand(b, m, k, dtype=dtype, device="cuda")
)
matB = (
torch.rand(b, n, k, dtype=dtype, device="cuda")
if transA
else torch.rand(b, k, n, dtype=dtype, device="cuda")
)
matA = matA.transpose(1, 2) if transB else matA
matB = matB.transpose(1, 2) if transA else matB
torch.bmm(matA, matB)
for file_path in glob.glob(filename_pattern):
with open(file_path) as file:
for line in file:
if line.startswith(("Gemm", "ScaledGemm")):
unique_gemm_entries.add(line)
return unique_gemm_entries
def _gather_tunableop_results() -> None:
r"""Gather results from multiple tunableop results file and create a single file."""
gemm_lines = set()
validator_lines = []
# Need to allow for the possibility that results filename was
# set with the Python API instead of with environment variable.
# Also possible that results filename was not set at all.
# There are several test cases to check, but ultimately we
# need a glob-able expression
results_filename = get_filename() # Note empty string could be returned here
if (
results_filename is not None and results_filename != ""
): # Case were the Python API was used to set the filename
dot_pos = results_filename.find(".")
if dot_pos != -1 and dot_pos > 0:
# Replace the character just to the left of the dot
filename_pattern = (
results_filename[: dot_pos - 1] + "?" + results_filename[dot_pos:]
)
else:
filename_pattern = "" # Needed to make linter happy
else: # Case where the environment variable was used to set the filename.
results_filename_env = os.getenv("PYTORCH_TUNABLEOP_FILENAME")
if results_filename_env is None or results_filename_env == "":
filename_pattern = "tunableop_results?.csv"
elif "%d" in results_filename_env:
filename_pattern = results_filename_env.replace("%d", "?")
else:
filename_pattern = results_filename_env.replace(".", "?.")
assert "?" in filename_pattern
FirstFile = False
matching_files = glob.glob(filename_pattern)
num_matching_files = len(matching_files)
for file_path in matching_files:
with open(file_path) as file:
for line in file:
if line.startswith("Validator"):
if not (FirstFile):
# Only read Validator from first file
validator_lines.append(line)
else:
warnings.warn(f"error: unkown op {op_sig}")
gemm_lines.add(line)
FirstFile = True
output_file = filename_pattern.replace("?", "_full0")
with open(output_file, "w") as out_file:
for line in validator_lines:
out_file.write(line)
for line in gemm_lines:
out_file.write(line)
# Create num_matching_copies of the results file
for i in range(1, num_matching_files):
duplicate_file = output_file.replace("0", str(i))
shutil.copy(output_file, duplicate_file)
def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
r"""Process a single untuned GEMM."""
deviceid = "cuda:" + str(gpu_id)
dtype_dict = {
"float": torch.float32,
"double": torch.float64,
"BFloat16": torch.bfloat16,
"Half": torch.half,
"c10::complex<double>": torch.complex128,
"c10::complex<float>": torch.complex64,
"Float8_e4m3fn": torch.float8_e4m3fn,
"Float8_e5m2": torch.float8_e5m2,
"Float8_e4m3fnuz": torch.float8_e4m3fnuz,
"Float8_e5m2fnuz": torch.float8_e5m2fnuz,
}
untuned_gemm = untuned_gemm_line.strip().split(",")[:]
underscore_count = untuned_gemm[0].count("_")
# Initialize dtype to make linter happy
dtype = None
dtypeA = None
dtypeB = None
dtypeC = None
if underscore_count == 2:
[op_sig, data_type, layout] = untuned_gemm[0].split("_")
transA = layout[0] == "T"
transB = layout[1] == "T"
dtype = dtype_dict.get(data_type)
else: # ScaledGEMM
untuned_gemm_temp = untuned_gemm[0].split("_")
op_sig = untuned_gemm_temp[0]
data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2]
data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4]
data_typeC = untuned_gemm_temp[5] + "_" + untuned_gemm_temp[6]
transA = untuned_gemm_temp[7][0] == "T"
transB = untuned_gemm_temp[7][1] == "T"
dtypeA = dtype_dict.get(data_typeA)
dtypeB = dtype_dict.get(data_typeB)
dtypeC = dtype_dict.get(data_typeC)
[n, m, k] = [int(g) for g in untuned_gemm[1].split("_")[1:4]]
if op_sig == "GemmTunableOp":
matA = (
torch.rand(k, m, dtype=dtype, device=deviceid).t()
if transB
else torch.rand(m, k, dtype=dtype, device=deviceid)
)
matB = (
torch.rand(n, k, dtype=dtype, device=deviceid).t()
if transA
else torch.rand(k, n, dtype=dtype, device=deviceid)
)
torch.mm(matA, matB)
elif op_sig == "GemmStridedBatchedTunableOp":
[b] = [int(g) for g in untuned_gemm[1].split("_")[5:6]]
matA = (
torch.rand(b, k, m, dtype=dtype, device=deviceid)
if transB
else torch.rand(b, m, k, dtype=dtype, device=deviceid)
)
matB = (
torch.rand(b, n, k, dtype=dtype, device=deviceid)
if transA
else torch.rand(b, k, n, dtype=dtype, device=deviceid)
)
matA = matA.transpose(1, 2) if transB else matA
matB = matB.transpose(1, 2) if transA else matB
torch.bmm(matA, matB)
elif op_sig == "ScaledGemmTunableOp":
fillA = 0.25
fillB = 0.75
scaleA = torch.tensor(0.8, device=deviceid)
scaleB = torch.tensor(0.9, device=deviceid)
matA = (
torch.full((k, m), fillA, dtype=dtypeA, device=deviceid).t()
if transB
else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid)
)
matB = (
torch.full((n, k), fillB, dtype=dtypeB, device=deviceid).t()
if transA
else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid)
)
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtypeC)
elif op_sig == "GemmAndBiasTunableOp":
# y = x*A^T + b
assert transA != transB
X = (
torch.rand(k, m, dtype=dtype, device=deviceid).t()
if transB
else torch.rand(m, k, dtype=dtype, device=deviceid)
)
matA = (
torch.rand(n, k, dtype=dtype, device=deviceid)
if transA
else torch.rand(k, n, dtype=dtype, device=deviceid).t()
)
bias = (
torch.rand(n, dtype=dtype, device=deviceid)
if transA
else torch.rand(m, dtype=dtype, device=deviceid)
)
torch.nn.functional.linear(X, matA, bias)
else:
warnings.warn(f"error: unknown op {op_sig}")
def _check_tuning_assertions() -> None:
r"""Helper function for multi-GPU tuning case. Need to check that TunableOp feature
is enabled and that tuning is enabled.
"""
assert is_enabled()
assert tuning_is_enabled()
def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None:
r"""Process one or more files and distribute work over one or more GPUs."""
unique_gemm_entries = _gather_unique_untuned_gemm_from_files(filename_pattern)
total_gpus = torch.cuda.device_count()
assert 1 <= num_gpus <= total_gpus
mp_context = mp.get_context("spawn")
checks = [] # empty list to hold futures
futures = [] # empty list to hold futures
# GEMM are assigned to GPUs in a round robin manner
h = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_gpus, mp_context=mp_context
) as executor:
# The workers are a separate process. TunableOp will be
# enabled in the child processes if the environment variable
# is set. However, if we enable TunableOp via the API
# the workers do not inherit this state. As a precaution,
# we need to check that TuningOp feature and tuning is
# enabled in the pool of processes.
for g in range(num_gpus):
check = executor.submit(_check_tuning_assertions)
checks.append(check)
for check in concurrent.futures.as_completed(checks):
check.result()
for line in unique_gemm_entries:
future = executor.submit(_process_single_offline_gemm, line, h)
futures.append(future)
h = (h + 1) % num_gpus
for future in concurrent.futures.as_completed(futures):
future.result()
torch.cuda.synchronize()
_gather_tunableop_results()