mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Paren-matching kernel launch check without external deps (#60778)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60778 Matches parens and the opening `<<<` to make a more accurate kernel launch check. Test Plan: ``` buck test //caffe2/test:kernel_launch_checks ``` Reviewed By: ngimel Differential Revision: D29401624 fbshipit-source-id: 8649af7c33e67dbb24044af0134b1cea6f2e5dc3
This commit is contained in:
committed by
Facebook GitHub Bot
parent
88b0518a83
commit
94cdbbf48d
@ -1,5 +1,5 @@
|
|||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
from torch.testing import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches
|
from torch.testing._check_kernel_launches import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches
|
||||||
|
|
||||||
|
|
||||||
class AlwaysCheckCudaLaunchTest(TestCase):
|
class AlwaysCheckCudaLaunchTest(TestCase):
|
||||||
@ -38,6 +38,36 @@ some_function_call<TemplateArg><<<1,2,0,stream>>> ( arg1 , arg2 , arg3 ) ;
|
|||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
"""))
|
"""))
|
||||||
|
|
||||||
|
# Does it work for lambdas?
|
||||||
|
self.assertEqual(1, check_code_for_cuda_kernel_launches(r"""
|
||||||
|
rrelu_with_noise_cuda_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||||
|
numel,
|
||||||
|
rng_engine_inputs,
|
||||||
|
output_data,
|
||||||
|
input_data,
|
||||||
|
noise_data,
|
||||||
|
lower,
|
||||||
|
upper,
|
||||||
|
[] __device__ (curandStatePhilox4_32_10_t* state) {
|
||||||
|
return curand_uniform2_double(state);
|
||||||
|
});
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|
||||||
|
rrelu_with_noise_cuda_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||||
|
numel,
|
||||||
|
rng_engine_inputs,
|
||||||
|
output_data,
|
||||||
|
input_data,
|
||||||
|
noise_data,
|
||||||
|
lower,
|
||||||
|
upper,
|
||||||
|
[] __device__ (curandStatePhilox4_32_10_t* state) {
|
||||||
|
return curand_uniform2_double(state);
|
||||||
|
});
|
||||||
|
uh oh;
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
"""))
|
||||||
|
|
||||||
def test_check_cuda_launches(self):
|
def test_check_cuda_launches(self):
|
||||||
unsafeLaunchesCount = check_cuda_kernel_launches()
|
unsafeLaunchesCount = check_cuda_kernel_launches()
|
||||||
self.assertTrue(unsafeLaunchesCount == 0)
|
self.assertTrue(unsafeLaunchesCount == 0)
|
||||||
|
@ -1,44 +1,63 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"check_code_for_cuda_kernel_launches",
|
"check_code_for_cuda_kernel_launches",
|
||||||
"check_cuda_kernel_launches",
|
"check_cuda_kernel_launches",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Files to exclude (match is done with suffix)
|
# FILES TO EXCLUDE (match is done with suffix using `endswith`)
|
||||||
exclude_files = [
|
# You wouldn't drive without a seatbelt, though, so why would you
|
||||||
"aten/src/ATen/native/cuda/Activation.cu"
|
# launch a kernel without some safety? Use this as a quick workaround
|
||||||
]
|
# for a problem with the checker, fix the checker, then de-exclude
|
||||||
|
# the files in question.
|
||||||
|
exclude_files: List[str] = []
|
||||||
|
|
||||||
# Regular expression identifies a kernel launch indicator by
|
# Without using a C++ AST we can't 100% detect kernel launches, so we
|
||||||
# finding something approximating the pattern ">>>(arguments);"
|
# model them as having the pattern "<<<parameters>>>(arguments);"
|
||||||
# It then requires that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
|
# We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
|
||||||
# the next command.
|
# the next statement.
|
||||||
# It allows a single backslash `\` between the end of the launch
|
|
||||||
# command and the beginning of the kernel check. This handles
|
|
||||||
# cases where the kernel launch is in a multiline preprocessor
|
|
||||||
# definition.
|
|
||||||
#
|
#
|
||||||
# There are various ways this can fail:
|
# We model the next statement as ending at the next `}` or `;`.
|
||||||
# * If the semicolon is in a string for some reason
|
# If we see `}` then a clause ended (bad) if we see a semi-colon then
|
||||||
# * If there's a triply-nested template
|
# we expect the launch check just before it.
|
||||||
# But this should be sufficient to detect and fix most problem
|
#
|
||||||
# instances and can be refined before the test is made binding
|
# Since the kernel launch can include lambda statements, it's important
|
||||||
kernel_launch_regex = re.compile(r"""
|
# to find the correct end-paren of the kernel launch. Doing this with
|
||||||
^.*>>> # Identifies kernel launch
|
# pure regex requires recursive regex, which aren't part of the Python
|
||||||
\s* # Maybe some whitespace (includes newlines)
|
# standard library. To avoid an additional dependency, we build a prefix
|
||||||
\([^;]+\); # And then arguments in parens and semi-colon
|
# regex that finds the start of a kernel launch, use a paren-matching
|
||||||
(?! # Negative lookahead: we trigger if we don't find the launch guard
|
# algorithm to find the end of the launch, and then another regex to
|
||||||
\s* # Maybe some whitespace (includes newlines)
|
# determine if a launch check is present.
|
||||||
\\? # 0 or 1 backslashes (for launches in preprocessor macros)
|
|
||||||
\s* # Maybe some whitespace (includes newlines)
|
# Finds potential starts of kernel launches
|
||||||
(?:[0-9]+: )? # Detects and ignores a line numbering, if present
|
kernel_launch_start = re.compile(
|
||||||
\s* # Maybe some whitespace (includes newlines)
|
r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK\(\); # Kernel launch guard!
|
)
|
||||||
) # End negative lookahead
|
|
||||||
""", flags=re.MULTILINE | re.VERBOSE)
|
# This pattern should start at the character after the final paren of the
|
||||||
|
# kernel launch. It returns a match if the launch check is not the next statement
|
||||||
|
has_check = re.compile(
|
||||||
|
r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE
|
||||||
|
)
|
||||||
|
|
||||||
|
def find_matching_paren(s: str, startpos: int) -> int:
|
||||||
|
"""Given a string "prefix (unknown number of characters) suffix"
|
||||||
|
and the position of the first `(` returns the index of the character
|
||||||
|
1 past the `)`, accounting for paren nesting
|
||||||
|
"""
|
||||||
|
opening = 0
|
||||||
|
for i, c in enumerate(s[startpos:]):
|
||||||
|
if c == '(':
|
||||||
|
opening += 1
|
||||||
|
elif c == ')':
|
||||||
|
opening -= 1
|
||||||
|
if opening == 0:
|
||||||
|
return startpos + i + 1
|
||||||
|
|
||||||
|
raise IndexError("Closing parens not found!")
|
||||||
|
|
||||||
|
|
||||||
def should_exclude_file(filename) -> bool:
|
def should_exclude_file(filename) -> bool:
|
||||||
@ -68,10 +87,15 @@ def check_code_for_cuda_kernel_launches(code, filename=None):
|
|||||||
code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines
|
code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines
|
||||||
code = '\n'.join(code) # Put it back together
|
code = '\n'.join(code) # Put it back together
|
||||||
|
|
||||||
results = kernel_launch_regex.findall(code) # Search for bad launches
|
num_launches_without_checks = 0
|
||||||
for r in results:
|
for m in kernel_launch_start.finditer(code):
|
||||||
print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{r}", file=sys.stderr)
|
end_paren = find_matching_paren(code, m.end() - 1)
|
||||||
return len(results)
|
if has_check.match(code, end_paren):
|
||||||
|
num_launches_without_checks += 1
|
||||||
|
context = code[m.start():end_paren + 1]
|
||||||
|
print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr)
|
||||||
|
|
||||||
|
return num_launches_without_checks
|
||||||
|
|
||||||
|
|
||||||
def check_file(filename):
|
def check_file(filename):
|
||||||
|
Reference in New Issue
Block a user