mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65345 FooType::get() can return a const reference. Inconveniently, converting shared_ptr<FooType> to shared_ptr<Type> requires a copy & refcount bump, so to properly take advantage of this in unshapedType() we need to take a const Type& in isSubtypeOf(), which is good practice anyway -- don't require a shared_ptr if you don't need to take ownership. ghstack-source-id: 140044165 Test Plan: CI perf says c10::unshapedType time decreased from 2.8% to 2.2% during static runtime startup, though I expect this to be generally beneficial. Reviewed By: hlu1 Differential Revision: D31027361 fbshipit-source-id: 676feb81db9f74ad7b8651d8774f4ecb4cfa6ab8
261 lines
8.4 KiB
C++
261 lines
8.4 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/jit/test_utils.h>
|
|
|
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/runtime/custom_operator.h>
|
|
#include <torch/csrc/jit/runtime/register_ops_utils.h>
|
|
#include <torch/jit.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TEST(CustomOperatorTest, InferredSchema) {
|
|
torch::RegisterOperators reg(
|
|
"foo::bar", [](double a, at::Tensor b) { return a + b; });
|
|
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
|
|
ASSERT_EQ(ops.size(), 1);
|
|
|
|
auto& op = ops.front();
|
|
ASSERT_EQ(op->schema().name(), "foo::bar");
|
|
|
|
ASSERT_EQ(op->schema().arguments().size(), 2);
|
|
ASSERT_EQ(op->schema().arguments()[0].name(), "_0");
|
|
ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
|
|
ASSERT_EQ(op->schema().arguments()[1].name(), "_1");
|
|
ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
|
|
|
|
ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
|
|
|
|
Stack stack;
|
|
push(stack, 2.0f, at::ones(5));
|
|
op->getOperation()(stack);
|
|
at::Tensor output;
|
|
pop(stack, output);
|
|
|
|
ASSERT_TRUE(output.allclose(at::full(5, 3.0f)));
|
|
}
|
|
|
|
TEST(CustomOperatorTest, ExplicitSchema) {
|
|
torch::RegisterOperators reg(
|
|
"foo::bar_with_schema(float a, Tensor b) -> Tensor",
|
|
[](double a, at::Tensor b) { return a + b; });
|
|
|
|
auto& ops =
|
|
getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
|
|
ASSERT_EQ(ops.size(), 1);
|
|
|
|
auto& op = ops.front();
|
|
ASSERT_EQ(op->schema().name(), "foo::bar_with_schema");
|
|
|
|
ASSERT_EQ(op->schema().arguments().size(), 2);
|
|
ASSERT_EQ(op->schema().arguments()[0].name(), "a");
|
|
ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
|
|
ASSERT_EQ(op->schema().arguments()[1].name(), "b");
|
|
ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
|
|
|
|
ASSERT_EQ(op->schema().returns().size(), 1);
|
|
ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
|
|
|
|
Stack stack;
|
|
push(stack, 2.0f, at::ones(5));
|
|
op->getOperation()(stack);
|
|
at::Tensor output;
|
|
pop(stack, output);
|
|
|
|
ASSERT_TRUE(output.allclose(at::full(5, 3.0f)));
|
|
}
|
|
|
|
TEST(CustomOperatorTest, ListParameters) {
|
|
// Check that lists work well.
|
|
torch::RegisterOperators reg(
|
|
"foo::lists(int[] ints, float[] floats, complex[] complexdoubles, Tensor[] tensors) -> float[]",
|
|
[](torch::List<int64_t> ints,
|
|
torch::List<double> floats,
|
|
torch::List<c10::complex<double>> complexdoubles,
|
|
torch::List<at::Tensor> tensors) { return floats; });
|
|
|
|
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
|
|
ASSERT_EQ(ops.size(), 1);
|
|
|
|
auto& op = ops.front();
|
|
ASSERT_EQ(op->schema().name(), "foo::lists");
|
|
|
|
ASSERT_EQ(op->schema().arguments().size(), 4);
|
|
ASSERT_EQ(op->schema().arguments()[0].name(), "ints");
|
|
ASSERT_TRUE(
|
|
op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofInts()));
|
|
ASSERT_EQ(op->schema().arguments()[1].name(), "floats");
|
|
ASSERT_TRUE(
|
|
op->schema().arguments()[1].type()->isSubtypeOf(*ListType::ofFloats()));
|
|
ASSERT_EQ(op->schema().arguments()[2].name(), "complexdoubles");
|
|
ASSERT_TRUE(op->schema().arguments()[2].type()->isSubtypeOf(
|
|
*ListType::ofComplexDoubles()));
|
|
ASSERT_EQ(op->schema().arguments()[3].name(), "tensors");
|
|
ASSERT_TRUE(
|
|
op->schema().arguments()[3].type()->isSubtypeOf(*ListType::ofTensors()));
|
|
|
|
ASSERT_EQ(op->schema().returns().size(), 1);
|
|
ASSERT_TRUE(
|
|
op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofFloats()));
|
|
|
|
Stack stack;
|
|
push(stack, c10::List<int64_t>({1, 2}));
|
|
push(stack, c10::List<double>({1.0, 2.0}));
|
|
push(
|
|
stack,
|
|
c10::List<c10::complex<double>>(
|
|
{c10::complex<double>(2.4, -5.5), c10::complex<double>(-1.3, 2)}));
|
|
push(stack, c10::List<at::Tensor>({at::ones(5)}));
|
|
op->getOperation()(stack);
|
|
c10::List<double> output;
|
|
pop(stack, output);
|
|
|
|
ASSERT_EQ(output.size(), 2);
|
|
ASSERT_EQ(output.get(0), 1.0);
|
|
ASSERT_EQ(output.get(1), 2.0);
|
|
}
|
|
|
|
TEST(CustomOperatorTest, ListParameters2) {
|
|
torch::RegisterOperators reg(
|
|
"foo::lists2(Tensor[] tensors) -> Tensor[]",
|
|
[](torch::List<at::Tensor> tensors) { return tensors; });
|
|
|
|
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
|
|
ASSERT_EQ(ops.size(), 1);
|
|
|
|
auto& op = ops.front();
|
|
ASSERT_EQ(op->schema().name(), "foo::lists2");
|
|
|
|
ASSERT_EQ(op->schema().arguments().size(), 1);
|
|
ASSERT_EQ(op->schema().arguments()[0].name(), "tensors");
|
|
ASSERT_TRUE(
|
|
op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofTensors()));
|
|
|
|
ASSERT_EQ(op->schema().returns().size(), 1);
|
|
ASSERT_TRUE(
|
|
op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofTensors()));
|
|
|
|
Stack stack;
|
|
push(stack, c10::List<at::Tensor>({at::ones(5)}));
|
|
op->getOperation()(stack);
|
|
c10::List<at::Tensor> output;
|
|
pop(stack, output);
|
|
|
|
ASSERT_EQ(output.size(), 1);
|
|
ASSERT_TRUE(output.get(0).allclose(at::ones(5)));
|
|
}
|
|
|
|
TEST(CustomOperatorTest, Aliasing) {
|
|
torch::RegisterOperators reg(
|
|
"foo::aliasing", [](at::Tensor a, at::Tensor b) -> at::Tensor {
|
|
a.add_(b);
|
|
return a;
|
|
});
|
|
getAllOperatorsFor(Symbol::fromQualString("foo::aliasing"));
|
|
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x: Tensor, %y: Tensor):
|
|
%ret : Tensor = foo::aliasing(%x, %y)
|
|
return (%ret)
|
|
)IR",
|
|
graph.get());
|
|
|
|
auto opNode = *graph->block()->nodes().begin();
|
|
|
|
AliasDb aliasDb(graph);
|
|
for (const auto input : opNode->inputs()) {
|
|
// The custom op writes to all its inputs
|
|
ASSERT_TRUE(aliasDb.writesToAlias(opNode, {input}));
|
|
// The output should be a wildcard and thus alias all inputs
|
|
ASSERT_TRUE(aliasDb.mayAlias(opNode->output(), input));
|
|
}
|
|
}
|
|
{
|
|
// DCE should not remove a custom op
|
|
auto graph = std::make_shared<Graph>();
|
|
const auto text = R"IR(
|
|
graph(%x: Tensor, %y: Tensor):
|
|
# CHECK: foo::aliasing
|
|
%ret : Tensor = foo::aliasing(%x, %y)
|
|
return (%x)
|
|
)IR";
|
|
parseIR(text, graph.get());
|
|
EliminateDeadCode(graph);
|
|
|
|
testing::FileCheck().run(text, *graph);
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
static constexpr char op_list[] = "foofoo::bar.template;foo::another";
|
|
#define TORCH_SELECTIVE_NAME_IN_SCHEMA(l, n) \
|
|
torch::detail::SelectiveStr<c10::impl::op_allowlist_contains_name_in_schema( \
|
|
l, n)>(n)
|
|
|
|
TEST(TestCustomOperator, OperatorGeneratorUndeclared) {
|
|
// Try to register an op name that does not exist in op_list.
|
|
// Expected: the op name is not registered.
|
|
torch::jit::RegisterOperators reg({OperatorGenerator(
|
|
TORCH_SELECTIVE_NAME_IN_SCHEMA(
|
|
op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"),
|
|
[](Stack& stack) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
double a;
|
|
at::Tensor b;
|
|
pop(stack, a, b);
|
|
push(stack, a + b);
|
|
},
|
|
aliasAnalysisFromSchema())});
|
|
|
|
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist"));
|
|
ASSERT_EQ(ops.size(), 0);
|
|
}
|
|
|
|
TEST(TestCustomOperator, OperatorGeneratorBasic) {
|
|
// The operator should be successfully registered since its name is in the
|
|
// whitelist.
|
|
torch::jit::RegisterOperators reg({OperatorGenerator(
|
|
TORCH_SELECTIVE_NAME_IN_SCHEMA(
|
|
op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"),
|
|
[](Stack& stack) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
double a;
|
|
at::Tensor b;
|
|
pop(stack, a, b);
|
|
push(stack, a + b);
|
|
},
|
|
aliasAnalysisFromSchema())});
|
|
|
|
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar"));
|
|
ASSERT_EQ(ops.size(), 1);
|
|
|
|
auto& op = ops.front();
|
|
ASSERT_EQ(op->schema().name(), "foofoo::bar");
|
|
|
|
ASSERT_EQ(op->schema().arguments().size(), 2);
|
|
ASSERT_EQ(op->schema().arguments()[0].name(), "a");
|
|
ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
|
|
ASSERT_EQ(op->schema().arguments()[1].name(), "b");
|
|
ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
|
|
|
|
ASSERT_EQ(op->schema().returns().size(), 1);
|
|
ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
|
|
|
|
Stack stack;
|
|
push(stack, 2.0f, at::ones(5));
|
|
op->getOperation()(stack);
|
|
at::Tensor output;
|
|
pop(stack, output);
|
|
|
|
ASSERT_TRUE(output.allclose(at::full(5, 3.0f)));
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|