mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 01:15:12 +08:00
add test_noncontiguous_samples_skips for xpu
This commit is contained in:
@ -101,9 +101,9 @@ _ref_test_ops = tuple(
|
||||
)
|
||||
)
|
||||
|
||||
ops_skips = defaultdict(dict)
|
||||
test_out_skips = defaultdict(dict)
|
||||
|
||||
ops_skips["xpu"] = {
|
||||
test_out_skips["xpu"] = {
|
||||
"_native_batch_norm_legit": {torch.float32},
|
||||
"addmv": {torch.float32},
|
||||
"cholesky_inverse": {torch.float32},
|
||||
@ -118,6 +118,17 @@ ops_skips["xpu"] = {
|
||||
"var": {torch.float32},
|
||||
}
|
||||
|
||||
test_noncontiguous_samples_skips = defaultdict(dict)
|
||||
|
||||
# Disable cases for https://github.com/intel/torch-xpu-ops/issues/2295
|
||||
test_noncontiguous_samples_skips["xpu"] = {
|
||||
"masked.amax": {torch.int64},
|
||||
"masked.amin": {torch.int64},
|
||||
"nn.functional.interpolate": {torch.float32},
|
||||
"amax": {torch.int64},
|
||||
"masked.prod": {torch.int64},
|
||||
}
|
||||
|
||||
|
||||
def reduction_dtype_filter(op):
|
||||
if (
|
||||
@ -845,6 +856,12 @@ class TestCommon(TestCase):
|
||||
@suppress_warnings
|
||||
@ops(op_db, allowed_dtypes=(torch.float32, torch.long, torch.complex64))
|
||||
def test_noncontiguous_samples(self, device, dtype, op):
|
||||
if dtype in test_noncontiguous_samples_skips[device.split(":")[0]].get(
|
||||
op.name, set()
|
||||
):
|
||||
raise unittest.SkipTest(
|
||||
f"Skipped! {op.name} does not support dtype {dtype} on {device} for this test."
|
||||
)
|
||||
test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
|
||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad)
|
||||
for sample_input in sample_inputs:
|
||||
@ -1072,9 +1089,9 @@ class TestCommon(TestCase):
|
||||
# - if device, dtype are passed, device and dtype should match
|
||||
@ops(ops_and_refs, dtypes=OpDTypes.any_one)
|
||||
def test_out(self, device, dtype, op):
|
||||
if dtype in ops_skips[device.split(":")[0]].get(op.name, set()):
|
||||
if dtype in test_out_skips[device.split(":")[0]].get(op.name, set()):
|
||||
raise unittest.SkipTest(
|
||||
f"Skipped! {op.name} does not support dtype {dtype} on {device}."
|
||||
f"Skipped! {op.name} does not support dtype {dtype} on {device} for this test."
|
||||
)
|
||||
|
||||
# Prefers running in float32 but has a fallback for the first listed supported dtype
|
||||
|
||||
Reference in New Issue
Block a user