mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor out shape test into InputMetadata::maybe_reduce (#119559)
I'm going to gut this function shortly, and having it all on InputMetadata is convenient for this purpose. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/119559 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
c24b74efc7
commit
482345d747
@ -53,6 +53,22 @@ at::Tensor InputMetadata::zeros_like() const {
|
||||
return at::zeros_symint(shape_as_dim_vector(), options_);
|
||||
}
|
||||
|
||||
at::Tensor InputMetadata::maybe_reduce(
|
||||
const size_t i,
|
||||
at::Tensor grad,
|
||||
const std::function<std::string(const std::string&)>& format_error) const {
|
||||
if (!is_same_shape(grad)) {
|
||||
if (is_expandable_to_shape(grad)) {
|
||||
return reduce_grad(grad);
|
||||
} else {
|
||||
const auto message = incompatible_shape_error_message(i, grad);
|
||||
TORCH_CHECK(false, format_error(message.str()));
|
||||
}
|
||||
} else {
|
||||
return grad;
|
||||
}
|
||||
}
|
||||
|
||||
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
|
||||
if (!is_nestedness_same(grad)) {
|
||||
return false;
|
||||
|
Reference in New Issue
Block a user