mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72603 This change adds out variant dispatchers generated by the previous diff. The number of the out variant dispatchers generated by this diff is 133, which increases the out variant coverage by 309% (current: 43, this diff: 133 + 43 = 176). This number is expected to increase a lot as we develop this script further to cover more ops. Test Plan: **Unittest** Confirmed ``` buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest ``` is passing. Reviewed By: swolchok Differential Revision: D33373928 fbshipit-source-id: 4d94d788282f3f313bb36f2f9452edecd9862246 (cherry picked from commit e4ce8b386d1fcc47b86cb9c9016a70e7a31b452c)
179 lines
7.6 KiB
Python
179 lines
7.6 KiB
Python
from tools.codegen.model import NativeFunctionsGroup
|
|
|
|
from typing import Dict
|
|
|
|
|
|
def func_name_base_str(g: NativeFunctionsGroup) -> str:
|
|
return str(g.functional.func.name.name.base)
|
|
|
|
is_hand_written_ops_ = frozenset(("add", "addmm", "all", "any", "argmin", "bmm", "clamp",
|
|
"cumsum", "div", "fmod", "leaky_relu", "log", "mul", "pow",
|
|
"remainder", "sigmoid", "sign", "sub", "tanh"))
|
|
|
|
def is_hand_written(g: NativeFunctionsGroup) -> bool:
|
|
name_base = func_name_base_str(g)
|
|
return name_base in is_hand_written_ops_
|
|
|
|
def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> None:
|
|
assert index == 0 or index == 1
|
|
if op_name == "addmv":
|
|
if index == 0:
|
|
arg_map["self"] = "at::rand({2})"
|
|
arg_map["mat"] = "at::rand({2, 2})"
|
|
arg_map["vec"] = "at::rand({2})"
|
|
else:
|
|
arg_map["self"] = "at::rand({35})"
|
|
arg_map["mat"] = "at::rand({35, 35})"
|
|
arg_map["vec"] = "at::rand({35})"
|
|
return
|
|
if op_name == "acosh":
|
|
if index == 0:
|
|
arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})"
|
|
else:
|
|
arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})"
|
|
return
|
|
if op_name == "adaptive_max_pool2d_backward":
|
|
if index == 0:
|
|
arg_map["grad_output"] = "at::randint(-3, 2, {2,2,2})"
|
|
arg_map["self"] = "at::randint(-3, 2, {2,2,2})"
|
|
arg_map["indices"] = "at::randint(0, 1, {2,2,2}, at::kLong)"
|
|
else:
|
|
arg_map["grad_output"] = "at::randint(-3, 3, {3,3,3})"
|
|
arg_map["self"] = "at::randint(-3, 2, {3,3,3})"
|
|
arg_map["indices"] = "at::randint(0, 1, {3,3,3}, at::kLong)"
|
|
return
|
|
if op_name == "adaptive_max_pool3d_backward":
|
|
if index == 0:
|
|
arg_map["grad_output"] = "at::randint(-3, 2, {2,2,2,2})"
|
|
arg_map["self"] = "at::randint(-3, 2, {2,2,2,2})"
|
|
arg_map["indices"] = "at::randint(0, 1, {2,2,2,2}, at::kLong)"
|
|
else:
|
|
arg_map["grad_output"] = "at::randint(-3, 3, {3,3,3,3})"
|
|
arg_map["self"] = "at::randint(-3, 2, {3,3,3,3})"
|
|
arg_map["indices"] = "at::randint(0, 1, {3,3,3,3}, at::kLong)"
|
|
return
|
|
if op_name == "gather":
|
|
if index == 0:
|
|
arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)"
|
|
arg_map["dim"] = "1"
|
|
arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
|
|
arg_map["sparse_grad"] = "false"
|
|
else:
|
|
arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)"
|
|
arg_map["dim"] = "1"
|
|
arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)"
|
|
arg_map["sparse_grad"] = "false"
|
|
return
|
|
if op_name == "gelu":
|
|
if index == 0:
|
|
arg_map["self"] = "at::rand({6, 6, 6})"
|
|
arg_map["approximate"] = "\"tanh\""
|
|
else:
|
|
arg_map["self"] = "at::rand({22, 22, 22})"
|
|
arg_map["approximate"] = "\"tanh\""
|
|
return
|
|
if op_name == "gelu_backward":
|
|
if index == 0:
|
|
arg_map["grad_output"] = "at::rand({6, 6, 6})"
|
|
arg_map["self"] = "at::rand({6, 6, 6})"
|
|
arg_map["approximate"] = "\"tanh\""
|
|
else:
|
|
arg_map["grad_output"] = "at::rand({22, 22, 22})"
|
|
arg_map["self"] = "at::rand({22, 22, 22})"
|
|
arg_map["approximate"] = "\"tanh\""
|
|
return
|
|
if op_name == "index_add":
|
|
if index == 0:
|
|
arg_map["self"] = "at::rand({2})"
|
|
arg_map["dim"] = "0"
|
|
arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)"
|
|
arg_map["source"] = "at::rand({2})"
|
|
arg_map["alpha"] = "2"
|
|
else:
|
|
arg_map["self"] = "at::rand({16})"
|
|
arg_map["dim"] = "0"
|
|
arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)"
|
|
arg_map["source"] = "at::rand({16})"
|
|
arg_map["alpha"] = "2"
|
|
return
|
|
if op_name == "index_copy":
|
|
if index == 0:
|
|
arg_map["self"] = "at::rand({2})"
|
|
arg_map["dim"] = "0"
|
|
arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)"
|
|
arg_map["source"] = "at::rand({2})"
|
|
else:
|
|
arg_map["self"] = "at::rand({32})"
|
|
arg_map["dim"] = "0"
|
|
arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)"
|
|
arg_map["source"] = "at::rand({32})"
|
|
return
|
|
if op_name == "linalg_cross":
|
|
if index == 0:
|
|
arg_map["self"] = "at::rand({6, 3, 6})"
|
|
arg_map["other"] = "at::rand({6, 3, 6})"
|
|
arg_map["dim"] = "1"
|
|
else:
|
|
arg_map["self"] = "at::rand({22, 3, 22})"
|
|
arg_map["other"] = "at::rand({22, 3, 22})"
|
|
arg_map["dim"] = "1"
|
|
return
|
|
if op_name == "nll_loss_backward":
|
|
if index == 0:
|
|
arg_map["grad_output"] = "at::rand({})"
|
|
arg_map["self"] = "at::rand({6})"
|
|
arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)"
|
|
arg_map["weight"] = "at::rand({6})"
|
|
arg_map["reduction"] = "1"
|
|
arg_map["ignore_index"] = "1"
|
|
arg_map["total_weight"] = "at::rand({})"
|
|
else:
|
|
arg_map["grad_output"] = "at::rand({})"
|
|
arg_map["self"] = "at::rand({36})"
|
|
arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)"
|
|
arg_map["weight"] = "at::rand({36})"
|
|
arg_map["reduction"] = "1"
|
|
arg_map["ignore_index"] = "1"
|
|
arg_map["total_weight"] = "at::rand({})"
|
|
return
|
|
if op_name in ["scatter", "scatter_add", "_scatter_reduce"]:
|
|
if index == 0:
|
|
arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
|
|
arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
|
|
arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
|
|
else:
|
|
arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
|
|
arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)"
|
|
arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
|
|
if "reduce" in arg_map:
|
|
arg_map["reduce"] = "\"sum\"" if op_name == "_scatter_reduce" else "\"add\""
|
|
return
|
|
if op_name == "special_zeta":
|
|
if index == 0:
|
|
arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
|
|
arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
|
|
else:
|
|
arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
|
|
arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
|
|
return
|
|
if op_name == "_convert_indices_from_csr_to_coo":
|
|
if index == 0:
|
|
arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)"
|
|
arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)"
|
|
arg_map["out_int32"] = "false"
|
|
else:
|
|
arg_map["crow_indices"] = "torch::tensor({0, 1}, torch::kInt32)"
|
|
arg_map["col_indices"] = "torch::tensor({0, 1, 0, 2, 1, 2}, torch::kInt32)"
|
|
arg_map["out_int32"] = "false"
|
|
return
|
|
if op_name == "_convert_indices_from_coo_to_csr":
|
|
if index == 0:
|
|
arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)"
|
|
arg_map["size"] = "10"
|
|
arg_map["out_int32"] = "false"
|
|
else:
|
|
arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)"
|
|
arg_map["size"] = "24"
|
|
arg_map["out_int32"] = "false"
|
|
return
|