Add support for torch.Generator type in TorchScript (#110413)

- Add support for `torch.Generator` type in TorchScript
- Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_`
- Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab)

CC: @eellison @davidberard98 @GlebKazantaev @behzad-a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110413
Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
This commit is contained in:
Antonio Kim
2023-11-06 21:26:57 +00:00
committed by PyTorch MergeBot
parent 7b99b3efb1
commit 27e31ab6e8
39 changed files with 650 additions and 179 deletions

View File

@ -933,10 +933,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
distance_function=None, margin=1.0,
swap=False, reduction='mean': -1),
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nn.init.uniform_: lambda tensor, a=0., b=1.: -1,
torch.nn.init.normal_: lambda tensor, mean=0., std=1.: -1,
torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
torch.nn.init.constant_: lambda tensor, val: -1,
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu': -1,
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
torch.nonzero: lambda input, as_tuple=False: -1,
torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
torch.argwhere: lambda input: -1,