Revert "Temp disable MKL in DistributionKernels.cpp (#132532)"

This reverts commit 7b2664ece6a961ce9e4557be913c2cead09c7390.

Reverted https://github.com/pytorch/pytorch/pull/132532 on behalf of https://github.com/PaliC due to causing numerical instability issues internally ([comment](https://github.com/pytorch/pytorch/pull/132532#issuecomment-2272136210))
This commit is contained in:
PyTorch MergeBot
2024-08-06 20:57:09 +00:00
parent 94155ce31b
commit e47b684c33
2 changed files with 3 additions and 21 deletions

View File

@ -18,8 +18,7 @@
#include <limits>
#include <type_traits>
// Disable MKL rng until https://github.com/pytorch/pytorch/issues/132395 is addressed
#if AT_MKL_ENABLED() && 0
#if AT_MKL_ENABLED()
#include <mkl.h>
#include <cpuinfo.h>
#endif
@ -37,8 +36,7 @@ void bernoulli_tensor_kernel(const TensorBase &self, const TensorBase &p_, std::
templates::cpu::bernoulli_kernel(self, p_, generator);
}
// Disable MKL rng until https://github.com/pytorch/pytorch/issues/132395 is addressed
#if !AT_MKL_ENABLED() || 1
#if !AT_MKL_ENABLED()
void bernoulli_scalar_kernel_default(const TensorBase &self, double p, std::optional<Generator> gen) {
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
templates::cpu::bernoulli_kernel(self, p, generator);
@ -106,8 +104,7 @@ static void exponential_kernel_default(TensorIteratorBase& iter, double lambda,
templates::cpu::exponential_kernel(iter, lambda, generator);
}
// Disable MKL rng until https://github.com/pytorch/pytorch/issues/132395 is addressed
#if (!AT_MKL_ENABLED() || defined(FBCODE_CAFFE2) || 1)
#if (!AT_MKL_ENABLED() || defined(FBCODE_CAFFE2))
void exponential_kernel(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
exponential_kernel_default(iter, lambda, gen);
}

View File

@ -1841,21 +1841,6 @@ class TestDistributions(DistributionsTestCase):
torch.tensor([[total_count, 0], [0, total_count]], dtype=torch.float64),
)
def test_multinomial_sequential_draw(self):
# Adapted after script mentioned in https://github.com/pytorch/pytorch/issues/132395
torch.manual_seed(0xDE0B6B3A764007E8)
prob = torch.ones(26)
dups_mult = 0
perm_counts_mult = {}
for _ in range(300_000):
p = tuple(torch.multinomial(prob, prob.numel(), replacement=False).tolist())
if p in perm_counts_mult:
dups_mult += 1
perm_counts_mult[p] += 1
else:
perm_counts_mult[p] = 1
self.assertLess(dups_mult, 10)
@set_default_dtype(torch.double)
def test_categorical_1d(self):
p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)