added torch.rot90() to ATen (#8628)

Summary:
1. fixes #6271
2. implemented torch.rot90() following [numpy.rot90()](6a58e25703/numpy/lib/function_base.py (L54-L138))
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8628

Reviewed By: ezyang

Differential Revision: D8987860

Pulled By: weiyangfb

fbshipit-source-id: 8dac3b2a1f6d3288672977aba8b547706ce97fe9
This commit is contained in:
Wei Yang
2018-07-25 14:59:35 -07:00
committed by Facebook Github Bot
parent 2f5c0c30cd
commit 302adb7cc8
14 changed files with 204 additions and 29 deletions

View File

@ -161,6 +161,37 @@ static inline at::optional<int64_t> parse_as_integer(const std::string& s) {
return (*str_end == 0) ? at::optional<int64_t>(ans) : at::nullopt;
}
/*
Parse default value of IntList declared at native_functions.yaml
There are two kinds of default values:
1. IntList[2] x=1 (where size=2, value={1,1}
2. IntList x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args)
*/
static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int64_t size) {
size_t n = s.size();
if (s.empty()) return std::vector<int64_t>();
// case 1. s is an int (e.g., s=2)
if (s[0] != '{') {
return std::vector<int64_t>(size, std::stol(s));
}
// case 2. s is a list of dims (e.g., s={1,2})
// since already checked left brace '{' above, here only checks right brace '}'
AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]);
auto args = std::vector<int64_t>();
std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
std::string tok;
while(std::getline(ss, tok, ',')) {
args.emplace_back(std::stol(tok));
}
return args;
}
void FunctionParameter::set_default_str(const std::string& str) {
if (str == "None") {
@ -189,7 +220,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
}
} else if (type_ == ParameterType::INT_LIST) {
if (str != "None") {
default_intlist.assign(size, std::stoi(str));
default_intlist = parse_intlist_args(str, size);
}
} else if (type_ == ParameterType::SCALARTYPE) {
if (str == "None") {