Add shape function for aten::cross_entropy_loss (#97875)

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97875
Approved by: https://github.com/davidberard98
This commit is contained in:
Vivek Khandelwal
2023-04-12 22:11:52 +00:00
committed by PyTorch MergeBot
parent 5c38c4cfa4
commit bb4998b531
3 changed files with 75 additions and 0 deletions

View File

@ -1033,6 +1033,10 @@ def native_batch_norm(input: List[int], weight: Optional[List[int]], bias: Optio
_size = [0]
return _copy(input), _size, _size
def cross_entropy_loss(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]:
result_shape = nll_loss_forward(self, target, weight, reduction)[0]
return result_shape
"""
Currently deferring the enabling of this, as part of the propoasal to suspend
adding ops.
@ -1149,6 +1153,7 @@ add_shape_compute_mapping("aten::native_layer_norm(Tensor input, int[] normalize
add_shape_compute_mapping("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
add_shape_compute_mapping("aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
add_shape_compute_mapping("aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
add_shape_compute_mapping("aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", cross_entropy_loss)
# add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
# TODO: migrate over all of symbolic_shape_registry_util.cpp