mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] PyBinds: improve QoL of pybind users. (#64886)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64886 Bind methods for implicit conversions and constructors to avoid boilerplate code. Differential Revision: D30889193 D30889193 Test Plan: Imported from OSS Reviewed By: jbschlosser Pulled By: ZolotukhinM fbshipit-source-id: 137c0c98f7f1576e1bb97c8de8a900b28407a30e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
caaa6efc1a
commit
199031c48e
@ -9,15 +9,14 @@ import unittest
|
||||
LLVM_ENABLED = torch._C._llvm_enabled()
|
||||
|
||||
|
||||
def construct_adder(n: int, dtype=te.Dtype.Float):
|
||||
dN = te.ExprHandle.int(n)
|
||||
A = te.Placeholder('A', dtype, [dN])
|
||||
B = te.Placeholder('B', dtype, [dN])
|
||||
def construct_adder(n: int, dtype=torch.float32):
|
||||
A = te.BufHandle('A', [n], dtype)
|
||||
B = te.BufHandle('B', [n], dtype)
|
||||
|
||||
def compute(i):
|
||||
return A.load([i]) + B.load([i])
|
||||
|
||||
C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
|
||||
C = te.Compute('C', [n], compute)
|
||||
|
||||
loopnest = te.LoopNest([C])
|
||||
loopnest.prepare_for_codegen()
|
||||
@ -50,17 +49,15 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
def test_external_calls(self):
|
||||
dtype = torch.float32
|
||||
|
||||
ONE = te.ExprHandle.int(1)
|
||||
FOUR = te.ExprHandle.int(4)
|
||||
A = te.BufHandle('A', [ONE, FOUR], dtype)
|
||||
B = te.BufHandle('B', [FOUR, ONE], dtype)
|
||||
C = te.BufHandle('C', [ONE, ONE], dtype)
|
||||
A = te.BufHandle('A', [1, 4], dtype)
|
||||
B = te.BufHandle('B', [4, 1], dtype)
|
||||
C = te.BufHandle('C', [1, 1], dtype)
|
||||
|
||||
s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
|
||||
|
||||
loopnest = te.LoopNest(s, [C])
|
||||
loopnest.prepare_for_codegen()
|
||||
codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]])
|
||||
codegen = te.construct_codegen('ir_eval', s, [A, B, C])
|
||||
|
||||
tA = torch.ones(1, 4)
|
||||
tB = torch.ones(4, 1)
|
||||
@ -97,11 +94,8 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
test_with_shape(31)
|
||||
|
||||
def test_dtype_error(self):
|
||||
one = te.ExprHandle.int(1)
|
||||
te.Placeholder([one], torch.float32) # ok
|
||||
te.Placeholder([one]) # ok
|
||||
self.assertRaises(TypeError,
|
||||
lambda: te.Placeholder([one], "float55"))
|
||||
te.BufHandle('a', [1], torch.float32) # ok
|
||||
self.assertRaises(TypeError, lambda: te.BufHandle('a', [1], "float55"))
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_tensor_inputs(self):
|
||||
@ -124,7 +118,7 @@ graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x, y, z))
|
||||
res2 = kernel.fallback((x, y, z))
|
||||
correct = f(x, y, z)
|
||||
@ -151,7 +145,7 @@ graph(%a.1 : Float(requires_grad=0, device=cpu),
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x, y, z))
|
||||
res2 = kernel.fallback((x, y, z))
|
||||
correct = f(x, y, z)
|
||||
@ -173,7 +167,7 @@ graph(%a : Tensor, %b : Tensor):
|
||||
|
||||
exception_thrown = False
|
||||
try:
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
except RuntimeError:
|
||||
# Graph doesn't have shape info for inputs => compilation should
|
||||
# fail
|
||||
@ -187,7 +181,7 @@ graph(%a : Tensor, %b : Tensor):
|
||||
torch._C._jit_pass_propagate_shapes_on_graph(graph)
|
||||
|
||||
# Now compilation should pass
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
|
||||
res = kernel.run((x, y))
|
||||
correct = torch.mul(x, y)
|
||||
@ -205,7 +199,7 @@ graph(%a : Tensor, %b : Tensor):
|
||||
# shape info.
|
||||
exception_thrown = False
|
||||
try:
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
except RuntimeError:
|
||||
exception_thrown = True
|
||||
pass
|
||||
@ -231,7 +225,7 @@ graph(%a : Tensor, %b : Tensor):
|
||||
torch._C._jit_pass_propagate_shapes_on_graph(graph)
|
||||
|
||||
# Now compilation should pass
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
|
||||
device, size = 'cpu', (4, 4)
|
||||
x = torch.rand(size, device=device)
|
||||
@ -256,7 +250,7 @@ graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x,))
|
||||
res2 = kernel.fallback((x,))
|
||||
correct = f(x)
|
||||
@ -280,7 +274,7 @@ graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x,))
|
||||
res2 = kernel.fallback((x,))
|
||||
correct = f(x)
|
||||
@ -306,7 +300,7 @@ graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)):
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x,))
|
||||
res2 = kernel.fallback((x,))
|
||||
correct = f(x)
|
||||
@ -341,7 +335,7 @@ graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
|
||||
return te.ifThenElse(te.ExprHandle.isnan(load), te.ExprHandle.float(0.), load)
|
||||
return te.Compute2("custom_nan_to_num", get_dim_args(out_shape), compute)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph, {'aten::nan_to_num' : my_custom_lowering})
|
||||
kernel = te.TensorExprKernel(graph, {'aten::nan_to_num' : my_custom_lowering})
|
||||
res1 = kernel.run((x,))
|
||||
res2 = kernel.fallback((x,))
|
||||
correct = f(x)
|
||||
@ -367,7 +361,7 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
kernel = te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x,))
|
||||
res2 = kernel.fallback((x,))
|
||||
correct = f(x)
|
||||
@ -376,18 +370,15 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_alloc_in_loop(self):
|
||||
a, tmp, b = [
|
||||
te.Placeholder(name, te.Dtype.Float, [te.ExprHandle.int(1)])
|
||||
for name in ["a", "tmp", "b"]]
|
||||
t0, t100 = [te.ExprHandle.int(n) for n in [0, 100]]
|
||||
a, tmp, b = [te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"]]
|
||||
body = te.Block([
|
||||
tmp.store([t0], a.load([t0])),
|
||||
b.store([t0], tmp.load([t0]))
|
||||
tmp.store([0], a.load([0])),
|
||||
b.store([0], tmp.load([0]))
|
||||
])
|
||||
for _ in range(4):
|
||||
i = te.VarHandle("i", te.Dtype.Int)
|
||||
body = te.For.make(i, t0, t100, body)
|
||||
nest = te.LoopNest(body, [b.data()])
|
||||
i = te.VarHandle("i", torch.int32)
|
||||
body = te.For.make(i, 0, 100, body)
|
||||
nest = te.LoopNest(body, [b])
|
||||
nest.prepare_for_codegen()
|
||||
f = te.construct_codegen("llvm", nest.simplify(), [a, b])
|
||||
ta, tb = [torch.ones(1) for _ in range(2)]
|
||||
|
@ -75,6 +75,13 @@ void initTensorExprBindings(PyObject* module) {
|
||||
|
||||
auto expr_handle_class =
|
||||
py::class_<ExprHandle>(te, "ExprHandle")
|
||||
.def(
|
||||
"__str__",
|
||||
[](const ExprHandle& self) {
|
||||
std::stringstream ss;
|
||||
ss << self;
|
||||
return ss.str();
|
||||
})
|
||||
.def(py::self + py::self)
|
||||
.def(py::self * py::self)
|
||||
.def(py::self - py::self)
|
||||
@ -124,7 +131,23 @@ void initTensorExprBindings(PyObject* module) {
|
||||
.def("trunc", [](const ExprHandle& self) { return trunc(self); })
|
||||
.def("frac", [](const ExprHandle& self) { return frac(self); })
|
||||
.def("lgamma", [](const ExprHandle& self) { return lgamma(self); })
|
||||
.def("isnan", [](const ExprHandle& self) { return isnan(self); });
|
||||
.def("isnan", [](const ExprHandle& self) { return isnan(self); })
|
||||
.def(
|
||||
"cast",
|
||||
[](const ExprHandle& self, const Dtype& dt) {
|
||||
return Cast::make(dt, self);
|
||||
})
|
||||
#define EXPRHANDLE_INIT(ctype, name) \
|
||||
.def(py::init([](ctype val) { return name##Imm::make(val); }))
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_INIT)
|
||||
#undef EXPRHANDLE_INIT
|
||||
;
|
||||
|
||||
#define EXPRHANDLE_IMPL_CONV(ctype, name) \
|
||||
py::implicitly_convertible<ctype, ExprHandle>();
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_IMPL_CONV)
|
||||
#undef EXPRHANDLE_IMPL_CONV
|
||||
|
||||
te.def(
|
||||
"ifThenElse",
|
||||
[](const ExprHandle& c, const ExprHandle& t, const ExprHandle& f) {
|
||||
@ -149,6 +172,13 @@ void initTensorExprBindings(PyObject* module) {
|
||||
#undef EXPRHANDLE_CTOR
|
||||
|
||||
py::class_<VarHandle, ExprHandle>(te, "VarHandle")
|
||||
.def(
|
||||
"__str__",
|
||||
[](const ExprHandle& self) {
|
||||
std::stringstream ss;
|
||||
ss << self;
|
||||
return ss.str();
|
||||
})
|
||||
.def(py::init<Dtype>())
|
||||
.def(py::init<const std::string&, Dtype>());
|
||||
py::class_<BufHandle, ExprHandle>( // NOLINT
|
||||
@ -163,9 +193,16 @@ void initTensorExprBindings(PyObject* module) {
|
||||
[](BufHandle& self, const std::vector<ExprHandle>& v) {
|
||||
return Load::make(self, v);
|
||||
})
|
||||
.def("load", [](BufHandle& self, const ExprHandle& v) {
|
||||
return Load::make(self, {v});
|
||||
});
|
||||
.def(
|
||||
"load",
|
||||
[](BufHandle& self, const ExprHandle& v) {
|
||||
return Load::make(self, {v});
|
||||
})
|
||||
.def(
|
||||
"store",
|
||||
[](BufHandle& self,
|
||||
const std::vector<ExprHandle>& args,
|
||||
const ExprHandle& val) { return Store::make(self, args, val); });
|
||||
|
||||
py::class_<Placeholder>(te, "Placeholder")
|
||||
.def(py::init<
|
||||
@ -196,12 +233,20 @@ void initTensorExprBindings(PyObject* module) {
|
||||
.def("buf", [](Tensor& self) { return BufHandle(self.buf()); })
|
||||
.def("stmt", &Tensor::stmt);
|
||||
py::class_<Cast, std::shared_ptr<Cast>>(te, "Cast")
|
||||
.def_static("make", &Cast::make);
|
||||
.def_static("make", &Cast::make)
|
||||
.def(
|
||||
"src_value",
|
||||
[](CastPtr& self) { return ExprHandle(self->src_value()); })
|
||||
.def("set_src_value", [](CastPtr& self, const ExprHandle& value) {
|
||||
self->set_src_value(value.node());
|
||||
});
|
||||
|
||||
py::class_<DimArg>(te, "DimArg")
|
||||
.def(py::init<const ExprHandle&>())
|
||||
.def(py::init<const ExprHandle&, const std::string&>());
|
||||
py::implicitly_convertible<ExprHandle, DimArg>();
|
||||
py::implicitly_convertible<int32_t, DimArg>();
|
||||
py::implicitly_convertible<int64_t, DimArg>();
|
||||
|
||||
te.def(
|
||||
"Compute",
|
||||
@ -584,7 +629,7 @@ void initTensorExprBindings(PyObject* module) {
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"flatten",
|
||||
[](const std::vector<ForPtr>& loops) {
|
||||
[](LoopNest& self, const std::vector<ForPtr>& loops) {
|
||||
ForPtr flattened = nullptr;
|
||||
LoopNest::flatten(loops, &flattened);
|
||||
return flattened;
|
||||
|
Reference in New Issue
Block a user