mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
ghstack-source-id: e44b3941530be83a630ec93f1478eec741ffca2e Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/162815 Fixes #ISSUE_NUMBER Relanding due to internal weirdness. Separate PR to codev w/o ghstack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164751 Approved by: https://github.com/albanD
208 lines
6.4 KiB
C++
208 lines
6.4 KiB
C++
#include <torch/csrc/autograd/input_metadata.h>
|
|
|
|
// TODO: we may be able to move some imports from input_metadata.h to here, but
|
|
// it seems that function.h transitively depends on some of them.
|
|
|
|
namespace torch::autograd {
|
|
|
|
namespace {
|
|
|
|
MetadataShape compute_variant_shape(const at::Tensor& input) {
|
|
if (input.is_nested() && !input.unsafeGetTensorImpl()->is_python_dispatch()) {
|
|
auto nested_size = input._nested_tensor_size();
|
|
return MetadataShape{std::in_place_type<at::Tensor>, nested_size};
|
|
}
|
|
return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
|
|
}
|
|
|
|
bool is_python_dispatch(const at::Tensor& tensor) {
|
|
return tensor.unsafeGetTensorImpl()->is_python_dispatch();
|
|
}
|
|
|
|
bool is_cpp_nested_tensor(const at::Tensor& tensor) {
|
|
return tensor.is_nested() && !is_python_dispatch(tensor);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
InputMetadata::InputMetadata(
|
|
const at::TensorOptions& options,
|
|
MetadataShape input_shape,
|
|
bool is_tensor_subclass,
|
|
bool is_nested,
|
|
std::optional<at::ScalarType> grad_dtype)
|
|
: options_{options},
|
|
shape_{std::move(input_shape)},
|
|
is_tensor_subclass_{is_tensor_subclass},
|
|
is_nested_{is_nested},
|
|
was_default_constructed_{false},
|
|
grad_dtype_{grad_dtype} {
|
|
auto device_ = options.device();
|
|
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
|
|
}
|
|
|
|
InputMetadata::InputMetadata(const at::Tensor& t)
|
|
: InputMetadata(
|
|
t.options(),
|
|
compute_variant_shape(t),
|
|
is_python_dispatch(t),
|
|
t.is_nested(),
|
|
t.grad_dtype()) {}
|
|
|
|
at::Tensor InputMetadata::zeros_like() const {
|
|
TORCH_CHECK(
|
|
!is_nested_, "Zeros is not currently supported for nested tensors.")
|
|
return at::zeros_symint(shape_as_dim_vector(), options_);
|
|
}
|
|
|
|
at::Tensor InputMetadata::maybe_reduce(
|
|
const size_t i,
|
|
at::Tensor grad,
|
|
const std::function<std::string(const std::string&)>& format_error) const {
|
|
auto fail = [&]() {
|
|
const auto message = incompatible_shape_error_message(i, grad);
|
|
TORCH_CHECK(false, format_error(message.str()));
|
|
};
|
|
|
|
// Nested tensor makes my brain explode, so I've just hard-coded the logic
|
|
// for this case, at risk of code duplication. This logic does NOT do the
|
|
// careful oblivious logic as seen below
|
|
if (is_nested_ || is_cpp_nested_tensor() || grad.is_nested() ||
|
|
::torch::autograd::is_cpp_nested_tensor(grad)) {
|
|
if (!is_same_shape(grad)) {
|
|
if (is_expandable_to_shape(grad)) {
|
|
return reduce_grad(grad);
|
|
} else {
|
|
fail();
|
|
}
|
|
} else {
|
|
return grad;
|
|
}
|
|
}
|
|
|
|
auto shape = shape_as_dim_vector();
|
|
auto desired = grad.sym_sizes();
|
|
|
|
size_t ndim = shape.size();
|
|
size_t target_dim = desired.size();
|
|
if (ndim > target_dim) {
|
|
fail();
|
|
}
|
|
bool needs_reduce = false;
|
|
for (const auto i : c10::irange(ndim)) {
|
|
const auto& size = shape[ndim - i - 1];
|
|
const auto& target = desired[target_dim - i - 1];
|
|
// The conditions here are written carefully so that we are able to
|
|
// infer deferred runtime asserts
|
|
if (TORCH_GUARD_OR_FALSE(size.sym_eq(1))) {
|
|
// NB: we could short circuit this once needs_reduce is true but there's
|
|
// no point since the reduction function will guard on this anyway
|
|
if (!c10::guard_or_false(size.sym_eq(target), __FILE__, __LINE__)) {
|
|
needs_reduce = true;
|
|
}
|
|
} else {
|
|
if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) {
|
|
fail();
|
|
}
|
|
}
|
|
}
|
|
if (ndim != target_dim) {
|
|
needs_reduce = true;
|
|
}
|
|
|
|
if (needs_reduce) {
|
|
return reduce_grad(grad);
|
|
} else {
|
|
return grad;
|
|
}
|
|
}
|
|
|
|
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
|
|
if (!is_nestedness_same(grad)) {
|
|
return false;
|
|
}
|
|
if (is_cpp_nested_tensor()) {
|
|
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
|
|
}
|
|
return grad.sym_sizes().equals(shape_as_dim_vector());
|
|
}
|
|
|
|
bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
|
|
if (!maybe_expandable_to(grad)) {
|
|
return false;
|
|
}
|
|
return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
|
|
}
|
|
|
|
at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
|
|
// reduce_grad should only be called if is_expandable_to_shape returns true.
|
|
TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
|
|
return at::sum_to(std::move(grad), shape_as_dim_vector());
|
|
}
|
|
|
|
std::stringstream InputMetadata::incompatible_shape_error_message(
|
|
const size_t index,
|
|
const at::Tensor& grad) const {
|
|
std::stringstream ss{};
|
|
ss << "invalid gradient at index " << index << " - got ";
|
|
if (::torch::autograd::is_cpp_nested_tensor(grad)) {
|
|
ss << grad._nested_tensor_size();
|
|
} else {
|
|
ss << grad.sym_sizes();
|
|
}
|
|
ss << " but expected shape compatible with ";
|
|
if (is_cpp_nested_tensor()) {
|
|
ss << shape_as_tensor();
|
|
} else {
|
|
ss << shape_as_dim_vector();
|
|
}
|
|
return ss;
|
|
}
|
|
|
|
bool InputMetadata::is_cpp_nested_tensor() const {
|
|
bool ret = std::holds_alternative<at::Tensor>(shape_);
|
|
TORCH_INTERNAL_ASSERT(ret == (is_nested_ && !is_tensor_subclass_))
|
|
return ret;
|
|
}
|
|
|
|
c10::SymIntArrayRef InputMetadata::shape_as_dim_vector() const {
|
|
const auto& dim_shape = std::get<SymIntSmallVec>(shape_);
|
|
return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
|
|
}
|
|
|
|
// Danger: not thread safe, caller must protect with lock
|
|
SymIntSmallVec& InputMetadata::mutable_shape_as_dim_vector() {
|
|
return std::get<SymIntSmallVec>(shape_);
|
|
}
|
|
|
|
bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
|
|
return (
|
|
grad.is_nested() == is_nested_ &&
|
|
::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
|
|
}
|
|
|
|
at::Tensor InputMetadata::shape_as_tensor() const {
|
|
return std::get<at::Tensor>(shape_);
|
|
}
|
|
|
|
bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
|
|
// This is the initial step to determine whether or not the tensor represented
|
|
// by input_metadata is expandable to grad based on is-nestedness information
|
|
// alone. If this function returns true, then is_expandable_to_shape will be
|
|
// called. We support the following 3 types of expansion:
|
|
bool grad_is_nested = grad.is_nested();
|
|
if (!is_nested_ && !grad_is_nested) {
|
|
// Normal case (no NestedTensors are involved)
|
|
// (1) plain Tensor -> plain Tensor
|
|
return true;
|
|
} else {
|
|
// (2) python NT -> python NT
|
|
// (3) plain Tensor -> python NT
|
|
return (
|
|
grad_is_nested && is_python_dispatch(grad) &&
|
|
(!is_nested_ || is_tensor_subclass_));
|
|
}
|
|
}
|
|
|
|
} // namespace torch::autograd
|