Compare commits

...

2 Commits

Author SHA1 Message Date
5195be51d2 Update
[ghstack-poisoned]
2025-11-13 20:37:02 -08:00
46d5b69502 Update (base update)
[ghstack-poisoned]
2025-11-13 20:37:02 -08:00
2 changed files with 7 additions and 0 deletions

View File

@ -146,6 +146,8 @@ class CooperativeReductionTests(TestCase):
self.assertIn("cooperative_reduction_grid", source_code)
else:
self.assertIn("@triton_heuristics.cooperative_reduction", source_code)
if GPU_TYPE == "cuda":
self.assertIn("'launch_cooperative_grid': True", source_code)
if "async_compile.multi_kernel" not in source_code:
self.assertEqual(
torch._inductor.metrics.generated_kernel_count, expect_kernel_count

View File

@ -5059,6 +5059,11 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
),
}
if self.cooperative_reduction:
# Cooperative reductions rely on multi-block synchronization that
# requires cooperative-grid launches to avoid hanging.
triton_meta["launch_cooperative_grid"] = True
# Skip memory optimization for forward of the training loop where we expect
# every new node will increase the peak memory and our greedy approach would
# introduce a lot of unnecessary cpu copies.