add fused support for xpu devices (#104517)

We want to add fused support for xpu devices in optimizer so we add 'xpu' to the fused support list.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104517
Approved by: https://github.com/ezyang
This commit is contained in:
Deng, Weishi
2023-07-05 21:06:56 +00:00
committed by PyTorch MergeBot
parent b5c2404116
commit f00f1d4cfb

View File

@ -14,7 +14,7 @@ def _get_fused_kernels_supported_devices() -> List[str]:
r"""
Return the device type list that supports fused kernels in optimizer.
"""
return ["cuda", torch._C._get_privateuse1_backend_name()]
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.