Files
pytorch/tools/codegen/static_runtime/config.py
Don Jang fe7e1bd1ce [Static Runtime] Add auto-generated out variant dispatchers (#72603)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72603

This change adds out variant dispatchers generated by the previous diff.

The number of the out variant dispatchers generated by this diff is 133, which increases the out variant coverage by 309% (current: 43, this diff: 133 + 43 = 176). This number is expected to increase a lot as we develop this script further to cover more ops.

Test Plan:
**Unittest**
Confirmed
```
buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest
```
is passing.

Reviewed By: swolchok

Differential Revision: D33373928

fbshipit-source-id: 4d94d788282f3f313bb36f2f9452edecd9862246
(cherry picked from commit e4ce8b386d1fcc47b86cb9c9016a70e7a31b452c)
2022-02-28 08:39:10 +00:00

179 lines
7.6 KiB
Python

from tools.codegen.model import NativeFunctionsGroup
from typing import Dict
def func_name_base_str(g: NativeFunctionsGroup) -> str:
return str(g.functional.func.name.name.base)
is_hand_written_ops_ = frozenset(("add", "addmm", "all", "any", "argmin", "bmm", "clamp",
"cumsum", "div", "fmod", "leaky_relu", "log", "mul", "pow",
"remainder", "sigmoid", "sign", "sub", "tanh"))
def is_hand_written(g: NativeFunctionsGroup) -> bool:
name_base = func_name_base_str(g)
return name_base in is_hand_written_ops_
def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> None:
assert index == 0 or index == 1
if op_name == "addmv":
if index == 0:
arg_map["self"] = "at::rand({2})"
arg_map["mat"] = "at::rand({2, 2})"
arg_map["vec"] = "at::rand({2})"
else:
arg_map["self"] = "at::rand({35})"
arg_map["mat"] = "at::rand({35, 35})"
arg_map["vec"] = "at::rand({35})"
return
if op_name == "acosh":
if index == 0:
arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})"
else:
arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})"
return
if op_name == "adaptive_max_pool2d_backward":
if index == 0:
arg_map["grad_output"] = "at::randint(-3, 2, {2,2,2})"
arg_map["self"] = "at::randint(-3, 2, {2,2,2})"
arg_map["indices"] = "at::randint(0, 1, {2,2,2}, at::kLong)"
else:
arg_map["grad_output"] = "at::randint(-3, 3, {3,3,3})"
arg_map["self"] = "at::randint(-3, 2, {3,3,3})"
arg_map["indices"] = "at::randint(0, 1, {3,3,3}, at::kLong)"
return
if op_name == "adaptive_max_pool3d_backward":
if index == 0:
arg_map["grad_output"] = "at::randint(-3, 2, {2,2,2,2})"
arg_map["self"] = "at::randint(-3, 2, {2,2,2,2})"
arg_map["indices"] = "at::randint(0, 1, {2,2,2,2}, at::kLong)"
else:
arg_map["grad_output"] = "at::randint(-3, 3, {3,3,3,3})"
arg_map["self"] = "at::randint(-3, 2, {3,3,3,3})"
arg_map["indices"] = "at::randint(0, 1, {3,3,3,3}, at::kLong)"
return
if op_name == "gather":
if index == 0:
arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)"
arg_map["dim"] = "1"
arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
arg_map["sparse_grad"] = "false"
else:
arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)"
arg_map["dim"] = "1"
arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)"
arg_map["sparse_grad"] = "false"
return
if op_name == "gelu":
if index == 0:
arg_map["self"] = "at::rand({6, 6, 6})"
arg_map["approximate"] = "\"tanh\""
else:
arg_map["self"] = "at::rand({22, 22, 22})"
arg_map["approximate"] = "\"tanh\""
return
if op_name == "gelu_backward":
if index == 0:
arg_map["grad_output"] = "at::rand({6, 6, 6})"
arg_map["self"] = "at::rand({6, 6, 6})"
arg_map["approximate"] = "\"tanh\""
else:
arg_map["grad_output"] = "at::rand({22, 22, 22})"
arg_map["self"] = "at::rand({22, 22, 22})"
arg_map["approximate"] = "\"tanh\""
return
if op_name == "index_add":
if index == 0:
arg_map["self"] = "at::rand({2})"
arg_map["dim"] = "0"
arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)"
arg_map["source"] = "at::rand({2})"
arg_map["alpha"] = "2"
else:
arg_map["self"] = "at::rand({16})"
arg_map["dim"] = "0"
arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)"
arg_map["source"] = "at::rand({16})"
arg_map["alpha"] = "2"
return
if op_name == "index_copy":
if index == 0:
arg_map["self"] = "at::rand({2})"
arg_map["dim"] = "0"
arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)"
arg_map["source"] = "at::rand({2})"
else:
arg_map["self"] = "at::rand({32})"
arg_map["dim"] = "0"
arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)"
arg_map["source"] = "at::rand({32})"
return
if op_name == "linalg_cross":
if index == 0:
arg_map["self"] = "at::rand({6, 3, 6})"
arg_map["other"] = "at::rand({6, 3, 6})"
arg_map["dim"] = "1"
else:
arg_map["self"] = "at::rand({22, 3, 22})"
arg_map["other"] = "at::rand({22, 3, 22})"
arg_map["dim"] = "1"
return
if op_name == "nll_loss_backward":
if index == 0:
arg_map["grad_output"] = "at::rand({})"
arg_map["self"] = "at::rand({6})"
arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)"
arg_map["weight"] = "at::rand({6})"
arg_map["reduction"] = "1"
arg_map["ignore_index"] = "1"
arg_map["total_weight"] = "at::rand({})"
else:
arg_map["grad_output"] = "at::rand({})"
arg_map["self"] = "at::rand({36})"
arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)"
arg_map["weight"] = "at::rand({36})"
arg_map["reduction"] = "1"
arg_map["ignore_index"] = "1"
arg_map["total_weight"] = "at::rand({})"
return
if op_name in ["scatter", "scatter_add", "_scatter_reduce"]:
if index == 0:
arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
else:
arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)"
arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
if "reduce" in arg_map:
arg_map["reduce"] = "\"sum\"" if op_name == "_scatter_reduce" else "\"add\""
return
if op_name == "special_zeta":
if index == 0:
arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
else:
arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
return
if op_name == "_convert_indices_from_csr_to_coo":
if index == 0:
arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)"
arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)"
arg_map["out_int32"] = "false"
else:
arg_map["crow_indices"] = "torch::tensor({0, 1}, torch::kInt32)"
arg_map["col_indices"] = "torch::tensor({0, 1, 0, 2, 1, 2}, torch::kInt32)"
arg_map["out_int32"] = "false"
return
if op_name == "_convert_indices_from_coo_to_csr":
if index == 0:
arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)"
arg_map["size"] = "10"
arg_map["out_int32"] = "false"
else:
arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)"
arg_map["size"] = "24"
arg_map["out_int32"] = "false"
return