[tensorexpr] Add python bindings for TensorExprKernel (#54450)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54450

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D27243175

Pulled By: huiguoo

fbshipit-source-id: 820cf0d6cd1dd984d4153628e0f419d234668c82
This commit is contained in:
Hui Guo
2021-04-01 02:10:15 -07:00
committed by Facebook GitHub Bot
parent ba95e08a95
commit 50cb75edce
3 changed files with 114 additions and 1 deletions

View File

@ -1,7 +1,11 @@
import torch
import numpy as np
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.jit_utils import JitTestCase
import unittest
LLVM_ENABLED = torch._C._llvm_enabled()
class kernel_arena_scope(object):
def __enter__(self):
@ -59,5 +63,61 @@ class TestTensorExprPyBind(JitTestCase):
codegen.call([tA, tB, tC])
torch.testing.assert_allclose(torch.matmul(tA, tB), tC)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_tensor_inputs(self):
def f(a, b, c):
return a + b + c
device, size = 'cpu', (4, 4)
x = torch.rand(size, device=device)
y = torch.rand(size, device=device)
z = torch.rand(size, device=device)
graph_str = """
graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
%b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
%c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%6 : int = prim::Constant[value=1]()
%7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6)
%3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6)
return (%3)
"""
graph = torch._C.parse_ir(graph_str)
with kernel_arena_scope():
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_scalar_inputs(self):
def f(a, b, c):
return a + b + c
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
y = torch.tensor(0.6, dtype=torch.float, device='cpu')
z = torch.tensor(0.7, dtype=torch.float, device='cpu')
graph_str = """
graph(%a.1 : Float(requires_grad=0, device=cpu),
%b.1 : Float(requires_grad=0, device=cpu),
%c.1 : Float(requires_grad=0, device=cpu)):
%3 : int = prim::Constant[value=1]()
%6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3)
%9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3)
return (%9)
"""
graph = torch._C.parse_ir(graph_str)
with kernel_arena_scope():
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
if __name__ == '__main__':
run_tests()

View File

@ -35,6 +35,10 @@ class TORCH_API TensorExprKernel {
return codegen_->getCodeText(attr);
}
const std::shared_ptr<Graph> graph() {
return graph_;
}
private:
enum ElementType {
kAllTypes = 0,

View File

@ -7,6 +7,7 @@
#endif
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/reduction.h>
@ -430,6 +431,54 @@ void initTensorExprBindings(PyObject* module) {
[](Stmt* stmt) { return IRSimplifier::simplify(stmt); },
py::return_value_policy::reference);
using TSGraph = std::shared_ptr<Graph>;
py::class_<TensorExprKernel>(te, "TensorExprKernel")
.def(py::init<const TSGraph&>())
.def(
"run",
[](TensorExprKernel& self, const py::tuple& inputs) {
Stack stack;
stack.reserve(inputs.size()); // captures?
for (auto& obj : inputs) {
stack.push_back(toTypeInferredIValue(obj));
}
auto g_inputs = self.graph()->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
if (stack[i].isTensor()) {
g_inputs[i]->setType(stack[i].type());
}
}
self.run(stack);
return createPyObjectForStack(std::move(stack));
})
.def(
"fallback",
[](TensorExprKernel& self, const py::tuple& inputs) {
Stack stack;
stack.reserve(inputs.size()); // captures?
for (auto& obj : inputs) {
stack.push_back(toTypeInferredIValue(obj));
}
auto g_inputs = self.graph()->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
if (stack[i].isTensor()) {
g_inputs[i]->setType(stack[i].type());
}
}
self.fallback(stack);
return createPyObjectForStack(std::move(stack));
})
.def(
"get_codegen_stmt",
[](TensorExprKernel& self) { return self.getCodeGenStmt(); },
py::return_value_policy::reference)
.def(
"get_code_text",
[](TensorExprKernel& self, const std::string& attr = "") {
return self.getCodeText(attr);
},
py::arg("attr") = "");
py::class_<CodeGen>(te, "CodeGen")
.def(
"call",
@ -442,7 +491,7 @@ void initTensorExprBindings(PyObject* module) {
self.call(value_ptrs);
})
.def(
"getCodeText",
"get_code_text",
[](CodeGen& self, const std::string& attr = "") {
return self.getCodeText(attr);
},