mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ATen] Add CPU fp16 support for nll_loss and cross_entropy_loss (#123256)
Add CPU FP16 support for nll_loss and cross_entropy_loss. Resolve issue #123328. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123256 Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
d59f1da62f
commit
6fcbeb3489
@ -304,8 +304,12 @@ void nll_loss_forward_out_cpu_template(
|
||||
const Tensor& weight,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
ScalarType::BFloat16, input.scalar_type(), "nll_loss_out_frame", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::BFloat16,
|
||||
ScalarType::Half,
|
||||
input.scalar_type(),
|
||||
"nll_loss_out_frame",
|
||||
[&] {
|
||||
if (target.scalar_type() == kByte) {
|
||||
nll_loss_out_frame<scalar_t, uint8_t>(
|
||||
output,
|
||||
@ -415,8 +419,9 @@ void nll_loss_backward_out_cpu_template(
|
||||
const Tensor& total_weight) {
|
||||
grad_input.zero_();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::BFloat16,
|
||||
ScalarType::Half,
|
||||
input.scalar_type(),
|
||||
"nll_loss_backward_out_frame",
|
||||
[&] {
|
||||
|
@ -262,8 +262,9 @@ void nll_loss2d_forward_out_cpu_template(
|
||||
check_inputs_nll_loss2d(input, target, weight);
|
||||
total_weight.resize_({});
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::BFloat16,
|
||||
ScalarType::Half,
|
||||
input.scalar_type(),
|
||||
"nll_loss2d_forward_out_frame",
|
||||
[&] {
|
||||
@ -383,8 +384,9 @@ void nll_loss2d_backward_out_cpu_template(
|
||||
total_weight.numel(),
|
||||
" elements)");
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::BFloat16,
|
||||
ScalarType::Half,
|
||||
input.scalar_type(),
|
||||
"nll_loss2d_backward_out_frame",
|
||||
[&] {
|
||||
|
@ -2001,6 +2001,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
||||
"nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3],
|
||||
"nn.functional.local_response_norm": [1e-2, 5e-3],
|
||||
"nn.functional.poisson_nll_loss": [3e-2, 1e-3],
|
||||
"nn.functional.nll_loss": [3e-2, 1e-3],
|
||||
"native_batch_norm": [3e-2, 1e-3],
|
||||
"dot": [3e-2, 1e-3],
|
||||
"logit": [3e-2, 1e-3],
|
||||
|
@ -171,6 +171,8 @@ def mps_ops_grad_modifier(ops):
|
||||
'nn.functional.conv_transpose1d': [torch.float16],
|
||||
'nn.functional.conv_transpose2d': [torch.float16],
|
||||
'nn.functional.conv_transpose3d': [torch.float16],
|
||||
'nn.functional.nll_loss': [torch.float16],
|
||||
'nn.functional.cross_entropy': [torch.float16],
|
||||
}
|
||||
|
||||
MACOS_13_3_XFAILLIST_GRAD = {
|
||||
@ -987,7 +989,10 @@ def mps_ops_modifier(ops):
|
||||
'nn.functional.avg_pool2d': [torch.float16],
|
||||
# input types 'tensor<f32>' and 'tensor<1xf16>' are not broadcast compatible
|
||||
# Refer to the issue please: https://github.com/pytorch/pytorch/issues/124252
|
||||
'nn.functional.binary_cross_entropy': [torch.float16]
|
||||
'nn.functional.binary_cross_entropy': [torch.float16],
|
||||
|
||||
'nn.functional.nll_loss': [torch.float16],
|
||||
'nn.functional.cross_entropy': [torch.float16],
|
||||
}
|
||||
|
||||
def addDecorator(op, d) -> None:
|
||||
|
@ -13009,8 +13009,7 @@ op_db: List[OpInfo] = [
|
||||
supports_out=False),
|
||||
OpInfo(
|
||||
"nn.functional.cross_entropy",
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_cross_entropy,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
@ -13033,6 +13032,9 @@ op_db: List[OpInfo] = [
|
||||
"test_variant_consistency_jit",
|
||||
device_type="cuda",
|
||||
),
|
||||
DecorateInfo(unittest.skip("FP16 corss_entropy cases have not been enabled on MPS yet"),
|
||||
dtypes=(torch.half,), device_type="mps"),
|
||||
|
||||
)
|
||||
),
|
||||
OpInfo('nn.functional.normalize',
|
||||
@ -19427,8 +19429,7 @@ op_db: List[OpInfo] = [
|
||||
),
|
||||
OpInfo(
|
||||
"nn.functional.nll_loss",
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_nll_loss,
|
||||
supports_forward_ad=True,
|
||||
@ -19449,6 +19450,9 @@ op_db: List[OpInfo] = [
|
||||
"test_cow_input",
|
||||
device_type='cuda',
|
||||
),
|
||||
DecorateInfo(unittest.skip("FP16 nll_loss cases have not been enabled on MPS yet"),
|
||||
dtypes=(torch.half,), device_type="mps"),
|
||||
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
|
@ -4013,16 +4013,8 @@ module_db: List[ModuleInfo] = [
|
||||
decorators=(
|
||||
# No channels_last support for loss functions.
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format'),
|
||||
# Expect failures for tests that rely on torch.half implementation on CPU
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", dtypes=[torch.float16], device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_if_train_and_eval_modes_differ",
|
||||
dtypes=[torch.float16], device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_save_load", dtypes=[torch.float16],
|
||||
device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors", dtypes=[torch.float16],
|
||||
device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_multiple_device_transfer", dtypes=[torch.float16],
|
||||
device_type='cuda'),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-2, rtol=1e-3)}), "TestModule",
|
||||
"test_forward", dtypes=[torch.float16], device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_cpu_gpu_parity", dtypes=[torch.float16],
|
||||
device_type='cuda'),),
|
||||
),
|
||||
|
Reference in New Issue
Block a user