Fix torch.normal ignores default_device (#144070)

Fixes #122886

1. Enable `torch.normal` working with `DeviceContext` to get default device which set via `set_default_device`.
2. Add hint in `set_default_device` doc, suggest use `torch.Tensor.to` method move to desired device explicitly.

**Test Result**
1. **Doc Preview**
![image](https://github.com/user-attachments/assets/eb69c334-be2b-4dc5-bdce-567da21e1635)

2. **Local Test**
```python
>>> import torch
>>> torch.normal(0.,1., (10,10)).device
device(type='cpu')
>>> torch.set_default_device('cuda')
>>> torch.normal(0.,1., (10,10)).device
device(type='cuda', index=0)
```

```bash
pytest test/test_tensor_creation_ops.py
```

![image](https://github.com/user-attachments/assets/8b466b55-f162-4b83-8b20-71de2c1d0914)

```bash
lintrunner
```
![image](https://github.com/user-attachments/assets/5b269c50-da57-47ed-8500-4edf2c2295e4)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144070
Approved by: https://github.com/ezyang
This commit is contained in:
zeshengzong
2025-01-10 08:19:52 +00:00
committed by PyTorch MergeBot
parent 1fe3af2c68
commit 184549b2d7
3 changed files with 10 additions and 3 deletions

View File

@ -3386,6 +3386,14 @@ class TestRandomTensorCreation(TestCase):
with self.assertRaisesRegex(RuntimeError, r'normal expects all elements of std >= 0.0'):
torch.normal(input, std)
def test_normal_default_device(self, device):
try:
torch.set_default_device(device)
t = torch.normal(0, 1, (10, 10))
finally:
torch.set_default_device(None)
self.assertEqual(str(t.device), device)
# https://github.com/pytorch/pytorch/issues/126834
@xfailIfTorchDynamo
@dtypes(torch.float, torch.double, torch.half)

View File

@ -1148,7 +1148,7 @@ def set_default_device(
.. note::
This doesn't affect functions that create tensors that share the same memory as the input, like:
:func:`torch.from_numpy` and :func:`torch.frombuffer`
:func:`torch.from_numpy` and :func:`torch.frombuffer`. Using :func:`torch.Tensor.to` move tensor to desired device.
Args:
device (device or string): the device to set as default

View File

@ -31,8 +31,7 @@ def _device_constructors():
torch.linspace,
torch.logspace,
torch.nested.nested_tensor,
# This function doesn't actually take a device argument
# torch.normal,
torch.normal,
torch.ones,
torch.rand,
torch.randn,