#include #include #include #include #include #include #include #include #include #include #include "deep_wide_pt.h" #include "test_utils.h" using namespace caffe2; using namespace torch; using namespace torch::jit; using namespace torch::jit::test; using c10::IValue; /* When adding a test for an operator implemented in static runtime, there are several things that you need to pay attention to: 1) if the op is an out variant, in the test script of the op, instead of: def forward(self, input): return myop(input) do: def forward(self, input): return myop(input).clone() This makes sure that the output of myop is managed by the memory planner and exercise the code path in the op impl that otherwise doesn't get exercised. The output of the model is not managed by the memory planner, because it needs to be returned to the client. 2) The memory planner rounds up the size of each Tensor's storage to multiples of 64 bytes (alignment requirement on AVX512). Make sure the sizes of the input tensors in args2 are big enough to trigger resizing. 3) for view ops such as aten::reshape or aten::to, if you want it to be replaced by the copy version with the ReplaceWithCopy pass in passes.h, you also want to make sure its output is not returned as the model output. The reason is that ReplaceWithCopy only replaces the op whose output is not an alias of the model output. */ C10_DECLARE_bool(static_runtime_enable_fast_math); TEST(StaticRuntime, UnaryOps) { const auto aten_sum = R"JIT( def forward(self, input): return torch.sum(input).clone() )JIT"; const auto aten_sum_0 = R"JIT( def forward(self, input): return torch.sum(input, 0).clone() )JIT"; const auto aten_sum_1 = R"JIT( def forward(self, input): return torch.sum(input, 1).clone() )JIT"; const auto aten_sum_0_true = R"JIT( def forward(self, input): return torch.sum(input, 0, True).clone() )JIT"; const auto aten_sum_1_true = R"JIT( def forward(self, input): return torch.sum(input, 1, True).clone() )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({3, 3, 6}); std::vector args{a}, args2{b}; // sum testStaticRuntime(aten_sum, args); testStaticRuntime(aten_sum_0, args); testStaticRuntime(aten_sum_1, args); testStaticRuntime(aten_sum_0_true, args); testStaticRuntime(aten_sum_1_true, args); testStaticRuntime(aten_sum, args, args2, false, false, false); testStaticRuntime(aten_sum_0, args, args2); testStaticRuntime(aten_sum_1, args, args2); testStaticRuntime(aten_sum_0_true, args, args2); testStaticRuntime(aten_sum_1_true, args, args2); } TEST(StaticRuntime, Max) { auto src_max_reduce = R"JIT( def forward(self, input): return torch.max(input).clone() )JIT"; auto src_max_dim = R"JIT( def forward(self, input, dim: int): values, indices = torch.max(input, dim) return values.clone(), indices.clone() )JIT"; auto src_max_dim_keepdim = R"JIT( def forward(self, input, dim: int): values, indices = torch.max(input, dim, keepdim=True) return values.clone(), indices.clone() )JIT"; auto src_max_pointwise = R"JIT( def forward(self, input, other): return torch.max(input, other).clone() )JIT"; auto input = at::randn({2, 3, 2}); auto input_other = at::randn({2, 3, 2}); auto large_input = at::randn({8, 9, 10}); auto large_input_other = at::randn({8, 9, 10}); testStaticRuntime(src_max_reduce, {input}); testStaticRuntime(src_max_dim, {input, 1}); testStaticRuntime(src_max_dim, {input, 1}, {large_input, 0}); testStaticRuntime(src_max_dim_keepdim, {input, 0}); testStaticRuntime(src_max_dim_keepdim, {input, 0}, {large_input, 2}); testStaticRuntime(src_max_pointwise, {input, input_other}); testStaticRuntime(src_max_pointwise, {input, input_other}, {large_input, large_input_other}); } TEST(StaticRuntime, Mean) { const auto src_default = R"JIT( def forward(self, input): return torch.mean(input).clone() )JIT"; const auto src_dtype = R"JIT( def forward(self, input, dtype: int): return torch.mean(input, dtype=dtype).clone() )JIT"; const auto src_dim = R"JIT( def forward(self, input, dim: List[int]): return torch.mean(input, dim).clone() )JIT"; const auto src_dim_keepdim = R"JIT( def forward(self, input, dim: List[int]): return torch.mean(input, dim, keepdim=True).clone() )JIT"; const auto src_dim_dtype = R"JIT( def forward(self, input, dim: List[int], dtype: int): return torch.mean(input, dim, dtype=dtype).clone() )JIT"; auto input = at::randn({2, 3, 2}); auto large_input = at::randn({8, 7, 6, 8}); std::vector args_default = {input}; std::vector args_dtype = {input, torch::kFloat}; std::vector args_dim = {input, c10::List{0, 2}}; std::vector args_dim_keepdim = {input, c10::List{1, 2}}; std::vector args_dim_dtype = {input, c10::List{0, 1}, torch::kBFloat16}; testStaticRuntime(src_default, args_default); testStaticRuntime(src_dtype, args_dtype); testStaticRuntime(src_dim, args_dim); testStaticRuntime(src_dim_keepdim, args_dim_keepdim); testStaticRuntime(src_dim_dtype, args_dim_dtype); std::vector large_args_dim = {large_input, c10::List{0, 3}}; std::vector large_args_dim_keepdim = {large_input, c10::List{1, 2}}; std::vector large_args_dim_dtype = {large_input, c10::List{1, 3}, torch::kBFloat16}; testStaticRuntime(src_dim, args_dim, large_args_dim); testStaticRuntime(src_dim_keepdim, args_dim_keepdim, large_args_dim_keepdim); testStaticRuntime(src_dim_dtype, args_dim_dtype, large_args_dim_dtype); } TEST(StaticRuntime, Sigmoid) { const auto sigmoid_script = R"JIT( def forward(self, inp: Tensor): b = torch.sigmoid(inp).clone() return (b) )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({4, 3, 2}); std::vector args{a}, args2{b}; testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true); testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true); FLAGS_static_runtime_enable_fast_math = false; testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true); testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true); FLAGS_static_runtime_enable_fast_math = true; } TEST(StaticRuntime, Clone) { /* Clone called two times to trigger memory planner for output of first clone. The output of last op(second clone) is not managed by memory planner since it needs to be returned to the client and cannot be reused by planner. */ const auto clone_script_0 = R"JIT( def forward(self, input): a = torch.clone(input).clone() return (a * a) )JIT"; // Case: clone with different set of memory_formats const auto clone_script_1 = R"JIT( def forward(self, input: Tensor, memory_format: int): a = torch.clone(input, memory_format=memory_format).clone() return (a * a) )JIT"; /* Case: input stride set to 0 (due to expand op) calls native clone instead of out variant */ const auto clone_script_2 = R"JIT( def forward(self, input: Tensor, other:Tensor): a = input.expand_as(other) return a.clone().clone() )JIT"; /* Case: testing the case of sliced tensor for testing non-contiguous tensor storage */ const auto clone_script_3 = R"JIT( def forward(self, input: Tensor): a = input[:, 0:10:2] return a.clone().clone() )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({3, 2}).as_strided({3, 2}, {1, 3}); auto b_larger = at::randn({30, 20}).as_strided({30, 20}, {1, 3}); auto c = at::randn({1, 20, 13, 8}); auto d = at::randn({1, 0, 3, 4}); auto e = at::randn({2, 1}); auto f = at::randn({2, 10}); auto g = at::randn({3, 20}); std::vector args_0{b, c10::MemoryFormat::Contiguous}; std::vector args_1{b_larger, c10::MemoryFormat::Preserve}; std::vector args_2{c, c10::MemoryFormat::ChannelsLast}; std::vector args_3{d, c10::MemoryFormat::ChannelsLast}; std::vector args_4{e,a}; std::vector args_5{e,f}; testStaticRuntime(clone_script_0, {a}); testStaticRuntime(clone_script_0, {a}, {b_larger}); testStaticRuntime(clone_script_1, args_0); testStaticRuntime(clone_script_1, args_1); testStaticRuntime(clone_script_1, args_2); testStaticRuntime(clone_script_1, args_3); testStaticRuntime(clone_script_1, args_0, args_1); testStaticRuntime(clone_script_1, args_3, args_2); testStaticRuntime(clone_script_2, args_4); testStaticRuntime(clone_script_2, args_4, args_5); testStaticRuntime(clone_script_3, {f}); testStaticRuntime(clone_script_3, {f}, {g}); } TEST(StaticRuntime, Clamp) { const auto clamp_script_1 = R"JIT( def forward(self, inp: Tensor, min: int, max: int): a = torch.clamp(inp, min, max).clone() return (a) )JIT"; const auto clamp_script_2 = R"JIT( def forward(self, inp: Tensor, min: Tensor, max: Tensor): a = torch.clamp(inp, min, max).clone() return (a) )JIT"; auto a = at::randn({2, 3}); auto max_t = at::full_like(a, 1); auto min_t = at::full_like(a, -1); auto b = at::randn({4, 3, 2}); auto max_t1 = at::full_like(b, 1); auto min_t1 = at::full_like(b, -1); testStaticRuntime(clamp_script_1, {a, -1, 1}); testStaticRuntime(clamp_script_2, {a, min_t, max_t}); testStaticRuntime(clamp_script_1, {a, -1, 1}, {b, -1, 1}); testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1}); } TEST(StaticRuntime, ClampMinOnly) { const auto src = R"JIT( def forward(self, inp: Tensor, min: float): a = torch.clamp(inp, min, None).clone() return (a) )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({4, 3, 2}); testStaticRuntime(src, {a, 0.5}); testStaticRuntime(src, {a, 0.5}, {b, 0.25}); } TEST(StaticRuntime, ClampMaxOnly) { const auto src = R"JIT( def forward(self, inp: Tensor, max: float): a = torch.clamp(inp, None, max).clone() return (a) )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({4, 3, 2}); testStaticRuntime(src, {a, 0.5}); testStaticRuntime(src, {a, 0.5}, {b, 0.25}); } TEST(StaticRuntime, ClampIntTensor) { const auto src = R"JIT( def forward(self, inp: Tensor, min: float, max: float): a = torch.clamp(inp, min, max).clone() return (a) )JIT"; auto a = at::randint(0, 20, {2, 3}, at::kFloat); auto b = at::randint(0, 20, {4, 3, 2}, at::kFloat); auto min = 5.0f; auto max = 5.0f; testStaticRuntime(src, {a, min, max}); testStaticRuntime(src, {a, min, max}, {b, min, max}); } TEST(StaticRuntime, LenWithTuple) { const auto src = R"IR( graph(%input : int[]): %res : int = aten::len(%input) return (%res) )IR"; testStaticRuntime(src, {c10::List(4)}); } TEST(StaticRuntime, LenWithTensor) { const auto src = R"IR( graph(%input : Tensor): %res : int = aten::len(%input) return (%res) )IR"; testStaticRuntime(src, {at::randn({2, 2, 2})}); } TEST(StaticRuntime, LenWithStr) { const auto src = R"IR( graph(%input : str): %res : int = aten::len(%input) return (%res) )IR"; testStaticRuntime(src, {"static_runtime"}); } TEST(StaticRuntime, LenWithDict_str) { const auto script = R"JIT( def forward(self, input: Dict[str, str]): return len(input) )JIT"; c10::Dict dict; dict.insert("abc", "123"); dict.insert("def", "456"); testStaticRuntime(script, {dict}); } TEST(StaticRuntime, LenWithDict_int) { const auto script = R"JIT( def forward(self, input: Dict[int, int]): return len(input) )JIT"; c10::Dict dict; dict.insert(0, 1); dict.insert(2, 3); testStaticRuntime(script, {dict}); } TEST(StaticRuntime, LenWithDict_bool) { const auto script = R"JIT( def forward(self, input: Dict[bool, bool]): return len(input) )JIT"; c10::Dict dict; dict.insert(true, false); dict.insert(false, true); testStaticRuntime(script, {dict}); } TEST(StaticRuntime, LenWithDict_float) { const auto script = R"JIT( def forward(self, input: Dict[float, float]): return len(input) )JIT"; c10::Dict dict; dict.insert(0.1, 0.9); dict.insert(0.8, 0.18); testStaticRuntime(script, {dict}); } TEST(StaticRuntime, LenWithDict_complex) { const auto script = R"JIT( def forward(self, input: Dict[complex, complex]): return len(input) )JIT"; c10::Dict, c10::complex> dict; dict.insert(0.1, 0.4); dict.insert(0.9, 0.45); testStaticRuntime(script, {dict}); } TEST(StaticRuntime, LenWithDict_Tensor) { const auto script = R"JIT( def forward(self, input: Dict[Tensor, Tensor]): return len(input) )JIT"; c10::Dict dict; dict.insert(at::randn({1, 2}), at::randn({1, 2})); dict.insert(at::randn({1, 2}), at::randn({1, 2})); testStaticRuntime(script, {dict}); } TEST(StaticRuntime, Logit) { // no nnc const auto logit_script_1 = R"JIT( def forward(self, inp: Tensor): a = torch.logit(inp).clone() return (a) )JIT"; // with nnc const auto logit_script_2 = R"JIT( def forward(self, inp: Tensor): a = torch.logit(inp, 1e-6).clone() return (a) )JIT"; // no nnc const auto logit_script_3 = R"JIT( def forward(self, inp: Tensor, eps: float): a = torch.logit(inp, eps).clone() return (a) )JIT"; auto a = at::ones({2, 3}); double b = 1e-6; std::vector args_1{a}; std::vector args_2({a, b}); auto c = at::ones({4, 3, 2}); // logit testStaticRuntime(logit_script_1, args_1); testStaticRuntime(logit_script_2, args_1); testStaticRuntime(logit_script_3, args_2); testStaticRuntime(logit_script_1, args_1, {c}); testStaticRuntime(logit_script_2, args_1, {c}); testStaticRuntime(logit_script_3, args_2, {c, b}); } TEST(StaticRuntime, EmbeddingBag) { const std::string embedding_bag_default = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): x, y, z, _ = torch.embedding_bag(a, b, c) return (x.clone(), y.clone(), z.clone(), _.clone()) )JIT"; const std::string embedding_bag_mean = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): x, y, z, _ = torch.embedding_bag(a, b, c, False, 1) return (x.clone(), y.clone(), z.clone(), _.clone()) )JIT"; const std::string embedding_bag_max = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): x, y, z, _ = torch.embedding_bag(a, b, c, False, 2) return (x.clone(), y.clone(), z.clone(), _.clone()) )JIT"; const std::string embedding_bag_sum_last_offset = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): x, y, z, _ = torch.embedding_bag(a, b, c, False, 0, False, None, True) return (x.clone(), y.clone(), z.clone(), _.clone()) )JIT"; const std::string embedding_bag_mean_last_offset = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): x, y, z, _ = torch.embedding_bag(a, b, c, False, 1, False, None, True) return (x.clone(), y.clone(), z.clone(), _.clone()) )JIT"; const std::string embedding_bag_max_last_offset = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): x, y, z, _ = torch.embedding_bag(a, b, c, False, 2, False, None, True) return (x.clone(), y.clone(), z.clone(), _.clone()) )JIT"; at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float); at::Tensor input = torch::tensor({0, 1, 0, 2}); at::Tensor offset = torch::tensor({0, 2, 4}); std::vector args{weight, input, offset}; testStaticRuntime(embedding_bag_default, args); testStaticRuntime(embedding_bag_mean, args); testStaticRuntime(embedding_bag_max, args); testStaticRuntime(embedding_bag_sum_last_offset, args); testStaticRuntime(embedding_bag_mean_last_offset, args); testStaticRuntime(embedding_bag_max_last_offset, args); at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float); at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1}); at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5}); std::vector args2{weight2, input2, offset2}; testStaticRuntime(embedding_bag_default, args, args2); testStaticRuntime(embedding_bag_mean, args, args2); testStaticRuntime(embedding_bag_max, args, args2); testStaticRuntime(embedding_bag_sum_last_offset, args, args2); testStaticRuntime(embedding_bag_mean_last_offset, args, args2); testStaticRuntime(embedding_bag_max_last_offset, args, args2); } TEST(StaticRuntime, EmbeddingBagWithManagedOutput) { const std::string embedding_bag_managed_output = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): # The outputs of embedding_bag become an intermediate tensors # since they are not directly returned from the graph. x, y, z, _ = torch.embedding_bag(a, b, c) return x + x )JIT"; at::Tensor weight = torch::randn({3, 8}, at::ScalarType::Float); at::Tensor input = torch::tensor({0, 1, 0, 2}); at::Tensor offset = torch::tensor({0, 2}); std::vector args{weight, input, offset}; at::Tensor weight2 = torch::randn({6, 8}, at::ScalarType::Float); at::Tensor input2 = torch::tensor({0, 1, 0, 2, 3, 4}); at::Tensor offset2 = torch::tensor({0, 2, 4, 5}); std::vector args2{weight2, input2, offset2}; testStaticRuntime(embedding_bag_managed_output, args); testStaticRuntime(embedding_bag_managed_output, args, args2); } TEST(StaticRuntime, EmbeddingBagWithExtraneousOutput) { const std::string embedding_bag_default_ir = R"IR( graph(%weight, %indices, %offsets): %scale_grad_by_freq : bool = prim::Constant[value=0]() %mode : int = prim::Constant[value=0]() %sparse : bool = prim::Constant[value=0]() %per_sample_weights : NoneType = prim::Constant() %include_last_offset : bool = prim::Constant[value=0]() %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset) %none : NoneType = prim::Constant() %res : Tensor = aten::clone(%y0, %none) return (%res) )IR"; auto graph = getGraphFromIR(embedding_bag_default_ir); RemoveUnnecessaryOutputs(graph); torch::jit::testing::FileCheck() .check("static_runtime::embedding_bag") ->run(*graph); const std::string embedding_bag_mean_ir = R"IR( graph(%weight, %indices, %offsets): %scale_grad_by_freq : bool = prim::Constant[value=0]() %mode : int = prim::Constant[value=1]() %sparse : bool = prim::Constant[value=0]() %per_sample_weights : NoneType = prim::Constant() %include_last_offset : bool = prim::Constant[value=0]() %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset) %none : NoneType = prim::Constant() %res : Tensor = aten::clone(%y0, %none) return (%res) )IR"; graph = getGraphFromIR(embedding_bag_mean_ir); RemoveUnnecessaryOutputs(graph); torch::jit::testing::FileCheck() .check("static_runtime::embedding_bag") ->run(*graph); const std::string embedding_bag_max_last_offset_ir = R"IR( graph(%weight, %indices, %offsets): %scale_grad_by_freq : bool = prim::Constant[value=0]() %mode : int = prim::Constant[value=2]() %sparse : bool = prim::Constant[value=0]() %per_sample_weights : NoneType = prim::Constant() %include_last_offset : bool = prim::Constant[value=1]() %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset) %none : NoneType = prim::Constant() %res : Tensor = aten::clone(%y0, %none) return (%res) )IR"; graph = getGraphFromIR(embedding_bag_max_last_offset_ir); RemoveUnnecessaryOutputs(graph); torch::jit::testing::FileCheck() .check("static_runtime::embedding_bag") ->run(*graph); const std::string embedding_bag_normal_ir = R"IR( graph(%weight, %indices, %offsets): %scale_grad_by_freq : bool = prim::Constant[value=0]() %mode : int = prim::Constant[value=0]() %sparse : bool = prim::Constant[value=0]() %per_sample_weights : NoneType = prim::Constant() %include_last_offset : bool = prim::Constant[value=0]() %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset) %none : NoneType = prim::Constant() %res0 : Tensor = aten::clone(%y0, %none) %res1 : Tensor = aten::clone(%y1, %none) %res2 : Tensor = aten::clone(%y2, %none) %res3 : Tensor = aten::clone(%y3, %none) return (%res0, %res1, %res2, %res3) )IR"; graph = getGraphFromIR(embedding_bag_normal_ir); RemoveUnnecessaryOutputs(graph); torch::jit::testing::FileCheck() .check_not("static_runtime::embedding_bag") ->run(*graph); at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float); at::Tensor input = torch::tensor({0, 1, 0, 2}); at::Tensor offset = torch::tensor({0, 2, 4}); std::vector args{weight, input, offset}; testStaticRuntime(embedding_bag_default_ir, args); testStaticRuntime(embedding_bag_mean_ir, args); testStaticRuntime(embedding_bag_max_last_offset_ir, args); at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float); at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1}); at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5}); std::vector args2{weight2, input2, offset2}; testStaticRuntime(embedding_bag_default_ir, args, args2); testStaticRuntime(embedding_bag_mean_ir, args, args2); testStaticRuntime(embedding_bag_max_last_offset_ir, args, args2); } TEST(StaticRuntime, EmbeddingBagWithMixedInt32Int64Input) { const std::string embedding_bag_default = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): x, y, z, _ = torch.embedding_bag(a, b, c) return (x.clone(), y.clone(), z.clone(), _.clone()) )JIT"; auto weight = torch::randn({3, 11}, at::ScalarType::Float); auto input = torch::tensor({0, 1, 0, 2}, at::ScalarType::Long); auto offset = torch::tensor({0, 2, 4}, at::ScalarType::Int); std::vector args{weight, input, offset}; testStaticRuntime(embedding_bag_default, args); } TEST(StaticRuntime, LayerNorm) { const std::string layer_norm_with_weights = R"JIT( def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor): return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone() )JIT"; const std::string layer_norm_without_weights = R"JIT( def forward(self, input: Tensor, normalized_shape: List[int]): return torch.layer_norm(input, normalized_shape, None, None, 1e-05, False).clone() )JIT"; const std::string layer_norm_with_noncontiguous_input = R"JIT( def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor): input = torch.transpose(input, 1, 2) return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone() )JIT"; const auto a = torch::rand({1, 2, 2, 2}); const auto b = torch::rand({3, 2, 2, 2}); for (int normalized_size : {2, 3}) { std::vector normalized_shape(normalized_size, 2); const auto weight = torch::rand(normalized_shape); const auto bias = torch::rand(normalized_shape); std::vector args{a, normalized_shape, weight, bias}; std::vector args1{b, normalized_shape, weight, bias}; testStaticRuntime(layer_norm_with_weights, args); testStaticRuntime(layer_norm_with_weights, args, args1); testStaticRuntime(layer_norm_with_noncontiguous_input, args); args = {a, normalized_shape}; testStaticRuntime(layer_norm_without_weights, args); testStaticRuntime(layer_norm_without_weights, args, {b, normalized_shape}); } } TEST(StaticRuntime, Bmm) { const auto bmm_script = R"JIT( def forward(self, inp: Tensor, mat2: Tensor): return torch.bmm(inp, mat2).clone() )JIT"; auto a = at::randn({10, 4, 5}); auto b = at::randn({10, 5, 6}); auto c = at::randn({12, 5, 6}); auto d = at::randn({12, 6, 7}); std::vector args{a, b}; std::vector args1{c, d}; testStaticRuntime(bmm_script, args); testStaticRuntime(bmm_script, args1); testStaticRuntime(bmm_script, args, args1); } TEST(StaticRuntime, Addmm) { const auto addmm_script = R"JIT( def forward(self, inp: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float): return torch.addmm(inp, mat1, mat2, alpha=alpha, beta=beta).clone() )JIT"; auto inp1 = at::randn({5}); auto mat1 = at::randn({3, 4}); auto mat2 = at::randn({4, 5}); auto inp2 = at::randn({3, 7}); auto mat3 = at::randn({3, 6}); auto mat4 = at::randn({6, 7}); std::vector args{inp1, mat1, mat2, 1.0, 2.0}; std::vector args1{inp2, mat3, mat4, 2.0, 1.0}; testStaticRuntime(addmm_script, args); testStaticRuntime(addmm_script, args1); testStaticRuntime(addmm_script, args, args1); } TEST(StaticRuntime, Abs) { const auto abs_script = R"JIT( def forward(self, a): return a.abs().clone() )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({4, 2, 3}); std::vector args{a}; std::vector args2{b}; testStaticRuntime(abs_script, args); testStaticRuntime(abs_script, args, args2); } TEST(StaticRuntime, Binary) { const auto add_script = R"JIT( def forward(self, a, b): c = a + b return (c.clone()) )JIT"; const auto add_script_ints = R"JIT( def forward(self, a: int, b: int): c = a + b d = c + 1 return d )JIT"; const auto add_list_script = R"JIT( def forward(self, a: List[int], b: List[int]): c = a + b return c[::] )JIT"; const auto list_construct_script = R"JIT( def forward(self, a, b): return [a, b] )JIT"; const auto list_construct_script_2 = R"JIT( def forward(self, a, b): c = a + a return [c, c] )JIT"; const auto list_construct_script_3 = R"JIT( def forward(self, a, b): c = a + a return [c, c.flatten()] )JIT"; const auto list_unpack_script = R"JIT( def forward(self, a, b): c = [a, b] x, y = c z = x + y return z.clone() )JIT"; const auto list_unpack_script_2 = R"JIT( def forward(self, a, b): c = [a, b] x, y = c z = (x, y) return z )JIT"; const auto tuple_construct_script = R"JIT( def forward(self, a, b): return (a, b) )JIT"; const auto tuple_construct_script_2 = R"JIT( def forward(self, a, b): return (a.flatten(), b) )JIT"; auto a = at::randn({2, 3}); auto b = at::ones({2, 3}); auto c = at::randn({4, 2, 3}); auto d = at::ones({4, 2, 3}); std::vector args{a, b}; testStaticRuntime(add_script, args); testStaticRuntime(add_script_ints, {1, 2}); testStaticRuntime(add_script, args, {c, d}); testStaticRuntime(list_construct_script, args); testStaticRuntime(list_construct_script_2, args); testStaticRuntime(list_construct_script_3, args); testStaticRuntime(list_unpack_script, args); testStaticRuntime(list_unpack_script_2, args); testStaticRuntime(tuple_construct_script, args); testStaticRuntime(tuple_construct_script_2, args); std::vector list_args{ c10::List{1, 2, 3}, c10::List{4, 5, 6}}; testStaticRuntime(add_list_script, list_args); } TEST(StaticRuntime, MatMul) { const auto aten_matmul = R"JIT( def forward(self, a: Tensor, b: Tensor): return torch.matmul(a, b).clone() )JIT"; // 1-D, 1-D std::vector args{at::randn({3}), at::randn({3})}; testStaticRuntime(aten_matmul, args); // 2-D, 2-D std::vector args1 = {at::randn({3, 2}), at::randn({2, 3})}; testStaticRuntime(aten_matmul, args1); // 1-D, 2-D std::vector args2 = {at::randn({3}), at::randn({3, 5})}; testStaticRuntime(aten_matmul, args2); // 2-D, 1-D std::vector args3 = {at::randn({3, 5}), at::randn({5})}; testStaticRuntime(aten_matmul, args3); // > 2-D , > 2-D std::vector args4 = {at::randn({3, 1, 4, 5}), at::randn({2, 5, 6})}; testStaticRuntime(aten_matmul, args4); testStaticRuntime(aten_matmul, args3, args4); } TEST(StaticRuntime, Sign) { const auto sign_tensor = R"JIT( def forward(self, input: Tensor): return torch.sign(input).clone() )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({4, 3, 2}); std::vector args{a}; testStaticRuntime(sign_tensor, args); testStaticRuntime(sign_tensor, args, {b}); } TEST(StaticRuntime, Div) { const auto div_tensor = R"JIT( def forward(self, a: Tensor, b: Tensor): return torch.div(a, b).clone() )JIT"; const auto div_scalar = R"JIT( def forward(self, a: Tensor, b: int): return torch.div(a, b).clone() )JIT"; const auto div_tensor_mode = R"JIT( def forward(self, a: Tensor, b: Tensor, c: str): return torch.div(a, b, rounding_mode=c).clone() )JIT"; const auto div_scalar_mode = R"JIT( def forward(self, a: Tensor, b: float, c: str): return torch.div(a, b, rounding_mode=c).clone() )JIT"; const auto div_strided = R"JIT( def forward(self, a: Tensor, b: Tensor): a_strided = torch.transpose(a, 0, 1) b_strided = torch.transpose(b, 0, 1) return torch.div(a_strided, b_strided).clone() )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({2, 3}); auto bs = at::randn({3, 2}).transpose(0, 1); auto c = at::randn({4, 3, 2}); auto d = at::randn({4, 3, 2}); auto ds = at::randn({3, 4, 2}).transpose(0, 1); std::vector args0{a, b}; testStaticRuntime(div_tensor, args0); testStaticRuntime(div_tensor, args0, {c, d}); testStaticRuntime(div_strided, args0); testStaticRuntime(div_strided, args0, {c, d}); testStaticRuntime(div_tensor, {a, bs}); testStaticRuntime(div_tensor, {a, bs}, {c, ds}); std::vector args1{a, 3}; testStaticRuntime(div_scalar, args1); testStaticRuntime(div_scalar, args1, {c, 4}); std::vector args2{a, b, "floor"}; testStaticRuntime(div_tensor_mode, args2); testStaticRuntime(div_tensor_mode, args2, {c, d, "floor"}); std::vector args3{a, 2.3, "trunc"}; testStaticRuntime(div_scalar_mode, args3); testStaticRuntime(div_scalar_mode, args3, {c, 1.5, "trunc"}); } TEST(StaticRuntime, Mul) { const auto mul_tensor = R"JIT( def forward(self, a: Tensor, b: Tensor): return torch.mul(a, b).clone() )JIT"; const auto mul_scalar = R"JIT( def forward(self, a: Tensor, b: int): return torch.mul(a, b).clone() )JIT"; const auto mul_list = R"JIT( def forward(self, a: List[int], n: int): b = a * n return b[::] )JIT"; auto a = at::randn({3, 3}); auto b = at::randn({3, 3}); auto c = at::randn({3, 3, 3}); auto d = at::randn({3, 3, 3}); std::vector tensor_args1{a, b}; std::vector tensor_args2{c, d}; testStaticRuntime(mul_tensor, tensor_args1); testStaticRuntime(mul_tensor, tensor_args1, tensor_args2); std::vector scalar_args1{a, 42}; std::vector scalar_args2{c, 42}; testStaticRuntime(mul_scalar, scalar_args1); testStaticRuntime(mul_scalar, scalar_args1, scalar_args2); std::vector list_args{c10::List{1, 2}, 3}; testStaticRuntime(mul_list, list_args); } TEST(StaticRuntime, Log) { const auto log_tensor = R"JIT( def forward(self, inp: Tensor): a = torch.log(inp).clone() return (a) )JIT"; // Ensure that the input values are valid. auto a = at::abs(at::randn({2, 3})); auto b = at::abs(at::randn({4, 3, 2})); std::vector args{a}; testStaticRuntime(log_tensor, args); testStaticRuntime(log_tensor, args, {b}); } TEST(StaticRuntime, Sub) { const auto sub_tensor = R"JIT( def forward(self, a: Tensor, b: Tensor): return torch.sub(a, b).clone() )JIT"; const auto sub_scalar = R"JIT( def forward(self, a: Tensor, b: int): return torch.sub(a, b).clone() )JIT"; const auto sub_tensor_alpha = R"JIT( def forward(self, a: Tensor, b: Tensor, c: float): return torch.sub(a, b, alpha=c).clone() )JIT"; const auto sub_scalar_alpha = R"JIT( def forward(self, a: Tensor, b: float, c: int): return torch.sub(a, b, alpha=c).clone() )JIT"; const auto sub_two_scalars = R"JIT( def forward(self, a: int, b: int): return (a - b - b) )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({2, 3}); auto c = at::randn({4, 3, 2}); auto d = at::randn({4, 3, 2}); std::vector args0{a, b}; testStaticRuntime(sub_tensor, args0); testStaticRuntime(sub_tensor, args0, {c, d}); std::vector args1{a, 3}; testStaticRuntime(sub_scalar, args1); testStaticRuntime(sub_scalar, args1, {c, 4}); std::vector args2{a, b, 2.3}; testStaticRuntime(sub_tensor_alpha, args2); testStaticRuntime(sub_tensor_alpha, {c, d, 3.1}); std::vector args3{a, 2.3, 4}; testStaticRuntime(sub_scalar_alpha, args3); testStaticRuntime(sub_scalar_alpha, {c, 1.3, 2}); std::vector args4{1, 2}; testStaticRuntime(sub_two_scalars, args4); } TEST(StaticRuntime, NanToNum) { const auto nan_to_num_script = R"JIT( def forward(self, a: Tensor, nan: float, posinf: float, neginf: float): return torch.nan_to_num(a, nan, posinf, neginf).clone() )JIT"; const auto inf = std::numeric_limits::infinity(); const auto nan = std::numeric_limits::quiet_NaN(); auto a = torch::tensor({{1.0, nan}, {-inf, inf}}); auto b = at::randn({3, 6}); float* b_data = b.data_ptr(); b_data[0] = nan; b_data[4] = -inf; b_data[11] = inf; b_data[13] = nan; std::vector args1{a, 1.0, 2.0, -2.0}; std::vector args2{b, 1.0, 2.0, -2.0}; testStaticRuntime( nan_to_num_script, args1, /*args2*/ {}, /*use_allclose*/ true, /*use_equalnan*/ true); testStaticRuntime( nan_to_num_script, args1, args2, /*use_allclose*/ true, /*use_equalnan*/ true); } TEST(StaticRuntime, Stack) { const auto stack_dim = R"JIT( def forward(self, a: Tensor, b: Tensor, dim: int): inputs = [a] inputs.append(b) # mutation to avoid using VarStack return torch.stack(inputs, dim = dim).clone() )JIT"; const auto stack_three = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): inputs = [a, b] inputs.append(c) # mutation to avoid using VarStack return torch.stack(inputs).clone() )JIT"; auto a = at::randn({2, 2}); auto b = at::randn({2, 2}); auto c = at::randn({2, 2}); auto d = at::randn({3, 3, 3}); auto e = at::randn({3, 3, 3}); auto f = at::randn({3, 3, 3}); std::vector args1_dim{a, b, 0}; std::vector args2_dim{d, e, 1}; std::vector args_dim_negative{d, e, -1}; std::vector args1_three_tensors{a, b, c}; std::vector args2_three_tensors{d, e, f}; testStaticRuntime(stack_dim, args1_dim); testStaticRuntime(stack_dim, args1_dim, args2_dim); testStaticRuntime(stack_dim, args_dim_negative); testStaticRuntime(stack_three, args1_three_tensors); testStaticRuntime(stack_three, args1_three_tensors, args2_three_tensors); } TEST(StaticRuntime, ReLU) { const auto relu_script = R"JIT( def forward(self, a: Tensor): return torch.relu(a).clone() )JIT"; auto a = at::randint(-10, 10, {2, 4}); auto b = at::randint(-10, 10, {3, 6}); std::vector args1{a}; std::vector args2{b}; testStaticRuntime(relu_script, args1); testStaticRuntime(relu_script, args1, args2); } TEST(StaticRuntime, Tanh) { const auto tanh_script = R"JIT( def forward(self, a): return torch.tanh(a).clone() )JIT"; auto a = at::randn({2, 2}); auto b = at::randn({3, 3, 3}); std::vector args1{a}; std::vector args2{b}; testStaticRuntime(tanh_script, args1, /*args2*/ {}, /*use_allclose*/ true); testStaticRuntime(tanh_script, args1, args2, /*use_allclose*/ true); } TEST(StaticRuntime, Norm) { const auto norm_2arg = R"JIT( def forward(self, a: Tensor, p: int): return torch.norm(a, p).clone() )JIT"; const auto norm_3arg = R"JIT( def forward(self, a: Tensor, p: int, dtype: int): return torch.norm(a, p, dtype=dtype).clone() )JIT"; const auto norm_4arg = R"JIT( def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool): return torch.norm(a, p, dim, keepdim).clone() )JIT"; const auto norm_5arg = R"JIT( def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool, dtype: int): return torch.norm(a, p, dim, keepdim, dtype=dtype).clone() )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({4, 3, 5}); auto dim = std::vector({1}); auto dtype = at::ScalarType::Float; std::vector args2{a, 2}; testStaticRuntime(norm_2arg, args2); testStaticRuntime(norm_2arg, args2, {b, 2}, false, false, false); std::vector args3{a, 2, dtype}; testStaticRuntime(norm_3arg, args3); testStaticRuntime(norm_3arg, args3, {b, 2, dtype}, false, false, false); std::vector args4{a, 3, dim, false}; testStaticRuntime(norm_4arg, args4); testStaticRuntime(norm_4arg, args4, {b, 3, dim, false}); std::vector args5{a, 4, dim, true, dtype}; testStaticRuntime(norm_5arg, args5); testStaticRuntime(norm_5arg, args5, {b, 4, dim, true, dtype}); } TEST(StaticRuntime, Reshape) { const auto reshape_script_1 = R"JIT( def forward(self, a: Tensor, shape: List[int]): b = a.reshape(shape) return b + b )JIT"; const auto reshape_script_2 = R"JIT( def forward(self, a: Tensor, shape: List[int]): b = a.transpose(0, 1) return b.reshape(shape) )JIT"; const auto reshape_script_3 = R"JIT( def forward(self, inp: Tensor, shape: List[int]): a = inp + inp b = a.reshape(shape) c = a.reshape(shape) d = c + c e = d + d f = e * e g = f * f return b.reshape(shape), g )JIT"; // exercise reshape_copy and flatten_copy const auto reshape_script_4 = R"JIT( def forward(self, inp: Tensor, shape: List[int]): k = inp + inp a = k + k b = a.reshape(shape) c = a.flatten().reshape(shape) return b + c )JIT"; // exercise reshape_copy const auto reshape_script_5 = R"JIT( def forward(self, inp: Tensor, shape: List[int]): a = inp + inp b = a.reshape(shape) c = a.reshape(shape).relu() d = c + c e = d + d f = e * e g = f * f return g )JIT"; const auto reshape_inplace_script = R"JIT( def forward(self, inp: Tensor, shape: List[int]): a = inp + inp b = a.reshape(shape) c = b.sigmoid_() d = c + c e = a + a f = b + b return (d, e, f) )JIT"; // b is in_contiguous const auto reshape_incontiguous_script = R"JIT( def forward(self, a: Tensor, shape: List[int]): b = a.transpose(0, 1) c = b.reshape(shape) c = c.relu() return (c) )JIT"; auto a = at::randn({2, 3}); auto b = std::vector({3, 2}); std::vector args{a, b}; auto c = at::randn({4, 5}); auto d = std::vector({5, 1, 2, 2}); std::vector args1{c, d}; testStaticRuntime(reshape_script_1, args); testStaticRuntime(reshape_script_2, args); testStaticRuntime(reshape_script_3, args); testStaticRuntime(reshape_script_4, args); testStaticRuntime(reshape_script_5, args); testStaticRuntime(reshape_inplace_script, args); testStaticRuntime(reshape_incontiguous_script, args); testStaticRuntime(reshape_script_1, args, args1); testStaticRuntime(reshape_script_2, args, args1); testStaticRuntime(reshape_script_3, args, args1); testStaticRuntime(reshape_script_4, args, args1); testStaticRuntime(reshape_script_5, args, args1); testStaticRuntime(reshape_inplace_script, args, args1); testStaticRuntime(reshape_incontiguous_script, args, args1); } TEST(StaticRuntime, Repeat) { const std::string repeat = R"JIT( def forward(self, a: Tensor, repeats: List[int]): return torch.repeat(a, repeats).clone() )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({4, 3}); auto c = std::vector({1, 2}); auto d = std::vector({2, 3}); std::vector args1{a, c}; std::vector args2{b, d}; testStaticRuntime(repeat, args1); testStaticRuntime(repeat, args2); testStaticRuntime(repeat, args1, args2); } TEST(StaticRuntime, Flatten) { // exercise flatten_copy const auto flatten_script_1 = R"JIT( def forward(self, a: Tensor, start_dim: int, end_dim: int): b = a * a c = torch.flatten(b, start_dim, end_dim) d = torch.relu(c) return d )JIT"; const auto flatten_script_2 = R"JIT( def forward(self, a: Tensor, start_dim: int, end_dim: int): b = a.transpose(0, 1) return torch.flatten(b, start_dim, end_dim).clone() )JIT"; auto test_flatten = [&](std::vector shape, int64_t start_dim, int64_t end_dim) { std::vector shape1(shape); if (shape1.size() > 0) { shape1[0] *= 6; } auto a = at::randn(shape); auto b = at::randn(shape1); std::vector args{a, start_dim, end_dim}; bool check_resize = shape1.size() > 0; testStaticRuntime(flatten_script_1, args); testStaticRuntime( flatten_script_1, args, {b, start_dim, end_dim}, false, /* use_allclose */ false, /* use_equalnan */ check_resize); if (shape.size() > 2) { testStaticRuntime(flatten_script_2, args); testStaticRuntime(flatten_script_2, args, {b, start_dim, end_dim}); } }; test_flatten({2, 3}, 0, 1); test_flatten({2, 1, 3}, 1, 2); test_flatten({0, 1, 3, 0}, 1, 2); test_flatten({2, 3}, 1, 1); test_flatten({}, 0, 0); } TEST(StaticRuntime, pow) { const auto pow_script_ten_sca = R"JIT( def forward(self, input : Tensor, exponent : int): return torch.pow(input, exponent).clone() )JIT"; const auto pow_script_ten_ten = R"JIT( def forward(self, input : Tensor, exponent : Tensor): return torch.pow(input, exponent).clone() )JIT"; const auto pow_script_sca_ten = R"JIT( def forward(self, input : int, exponent : Tensor): return torch.pow(input, exponent).clone() )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({2, 3}); auto c = at::randn({4, 3, 2}); auto d = at::randn({4, 3, 2}); std::vector args0{a, 4}; testStaticRuntime(pow_script_ten_sca, args0); testStaticRuntime(pow_script_ten_sca, args0, {c, 4}); std::vector args1{at::abs(a), b}; testStaticRuntime(pow_script_ten_ten, args1); testStaticRuntime(pow_script_ten_ten, args1, {at::abs(c), d}); std::vector args2{5, b}; testStaticRuntime(pow_script_sca_ten, args2); testStaticRuntime(pow_script_sca_ten, args2, {3, d}); } TEST(StaticRuntime, to) { const auto to_script_dtype = R"JIT( def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int): a = input + input return torch.to(a, dtype, non_blocking, copy, memory_format).clone() )JIT"; const auto to_script_dtype_strided = R"JIT( def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int): b = input.permute(0, 2, 3, 1) return torch.to(b, dtype, non_blocking, copy, memory_format).clone() )JIT"; const auto to_script_prim_dtype = R"JIT( def forward(self, input:Tensor, dtype: Optional[int], non_blocking: bool, copy: bool): a = input + input return torch.to(a, dtype, non_blocking, copy).clone() )JIT"; const auto to_script_other = R"JIT( def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int): a = input + input return torch.to(a, other, non_blocking, copy, memory_format).clone() )JIT"; // if input is float tensor, b could be alias of a const auto to_script_alias = R"JIT( def forward(self, input:Tensor): a = input + input b = a.float() c = b * b return (c) )JIT"; const auto to_script_fails_managed_output_check = R"JIT( def forward(self, a, b): d = a.half() * b.half() e = d.float() return e )JIT"; const auto to_script_select_tensor_output_into_tuple = R"JIT( def forward(self, a, b): d = a.half() * b.half() e = d.float() return (d, e) )JIT"; const auto to_script_memory_planning_fail = R"JIT( def forward(self, a, b): d = a.half() * b.half() e = d.float().relu() return e )JIT"; auto test_to = [&](at::ScalarType b, bool c, bool d, c10::MemoryFormat e) { auto a = at::randn({4, 3, 1, 2}); auto other = at::randn({4, 3, 1, 2}).to(b); auto a2 = at::randn({3, 2, 2, 4}); auto a2_other = at::randn({3, 2, 2, 4}).to(b); std::vector args0{a, b, c, d, e}; std::vector args1{a, b, c, d}; std::vector args2{a, other, c, d, e}; std::vector args2WithDifferentOtherType{ a, at::randn({4, 3, 1, 2}, ScalarType::Double), c, d, e}; std::vector args3{a, std::nullopt, c, d}; std::vector args0WithInt{a, ScalarType::Int, c, d, e}; testStaticRuntime( to_script_dtype, args0, args0WithInt, /* default for use_allclose */ false, /* default for use_equalnan */ false, /* check_resize */ false); testStaticRuntime(to_script_dtype_strided, args0); testStaticRuntime(to_script_prim_dtype, args1); if (!d) { testStaticRuntime(to_script_prim_dtype, args3); } // Second set of args tests case where the `other` tensor's dtype // changes between iterations. testStaticRuntime( to_script_other, args2, args2WithDifferentOtherType, /* default for use_allclose */ false, /* default for use_equalnan */ false, /* check_resize */ false); testStaticRuntime(to_script_alias, {a}); testStaticRuntime(to_script_memory_planning_fail, {a, a}); testStaticRuntime(to_script_fails_managed_output_check, {a, a}); testStaticRuntime(to_script_select_tensor_output_into_tuple, {a, a}); // dynamic shapes testStaticRuntime(to_script_dtype, args0, {a2, b, c, d, e}); testStaticRuntime(to_script_dtype_strided, args0, {a2, b, c, d, e}); testStaticRuntime(to_script_prim_dtype, args1, {a2, b, c, d}); if (!d) { testStaticRuntime(to_script_prim_dtype, args3, {a2, std::nullopt, c, d}); } testStaticRuntime(to_script_other, args2, {a2, a2_other, c, d, e}); testStaticRuntime(to_script_alias, {a}, {a2}); }; for (const bool non_blocking : {false, true}) { for (const bool copy : {false, true}) { // float->float, NCHW->NHWC test_to( at::ScalarType::Float, non_blocking, copy, c10::MemoryFormat::ChannelsLast); // float->half test_to( at::ScalarType::Half, non_blocking, copy, c10::MemoryFormat::Preserve); // float->float test_to( at::ScalarType::Float, non_blocking, copy, c10::MemoryFormat::Contiguous); test_to( at::ScalarType::Bool, non_blocking, copy, c10::MemoryFormat::Contiguous); // TODO: check if fbgemm is enabled properly in this case // half->float, NCHW->NHWC test_to( at::ScalarType::Half, non_blocking, copy, c10::MemoryFormat::ChannelsLast); } } } TEST(StaticRuntime, ExpandAs) { const auto expand_as_script = R"JIT( def forward(self, input: Tensor, other:Tensor): a = input.expand_as(other) return a.clone() )JIT"; auto a = at::randn({3, 1}); auto b = at::randn({3, 2}); auto c = at::randn({4, 1}); auto d = at::randn({4, 2}); std::vector args{a, b}; std::vector args2{c, d}; testStaticRuntime(expand_as_script, args); testStaticRuntime(expand_as_script, args, args2); } TEST(StaticRuntime, Full) { const auto full_script = R"JIT( def forward(self, size: List[int], fill_value: int, dtype: Optional[int], layout: Optional[int], device: Optional[Device], pin_memory: Optional[bool]): a = torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) return (a.clone()) )JIT"; auto cpu = at::Device(DeviceType::CPU); c10::List size0{2, 5}; std::vector args{ size0, 4, at::ScalarType::Int, at::kStrided, cpu, false}; std::vector args1{ size0, 4, at::ScalarType::Float, at::kStrided, cpu, false}; c10::List size1{5, 6}; std::vector args2{ size1, 5, at::ScalarType::Float, at::kStrided, cpu, false}; testStaticRuntime(full_script, args); testStaticRuntime( full_script, args, args1, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/false); testStaticRuntime(full_script, args, args2); } TEST(StaticRuntime, FullLike) { const auto full_like_script = R"JIT( def forward(self, a: Tensor, fill_value: int, dtype: Optional[int], layout: Optional[int], device: Optional[Device], pin_memory: Optional[bool], memory_format: Optional[int]): b = torch.full_like(a, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, memory_format=memory_format) return (b.clone()) )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({3, 4, 2}); auto cpu = at::Device(DeviceType::CPU); std::vector args{ a, 4, at::ScalarType::Int, at::kStrided, cpu, false, c10::MemoryFormat::Contiguous}; std::vector args1{ a, 4, at::ScalarType::Float, at::kStrided, cpu, false, c10::MemoryFormat::Contiguous}; std::vector args2{ b, 4, at::ScalarType::Float, at::kStrided, cpu, false, c10::MemoryFormat::Contiguous}; testStaticRuntime(full_like_script, args); testStaticRuntime( full_like_script, args, args1, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/false); testStaticRuntime(full_like_script, args, args2); } TEST(StaticRuntime, Ones) { const auto script = R"JIT( def forward(self, size: List[int], dtype: Optional[int], layout: Optional[int], device: Optional[Device], pin_memory: Optional[bool]): a = torch.ones(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) return (a.clone()) )JIT"; auto dtype = at::ScalarType::Int; auto cpu = at::Device(DeviceType::CPU); c10::List size0{2, 5}; std::vector args{size0, dtype, at::kStrided, cpu, false}; c10::List size1{5, 6}; std::vector args2{size1, dtype, at::kStrided, cpu, false}; testStaticRuntime(script, args); testStaticRuntime(script, args, args2); } TEST(StaticRuntime, OnesLike) { const auto script = R"JIT( def forward(self, input: Tensor, dtype: Optional[int], layout: Optional[int], device: Optional[Device], pin_memory: Optional[bool], memory_format: Optional[int]): a = torch.ones_like(input, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, memory_format=memory_format) return (a.clone()) )JIT"; auto cpu = at::Device(DeviceType::CPU); auto input0 = at::randn({2, 5}); std::vector args{ input0, at::ScalarType::Int, at::kStrided, cpu, false, c10::MemoryFormat::Contiguous}; std::vector args1{ input0, at::ScalarType::Float, at::kStrided, cpu, false, c10::MemoryFormat::Contiguous}; auto input1 = at::randn({5, 6}); std::vector args2{ input1, at::ScalarType::Float, at::kStrided, cpu, false, c10::MemoryFormat::Contiguous}; testStaticRuntime(script, args); testStaticRuntime( script, args, args1, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/false); testStaticRuntime(script, args, args2); } TEST(StaticRuntime, Zeros) { const auto script = R"JIT( def forward(self, size: List[int], dtype: Optional[int], layout: Optional[int], device: Optional[Device], pin_memory: Optional[bool]): a = torch.zeros(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) return (a.clone()) )JIT"; auto cpu = at::Device(DeviceType::CPU); c10::List size0{2, 5}; std::vector args{ size0, at::ScalarType::Int, at::kStrided, cpu, false}; std::vector args1{ size0, at::ScalarType::Float, at::kStrided, cpu, false}; c10::List size1{5, 6}; std::vector args2{ size1, at::ScalarType::Float, at::kStrided, cpu, false}; testStaticRuntime(script, args); testStaticRuntime( script, args, args1, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/false); testStaticRuntime(script, args, args2); } TEST(StaticRuntime, Linear) { const auto linear_script = R"JIT( def forward(self, inp: Tensor, weights: Tensor, bias: Optional[Tensor]) -> Tensor: return torch.linear(inp, weights, bias).clone() )JIT"; auto input = at::randn({1, 2}); auto weights = at::randn({1, 2}); auto bias = at::randn({1, 1}); std::vector args{input, weights, bias}; std::vector args_no_bias{input, weights, std::nullopt}; auto input2 = at::randn({6, 3}); auto weights2 = at::randn({6, 3}); auto bias2 = at::randn({6, 6}); std::vector args2{input2, weights2, bias2}; std::vector args2_no_bias{input2, weights2, std::nullopt}; testStaticRuntime(linear_script, args); testStaticRuntime(linear_script, args_no_bias); testStaticRuntime(linear_script, args, args2); testStaticRuntime(linear_script, args, args2_no_bias); } TEST(StaticRuntime, VarCat) { const auto var_cat_script = R"JIT( def forward(self, inp1: Tensor, inp2: Tensor, dim: int): return torch.cat([inp1, inp2], dim).clone() )JIT"; // 2D tensors - cat dim = 0 std::vector args1 = {at::randn({4, 6}), at::randn({5, 6}), 0}; testStaticRuntime(var_cat_script, args1); // 3D tensors - cat dim = 1 std::vector args2 = {at::randn({4, 5, 6}), at::randn({4, 8, 6}), 1}; testStaticRuntime(var_cat_script, args2); // 3D tensors - cat dim = 2 std::vector args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), 2}; testStaticRuntime(var_cat_script, args3); // Negative dim std::vector args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), -1}; testStaticRuntime(var_cat_script, args4); testStaticRuntime(var_cat_script, args1, args2); } TEST(StaticRuntime, LeakyReLU) { torch::jit::Module mod = getLeakyReLUConstScriptModel(); auto inputs = torch::randn({2, 2}); // run jit graph executor std::vector input_ivalues({inputs}); at::Tensor output_1 = mod.forward(input_ivalues).toTensor(); // run static runtime std::vector input_tensors({inputs}); torch::jit::StaticModule smod(mod); at::Tensor output_2 = smod(input_tensors, {}).toTensor(); smod.runtime().check_for_memory_leak(); EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6)); } static ProcessedNodeInputs createProcessedNodeInputs( c10::ArrayRef inputs) { ProcessedNodeInputs result(inputs.size()); for (const auto idx : c10::irange(inputs.size())) { result[idx] = inputs[idx]; } return result; } static void checkProcessedNodeInputs( const ProcessedNodeInputs& io, c10::ArrayRef inputs) { ASSERT_EQ(inputs.size(), io.size()); for (const auto idx : c10::irange(inputs.size())) { EXPECT_EQ(inputs[idx], io[idx]); } } static void testProcessedNodeInputsRoundTrip(c10::ArrayRef inputs) { auto io = createProcessedNodeInputs(inputs); checkProcessedNodeInputs(io, inputs); ProcessedNodeInputs copied(io); checkProcessedNodeInputs(copied, inputs); ProcessedNodeInputs moved(std::move(io)); checkProcessedNodeInputs(moved, inputs); } TEST(ProcessedNodeInputs, Basic) { std::vector> testCases = { {}, // empty {0xABCD, 0x5a5a}, // inline {0x11, 0x22, 0x33, 0x44, 0x55}, // max inline size {0x11, 0x22, 0x33, 0x44, 0x55, 0x66}, // minimum outline size std::vector(100, 0x5a), // large outline size }; for (const auto& values : testCases) { testProcessedNodeInputsRoundTrip(values); for (const auto& values2 : testCases) { auto from = createProcessedNodeInputs(values); auto to = createProcessedNodeInputs(values2); to = from; checkProcessedNodeInputs(to, values); auto toMoveInto = createProcessedNodeInputs(values2); toMoveInto = std::move(from); checkProcessedNodeInputs(toMoveInto, values); } } } TEST(StaticRuntime, isinstance) { const auto isinstance_int_script = R"JIT( def forward(self, a: Any): return isinstance(a, int) )JIT"; const auto isinstance_tensor_script = R"JIT( def forward(self, a: Any): return isinstance(a, torch.Tensor) )JIT"; const auto isinstance_many_types_script = R"JIT( def forward(self, a: Any): return isinstance(a, (bool, int)) )JIT"; auto a = at::randn({2, 2}); auto b = at::randn({2, 2, 2}); std::vector args{a}; std::vector args2{b}; testStaticRuntime(isinstance_int_script, args); testStaticRuntime(isinstance_int_script, args, args2); testStaticRuntime(isinstance_tensor_script, args); testStaticRuntime(isinstance_tensor_script, args, args2); testStaticRuntime(isinstance_many_types_script, args); testStaticRuntime(isinstance_many_types_script, args, args2); } TEST(StaticRuntime, TypeCheck) { const auto typecheck_ir = R"IR( graph(%a.1 : Tensor, %b.1 : Tensor): %t0 : Float(2, 2, strides=[2, 1], device=cpu), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu), Float(3, 3, strides=[3, 1])]](%a.1, %b.1) return (%t0, %t1, %type_matched) )IR"; auto a = at::zeros({2, 2}, at::kFloat); a.to(at::kCPU); auto b = at::ones({3, 3}, at::kFloat); auto c = at::ones({2, 2, 2}, at::kFloat); std::vector args_correct = {a, b}; std::vector args_incorrect = {a, c}; testStaticRuntime(typecheck_ir, args_correct); testStaticRuntime(typecheck_ir, args_correct, args_incorrect); } TEST(StaticRuntime, Index) { const auto index_without_none_script = R"JIT( def forward(self, a: Tensor, idx: Tensor): return a[idx].clone() )JIT"; // Index with boolean mask auto a = at::arange(4, at::kFloat).view({2, 2}); auto idx_a = torch::tensor({{0, 1}, {0, 0}}, at::kBool); std::vector args_a{a, idx_a}; // Index with tensor auto b = at::arange(27, at::kFloat).view({3, 3, 3}); auto idx_b = torch::tensor({0, 1, 2}, at::kLong); std::vector args_b{b, idx_b}; testStaticRuntime(index_without_none_script, args_a); testStaticRuntime(index_without_none_script, args_a, args_b); const auto index_with_none_script = R"JIT( def forward(self, a: Tensor, idx: Tensor, none: Optional[Tensor]): return a[idx, none].clone() )JIT"; // Index with None // When indexing with none, the shape of `f` becomes [2, 1, 2], // so the mask must be reshaped appropriately. auto f = at::arange(4, at::kFloat).view({2, 1, 2}); auto idx_f_reshape = torch::tensor({{{0, 1}}, {{0, 0}}}, at::kBool); std::vector args_f_with_none{f, idx_f_reshape}; args_f_with_none.emplace_back(); testStaticRuntime(index_with_none_script, args_f_with_none); testStaticRuntime( index_with_none_script, args_f_with_none, {IValue(b), IValue(idx_b), IValue()}); const auto index_with_two_tensors_script = R"JIT( def forward(self, a: Tensor, idx_a: Tensor, idx_b: Tensor): return a[idx_a, idx_b].clone() )JIT"; // Index with multiple tensors const auto& c = a; // 2x2 tensor auto idx_c1 = torch::tensor({0, 0}, at::kLong); auto idx_c2 = torch::tensor({0}, at::kLong); std::vector args_c{c, idx_c1, idx_c2}; const auto& d = b; // 3x3x3 tensor auto idx_d1 = torch::tensor({{0, 0, 2}, {0, 1, 1}}, at::kLong); auto idx_d2 = torch::tensor({{1, 1, 0}, {1, 0, 2}}, at::kLong); std::vector args_d{d, idx_d1, idx_d2}; testStaticRuntime(index_with_two_tensors_script, args_c, args_d); } TEST(StaticRuntime, IndexSelect) { const std::string script = R"IR( graph(%self: Tensor, %dim: int, %index: Tensor): %bias: None = prim::Constant() %ret = aten::index_select(%self, %dim, %index) %cloned = aten::clone(%ret, %bias) return (%cloned) )IR"; auto self0 = at::rand({6}); auto dim0 = 0; auto index0 = at::randint(0, 5, {6}, torch::kInt32); std::vector args{self0, dim0, index0}; testStaticRuntime(script, args); auto self1 = at::rand({128}); auto dim1 = 0; auto index1 = at::randint(0, 127, {127}, torch::kInt32); std::vector args2{self1, dim1, index1}; testStaticRuntime(script, args, args2); } TEST(StaticRuntime, ClampMin) { const auto clamp_min_int_script = R"JIT( def forward(self, a: Tensor, b: int): return torch.clamp_min(a, b).clone() )JIT"; const auto clamp_min_float_script = R"JIT( def forward(self, a: Tensor, b: float): return torch.clamp_min(a, b).clone() )JIT"; auto a = at::randn({2, 2}); auto b = at::randn({3, 3, 3}); int scalar_int = 1; float scalar_float = 3.14; std::vector args_a_int{a, scalar_int}; std::vector args_b_int{b, scalar_int}; testStaticRuntime(clamp_min_int_script, args_a_int); testStaticRuntime(clamp_min_int_script, args_a_int, args_b_int); std::vector args_a_float{a, scalar_float}; std::vector args_b_float{b, scalar_float}; testStaticRuntime(clamp_min_float_script, args_a_float); testStaticRuntime(clamp_min_float_script, args_a_float, args_b_float); } TEST(StaticRuntime, Argmin) { const auto argmin_script = R"JIT( def forward(self, a: Tensor): return torch.argmin(a).clone() )JIT"; const auto argmin_with_dim_script = R"JIT( def forward(self, a: Tensor, dim: int): return torch.argmin(a, dim).clone() )JIT"; const auto argmin_with_keep_dim_script = R"JIT( def forward(self, a: Tensor, dim: int): return torch.argmin(a, dim, True).clone() )JIT"; auto a = at::randn({2, 2}); auto b = at::randn({17, 2, 1}); testStaticRuntime(argmin_script, {a}); testStaticRuntime( argmin_script, {a}, {b}, /* use_allclose */ false, /* use_equalnan */ false, /* check_resize */ false); int dim_a = 0; int dim_b = 1; std::vector args_a{a, dim_a}; std::vector args_b{b, dim_b}; testStaticRuntime(argmin_with_dim_script, args_a); testStaticRuntime(argmin_with_dim_script, args_a, args_b); testStaticRuntime(argmin_with_keep_dim_script, args_a); testStaticRuntime(argmin_with_keep_dim_script, args_a, args_b); } TEST(StaticRuntime, Softmax) { const auto softmax_script = R"JIT( def forward(self, a: Tensor, dim: int): return torch.softmax(a, dim).clone() )JIT"; const auto softmax_script_with_dtype = R"JIT( def forward(self, a: Tensor, dim: int, dtype: int): return torch.softmax(a, dim, dtype=dtype).clone() )JIT"; auto a = at::randn({2, 3}); auto b = at::randn({3, 3, 3}); testStaticRuntime(softmax_script, {a, 0}); testStaticRuntime(softmax_script, {a, 1}); testStaticRuntime(softmax_script, {b, 0}); testStaticRuntime(softmax_script, {b, 1}); testStaticRuntime(softmax_script, {b, 2}); testStaticRuntime(softmax_script_with_dtype, {a, 1, at::ScalarType::Float}); testStaticRuntime(softmax_script_with_dtype, {b, 1, at::ScalarType::Float}); } TEST(StaticRuntime, GetItem_Dict) { const auto getitem_dict_tensor_script = R"JIT( def forward(self, key: Tensor): d = {key: 1} return d[key] )JIT"; const auto getitem_dict_int_script = R"JIT( def forward(self, key: int): d = {key: 1} return d[key] )JIT"; const auto getitem_dict_str_script = R"JIT( def forward(self, key: str): d = {key: 1} return d[key] )JIT"; int int_key = 0; std::string str_key = "str"; // No need to test these multiple times, args are not tensors testStaticRuntime(getitem_dict_int_script, {int_key}); testStaticRuntime(getitem_dict_str_script, {str_key}); auto a = torch::tensor({1}); auto b = torch::tensor({1, 1}); testStaticRuntime(getitem_dict_tensor_script, {a}); testStaticRuntime(getitem_dict_tensor_script, {a}, {b}); } TEST(StaticRuntime, GetItem_List) { const auto getitem_list_int_script = R"JIT( def forward(self, idx: int): lst = [1, 2, 3] return lst[idx] )JIT"; const auto getitem_list_tensor_script = R"JIT( def forward(self, tensor: Tensor, idx: int): lst = [tensor, tensor] return lst[idx] )JIT"; testStaticRuntime(getitem_list_int_script, {1}); testStaticRuntime(getitem_list_int_script, {-1}); auto a = torch::tensor({1}); auto b = torch::tensor({1, 1}); testStaticRuntime(getitem_list_tensor_script, {a, 1}); testStaticRuntime(getitem_list_tensor_script, {a, 1}, {b, -1}); } TEST(StaticRuntime, Transpose) { const auto transpose_script = R"JIT( def forward(self, a: Tensor, dim1: int, dim2: int): return torch.transpose(a, dim1, dim2).clone() )JIT"; auto a = at::randn({2, 2}); int dim1_a = 0; int dim2_a = 1; std::vector args_a{a, dim1_a, dim2_a}; auto b = at::randn({3, 3, 3}); int dim1_b = 0; int dim2_b = 2; std::vector args_b{b, dim1_b, dim2_b}; testStaticRuntime(transpose_script, args_a); testStaticRuntime(transpose_script, args_a, args_b); } TEST(StaticRuntime, Permute) { auto permute_script = R"JIT( def forward(self, a: Tensor, dims: List[int]): return torch.permute(a, dims).clone() )JIT"; auto a = at::randn({2, 2}); c10::List dims_a{1, 0}; std::vector args_a{a, dims_a}; auto b = at::randn({3, 3, 3}); c10::List dims_b{0, 2, 1}; std::vector args_b{b, dims_b}; testStaticRuntime(permute_script, args_a); testStaticRuntime(permute_script, args_a, args_b); permute_script = R"JIT( def forward(self, a: Tensor, dims: List[int], shape: List[int]): return torch.permute(a, dims).reshape(shape).clone() )JIT"; a = at::randn({8, 16, 4}); dims_a = {0, 2, 1}; dims_b = {-1, 16}; testStaticRuntime(permute_script, {a, dims_a, dims_b}); } TEST(StaticRuntime, Slice) { const auto slice_script = R"JIT( def forward(self, a: Tensor, dim: int, start: int, end: int, step: int): return a.slice(dim, start, end, step).clone() )JIT"; auto a = at::randn({2, 2}); int dim_a = 1; int start_a = 0; int end_a = 1; int step_a = 1; std::vector args_a{a, dim_a, start_a, end_a, step_a}; auto b = at::randn({3, 3, 3}); int dim_b = 2; int start_b = 0; int end_b = 1; int step_b = 2; std::vector args_b{b, dim_b, start_b, end_b, step_b}; testStaticRuntime(slice_script, args_a); testStaticRuntime(slice_script, args_a, args_b); const auto slice_script2 = R"JIT( def forward(self, a: Tensor, dim: int, step: int): return a.slice(dim, None, None, step).clone() )JIT"; std::vector args_c{b, dim_b, step_b}; testStaticRuntime(slice_script2, args_c); } TEST(StaticRuntime, Narrow) { const auto narrow_with_int_script = R"JIT( def forward(self, a: Tensor, dim: int, start: int, length: int): return a.narrow(dim, start, length).clone() )JIT"; auto a = at::randn({5, 5}); int dim_a = 0; int start_a_int = 3; int len_a = 2; std::vector args_a{a, dim_a, start_a_int, len_a}; auto b = at::randn({5, 5, 5}); int dim_b = 1; int start_b_int = 2; int len_b = 3; std::vector args_b{b, dim_b, start_b_int, len_b}; testStaticRuntime(narrow_with_int_script, args_a); testStaticRuntime(narrow_with_int_script, args_a, args_b); } TEST(StaticRuntime, TupleUnpack) { const auto two_tuple_unpack_script = R"JIT( def forward(self, tup: Tuple[Tensor, Tensor]): a, b = tup return (a, b) )JIT"; const auto three_tuple_unpack_script = R"JIT( def forward(self, tup: Tuple[Tensor, Tensor, Tensor]): a, b, c = tup return (a, b, c) )JIT"; auto two_tup = c10::ivalue::Tuple::create({at::randn({1}), at::randn({1})}); auto two_tup_large = c10::ivalue::Tuple::create({at::randn({2, 2}), at::randn({2, 2})}); auto three_tup = c10::ivalue::Tuple::create( {at::randn({1}), at::randn({1}), at::randn({1})}); auto three_tup_large = c10::ivalue::Tuple::create( {at::randn({2, 2}), at::randn({2, 2}), at::randn({2, 2})}); testStaticRuntime(two_tuple_unpack_script, {two_tup}); testStaticRuntime(two_tuple_unpack_script, {two_tup}, {two_tup_large}); testStaticRuntime(three_tuple_unpack_script, {three_tup}); testStaticRuntime(three_tuple_unpack_script, {three_tup}, {three_tup_large}); } TEST(StaticRuntime, Append) { const auto append_int_script = R"JIT( def forward(self, a: int): lst = [1, 2, 3] lst.append(a) return lst )JIT"; const auto append_tensor_script = R"JIT( def forward(self, a: Tensor): lst = [] lst.append(a) return lst )JIT"; std::vector args_int{1}; testStaticRuntime(append_int_script, args_int); std::vector args_tensor{at::randn({1})}; std::vector args_tensor_large{at::randn({2, 2})}; testStaticRuntime(append_tensor_script, args_tensor); testStaticRuntime(append_tensor_script, args_tensor, args_tensor_large); } TEST(StaticRuntime, QuantizedLinear) { const std::string quantize_script = R"IR( graph(%input: Tensor, %weights: Tensor): %scale: float = prim::Constant[value=1.]() %zero_point: int = prim::Constant[value=1]() %bias: None = prim::Constant() %packed_params = quantized::linear_prepack(%weights, %bias) %1254 = quantized::linear(%input, %packed_params, %scale, %zero_point) %1249: Tensor = aten::dequantize(%1254) return (%1249) )IR"; at::Tensor weight = at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8); at::Tensor input = at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8); at::Tensor weight_2 = at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8); at::Tensor input_2 = at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8); testStaticRuntime(quantize_script, {input, weight}, {input_2, weight_2}); } TEST(StaticRuntime, QuantizedLinearDynamicFp16) { const std::string quantized_linear_dynamic_fp16_script = R"IR( graph(%input: Tensor, %weights: Tensor): %bias: None = prim::Constant() %packed_params = quantized::linear_prepack_fp16(%weights, %bias) %output = quantized::linear_dynamic_fp16(%input, %packed_params) %ret = aten::clone(%output, %bias) return (%ret) )IR"; at::Tensor weight = torch::randn({3, 2}, torch::kFloat); at::Tensor input = torch::randn({3, 2}, torch::kFloat); at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat); at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat); testStaticRuntime( quantized_linear_dynamic_fp16_script, {input, weight}, {input_2, weight_2}); } TEST(StaticRuntime, QuantizedLinearReluDynamicFp16) { const std::string quantized_linear_relu_dynamic_fp16_script = R"IR( graph(%input: Tensor, %weights: Tensor): %bias: None = prim::Constant() %packed_params = quantized::linear_prepack_fp16(%weights, %bias) %output = quantized::linear_relu_dynamic_fp16(%input, %packed_params) %ret = aten::clone(%output, %bias) return (%ret) )IR"; at::Tensor weight = torch::randn({3, 2}, torch::kFloat); at::Tensor input = torch::randn({3, 2}, torch::kFloat); at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat); at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat); testStaticRuntime( quantized_linear_relu_dynamic_fp16_script, {input, weight}, {input_2, weight_2}); } TEST(StaticRuntime, VarStack) { const auto var_stack_script = R"JIT( def forward(self, inp1: Tensor, inp2: Tensor, dim: int): return torch.stack([inp1, inp2], dim).clone() )JIT"; // 2D tensors - stack dim = 0 std::vector args1 = {at::randn({6, 6}), at::randn({6, 6}), 0}; testStaticRuntime(var_stack_script, args1); // 3D tensors - stack dim = 1 std::vector args2 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 1}; testStaticRuntime(var_stack_script, args2); // 3D tensors - stack dim = 2 std::vector args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 2}; testStaticRuntime(var_stack_script, args3); // Negative dim std::vector args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), -1}; testStaticRuntime(var_stack_script, args4); // Non-serial path std::vector args5 = {at::randn({1, 2, 3}), at::randn({1, 2, 3}), 3}; testStaticRuntime(var_stack_script, args5); // Fast path std::vector args6 = {at::randn({1}), at::randn({1}), 0}; testStaticRuntime(var_stack_script, args6); testStaticRuntime(var_stack_script, args1, args2); } TEST(StaticRuntime, FmodTensor) { const auto fmod_tensor = R"JIT( def forward(self, a: Tensor, b: Tensor): return torch.fmod(a, b).clone() )JIT"; // fmod tensor version auto a = at::randn({2, 3}); auto b = at::randn({2, 3}); std::vector args0{a, b}; testStaticRuntime(fmod_tensor, args0); // check for dynamic shapes auto c = at::randn({4, 3, 2}); auto d = at::randn({4, 3, 2}); std::vector args1{c, d}; testStaticRuntime(fmod_tensor, args0, args1); } TEST(StaticRuntime, FmodScalar) { const auto fmod_scalar = R"JIT( def forward(self, a: Tensor, b: int): return torch.fmod(a, b).clone() )JIT"; auto a = at::randn({2, 3}); // fmod scalar version std::vector args2{a, 3}; testStaticRuntime(fmod_scalar, args2); // check for dynamic shapes auto c = at::randn({4, 3, 2}); std::vector args3{c, 4}; testStaticRuntime(fmod_scalar, args2, args3); // test int32 version a = at::randint(-100, 100, {2, 3}, at::kInt); c = at::randint(-100, 100, {4, 3, 2}, at::kInt); testStaticRuntime(fmod_scalar, {a, 3}); testStaticRuntime(fmod_scalar, {a, 3}, {c, 4}); } TEST(StaticRuntime, QEmbeddingBagBytePrepack) { const std::string embedding_bag_byte_prepack_script = R"IR( graph(%input: Tensor): %none : None = prim::Constant() %output: Tensor = quantized::embedding_bag_byte_prepack(%input) %res: Tensor = aten::clone(%output, %none) return (%res) )IR"; auto a = torch::randn({8, 16}, at::ScalarType::Float); auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float); testStaticRuntime(embedding_bag_byte_prepack_script, {a}); testStaticRuntime(embedding_bag_byte_prepack_script, {a}, {b}); } TEST(StaticRuntime, QEmbeddingBagByteUnpack) { const auto src = R"IR( graph(%input: Tensor): %none : None = prim::Constant() %weight: Tensor = quantized::embedding_bag_byte_prepack(%input) %output: Tensor = quantized::embedding_bag_byte_unpack(%weight) %res: Tensor = aten::clone(%output, %none) return (%res) )IR"; auto a = torch::randn({8, 16}, at::ScalarType::Float); auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float); testStaticRuntime(src, {a}); testStaticRuntime(src, {a}, {b}); } TEST(StaticRuntime, LinalgNorm_ScalarOrd) { const auto linalg_norm_ord_scalar = R"JIT( def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int): return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone() )JIT"; auto a = at::randn({2, 3}); auto dim = std::vector({1}); auto dtype = at::ScalarType::Float; std::vector args0{a, 4, dim, true, dtype}; testStaticRuntime(linalg_norm_ord_scalar, args0); auto b = at::randn({3, 2, 6}); std::vector args1{b, 4, dim, true, dtype}; testStaticRuntime(linalg_norm_ord_scalar, args0, args1); } TEST(StaticRuntime, LinalgNorm_StringOrd) { const auto linalg_norm_ord_str = R"JIT( def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int): return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone() )JIT"; auto a = at::randn({2, 3}); auto dim = std::vector({0, 1}); auto dtype = at::ScalarType::Float; std::vector args0{a, "fro", dim, true, dtype}; testStaticRuntime(linalg_norm_ord_str, args0); auto b = at::randn({3, 2, 17}); std::vector args1{b, "fro", dim, true, dtype}; testStaticRuntime(linalg_norm_ord_str, args0, args1); } TEST(StaticRuntime, Index_Put) { const auto index_put_str = R"JIT( def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool): return torch.index_put(a, indices, values, accumulate).clone() )JIT"; auto a = at::randn({2}); auto indices_a = std::make_tuple(torch::tensor({0}, at::kLong)); auto values_a = at::randn({1}); std::vector args0{a, indices_a, values_a, false}; testStaticRuntime(index_put_str, args0); const auto index_put_non_optional_str = R"JIT( def forward(self, a: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool): return torch.index_put(a, indices, values, accumulate).clone() )JIT"; auto indices_b = c10::List{torch::tensor({0}, at::kLong)}; std::vector args1{a, indices_b, values_a, false}; testStaticRuntime(index_put_non_optional_str, args1); const auto index_put_list_construct = R"JIT( def forward(self, a: Tensor, indices: Tensor, values: Tensor, accumulate: bool): indices: List[Optional[Tensor]] = [indices] return torch.index_put(a, indices, values, accumulate).clone() )JIT"; std::vector args2{a, torch::tensor({0}, at::kLong), values_a, false}; testStaticRuntime(index_put_list_construct, args2); } TEST(StaticRuntime, Item) { const auto item_str = R"JIT( def forward(self, a: Tensor): return torch.item(a) )JIT"; auto a = at::randn({1}); std::vector args0{a}; testStaticRuntime(item_str, args0); } TEST(StaticRuntime, Tensor_Split) { const auto tensor_split_str1 = R"JIT( def forward(self, a: Tensor, sections: int, dim: int): return torch.tensor_split(a, sections, dim) )JIT"; std::vector args1{at::randn({8}), 3, 0}; const auto tensor_split_str2 = R"JIT( def forward(self, a: Tensor, sections: Tensor, dim: int): return torch.tensor_split(a, sections, dim) )JIT"; std::vector args2{at::randn({8}), torch::tensor(3), 0}; const auto tensor_split_str3 = R"JIT( def forward(self, a: Tensor, indices: List[int], dim: int): return torch.tensor_split(a, indices, dim) )JIT"; std::vector args3{at::randn({8}), c10::List({1, 6}), 0}; testStaticRuntime(tensor_split_str1, args1); testStaticRuntime(tensor_split_str2, args2); testStaticRuntime(tensor_split_str3, args3); } TEST(StaticRuntime, JIT_Aten_Cpu) { const std::string script = R"IR( graph(%a: Tensor): %1 : int = prim::Constant[value=0]() %aa: Tensor = aten::add(%a, %a, %1) %ret: Tensor = aten::cpu(%aa) return (%ret) )IR"; auto graph = std::make_shared(); std::unordered_map vmap; vmap.reserve(0); parseIR(script, graph.get(), vmap); torch::jit::StaticModule smodule(graph); auto a = at::randn({2, 4}); std::vector args0{a}; testStaticRuntime(script, args0); } TEST(StaticRuntime, JIT_Aten_Numel) { const std::string script = R"IR( graph(%a: Tensor): %1 : int = prim::Constant[value=0]() %aa: Tensor = aten::add(%a, %a, %1) %ret: int = aten::numel(%aa) return (%ret) )IR"; auto graph = std::make_shared(); std::unordered_map vmap; vmap.reserve(0); parseIR(script, graph.get(), vmap); torch::jit::StaticModule smodule(graph); auto a = at::randn({2, 4}); std::vector args0{a}; testStaticRuntime(script, args0); } TEST(StaticRuntime, JIT_Aten_List) { const auto script_str = R"IR( graph(%a: str): %ret: str[] = aten::list(%a) return (%ret) )IR"; std::string a = "abcd"; std::vector args0{a}; testStaticRuntime(script_str, args0); // Update the result of aten::list to ensure that a deep copy // took place const auto script_list = R"IR( graph(%a : int[]): %idx : int = prim::Constant[value=0]() %value : int = prim::Constant[value=42]() %res : int[] = aten::list(%a) %updated : int[] = aten::_set_item(%res, %idx, %value) return (%res, %a) )IR"; std::vector args1{c10::List{1, 2, 3}}; testStaticRuntime(script_list, args1); } TEST(StaticRuntime, JIT_Aten_Range_Length) { const std::string script = R"IR( graph(%lo: int, %hi: int, %step: int): %1 : int = prim::Constant[value=0]() %ret: int = aten::__range_length(%lo, %hi, %step) return (%ret) )IR"; auto graph = std::make_shared(); std::unordered_map vmap; vmap.reserve(0); parseIR(script, graph.get(), vmap); torch::jit::StaticModule smodule(graph); std::vector args0{0, 10, 2}; testStaticRuntime(script, args0); } TEST(StaticRuntime, Cat) { const std::string cat_script = R"IR( graph(%a: Tensor, %b: Tensor, %dim: int): %ten_list: Tensor[] = prim::ListConstruct(%a, %b) %1 : int = prim::Constant[value=0]() %2 : int = prim::Constant[value=1]() %3 : int = prim::Constant[value=1]() %ten_list2 : Tensor[] = aten::slice(%ten_list, %1, %2, %3) %ret: Tensor = aten::cat(%ten_list2, %dim) return (%ret) )IR"; auto graph = std::make_shared(); std::unordered_map vmap; parseIR(cat_script, graph.get(), vmap); torch::jit::StaticModule smodule(graph); ASSERT_TRUE(getNodeWithKind(smodule, "aten::cat")); auto a = at::randn({2, 4}); auto b = at::randn({3, 4}); std::vector args0{a, b, 0}; testStaticRuntime(cat_script, args0); auto c = at::randn({3, 4}); auto d = at::randn({3, 5}); std::vector args1{c, d, 1}; testStaticRuntime(cat_script, args0, args1); std::vector args_dim_negative{c, d, -1}; testStaticRuntime(cat_script, args_dim_negative); } TEST(StaticRuntime, Cumsum) { const auto cumsum_script = R"JIT( def forward(self, a: Tensor, dim: int): return torch.cumsum(a, dim).clone() )JIT"; auto a = at::randn({2, 3}); std::vector args0{a, 0}; testStaticRuntime(cumsum_script, args0); auto b = at::randn({3, 6}); std::vector args1{b, 1}; testStaticRuntime(cumsum_script, args0, args1); } TEST(StaticRuntime, CumsumDtype) { const auto cumsum_script_dtype = R"JIT( def forward(self, a: Tensor, dim: int, dtype: int): return torch.cumsum(a, dim, dtype=dtype).clone() )JIT"; auto a = at::randn({1, 2}); auto dtype = at::ScalarType::Float; std::vector args0{a, 0, dtype}; testStaticRuntime(cumsum_script_dtype, args0); auto b = at::randn({3, 6}); std::vector args1{b, 1, dtype}; testStaticRuntime(cumsum_script_dtype, args0, args1); } TEST(StaticRuntime, Nonzero) { const auto nonzero_tensor = R"JIT( def forward(self, input: Tensor): a = torch.nonzero(input).clone() return (a) )JIT"; auto a = at::randint(0, 2, {2, 3}); testStaticRuntime(nonzero_tensor, {a}); auto b = at::randint(0, 2, {4, 3, 2}); testStaticRuntime(nonzero_tensor, {a}, {b}); } TEST(StaticRuntime, SignedLog1p) { const std::string signed_log1p_script = R"IR( graph(%input): %0 : Tensor = aten::sign(%input) %1 : Tensor = aten::abs(%input) %2 : Tensor = aten::log1p(%1) %3 : Tensor = aten::mul(%0, %2) %none : NoneType = prim::Constant() %res : Tensor = aten::clone(%3, %none) return (%res) )IR"; std::vector args1 = {at::randn({2, 2})}; testStaticRuntime(signed_log1p_script, args1, {}, true); std::vector args2 = {at::randn({3, 3, 3})}; testStaticRuntime(signed_log1p_script, args1, args2, true); } TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithImmutableInputDict) { const auto getitem_immutable_input_dict_script = R"JIT( def forward(self, input: Dict[int, Tensor]): a = input[0] b = input[1] c = a + b return c.clone() )JIT"; script::Module module("module"); module.define(getitem_immutable_input_dict_script); torch::jit::StaticModule smodule(module); EXPECT_FALSE(hasNodeWithKind(smodule, "aten::__getitem__")); EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::dict_unpack")); auto a = at::randn({2, 4}); auto b = at::randn({2, 4}); c10::Dict dict( c10::IntType::get(), c10::TensorType::get()); dict.insert(0, a); dict.insert(1, b); testStaticRuntime(getitem_immutable_input_dict_script, {dict}); c10::Dict dict0( c10::IntType::get(), c10::TensorType::get()); auto a0 = at::randn({3, 4}); auto b0 = at::randn({3, 4}); dict0.insert(0, a0); dict0.insert(1, b0); testStaticRuntime(getitem_immutable_input_dict_script, {dict0}); } TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithMutableInputDict) { const auto getitem_mutable_input_dict_script = R"JIT( def forward(self, input: Dict[int, Tensor]): a = input[0] input[1] = a b = input[1] c = a + b return c.clone() )JIT"; script::Module module("module"); module.define(getitem_mutable_input_dict_script); torch::jit::StaticModule smodule(module); EXPECT_TRUE(hasNodeWithKind(smodule, "aten::__getitem__")); EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::dict_unpack")); } TEST(StaticRuntime, VarTupleUnpack) { const auto var_tuple_unpack_script = R"JIT( def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]): a, b = input_0 c, d = input_1 res = a * c + b * d return res.clone() )JIT"; script::Module module("module"); module.define(var_tuple_unpack_script); torch::jit::StaticModule smodule(module); EXPECT_FALSE(hasNodeWithKind(smodule, "prim::TupleUnpack")); EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack")); auto a = at::randn({2, 2}); auto b = at::randn({3, 3, 3}); std::vector args1{ c10::ivalue::Tuple::create(a, a), c10::ivalue::Tuple::create(1, 2)}; std::vector args2{ c10::ivalue::Tuple::create(b, b), c10::ivalue::Tuple::create(1, 2)}; testStaticRuntime(var_tuple_unpack_script, args1); testStaticRuntime(var_tuple_unpack_script, args1, args2); } TEST(StaticRuntime, VarTupleUnpack_NotApplied) { const auto var_tuple_unpack_not_applied_script = R"JIT( def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]): a, b = input_0 x = a + b c, d = input_1 res = a * c + b * d + x return res.clone() )JIT"; script::Module module("module"); // In this script, the optimization is not applied since there is a // computation between the TupleUnpack nodes. module.define(var_tuple_unpack_not_applied_script); torch::jit::StaticModule smodule(module); EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack")); EXPECT_TRUE(hasNodeWithKind(smodule, "prim::TupleUnpack")); } TEST(StaticRuntime, RemainderTensor) { const auto remainder_tensor = R"JIT( def forward(self, x, y): return torch.remainder(x, y).clone() )JIT"; std::vector args1 = { at::randint(0, 10, {2, 2}), at::randint(1, 10, {2, 2})}; std::vector args2 = { at::randint(0, 10, {3, 6}), at::randint(1, 10, {3, 6})}; // Use allclose and equalnan since outputs may be NaN. testStaticRuntime( remainder_tensor, args1, /*args2*/ {}, /*use_alloclose*/ true, /*use_equalnan*/ true); testStaticRuntime( remainder_tensor, args1, args2, /*use_allclose*/ true, /*use_equalnan*/ true); } TEST(StaticRuntime, RemainderScalar) { const auto remainder_scalar = R"JIT( def forward(self, x, y: int): return torch.remainder(x, y).clone() )JIT"; std::vector args1 = {at::randint(0, 10, {2, 2}), 4}; std::vector args2 = {at::randint(0, 10, {3, 6}), 4}; // Use allclose and equalnan since outputs may be NaN. testStaticRuntime( remainder_scalar, args1, /*args2*/ {}, /*use_alloclose*/ true, /*use_equalnan*/ true); testStaticRuntime( remainder_scalar, args1, args2, /*use_allclose*/ true, /*use_equalnan*/ true); } TEST(StaticRuntime, Where) { const auto where_script = R"JIT( def forward(self, x, y): return torch.where(x > 0, x, y).clone() )JIT"; std::vector args1 = {at::randn({2, 2}), at::randn({2, 2})}; std::vector args2 = {at::randn({8, 10}), at::randn({8, 10})}; testStaticRuntime(where_script, args1); testStaticRuntime(where_script, args1, args2); } TEST(StaticRuntime, WhereBroadcast) { const auto where_script = R"JIT( def forward(self, cond_1d, x, y): shape = [-1] + [1] * (x.dim() - 1) cond = cond_1d.view(shape) return torch.where(cond, x, y).clone() )JIT"; std::vector args1 = { at::tensor({0, 1}).to(at::kBool), at::randn({2, 2}), at::randn({2, 2})}; std::vector args2 = { at::tensor({1, 0, 0}).to(at::kBool), at::randn({3, 6}), at::randn({3, 6})}; testStaticRuntime(where_script, args1); testStaticRuntime(where_script, args1, args2); } TEST(StaticRuntime, View) { // Note that clone is not technically necessary here since this is not // an out variant, but it suppresses warnings about only have one op // in testStaticRuntime const auto src = R"IR( graph(%input : Tensor, %shape : int[]): %none : NoneType = prim::Constant() %view : Tensor = aten::view(%input, %shape) %res : Tensor = aten::clone(%view, %none) return (%res) )IR"; std::vector args1{at::randn({2, 2}), c10::List(4)}; std::vector args2{at::randn({2, 2, 2}), c10::List({4, 2})}; testStaticRuntime(src, args1); testStaticRuntime(src, args1, args2); } TEST(StaticRuntime, Size) { const auto src_with_dim = R"JIT( def forward(self, x, dim: int): return x.size(dim) )JIT"; const auto src_no_dim = R"JIT( def forward(self, x): return x.size() )JIT"; std::vector args1{at::randn({1}), 0}; std::vector args2{at::randn({1}), -1}; std::vector args3{at::randn({2, 4}), 1}; std::vector args_no_dim{at::randn({2, 4})}; testStaticRuntime(src_with_dim, args1); testStaticRuntime(src_with_dim, args2); testStaticRuntime(src_with_dim, args1, args3); testStaticRuntime(src_no_dim, args_no_dim); } TEST(StaticRuntime, Squeeze) { // Note: this is a native op, not an out variant, but clone anyways // to silence warnings in testStaticRuntime const auto src = R"JIT( def forward(self, inp, dim: int): return inp.squeeze(dim).clone() )JIT"; const auto a = at::randn({2, 2}); const auto b = at::randn({3, 2, 3}); testStaticRuntime(src, {a, 0}); testStaticRuntime(src, {a, 1}); testStaticRuntime(src, {a, -1}, {b, 2}); } TEST(StaticRuntime, NumToTensorScalar) { const auto num_to_tensor_ir = R"IR( graph(%1 : int): %2 : NoneType = prim::Constant() %3 : Tensor = prim::NumToTensor(%1) %4 : Tensor = aten::clone(%3, %2) return (%4) )IR"; IValue arg{5}; std::vector args = {arg}; testStaticRuntime(num_to_tensor_ir, args); } TEST(StaticRuntime, NumToTensorFalse) { const auto num_to_tensor_ir = R"IR( graph(%1 : bool): %2 : NoneType = prim::Constant() %3 : Tensor = prim::NumToTensor(%1) %4 : Tensor = aten::clone(%3, %2) return (%4) )IR"; IValue arg{false}; std::vector args = {arg}; testStaticRuntime(num_to_tensor_ir, args); } TEST(StaticRuntime, NumToTensorTrue) { const auto num_to_tensor_ir = R"IR( graph(%1 : bool): %2 : NoneType = prim::Constant() %3 : Tensor = prim::NumToTensor(%1) %4 : Tensor = aten::clone(%3, %2) return (%4) )IR"; IValue arg{true}; std::vector args = {arg}; testStaticRuntime(num_to_tensor_ir, args); } TEST(StaticRuntime, Split) { const auto src = R"JIT( def forward(self, inp, split_size: int, dim: int): return inp.split(split_size, dim) )JIT"; const auto a = at::randn({2, 2}); const auto b = at::randn({2, 2, 2}); testStaticRuntime(src, {a, 1, 0}); testStaticRuntime(src, {a, 1, 1}); testStaticRuntime(src, {a, 2, -1}, {b, 2, 2}); } TEST(StaticRuntime, SplitWithSizes) { const auto src = R"JIT( def forward(self, inp, split_sizes: List[int], dim: int): return inp.split(split_sizes, dim) )JIT"; const auto a = at::randn({2, 2}); const auto b = at::randn({2, 2, 2}); const auto split_sizes = c10::List{1, 1}; testStaticRuntime(src, {a, split_sizes, 0}); testStaticRuntime(src, {a, split_sizes, 1}); testStaticRuntime(src, {a, split_sizes, -1}, {b, split_sizes, 2}); } namespace { void maybe_throw(bool should_throw) { if (should_throw) { throw std::runtime_error("test exception"); } } TORCH_LIBRARY(static_runtime_tests, m) { // Conservative so this op doesn't get deleted by dead // code elimination m.def(torch::schema( "static_runtime_tests::maybe_throw(bool throw) -> ()", at::AliasAnalysisKind::CONSERVATIVE)); m.impl("maybe_throw", maybe_throw); } } // namespace TEST(StaticRuntime, ModelCrashOnFirstRun) { const auto src = R"JIT( graph(%0: Tensor, %throw: bool): %1: Tensor = aten::mul(%0, %0) static_runtime_tests::maybe_throw(%throw) %2: Tensor = aten::mul(%1, %1) %3: Tensor = aten::mul(%2, %2) return (%3) )JIT"; auto graph = getGraphFromIR(src); auto static_module = StaticModule(graph); auto& runtime = static_module.runtime(); std::vector args_crash{at::randn({1}), true}; std::vector args_no_crash{at::randn({1}), false}; EXPECT_THROW(runtime(args_crash, {}), std::runtime_error); // The run didn't finish, we didn't allocate the memory planner EXPECT_EQ(runtime.get_memory_planner(), nullptr); runtime.check_for_memory_leak(); // We guarantee that the runtime is still usable after the crash. // Run again to verify this. compareResultsWithJIT(runtime, graph, args_no_crash); EXPECT_NE(runtime.get_memory_planner(), nullptr); } TEST(StaticRuntime, ModelCrashOnSecondRun) { const auto src = R"JIT( graph(%0: Tensor, %throw: bool): %1: Tensor = aten::mul(%0, %0) static_runtime_tests::maybe_throw(%throw) %2: Tensor = aten::mul(%1, %1) %3: Tensor = aten::mul(%2, %2) return (%3) )JIT"; auto graph = getGraphFromIR(src); auto static_module = StaticModule(graph); auto& runtime = static_module.runtime(); std::vector args_crash{at::randn({1}), true}; std::vector args_no_crash{at::randn({1}), false}; runtime(args_no_crash, {}); EXPECT_NE(runtime.get_memory_planner(), nullptr); runtime.check_for_memory_leak(); EXPECT_THROW(runtime(args_crash, {}), std::runtime_error); runtime.check_for_memory_leak(); // We guarantee that the runtime is still usable after the crash. // Run again to verify this. compareResultsWithJIT(runtime, graph, args_no_crash); } TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrows) { const auto src = R"JIT( graph(%0: Tensor): %1: Tensor = aten::mul(%0, %0) %2: Tensor = aten::mul(%1, %1) %3: bool = prim::Constant[value=1]() %4: Tensor = static_runtime::select_tensor(%1, %2, %3) static_runtime_tests::maybe_throw(%3) return (%4) )JIT"; auto graph = getGraphFromIR(src); auto static_module = StaticModule(graph); auto& runtime = static_module.runtime(); std::vector args{at::randn({1})}; EXPECT_THROW(runtime(args), std::runtime_error); } TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrowedInputs) { const auto src = R"JIT( graph(%0: Tensor, %1: Tensor): %2: bool = prim::Constant[value=1]() %3: Tensor = static_runtime::select_tensor(%0, %1, %2) static_runtime_tests::maybe_throw(%2) return (%3) )JIT"; auto graph = getGraphFromIR(src); auto static_module = StaticModule(graph); auto& runtime = static_module.runtime(); std::vector args{at::randn({1}), at::randn({1})}; EXPECT_THROW(runtime(std::move(args)), std::runtime_error); } TEST(StaticRuntime, ReplaceWithMaybeCopy) { const std::string to = R"IR( graph(%0 : Tensor): %1: int = prim::Constant[value=4]() %2: bool = prim::Constant[value=0]() %3: None = prim::Constant() %res : Tensor = aten::to(%0, %1, %2, %2, %3) return (%res) )IR"; at::Tensor a = at::tensor({1.1, 2.2, 3.3, 4.0}, at::ScalarType::Float); std::vector args{a}; auto g = std::make_shared(); torch::jit::parseIR(to, g.get()); // Jit Interpreter. Stack stack(args); torch::jit::GraphExecutor graph_exec(g, ""); graph_exec.run(stack); ASSERT_EQ(stack.size(), 1); auto expected = stack[0].toTensor(); // Static Runtime. torch::jit::StaticModule smodule(g); auto actual = smodule(args, {}).toTensor(); smodule.runtime().check_for_memory_leak(); EXPECT_TRUE(expected.equal(actual)); // Make a fresh graph to ensure the pass works in isolation auto new_graph = std::make_shared(); torch::jit::parseIR(to, new_graph.get()); ReplaceWithMaybeCopy(new_graph); EXPECT_FALSE(hasNodeWithKind(new_graph, "aten::to")); EXPECT_TRUE( hasNodeWithKind(new_graph, "static_runtime::to_maybe_copy_out")); } TEST(StaticRuntime, Int) { const auto src = R"JIT( def forward(self, x): return int(x) + int(x) )JIT"; std::vector args{at::tensor({3.14})}; testStaticRuntime(src, args); } TEST(StaticRuntime, ReturnConstant) { const auto src = R"JIT( def forward(self): return 1 )JIT"; testStaticRuntime(src, {}); } TEST(StaticRuntime, SimpleIf) { const auto src = R"JIT( def forward(self, cond: bool, x): if cond: return torch.mul(x, 42).clone() else: return x.clone() )JIT"; std::vector args_false{false, at::randn({1})}; std::vector args_true{true, at::randn({1})}; std::vector args_big_tensor{true, at::randn({3, 3, 3})}; testStaticRuntime(src, args_false); testStaticRuntime(src, args_true); testStaticRuntime(src, args_true, args_big_tensor); } TEST(StaticRuntime, NestedIf) { const auto src = R"JIT( def forward(self, cond1: bool, cond2: bool, x): y = x * 42 if cond1: y = y * y if cond2: y += x else: if cond2: return x.clone() return y.clone() )JIT"; for (auto cond1 : {true, false}) { for (auto cond2 : {true, false}) { std::vector args1{cond1, cond2, at::randn({1})}; std::vector args2{cond1, cond2, at::randn({3, 3, 3})}; testStaticRuntime(src, args1, args2); } } } TEST(StaticRuntime, DeeplyNestedIf) { const auto src = R"JIT( def forward(self, cond1: bool, cond2: bool, cond3: bool, x): y = x * 42 if cond1: y = y * y if cond2: y += x if cond2 and cond3: y += 1 if cond2: if cond3: y += 2 else: y = y * y y += 4 else: if cond2: return x.clone() if cond3 or cond2: y += 42 return y.clone() )JIT"; for (auto cond1 : {true, false}) { for (auto cond2 : {true, false}) { for (auto cond3 : {true, false}) { std::vector args1{cond1, cond2, cond3, at::randn({1})}; std::vector args2{cond1, cond2, cond3, at::randn({3, 3, 3})}; testStaticRuntime(src, args1, args2); } } } } TEST(StaticRuntime, BasicForLoop) { const auto src = R"JIT( def forward(self, x, loop_max: int): y = x.clone() for i in range(loop_max): y += 1 return y )JIT"; std::vector args1{at::randn({1}), 10}; std::vector args2{at::randn({3, 3, 3}), 10}; testStaticRuntime(src, args1, args2); } TEST(StaticRuntime, BasicWhileLoop) { const auto src = R"JIT( def forward(self, x, loop_max: int): y = x.clone() loop_count = 0 while loop_count < loop_max: y += 1 loop_count += 1 return y )JIT"; std::vector args1{at::randn({1}), 10}; std::vector args2{at::randn({3, 3, 3}), 10}; testStaticRuntime(src, args1, args2); } TEST(StaticRuntime, NestedLoops) { const auto src = R"JIT( def forward(self, x, loop_max: int): y = x.clone() even: List[int] = [] odd: List[int] = [] for i in range(loop_max): if i % 2: odd.append(i) else: even.append(i) for j in range(i): y += 1 return y, even, odd )JIT"; std::vector args1{at::randn({1}), 10}; std::vector args2{at::randn({3, 3, 3}), 10}; testStaticRuntime(src, args1, args2); } TEST(StaticRuntime, TupleIndex) { const auto src = R"JIT( def forward(self, idx: int, tup: Tuple[int, int]): a = tup[idx] return a * a )JIT"; const auto tuple = c10::ivalue::Tuple::create({1, 2}); testStaticRuntime(src, {1, tuple}, {-1, tuple}); torch::jit::Module mod("module"); mod.define(src); StaticModule smod(mod); EXPECT_THROW(smod({100, tuple}), std::out_of_range); } TEST(StaticRuntime, RaiseException) { const auto src = R"IR( graph(%str: str): %none: NoneType = prim::Constant() prim::RaiseException(%str, %none) return (%none) )IR"; auto graph = getGraphFromIR(src); StaticModule smod(graph); const auto msg = "exception message"; EXPECT_THROW( { try { smod({msg}); } catch (const std::runtime_error& e) { EXPECT_STREQ(msg, e.what()); throw; } }, std::runtime_error); } TEST(StaticRuntime, Uninitialized) { const auto src = R"IR( graph(): %0: int = prim::Uninitialized() return (%0) )IR"; auto graph = getGraphFromIR(src); StaticModule smod(graph); const auto ret = smod({}); // If a and b are both uninitialized, then a != b. So just check that the type // is Any EXPECT_EQ(ret.type()->kind(), c10::TypeKind::AnyType); } TEST(StaticRuntime, Format) { const auto src = R"JIT( def forward(self, arg1: int, arg2: Tensor, arg3: str): a = "arg1: {}, arg2: {}, arg3: {}".format(arg1, arg2, arg3) return a[::] )JIT"; testStaticRuntime(src, {1, at::randn({3}), "str"}); } TEST(StaticRuntime, Device) { const auto src = R"JIT( def forward(self, x): return x.device, x.device )JIT"; testStaticRuntime(src, {at::tensor({1})}); } TEST(StaticRuntime, Dtype) { const auto src = R"JIT( def forward(self, x, y): return x.dtype, y.dtype )JIT"; testStaticRuntime( src, {at::tensor({1}, at::kLong), at::tensor({1}, at::kFloat)}); } TEST(StaticRuntime, Dim) { const auto src = R"JIT( def forward(self, x, y): return x.dim(), y.dim() )JIT"; testStaticRuntime(src, {at::randn({2, 2}), at::randn({1})}); } TEST(StaticRuntime, Not) { const auto src = R"JIT( def forward(self, x: bool, y: bool): return not x, not y )JIT"; testStaticRuntime(src, {true, false}); } TEST(StaticRuntime, Bool) { const auto src = R"JIT( def forward(self, x: Tensor, y: int, z: float): return bool(x), bool(y), bool(z) )JIT"; testStaticRuntime(src, {at::randn({1}), 0, 1.151}, {at::zeros({1}), 1, 0.0}); } TEST(StaticRuntime, IsCuda) { const auto src = R"JIT( def forward(self, x: Tensor, y: Tensor): return x.is_cuda, y.is_cuda )JIT"; testStaticRuntime(src, {at::randn({1}), at::randn({1})}); } TEST(StaticRuntime, ToList) { const auto src = R"JIT( graph(%x: Tensor): %type: int = prim::Constant[value=1]() %dim: int = aten::dim(%x) %ret: float[] = prim::tolist(%x, %dim, %type) return (%ret) )JIT"; testStaticRuntime(src, {at::randn({2, 2})}); } TEST(StaticRuntime, IfThenElse) { const auto src = R"IR( graph(%cond: bool, %a: Tensor, %b: Tensor): %none: NoneType = prim::Constant() %c: Tensor = prim::IfThenElse(%cond, %a, %b) %d: Tensor = aten::clone(%c, %none) return (%d) )IR"; std::vector args1{true, at::randn({1}), at::randn({1})}; std::vector args2{false, at::randn({1}), at::randn({1})}; testStaticRuntime(src, args1); testStaticRuntime(src, args2); } TEST(StaticRuntime, EmptyIfBlock) { const auto src = R"JIT( def forward(self, cond: bool, a: Tensor, b: Tensor): l = [] if cond: l.append((a + b).clone()) return l )JIT"; testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})}); testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})}); } TEST(StaticRuntime, EmptyNestedIfBlock) { const auto src = R"JIT( def forward(self, cond: bool, a: Tensor, b: Tensor): l = [] if cond: if cond: l.append((a + b).clone()) return l )JIT"; testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})}); testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})}); } TEST(StaticRuntime, StackEmpty) { const auto src = R"JIT( def forward(self): x = torch.stack([]) return x )JIT"; torch::jit::Module mod("mod"); mod.define(src); torch::jit::StaticModule smod(mod); EXPECT_THROW(smod({}), c10::Error); } TEST(StaticRuntime, ConcatEmpty) { const auto src = R"JIT( def forward(self): x = torch.concat([]) return x )JIT"; torch::jit::Module mod("mod"); mod.define(src); torch::jit::StaticModule smod(mod); EXPECT_THROW(smod({}), c10::Error); } TEST(StaticRuntime, IntImplicit) { const auto src = R"IR( graph(%a: Tensor): %y: int = aten::IntImplicit(%a) return (%y) )IR"; testStaticRuntime(src, {at::tensor({1}, at::kInt).squeeze()}); } TEST(StaticRuntime, IntImplicit_ThrowOnBadInputs) { const auto src = R"IR( graph(%a: Tensor): %y: int = aten::IntImplicit(%a) return (%y) )IR"; auto graph = getGraphFromIR(src); torch::jit::StaticModule smod(graph); // Not 0D tensor EXPECT_THROW(smod({at::tensor({1, 2}, at::kInt)}), std::runtime_error); // Wrong dtype EXPECT_THROW( smod({at::tensor({1}, at::kFloat).squeeze()}), std::runtime_error); } TEST(StaticRuntime, Select) { const auto src = R"IR( graph(%a: Tensor, %dim: int, %index: int): %none: NoneType = prim::Constant() %b: Tensor = aten::select(%a, %dim, %index) %c: Tensor = aten::clone(%b, %none) return (%c) )IR"; testStaticRuntime(src, {at::randn({2, 2}), 0, 1}); } TEST(StaticRuntime, ReshapeAs) { const auto src = R"JIT( def forward(self, a, b): return a.reshape_as(b).clone() )JIT"; testStaticRuntime(src, {at::randn({2, 2}), at::randn({4})}); } TEST(StaticRuntime, MoveCtor) { auto mod = getDeepAndWideSciptModel(); std::vector args{ at::randn({1, 1, 32}), at::randn({1, 1, 32}), at::randn({1, 50})}; torch::jit::StaticModule smod(mod); torch::jit::StaticRuntime runtime(smod); auto expected = runtime(args); torch::jit::StaticRuntime new_runtime(std::move(runtime)); auto actual = new_runtime(args); compareResults(expected, actual); } TEST(StaticRuntime, SingleBlockIfReturnList) { const auto src = R"JIT( def forward(self, a, b, cond: bool): lst = [] if cond: lst.append(a + b) return lst )JIT"; std::vector args1{at::randn({1}), at::randn({1}), true}; std::vector args2{at::randn({42, 42}), at::randn({42, 42}), false}; testStaticRuntime(src, args1, args2); } TEST(StaticRuntime, NestedBlockIfReturnList) { const auto src = R"JIT( def forward(self, a, b, cond1: bool, cond2: bool): if cond1: lst = [] if cond2: lst.append(a + b) lst.append(a * b) return lst return [] )JIT"; std::vector args1{at::randn({1}), at::randn({1}), true, true}; std::vector args2{ at::randn({42, 42}), at::randn({42, 42}), true, false}; testStaticRuntime(src, args1, args2); } TEST(StaticRuntime, ClampNaNToNum) { const auto src1 = R"JIT( def forward(self, a): return torch.clamp(a, min=1.0, max=2.0).nan_to_num().clone() )JIT"; const auto src2 = R"JIT( def forward(self, a, nan: float): return torch.clamp(a, min=-1.0, max=2.0).nan_to_num(nan=nan).clone() )JIT"; const auto src3 = R"JIT( def forward(self, a): return torch.clamp(a, min=1.0, max=-1.0).nan_to_num().clone() )JIT"; auto a = at::tensor({ std::numeric_limits::quiet_NaN(), std::numeric_limits::infinity(), -std::numeric_limits::infinity(), 0.0f, 3.0f }); auto b = a.repeat({10, 5}); // Have to use_allclose even though all NaNs will be replaced - testStaticRuntime // also checks inputs at the end to make sure they're not changed testStaticRuntime(src1, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true); testStaticRuntime(src1, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true); testStaticRuntime(src2, {a, 42.0}, {}, /*use_allclose=*/true, /*use_equalnan=*/true); testStaticRuntime(src2, {a, 2.0}, {b, 1.0}, /*use_allclose=*/true, /*use_equalnan=*/true); testStaticRuntime(src3, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true); testStaticRuntime(src3, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true); // Non-NNC path testStaticRuntime(src1, {a.to(at::kDouble)}, {}, /*use_allclose=*/true, /*use_equalnan=*/true); testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true); } TEST(StaticRuntime, IfReturningTuple) { const auto src = R"JIT( def forward(self, x, y, cond: bool, idx: int): if cond: tup = (x, y) else: tup = (x, x) return tup[idx] )JIT"; std::vector args{at::randn({3}), at::randn({3}), true, 0}; testStaticRuntime(src, args); }