Add support for torch.Generator type in TorchScript (#110413)

- Add support for `torch.Generator` type in TorchScript
- Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_`
- Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab)

CC: @eellison @davidberard98 @GlebKazantaev @behzad-a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110413
Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
This commit is contained in:
Antonio Kim
2023-11-13 23:18:14 +00:00
committed by PyTorch MergeBot
parent 3eacdaf1b3
commit 54493fe8c4
39 changed files with 656 additions and 179 deletions

View File

@ -1,4 +1,5 @@
#include <ATen/autocast_mode.h>
#include <ATen/core/Generator.h>
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
@ -2492,6 +2493,44 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs1{
TORCH_SELECTIVE_SCHEMA("aten::manual_seed(int seed) -> ()"),
[](Stack& stack) { at::manual_seed(pop(stack).toInt()); },
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::Generator(*, Device? device=None, int? seed=None) -> Generator"),
[](Stack& stack) {
auto seed = pop(stack).toOptional<int64_t>();
auto device = pop(stack).toOptional<c10::Device>();
push(
stack,
at::make_generator_for_device(
device.value_or(c10::Device("cpu")), seed));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::initial_seed(Generator self) -> int"),
[](Stack& stack) {
auto generator = pop(stack);
auto current_seed = generator.toGenerator().current_seed();
push(stack, (int64_t)current_seed);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::manual_seed.generator(Generator(a!) self, int seed) -> Generator(a!)"),
[](Stack& stack) {
auto seed = pop(stack).toInt();
auto generator = pop(stack);
generator.toGenerator().set_current_seed(seed);
push(stack, generator);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::seed(Generator(a!) self) -> int"),
[](Stack& stack) {
auto generator = pop(stack);
auto current_seed = generator.toGenerator().seed();
push(stack, (int64_t)current_seed);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::cuda(Tensor(a) self) -> Tensor(a|b)"),
[](Stack& stack) {