symintify unbind_backward and tensor_split (#86357)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86357
Approved by: https://github.com/albanD
This commit is contained in:
anjali411
2022-10-09 13:35:57 +00:00
committed by PyTorch MergeBot
parent a6c0442cce
commit c89d286af6
4 changed files with 10 additions and 11 deletions

View File

@ -743,14 +743,14 @@ std::vector<Tensor> tensor_split(const Tensor& self, int64_t sections, int64_t d
TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims"); TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
int64_t dim_ = maybe_wrap_dim(dim, self.dim()); int64_t dim_ = maybe_wrap_dim(dim, self.dim());
TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections); TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections);
const auto dim_size = self.size(dim_); const auto dim_size = self.sym_size(dim_);
std::vector<Tensor> splits(sections); std::vector<Tensor> splits(sections);
int64_t min_split_size = dim_size / sections; auto min_split_size = dim_size / sections;
int64_t num_splits_one_extra = dim_size % sections; auto num_splits_one_extra = dim_size % sections;
int64_t start_idx = 0; c10::SymInt start_idx = 0;
for (const auto split_idx : c10::irange(sections)) { for (const auto split_idx : c10::irange(sections)) {
int64_t split_size = (split_idx < num_splits_one_extra) ? (min_split_size + 1) : min_split_size; auto split_size = (num_splits_one_extra > split_idx) ? (min_split_size + 1) : min_split_size;
splits[split_idx] = at::slice(self, dim_, start_idx, start_idx + split_size); splits[split_idx] = at::slice_symint(self, dim_, start_idx, start_idx + split_size);
start_idx += split_size; start_idx += split_size;
} }
return splits; return splits;

View File

@ -904,7 +904,6 @@ symbolic_aot_autograd_failures = {
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Cannot call numel() on tensor with symbol... xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Cannot call numel() on tensor with symbol...
xfail('nn.functional.fractional_max_pool2d', ''), # rand() received an invalid combination of arguments - g... xfail('nn.functional.fractional_max_pool2d', ''), # rand() received an invalid combination of arguments - g...
xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g... xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g...
xfail('nn.functional.glu', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.grid_sample', ''), # prims::arange() Expected a value of type 'number' for argument... xfail('nn.functional.grid_sample', ''), # prims::arange() Expected a value of type 'number' for argument...
xfail('nn.functional.group_norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.group_norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.hinge_embedding_loss', ''), # aten.zeros_like.default - couldn't find symbolic meta... xfail('nn.functional.hinge_embedding_loss', ''), # aten.zeros_like.default - couldn't find symbolic meta...

View File

@ -1196,7 +1196,6 @@ symbolic_tensor_failures = {
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Tensors of type TensorImpl do not have numel xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t... xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t...
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t... xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
xfail('nn.functional.glu', ''), # aten.glu.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos... xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
xfail('nn.functional.group_norm', ''), # 'torch._C.SymIntNode' and 'int' xfail('nn.functional.group_norm', ''), # 'torch._C.SymIntNode' and 'int'
xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco... xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco...

View File

@ -831,18 +831,19 @@ Tensor logcumsumexp_backward(
} }
Tensor unbind_backward(const variable_list& grads, int64_t dim) { Tensor unbind_backward(const variable_list& grads, int64_t dim) {
IntArrayRef sizes; c10::SymIntArrayRef sizes;
at::TensorOptions o; at::TensorOptions o;
for (const auto& v : grads) { for (const auto& v : grads) {
if (v.defined()) { if (v.defined()) {
sizes = v.sizes(); sizes = v.sym_sizes();
o = static_cast<Tensor>(v).options(); o = static_cast<Tensor>(v).options();
break; break;
} }
} }
auto grads_tensors = fmap(grads, [&](const Variable& v) { auto grads_tensors = fmap(grads, [&](const Variable& v) {
return ( return (
v.defined() ? static_cast<Tensor>(v) : at::zeros({}, o).expand(sizes)); v.defined() ? static_cast<Tensor>(v)
: at::zeros({}, o).expand_symint(sizes));
}); });
return at::stack(grads_tensors, dim); return at::stack(grads_tensors, dim);
} }