[ATen] Add CPU fp16 support for nll_loss and cross_entropy_loss (#123256)

Add CPU FP16 support for nll_loss and cross_entropy_loss.
Resolve issue #123328.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123256
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/malfet
This commit is contained in:
xinan.lin
2024-04-18 01:03:38 -07:00
committed by PyTorch MergeBot
parent d59f1da62f
commit 6fcbeb3489
6 changed files with 29 additions and 20 deletions

View File

@ -304,8 +304,12 @@ void nll_loss_forward_out_cpu_template(
const Tensor& weight,
int64_t reduction,
int64_t ignore_index) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, input.scalar_type(), "nll_loss_out_frame", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16,
ScalarType::Half,
input.scalar_type(),
"nll_loss_out_frame",
[&] {
if (target.scalar_type() == kByte) {
nll_loss_out_frame<scalar_t, uint8_t>(
output,
@ -415,8 +419,9 @@ void nll_loss_backward_out_cpu_template(
const Tensor& total_weight) {
grad_input.zero_();
AT_DISPATCH_FLOATING_TYPES_AND(
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16,
ScalarType::Half,
input.scalar_type(),
"nll_loss_backward_out_frame",
[&] {

View File

@ -262,8 +262,9 @@ void nll_loss2d_forward_out_cpu_template(
check_inputs_nll_loss2d(input, target, weight);
total_weight.resize_({});
AT_DISPATCH_FLOATING_TYPES_AND(
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16,
ScalarType::Half,
input.scalar_type(),
"nll_loss2d_forward_out_frame",
[&] {
@ -383,8 +384,9 @@ void nll_loss2d_backward_out_cpu_template(
total_weight.numel(),
" elements)");
AT_DISPATCH_FLOATING_TYPES_AND(
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16,
ScalarType::Half,
input.scalar_type(),
"nll_loss2d_backward_out_frame",
[&] {

View File

@ -2001,6 +2001,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3],
"nn.functional.local_response_norm": [1e-2, 5e-3],
"nn.functional.poisson_nll_loss": [3e-2, 1e-3],
"nn.functional.nll_loss": [3e-2, 1e-3],
"native_batch_norm": [3e-2, 1e-3],
"dot": [3e-2, 1e-3],
"logit": [3e-2, 1e-3],

View File

@ -171,6 +171,8 @@ def mps_ops_grad_modifier(ops):
'nn.functional.conv_transpose1d': [torch.float16],
'nn.functional.conv_transpose2d': [torch.float16],
'nn.functional.conv_transpose3d': [torch.float16],
'nn.functional.nll_loss': [torch.float16],
'nn.functional.cross_entropy': [torch.float16],
}
MACOS_13_3_XFAILLIST_GRAD = {
@ -987,7 +989,10 @@ def mps_ops_modifier(ops):
'nn.functional.avg_pool2d': [torch.float16],
# input types 'tensor<f32>' and 'tensor<1xf16>' are not broadcast compatible
# Refer to the issue please: https://github.com/pytorch/pytorch/issues/124252
'nn.functional.binary_cross_entropy': [torch.float16]
'nn.functional.binary_cross_entropy': [torch.float16],
'nn.functional.nll_loss': [torch.float16],
'nn.functional.cross_entropy': [torch.float16],
}
def addDecorator(op, d) -> None:

View File

@ -13009,8 +13009,7 @@ op_db: List[OpInfo] = [
supports_out=False),
OpInfo(
"nn.functional.cross_entropy",
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_cross_entropy,
supports_out=False,
supports_forward_ad=True,
@ -13033,6 +13032,9 @@ op_db: List[OpInfo] = [
"test_variant_consistency_jit",
device_type="cuda",
),
DecorateInfo(unittest.skip("FP16 corss_entropy cases have not been enabled on MPS yet"),
dtypes=(torch.half,), device_type="mps"),
)
),
OpInfo('nn.functional.normalize',
@ -19427,8 +19429,7 @@ op_db: List[OpInfo] = [
),
OpInfo(
"nn.functional.nll_loss",
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_nll_loss,
supports_forward_ad=True,
@ -19449,6 +19450,9 @@ op_db: List[OpInfo] = [
"test_cow_input",
device_type='cuda',
),
DecorateInfo(unittest.skip("FP16 nll_loss cases have not been enabled on MPS yet"),
dtypes=(torch.half,), device_type="mps"),
),
),
OpInfo(

View File

@ -4013,16 +4013,8 @@ module_db: List[ModuleInfo] = [
decorators=(
# No channels_last support for loss functions.
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format'),
# Expect failures for tests that rely on torch.half implementation on CPU
DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", dtypes=[torch.float16], device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_if_train_and_eval_modes_differ",
dtypes=[torch.float16], device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_save_load", dtypes=[torch.float16],
device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors", dtypes=[torch.float16],
device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_multiple_device_transfer", dtypes=[torch.float16],
device_type='cuda'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-2, rtol=1e-3)}), "TestModule",
"test_forward", dtypes=[torch.float16], device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_cpu_gpu_parity", dtypes=[torch.float16],
device_type='cuda'),),
),