Add generator parameter to rand*_like functions (#136780)

Fixes #128786
Fixes #101974
Fixes #27072

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136780
Approved by: https://github.com/Chillee, https://github.com/ezyang
This commit is contained in:
Sam
2025-01-15 21:16:50 +00:00
committed by PyTorch MergeBot
parent d62b3979da
commit c7b2f7dd14
7 changed files with 265 additions and 17 deletions

View File

@ -1094,13 +1094,31 @@ Tensor rand_like(
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
return native::rand_like(
self,
static_cast<std::optional<Generator>>(std::nullopt),
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
Tensor rand_like(
const Tensor& self,
std::optional<Generator> generator,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options =
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
pin_memory);
auto result = at::empty_like(self, options, optional_memory_format);
return result.uniform_(0, 1, std::nullopt);
return result.uniform_(0, 1, std::move(generator));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randint ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -1203,13 +1221,37 @@ Tensor randint_like(
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options =
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
pin_memory);
return native::randint_like(
self,
0,
high,
static_cast<std::optional<Generator>>(std::nullopt),
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
auto result = at::empty_like(self, options, optional_memory_format);
return result.random_(0, high, std::nullopt);
Tensor randint_like(
const Tensor& self,
int64_t high,
std::optional<Generator> generator,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
return native::randint_like(
self,
0,
high,
std::move(generator),
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
Tensor randint_like(
@ -1221,13 +1263,35 @@ Tensor randint_like(
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
return native::randint_like(
self,
low,
high,
static_cast<std::optional<Generator>>(std::nullopt),
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
Tensor randint_like(
const Tensor& self,
int64_t low,
int64_t high,
std::optional<Generator> generator,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options =
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
pin_memory);
auto result = at::empty_like(self, options, optional_memory_format);
return result.random_(low, high, std::nullopt);
return result.random_(low, high, std::move(generator));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randn ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -1310,13 +1374,31 @@ Tensor randn_like(
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
return native::randn_like(
self,
static_cast<std::optional<Generator>>(std::nullopt),
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
Tensor randn_like(
const Tensor& self,
std::optional<Generator> generator,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options =
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
pin_memory);
auto result = at::empty_like(self, options, optional_memory_format);
return result.normal_(0, 1, std::nullopt);
return result.normal_(0, 1, std::move(generator));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randperm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -4709,6 +4709,14 @@
CompositeExplicitAutograd: rand_like
autogen: rand_like.out
- func: rand_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd: rand_like
autogen: rand_like.generator_out
- func: randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
@ -4757,6 +4765,14 @@
CompositeExplicitAutograd: randint_like
autogen: randint_like.out
- func: randint_like.generator(Tensor self, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd: randint_like
autogen: randint_like.generator_out
- func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
@ -4765,6 +4781,14 @@
CompositeExplicitAutograd: randint_like
autogen: randint_like.low_dtype_out
- func: randint_like.generator_with_low_dtype(Tensor self, SymInt low, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd: randint_like
autogen: randint_like.generator_with_low_dtype_out
- func: randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
tags: [core, nondeterministic_seeded]
dispatch:
@ -4805,6 +4829,14 @@
CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like
autogen: randn_like.out
- func: randn_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like
autogen: randn_like.generator_out
- func: randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
tags: [core, nondeterministic_seeded]
dispatch:

View File

@ -1072,6 +1072,8 @@ aten::rand.names
aten::rand.names_out
aten::rand.out
aten::rand_like
aten::rand_like.generator
aten::rand_like.generator_out
aten::rand_like.out
aten::randint
aten::randint.generator
@ -1082,6 +1084,10 @@ aten::randint.low_generator_out
aten::randint.low_out
aten::randint.out
aten::randint_like
aten::randint_like.generator
aten::randint_like.generator_out
aten::randint_like.generator_with_low_dtype
aten::randint_like.generator_with_low_dtype_out
aten::randint_like.low_dtype
aten::randint_like.low_dtype_out
aten::randint_like.out
@ -1091,6 +1097,8 @@ aten::randn.generator_with_names_out
aten::randn.names
aten::randn.names_out
aten::randn_like
aten::randn_like.generator
aten::randn_like.generator_out
aten::randn_like.out
aten::random
aten::random.from

View File

@ -3660,6 +3660,116 @@ class TestRandomTensorCreation(TestCase):
self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cpu', generator=cuda_gen, out=cpu_t))
self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, generator=cuda_gen)) # implicitly on CPU
@dtypes(*integral_types_and(torch.uint16, torch.uint32, torch.uint64))
def test_randint_like(self, device, dtype):
SIZE = 100
RANGE = (0, 6)
def seed(generator):
if generator is None:
torch.manual_seed(123456)
else:
generator.manual_seed(123456)
return generator
tensor = torch.empty((SIZE, SIZE), device=device, dtype=dtype)
gen = torch.Generator(device=device)
# Using default generator
generator = seed(None)
res1 = torch.randint(*RANGE, tensor.size(), device=tensor.device, dtype=tensor.dtype,
layout=tensor.layout, generator=generator)
generator = seed(None)
res2 = torch.randint_like(tensor, *RANGE, generator=generator)
self.assertEqual(res1, res2, exact_device=True, exact_layout=True)
# Using explicit generator
generator = seed(gen)
res1 = torch.randint(*RANGE, tensor.size(), device=tensor.device, dtype=tensor.dtype,
layout=tensor.layout, generator=generator)
generator = seed(gen)
res2 = torch.randint_like(tensor, *RANGE, generator=generator)
self.assertEqual(res1, res2, exact_device=True, exact_layout=True)
# Default vs. explicit
generator = seed(gen)
res1 = torch.randint_like(tensor, *RANGE, generator=generator)
generator = seed(None)
res2 = torch.randint_like(tensor, *RANGE, generator=generator)
self.assertEqual(res1, res2, exact_device=True, exact_layout=True)
@dtypes(torch.half, torch.float, torch.bfloat16, torch.double,
torch.complex32, torch.complex64, torch.complex128)
def test_randn_like(self, device, dtype):
SIZE = 100
def seed(generator):
if generator is None:
torch.manual_seed(123456)
else:
generator.manual_seed(123456)
return generator
tensor = torch.empty((SIZE, SIZE), device=device, dtype=dtype)
gen = torch.Generator(device=device)
# Using default generator
generator = seed(None)
res1 = torch.randn(tensor.size(), device=tensor.device, dtype=tensor.dtype, layout=tensor.layout, generator=generator)
generator = seed(None)
res2 = torch.randn_like(tensor, generator=generator)
self.assertEqual(res1, res2, exact_device=True, exact_layout=True)
# Using explicit generator
generator = seed(gen)
res1 = torch.randn(tensor.size(), device=tensor.device, dtype=tensor.dtype, layout=tensor.layout, generator=generator)
generator = seed(gen)
res2 = torch.randn_like(tensor, generator=generator)
self.assertEqual(res1, res2, exact_device=True, exact_layout=True)
# Default vs. explicit
generator = seed(gen)
res1 = torch.randn_like(tensor, generator=generator)
generator = seed(None)
res2 = torch.randn_like(tensor, generator=generator)
self.assertEqual(res1, res2, exact_device=True, exact_layout=True)
@dtypes(torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
def test_rand_like(self, device, dtype):
SIZE = 100
def seed(generator):
if generator is None:
torch.manual_seed(123456)
else:
generator.manual_seed(123456)
return generator
tensor = torch.empty((SIZE, SIZE), device=device, dtype=dtype)
gen = torch.Generator(device=device)
# Using default generator
generator = seed(None)
res1 = torch.rand(tensor.size(), device=tensor.device, dtype=tensor.dtype, layout=tensor.layout, generator=generator)
generator = seed(None)
res2 = torch.rand_like(tensor, generator=generator)
self.assertEqual(res1, res2, exact_device=True, exact_layout=True)
# Using explicit generator
generator = seed(gen)
res1 = torch.rand(tensor.size(), device=tensor.device, dtype=tensor.dtype, layout=tensor.layout, generator=generator)
generator = seed(gen)
res2 = torch.rand_like(tensor, generator=generator)
self.assertEqual(res1, res2, exact_device=True, exact_layout=True)
# Default vs. explicit
generator = seed(gen)
res1 = torch.rand_like(tensor, generator=generator)
generator = seed(None)
res2 = torch.rand_like(tensor, generator=generator)
self.assertEqual(res1, res2, exact_device=True, exact_layout=True)
# Class for testing *like ops, like torch.ones_like
class TestLikeTensorCreation(TestCase):
exact_dtype = True

View File

@ -65,12 +65,20 @@ _like_tensor_constructors = ordered_set(
aten.ones_like.out,
aten.rand_like.default,
aten.rand_like.out,
aten.rand_like.generator,
aten.rand_like.generator_out,
aten.randn_like.default,
aten.randn_like.out,
aten.randn_like.generator,
aten.randn_like.generator_out,
aten.randint_like.default,
aten.randint_like.out,
aten.randint_like.low_dtype,
aten.randint_like.low_dtype_out,
aten.randint_like.generator,
aten.randint_like.generator_out,
aten.randint_like.generator_with_low_dtype,
aten.randint_like.generator_with_low_dtype_out,
aten.zeros_like.default,
aten.zeros_like.out,
aten.new_empty.default,

View File

@ -129,6 +129,7 @@ factory_common_args = merge_dicts(
factory_like_common_args = parse_kwargs(
"""
input (Tensor): the size of :attr:`input` will determine size of the output tensor.
generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling
layout (:class:`torch.layout`, optional): the desired layout of returned tensor.
Default: if ``None``, defaults to the layout of :attr:`input`.
dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor.
@ -8816,8 +8817,9 @@ Example::
add_docstr(
torch.rand_like,
r"""
rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor
"""
rand_like(input, *, generator=None, dtype=None, layout=None, device=None, requires_grad=False, \
memory_format=torch.preserve_format) -> Tensor
Returns a tensor with the same size as :attr:`input` that is filled with
random numbers from a uniform distribution on the interval :math:`[0, 1)`.
@ -8828,6 +8830,7 @@ Args:
{input}
Keyword args:
{generator}
{dtype}
{layout}
{device}
@ -8888,7 +8891,7 @@ Example::
add_docstr(
torch.randint_like,
"""
randint_like(input, low=0, high, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \
randint_like(input, low=0, high, \\*, generator=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, \
memory_format=torch.preserve_format) -> Tensor
Returns a tensor with the same shape as Tensor :attr:`input` filled with
@ -8905,6 +8908,7 @@ Args:
high (int): One above the highest integer to be drawn from the distribution.
Keyword args:
{generator}
{dtype}
{layout}
{device}
@ -8972,8 +8976,9 @@ Example::
add_docstr(
torch.randn_like,
r"""
randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor
"""
randn_like(input, *, generator=None, dtype=None, layout=None, device=None, requires_grad=False, \
memory_format=torch.preserve_format) -> Tensor
Returns a tensor with the same size as :attr:`input` that is filled with
random numbers from a normal distribution with mean 0 and variance 1. Please refer to :func:`torch.randn` for the
@ -8984,6 +8989,7 @@ Args:
{input}
Keyword args:
{generator}
{dtype}
{layout}
{device}

View File

@ -1065,9 +1065,11 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
),
torch.rad2deg: lambda input, out=None: -1,
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.rand_like: lambda input, dtype=None, layout=None, device=None, generator=None, requires_grad=False: -1,
torch.randint_like: (
lambda input, high, dtype=None, layout=torch.strided, device=None, generator=None, requires_grad=False: -1
),
torch.randn_like: lambda input, dtype=None, layout=None, device=None, generator=None, requires_grad=False: -1,
torch.ravel: lambda input: -1,
torch.real: lambda input, out=None: -1,
torch.vdot: lambda input, other, out=None: -1,