Compare commits

...

1 Commits

Author SHA1 Message Date
38407821b1 Setting the USA_TMA to be true by default 2025-10-06 14:39:25 -07:00
2 changed files with 4 additions and 4 deletions

View File

@ -325,8 +325,8 @@ def flex_attention(
"num_buffers_warp_spec", num_buffers_warp_spec
)
# USE TMA = false by default
cur_kernel_options.setdefault("USE_TMA", False)
# USE TMA = true by default
cur_kernel_options.setdefault("USE_TMA", True)
cur_kernel_options.setdefault("BLOCK_M", conf.block_m)
cur_kernel_options.setdefault("BLOCK_N", conf.block_n)

View File

@ -328,8 +328,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
"num_buffers_warp_spec", num_buffers_warp_spec
)
# Set default to False
cur_kernel_options.setdefault("USE_TMA", False)
# Set default to True
cur_kernel_options.setdefault("USE_TMA", True)
# Add ROCm-specific parameters if they exist in the config
for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: