mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ez] use list initializer syntax in fill_diagonal_ (#163607)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163607 Approved by: https://github.com/Skylion007 ghstack dependencies: #163485
This commit is contained in:
committed by
PyTorch MergeBot
parent
5ca563ea09
commit
b182365660
@ -109,26 +109,22 @@ Tensor& fill_diagonal_(Tensor& self, const Scalar& fill_value, bool wrap) {
|
||||
}
|
||||
|
||||
auto storage_offset = self.sym_storage_offset();
|
||||
std::vector<SymInt> sizes;
|
||||
std::vector<SymInt> strides;
|
||||
auto size = std::min(height, width);
|
||||
|
||||
int64_t stride = 0;
|
||||
for (const auto i : c10::irange(nDims)) {
|
||||
stride += self.stride(i);
|
||||
}
|
||||
strides.push_back(stride);
|
||||
sizes.push_back(size);
|
||||
std::vector<SymInt> strides{stride};
|
||||
std::vector<SymInt> sizes{size};
|
||||
|
||||
auto main_diag = self.as_strided_symint(sizes, strides, storage_offset);
|
||||
main_diag.fill_(fill_value);
|
||||
|
||||
if (wrap && nDims == 2 && height > width + 1) {
|
||||
std::vector<SymInt> wrap_sizes;
|
||||
|
||||
auto step = width + 1;
|
||||
auto wrap_size = ((self.numel() + step - 1) / step) - size;
|
||||
wrap_sizes.push_back(wrap_size);
|
||||
std::vector<SymInt> wrap_sizes{wrap_size};
|
||||
|
||||
auto offset = self.stride(0) * (width + 1);
|
||||
|
||||
|
Reference in New Issue
Block a user