Compare commits

...

1 Commits

Author SHA1 Message Date
4b691036ad save 2025-02-26 09:54:07 -08:00
3 changed files with 31 additions and 3 deletions

View File

@ -3355,7 +3355,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
def codegen_kernel(self, name=None):
code = IndentedBuffer()
# triton_meta?
size_hints = {}
breakpoint()
for prefix, numel in self.numels.items():
if prefix_is_reduction(prefix) and not self.inside_reduction:
continue
@ -3390,13 +3393,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
code.splice(self.imports_for_benchmark_kernel())
argdefs, _, signature, _ = self.args.python_argdefs()
breakpoint()
# maps actual expression to SizeArg if it is in sizevars replacements
for i, arg in enumerate(signature):
if isinstance(arg, SizeArg):
# mypy is unhappy about the sympy.Expr
breakpoint()
# mypy is unhappy about the sykimpy.Expr
# type for the key of the dict below
symbol = cast(sympy.Symbol, arg.expr)
if (arg.expr == 1):
breakpoint()
# if triton_version_uses_attrs_dict():
# for arg in argdefs:
# if arg.name == 'load_seed_offset' and arg.is_constexpr == False:
# argdefs.pop(-1)
# argdefs.append(ArgName(name='load_seed_offset', is_constexpr=True))
if symbol in V.graph.sizevars.inv_precomputed_replacements:
breakpoint()
signature[i] = SizeArg(
arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol]
)
@ -3470,11 +3484,20 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
triton_meta_signature = signature_to_meta(
signature, size_dtype=self.index_dtype, argdefs=argdefs
)
breakpoint()
for name, arg_type in triton_meta_signature.items():
breakpoint()
if name == 'load_seed_offset' and arg_type == 'constexpr':
triton_meta_signature = {'in_ptr0': '*i64', 'out_ptr0': '*fp32', 'load_seed_offset': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}
#check triton_meta_signature here
breakpoint()
triton_meta: dict[str, Any] = {
"signature": triton_meta_signature,
"device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
"constants": {},
}
breakpoint()
# Skip memory optimization for forward of the training loop where we expect
# every new node will increase the peak memory and our greedy approach would
@ -3501,6 +3524,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
inductor_meta["kernel_num_gb"] = num_gb
triton_meta["configs"] = [config_of(signature)]
breakpoint()
# Triton compiler includes equal_to_1 args into constants even
# when they are not constexpr. otherwise there may be a segfault
@ -3508,9 +3532,11 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
for arg_num in equal_1_arg_indices(signature): # type: ignore[index]
breakpoint()
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr]
triton_meta["constants"] = {}
self.triton_meta = triton_meta
breakpoint()
self.codegen_body()

View File

@ -1270,6 +1270,7 @@ class PythonWrapperCodegen(CodeGen):
file_path,
)
# Execute the code to autotune kernels
breakpoint()
try:
exec(tuning_code, scope)
except Exception as e:

View File

@ -33,8 +33,9 @@ def get_kernel_category_by_source_code(src_code: str) -> str:
Similar to get_kernel_category but use the source code. Call this API
if we have not compile the src_code to module yet.
"""
breakpoint()
choices = [
ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code
ch for ch in _kernel_category_choices if ch and f"@triton_heuristics.{ch}" in src_code
]
if len(choices) == 1:
return choices[0]