mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor out shape test into InputMetadata::maybe_reduce (#119559)
I'm going to gut this function shortly, and having it all on InputMetadata is convenient for this purpose. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/119559 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
c24b74efc7
commit
482345d747
@ -873,14 +873,7 @@ void validate_outputs(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!metadata.is_same_shape(grad)) {
|
grad = metadata.maybe_reduce(i, std::move(grad), format_error);
|
||||||
if (metadata.is_expandable_to_shape(grad)) {
|
|
||||||
grad = metadata.reduce_grad(grad);
|
|
||||||
} else {
|
|
||||||
const auto message = metadata.incompatible_shape_error_message(i, grad);
|
|
||||||
TORCH_CHECK(false, format_error(message.str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool input_is_complex =
|
bool input_is_complex =
|
||||||
isComplexType(c10::typeMetaToScalarType(metadata.options().dtype()));
|
isComplexType(c10::typeMetaToScalarType(metadata.options().dtype()));
|
||||||
|
@ -53,6 +53,22 @@ at::Tensor InputMetadata::zeros_like() const {
|
|||||||
return at::zeros_symint(shape_as_dim_vector(), options_);
|
return at::zeros_symint(shape_as_dim_vector(), options_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
at::Tensor InputMetadata::maybe_reduce(
|
||||||
|
const size_t i,
|
||||||
|
at::Tensor grad,
|
||||||
|
const std::function<std::string(const std::string&)>& format_error) const {
|
||||||
|
if (!is_same_shape(grad)) {
|
||||||
|
if (is_expandable_to_shape(grad)) {
|
||||||
|
return reduce_grad(grad);
|
||||||
|
} else {
|
||||||
|
const auto message = incompatible_shape_error_message(i, grad);
|
||||||
|
TORCH_CHECK(false, format_error(message.str()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return grad;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
|
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
|
||||||
if (!is_nestedness_same(grad)) {
|
if (!is_nestedness_same(grad)) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -73,6 +73,11 @@ struct TORCH_API InputMetadata {
|
|||||||
|
|
||||||
at::Tensor reduce_grad(at::Tensor& grad) const;
|
at::Tensor reduce_grad(at::Tensor& grad) const;
|
||||||
|
|
||||||
|
at::Tensor maybe_reduce(
|
||||||
|
const size_t index,
|
||||||
|
at::Tensor grad,
|
||||||
|
const std::function<std::string(const std::string&)>& format_error) const;
|
||||||
|
|
||||||
std::stringstream incompatible_shape_error_message(
|
std::stringstream incompatible_shape_error_message(
|
||||||
const size_t index,
|
const size_t index,
|
||||||
const at::Tensor& grad) const;
|
const at::Tensor& grad) const;
|
||||||
|
Reference in New Issue
Block a user