Revert "Fix torch.normal ignores default_device (#144070)"

This reverts commit 184549b2d7e59acfc6e47d121e9ebb50648945b3.

Reverted https://github.com/pytorch/pytorch/pull/144070 on behalf of https://github.com/ezyang due to broken a specific use case ([comment](https://github.com/pytorch/pytorch/pull/144070#issuecomment-2590681953))
This commit is contained in:
PyTorch MergeBot
2025-01-14 17:41:58 +00:00
parent 7977a3638e
commit d21738f24a
3 changed files with 3 additions and 10 deletions

View File

@ -3388,14 +3388,6 @@ class TestRandomTensorCreation(TestCase):
with self.assertRaisesRegex(RuntimeError, r'normal expects all elements of std >= 0.0'): with self.assertRaisesRegex(RuntimeError, r'normal expects all elements of std >= 0.0'):
torch.normal(input, std) torch.normal(input, std)
def test_normal_default_device(self, device):
try:
torch.set_default_device(device)
t = torch.normal(0, 1, (10, 10))
finally:
torch.set_default_device(None)
self.assertEqual(str(t.device), device)
# https://github.com/pytorch/pytorch/issues/126834 # https://github.com/pytorch/pytorch/issues/126834
@xfailIfTorchDynamo @xfailIfTorchDynamo
@dtypes(torch.float, torch.double, torch.half) @dtypes(torch.float, torch.double, torch.half)

View File

@ -1166,7 +1166,7 @@ def set_default_device(
.. note:: .. note::
This doesn't affect functions that create tensors that share the same memory as the input, like: This doesn't affect functions that create tensors that share the same memory as the input, like:
:func:`torch.from_numpy` and :func:`torch.frombuffer`. Using :func:`torch.Tensor.to` move tensor to desired device. :func:`torch.from_numpy` and :func:`torch.frombuffer`
Args: Args:
device (device or string): the device to set as default device (device or string): the device to set as default

View File

@ -31,7 +31,8 @@ def _device_constructors():
torch.linspace, torch.linspace,
torch.logspace, torch.logspace,
torch.nested.nested_tensor, torch.nested.nested_tensor,
torch.normal, # This function doesn't actually take a device argument
# torch.normal,
torch.ones, torch.ones,
torch.rand, torch.rand,
torch.randn, torch.randn,