Compare commits

...

10 Commits

Author SHA1 Message Date
29d9e9c762 Fix indentation 2025-11-11 15:20:41 +00:00
6cea2f04ca Advance global MKLGenerator state before generation
* Added a local `main_stream` copy of the global VSLStream
    and changed the kernel logic to immediately advance the global
    stream by the number of generated elements upon obtaining the copy.
    This fixes an issue where previously calling these kernels in
    multiple concurrent python threads could yield identical generated
    sequences (since the advance in the global stream happened only
    after the generation step was finished).
2025-11-10 12:31:42 +00:00
6bda7aa776 Update test_variance 2025-11-10 12:31:42 +00:00
232dabb5ab Update test_distributions.py 2025-11-10 12:31:42 +00:00
0b7eafd1c9 Update the bazel build to include mklrng 2025-11-10 12:31:42 +00:00
850cd5fa03 Allow more samples for convergence in tests of exponential 2025-11-10 12:31:42 +00:00
d068f2b695 Use VSL_BRNG_PHILOX4X32X10 instead of VSL_BRNG_MCG59 2025-11-10 12:31:42 +00:00
b10537378e Link MKLGenerator seed change to CPUGenerator seed changes 2025-11-10 12:31:42 +00:00
a6950289c3 Use MKLGeneratorImpl in DistributionKernels.cpp 2025-11-10 12:31:42 +00:00
7924f740f0 Implement MKLGeneratorImpl
* Implements MKLGeneratorImpl which uses MKL/OpenRNG to generate random
    variates and keeps a consistent global state.
  * Links MKLGenerator state change to CPUGenerator state changes
2025-11-10 12:31:41 +00:00
7 changed files with 310 additions and 42 deletions

View File

