mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add Python Module Bindings for the MPS backend (#94417)
- This PR is a prerequisite for the upcoming Memory Leak Detection PR. - Enable global manual seeding via `torch.manual_seed()` + test case - Add `torch.mps.synchronize()` to wait for MPS stream to finish + test case - Enable the following python interfaces for MPS: `torch.mps.[get_rng_state(), set_rng_state(), synchronize(), manual_seed(), seed()]` - Added some test cases in test_mps.py - Added `mps.rst` to document the `torch.mps` module. - Fixed the failure with `test_public_bindings.py` Description of new files added: - `torch/csrc/mps/Module.cpp`: implements `torch._C` module functions for `torch.mps` and `torch.backends.mps`. - `torch/mps/__init__.py`: implements Python bindings for `torch.mps` module. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94417 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
a0f9abdcb6
commit
bdd8f518d7
@ -39,6 +39,10 @@ def manual_seed(seed) -> torch._C.Generator:
|
||||
if not torch.cuda._is_in_bad_fork():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
import torch.mps
|
||||
if not torch.mps._is_in_bad_fork():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
return default_generator.manual_seed(seed)
|
||||
|
||||
|
||||
@ -52,6 +56,10 @@ def seed() -> int:
|
||||
if not torch.cuda._is_in_bad_fork():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
import torch.mps
|
||||
if not torch.mps._is_in_bad_fork():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user