mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] Python binding improvements (#59920)
Summary: Some minor quality of life improvements for the NNC python bindings: - expose `call_raw()` - support passing integers to `call()` (for dynamic shapes) - implicit conversions to cleanup `[BufferArg(x) for x in [A, B, C]]` into just `[A, B, C]` - don't silently default to "ir_eval" for unknown mode (e.g. "LLVM") Pull Request resolved: https://github.com/pytorch/pytorch/pull/59920 Reviewed By: ZolotukhinM Differential Revision: D29090904 Pulled By: jansel fbshipit-source-id: 154ace82725ae2046cfe2e6eb324fd37f5d209a7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
68d690ffbd
commit
3d90c82a5c
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch._C._te as te
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
@ -7,6 +8,7 @@ import unittest
|
||||
|
||||
LLVM_ENABLED = torch._C._llvm_enabled()
|
||||
|
||||
|
||||
class kernel_arena_scope(object):
|
||||
def __enter__(self):
|
||||
self.scope = torch._C._te.KernelScope()
|
||||
@ -14,48 +16,62 @@ class kernel_arena_scope(object):
|
||||
def __exit__(self, typ, val, traceback):
|
||||
self.scope = None
|
||||
|
||||
|
||||
def construct_adder(n: int, dtype=te.Dtype.Float):
|
||||
dN = te.ExprHandle.int(n)
|
||||
A = te.Placeholder('A', dtype, [dN])
|
||||
B = te.Placeholder('B', dtype, [dN])
|
||||
|
||||
def compute(i):
|
||||
return A.load([i]) + B.load([i])
|
||||
|
||||
C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
|
||||
|
||||
loopnest = te.LoopNest([C])
|
||||
loopnest.prepare_for_codegen()
|
||||
stmt = te.simplify(loopnest.root_stmt())
|
||||
|
||||
return te.construct_codegen('ir_eval', stmt, [A, B, C])
|
||||
|
||||
|
||||
class TestTensorExprPyBind(JitTestCase):
|
||||
def test_simple_sum(self):
|
||||
with kernel_arena_scope():
|
||||
dtype = torch._C._te.Dtype.Float
|
||||
N = 32
|
||||
dN = torch._C._te.ExprHandle.int(N)
|
||||
n = 32
|
||||
cg = construct_adder(n)
|
||||
|
||||
A = torch._C._te.Placeholder('A', dtype, [dN])
|
||||
B = torch._C._te.Placeholder('B', dtype, [dN])
|
||||
|
||||
def compute(i):
|
||||
return A.load([i]) + B.load([i])
|
||||
C = torch._C._te.Compute('C', [torch._C._te.DimArg(dN, 'i')], compute)
|
||||
|
||||
loopnest = torch._C._te.LoopNest([C])
|
||||
loopnest.prepare_for_codegen()
|
||||
stmt = torch._C._te.simplify(loopnest.root_stmt())
|
||||
|
||||
cg = torch._C._te.construct_codegen('ir_eval', stmt, [torch._C._te.BufferArg(x) for x in [A, B, C]])
|
||||
|
||||
tA = torch.rand(N) * 5
|
||||
tB = torch.rand(N) * 6
|
||||
tC = torch.empty(N)
|
||||
tA = torch.randn(n)
|
||||
tB = torch.randn(n)
|
||||
tC = torch.empty(n)
|
||||
cg.call([tA, tB, tC])
|
||||
torch.testing.assert_allclose(tA + tB, tC)
|
||||
|
||||
def test_call_raw(self):
|
||||
with kernel_arena_scope():
|
||||
n = 16
|
||||
cg = construct_adder(n, dtype=te.Dtype.Double)
|
||||
|
||||
tA = torch.randn(n, dtype=torch.float64)
|
||||
tB = torch.randn(n, dtype=torch.float64)
|
||||
tC = torch.empty(n, dtype=torch.float64)
|
||||
cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()])
|
||||
torch.testing.assert_allclose(tA + tB, tC)
|
||||
|
||||
def test_external_calls(self):
|
||||
with kernel_arena_scope():
|
||||
dtype = torch._C._te.Dtype.Float
|
||||
dtype = te.Dtype.Float
|
||||
|
||||
ZERO = torch._C._te.ExprHandle.int(0)
|
||||
ONE = torch._C._te.ExprHandle.int(1)
|
||||
FOUR = torch._C._te.ExprHandle.int(4)
|
||||
A = torch._C._te.BufHandle('A', [ONE, FOUR], dtype)
|
||||
B = torch._C._te.BufHandle('B', [FOUR, ONE], dtype)
|
||||
C = torch._C._te.BufHandle('C', [ONE, ONE], dtype)
|
||||
ONE = te.ExprHandle.int(1)
|
||||
FOUR = te.ExprHandle.int(4)
|
||||
A = te.BufHandle('A', [ONE, FOUR], dtype)
|
||||
B = te.BufHandle('B', [FOUR, ONE], dtype)
|
||||
C = te.BufHandle('C', [ONE, ONE], dtype)
|
||||
|
||||
s = torch._C._te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
|
||||
s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
|
||||
|
||||
loopnest = torch._C._te.LoopNest(s, [C])
|
||||
loopnest = te.LoopNest(s, [C])
|
||||
loopnest.prepare_for_codegen()
|
||||
codegen = torch._C._te.construct_codegen('ir_eval', s, [torch._C._te.BufferArg(x) for x in [A, B, C]])
|
||||
codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]])
|
||||
|
||||
tA = torch.ones(1, 4)
|
||||
tB = torch.ones(4, 1)
|
||||
@ -63,10 +79,41 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
codegen.call([tA, tB, tC])
|
||||
torch.testing.assert_allclose(torch.matmul(tA, tB), tC)
|
||||
|
||||
def test_dynamic_shape(self):
|
||||
with kernel_arena_scope():
|
||||
dN = te.VarHandle("n", te.Dtype.Int)
|
||||
A = te.Placeholder('A', te.Dtype.Double, [dN])
|
||||
B = te.Placeholder('B', te.Dtype.Double, [dN])
|
||||
|
||||
def compute(i):
|
||||
return A.load([i]) - B.load([i])
|
||||
|
||||
C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
|
||||
|
||||
loopnest = te.LoopNest([C])
|
||||
loopnest.prepare_for_codegen()
|
||||
stmt = te.simplify(loopnest.root_stmt())
|
||||
|
||||
cg = te.construct_codegen(
|
||||
'ir_eval',
|
||||
stmt,
|
||||
[A, B, C, dN])
|
||||
|
||||
def test_with_shape(n):
|
||||
tA = torch.randn(n, dtype=torch.double)
|
||||
tB = torch.randn(n, dtype=torch.double)
|
||||
tC = torch.empty(n, dtype=torch.double)
|
||||
cg.call([tA, tB, tC, n])
|
||||
torch.testing.assert_allclose(tA - tB, tC)
|
||||
|
||||
test_with_shape(8)
|
||||
test_with_shape(31)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_tensor_inputs(self):
|
||||
def f(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
device, size = 'cpu', (4, 4)
|
||||
x = torch.rand(size, device=device)
|
||||
y = torch.rand(size, device=device)
|
||||
@ -94,6 +141,7 @@ graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
||||
def test_kernel_with_scalar_inputs(self):
|
||||
def f(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
|
||||
y = torch.tensor(0.6, dtype=torch.float, device='cpu')
|
||||
z = torch.tensor(0.7, dtype=torch.float, device='cpu')
|
||||
|
@ -47,6 +47,7 @@ ArgValue convertPyToArgValue(py::handle inp) {
|
||||
throw std::runtime_error("nyi");
|
||||
}
|
||||
}
|
||||
|
||||
void initTensorExprBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
@ -653,14 +654,30 @@ void initTensorExprBindings(PyObject* module) {
|
||||
py::class_<CodeGen>(te, "CodeGen")
|
||||
.def(
|
||||
"call",
|
||||
[](CodeGen& self, const std::vector<at::Tensor>& values) {
|
||||
[](CodeGen& self, const py::sequence& values) {
|
||||
std::vector<CodeGen::CallArg> value_ptrs;
|
||||
value_ptrs.reserve(values.size());
|
||||
value_ptrs.reserve(py::len(values));
|
||||
for (const auto& value : values) {
|
||||
value_ptrs.emplace_back(CodeGen::CallArg(value.data_ptr()));
|
||||
if (py::isinstance<py::int_>(value)) {
|
||||
value_ptrs.emplace_back(value.cast<int64_t>());
|
||||
} else {
|
||||
value_ptrs.emplace_back(value.cast<at::Tensor>().data_ptr());
|
||||
}
|
||||
}
|
||||
self.call(value_ptrs);
|
||||
})
|
||||
.def(
|
||||
"call_raw",
|
||||
[](CodeGen& self, const py::sequence& values) {
|
||||
std::vector<void*> value_ptrs;
|
||||
value_ptrs.reserve(py::len(values));
|
||||
for (const auto& value : values) {
|
||||
// Tensor.data_ptr() returns an int in python
|
||||
value_ptrs.emplace_back(
|
||||
reinterpret_cast<void*>(value.cast<intptr_t>()));
|
||||
}
|
||||
self.call_raw(value_ptrs);
|
||||
})
|
||||
.def(
|
||||
"get_code_text",
|
||||
[](CodeGen& self, const std::string& attr = "") {
|
||||
@ -678,6 +695,11 @@ void initTensorExprBindings(PyObject* module) {
|
||||
.def(py::init<const VarHandle&>())
|
||||
.def(py::init<const BufHandle&>());
|
||||
|
||||
py::implicitly_convertible<Placeholder, CodeGen::BufferArg>();
|
||||
py::implicitly_convertible<Tensor*, CodeGen::BufferArg>();
|
||||
py::implicitly_convertible<VarHandle, CodeGen::BufferArg>();
|
||||
py::implicitly_convertible<BufHandle, CodeGen::BufferArg>();
|
||||
|
||||
te.def(
|
||||
"construct_codegen",
|
||||
[](const std::string& name,
|
||||
@ -696,8 +718,11 @@ void initTensorExprBindings(PyObject* module) {
|
||||
#else
|
||||
throw std::runtime_error("PyTorch not compiled with CUDA support!");
|
||||
#endif
|
||||
} else {
|
||||
} else if (name == "ir_eval") {
|
||||
cg = new SimpleIREvaluator(stmt, args);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"construct_codegen() expects 'llvm', 'cuda', or 'ir_eval'");
|
||||
}
|
||||
return cg;
|
||||
});
|
||||
|
Reference in New Issue
Block a user