mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR reland #161142 which is reverted to be able to revert other PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161949 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
4cdaf8265d
commit
81b7b16618
@ -599,6 +599,8 @@ class AOTFxirTestCase(InductorTestCase):
|
||||
device = GPU_TYPE
|
||||
|
||||
def check(self, model, inp, dynamic_shapes=None, strict=False):
|
||||
if self.device == "xpu":
|
||||
raise unittest.SkipTest("The feature AOTFxir not currently ready for XPU")
|
||||
with torch.no_grad():
|
||||
ep = torch.export.export(
|
||||
model, inp, dynamic_shapes=dynamic_shapes, strict=strict
|
||||
|
@ -381,11 +381,19 @@ class CommonTemplate:
|
||||
input_reader = InputReader()
|
||||
load_args(input_reader)
|
||||
args = input_reader.args
|
||||
if self.device == "xpu":
|
||||
atol = 1e-7
|
||||
rtol = 1e-5
|
||||
else:
|
||||
atol = None
|
||||
rtol = None
|
||||
|
||||
self._run_and_compare(
|
||||
forward,
|
||||
*args,
|
||||
expected_num_block_pointers=4,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
|
@ -250,6 +250,7 @@ XPU_BLOCKLIST = [
|
||||
"profiler/test_profiler_tree",
|
||||
"profiler/test_record_function",
|
||||
"profiler/test_torch_tidy",
|
||||
"test_openreg",
|
||||
]
|
||||
|
||||
XPU_TEST = [
|
||||
|
@ -1344,6 +1344,13 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
|
||||
|
||||
return flex_decode_configs
|
||||
|
||||
def _prune_exhaustive_configs(
|
||||
self,
|
||||
configs: list[BaseConfig],
|
||||
dtype_size: int,
|
||||
) -> list[BaseConfig]:
|
||||
return configs
|
||||
|
||||
|
||||
class MTIAConfigHeuristic(BaseConfigHeuristic):
|
||||
"""
|
||||
|
@ -21078,6 +21078,7 @@ op_db: list[OpInfo] = [
|
||||
# NOTE: Only run on MPS
|
||||
DecorateInfo(unittest.skip('Skipped!'), device_type='cpu'),
|
||||
DecorateInfo(unittest.skip('Skipped!'), device_type='cuda'),
|
||||
DecorateInfo(unittest.skip('Skipped!'), device_type='xpu'),
|
||||
DecorateInfo(unittest.skip('Skipped!'), device_type='meta'),
|
||||
),),
|
||||
OpInfo(
|
||||
|
Reference in New Issue
Block a user