mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
inductor: Log the specific triton kernel that fails (#161452)
Added a optional name argument to SubprocPool.submit. We record this in a dictionary, and when raising exceptions, add the name. We manage the lifecycle the same as the pending futures. Added a specific testcase to make sure this logs correctly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161452 Approved by: https://github.com/masnesral
This commit is contained in:
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.async_compile import AsyncCompile, shutdown_compile_workers
|
||||
from torch._inductor.compile_worker.subproc_pool import SubprocException
|
||||
from torch._inductor.runtime.triton_compat import Config
|
||||
from torch._inductor.runtime.triton_heuristics import (
|
||||
generate_lookup_hash_from_source_code,
|
||||
@ -41,6 +42,21 @@ class TestAsyncCompile(TestCase):
|
||||
compiled_fn = torch.compile(fn)
|
||||
self.assertEqual(fn(x, y), compiled_fn(x, y))
|
||||
|
||||
@requires_gpu()
|
||||
@requires_triton()
|
||||
def test_bad_kernel(self):
|
||||
shutdown_compile_workers()
|
||||
|
||||
with config.patch(worker_start_method="subprocess", compile_threads=8):
|
||||
async_compile = AsyncCompile()
|
||||
AsyncCompile.wait_pool_ready()
|
||||
# Working around bug in wait_pool_ready()
|
||||
async_compile._ready_future.result(timeout=120)
|
||||
with self.assertRaises(SubprocException):
|
||||
async_compile.triton(
|
||||
"fake_kernel_name", source_code="This definitely doesn't exist"
|
||||
).result()
|
||||
|
||||
@requires_gpu()
|
||||
@requires_triton()
|
||||
@patch("torch._inductor.runtime.coordinate_descent_tuner.CoordescTuner.autotune")
|
||||
|
@ -38,7 +38,11 @@ from torch._inductor.codecache import (
|
||||
StaticAutotunerFuture,
|
||||
torch_key,
|
||||
)
|
||||
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
|
||||
from torch._inductor.compile_worker.subproc_pool import (
|
||||
AnyPool,
|
||||
SubprocException,
|
||||
SubprocPool,
|
||||
)
|
||||
from torch._inductor.compile_worker.tracked_process_pool import (
|
||||
TrackedProcessPoolExecutor,
|
||||
)
|
||||
@ -450,7 +454,11 @@ class AsyncCompile:
|
||||
)
|
||||
|
||||
def get_result() -> CachingAutotuner:
|
||||
kernel, elapsed_us = task.result()
|
||||
try:
|
||||
kernel, elapsed_us = task.result()
|
||||
except SubprocException as e:
|
||||
raise e.with_name(kernel_name) from e
|
||||
|
||||
# Now that we've compiled, we should clear the future
|
||||
# so it can't be used again
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
|
@ -90,8 +90,14 @@ class SubprocException(Exception):
|
||||
Thrown when a job in a subprocess raises an Exception.
|
||||
"""
|
||||
|
||||
def __init__(self, details: str) -> None:
|
||||
super().__init__(f"An exception occurred in a subprocess:\n\n{details}")
|
||||
def __init__(self, details: str, name: str = "<unknown>") -> None:
|
||||
self.details = details
|
||||
super().__init__(
|
||||
f"An exception occurred in a subprocess:\n\nName={name}\n{details}"
|
||||
)
|
||||
|
||||
def with_name(self, name: str) -> "SubprocException":
|
||||
return SubprocException(self.details, name)
|
||||
|
||||
|
||||
class SubprocPickler:
|
||||
|
Reference in New Issue
Block a user