mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Static Runtime] Add auto-generated out variant dispatchers (#72603)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72603 This change adds out variant dispatchers generated by the previous diff. The number of the out variant dispatchers generated by this diff is 133, which increases the out variant coverage by 309% (current: 43, this diff: 133 + 43 = 176). This number is expected to increase a lot as we develop this script further to cover more ops. Test Plan: **Unittest** Confirmed ``` buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest ``` is passing. Reviewed By: swolchok Differential Revision: D33373928 fbshipit-source-id: 4d94d788282f3f313bb36f2f9452edecd9862246 (cherry picked from commit e4ce8b386d1fcc47b86cb9c9016a70e7a31b452c)
This commit is contained in:
committed by
PyTorch MergeBot
parent
94501ff91e
commit
fe7e1bd1ce
@ -6,4 +6,5 @@ list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc
|
||||
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_utils.cc)
|
||||
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_static_runtime.cc)
|
||||
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_static_module.cc)
|
||||
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_generated_ops.cc)
|
||||
set(STATIC_RUNTIME_TEST_SRCS ${STATIC_RUNTIME_TEST_SRCS} PARENT_SCOPE)
|
||||
|
4445
benchmarks/static_runtime/test_generated_ops.cc
Normal file
4445
benchmarks/static_runtime/test_generated_ops.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -2160,7 +2160,7 @@ TEST(StaticRuntime, Where) {
|
||||
)JIT";
|
||||
|
||||
std::vector<IValue> args1 = {at::randn({2, 2}), at::randn({2, 2})};
|
||||
std::vector<IValue> args2 = {at::randn({3, 6}), at::randn({3, 6})};
|
||||
std::vector<IValue> args2 = {at::randn({8, 10}), at::randn({8, 10})};
|
||||
|
||||
testStaticRuntime(where_script, args1);
|
||||
testStaticRuntime(where_script, args1, args2);
|
||||
|
@ -373,6 +373,7 @@ core_sources_full_mobile = core_sources_full_mobile_no_backend_interface + [
|
||||
|
||||
core_sources_full = core_sources_full_mobile + [
|
||||
"torch/csrc/jit/runtime/static/fusion.cpp",
|
||||
"torch/csrc/jit/runtime/static/generated_ops.cpp",
|
||||
"torch/csrc/jit/runtime/static/impl.cpp",
|
||||
"torch/csrc/jit/runtime/static/memory_planner.cpp",
|
||||
"torch/csrc/jit/runtime/static/native_ops.cpp",
|
||||
|
@ -32,20 +32,6 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N
|
||||
else:
|
||||
arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})"
|
||||
return
|
||||
if op_name == "index_add":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({2})"
|
||||
arg_map["dim"] = "0"
|
||||
arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)"
|
||||
arg_map["source"] = "at::rand({2})"
|
||||
arg_map["alpha"] = "2"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({16})"
|
||||
arg_map["dim"] = "0"
|
||||
arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)"
|
||||
arg_map["source"] = "at::rand({16})"
|
||||
arg_map["alpha"] = "2"
|
||||
return
|
||||
if op_name == "adaptive_max_pool2d_backward":
|
||||
if index == 0:
|
||||
arg_map["grad_output"] = "at::randint(-3, 2, {2,2,2})"
|
||||
@ -78,6 +64,60 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N
|
||||
arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)"
|
||||
arg_map["sparse_grad"] = "false"
|
||||
return
|
||||
if op_name == "gelu":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({6, 6, 6})"
|
||||
arg_map["approximate"] = "\"tanh\""
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 22, 22})"
|
||||
arg_map["approximate"] = "\"tanh\""
|
||||
return
|
||||
if op_name == "gelu_backward":
|
||||
if index == 0:
|
||||
arg_map["grad_output"] = "at::rand({6, 6, 6})"
|
||||
arg_map["self"] = "at::rand({6, 6, 6})"
|
||||
arg_map["approximate"] = "\"tanh\""
|
||||
else:
|
||||
arg_map["grad_output"] = "at::rand({22, 22, 22})"
|
||||
arg_map["self"] = "at::rand({22, 22, 22})"
|
||||
arg_map["approximate"] = "\"tanh\""
|
||||
return
|
||||
if op_name == "index_add":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({2})"
|
||||
arg_map["dim"] = "0"
|
||||
arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)"
|
||||
arg_map["source"] = "at::rand({2})"
|
||||
arg_map["alpha"] = "2"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({16})"
|
||||
arg_map["dim"] = "0"
|
||||
arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)"
|
||||
arg_map["source"] = "at::rand({16})"
|
||||
arg_map["alpha"] = "2"
|
||||
return
|
||||
if op_name == "index_copy":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({2})"
|
||||
arg_map["dim"] = "0"
|
||||
arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)"
|
||||
arg_map["source"] = "at::rand({2})"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({32})"
|
||||
arg_map["dim"] = "0"
|
||||
arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)"
|
||||
arg_map["source"] = "at::rand({32})"
|
||||
return
|
||||
if op_name == "linalg_cross":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({6, 3, 6})"
|
||||
arg_map["other"] = "at::rand({6, 3, 6})"
|
||||
arg_map["dim"] = "1"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 3, 22})"
|
||||
arg_map["other"] = "at::rand({22, 3, 22})"
|
||||
arg_map["dim"] = "1"
|
||||
return
|
||||
if op_name == "nll_loss_backward":
|
||||
if index == 0:
|
||||
arg_map["grad_output"] = "at::rand({})"
|
||||
|
2772
torch/csrc/jit/runtime/static/generated_ops.cpp
Normal file
2772
torch/csrc/jit/runtime/static/generated_ops.cpp
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user