mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Summary: ## Rationale While most of the `torch.Generator` properties and methods are implemented as a thin wrapper of the corresponding `at::Generator` methods, `torch.Generator.get_state()` and `torch.Generator.set_state()` are implemented in legacy Torch code and are not dispatched through the `c10::GeneratorImpl` interface. This is not structured well and makes implementing generators for new backends (e.g. `XLAGeneratorImpl` for the XLA backend) inconvenient. As such, this pull request seeks to move these generator state APIs to c10 and ATen. ## What is being refactored? * Interfaces - Added `c10::GeneratorImpl::set_state` and `c10::GeneratorImpl::state` for getting and setting the internal state of a random number generator. - `at::Generator::set_state` and `at::Generator::state` wraps the above-mentioned APIs, as it's basically a PIMPL. - Added helper function `at::detail::check_rng_state` for checking the validity of new RNG state tensor. * CPU Generator - Renamed and moved `THTensor_(setRNGState)` and `THTensor_(getRNGState)` to `CPUGeneratorImpl::set_state` and `CPUGenerator::state`. - Renamed and moved `THGeneratorState` and `THGeneratorStateNew` to `CPUGeneratorStateLegacy` and `CPUGeneratorState`. * CUDA Generator - Renamed and moved `THCRandom_setRNGState` and `THCRandom_getRNGState` to `CUDAGeneratorImpl::set_state` and `CUDAGeneratorImpl::state`. * PyTorch Bindings - `THPGenerator_setState` and `THPGenerator_getState` now simply forward to `at::Generator::set_state` and `at::Generator::state`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/49589 Reviewed By: H-Huang Differential Revision: D25785774 Pulled By: pbelevich fbshipit-source-id: 8ed79209c4ffb1a0ae8b19952ac8871ac9e0255f
69 lines
2.5 KiB
C++
69 lines
2.5 KiB
C++
#include <torch/extension.h>
|
|
#include <torch/library.h>
|
|
#include <ATen/Generator.h>
|
|
#include <ATen/Tensor.h>
|
|
#include <ATen/native/DistributionTemplates.h>
|
|
#include <ATen/native/cpu/DistributionTemplates.h>
|
|
#include <memory>
|
|
|
|
using namespace at;
|
|
|
|
static size_t instance_count = 0;
|
|
|
|
struct TestCPUGenerator : public c10::GeneratorImpl {
|
|
TestCPUGenerator(uint64_t value) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, value_(value) {
|
|
++instance_count;
|
|
}
|
|
~TestCPUGenerator() {
|
|
--instance_count;
|
|
}
|
|
uint32_t random() { return static_cast<uint32_t>(value_); }
|
|
uint64_t random64() { return value_; }
|
|
void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
|
|
uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
|
|
uint64_t seed() override { throw std::runtime_error("not implemented"); }
|
|
void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); }
|
|
c10::intrusive_ptr<c10::TensorImpl> get_state() const override { throw std::runtime_error("not implemented"); }
|
|
TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); }
|
|
|
|
static DeviceType device_type() { return DeviceType::CPU; }
|
|
|
|
uint64_t value_;
|
|
};
|
|
|
|
Tensor& random_(Tensor& self, c10::optional<Generator> generator) {
|
|
return at::native::templates::random_impl<native::templates::cpu::RandomKernel, TestCPUGenerator>(self, generator);
|
|
}
|
|
|
|
Tensor& random_from_to(Tensor& self, int64_t from, optional<int64_t> to, c10::optional<Generator> generator) {
|
|
return at::native::templates::random_from_to_impl<native::templates::cpu::RandomFromToKernel, TestCPUGenerator>(self, from, to, generator);
|
|
}
|
|
|
|
Tensor& random_to(Tensor& self, int64_t to, c10::optional<Generator> generator) {
|
|
return random_from_to(self, 0, to, generator);
|
|
}
|
|
|
|
Generator createTestCPUGenerator(uint64_t value) {
|
|
return at::make_generator<TestCPUGenerator>(value);
|
|
}
|
|
|
|
Generator identity(Generator g) {
|
|
return g;
|
|
}
|
|
|
|
size_t getInstanceCount() {
|
|
return instance_count;
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
|
|
m.impl("aten::random_.from", random_from_to);
|
|
m.impl("aten::random_.to", random_to);
|
|
m.impl("aten::random_", random_);
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("createTestCPUGenerator", &createTestCPUGenerator);
|
|
m.def("getInstanceCount", &getInstanceCount);
|
|
m.def("identity", &identity);
|
|
}
|