[Fix XPU CI] [Inductor UT] Fix test cases broken by community. (#165406)

Fixes #163159, Fixes #164098, Fixes #164097, Fixes #164099, Fixes #165025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165406
Approved by: https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
xinan.lin
2025-10-16 00:53:32 +00:00
committed by PyTorch MergeBot
parent 36371b8ec7
commit e5a9c247bc
5 changed files with 13 additions and 3 deletions

View File

@ -5338,7 +5338,7 @@ class AOTInductorTestsTemplate:
record_shapes=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
getattr(torch.profiler.ProfilerActivity, GPU_TYPE.upper()),
],
) as prof,
):

View File

@ -4640,6 +4640,7 @@ class CommonTemplate:
(torch.randn([4, 4, 4]),),
)
@skipIfXpu(msg="Incorrect reference on XPU, see issue #165392")
def test_conv1d_with_permute(self):
# fix https://github.com/pytorch/pytorch/issues/159462
class ConvModel(nn.Module):
@ -15783,7 +15784,7 @@ if RUN_GPU:
).run(code)
else:
FileCheck().check_count(
"with torch.cuda._DeviceGuard(0)", 1, exactly=True
f"with torch.{GPU_TYPE}._DeviceGuard(0)", 1, exactly=True
).run(code)
class RNNTest(TestCase):

View File

@ -111,6 +111,8 @@ test_failures = {
# Failed to find dynamic for loop variable:
#
"test_conv1d_with_permute_dynamic_shapes": TestFailure(("cpu",), is_skip=True),
# XPU always convert conv1d to conv2d and can not match the expected codegen result.
"test_conv1d_depthwise_dynamic_shapes": TestFailure(("xpu",), is_skip=True),
"test_arange1_dynamic_shapes": TestFailure(("cpu",)),
"test_arange2_dynamic_shapes": TestFailure(("cpu",)),
"test_arange3_dynamic_shapes": TestFailure(("cpu",)),

View File

@ -646,7 +646,7 @@ inductor_override_kwargs["xpu"] = {
("tanh", f16): {"atol": 1e-4, "rtol": 1e-2},
("nn.functional.embedding_bag", f32): {"check_gradient": False},
("nn.functional.embedding_bag", f64): {"check_gradient": False},
("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-5, "rtol": 5e-3},
("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-4, "rtol": 0.01},
("_unsafe_masked_index", f16): {
"reference_in_float": True,
"atol": 3e-4,

View File

@ -1152,6 +1152,13 @@ class TritonOverrides(OpOverrides):
out = f"triton.language.div_rn({x}, {y})"
else:
out = f"({x} / {y})"
# Workaround here since the functionality of div_rn has not ready on XPU.
# TODO: remove this workaround after https://github.com/intel/intel-xpu-backend-for-triton/issues/5306
# resolved.
if torch.xpu.is_available():
out = f"({x} / {y})"
if low_precision_fp_var(x) or low_precision_fp_var(y):
out_dtype = get_dtype_handler().truediv(x, y)
if out_dtype in (torch.float16, torch.float32):