Compare commits

...

11 Commits

Author SHA1 Message Date
46ae8482b7 Update on "[Inductor] Fix combo kernels for cpu backend"
This PR fixes two issues
Fixes: #167780 combo_kernel fails with CppScheduling backend
Fixes: #168067 combo_kernel fails with mixed cpu/cuda nodes


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-18 16:04:50 -08:00
0bcd348d04 Update base for Update on "[Inductor] Fix combo kernels for cpu backend"
This PR fixes two issues
Fixes: #167780 combo_kernel fails with CppScheduling backend
Fixes: #168067 combo_kernel fails with mixed cpu/cuda nodes


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-18 16:04:50 -08:00
5d99920d92 Update on "[Inductor] Fix combo kernels for cpu backend"
This PR fixes two issues
Fixes: #167780 combo_kernel fails with CppScheduling backend
Fixes: #168067 combo_kernel fails with mixed cpu/cuda nodes


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-18 13:52:28 -08:00
145c27b573 Update base for Update on "[Inductor] Fix combo kernels for cpu backend"
This PR fixes two issues
Fixes: #167780 combo_kernel fails with CppScheduling backend
Fixes: #168067 combo_kernel fails with mixed cpu/cuda nodes


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-18 13:52:28 -08:00
61bfcdecdb Update on "[Inductor] Fix combo kernels for cpu backend"
This PR fixes two issues
Fixes: #167780 combo_kernel fails with CppScheduling backend
Fixes: #168067 combo_kernel fails with mixed cpu/cuda nodes


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-17 23:13:28 -08:00
93cc036d0e Update base for Update on "[Inductor] Fix combo kernels for cpu backend"
This PR fixes two issues
Fixes: #167780 combo_kernel fails with CppScheduling backend
Fixes: #168067 combo_kernel fails with mixed cpu/cuda nodes


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-17 23:13:28 -08:00
cc59358752 Update on "[Inductor] Fix combo kernels for cpu backend"
Fixes: #167780 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-17 21:57:02 -08:00
c8ac2cd7f4 Update base for Update on "[Inductor] Fix combo kernels for cpu backend"
Fixes: #167780 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-17 21:57:02 -08:00
1589528680 Update on "[Inductor] Fix combo kernels for cpu backend"
Fixes: #167780 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-17 16:04:19 -08:00
030c5e7a40 Update base for Update on "[Inductor] Fix combo kernels for cpu backend"
Fixes: #167780 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-17 16:04:19 -08:00
7f62da025d [WIP][Inductor] Fix combo kernels for cpu backend
[ghstack-poisoned]
2025-11-13 16:41:15 -08:00
4 changed files with 50 additions and 38 deletions

View File

@ -171,7 +171,8 @@ if RUN_CPU:
BaseTest("test_add_complex4"),
BaseTest("test_add_complex4", test_build_separate=True),
BaseTest("test_as_strided"), # buffer reuse
BaseTest("test_bernoulli1"),
BaseTest("test_bernoulli1_combo_kernels_False"),
BaseTest("test_bernoulli1_combo_kernels_True"),
BaseTest("test_bitwise"), # int32
BaseTest("test_bmm1"),
BaseTest("test_bmm1", test_build_separate=True),

View File

@ -196,7 +196,8 @@ if RUN_GPU:
BaseTest("test_add_complex4"),
BaseTest("test_as_strided"), # buffer reuse
BaseTest("test_batch_norm_2d_2"),
BaseTest("test_bernoulli1"),
BaseTest("test_bernoulli1_combo_kernels_False"),
BaseTest("test_bernoulli1_combo_kernels_True"),
BaseTest("test_bitwise"), # int32
BaseTest("test_bmm1"),
BaseTest("test_bmm2"),

View File

@ -4878,29 +4878,32 @@ class CommonTemplate:
@skip_if_gpu_halide # slow
@xfail_if_mps # Non-divisible input sizes are not implemented on MPS device
def test_adaptive_avg_pool2d1(self):
def fn(x):
return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d(
x + 1, (2, 5)
@parametrize("combo_kernels", (False, True))
def test_adaptive_avg_pool2d1(self, combo_kernels):
with config.patch(combo_kernels=combo_kernels):
def fn(x):
return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d(
x + 1, (2, 5)
)
self.common(
fn,
(torch.randn(2, 4, 16, 16),),
check_lowp=False,
)
self.common(
fn,
(torch.randn(2, 4, 16, 16),),
check_lowp=False,
)
# lowering to avg_pool2d case
self.common(
fn,
(torch.randn(2, 4, 3, 3),),
)
# lowering to avg_pool2d case
self.common(
fn,
(torch.randn(2, 4, 3, 3),),
)
# no-op case
self.common(
fn,
(torch.randn(2, 4, 6, 6),),
)
# no-op case
self.common(
fn,
(torch.randn(2, 4, 6, 6),),
)
@xfail_if_mps # Non-divisible input sizes are not implemented on MPS device
def test_adaptive_avg_pool2d2(self):
@ -8484,22 +8487,25 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.common(fn, [torch.randn(1, 1024), torch.randn(1, 1024, 2)])
@parametrize("combo_kernels", (False, True))
@config.patch(fallback_random=True)
def test_bernoulli1(self):
def fn(a):
b = a.clone()
# aten.bernoulli_() uses aten.bernoulli.p() behind the scene, so it will be decomposed.
return aten.bernoulli_(b).sum() / torch.prod(torch.tensor(a.size()))
def test_bernoulli1(self, combo_kernels):
with config.patch(combo_kernels=combo_kernels):
p = 0.3
self.common(
fn,
[
torch.ones(200, 200) * p,
],
atol=p * 0.06,
rtol=0.06,
)
def fn(a):
b = a.clone()
# aten.bernoulli_() uses aten.bernoulli.p() behind the scene, so it will be decomposed.
return aten.bernoulli_(b).sum() / torch.prod(torch.tensor(a.size()))
p = 0.3
self.common(
fn,
[
torch.ones(200, 200) * p,
],
atol=p * 0.06,
rtol=0.06,
)
@skip_if_triton_cpu
def test_bernoulli2(self):

View File

@ -6107,14 +6107,18 @@ class Scheduler:
If config.benchmark_fusion is False, always return True.
Otherwise, return True if fusion can brings speedup.
"""
if not config.benchmark_combo_kernel:
return True
subkernel_nodes = nodes
device = subkernel_nodes[0].get_device()
if not all(node.get_device() == device for node in subkernel_nodes):
return False
# don't support benchmark fusion for CPU C++ backend right now.
if device is None or (device.type == "cpu" and config.cpu_backend != "triton"):
return False
if not config.benchmark_combo_kernel:
return True
from triton.compiler.errors import CompilationError