mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
symintify unbind_backward and tensor_split (#86357)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86357 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
a6c0442cce
commit
c89d286af6
@ -743,14 +743,14 @@ std::vector<Tensor> tensor_split(const Tensor& self, int64_t sections, int64_t d
|
||||
TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
|
||||
int64_t dim_ = maybe_wrap_dim(dim, self.dim());
|
||||
TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections);
|
||||
const auto dim_size = self.size(dim_);
|
||||
const auto dim_size = self.sym_size(dim_);
|
||||
std::vector<Tensor> splits(sections);
|
||||
int64_t min_split_size = dim_size / sections;
|
||||
int64_t num_splits_one_extra = dim_size % sections;
|
||||
int64_t start_idx = 0;
|
||||
auto min_split_size = dim_size / sections;
|
||||
auto num_splits_one_extra = dim_size % sections;
|
||||
c10::SymInt start_idx = 0;
|
||||
for (const auto split_idx : c10::irange(sections)) {
|
||||
int64_t split_size = (split_idx < num_splits_one_extra) ? (min_split_size + 1) : min_split_size;
|
||||
splits[split_idx] = at::slice(self, dim_, start_idx, start_idx + split_size);
|
||||
auto split_size = (num_splits_one_extra > split_idx) ? (min_split_size + 1) : min_split_size;
|
||||
splits[split_idx] = at::slice_symint(self, dim_, start_idx, start_idx + split_size);
|
||||
start_idx += split_size;
|
||||
}
|
||||
return splits;
|
||||
|
@ -904,7 +904,6 @@ symbolic_aot_autograd_failures = {
|
||||
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Cannot call numel() on tensor with symbol...
|
||||
xfail('nn.functional.fractional_max_pool2d', ''), # rand() received an invalid combination of arguments - g...
|
||||
xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g...
|
||||
xfail('nn.functional.glu', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.grid_sample', ''), # prims::arange() Expected a value of type 'number' for argument...
|
||||
xfail('nn.functional.group_norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.hinge_embedding_loss', ''), # aten.zeros_like.default - couldn't find symbolic meta...
|
||||
|
@ -1196,7 +1196,6 @@ symbolic_tensor_failures = {
|
||||
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Tensors of type TensorImpl do not have numel
|
||||
xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t...
|
||||
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
|
||||
xfail('nn.functional.glu', ''), # aten.glu.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
|
||||
xfail('nn.functional.group_norm', ''), # 'torch._C.SymIntNode' and 'int'
|
||||
xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco...
|
||||
|
@ -831,18 +831,19 @@ Tensor logcumsumexp_backward(
|
||||
}
|
||||
|
||||
Tensor unbind_backward(const variable_list& grads, int64_t dim) {
|
||||
IntArrayRef sizes;
|
||||
c10::SymIntArrayRef sizes;
|
||||
at::TensorOptions o;
|
||||
for (const auto& v : grads) {
|
||||
if (v.defined()) {
|
||||
sizes = v.sizes();
|
||||
sizes = v.sym_sizes();
|
||||
o = static_cast<Tensor>(v).options();
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto grads_tensors = fmap(grads, [&](const Variable& v) {
|
||||
return (
|
||||
v.defined() ? static_cast<Tensor>(v) : at::zeros({}, o).expand(sizes));
|
||||
v.defined() ? static_cast<Tensor>(v)
|
||||
: at::zeros({}, o).expand_symint(sizes));
|
||||
});
|
||||
return at::stack(grads_tensors, dim);
|
||||
}
|
||||
|
Reference in New Issue
Block a user