mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix torch.normal ignores default_device (#144070)
Fixes #122886 1. Enable `torch.normal` working with `DeviceContext` to get default device which set via `set_default_device`. 2. Add hint in `set_default_device` doc, suggest use `torch.Tensor.to` method move to desired device explicitly. **Test Result** 1. **Doc Preview**  2. **Local Test** ```python >>> import torch >>> torch.normal(0.,1., (10,10)).device device(type='cpu') >>> torch.set_default_device('cuda') >>> torch.normal(0.,1., (10,10)).device device(type='cuda', index=0) ``` ```bash pytest test/test_tensor_creation_ops.py ```  ```bash lintrunner ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/144070 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
1fe3af2c68
commit
184549b2d7
@ -3386,6 +3386,14 @@ class TestRandomTensorCreation(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, r'normal expects all elements of std >= 0.0'):
|
||||
torch.normal(input, std)
|
||||
|
||||
def test_normal_default_device(self, device):
|
||||
try:
|
||||
torch.set_default_device(device)
|
||||
t = torch.normal(0, 1, (10, 10))
|
||||
finally:
|
||||
torch.set_default_device(None)
|
||||
self.assertEqual(str(t.device), device)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126834
|
||||
@xfailIfTorchDynamo
|
||||
@dtypes(torch.float, torch.double, torch.half)
|
||||
|
@ -1148,7 +1148,7 @@ def set_default_device(
|
||||
.. note::
|
||||
|
||||
This doesn't affect functions that create tensors that share the same memory as the input, like:
|
||||
:func:`torch.from_numpy` and :func:`torch.frombuffer`
|
||||
:func:`torch.from_numpy` and :func:`torch.frombuffer`. Using :func:`torch.Tensor.to` move tensor to desired device.
|
||||
|
||||
Args:
|
||||
device (device or string): the device to set as default
|
||||
|
@ -31,8 +31,7 @@ def _device_constructors():
|
||||
torch.linspace,
|
||||
torch.logspace,
|
||||
torch.nested.nested_tensor,
|
||||
# This function doesn't actually take a device argument
|
||||
# torch.normal,
|
||||
torch.normal,
|
||||
torch.ones,
|
||||
torch.rand,
|
||||
torch.randn,
|
||||
|
Reference in New Issue
Block a user