Refactor out shape test into InputMetadata::maybe_reduce (#119559)

I'm going to gut this function shortly, and having it all on
InputMetadata is convenient for this purpose.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119559
Approved by: https://github.com/soulitzer
This commit is contained in:
Edward Z. Yang
2024-02-12 08:11:16 -08:00
committed by PyTorch MergeBot
parent c24b74efc7
commit 482345d747
3 changed files with 22 additions and 8 deletions

View File

@ -873,14 +873,7 @@ void validate_outputs(
continue; continue;
} }
if (!metadata.is_same_shape(grad)) { grad = metadata.maybe_reduce(i, std::move(grad), format_error);
if (metadata.is_expandable_to_shape(grad)) {
grad = metadata.reduce_grad(grad);
} else {
const auto message = metadata.incompatible_shape_error_message(i, grad);
TORCH_CHECK(false, format_error(message.str()));
}
}
bool input_is_complex = bool input_is_complex =
isComplexType(c10::typeMetaToScalarType(metadata.options().dtype())); isComplexType(c10::typeMetaToScalarType(metadata.options().dtype()));

View File

@ -53,6 +53,22 @@ at::Tensor InputMetadata::zeros_like() const {
return at::zeros_symint(shape_as_dim_vector(), options_); return at::zeros_symint(shape_as_dim_vector(), options_);
} }
at::Tensor InputMetadata::maybe_reduce(
const size_t i,
at::Tensor grad,
const std::function<std::string(const std::string&)>& format_error) const {
if (!is_same_shape(grad)) {
if (is_expandable_to_shape(grad)) {
return reduce_grad(grad);
} else {
const auto message = incompatible_shape_error_message(i, grad);
TORCH_CHECK(false, format_error(message.str()));
}
} else {
return grad;
}
}
bool InputMetadata::is_same_shape(const at::Tensor& grad) const { bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
if (!is_nestedness_same(grad)) { if (!is_nestedness_same(grad)) {
return false; return false;

View File

@ -73,6 +73,11 @@ struct TORCH_API InputMetadata {
at::Tensor reduce_grad(at::Tensor& grad) const; at::Tensor reduce_grad(at::Tensor& grad) const;
at::Tensor maybe_reduce(
const size_t index,
at::Tensor grad,
const std::function<std::string(const std::string&)>& format_error) const;
std::stringstream incompatible_shape_error_message( std::stringstream incompatible_shape_error_message(
const size_t index, const size_t index,
const at::Tensor& grad) const; const at::Tensor& grad) const;