Revert "[MPS] Expand fused forloop to bfloat16 (#141104)"

This reverts commit 9a729390420570cd2528ce2e9947e3eab209660b.

Reverted https://github.com/pytorch/pytorch/pull/141104 on behalf of https://github.com/malfet due to Want to add test script to the commit message ([comment](https://github.com/pytorch/pytorch/pull/141104#issuecomment-2492659931))
This commit is contained in:
PyTorch MergeBot
2024-11-22 01:03:43 +00:00
parent e8de8f3969
commit 989888236e
3 changed files with 33 additions and 56 deletions

View File

@ -1,12 +1,5 @@
#include <metal_stdlib>
using metal::max;
#if __METAL_VERSION__ >= 310
bfloat max(bfloat a, bfloat b) {
return a > b ? a : b;
}
#endif
#define kmaxThreadGroups 32
#define kmaxTensors 32
#define chunk_size 65536
@ -88,28 +81,26 @@ inline void adam_math_amsgrad(
if (weight_decay != 0) {
switch (adam_mode) {
case ADAM_MODE::ORIGINAL:
grad += T(param * weight_decay);
grad += param * weight_decay;
break;
case ADAM_MODE::ADAMW:
param -= T(lr * weight_decay * param);
param -= lr * weight_decay * param;
break;
}
}
exp_avg = T(beta1 * exp_avg + (1 - beta1) * grad);
exp_avg_sq = T(beta2 * exp_avg_sq + (1 - beta2) * grad * grad);
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
const float casted_state_steps = static_cast<float>(state_steps);
const auto bias_correction1 =
1 - metal::precise::pow(beta1, casted_state_steps);
const auto step_size = lr / bias_correction1;
const auto bias_correction2 =
1 - metal::precise::pow(beta2, casted_state_steps);
const auto bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
max_exp_avg_sq = max(max_exp_avg_sq, exp_avg_sq);
const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
const T step_size = lr / bias_correction1;
const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq);
const auto denom =
const T denom =
(metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
param -= T(step_size * exp_avg / denom);
param -= step_size * exp_avg / denom;
grad = grad_;
}
@ -136,26 +127,24 @@ inline void adam_math(
if (weight_decay != 0) {
switch (adam_mode) {
case ADAM_MODE::ORIGINAL:
grad += T(param * weight_decay);
grad += param * weight_decay;
break;
case ADAM_MODE::ADAMW:
param -= T(lr * weight_decay * param);
param -= lr * weight_decay * param;
break;
}
}
exp_avg = T(beta1 * exp_avg + (1 - beta1) * grad);
exp_avg_sq = T(beta2 * exp_avg_sq + (1 - beta2) * grad * grad);
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
const float casted_state_steps = static_cast<float>(state_steps);
const auto bias_correction1 =
1 - metal::precise::pow(beta1, casted_state_steps);
const auto step_size = lr / bias_correction1;
const auto bias_correction2 =
1 - metal::precise::pow(beta2, casted_state_steps);
const auto bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
const auto denom =
const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
const T step_size = lr / bias_correction1;
const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
const T denom =
(metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps;
param -= T(step_size * exp_avg / denom);
param -= step_size * exp_avg / denom;
grad = grad_;
}
@ -306,11 +295,6 @@ REGISTER_ADAM_OPS_QUART(float, float);
REGISTER_ADAM_OPS_QUART(float, half);
REGISTER_ADAM_OPS_QUART(half, float);
REGISTER_ADAM_OPS_QUART(half, half);
#if __METAL_VERSION__ >= 310
REGISTER_ADAM_OPS_QUART(float, bfloat);
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
REGISTER_ADAM_OPS_QUART(bfloat, float);
#endif
template <typename T>
inline void sgd_momentum_math(
@ -326,22 +310,22 @@ inline void sgd_momentum_math(
const uint8_t is_first_step) {
auto grad_ = grad;
if (maximize) {
grad_ *= T(-1.0);
grad_ *= -1.0;
}
if (weight_decay != 0) {
grad_ += T(weight_decay * param);
grad_ += weight_decay * param;
}
momentum_buffer = is_first_step
? grad_
: T(momentum * momentum_buffer + (1 - dampening) * grad_);
: (momentum * momentum_buffer + (1 - dampening) * grad_);
if (nesterov) {
grad_ += T(momentum * momentum_buffer);
grad_ += momentum * momentum_buffer;
} else {
grad_ = momentum_buffer;
}
param -= T(lr * grad_);
param -= lr * grad_;
}
template <typename T>
@ -353,13 +337,13 @@ inline void sgd_math(
const uint8_t maximize) {
auto grad_ = grad;
if (maximize) {
grad_ *= T(-1.0);
grad_ *= -1.0;
}
if (weight_decay != 0) {
grad_ += T(weight_decay * param);
grad_ += weight_decay * param;
}
param -= T(lr * grad_);
param -= lr * grad_;
}
template <typename T>
@ -460,7 +444,3 @@ REGISTER_FUSED_SGD_OP(float);
REGISTER_FUSED_SGD_OP(half);
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_FUSED_SGD_OP(bfloat);
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
#endif

View File

@ -131,9 +131,9 @@ static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_nam
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth");
for (const auto& d : c10::irange(depth)) {
const auto scalar_type = tensor_lists[d][0].scalar_type();
TORCH_CHECK(scalar_type == kFloat || scalar_type == kHalf || scalar_type == kBFloat16,
"Only float, bfloat and half are supported");
TORCH_CHECK(tensor_lists[d][0].scalar_type() == at::ScalarType::Float ||
tensor_lists[d][0].scalar_type() == at::ScalarType::Half,
"Only float and half are supported");
}
id<MTLDevice> device = MPSDevice::getInstance()->device();

View File

@ -1027,11 +1027,8 @@ class TestOptimRenewed(TestCase):
if _get_device_type(device) == "mps" and dtype not in (
torch.float16,
torch.float32,
torch.bfloat16,
):
self.skipTest(
"MPS supports only torch.float16, torch.float32 and torch.bfloat16"
)
self.skipTest("MPS supports only torch.float16 and torch.float32")
self._test_derived_optimizers(device, dtype, optim_info, "fused")
@optims(