mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCM] Support Multi-GPU offline tuning in TunableOp (#139673)
This PR enhances offline tuning to support multi-GPUs. High-level description of algorithm: - Duplicate GEMMs are first eliminated - GEMMs are distributed to multi-GPUs for tuning - Results are gathered into a file with `_full` in the filename Also adding support for GemmAndBias and ScaledGemm Pull Request resolved: https://github.com/pytorch/pytorch/pull/139673 Approved by: https://github.com/jeffdaily, https://github.com/hongxiayang
This commit is contained in:
committed by
PyTorch MergeBot
parent
5b4c864672
commit
a99332eb25
@ -952,6 +952,9 @@ test_distributed() {
|
||||
python test/run_test.py --cpp --verbose -i cpp/HashStoreTest
|
||||
python test/run_test.py --cpp --verbose -i cpp/TCPStoreTest
|
||||
|
||||
echo "Testing multi-GPU linalg tests"
|
||||
python test/run_test.py -i test_linalg.py -k test_matmul_offline_mgpu_tunable --verbose
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
|
||||
MPIEXEC=$(command -v mpiexec)
|
||||
if [[ -n "$MPIEXEC" ]]; then
|
||||
|
||||
@ -80,18 +80,23 @@ fastest available implementation across both rocblas and hipblaslt.
|
||||
## Offline Tuning
|
||||
|
||||
### Motivation
|
||||
Basically it is used for workload with high-memory utilization where one might run out of memory with regular tuning.
|
||||
There are a couple of uses cases for offline tuning.
|
||||
|
||||
One use case, is a workload with a high-memory utilization where one might run out of memory with regular tuning.
|
||||
|
||||
Another use case would be a workload that is compute intensive to run and it would be more resource efficient to collect
|
||||
the GEMMs for the workload once, and then tune repeatedly with different tuning parameters or libraries.
|
||||
|
||||
### Workflow
|
||||
There are basically two steps:
|
||||
1) Set the environment variables to collect the untuned GEMM and this will generate `tunableop_untuned?.csv` ("?" is placeholder for the GPU ID), like:
|
||||
1) Set the environment variables to collect the untuned GEMM and this will generate `tunableop_untuned0.csv`
|
||||
```
|
||||
PYTORCH_TUNABLEOP_ENABLED=1
|
||||
PYTORCH_TUNABLEOP_TUNING=0
|
||||
PYTORCH_TUNABLEOP_RECORD_UNTUNED=1
|
||||
...
|
||||
```
|
||||
2) Run a Python script that reads the `tunableop_untuned?.csv` and generates the `tunableop_results?.csv`, like:
|
||||
2) Run a Python script that reads the `tunableop_untuned0.csv` and generates the `tunableop_results0.csv`, like:
|
||||
```
|
||||
import torch.cuda.tunable as tunable
|
||||
import os
|
||||
@ -99,9 +104,29 @@ import os
|
||||
os.putenv('PYTORCH_TUNABLEOP_ENABLED', '1')
|
||||
os.putenv('PYTORCH_TUNABLEOP_TUNING', '1')
|
||||
os.putenv('PYTORCH_TUNABLEOP_RECORD_UNTUNED', '0')
|
||||
tunable.tune_gemm_in_file("tunableop_results?.csv")
|
||||
tunable.tune_gemm_in_file("tunableop_untuned0.csv")
|
||||
```
|
||||
|
||||
It is also possible to take multiple untuned files and distribute the GEMMs for tuning to multiple GPUs
|
||||
within a single node. In the first step, the GEMMs are first gathered and duplicate GEMMs are eliminated.
|
||||
Next, the GEMMs are distributed to different GPUs for tuning. After all GEMMs are tuned, the results from
|
||||
all the GPUs are then gathered into a single file whose base filename has `_full0` appended to it
|
||||
(e.g. `tunableop_results_full0.csv`). Finally, this new file, containing the gathered results, will be
|
||||
duplicated N times, once for each GPU as convenience to the user will run the workload with the tuned
|
||||
configuration on N GPUs.
|
||||
|
||||
```
|
||||
if __name__ == "__main__":
|
||||
num_gpus = 8 # number of GPUs that will be used during the tuning process
|
||||
tunable.mgpu_tune_gemm_in_file("tunableop_untuned?.csv", num_gpus)
|
||||
```
|
||||
|
||||
Note that the usage of the `mgpu_tune_gemm_in_file` API is different from its single GPU counterpart
|
||||
(`tune_gemm_in_file`). The body of the Python script that calls the API must be wrapped in `main()` as shown
|
||||
due to the use of concurrent futures module. The argument to `mgpu_tune_gemm_in_file` must contain a wild card
|
||||
expression (? or *) to generate the list of untuned files containing the GEMMs to be processed. The `num_gpus`
|
||||
must between 1 and the total number of GPUs available.
|
||||
|
||||
## Tuning Context
|
||||
The behavior of TunableOp is currently manipulated through environment variables, the C++ interface of
|
||||
at::cuda::tunable::getTuningContext(), or the `torch.cuda.tunable` python interfaces. The environment variables take
|
||||
@ -153,6 +178,7 @@ All python APIs exist in the `torch.cuda.tunable` module.
|
||||
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
||||
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
||||
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
|
||||
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |
|
||||
|
||||
### C++ Interface
|
||||
Example:
|
||||
|
||||
@ -33,3 +33,4 @@ API Reference
|
||||
.. autofunction:: write_file
|
||||
.. autofunction:: read_file
|
||||
.. autofunction:: tune_gemm_in_file
|
||||
.. autofunction:: mgpu_tune_gemm_in_file
|
||||
|
||||
@ -32,7 +32,7 @@ from torch.testing._internal.common_dtype import (
|
||||
floating_and_complex_types_and, floating_types_and, complex_types,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
|
||||
_get_torch_cuda_version, CDNA2OrLater
|
||||
_get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
|
||||
from torch.testing._internal.common_mkldnn import bf32_on_and_off
|
||||
from torch.distributions.binomial import Binomial
|
||||
@ -61,11 +61,9 @@ def set_tunableop_defaults():
|
||||
return
|
||||
|
||||
# disable TunableOp and restore to default values
|
||||
ordinal = torch.cuda.current_device()
|
||||
filename = f"tunableop_results{ordinal}.csv"
|
||||
torch.cuda.tunable.enable(False)
|
||||
torch.cuda.tunable.record_untuned_enable(False)
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
torch.cuda.tunable.set_filename(filename) # reset back to default filename for next unit test
|
||||
torch.cuda.tunable.set_max_tuning_duration(30)
|
||||
torch.cuda.tunable.set_max_tuning_iterations(100)
|
||||
|
||||
@ -4673,6 +4671,103 @@ class TestLinalg(TestCase):
|
||||
assert torch.cuda.tunable.is_enabled() is False, "TunableOp should be off after resetting"
|
||||
assert torch.cuda.tunable.get_max_tuning_iterations() == 100
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.float)
|
||||
def test_matmul_offline_mgpu_tunableop(self, device, dtype):
|
||||
# Offline tuning with multiple GPUs.
|
||||
# Case where you record GEMMs on one GPU, but then tune
|
||||
# on multiple GPUs
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
|
||||
# Use all available GPUs for this test
|
||||
total_gpus = torch.cuda.device_count()
|
||||
|
||||
# Test in try-finally block to avoid leaking state
|
||||
# if test is interrupted.
|
||||
try:
|
||||
set_tunableop_defaults()
|
||||
os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"] = "0"
|
||||
|
||||
# Pointing to temp files. The test cannot remove them on Windows because
|
||||
# they are in use and locked
|
||||
os.putenv("PYTORCH_TUNABLEOP_UNTUNED_FILENAME", os.path.join(tmp_dir, "tunableop_untuned.csv"))
|
||||
os.putenv("PYTORCH_TUNABLEOP_FILENAME", os.path.join(tmp_dir, "tunableop_results.csv"))
|
||||
|
||||
# turn on untuned GEMM recording and turn off tuning
|
||||
torch.cuda.tunable.enable(True)
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
torch.cuda.tunable.record_untuned_enable(True)
|
||||
|
||||
# Choose matrix sizes that have not been used before
|
||||
m = n = k = 23
|
||||
|
||||
# Create at least one GEMM per GPU, so when the GEMMs
|
||||
# are distributed to the GPUs there is at least one
|
||||
# GEMM per GPU.
|
||||
for g in range(1, total_gpus + 1):
|
||||
A = torch.rand(m * g, k * g, device=device, dtype=dtype)
|
||||
B = torch.rand(k * g, n * g, device=device, dtype=dtype)
|
||||
C = torch.matmul(A, B)
|
||||
|
||||
# check the untuned file was written
|
||||
ordinal = torch.cuda.current_device()
|
||||
untuned_filename = os.path.join(tmp_dir, f"tunableop_untuned{ordinal}.csv")
|
||||
self.assertTrue(os.path.exists(untuned_filename))
|
||||
|
||||
# turn off untuned GEMM recording and turn on tuning
|
||||
# We need to set the environment variables here instead of using
|
||||
# the Python API, so that the child processes created will inherit
|
||||
# these operations
|
||||
os.environ["PYTORCH_TUNABLEOP_ENABLED"] = "1"
|
||||
os.environ["PYTORCH_TUNABLEOP_TUNING"] = "1"
|
||||
os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "1"
|
||||
|
||||
torch.cuda.tunable.mgpu_tune_gemm_in_file(untuned_filename, total_gpus)
|
||||
assert torch.cuda.tunable.write_file()
|
||||
|
||||
# check the results files where written, one per gpu
|
||||
for i in range(total_gpus):
|
||||
result_filename = os.path.join(tmp_dir, f"tunableop_results{i}.csv")
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
|
||||
# Check the full results files was written, one per gpu
|
||||
# for i in range(total_gpus):
|
||||
# result_full_filename = os.path.join(tmp_dir, f"tunableop_results_full{i}.csv")
|
||||
# self.assertTrue(os.path.exists(result_full_filename))
|
||||
|
||||
finally:
|
||||
# disables TunableOp
|
||||
torch.cuda.tunable.enable(False)
|
||||
|
||||
# undo all the environment variables set
|
||||
try:
|
||||
del os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"]
|
||||
del os.environ["PYTORCH_TUNABLEOP_UNTUNED_FILENAME"]
|
||||
del os.environ["PYTORCH_TUNABLEOP_FILENAME"]
|
||||
del os.environ["PYTORCH_TUNABLEOP_ENABLED"]
|
||||
del os.environ["PYTORCH_TUNABLEOP_TUNING"]
|
||||
del os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# # clean up, remove any files that were generated
|
||||
try:
|
||||
untuned_filename = os.path.join(tmp_dir, "tunableop_untuned0.csv")
|
||||
os.remove(untuned_filename)
|
||||
for i in range(total_gpus):
|
||||
result_filename = os.path.join(tmp_dir, f"tunableop_results{i}.csv")
|
||||
result_full_filename = os.path.join(tmp_dir, f"tunableop_results_full{i}.csv")
|
||||
os.remove(result_filename)
|
||||
os.remove(result_full_filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.float)
|
||||
|
||||
@ -49,7 +49,7 @@ like so::
|
||||
GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262
|
||||
GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033
|
||||
|
||||
Note the "Validator" lines. If you change a library verison, or ROCm version, or
|
||||
Note the "Validator" lines. If you change a library version, or ROCm version, or
|
||||
PyTorch version, TunableOp will detect this and reject the tunings file because
|
||||
the prior tunings are likely affected by other software changes.
|
||||
|
||||
@ -112,6 +112,11 @@ environment variables take precedence over any setting you manipulate using the
|
||||
C++ or Python APIs.
|
||||
|
||||
"""
|
||||
import concurrent.futures
|
||||
import glob
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@ -137,6 +142,7 @@ __all__ = [
|
||||
"write_file",
|
||||
"read_file",
|
||||
"tune_gemm_in_file",
|
||||
"mgpu_tune_gemm_in_file",
|
||||
]
|
||||
|
||||
|
||||
@ -265,56 +271,250 @@ def tune_gemm_in_file(filename: str) -> None:
|
||||
assert is_enabled()
|
||||
assert tuning_is_enabled()
|
||||
|
||||
deviceid = torch.cuda.current_device()
|
||||
|
||||
with open(filename) as file:
|
||||
for line in file:
|
||||
if line.startswith("Gemm"):
|
||||
untuned_gemm = line.strip().split(",")[:]
|
||||
[op_sig, data_type, layout] = untuned_gemm[0].split("_")
|
||||
if line.startswith(("Gemm", "ScaledGemm")):
|
||||
_process_single_offline_gemm(line, deviceid)
|
||||
|
||||
transA = True if layout[0] == "T" else False
|
||||
transB = True if layout[1] == "T" else False
|
||||
|
||||
dtype = {
|
||||
"float": torch.float32,
|
||||
"double": torch.float64,
|
||||
"BFloat16": torch.bfloat16,
|
||||
"Half": torch.half,
|
||||
"c10::complex<double>": torch.complex128,
|
||||
"c10::complex<float>": torch.complex64,
|
||||
"Float8_e4m3fn": torch.float8_e4m3fn,
|
||||
"Float8_e5m2": torch.float8_e5m2,
|
||||
"Float8_e4m3fnuz": torch.float8_e4m3fnuz,
|
||||
"Float8_e5m2fnuz": torch.float8_e5m2fnuz,
|
||||
}.get(data_type, torch.half)
|
||||
def _gather_unique_untuned_gemm_from_files(filename_pattern: str) -> set[str]:
|
||||
r"""Process multiple untuned results file and return a set with duplicates removed."""
|
||||
unique_gemm_entries = set() # set will avoid duplicates
|
||||
|
||||
if op_sig == "GemmTunableOp":
|
||||
[n, m, k] = [int(g) for g in untuned_gemm[1].split("_")[1:]]
|
||||
matA = (
|
||||
torch.rand(k, m, dtype=dtype, device="cuda").t()
|
||||
if transB
|
||||
else torch.rand(m, k, dtype=dtype, device="cuda")
|
||||
)
|
||||
matB = (
|
||||
torch.rand(n, k, dtype=dtype, device="cuda").t()
|
||||
if transA
|
||||
else torch.rand(k, n, dtype=dtype, device="cuda")
|
||||
)
|
||||
torch.mm(matA, matB)
|
||||
elif op_sig == "GemmStridedBatchedTunableOp":
|
||||
[n, m, k] = [int(g) for g in untuned_gemm[1].split("_")[1:4]]
|
||||
[b] = [int(g) for g in untuned_gemm[1].split("_")[5:6]]
|
||||
matA = (
|
||||
torch.rand(b, k, m, dtype=dtype, device="cuda")
|
||||
if transB
|
||||
else torch.rand(b, m, k, dtype=dtype, device="cuda")
|
||||
)
|
||||
matB = (
|
||||
torch.rand(b, n, k, dtype=dtype, device="cuda")
|
||||
if transA
|
||||
else torch.rand(b, k, n, dtype=dtype, device="cuda")
|
||||
)
|
||||
matA = matA.transpose(1, 2) if transB else matA
|
||||
matB = matB.transpose(1, 2) if transA else matB
|
||||
torch.bmm(matA, matB)
|
||||
for file_path in glob.glob(filename_pattern):
|
||||
with open(file_path) as file:
|
||||
for line in file:
|
||||
if line.startswith(("Gemm", "ScaledGemm")):
|
||||
unique_gemm_entries.add(line)
|
||||
|
||||
return unique_gemm_entries
|
||||
|
||||
|
||||
def _gather_tunableop_results() -> None:
|
||||
r"""Gather results from multiple tunableop results file and create a single file."""
|
||||
gemm_lines = set()
|
||||
validator_lines = []
|
||||
|
||||
# Need to allow for the possibility that results filename was
|
||||
# set with the Python API instead of with environment variable.
|
||||
# Also possible that results filename was not set at all.
|
||||
# There are several test cases to check, but ultimately we
|
||||
# need a glob-able expression
|
||||
results_filename = get_filename() # Note empty string could be returned here
|
||||
|
||||
if (
|
||||
results_filename is not None and results_filename != ""
|
||||
): # Case were the Python API was used to set the filename
|
||||
dot_pos = results_filename.find(".")
|
||||
if dot_pos != -1 and dot_pos > 0:
|
||||
# Replace the character just to the left of the dot
|
||||
filename_pattern = (
|
||||
results_filename[: dot_pos - 1] + "?" + results_filename[dot_pos:]
|
||||
)
|
||||
else:
|
||||
filename_pattern = "" # Needed to make linter happy
|
||||
else: # Case where the environment variable was used to set the filename.
|
||||
results_filename_env = os.getenv("PYTORCH_TUNABLEOP_FILENAME")
|
||||
if results_filename_env is None or results_filename_env == "":
|
||||
filename_pattern = "tunableop_results?.csv"
|
||||
elif "%d" in results_filename_env:
|
||||
filename_pattern = results_filename_env.replace("%d", "?")
|
||||
else:
|
||||
filename_pattern = results_filename_env.replace(".", "?.")
|
||||
|
||||
assert "?" in filename_pattern
|
||||
|
||||
FirstFile = False
|
||||
matching_files = glob.glob(filename_pattern)
|
||||
num_matching_files = len(matching_files)
|
||||
for file_path in matching_files:
|
||||
with open(file_path) as file:
|
||||
for line in file:
|
||||
if line.startswith("Validator"):
|
||||
if not (FirstFile):
|
||||
# Only read Validator from first file
|
||||
validator_lines.append(line)
|
||||
else:
|
||||
warnings.warn(f"error: unkown op {op_sig}")
|
||||
gemm_lines.add(line)
|
||||
|
||||
FirstFile = True
|
||||
|
||||
output_file = filename_pattern.replace("?", "_full0")
|
||||
|
||||
with open(output_file, "w") as out_file:
|
||||
for line in validator_lines:
|
||||
out_file.write(line)
|
||||
for line in gemm_lines:
|
||||
out_file.write(line)
|
||||
|
||||
# Create num_matching_copies of the results file
|
||||
for i in range(1, num_matching_files):
|
||||
duplicate_file = output_file.replace("0", str(i))
|
||||
shutil.copy(output_file, duplicate_file)
|
||||
|
||||
|
||||
def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
||||
r"""Process a single untuned GEMM."""
|
||||
|
||||
deviceid = "cuda:" + str(gpu_id)
|
||||
|
||||
dtype_dict = {
|
||||
"float": torch.float32,
|
||||
"double": torch.float64,
|
||||
"BFloat16": torch.bfloat16,
|
||||
"Half": torch.half,
|
||||
"c10::complex<double>": torch.complex128,
|
||||
"c10::complex<float>": torch.complex64,
|
||||
"Float8_e4m3fn": torch.float8_e4m3fn,
|
||||
"Float8_e5m2": torch.float8_e5m2,
|
||||
"Float8_e4m3fnuz": torch.float8_e4m3fnuz,
|
||||
"Float8_e5m2fnuz": torch.float8_e5m2fnuz,
|
||||
}
|
||||
|
||||
untuned_gemm = untuned_gemm_line.strip().split(",")[:]
|
||||
|
||||
underscore_count = untuned_gemm[0].count("_")
|
||||
|
||||
# Initialize dtype to make linter happy
|
||||
dtype = None
|
||||
dtypeA = None
|
||||
dtypeB = None
|
||||
dtypeC = None
|
||||
|
||||
if underscore_count == 2:
|
||||
[op_sig, data_type, layout] = untuned_gemm[0].split("_")
|
||||
transA = layout[0] == "T"
|
||||
transB = layout[1] == "T"
|
||||
dtype = dtype_dict.get(data_type)
|
||||
else: # ScaledGEMM
|
||||
untuned_gemm_temp = untuned_gemm[0].split("_")
|
||||
op_sig = untuned_gemm_temp[0]
|
||||
data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2]
|
||||
data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4]
|
||||
data_typeC = untuned_gemm_temp[5] + "_" + untuned_gemm_temp[6]
|
||||
transA = untuned_gemm_temp[7][0] == "T"
|
||||
transB = untuned_gemm_temp[7][1] == "T"
|
||||
dtypeA = dtype_dict.get(data_typeA)
|
||||
dtypeB = dtype_dict.get(data_typeB)
|
||||
dtypeC = dtype_dict.get(data_typeC)
|
||||
|
||||
[n, m, k] = [int(g) for g in untuned_gemm[1].split("_")[1:4]]
|
||||
if op_sig == "GemmTunableOp":
|
||||
matA = (
|
||||
torch.rand(k, m, dtype=dtype, device=deviceid).t()
|
||||
if transB
|
||||
else torch.rand(m, k, dtype=dtype, device=deviceid)
|
||||
)
|
||||
matB = (
|
||||
torch.rand(n, k, dtype=dtype, device=deviceid).t()
|
||||
if transA
|
||||
else torch.rand(k, n, dtype=dtype, device=deviceid)
|
||||
)
|
||||
torch.mm(matA, matB)
|
||||
elif op_sig == "GemmStridedBatchedTunableOp":
|
||||
[b] = [int(g) for g in untuned_gemm[1].split("_")[5:6]]
|
||||
matA = (
|
||||
torch.rand(b, k, m, dtype=dtype, device=deviceid)
|
||||
if transB
|
||||
else torch.rand(b, m, k, dtype=dtype, device=deviceid)
|
||||
)
|
||||
matB = (
|
||||
torch.rand(b, n, k, dtype=dtype, device=deviceid)
|
||||
if transA
|
||||
else torch.rand(b, k, n, dtype=dtype, device=deviceid)
|
||||
)
|
||||
matA = matA.transpose(1, 2) if transB else matA
|
||||
matB = matB.transpose(1, 2) if transA else matB
|
||||
torch.bmm(matA, matB)
|
||||
elif op_sig == "ScaledGemmTunableOp":
|
||||
fillA = 0.25
|
||||
fillB = 0.75
|
||||
scaleA = torch.tensor(0.8, device=deviceid)
|
||||
scaleB = torch.tensor(0.9, device=deviceid)
|
||||
matA = (
|
||||
torch.full((k, m), fillA, dtype=dtypeA, device=deviceid).t()
|
||||
if transB
|
||||
else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid)
|
||||
)
|
||||
matB = (
|
||||
torch.full((n, k), fillB, dtype=dtypeB, device=deviceid).t()
|
||||
if transA
|
||||
else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid)
|
||||
)
|
||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtypeC)
|
||||
elif op_sig == "GemmAndBiasTunableOp":
|
||||
# y = x*A^T + b
|
||||
assert transA != transB
|
||||
|
||||
X = (
|
||||
torch.rand(k, m, dtype=dtype, device=deviceid).t()
|
||||
if transB
|
||||
else torch.rand(m, k, dtype=dtype, device=deviceid)
|
||||
)
|
||||
matA = (
|
||||
torch.rand(n, k, dtype=dtype, device=deviceid)
|
||||
if transA
|
||||
else torch.rand(k, n, dtype=dtype, device=deviceid).t()
|
||||
)
|
||||
bias = (
|
||||
torch.rand(n, dtype=dtype, device=deviceid)
|
||||
if transA
|
||||
else torch.rand(m, dtype=dtype, device=deviceid)
|
||||
)
|
||||
torch.nn.functional.linear(X, matA, bias)
|
||||
else:
|
||||
warnings.warn(f"error: unknown op {op_sig}")
|
||||
|
||||
|
||||
def _check_tuning_assertions() -> None:
|
||||
r"""Helper function for multi-GPU tuning case. Need to check that TunableOp feature
|
||||
is enabled and that tuning is enabled.
|
||||
"""
|
||||
assert is_enabled()
|
||||
assert tuning_is_enabled()
|
||||
|
||||
|
||||
def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None:
|
||||
r"""Process one or more files and distribute work over one or more GPUs."""
|
||||
unique_gemm_entries = _gather_unique_untuned_gemm_from_files(filename_pattern)
|
||||
|
||||
total_gpus = torch.cuda.device_count()
|
||||
|
||||
assert 1 <= num_gpus <= total_gpus
|
||||
|
||||
mp_context = mp.get_context("spawn")
|
||||
|
||||
checks = [] # empty list to hold futures
|
||||
futures = [] # empty list to hold futures
|
||||
|
||||
# GEMM are assigned to GPUs in a round robin manner
|
||||
h = 0
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_gpus, mp_context=mp_context
|
||||
) as executor:
|
||||
# The workers are a separate process. TunableOp will be
|
||||
# enabled in the child processes if the environment variable
|
||||
# is set. However, if we enable TunableOp via the API
|
||||
# the workers do not inherit this state. As a precaution,
|
||||
# we need to check that TuningOp feature and tuning is
|
||||
# enabled in the pool of processes.
|
||||
for g in range(num_gpus):
|
||||
check = executor.submit(_check_tuning_assertions)
|
||||
checks.append(check)
|
||||
|
||||
for check in concurrent.futures.as_completed(checks):
|
||||
check.result()
|
||||
|
||||
for line in unique_gemm_entries:
|
||||
future = executor.submit(_process_single_offline_gemm, line, h)
|
||||
futures.append(future)
|
||||
h = (h + 1) % num_gpus
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
future.result()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
_gather_tunableop_results()
|
||||
|
||||
Reference in New Issue
Block a user