Integrate xpu into torch.Generator and torch.seed (#109866)

Integrate torch.xpu.Generator into torch.Generator
Integrate torch.xpu.seed into torch.seed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109866
Approved by: https://github.com/ezyang
This commit is contained in:
Lei, Zhenyuan
2023-09-27 17:44:42 +00:00
committed by PyTorch MergeBot
parent 0511df0ee9
commit 633bd0765e
4 changed files with 44 additions and 6 deletions

View File

@ -43,6 +43,9 @@ def manual_seed(seed) -> torch._C.Generator:
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
if hasattr(torch, 'xpu') and not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed(seed)
_seed_custom_device(seed)
return default_generator.manual_seed(seed)
@ -62,6 +65,9 @@ def seed() -> int:
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
if hasattr(torch, 'xpu') and not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed(seed)
_seed_custom_device(seed)
return seed