mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add shape function for aten::cross_entropy_loss
(#97875)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/97875 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
5c38c4cfa4
commit
bb4998b531
@ -1033,6 +1033,10 @@ def native_batch_norm(input: List[int], weight: Optional[List[int]], bias: Optio
|
||||
_size = [0]
|
||||
return _copy(input), _size, _size
|
||||
|
||||
def cross_entropy_loss(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]:
|
||||
result_shape = nll_loss_forward(self, target, weight, reduction)[0]
|
||||
return result_shape
|
||||
|
||||
"""
|
||||
Currently deferring the enabling of this, as part of the propoasal to suspend
|
||||
adding ops.
|
||||
@ -1149,6 +1153,7 @@ add_shape_compute_mapping("aten::native_layer_norm(Tensor input, int[] normalize
|
||||
add_shape_compute_mapping("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
|
||||
add_shape_compute_mapping("aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
|
||||
add_shape_compute_mapping("aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
|
||||
add_shape_compute_mapping("aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", cross_entropy_loss)
|
||||
# add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
|
||||
|
||||
# TODO: migrate over all of symbolic_shape_registry_util.cpp
|
||||
|
Reference in New Issue
Block a user