add test_noncontiguous_samples_skips for xpu

This commit is contained in:
Deng, Daisy
2025-11-07 06:53:50 +00:00
parent 2aa9a642a6
commit 2a87efebbe

View File

@ -101,9 +101,9 @@ _ref_test_ops = tuple(
)
)
ops_skips = defaultdict(dict)
test_out_skips = defaultdict(dict)
ops_skips["xpu"] = {
test_out_skips["xpu"] = {
"_native_batch_norm_legit": {torch.float32},
"addmv": {torch.float32},
"cholesky_inverse": {torch.float32},
@ -118,6 +118,17 @@ ops_skips["xpu"] = {
"var": {torch.float32},
}
test_noncontiguous_samples_skips = defaultdict(dict)
# Disable cases for https://github.com/intel/torch-xpu-ops/issues/2295
test_noncontiguous_samples_skips["xpu"] = {
"masked.amax": {torch.int64},
"masked.amin": {torch.int64},
"nn.functional.interpolate": {torch.float32},
"amax": {torch.int64},
"masked.prod": {torch.int64},
}
def reduction_dtype_filter(op):
if (
@ -845,6 +856,12 @@ class TestCommon(TestCase):
@suppress_warnings
@ops(op_db, allowed_dtypes=(torch.float32, torch.long, torch.complex64))
def test_noncontiguous_samples(self, device, dtype, op):
if dtype in test_noncontiguous_samples_skips[device.split(":")[0]].get(
op.name, set()
):
raise unittest.SkipTest(
f"Skipped! {op.name} does not support dtype {dtype} on {device} for this test."
)
test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad)
for sample_input in sample_inputs:
@ -1072,9 +1089,9 @@ class TestCommon(TestCase):
# - if device, dtype are passed, device and dtype should match
@ops(ops_and_refs, dtypes=OpDTypes.any_one)
def test_out(self, device, dtype, op):
if dtype in ops_skips[device.split(":")[0]].get(op.name, set()):
if dtype in test_out_skips[device.split(":")[0]].get(op.name, set()):
raise unittest.SkipTest(
f"Skipped! {op.name} does not support dtype {dtype} on {device}."
f"Skipped! {op.name} does not support dtype {dtype} on {device} for this test."
)
# Prefers running in float32 but has a fallback for the first listed supported dtype