mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 13:54:36 +08:00
Compare commits
10 Commits
sy_invoke_
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 29d9e9c762 | |||
| 6cea2f04ca | |||
| 6bda7aa776 | |||
| 232dabb5ab | |||
| 0b7eafd1c9 | |||
| 850cd5fa03 | |||
| d068f2b695 | |||
| b10537378e | |||
| a6950289c3 | |||
| 7924f740f0 |
@ -195,6 +195,13 @@ filegroup(
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aten_mklrng_cpp",
|
||||
srcs = glob([
|
||||
"aten/src/ATen/mklrng/*.cpp",
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aten_native_mkldnn_cpp",
|
||||
srcs = glob(["aten/src/ATen/native/mkldnn/*.cpp"]),
|
||||
@ -357,6 +364,7 @@ cc_library(
|
||||
":aten_base_cpp",
|
||||
":aten_base_metal",
|
||||
":aten_base_vulkan",
|
||||
":aten_mklrng_cpp",
|
||||
":aten_native_cpp",
|
||||
":aten_native_mkl_cpp",
|
||||
":aten_native_mkldnn_cpp",
|
||||
|
||||
@ -85,6 +85,7 @@ file(GLOB miopen_h "miopen/*.h")
|
||||
file(GLOB miopen_cpp "miopen/*.cpp")
|
||||
|
||||
file(GLOB mkl_cpp "mkl/*.cpp")
|
||||
file(GLOB mklrng_cpp "mklrng/*.cpp")
|
||||
file(GLOB mkldnn_cpp "mkldnn/*.cpp")
|
||||
|
||||
file(GLOB mkldnn_xpu_h "native/mkldnn/xpu/*.h" "native/mkldnn/xpu/detail/*.h")
|
||||
@ -392,6 +393,7 @@ if(USE_LIGHTWEIGHT_DISPATCH)
|
||||
endif()
|
||||
if(AT_MKL_ENABLED)
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${mklrng_cpp})
|
||||
endif()
|
||||
if(AT_KLEIDIAI_ENABLED)
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${native_kleidiai})
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
#include <ATen/CPUGeneratorImpl.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/core/MT19937RNGEngine.h>
|
||||
#include <c10/util/MathConstants.h>
|
||||
#include <algorithm>
|
||||
|
||||
#if AT_MKL_ENABLED()
|
||||
#include <ATen/mklrng/MKLGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace detail {
|
||||
@ -43,6 +48,10 @@ struct CPUGeneratorImplState {
|
||||
CPUGeneratorImplStateLegacy legacy_pod;
|
||||
float next_float_normal_sample;
|
||||
bool is_next_float_normal_sample_valid;
|
||||
#if AT_MKL_ENABLED()
|
||||
uint64_t mkl_seed;
|
||||
uint64_t mkl_offset;
|
||||
#endif
|
||||
};
|
||||
|
||||
/**
|
||||
@ -82,7 +91,15 @@ CPUGeneratorImpl::CPUGeneratorImpl(uint64_t seed_in)
|
||||
: c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(c10::DispatchKey::CPU)},
|
||||
engine_{seed_in},
|
||||
next_float_normal_sample_{std::optional<float>()},
|
||||
next_double_normal_sample_{std::optional<double>()} { }
|
||||
next_double_normal_sample_{std::optional<double>()} {
|
||||
#if AT_MKL_ENABLED()
|
||||
{
|
||||
auto mkl_gen = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
|
||||
std::scoped_lock lock(mkl_gen->mutex_);
|
||||
mkl_gen->set_current_seed(seed_in);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Manually seeds the engine with the seed input
|
||||
@ -92,6 +109,13 @@ void CPUGeneratorImpl::set_current_seed(uint64_t seed) {
|
||||
next_float_normal_sample_.reset();
|
||||
next_double_normal_sample_.reset();
|
||||
engine_ = mt19937(seed);
|
||||
#if AT_MKL_ENABLED()
|
||||
{
|
||||
auto mkl_gen = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
|
||||
std::scoped_lock lock(mkl_gen->mutex_);
|
||||
mkl_gen->set_current_seed(seed);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
@ -126,6 +150,13 @@ uint64_t CPUGeneratorImpl::current_seed() const {
|
||||
uint64_t CPUGeneratorImpl::seed() {
|
||||
auto random = c10::detail::getNonDeterministicRandom();
|
||||
this->set_current_seed(random);
|
||||
#if AT_MKL_ENABLED()
|
||||
{
|
||||
auto mkl_gen = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
|
||||
std::scoped_lock lock(mkl_gen->mutex_);
|
||||
mkl_gen->set_current_seed(random);
|
||||
}
|
||||
#endif
|
||||
return random;
|
||||
}
|
||||
|
||||
@ -169,6 +200,14 @@ void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
this->next_double_normal_sample_ = legacy_pod->normal_is_valid
|
||||
? std::optional<double>(legacy_pod->normal_y)
|
||||
: std::optional<double>();
|
||||
#if AT_MKL_ENABLED()
|
||||
{
|
||||
auto mkl_gen = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
|
||||
std::scoped_lock lock(mkl_gen->mutex_);
|
||||
mkl_gen->set_current_seed(rng_state->mkl_seed);
|
||||
mkl_gen->skip_ahead(rng_state->mkl_offset);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
@ -207,6 +246,15 @@ c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
|
||||
accum_state->next_float_normal_sample = *(this->next_float_normal_sample_);
|
||||
}
|
||||
|
||||
#if AT_MKL_ENABLED()
|
||||
{
|
||||
auto mkl_gen = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
|
||||
std::scoped_lock lock(mkl_gen->mutex_);
|
||||
accum_state->mkl_seed = mkl_gen->current_seed();
|
||||
accum_state->mkl_offset = mkl_gen->get_offset();
|
||||
}
|
||||
#endif
|
||||
|
||||
memcpy(rng_state, accum_state.get(), size);
|
||||
return state_tensor.getIntrusivePtr();
|
||||
}
|
||||
|
||||
155
aten/src/ATen/mklrng/MKLGeneratorImpl.cpp
Normal file
155
aten/src/ATen/mklrng/MKLGeneratorImpl.cpp
Normal file
@ -0,0 +1,155 @@
|
||||
#include <ATen/mklrng/MKLGeneratorImpl.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace at {
|
||||
namespace detail {
|
||||
|
||||
/**
|
||||
* PyTorch maintains a collection of default generators that get
|
||||
* initialized once. The purpose of these default generators is to
|
||||
* maintain a global running state of the pseudo random number generation,
|
||||
* when a user does not explicitly mention any generator.
|
||||
* getDefaultMKLGenerator gets the default generator for a particular
|
||||
* device.
|
||||
*/
|
||||
const Generator& getDefaultMKLGenerator() {
|
||||
static auto gen = createMKLGenerator(c10::detail::getNonDeterministicRandom());
|
||||
return gen;
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility to create an MKLGeneratorImpl. Returns a shared_ptr
|
||||
*/
|
||||
Generator createMKLGenerator(uint64_t seed_val) {
|
||||
return make_generator<MKLGeneratorImpl>(seed_val);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* MKLGeneratorImpl class implementation
|
||||
*/
|
||||
MKLGeneratorImpl::MKLGeneratorImpl(uint64_t seed_in)
|
||||
: c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(c10::DispatchKey::CPU)},
|
||||
seed_(seed_in),
|
||||
offset_(0) {
|
||||
vslNewStream(&stream_, VSL_BRNG_PHILOX4X32X10, seed_);
|
||||
}
|
||||
|
||||
/**
|
||||
* Manually seeds the engine with the seed input
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
void MKLGeneratorImpl::set_current_seed(uint64_t seed) {
|
||||
this->seed_ = seed;
|
||||
vslDeleteStream(&stream_);
|
||||
vslNewStream(&stream_, VSL_BRNG_PHILOX4X32X10, seed_);
|
||||
this->offset_ = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a nondeterministic random number from /dev/urandom or time,
|
||||
* seeds the MKLGeneratorImpl with it and then returns that number.
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
uint64_t MKLGeneratorImpl::seed() {
|
||||
auto random = c10::detail::getNonDeterministicRandom();
|
||||
this->set_current_seed(static_cast<uint64_t>(random));
|
||||
return random;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current seed of CPUGeneratorImpl.
|
||||
*/
|
||||
uint64_t MKLGeneratorImpl::current_seed() const {
|
||||
return this->seed_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the DeviceType of MKLGeneratorImpl.
|
||||
* Used for type checking during run time.
|
||||
*/
|
||||
DeviceType MKLGeneratorImpl::device_type() {
|
||||
return DeviceType::CPU;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the copy of VSLStreamStatePtr in MKLGenerator
|
||||
* to be used for variate generation in a thread-safe way
|
||||
* (each thread should receive its own stream copy).
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
void MKLGeneratorImpl::get_stream_copy(VSLStreamStatePtr &streamCopy) {
|
||||
vslCopyStream(&streamCopy, stream_);
|
||||
}
|
||||
|
||||
/**
|
||||
* Progresses the internal PRNG state n steps ahead --
|
||||
* used to account for variates generated by the copies
|
||||
* of the stream in MKLGenerator.
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
void MKLGeneratorImpl::skip_ahead(uint64_t n) {
|
||||
vslSkipAheadStream(stream_, n);
|
||||
this->advance_offset(n);
|
||||
}
|
||||
|
||||
/**
|
||||
* Private clone method implementation
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
MKLGeneratorImpl* MKLGeneratorImpl::clone_impl() const {
|
||||
auto gen = new MKLGeneratorImpl();
|
||||
return gen;
|
||||
}
|
||||
|
||||
/**
|
||||
* Public clone method implementation
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
std::shared_ptr<MKLGeneratorImpl> MKLGeneratorImpl::clone() const {
|
||||
return std::shared_ptr<MKLGeneratorImpl>(this->clone_impl());
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the offset of RNG state.
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
void MKLGeneratorImpl::set_offset(uint64_t offset) {
|
||||
TORCH_CHECK(false, "MKL Generator does not allow to set offset");
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the offset of RNG state.
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
uint64_t MKLGeneratorImpl::get_offset() const {
|
||||
return this->offset_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Private method to advance the offset of RNG state.
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
void MKLGeneratorImpl::advance_offset(uint64_t n) {
|
||||
this->offset_ += n;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current internal state of MKLGeneratorImpl. The internal
|
||||
* state is returned as a CPU byte tensor.
|
||||
*/
|
||||
c10::intrusive_ptr<c10::TensorImpl> MKLGeneratorImpl::get_state() const {
|
||||
TORCH_CHECK(false, "MKL Generator does not use get_state");
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the internal state of MKLGeneratorImpl. The new internal state
|
||||
* must be a strided CPU byte tensor.
|
||||
*/
|
||||
void MKLGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
TORCH_CHECK(false, "MKL Generator does not use set_state");
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
46
aten/src/ATen/mklrng/MKLGeneratorImpl.h
Normal file
46
aten/src/ATen/mklrng/MKLGeneratorImpl.h
Normal file
@ -0,0 +1,46 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/core/GeneratorImpl.h>
|
||||
#include <cstdint>
|
||||
|
||||
#include <mkl.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API MKLGeneratorImpl : public c10::GeneratorImpl {
|
||||
// Constructors
|
||||
MKLGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
|
||||
~MKLGeneratorImpl() override = default;
|
||||
|
||||
// MKLGeneratorImpl methods
|
||||
std::shared_ptr<MKLGeneratorImpl> clone() const;
|
||||
void set_current_seed(uint64_t seed) override;
|
||||
uint64_t seed() override;
|
||||
uint64_t current_seed() const override;
|
||||
void set_offset(uint64_t offset) override;
|
||||
uint64_t get_offset() const override;
|
||||
static c10::DeviceType device_type();
|
||||
void get_stream_copy(VSLStreamStatePtr &streamCopy);
|
||||
void skip_ahead(uint64_t n);
|
||||
void set_state(const c10::TensorImpl& new_state) override;
|
||||
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
||||
|
||||
private:
|
||||
MKLGeneratorImpl* clone_impl() const override;
|
||||
void advance_offset(uint64_t n);
|
||||
VSLStreamStatePtr stream_;
|
||||
uint64_t seed_;
|
||||
uint64_t offset_;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
TORCH_API const Generator& getDefaultMKLGenerator();
|
||||
TORCH_API Generator
|
||||
createMKLGenerator(uint64_t seed_val = default_rng_seed_val);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace at
|
||||
@ -18,9 +18,9 @@
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
// Disable MKL rng until https://github.com/pytorch/pytorch/issues/132395 is addressed
|
||||
#if AT_MKL_ENABLED() && defined(FBCODE_CAFFE2)
|
||||
#if AT_MKL_ENABLED()
|
||||
#include <mkl.h>
|
||||
#include <ATen/mklrng/MKLGeneratorImpl.h>
|
||||
#include <cpuinfo.h>
|
||||
#endif
|
||||
|
||||
@ -37,8 +37,7 @@ void bernoulli_tensor_kernel(const TensorBase &self, const TensorBase &p_, std::
|
||||
templates::cpu::bernoulli_kernel(self, p_, generator);
|
||||
}
|
||||
|
||||
// Disable MKL rng until https://github.com/pytorch/pytorch/issues/132395 is addressed
|
||||
#if !AT_MKL_ENABLED() || (AT_MKL_ENABLED() && !defined(FBCODE_CAFFE2))
|
||||
#if !AT_MKL_ENABLED()
|
||||
void bernoulli_scalar_kernel_default(const TensorBase &self, double p, std::optional<Generator> gen) {
|
||||
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
|
||||
templates::cpu::bernoulli_kernel(self, p, generator);
|
||||
@ -49,13 +48,6 @@ void bernoulli_scalar_kernel(const TensorBase &self, double p, std::optional<Gen
|
||||
}
|
||||
#else
|
||||
void bernoulli_scalar_kernel(const TensorBase &self, double p, std::optional<Generator> gen) {
|
||||
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
|
||||
int64_t seed;
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(generator->mutex_);
|
||||
seed = generator->random();
|
||||
}
|
||||
int64_t n = self.numel();
|
||||
bool contig = self.is_contiguous();
|
||||
|
||||
@ -71,15 +63,28 @@ void bernoulli_scalar_kernel(const TensorBase &self, double p, std::optional<Gen
|
||||
scalar_t *self_ptr = self.data_ptr<scalar_t>();
|
||||
int *sample_int_ptr = tmp_int_tensor.data_ptr<int>();
|
||||
|
||||
auto mklGenerator = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
|
||||
VSLStreamStatePtr main_stream;
|
||||
|
||||
// Get a local copy of the global stream and immediately advance the global
|
||||
// state before the generation step to avoid multiple threads using the same state.
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(mklGenerator->mutex_);
|
||||
mklGenerator->get_stream_copy(main_stream);
|
||||
mklGenerator->skip_ahead(n);
|
||||
}
|
||||
|
||||
auto sample = [&](int64_t begin, int64_t end) {
|
||||
int64_t len = end - begin;
|
||||
if (len > 0) {
|
||||
VSLStreamStatePtr stream;
|
||||
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
|
||||
vslSkipAheadStream(stream, begin);
|
||||
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, len,
|
||||
VSLStreamStatePtr sample_stream;
|
||||
vslCopyStream(&sample_stream, main_stream);
|
||||
vslSkipAheadStream(sample_stream, begin);
|
||||
|
||||
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, sample_stream, len,
|
||||
sample_int_ptr + begin, p);
|
||||
vslDeleteStream(&stream);
|
||||
vslDeleteStream(&sample_stream);
|
||||
|
||||
// vectorized copy if using buffer and contiguous, i.e., being non-int
|
||||
// type and contiguous
|
||||
@ -92,6 +97,7 @@ void bernoulli_scalar_kernel(const TensorBase &self, double p, std::optional<Gen
|
||||
};
|
||||
|
||||
parallel_for(0, n, /* grain_size= */ 800, sample);
|
||||
vslDeleteStream(&main_stream);
|
||||
|
||||
// copy_ if using buffer and non contiguous
|
||||
if (!contig) {
|
||||
@ -106,27 +112,15 @@ void exponential_kernel_default(TensorIteratorBase& iter, double lambda, std::op
|
||||
templates::cpu::exponential_kernel(iter, lambda, generator);
|
||||
}
|
||||
|
||||
// Disable MKL rng until https://github.com/pytorch/pytorch/issues/132395 is addressed
|
||||
#if (!AT_MKL_ENABLED() || defined(FBCODE_CAFFE2) || 1)
|
||||
#if (!AT_MKL_ENABLED() || defined(FBCODE_CAFFE2))
|
||||
void exponential_kernel(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
|
||||
exponential_kernel_default(iter, lambda, gen);
|
||||
}
|
||||
#else
|
||||
void exponential_kernel(TensorIteratorBase &iter, double lambda, std::optional<Generator> gen) {
|
||||
TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
|
||||
|
||||
Tensor self = iter.tensor(0);
|
||||
if (lambda > 0 && !std::isinf(lambda) && !std::isnan(lambda)) {
|
||||
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
|
||||
int64_t seed;
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(generator->mutex_);
|
||||
if (self.scalar_type() == at::kDouble)
|
||||
seed = generator->random64();
|
||||
else
|
||||
seed = generator->random();
|
||||
}
|
||||
int64_t n = self.numel();
|
||||
bool contig = self.is_contiguous();
|
||||
|
||||
@ -158,23 +152,35 @@ void exponential_kernel(TensorIteratorBase &iter, double lambda, std::optional<G
|
||||
// Variance: V[X+eps] = 1/lambda**2
|
||||
auto eps = std::numeric_limits<tmp_scalar_t>::min();
|
||||
|
||||
auto mklGenerator = check_generator<MKLGeneratorImpl>(detail::getDefaultMKLGenerator());
|
||||
VSLStreamStatePtr main_stream;
|
||||
|
||||
// Get a local copy of the global stream and immediately advance the global
|
||||
// state before the generation step to avoid multiple threads using the same state.
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(mklGenerator->mutex_);
|
||||
mklGenerator->get_stream_copy(main_stream);
|
||||
mklGenerator->skip_ahead(n);
|
||||
}
|
||||
|
||||
auto sample = [&](int64_t begin, int64_t end) {
|
||||
int64_t len = end - begin;
|
||||
if (len > 0) {
|
||||
VSLStreamStatePtr stream;
|
||||
VSLStreamStatePtr sample_stream;
|
||||
vslCopyStream(&sample_stream, main_stream);
|
||||
vslSkipAheadStream(sample_stream, begin);
|
||||
|
||||
if constexpr (std::is_same_v<scalar_t, double>) {
|
||||
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
|
||||
vslSkipAheadStream(stream, begin);
|
||||
vdRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, stream, len,
|
||||
vdRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, sample_stream, len,
|
||||
(double *)(sample_ptr + begin), eps, 1./lambda);
|
||||
vslDeleteStream(&stream);
|
||||
vslDeleteStream(&sample_stream);
|
||||
} else {
|
||||
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
|
||||
vslSkipAheadStream(stream, begin);
|
||||
vsRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, stream, len,
|
||||
vsRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, sample_stream, len,
|
||||
(float *) (sample_ptr + begin), eps, 1./lambda);
|
||||
vslDeleteStream(&stream);
|
||||
vslDeleteStream(&sample_stream);
|
||||
}
|
||||
|
||||
// vectorized copy if using buffer and contiguous
|
||||
if (!is_df && contig) {
|
||||
scalar_t *self_seg = self_ptr + begin;
|
||||
@ -185,6 +191,7 @@ void exponential_kernel(TensorIteratorBase &iter, double lambda, std::optional<G
|
||||
};
|
||||
|
||||
parallel_for(0, n, /* grain_size= */ 800, sample);
|
||||
vslDeleteStream(&main_stream);
|
||||
|
||||
// copy_ if using buffer and non contiguous
|
||||
if (!contig) {
|
||||
|
||||
@ -382,6 +382,20 @@ coverage_ignore_functions = [
|
||||
# torch.ao.quantization.backend_config.tensorrt
|
||||
"get_tensorrt_backend_config",
|
||||
"get_tensorrt_backend_config_dict",
|
||||
# torch.ao.quantization.backend_config.utils
|
||||
"entry_to_pretty_str",
|
||||
"get_fused_module_classes",
|
||||
"get_fuser_method_mapping",
|
||||
"get_fusion_pattern_to_extra_inputs_getter",
|
||||
"get_fusion_pattern_to_root_node_getter",
|
||||
"get_module_to_qat_module",
|
||||
"get_pattern_to_dtype_configs",
|
||||
"get_pattern_to_input_type_to_index",
|
||||
"get_qat_module_classes",
|
||||
"get_root_module_to_quantized_reference_module",
|
||||
"pattern_to_human_readable",
|
||||
"remove_boolean_dispatch_from_name",
|
||||
# torch.ao.quantization.backend_config.x86
|
||||
"get_x86_backend_config",
|
||||
# torch.ao.quantization.fuse_modules
|
||||
"fuse_known_modules",
|
||||
@ -412,6 +426,25 @@ coverage_ignore_functions = [
|
||||
"insert_observers_for_model",
|
||||
"prepare",
|
||||
"propagate_dtypes_for_known_nodes",
|
||||
# torch.ao.quantization.fx.utils
|
||||
"all_node_args_except_first",
|
||||
"all_node_args_have_no_tensors",
|
||||
"assert_and_get_unique_device",
|
||||
"collect_producer_nodes",
|
||||
"create_getattr_from_value",
|
||||
"create_node_from_old_node_preserve_meta",
|
||||
"get_custom_module_class_keys",
|
||||
"get_linear_prepack_op_for_dtype",
|
||||
"get_new_attr_name_with_prefix",
|
||||
"get_non_observable_arg_indexes_and_types",
|
||||
"get_qconv_prepack_op",
|
||||
"get_skipped_module_name_and_classes",
|
||||
"graph_module_from_producer_nodes",
|
||||
"maybe_get_next_module",
|
||||
"node_arg_is_bias",
|
||||
"node_arg_is_weight",
|
||||
"return_arg_list",
|
||||
# torch.ao.quantization.pt2e.graph_utils
|
||||
"bfs_trace_with_node_process",
|
||||
"find_sequential_partitions",
|
||||
"get_equivalent_types",
|
||||
@ -827,10 +860,80 @@ coverage_ignore_functions = [
|
||||
"get_latency_of_one_partition",
|
||||
"get_latency_of_partitioned_graph",
|
||||
"get_partition_to_latency_mapping",
|
||||
# torch.fx.experimental.proxy_tensor
|
||||
"decompose",
|
||||
"disable_autocast_cache",
|
||||
"disable_proxy_modes_tracing",
|
||||
"dispatch_trace",
|
||||
"extract_val",
|
||||
"fake_signature",
|
||||
"fetch_sym_proxy",
|
||||
"fetch_object_proxy",
|
||||
"get_innermost_proxy_mode",
|
||||
"get_isolated_graphmodule",
|
||||
"get_proxy_slot",
|
||||
"get_torch_dispatch_modes",
|
||||
"has_proxy_slot",
|
||||
"is_sym_node",
|
||||
"maybe_handle_decomp",
|
||||
"proxy_call",
|
||||
"set_meta",
|
||||
"set_original_aten_op",
|
||||
"set_proxy_slot",
|
||||
"snapshot_fake",
|
||||
"thunkify",
|
||||
"track_tensor",
|
||||
"track_tensor_tree",
|
||||
"wrap_key",
|
||||
"wrapper_and_args_for_make_fx",
|
||||
# torch.fx.experimental.recording
|
||||
"record_shapeenv_event",
|
||||
"replay_shape_env_events",
|
||||
"shape_env_check_state_equal",
|
||||
# torch.fx.experimental.sym_node
|
||||
"ceil_impl",
|
||||
"floor_ceil_helper",
|
||||
"floor_impl",
|
||||
"method_to_operator",
|
||||
"sympy_is_channels_last_contiguous_2d",
|
||||
"sympy_is_channels_last_contiguous_3d",
|
||||
"sympy_is_channels_last_strides_2d",
|
||||
"sympy_is_channels_last_strides_3d",
|
||||
"sympy_is_channels_last_strides_generic",
|
||||
"sympy_is_contiguous",
|
||||
"sympy_is_contiguous_generic",
|
||||
"to_node",
|
||||
"wrap_node",
|
||||
"sym_sqrt",
|
||||
# torch.fx.experimental.symbolic_shapes
|
||||
"bind_symbols",
|
||||
"cast_symbool_to_symint_guardless",
|
||||
"create_contiguous",
|
||||
"error",
|
||||
"eval_guards",
|
||||
"eval_is_non_overlapping_and_dense",
|
||||
"expect_true",
|
||||
"find_symbol_binding_fx_nodes",
|
||||
"free_symbols",
|
||||
"free_unbacked_symbols",
|
||||
"fx_placeholder_targets",
|
||||
"fx_placeholder_vals",
|
||||
"guard_bool",
|
||||
"guard_float",
|
||||
"guard_int",
|
||||
"guard_scalar",
|
||||
"has_hint",
|
||||
"has_symbolic_sizes_strides",
|
||||
"is_channels_last_contiguous_2d",
|
||||
"is_channels_last_contiguous_3d",
|
||||
"is_channels_last_strides_2d",
|
||||
"is_channels_last_strides_3d",
|
||||
"is_contiguous",
|
||||
"is_non_overlapping_and_dense_indicator",
|
||||
"is_nested_int",
|
||||
"is_symbol_binding_fx_node",
|
||||
"is_symbolic",
|
||||
# torch.fx.experimental.unification.core
|
||||
"reify",
|
||||
# torch.fx.experimental.unification.match
|
||||
"edge",
|
||||
@ -868,6 +971,24 @@ coverage_ignore_functions = [
|
||||
"reverse_dict",
|
||||
# torch.fx.experimental.unification.multipledispatch.variadic
|
||||
"isvariadic",
|
||||
# torch.fx.experimental.unification.unification_tools
|
||||
"assoc",
|
||||
"assoc_in",
|
||||
"dissoc",
|
||||
"first",
|
||||
"get_in",
|
||||
"getter",
|
||||
"groupby",
|
||||
"itemfilter",
|
||||
"itemmap",
|
||||
"keyfilter",
|
||||
"keymap",
|
||||
"merge",
|
||||
"merge_with",
|
||||
"update_in",
|
||||
"valfilter",
|
||||
"valmap",
|
||||
# torch.fx.experimental.unification.utils
|
||||
"freeze",
|
||||
"hashable",
|
||||
"raises",
|
||||
|
||||
@ -12,37 +12,6 @@ These APIs are experimental and subject to change without notice.
|
||||
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
|
||||
```
|
||||
|
||||
## torch.fx.experimental.sym_node
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
is_channels_last_contiguous_2d
|
||||
is_channels_last_contiguous_3d
|
||||
is_channels_last_strides_2d
|
||||
is_channels_last_strides_3d
|
||||
is_contiguous
|
||||
is_non_overlapping_and_dense_indicator
|
||||
method_to_operator
|
||||
sympy_is_channels_last_contiguous_2d
|
||||
sympy_is_channels_last_contiguous_3d
|
||||
sympy_is_channels_last_strides_2d
|
||||
sympy_is_channels_last_strides_3d
|
||||
sympy_is_channels_last_strides_generic
|
||||
sympy_is_contiguous
|
||||
sympy_is_contiguous_generic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.symbolic_shapes
|
||||
|
||||
```{eval-rst}
|
||||
@ -100,25 +69,6 @@ These APIs are experimental and subject to change without notice.
|
||||
rebind_unbacked
|
||||
resolve_unbacked_bindings
|
||||
is_accessor_node
|
||||
cast_symbool_to_symint_guardless
|
||||
create_contiguous
|
||||
error
|
||||
eval_guards
|
||||
eval_is_non_overlapping_and_dense
|
||||
find_symbol_binding_fx_nodes
|
||||
free_symbols
|
||||
free_unbacked_symbols
|
||||
fx_placeholder_targets
|
||||
fx_placeholder_vals
|
||||
guard_bool
|
||||
guard_float
|
||||
guard_int
|
||||
guard_scalar
|
||||
has_hint
|
||||
has_symbolic_sizes_strides
|
||||
is_nested_int
|
||||
is_symbol_binding_fx_node
|
||||
is_symbolic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.proxy_tensor
|
||||
@ -141,46 +91,4 @@ These APIs are experimental and subject to change without notice.
|
||||
get_proxy_mode
|
||||
maybe_enable_thunkify
|
||||
maybe_disable_thunkify
|
||||
decompose
|
||||
disable_autocast_cache
|
||||
disable_proxy_modes_tracing
|
||||
extract_val
|
||||
fake_signature
|
||||
fetch_object_proxy
|
||||
fetch_sym_proxy
|
||||
has_proxy_slot
|
||||
is_sym_node
|
||||
maybe_handle_decomp
|
||||
proxy_call
|
||||
set_meta
|
||||
set_original_aten_op
|
||||
set_proxy_slot
|
||||
snapshot_fake
|
||||
```
|
||||
|
||||
## torch.fx.experimental.unification.unification_tools
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
assoc
|
||||
assoc_in
|
||||
dissoc
|
||||
first
|
||||
keyfilter
|
||||
keymap
|
||||
merge
|
||||
merge_with
|
||||
update_in
|
||||
valfilter
|
||||
valmap
|
||||
|
||||
@ -1134,6 +1134,7 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.refinement_types
|
||||
.. py:module:: torch.fx.experimental.rewriter
|
||||
.. py:module:: torch.fx.experimental.schema_type_annotation
|
||||
.. py:module:: torch.fx.experimental.sym_node
|
||||
.. py:module:: torch.fx.experimental.unification.core
|
||||
.. py:module:: torch.fx.experimental.unification.dispatch
|
||||
.. py:module:: torch.fx.experimental.unification.match
|
||||
@ -1143,6 +1144,7 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
|
||||
.. py:module:: torch.fx.experimental.unification.unification_tools
|
||||
.. py:module:: torch.fx.experimental.unification.utils
|
||||
.. py:module:: torch.fx.experimental.unification.variable
|
||||
.. py:module:: torch.fx.experimental.unify_refinements
|
||||
|
||||
@ -134,23 +134,6 @@ Quantization to work with this as well.
|
||||
ObservationType
|
||||
```
|
||||
|
||||
## torch.ao.quantization.backend_config.utils
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.backend_config.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
entry_to_pretty_str
|
||||
pattern_to_human_readable
|
||||
remove_boolean_dispatch_from_name
|
||||
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.custom_config
|
||||
|
||||
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
|
||||
@ -171,30 +154,6 @@ This module contains a few CustomConfig classes that's used in both eager mode a
|
||||
StandaloneModuleConfigEntry
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.utils
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.fx.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
all_node_args_except_first
|
||||
all_node_args_have_no_tensors
|
||||
collect_producer_nodes
|
||||
create_getattr_from_value
|
||||
create_node_from_old_node_preserve_meta
|
||||
graph_module_from_producer_nodes
|
||||
maybe_get_next_module
|
||||
node_arg_is_bias
|
||||
node_arg_is_weight
|
||||
return_arg_list
|
||||
```
|
||||
|
||||
## torch.ao.quantization.quantizer
|
||||
|
||||
```{eval-rst}
|
||||
|
||||
@ -260,7 +260,6 @@ select = [
|
||||
"TRY401", # verbose-log-message
|
||||
"UP",
|
||||
"YTT",
|
||||
"S101",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pyupgrade]
|
||||
@ -340,39 +339,6 @@ keep-runtime-typing = true
|
||||
"tools/linter/**" = [
|
||||
"LOG015" # please fix
|
||||
]
|
||||
"benchmarks/**" = [
|
||||
"S101"
|
||||
]
|
||||
"test/**" = [
|
||||
"S101"
|
||||
]
|
||||
"torchgen/**" = [
|
||||
"S101"
|
||||
]
|
||||
"torch/**" = [
|
||||
"S101"
|
||||
]
|
||||
"tools/**" = [
|
||||
"S101"
|
||||
]
|
||||
"setup.py" = [
|
||||
"S101"
|
||||
]
|
||||
"functorch/**" = [
|
||||
"S101"
|
||||
]
|
||||
"docs/**" = [
|
||||
"S101"
|
||||
]
|
||||
"android/**" = [
|
||||
"S101"
|
||||
]
|
||||
".github/**" = [
|
||||
"S101"
|
||||
]
|
||||
".ci/**" = [
|
||||
"S101"
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words = "tools/linter/dictionary.txt"
|
||||
|
||||
@ -3402,14 +3402,15 @@ class TestDistributions(DistributionsTestCase):
|
||||
self.assertEqual((1.0 / lambd), mean, atol=2e-2, rtol=2e-2)
|
||||
self.assertEqual((1.0 / lambd) ** 2, var, atol=2e-2, rtol=2e-2)
|
||||
|
||||
set_rng_seed(2) # see Note [Randomized statistical tests]
|
||||
for dtype in [torch.float, torch.double, torch.bfloat16, torch.float16]:
|
||||
for lambd in [0.2, 0.5, 1.0, 1.5, 2.0, 5.0]:
|
||||
sample_len = 50000
|
||||
sample_len = 50_000
|
||||
mean_var(lambd, torch.rand(sample_len, dtype=dtype))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_exponential_sample(self):
|
||||
set_rng_seed(1) # see Note [Randomized statistical tests]
|
||||
set_rng_seed(2) # see Note [Randomized statistical tests]
|
||||
for rate in [1e-5, 1.0, 10.0]:
|
||||
self._check_sampler_sampler(
|
||||
Exponential(rate),
|
||||
@ -3569,7 +3570,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_pareto_sample(self):
|
||||
set_rng_seed(1) # see Note [Randomized statistical tests]
|
||||
set_rng_seed(3) # see Note [Randomized statistical tests]
|
||||
for scale, alpha in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(
|
||||
Pareto(scale, alpha),
|
||||
@ -7043,6 +7044,7 @@ class TestJit(DistributionsTestCase):
|
||||
)
|
||||
|
||||
def test_variance(self):
|
||||
set_rng_seed(3) # see Note [Randomized statistical tests]
|
||||
for Dist, keys, values, sample in self._examples():
|
||||
if Dist in [Cauchy, HalfCauchy]:
|
||||
continue # infinite variance
|
||||
|
||||
@ -1556,48 +1556,6 @@ class GraphModule(torch.nn.Module):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_unbacked_expr(self):
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return x + 1
|
||||
|
||||
def fn(c):
|
||||
d = torch.concat([c, c], dim=0)
|
||||
d = gn(d)
|
||||
return d
|
||||
|
||||
c = torch.randn((64, 32))
|
||||
torch._dynamo.decorators.mark_unbacked(c, 0)
|
||||
|
||||
ref = fn(c)
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
res = opt_fn(c)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_grad_acc(self):
|
||||
mod1 = torch.nn.Linear(8, 8)
|
||||
mod2 = torch.nn.Linear(8, 8)
|
||||
mod3 = torch.nn.Linear(8, 8)
|
||||
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return mod1(x) - mod2(x)
|
||||
|
||||
def fn(c):
|
||||
d = gn(c) - mod3(c)
|
||||
return d * 2
|
||||
|
||||
c = torch.randn((8, 8), requires_grad=True)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
res = opt_fn(c)
|
||||
res.sum().backward()
|
||||
|
||||
# The gradient addition node between mod1 and mode2 will be in the subgraph
|
||||
# The gradient addition node for mod3 is not in the subgraph.
|
||||
# print(backend.bw_graphs[0].print_readable())
|
||||
|
||||
def test_complex(self):
|
||||
# Observed in Wan2.1
|
||||
@nested_compile_region
|
||||
|
||||
@ -7522,38 +7522,6 @@ class AOTInductorTestsTemplate:
|
||||
eager_outputs = model(*example_inputs)
|
||||
torch.testing.assert_close(eager_outputs, compiled_outputs)
|
||||
|
||||
@requires_gpu
|
||||
def test_mixed_device_1(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("Mixed-device test requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Buffers are on CPU
|
||||
self.register_buffer(
|
||||
"index", torch.tensor([1, 4, 1, 7], device="cpu", dtype=torch.int64)
|
||||
)
|
||||
self.register_buffer(
|
||||
"src", torch.ones(4, device="cpu", dtype=torch.int64)
|
||||
)
|
||||
|
||||
def forward(self, matrix, vector):
|
||||
# Inputs are on CUDA
|
||||
# 1. Operation on CPU tensors
|
||||
z = torch.zeros((vector.shape[0],), device="cpu", dtype=torch.int64)
|
||||
scatter_result = z.scatter_add(0, self.index, self.src)
|
||||
|
||||
# 2. Move result to CUDA and continue on CUDA
|
||||
v = vector + scatter_result.to(vector.dtype).to(GPU_TYPE)
|
||||
return torch.matmul(matrix, v)
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(10, 10, device=self.device),
|
||||
torch.randn(10, device=self.device),
|
||||
)
|
||||
self.check_model(Model(), example_inputs, move_model_to_device=False)
|
||||
|
||||
|
||||
class AOTInductorLoggingTest(LoggingTestCase):
|
||||
@make_logging_test(dynamic=logging.DEBUG)
|
||||
|
||||
@ -218,7 +218,6 @@ def check_model(
|
||||
dynamic_shapes=None,
|
||||
atol=None,
|
||||
rtol=None,
|
||||
move_model_to_device=True,
|
||||
):
|
||||
with (
|
||||
torch.no_grad(),
|
||||
@ -230,7 +229,7 @@ def check_model(
|
||||
),
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
if not isinstance(model, types.FunctionType) and move_model_to_device:
|
||||
if not isinstance(model, types.FunctionType):
|
||||
model = model.to(self.device)
|
||||
|
||||
# For non mixed device inputs with default "cpu",set the device manually.
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["oncall: pt2"]
|
||||
import functools
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
@ -231,33 +230,6 @@ class PallasTestsMixin:
|
||||
self.assertIn("import jax.numpy as jnp", code)
|
||||
self.assertIn("from jax.experimental import pallas as pl", code)
|
||||
|
||||
def test_jax_jit_wrapper_is_emitted(self):
|
||||
"""Ensure generated Pallas code wraps pl.pallas_call in jax.jit."""
|
||||
|
||||
key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"
|
||||
|
||||
@torch.compile(backend="inductor", options={key: "pallas"})
|
||||
def pallas_fn(a, b):
|
||||
return a + b
|
||||
|
||||
_, (code,) = run_and_get_code(
|
||||
pallas_fn,
|
||||
torch.randn(32, device=self.DEVICE),
|
||||
torch.randn(32, device=self.DEVICE),
|
||||
)
|
||||
|
||||
kernel_match = re.search(r"def (pallas_[A-Za-z0-9_]+)_kernel", code)
|
||||
self.assertIsNotNone(kernel_match)
|
||||
kernel_name = kernel_match.group(1)
|
||||
wrapper_name = f"{kernel_name}_jit_wrapper"
|
||||
self.assertIn(wrapper_name, code)
|
||||
start = code.index(f"def {wrapper_name}")
|
||||
end = code.index(f"def {kernel_name}_main", start)
|
||||
wrapper_block = code[start:end]
|
||||
|
||||
self.assertIn("jax.jit", code)
|
||||
self.assertNotIn("torch.", wrapper_block)
|
||||
|
||||
def test_2d_tensor(self):
|
||||
"""Test with 2D tensors (though current implementation flattens)."""
|
||||
|
||||
|
||||
@ -1,236 +1,245 @@
|
||||
{
|
||||
"EndToEndLSTM (__main__.RNNTest)": 190.48799641927084,
|
||||
"MultiheadAttention (__main__.ModulesTest)": 141.2663370768229,
|
||||
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 82.87333234151204,
|
||||
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 70.6538565499442,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 123.34033711751302,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 171.25450134277344,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 119.71899922688802,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.35733322870163,
|
||||
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.64533233642578,
|
||||
"test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.672952016194664,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 138.04000091552734,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 172.1344985961914,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 114.02050018310547,
|
||||
"test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.25642830984933,
|
||||
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.3350003560384,
|
||||
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 120.95249938964844,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.97774887084961,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.90774917602539,
|
||||
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1144.3935089111328,
|
||||
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 222.58500061035156,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 501.10033162434894,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 517.1875050862631,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 113.88125228881836,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 235.77350616455078,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 74.6155014038086,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 66.63325119018555,
|
||||
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 216.2968317667643,
|
||||
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 153.0915012359619,
|
||||
"test_cat_2k_args (__main__.TestTEFuserDynamic)": 108.80471753561869,
|
||||
"test_cat_2k_args (__main__.TestTEFuserStatic)": 102.20949847949669,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 311.7026621500651,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 395.0001729329427,
|
||||
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 348.6218566894531,
|
||||
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 98.71574974060059,
|
||||
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 97.68499946594238,
|
||||
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 65.0557508468628,
|
||||
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 65.86899948120117,
|
||||
"test_comprehensive_gradient_cuda_complex64 (__main__.TestDecompCUDA)": 97.15880012512207,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 103.20700073242188,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 102.74033610026042,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 460.4286702473958,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 435.62066650390625,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 287.3090057373047,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 265.1860008239746,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1235.7365112304688,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.20825004577637,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1281.2615051269531,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.90750026702881,
|
||||
"test_comprehensive_linalg_householder_product_cuda_complex64 (__main__.TestDecompCUDA)": 79.04633331298828,
|
||||
"test_comprehensive_linalg_lu_factor_ex_cuda_complex128 (__main__.TestDecompCUDA)": 68.10879821777344,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.43025207519531,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.94575023651123,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.93649864196777,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.46275043487549,
|
||||
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 64.10650062561035,
|
||||
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.03124904632568,
|
||||
"test_comprehensive_linalg_svd_cuda_float64 (__main__.TestDecompCUDA)": 64.32800025939942,
|
||||
"test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 96.41353665865384,
|
||||
"test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 100.17661388103778,
|
||||
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 110.95025062561035,
|
||||
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 108.06550025939941,
|
||||
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 104.24150085449219,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.453749656677246,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 61.739999771118164,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 69.96549987792969,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 113.65749931335449,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 106.57500076293945,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 117.54049682617188,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 116.19766489664714,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 272.48475646972656,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 248.12175369262695,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 79.66900062561035,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 81.52649879455566,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 79.29400062561035,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.40349960327148,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 128.42924880981445,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 125.03675079345703,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1264.9732360839844,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1250.7332458496094,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1255.0684814453125,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 574.4627532958984,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 581.7282485961914,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 65.052001953125,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 61.19200134277344,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 63.16874885559082,
|
||||
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 62.39250183105469,
|
||||
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 113.32574844360352,
|
||||
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 113.91499900817871,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 74.42549800872803,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 76.1560001373291,
|
||||
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.76750087738037,
|
||||
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 70.69724941253662,
|
||||
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 69.87625026702881,
|
||||
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 80.2542495727539,
|
||||
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 69.0419979095459,
|
||||
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 117.03342655726841,
|
||||
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 289.50213841029574,
|
||||
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 67.38800048828125,
|
||||
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 145.27399444580078,
|
||||
"test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 66.9245999654134,
|
||||
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 151.91099548339844,
|
||||
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 92.79549789428711,
|
||||
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.60149955749512,
|
||||
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 69.27724676392972,
|
||||
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 76.24971498761859,
|
||||
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 81.93449974060059,
|
||||
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 78.87700080871582,
|
||||
"test_count_nonzero_all (__main__.TestBool)": 631.2585144042969,
|
||||
"test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 61.042999267578125,
|
||||
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.49850082397461,
|
||||
"test_dtensor_op_db_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 93.03299713134766,
|
||||
"test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 228.46711820714614,
|
||||
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 286.29998779296875,
|
||||
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 68.43842806134906,
|
||||
"test_fail_random.py (__main__.TestTyping)": 74.83523060725285,
|
||||
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 72.84900093078613,
|
||||
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 75.86675071716309,
|
||||
"test_fuse_large_params_cpu (__main__.CpuTests)": 151.4199981689453,
|
||||
"test_fuse_large_params_cuda (__main__.GPUTests)": 60.351999282836914,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 158.3622828892299,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 149.6796646118164,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 139.97800064086914,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 114.8385009765625,
|
||||
"test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 84.69736822027909,
|
||||
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.62700080871582,
|
||||
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 89.197998046875,
|
||||
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 96.46900177001953,
|
||||
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 187.83824920654297,
|
||||
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.49449920654297,
|
||||
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 124.90424919128418,
|
||||
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 518.4157485961914,
|
||||
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 304.6440022786458,
|
||||
"test_inductor_dynamic_shapes_broadcasting_dynamic_shapes (__main__.DynamicShapesReproTests)": 143.82052836698645,
|
||||
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 77.4985705784389,
|
||||
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 76.06225109100342,
|
||||
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 138.9222858973912,
|
||||
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 120.62233225504558,
|
||||
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 148.1219940185547,
|
||||
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 109.34200286865234,
|
||||
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 119.36233266194661,
|
||||
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 127.95700073242188,
|
||||
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 61.64850175380707,
|
||||
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 105.3174296787807,
|
||||
"test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 585.9210001627604,
|
||||
"test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 504.3250020345052,
|
||||
"test_lstm_cpu (__main__.TestMkldnnCPU)": 86.21566645304362,
|
||||
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 129.277715410505,
|
||||
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 64.24800109863281,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 77.23899841308594,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_True (__main__.TestCKBackend)": 65.15649795532227,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 62.579833984375,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.6555004119873,
|
||||
"test_pattern_matcher_multi_user_cpu (__main__.CpuTritonTests)": 142.21566772460938,
|
||||
"test_proper_exit (__main__.TestDataLoader)": 267.74214717320035,
|
||||
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 266.6539971487863,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 101.97100067138672,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.3346659342448,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 81.50300216674805,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 104.61333465576172,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.41133371988933,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.37100219726562,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.30900065104167,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.61750030517578,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 79.33600234985352,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 101.2393315633138,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 103.18400192260742,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.4114990234375,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 96.52833302815755,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.72700119018555,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 100.61966705322266,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.2750015258789,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.17449951171875,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.96749877929688,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 106.44049835205078,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 101.7173334757487,
|
||||
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 531.5236612955729,
|
||||
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1077.4210205078125,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 812.0880126953125,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1347.9365234375,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 88.93533070882161,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 269.01949310302734,
|
||||
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 131.99799601236978,
|
||||
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 232.36275100708008,
|
||||
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 69.80400085449219,
|
||||
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 134.3415012359619,
|
||||
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 67.51749992370605,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 91.21066792805989,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 170.97775268554688,
|
||||
"test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 61.608266321818036,
|
||||
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.62575149536133,
|
||||
"test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 63.59499969482422,
|
||||
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 88.68299865722656,
|
||||
"test_rnn_decomp_module_nn_LSTM_train_mode_cuda_float32 (__main__.TestDecompCUDA)": 91.50320053100586,
|
||||
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 66.10774898529053,
|
||||
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 66.20533180236816,
|
||||
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 243.1092529296875,
|
||||
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 105.01200103759766,
|
||||
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 107.93685695103237,
|
||||
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 142.38899993896484,
|
||||
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 119.90166600545247,
|
||||
"test_sort_bool_cpu (__main__.CpuTritonTests)": 346.2856750488281,
|
||||
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 423.09974098205566,
|
||||
"test_sort_stable_cuda (__main__.GPUTests)": 117.61659927368164,
|
||||
"test_sort_transpose_cpu (__main__.CpuTritonTests)": 378.31200154622394,
|
||||
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 222.822007894516,
|
||||
"test_terminate_handler_on_crash (__main__.TestTorch)": 143.31728431156702,
|
||||
"test_terminate_signal (__main__.ForkTest)": 168.20485967184817,
|
||||
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 168.19242484867573,
|
||||
"test_terminate_signal (__main__.SpawnTest)": 172.16428443363733,
|
||||
"test_thnn_conv_strided_padded_dilated (__main__.TestConvolutionNN)": 93.30639710426331,
|
||||
"test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 163.89743041992188,
|
||||
"test_train_parity_with_activation_checkpointing (__main__.TestFullyShard1DTrainingCompose)": 60.47671399797712,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 63.39550018310547,
|
||||
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 173.53924942016602,
|
||||
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 175.3212537765503,
|
||||
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 122.20649909973145,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 99.9885025024414,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 71.64024829864502,
|
||||
"test_view_ops (__main__.TestViewOpsWithLocalTensor)": 73.45887422561646,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 95.75249862670898,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 61.858001708984375,
|
||||
"test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 65.11023766653878,
|
||||
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 66.35274982452393,
|
||||
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 61.196499824523926,
|
||||
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.75380906604585,
|
||||
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 73.64649868011475,
|
||||
"test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 75.09799966358003,
|
||||
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.51450157165527,
|
||||
"test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 66.21433276221866,
|
||||
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 73.20024871826172,
|
||||
"test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 88.1349983215332,
|
||||
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 76.89924907684326,
|
||||
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 77.32975196838379,
|
||||
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 120.09600067138672
|
||||
"EndToEndLSTM (__main__.RNNTest)": 207.89400227864584,
|
||||
"MultiheadAttention (__main__.ModulesTest)": 141.1396687825521,
|
||||
"test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 214.02366638183594,
|
||||
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 77.26125049591064,
|
||||
"test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 116.37000020345052,
|
||||
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 69.25722334120009,
|
||||
"test_after_aot_gpu_runtime_error (__main__.MinifierIsolateTests)": 65.84466807047527,
|
||||
"test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 178.41399637858072,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.55014337812151,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 122.18047623407273,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 192.6405719575428,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.27904801141648,
|
||||
"test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.906999588012695,
|
||||
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 62.244998931884766,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 150.04100036621094,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 191.85050201416016,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.9276631673177,
|
||||
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.31450271606445,
|
||||
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 125.24066416422527,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.47783279418945,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.46250025431316,
|
||||
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1031.0534973144531,
|
||||
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 239.67400105794272,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 495.0447726779514,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 490.18524169921875,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 144.06477737426758,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 342.20416259765625,
|
||||
"test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 62.01366678873698,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 71.07200050354004,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 73.9221674601237,
|
||||
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 226.0122528076172,
|
||||
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 144.97249857584634,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 303.20537185668945,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 386.0518798828125,
|
||||
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 291.2442270914714,
|
||||
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 95.87866719563802,
|
||||
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 98.38716634114583,
|
||||
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 69.08016649881999,
|
||||
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 69.88233311971028,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 104.17599995930989,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 97.41800308227539,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 474.6719970703125,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 440.4375,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 293.3983332316081,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 238.7328338623047,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1218.4906717936199,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.73516782124837,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1156.0123494466145,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.13916714986165,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.90450032552083,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.42100016276042,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.98883310953777,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 73.34433364868164,
|
||||
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.38016573588053,
|
||||
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.52783330281575,
|
||||
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 111.06333287556966,
|
||||
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 110.19833374023438,
|
||||
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 113.10083134969075,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.23766644795736,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 70.18666712443034,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 62.61399841308594,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 67.7816670735677,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 121.6183344523112,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 107.30266698201497,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 130.8143310546875,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 127.27633412679036,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 303.55183664957684,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 234.41216532389322,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 85.3436673482259,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 80.9688326517741,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 82.55149968465169,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.37966791788737,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 129.88233184814453,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 129.4015007019043,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1282.3826497395833,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1270.64599609375,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1297.9046630859375,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 545.2034962972006,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 572.5616760253906,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 64.40316645304362,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.68383344014485,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.48333422342936,
|
||||
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 61.959999084472656,
|
||||
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 105.79100036621094,
|
||||
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 122.34666570027669,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 68.7205015818278,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.2183329264323,
|
||||
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.86883227030437,
|
||||
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 77.48183314005534,
|
||||
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 79.1564998626709,
|
||||
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 160.41250228881836,
|
||||
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 79.10633341471355,
|
||||
"test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 60.106833140055336,
|
||||
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 221.3586196899414,
|
||||
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 504.3203754425049,
|
||||
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 78.03233337402344,
|
||||
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 152.302001953125,
|
||||
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 152.99433390299478,
|
||||
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 96.25399971008301,
|
||||
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 75.70275068283081,
|
||||
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 139.14399747674665,
|
||||
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 72.7847490310669,
|
||||
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 91.59966786702473,
|
||||
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 87.57833353678386,
|
||||
"test_count_nonzero_all (__main__.TestBool)": 664.9986343383789,
|
||||
"test_cp_flex_attention_document_mask (__main__.CPFlexAttentionTest)": 78.31500244140625,
|
||||
"test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 385.24249792099,
|
||||
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.70466740926106,
|
||||
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 685.0679931640625,
|
||||
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 86.26266733805339,
|
||||
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 292.93699645996094,
|
||||
"test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 66.84199905395508,
|
||||
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 69.56212568283081,
|
||||
"test_fail_creation_ops.py (__main__.TestTyping)": 69.80560022989908,
|
||||
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 73.36666552225749,
|
||||
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 90.40366744995117,
|
||||
"test_fuse_large_params_cpu (__main__.CpuTests)": 132.73199844360352,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 150.16662406921387,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 159.28499794006348,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 165.19283294677734,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 151.12366739908853,
|
||||
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.61699930826823,
|
||||
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.00600179036458,
|
||||
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.3759994506836,
|
||||
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 190.89249674479166,
|
||||
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 149.6598358154297,
|
||||
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 146.07766723632812,
|
||||
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 532.8139902750651,
|
||||
"test_graph_partition_refcount_cuda (__main__.GPUTests)": 69.78400001525878,
|
||||
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 267.04988850487604,
|
||||
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 273.54955800374347,
|
||||
"test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 195.84733072916666,
|
||||
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 326.0143330891927,
|
||||
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.96037435531616,
|
||||
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 77.44933319091797,
|
||||
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 126.81488884819879,
|
||||
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 118.70199839274089,
|
||||
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 129.20266723632812,
|
||||
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 97.18800099690755,
|
||||
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 130.3183339436849,
|
||||
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 140.43233235677084,
|
||||
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 293.122774971856,
|
||||
"test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 63.835832277933754,
|
||||
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 106.77049922943115,
|
||||
"test_lstm_cpu (__main__.TestMkldnnCPU)": 100.89649963378906,
|
||||
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 140.07424926757812,
|
||||
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 72.90299733479817,
|
||||
"test_max_autotune_addmm_search_space_EXHAUSTIVE_dynamic_True (__main__.TestMaxAutotuneSubproc)": 82.62433369954427,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 87.51499938964844,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_True_use_aoti_True (__main__.TestCKBackend)": 71.22416591644287,
|
||||
"test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 424.50966389973956,
|
||||
"test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 134.14600626627603,
|
||||
"test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 358.88099161783856,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.58866712782118,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 62.68674945831299,
|
||||
"test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 65.85794713936355,
|
||||
"test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTest)": 103.6923344930013,
|
||||
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTest)": 187.6953328450521,
|
||||
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 370.27442932128906,
|
||||
"test_proper_exit (__main__.TestDataLoader)": 227.83111148410373,
|
||||
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 227.1901126437717,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.52099990844727,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 106.50249862670898,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 92.52400207519531,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 111.75499725341797,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 107.40500259399414,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 83.80450057983398,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 107.46599833170573,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.65650177001953,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 83.4114990234375,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.47100067138672,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 108.55533345540364,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 89.23666381835938,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.13900375366211,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 100.14550018310547,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.33649826049805,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.08150100708008,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 97.59600067138672,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 104.82933553059895,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 114.43099721272786,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 110.40333302815755,
|
||||
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 567.2765197753906,
|
||||
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1032.5083312988281,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 852.7170003255209,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1361.954854329427,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 77.385498046875,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 265.0193354288737,
|
||||
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 115.31749725341797,
|
||||
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 245.27666727701822,
|
||||
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 71.75300216674805,
|
||||
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 141.8895009358724,
|
||||
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 71.15749994913737,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 90.59066772460938,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 173.73916625976562,
|
||||
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.65066655476888,
|
||||
"test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 99.21799850463867,
|
||||
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 90.86299896240234,
|
||||
"test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 66.57050196329753,
|
||||
"test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 69.65149958928426,
|
||||
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 78.13350168863933,
|
||||
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 76.85255601671007,
|
||||
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 333.04866282145184,
|
||||
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 146.96599833170572,
|
||||
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 160.4881100124783,
|
||||
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 124.10055626763238,
|
||||
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 117.38410907321506,
|
||||
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 710.2327779134115,
|
||||
"test_sort_stable_cpu (__main__.CpuTritonTests)": 1324.4399820963542,
|
||||
"test_sort_stable_cuda (__main__.GPUTests)": 76.83109970092774,
|
||||
"test_split_cumsum_cpu (__main__.CpuTritonTests)": 88.58433532714844,
|
||||
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 160.1271684964498,
|
||||
"test_tensor_split (__main__.TestVmapOperators)": 79.18955569393519,
|
||||
"test_terminate_handler_on_crash (__main__.TestTorch)": 111.30388899644215,
|
||||
"test_terminate_signal (__main__.ForkTest)": 132.3458870516883,
|
||||
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 132.2043343567186,
|
||||
"test_terminate_signal (__main__.SpawnTest)": 136.1005539894104,
|
||||
"test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 76.20899939537048,
|
||||
"test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 63.82099969046457,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 61.925000508626304,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 60.89849980672201,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 66.88233375549316,
|
||||
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.9854990641276,
|
||||
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.4044977823893,
|
||||
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 108.19166437784831,
|
||||
"test_unary_ops (__main__.TestTEFuserDynamic)": 96.32655514611139,
|
||||
"test_unary_ops (__main__.TestTEFuserStatic)": 105.33362591266632,
|
||||
"test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 97.8336664835612,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 82.86566925048828,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 68.26500002543132,
|
||||
"test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 97.1120007832845,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 88.24766794840495,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 65.41266759236653,
|
||||
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 74.75533294677734,
|
||||
"test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 73.52500089009602,
|
||||
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.85466639200847,
|
||||
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 98.39650090535481,
|
||||
"test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 61.39695285615467,
|
||||
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 77.88249842325847,
|
||||
"test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.0695006052653,
|
||||
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 81.86250114440918,
|
||||
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 98.63116455078125,
|
||||
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 94.85683314005534,
|
||||
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 173.00183614095053
|
||||
}
|
||||
@ -221,9 +221,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
"""
|
||||
)
|
||||
|
||||
for device in V.graph.device_types:
|
||||
if device != "meta":
|
||||
self.add_device_include(device)
|
||||
self.add_device_include(self.device)
|
||||
|
||||
if V.graph.aot_mode:
|
||||
if config.aot_inductor.dynamic_linkage:
|
||||
@ -1425,13 +1423,11 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
src_is_tensor,
|
||||
reduce,
|
||||
kwargs,
|
||||
device,
|
||||
):
|
||||
reduce = self._get_scatter_reduce_enum(reduce)
|
||||
|
||||
# call the ABI shim function instead of the ATen one
|
||||
self.add_device_include(device)
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
||||
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
||||
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
||||
inputs_wrapped = [str(x) for x in inputs]
|
||||
|
||||
@ -708,14 +708,11 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
||||
src_is_tensor,
|
||||
reduce,
|
||||
kwargs,
|
||||
device,
|
||||
):
|
||||
reduce = self._get_scatter_reduce_enum(reduce)
|
||||
|
||||
# call the ABI shim function instead of the ATen one
|
||||
self.add_device_include(device)
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
|
||||
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
||||
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
||||
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
||||
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
|
||||
|
||||
@ -287,7 +287,6 @@ class PallasKernel(SIMDKernel):
|
||||
code = IndentedBuffer()
|
||||
code.splice(
|
||||
"""
|
||||
import functools
|
||||
import torch
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -302,9 +301,6 @@ class PallasKernel(SIMDKernel):
|
||||
kernel_params = [a.name for a in arg_defs]
|
||||
|
||||
kernel_name = name or "<KERNEL_NAME>"
|
||||
interpret_literal = (
|
||||
"True" if V.graph.get_current_device_or_throw().type == "cpu" else "False"
|
||||
)
|
||||
code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):")
|
||||
with code.indent():
|
||||
# Emit compute (CSE) and store lines; they reference *_ptr[...] directly
|
||||
@ -313,22 +309,16 @@ class PallasKernel(SIMDKernel):
|
||||
for line in self.stores._lines:
|
||||
code.writeline(str(line))
|
||||
|
||||
jit_wrapper_name = f"{kernel_name}_jit_wrapper"
|
||||
code.writeline("@functools.partial(jax.jit, static_argnums=(0, 1))")
|
||||
code.writeline(f"def {jit_wrapper_name}(out_shape, out_dtype, *kernel_refs):")
|
||||
with code.indent():
|
||||
code.writeline("out_spec = jax.ShapeDtypeStruct(out_shape, out_dtype)")
|
||||
code.writeline("return pl.pallas_call(")
|
||||
code.writeline(f" {kernel_name}_kernel,")
|
||||
code.writeline(" out_shape=out_spec,")
|
||||
code.writeline(f" interpret={interpret_literal},")
|
||||
code.writeline(" grid=(1,),")
|
||||
code.writeline(")(*kernel_refs)")
|
||||
|
||||
# Host entry: convert torch tensors <-> jax, call pallas_call and copy back
|
||||
main_name = f"{kernel_name}_main"
|
||||
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
|
||||
with code.indent():
|
||||
# Determine interpret statically based on codegen device
|
||||
interpret_literal = (
|
||||
"True"
|
||||
if V.graph.get_current_device_or_throw().type == "cpu"
|
||||
else "False"
|
||||
)
|
||||
# Identify inputs (in_ptr*) and output (out_ptr*)
|
||||
input_params = [
|
||||
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
|
||||
@ -347,9 +337,9 @@ class PallasKernel(SIMDKernel):
|
||||
for inp in input_params:
|
||||
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
|
||||
|
||||
# Get output metadata from PyTorch tensor
|
||||
code.writeline("# Prepare output metadata from PyTorch tensor")
|
||||
code.writeline("# Map PyTorch dtype to JAX dtype")
|
||||
# Get output spec from PyTorch tensor
|
||||
code.writeline("# Prepare output spec from PyTorch tensor")
|
||||
code.writeline("# Map PyTorch dtype to JAX dtype string")
|
||||
code.writeline("_torch_dtype_to_jax = {")
|
||||
code.writeline(
|
||||
" torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,"
|
||||
@ -359,14 +349,21 @@ class PallasKernel(SIMDKernel):
|
||||
)
|
||||
code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,")
|
||||
code.writeline("}")
|
||||
code.writeline(f"out_shape = tuple({output_param}.shape)")
|
||||
code.writeline(f"out_dtype = _torch_dtype_to_jax[{output_param}.dtype]")
|
||||
code.writeline(
|
||||
f"out_spec = jax.ShapeDtypeStruct({output_param}.shape, _torch_dtype_to_jax[{output_param}.dtype])"
|
||||
)
|
||||
|
||||
call_args = ["out_shape", "out_dtype"] + [
|
||||
f"{inp}_jax" for inp in input_params
|
||||
]
|
||||
call_arg_str = ", ".join(call_args)
|
||||
code.writeline(f"res = {jit_wrapper_name}({call_arg_str})")
|
||||
# Call pallas
|
||||
# Pass interpret=True on CPU, False otherwise (single call, no duplication)
|
||||
code.writeline("compiled = pl.pallas_call(")
|
||||
code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),")
|
||||
code.writeline(" out_shape=out_spec,")
|
||||
code.writeline(f" interpret={interpret_literal},")
|
||||
code.writeline(" grid=(1,),")
|
||||
code.writeline(")")
|
||||
|
||||
jax_input_args = ", ".join([f"{inp}_jax" for inp in input_params])
|
||||
code.writeline(f"res = compiled({jax_input_args})")
|
||||
|
||||
# Copy result back
|
||||
code.writeline("# Copy result back into the provided torch output tensor")
|
||||
|
||||
@ -971,7 +971,6 @@ class ScatterFallbackLine(WrapperLine):
|
||||
else:
|
||||
(x, index) = (t.codegen_reference() for t in node.inputs)
|
||||
src = node.constant_args[1]
|
||||
device = d.type if (d := node.get_device()) else V.graph.device_type
|
||||
self.wrapper._generate_scatter_fallback(
|
||||
x,
|
||||
[x, node.constant_args[0], index, src],
|
||||
@ -980,7 +979,6 @@ class ScatterFallbackLine(WrapperLine):
|
||||
node.src_is_tensor,
|
||||
node.kwargs["reduce"],
|
||||
node.codegen_kwargs(),
|
||||
device,
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
@ -1634,7 +1632,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||
src_is_tensor,
|
||||
reduce,
|
||||
kwargs,
|
||||
device,
|
||||
):
|
||||
line = f"{python_kernel_name}({','.join(map(str, inputs))}"
|
||||
if python_kernel_name.startswith("aten.scatter_reduce"):
|
||||
|
||||
@ -195,12 +195,10 @@ def get_new_attr_name_with_prefix(prefix: str) -> Callable:
|
||||
def collect_producer_nodes(node: Node) -> Optional[list[Node]]:
|
||||
r"""Starting from a target node, trace back until we hit input or
|
||||
getattr node. This is used to extract the chain of operators
|
||||
starting from getattr to the target node, for example::
|
||||
|
||||
def forward(self, x):
|
||||
observed = self.observer(self.weight)
|
||||
return F.linear(x, observed)
|
||||
|
||||
starting from getattr to the target node, for example
|
||||
def forward(self, x):
|
||||
observed = self.observer(self.weight)
|
||||
return F.linear(x, observed)
|
||||
collect_producer_nodes(observed) will either return a list of nodes that
|
||||
produces the observed node or None if we can't extract a self contained
|
||||
graph without free variables(inputs of the forward function).
|
||||
|
||||
Reference in New Issue
Block a user