mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[tensorexpr] Add python bindings for TensorExprKernel (#54450)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54450 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D27243175 Pulled By: huiguoo fbshipit-source-id: 820cf0d6cd1dd984d4153628e0f419d234668c82
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ba95e08a95
commit
50cb75edce
@ -1,7 +1,11 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
import unittest
|
||||
|
||||
LLVM_ENABLED = torch._C._llvm_enabled()
|
||||
|
||||
class kernel_arena_scope(object):
|
||||
def __enter__(self):
|
||||
@ -59,5 +63,61 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
codegen.call([tA, tB, tC])
|
||||
torch.testing.assert_allclose(torch.matmul(tA, tB), tC)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_tensor_inputs(self):
|
||||
def f(a, b, c):
|
||||
return a + b + c
|
||||
device, size = 'cpu', (4, 4)
|
||||
x = torch.rand(size, device=device)
|
||||
y = torch.rand(size, device=device)
|
||||
z = torch.rand(size, device=device)
|
||||
|
||||
graph_str = """
|
||||
graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
||||
%b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
||||
%c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)):
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6)
|
||||
%3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6)
|
||||
return (%3)
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
with kernel_arena_scope():
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x, y, z))
|
||||
res2 = kernel.fallback((x, y, z))
|
||||
correct = f(x, y, z)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_scalar_inputs(self):
|
||||
def f(a, b, c):
|
||||
return a + b + c
|
||||
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
|
||||
y = torch.tensor(0.6, dtype=torch.float, device='cpu')
|
||||
z = torch.tensor(0.7, dtype=torch.float, device='cpu')
|
||||
|
||||
graph_str = """
|
||||
graph(%a.1 : Float(requires_grad=0, device=cpu),
|
||||
%b.1 : Float(requires_grad=0, device=cpu),
|
||||
%c.1 : Float(requires_grad=0, device=cpu)):
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
%6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3)
|
||||
%9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3)
|
||||
return (%9)
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
with kernel_arena_scope():
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x, y, z))
|
||||
res2 = kernel.fallback((x, y, z))
|
||||
correct = f(x, y, z)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -35,6 +35,10 @@ class TORCH_API TensorExprKernel {
|
||||
return codegen_->getCodeText(attr);
|
||||
}
|
||||
|
||||
const std::shared_ptr<Graph> graph() {
|
||||
return graph_;
|
||||
}
|
||||
|
||||
private:
|
||||
enum ElementType {
|
||||
kAllTypes = 0,
|
||||
|
@ -7,6 +7,7 @@
|
||||
#endif
|
||||
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/reduction.h>
|
||||
@ -430,6 +431,54 @@ void initTensorExprBindings(PyObject* module) {
|
||||
[](Stmt* stmt) { return IRSimplifier::simplify(stmt); },
|
||||
py::return_value_policy::reference);
|
||||
|
||||
using TSGraph = std::shared_ptr<Graph>;
|
||||
py::class_<TensorExprKernel>(te, "TensorExprKernel")
|
||||
.def(py::init<const TSGraph&>())
|
||||
.def(
|
||||
"run",
|
||||
[](TensorExprKernel& self, const py::tuple& inputs) {
|
||||
Stack stack;
|
||||
stack.reserve(inputs.size()); // captures?
|
||||
for (auto& obj : inputs) {
|
||||
stack.push_back(toTypeInferredIValue(obj));
|
||||
}
|
||||
auto g_inputs = self.graph()->inputs();
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (stack[i].isTensor()) {
|
||||
g_inputs[i]->setType(stack[i].type());
|
||||
}
|
||||
}
|
||||
self.run(stack);
|
||||
return createPyObjectForStack(std::move(stack));
|
||||
})
|
||||
.def(
|
||||
"fallback",
|
||||
[](TensorExprKernel& self, const py::tuple& inputs) {
|
||||
Stack stack;
|
||||
stack.reserve(inputs.size()); // captures?
|
||||
for (auto& obj : inputs) {
|
||||
stack.push_back(toTypeInferredIValue(obj));
|
||||
}
|
||||
auto g_inputs = self.graph()->inputs();
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (stack[i].isTensor()) {
|
||||
g_inputs[i]->setType(stack[i].type());
|
||||
}
|
||||
}
|
||||
self.fallback(stack);
|
||||
return createPyObjectForStack(std::move(stack));
|
||||
})
|
||||
.def(
|
||||
"get_codegen_stmt",
|
||||
[](TensorExprKernel& self) { return self.getCodeGenStmt(); },
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"get_code_text",
|
||||
[](TensorExprKernel& self, const std::string& attr = "") {
|
||||
return self.getCodeText(attr);
|
||||
},
|
||||
py::arg("attr") = "");
|
||||
|
||||
py::class_<CodeGen>(te, "CodeGen")
|
||||
.def(
|
||||
"call",
|
||||
@ -442,7 +491,7 @@ void initTensorExprBindings(PyObject* module) {
|
||||
self.call(value_ptrs);
|
||||
})
|
||||
.def(
|
||||
"getCodeText",
|
||||
"get_code_text",
|
||||
[](CodeGen& self, const std::string& attr = "") {
|
||||
return self.getCodeText(attr);
|
||||
},
|
||||
|
Reference in New Issue
Block a user