mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added a optional name argument to SubprocPool.submit. We record this in a dictionary, and when raising exceptions, add the name. We manage the lifecycle the same as the pending futures. Added a specific testcase to make sure this logs correctly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161452 Approved by: https://github.com/masnesral
144 lines
5.0 KiB
Python
144 lines
5.0 KiB
Python
# Owner(s): ["module: inductor"]
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch._inductor.async_compile import AsyncCompile, shutdown_compile_workers
|
|
from torch._inductor.compile_worker.subproc_pool import SubprocException
|
|
from torch._inductor.runtime.triton_compat import Config
|
|
from torch._inductor.runtime.triton_heuristics import (
|
|
generate_lookup_hash_from_source_code,
|
|
)
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import fresh_cache
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
GPU_TYPE,
|
|
requires_gpu,
|
|
requires_triton,
|
|
)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestAsyncCompile(TestCase):
|
|
@requires_gpu()
|
|
@requires_triton()
|
|
@parametrize("method", ("subprocess", "fork", "spawn"))
|
|
def test_pool(self, method):
|
|
def fn(x, y):
|
|
return x + y
|
|
|
|
x = torch.rand(10).to(GPU_TYPE)
|
|
y = torch.rand(10).to(GPU_TYPE)
|
|
|
|
with config.patch("worker_start_method", method):
|
|
shutdown_compile_workers()
|
|
AsyncCompile.wait_pool_ready()
|
|
|
|
with fresh_cache():
|
|
compiled_fn = torch.compile(fn)
|
|
self.assertEqual(fn(x, y), compiled_fn(x, y))
|
|
|
|
@requires_gpu()
|
|
@requires_triton()
|
|
def test_bad_kernel(self):
|
|
shutdown_compile_workers()
|
|
|
|
with config.patch(worker_start_method="subprocess", compile_threads=8):
|
|
async_compile = AsyncCompile()
|
|
AsyncCompile.wait_pool_ready()
|
|
# Working around bug in wait_pool_ready()
|
|
async_compile._ready_future.result(timeout=120)
|
|
with self.assertRaises(SubprocException):
|
|
async_compile.triton(
|
|
"fake_kernel_name", source_code="This definitely doesn't exist"
|
|
).result()
|
|
|
|
@requires_gpu()
|
|
@requires_triton()
|
|
@patch("torch._inductor.runtime.coordinate_descent_tuner.CoordescTuner.autotune")
|
|
@parametrize("method", ("subprocess", "fork", "spawn"))
|
|
def test_autotune_lookup_table(self, mock_autotune, method):
|
|
def f(a, b):
|
|
return (a @ b).to(torch.float32).sum(dim=1)
|
|
|
|
# Fake name to make sure the lookup table is name agnostic
|
|
func_def = """
|
|
def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
|
xnumel = 1024
|
|
r0_numel = 11776
|
|
rnumel = r0_numel
|
|
RBLOCK: tl.constexpr = R0_BLOCK
|
|
xoffset = tl.program_id(0) * XBLOCK
|
|
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
|
xmask = xindex < xnumel
|
|
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
|
rbase = r0_base
|
|
x0 = xindex
|
|
_tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
|
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
|
r0_index = r0_offset + r0_base
|
|
r0_mask = r0_index < r0_numel
|
|
roffset = r0_offset
|
|
rindex = r0_index
|
|
r0_1 = r0_index
|
|
tmp0 = tl.load(in_ptr0 + (r0_1 + 11776*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
|
tmp1 = tmp0.to(tl.float32)
|
|
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
|
|
tmp4 = _tmp3 + tmp2
|
|
_tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
|
|
tmp3 = tl.sum(_tmp3, 1)[:, None]
|
|
tl.store(out_ptr0 + (x0), tmp3, xmask)
|
|
|
|
"""
|
|
|
|
fn_hash = generate_lookup_hash_from_source_code(
|
|
str({"x": 1024, "r0_": 16384}), func_def
|
|
)
|
|
block_configs = {
|
|
"XBLOCK": 1,
|
|
"R0_BLOCK": 128,
|
|
}
|
|
num_warps = 16
|
|
num_stages = 1
|
|
autotune_lookup_table = {
|
|
fn_hash: {**block_configs, "num_warps": num_warps, "num_stages": num_stages}
|
|
}
|
|
autotune_config = Config(
|
|
block_configs, num_warps=num_warps, num_stages=num_stages
|
|
)
|
|
mock_autotune.return_value = autotune_config
|
|
|
|
a = torch.randn(1152, 1024, device=GPU_TYPE, dtype=torch.float16).T
|
|
b = torch.randn(1152, 11776, device=GPU_TYPE, dtype=torch.float16)
|
|
compiled_f = torch.compile(f)
|
|
|
|
with config.patch(
|
|
{
|
|
"autotune_lookup_table": autotune_lookup_table,
|
|
"coordinate_descent_tuning": True,
|
|
"worker_start_method": method,
|
|
}
|
|
):
|
|
shutdown_compile_workers()
|
|
AsyncCompile.wait_pool_ready()
|
|
with fresh_cache():
|
|
compiled_f(a, b)
|
|
|
|
# Check that the input to coordinate descent (the resulting chosen config)
|
|
# is the same as the one in the lookup table
|
|
mock_autotune.assert_called_once()
|
|
args, _ = mock_autotune.call_args
|
|
self.assertTrue(isinstance(args[1], Config))
|
|
|
|
self.assertEqual(args[1].kwargs, autotune_config.kwargs)
|
|
self.assertEqual(args[1].num_warps, autotune_config.num_warps)
|
|
self.assertEqual(args[1].num_stages, autotune_config.num_stages)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|