mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
added torch.rot90() to ATen (#8628)
Summary:
1. fixes #6271
2. implemented torch.rot90() following [numpy.rot90()](6a58e25703/numpy/lib/function_base.py (L54-L138)
)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8628
Reviewed By: ezyang
Differential Revision: D8987860
Pulled By: weiyangfb
fbshipit-source-id: 8dac3b2a1f6d3288672977aba8b547706ce97fe9
This commit is contained in:
committed by
Facebook Github Bot
parent
2f5c0c30cd
commit
302adb7cc8
@ -161,6 +161,37 @@ static inline at::optional<int64_t> parse_as_integer(const std::string& s) {
|
||||
return (*str_end == 0) ? at::optional<int64_t>(ans) : at::nullopt;
|
||||
}
|
||||
|
||||
/*
|
||||
Parse default value of IntList declared at native_functions.yaml
|
||||
|
||||
There are two kinds of default values:
|
||||
1. IntList[2] x=1 (where size=2, value={1,1}
|
||||
2. IntList x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args)
|
||||
*/
|
||||
static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int64_t size) {
|
||||
size_t n = s.size();
|
||||
|
||||
if (s.empty()) return std::vector<int64_t>();
|
||||
|
||||
// case 1. s is an int (e.g., s=2)
|
||||
if (s[0] != '{') {
|
||||
return std::vector<int64_t>(size, std::stol(s));
|
||||
}
|
||||
|
||||
// case 2. s is a list of dims (e.g., s={1,2})
|
||||
|
||||
// since already checked left brace '{' above, here only checks right brace '}'
|
||||
AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]);
|
||||
|
||||
auto args = std::vector<int64_t>();
|
||||
std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
|
||||
std::string tok;
|
||||
|
||||
while(std::getline(ss, tok, ',')) {
|
||||
args.emplace_back(std::stol(tok));
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
void FunctionParameter::set_default_str(const std::string& str) {
|
||||
if (str == "None") {
|
||||
@ -189,7 +220,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
||||
}
|
||||
} else if (type_ == ParameterType::INT_LIST) {
|
||||
if (str != "None") {
|
||||
default_intlist.assign(size, std::stoi(str));
|
||||
default_intlist = parse_intlist_args(str, size);
|
||||
}
|
||||
} else if (type_ == ParameterType::SCALARTYPE) {
|
||||
if (str == "None") {
|
||||
|
Reference in New Issue
Block a user