mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix boxcox to return same result for same input in one batch (#162772)
Summary: The SIMD path is using SLEEF version of `pow` which is slightly different from `std::pow`. The fix is to use the same vectorized code (with partial load and store) for the trailing data as well to ensure consistency between results. Rollback Plan: Differential Revision: D82265247 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162772 Approved by: https://github.com/swolchok
This commit is contained in:
committed by
PyTorch MergeBot
parent
66133b1ab7
commit
49d30f9a23
@ -73,6 +73,19 @@ void box_cox_zero_lambda(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::vec::Vectorized<T> box_cox_nonzero_lambda_impl(
|
||||
at::vec::Vectorized<T> data,
|
||||
at::vec::Vectorized<T> lambda1,
|
||||
at::vec::Vectorized<T> lambda2,
|
||||
at::vec::Vectorized<T> k_eps) {
|
||||
auto sum = data + lambda2;
|
||||
auto max = at::vec::max(sum, k_eps);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
|
||||
auto pow = max.pow(lambda1);
|
||||
return at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void box_cox_nonzero_lambda(
|
||||
int64_t D,
|
||||
@ -88,21 +101,18 @@ void box_cox_nonzero_lambda(
|
||||
auto k_eps_vec = Vec(k_eps);
|
||||
for(; j + VLEN < D; j += VLEN) {
|
||||
auto data = Vec::loadu(data_ptr + j);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j);
|
||||
auto sum = data + lambda2;
|
||||
auto max = at::vec::max(sum, k_eps_vec);
|
||||
auto lambda1 = Vec::loadu(lambda1_ptr + j);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
|
||||
auto pow = max.pow(lambda1);
|
||||
auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j);
|
||||
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
|
||||
res.store(out + j);
|
||||
}
|
||||
for ( ;j < D; ++j) {
|
||||
auto sum = data_ptr[j] + lambda2_ptr[j];
|
||||
auto max = std::max(sum, k_eps);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]);
|
||||
auto pow = std::pow(max, lambda1_ptr[j]);
|
||||
out[j] = pow * lambda_over_1 - lambda_over_1;
|
||||
if (j < D) {
|
||||
auto remaining = D - j;
|
||||
auto data = Vec::loadu(data_ptr + j, remaining);
|
||||
auto lambda1 = Vec::loadu(lambda1_ptr + j, remaining);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j, remaining);
|
||||
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
|
||||
res.store(out + j, remaining);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
Reference in New Issue
Block a user