@ -195,6 +195,13 @@ filegroup(
]),
)
filegroup(
name = "aten_mklrng_cpp",
srcs = glob([
"aten/src/ATen/mklrng/*.cpp",
]),
)
filegroup(
name = "aten_native_mkldnn_cpp",
srcs = glob(["aten/src/ATen/native/mkldnn/*.cpp"]),
@ -357,6 +364,7 @@ cc_library(
":aten_base_cpp",
":aten_base_metal",
":aten_base_vulkan",
":aten_mklrng_cpp",
":aten_native_cpp",
":aten_native_mkl_cpp",
":aten_native_mkldnn_cpp",

View File

@ -85,6 +85,7 @@ file(GLOB miopen_h "miopen/*.h")
file(GLOB miopen_cpp "miopen/*.cpp")
file(GLOB mkl_cpp "mkl/*.cpp")
file(GLOB mklrng_cpp "mklrng/*.cpp")
file(GLOB mkldnn_cpp "mkldnn/*.cpp")
file(GLOB mkldnn_xpu_h "native/mkldnn/xpu/*.h" "native/mkldnn/xpu/detail/*.h")
@ -392,6 +393,7 @@ if(USE_LIGHTWEIGHT_DISPATCH)
endif()
if(AT_MKL_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
set(all_cpu_cpp ${all_cpu_cpp} ${mklrng_cpp})
endif()
if(AT_KLEIDIAI_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${native_kleidiai})

View File

@ -1,9 +1,14 @@
#include <ATen/CPUGeneratorImpl.h>
#include <ATen/Config.h>
#include <ATen/Utils.h>
#include <ATen/core/MT19937RNGEngine.h>
#include <c10/util/MathConstants.h>
#include <algorithm>
#if AT_MKL_ENABLED()
#include <ATen/mklrng/MKLGeneratorImpl.h>
#endif
namespace at {
namespace detail {
@ -43,6 +48,10 @@ struct CPUGeneratorImplState {
CPUGeneratorImplStateLegacy legacy_pod;
float next_float_normal_sample;
bool is_next_float_normal_sample_valid;
#if AT_MKL_ENABLED()
uint64_t mkl_seed;
uint64_t mkl_offset;
#endif
};
/**
@ -82,7 +91,15 @@ CPUGeneratorImpl::CPUGeneratorImpl(uint64_t seed_in)
: c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(c10::DispatchKey::CPU)},
engine_{seed_in},
next_float_normal_sample_{std::optional<float>()},
next_double_normal_sample_{std::optional<double>()} { }
next_double_normal_sample_{std::optional<double>()} {
#if AT_MKL_ENABLED()
{
auto mkl_gen = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
std::scoped_lock lock(mkl_gen->mutex_);
mkl_gen->set_current_seed(seed_in);
}
#endif
}
/**
* Manually seeds the engine with the seed input
@ -92,6 +109,13 @@ void CPUGeneratorImpl::set_current_seed(uint64_t seed) {
next_float_normal_sample_.reset();
next_double_normal_sample_.reset();
engine_ = mt19937(seed);
#if AT_MKL_ENABLED()
{
auto mkl_gen = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
std::scoped_lock lock(mkl_gen->mutex_);
mkl_gen->set_current_seed(seed);
}
#endif
}
/**
@ -126,6 +150,13 @@ uint64_t CPUGeneratorImpl::current_seed() const {
uint64_t CPUGeneratorImpl::seed() {
auto random = c10::detail::getNonDeterministicRandom();
this->set_current_seed(random);
#if AT_MKL_ENABLED()
{
auto mkl_gen = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
std::scoped_lock lock(mkl_gen->mutex_);
mkl_gen->set_current_seed(random);
}
#endif
return random;
}
@ -169,6 +200,14 @@ void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
this->next_double_normal_sample_ = legacy_pod->normal_is_valid
? std::optional<double>(legacy_pod->normal_y)
: std::optional<double>();
#if AT_MKL_ENABLED()
{
auto mkl_gen = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
std::scoped_lock lock(mkl_gen->mutex_);
mkl_gen->set_current_seed(rng_state->mkl_seed);
mkl_gen->skip_ahead(rng_state->mkl_offset);
}
#endif
}
/**
@ -207,6 +246,15 @@ c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
accum_state->next_float_normal_sample = *(this->next_float_normal_sample_);
}
#if AT_MKL_ENABLED()
{
auto mkl_gen = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
std::scoped_lock lock(mkl_gen->mutex_);
accum_state->mkl_seed = mkl_gen->current_seed();
accum_state->mkl_offset = mkl_gen->get_offset();
}
#endif
memcpy(rng_state, accum_state.get(), size);
return state_tensor.getIntrusivePtr();
}

View File

@ -0,0 +1,155 @@
#include <ATen/mklrng/MKLGeneratorImpl.h>
#include <ATen/Utils.h>
#include <cstdint>
namespace at {
namespace detail {
/**
* PyTorch maintains a collection of default generators that get
* initialized once. The purpose of these default generators is to
* maintain a global running state of the pseudo random number generation,
* when a user does not explicitly mention any generator.
* getDefaultMKLGenerator gets the default generator for a particular
* device.
*/
const Generator& getDefaultMKLGenerator() {
static auto gen = createMKLGenerator(c10::detail::getNonDeterministicRandom());
return gen;
}
/**
* Utility to create an MKLGeneratorImpl. Returns a shared_ptr
*/
Generator createMKLGenerator(uint64_t seed_val) {
return make_generator<MKLGeneratorImpl>(seed_val);
}
} // namespace detail
/**
* MKLGeneratorImpl class implementation
*/
MKLGeneratorImpl::MKLGeneratorImpl(uint64_t seed_in)
: c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(c10::DispatchKey::CPU)},
seed_(seed_in),
offset_(0) {
vslNewStream(&stream_, VSL_BRNG_PHILOX4X32X10, seed_);
}
/**
* Manually seeds the engine with the seed input
* See Note [Acquire lock when using random generators]
*/
void MKLGeneratorImpl::set_current_seed(uint64_t seed) {
this->seed_ = seed;
vslDeleteStream(&stream_);
vslNewStream(&stream_, VSL_BRNG_PHILOX4X32X10, seed_);
this->offset_ = 0;
}
/**
* Gets a nondeterministic random number from /dev/urandom or time,
* seeds the MKLGeneratorImpl with it and then returns that number.
* See Note [Acquire lock when using random generators]
*/
uint64_t MKLGeneratorImpl::seed() {
auto random = c10::detail::getNonDeterministicRandom();
this->set_current_seed(static_cast<uint64_t>(random));
return random;
}
/**
* Gets the current seed of CPUGeneratorImpl.
*/
uint64_t MKLGeneratorImpl::current_seed() const {
return this->seed_;
}
/**
* Gets the DeviceType of MKLGeneratorImpl.
* Used for type checking during run time.
*/
DeviceType MKLGeneratorImpl::device_type() {
return DeviceType::CPU;
}
/**
* Gets the copy of VSLStreamStatePtr in MKLGenerator
* to be used for variate generation in a thread-safe way
* (each thread should receive its own stream copy).
* See Note [Acquire lock when using random generators]
*/
void MKLGeneratorImpl::get_stream_copy(VSLStreamStatePtr &streamCopy) {
vslCopyStream(&streamCopy, stream_);
}
/**
* Progresses the internal PRNG state n steps ahead --
* used to account for variates generated by the copies
* of the stream in MKLGenerator.
* See Note [Acquire lock when using random generators]
*/
void MKLGeneratorImpl::skip_ahead(uint64_t n) {
vslSkipAheadStream(stream_, n);
this->advance_offset(n);
}
/**
* Private clone method implementation
* See Note [Acquire lock when using random generators]
*/
MKLGeneratorImpl* MKLGeneratorImpl::clone_impl() const {
auto gen = new MKLGeneratorImpl();
return gen;
}
/**
* Public clone method implementation
* See Note [Acquire lock when using random generators]
*/
std::shared_ptr<MKLGeneratorImpl> MKLGeneratorImpl::clone() const {
return std::shared_ptr<MKLGeneratorImpl>(this->clone_impl());
}
/**
* Sets the offset of RNG state.
* See Note [Acquire lock when using random generators]
*/
void MKLGeneratorImpl::set_offset(uint64_t offset) {
TORCH_CHECK(false, "MKL Generator does not allow to set offset");
}
/**
* Gets the offset of RNG state.
* See Note [Acquire lock when using random generators]
*/
uint64_t MKLGeneratorImpl::get_offset() const {
return this->offset_;
}
/**
* Private method to advance the offset of RNG state.
* See Note [Acquire lock when using random generators]
*/
void MKLGeneratorImpl::advance_offset(uint64_t n) {
this->offset_ += n;
}
/**
* Gets the current internal state of MKLGeneratorImpl. The internal
* state is returned as a CPU byte tensor.
*/
c10::intrusive_ptr<c10::TensorImpl> MKLGeneratorImpl::get_state() const {
TORCH_CHECK(false, "MKL Generator does not use get_state");
}
/**
* Sets the internal state of MKLGeneratorImpl. The new internal state
* must be a strided CPU byte tensor.
*/
void MKLGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
TORCH_CHECK(false, "MKL Generator does not use set_state");
}
} // namespace at

View File

@ -0,0 +1,46 @@
#pragma once
#include <ATen/Config.h>
#include <ATen/core/Generator.h>
#include <c10/core/GeneratorImpl.h>
#include <cstdint>
#include <mkl.h>
namespace at {
struct TORCH_API MKLGeneratorImpl : public c10::GeneratorImpl {
// Constructors
MKLGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
~MKLGeneratorImpl() override = default;
// MKLGeneratorImpl methods
std::shared_ptr<MKLGeneratorImpl> clone() const;
void set_current_seed(uint64_t seed) override;
uint64_t seed() override;
uint64_t current_seed() const override;
void set_offset(uint64_t offset) override;
uint64_t get_offset() const override;
static c10::DeviceType device_type();
void get_stream_copy(VSLStreamStatePtr &streamCopy);
void skip_ahead(uint64_t n);
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
private:
MKLGeneratorImpl* clone_impl() const override;
void advance_offset(uint64_t n);
VSLStreamStatePtr stream_;
uint64_t seed_;
uint64_t offset_;
};
namespace detail {
TORCH_API const Generator& getDefaultMKLGenerator();
TORCH_API Generator
createMKLGenerator(uint64_t seed_val = default_rng_seed_val);
} // namespace detail
} // namespace at

View File

@ -18,9 +18,9 @@
#include <limits>
#include <type_traits>
// Disable MKL rng until https://github.com/pytorch/pytorch/issues/132395 is addressed
#if AT_MKL_ENABLED() && defined(FBCODE_CAFFE2)
#if AT_MKL_ENABLED()
#include <mkl.h>
#include <ATen/mklrng/MKLGeneratorImpl.h>
#include <cpuinfo.h>
#endif
@ -37,8 +37,7 @@ void bernoulli_tensor_kernel(const TensorBase &self, const TensorBase &p_, std::
templates::cpu::bernoulli_kernel(self, p_, generator);
}
// Disable MKL rng until https://github.com/pytorch/pytorch/issues/132395 is addressed
#if !AT_MKL_ENABLED() || (AT_MKL_ENABLED() && !defined(FBCODE_CAFFE2))
#if !AT_MKL_ENABLED()
void bernoulli_scalar_kernel_default(const TensorBase &self, double p, std::optional<Generator> gen) {
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
templates::cpu::bernoulli_kernel(self, p, generator);
@ -49,13 +48,6 @@ void bernoulli_scalar_kernel(const TensorBase &self, double p, std::optional<Gen
}
#else
void bernoulli_scalar_kernel(const TensorBase &self, double p, std::optional<Generator> gen) {
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
int64_t seed;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
seed = generator->random();
}
int64_t n = self.numel();
bool contig = self.is_contiguous();
@ -71,15 +63,28 @@ void bernoulli_scalar_kernel(const TensorBase &self, double p, std::optional<Gen
scalar_t *self_ptr = self.data_ptr<scalar_t>();
int *sample_int_ptr = tmp_int_tensor.data_ptr<int>();
auto mklGenerator = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
VSLStreamStatePtr main_stream;
// Get a local copy of the global stream and immediately advance the global
// state before the generation step to avoid multiple threads using the same state.
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(mklGenerator->mutex_);
mklGenerator->get_stream_copy(main_stream);
mklGenerator->skip_ahead(n);
}
auto sample = [&](int64_t begin, int64_t end) {
int64_t len = end - begin;
if (len > 0) {
VSLStreamStatePtr stream;
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
vslSkipAheadStream(stream, begin);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, len,
VSLStreamStatePtr sample_stream;
vslCopyStream(&sample_stream, main_stream);
vslSkipAheadStream(sample_stream, begin);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, sample_stream, len,
sample_int_ptr + begin, p);
vslDeleteStream(&stream);
vslDeleteStream(&sample_stream);
// vectorized copy if using buffer and contiguous, i.e., being non-int
// type and contiguous
@ -92,6 +97,7 @@ void bernoulli_scalar_kernel(const TensorBase &self, double p, std::optional<Gen
};
parallel_for(0, n, /* grain_size= */ 800, sample);
vslDeleteStream(&main_stream);
// copy_ if using buffer and non contiguous
if (!contig) {
@ -106,27 +112,15 @@ void exponential_kernel_default(TensorIteratorBase& iter, double lambda, std::op
templates::cpu::exponential_kernel(iter, lambda, generator);
}
// Disable MKL rng until https://github.com/pytorch/pytorch/issues/132395 is addressed
#if (!AT_MKL_ENABLED() || defined(FBCODE_CAFFE2) || 1)
#if (!AT_MKL_ENABLED() || defined(FBCODE_CAFFE2))
void exponential_kernel(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
exponential_kernel_default(iter, lambda, gen);
}
#else
void exponential_kernel(TensorIteratorBase &iter, double lambda, std::optional<Generator> gen) {
TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
Tensor self = iter.tensor(0);
if (lambda > 0 && !std::isinf(lambda) && !std::isnan(lambda)) {
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
int64_t seed;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
if (self.scalar_type() == at::kDouble)
seed = generator->random64();
else
seed = generator->random();
}
int64_t n = self.numel();
bool contig = self.is_contiguous();
@ -158,23 +152,35 @@ void exponential_kernel(TensorIteratorBase &iter, double lambda, std::optional<G
// Variance: V[X+eps] = 1/lambda**2
auto eps = std::numeric_limits<tmp_scalar_t>::min();
auto mklGenerator = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
VSLStreamStatePtr main_stream;
// Get a local copy of the global stream and immediately advance the global
// state before the generation step to avoid multiple threads using the same state.
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(mklGenerator->mutex_);
mklGenerator->get_stream_copy(main_stream);
mklGenerator->skip_ahead(n);
}
auto sample = [&](int64_t begin, int64_t end) {
int64_t len = end - begin;
if (len > 0) {
VSLStreamStatePtr stream;
VSLStreamStatePtr sample_stream;
vslCopyStream(&sample_stream, main_stream);
vslSkipAheadStream(sample_stream, begin);
if constexpr (std::is_same_v<scalar_t, double>) {
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
vslSkipAheadStream(stream, begin);
vdRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, stream, len,
vdRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, sample_stream, len,
(double *)(sample_ptr + begin), eps, 1./lambda);
vslDeleteStream(&stream);
vslDeleteStream(&sample_stream);
} else {
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
vslSkipAheadStream(stream, begin);
vsRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, stream, len,
vsRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, sample_stream, len,
(float *) (sample_ptr + begin), eps, 1./lambda);
vslDeleteStream(&stream);
vslDeleteStream(&sample_stream);
}
// vectorized copy if using buffer and contiguous
if (!is_df && contig) {
scalar_t *self_seg = self_ptr + begin;
@ -185,6 +191,7 @@ void exponential_kernel(TensorIteratorBase &iter, double lambda, std::optional<G
};
parallel_for(0, n, /* grain_size= */ 800, sample);
vslDeleteStream(&main_stream);
// copy_ if using buffer and non contiguous
if (!contig) {

View File

@ -3402,14 +3402,15 @@ class TestDistributions(DistributionsTestCase):
self.assertEqual((1.0 / lambd), mean, atol=2e-2, rtol=2e-2)
self.assertEqual((1.0 / lambd) ** 2, var, atol=2e-2, rtol=2e-2)
set_rng_seed(2) # see Note [Randomized statistical tests]
for dtype in [torch.float, torch.double, torch.bfloat16, torch.float16]:
for lambd in [0.2, 0.5, 1.0, 1.5, 2.0, 5.0]:
sample_len = 50000
sample_len = 50_000
mean_var(lambd, torch.rand(sample_len, dtype=dtype))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_exponential_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
set_rng_seed(2) # see Note [Randomized statistical tests]
for rate in [1e-5, 1.0, 10.0]:
self._check_sampler_sampler(
Exponential(rate),
@ -3569,7 +3570,7 @@ class TestDistributions(DistributionsTestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_pareto_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
set_rng_seed(3) # see Note [Randomized statistical tests]
for scale, alpha in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(
Pareto(scale, alpha),
@ -7043,6 +7044,7 @@ class TestJit(DistributionsTestCase):
)
def test_variance(self):
set_rng_seed(3) # see Note [Randomized statistical tests]
for Dist, keys, values, sample in self._examples():
if Dist in [Cauchy, HalfCauchy]:
continue # infinite variance