Add num_store to inductor_meta and use it to scale persistent reduction x block (#162446)

Scale up XBLOCK for contiguous persistent reductions based on rnumel and number of loads + stores

<img width="928" height="656" alt="Screenshot 2025-09-18 at 5 02 57 PM" src="https://github.com/user-attachments/assets/ec3c561f-2a3f-4459-9e14-653715898da3" />

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162446
Approved by: https://github.com/v0i0, https://github.com/eellison, https://github.com/shunting314
ghstack dependencies: #162296
This commit is contained in:
PaulZhang12
2025-10-05 19:41:29 -07:00
committed by PyTorch MergeBot
parent f11ac803d7
commit 600267ea56
3 changed files with 24 additions and 2 deletions

View File

@ -2053,6 +2053,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.stores = IndentedBuffer()
self.num_load = 0
self.num_store = 0
self.num_reduction = 0
self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix)
@ -2266,6 +2267,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
name, fused_node_names
)
):
self.num_store -= 1
names_to_remove.add(name)
for name in names_to_remove:
@ -2732,6 +2734,7 @@ class CSEProxy(DefaultHandler):
self._update_store_cache(name, value)
if name not in V.graph.removed_buffers:
self.kernel.store(name, index, value, mode=mode)
self.kernel.num_store += 1
def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
self.kernel.device_assert_async(cond, msg)
@ -2741,6 +2744,7 @@ class CSEProxy(DefaultHandler):
self._update_store_cache(name, value)
if name not in V.graph.removed_buffers:
self.kernel.num_store += 1
return self.kernel.store_reduction(name, index, value)
def reduction(