mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for torch.Generator
type in TorchScript (#110413)
- Add support for `torch.Generator` type in TorchScript - Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_` - Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab) CC: @eellison @davidberard98 @GlebKazantaev @behzad-a Pull Request resolved: https://github.com/pytorch/pytorch/pull/110413 Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
3eacdaf1b3
commit
54493fe8c4
@ -1,4 +1,5 @@
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
|
||||
@ -2492,6 +2493,44 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs1{
|
||||
TORCH_SELECTIVE_SCHEMA("aten::manual_seed(int seed) -> ()"),
|
||||
[](Stack& stack) { at::manual_seed(pop(stack).toInt()); },
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::Generator(*, Device? device=None, int? seed=None) -> Generator"),
|
||||
[](Stack& stack) {
|
||||
auto seed = pop(stack).toOptional<int64_t>();
|
||||
auto device = pop(stack).toOptional<c10::Device>();
|
||||
push(
|
||||
stack,
|
||||
at::make_generator_for_device(
|
||||
device.value_or(c10::Device("cpu")), seed));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::initial_seed(Generator self) -> int"),
|
||||
[](Stack& stack) {
|
||||
auto generator = pop(stack);
|
||||
auto current_seed = generator.toGenerator().current_seed();
|
||||
push(stack, (int64_t)current_seed);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::manual_seed.generator(Generator(a!) self, int seed) -> Generator(a!)"),
|
||||
[](Stack& stack) {
|
||||
auto seed = pop(stack).toInt();
|
||||
auto generator = pop(stack);
|
||||
generator.toGenerator().set_current_seed(seed);
|
||||
push(stack, generator);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::seed(Generator(a!) self) -> int"),
|
||||
[](Stack& stack) {
|
||||
auto generator = pop(stack);
|
||||
auto current_seed = generator.toGenerator().seed();
|
||||
push(stack, (int64_t)current_seed);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::cuda(Tensor(a) self) -> Tensor(a|b)"),
|
||||
[](Stack& stack) {
|
||||
|
Reference in New Issue
Block a user