[TensorExpr] PyBinds: improve QoL of pybind users. (#64886)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64886

Bind methods for implicit conversions and constructors to avoid
boilerplate code.

Differential Revision:
D30889193
D30889193

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Pulled By: ZolotukhinM

fbshipit-source-id: 137c0c98f7f1576e1bb97c8de8a900b28407a30e
This commit is contained in:
Mikhail Zolotukhin
2021-09-14 00:19:57 -07:00
committed by Facebook GitHub Bot
parent caaa6efc1a
commit 199031c48e
2 changed files with 78 additions and 42 deletions

View File

@ -9,15 +9,14 @@ import unittest
LLVM_ENABLED = torch._C._llvm_enabled()
def construct_adder(n: int, dtype=te.Dtype.Float):
dN = te.ExprHandle.int(n)
A = te.Placeholder('A', dtype, [dN])
B = te.Placeholder('B', dtype, [dN])
def construct_adder(n: int, dtype=torch.float32):
A = te.BufHandle('A', [n], dtype)
B = te.BufHandle('B', [n], dtype)
def compute(i):
return A.load([i]) + B.load([i])
C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
C = te.Compute('C', [n], compute)
loopnest = te.LoopNest([C])
loopnest.prepare_for_codegen()
@ -50,17 +49,15 @@ class TestTensorExprPyBind(JitTestCase):
def test_external_calls(self):
dtype = torch.float32
ONE = te.ExprHandle.int(1)
FOUR = te.ExprHandle.int(4)
A = te.BufHandle('A', [ONE, FOUR], dtype)
B = te.BufHandle('B', [FOUR, ONE], dtype)
C = te.BufHandle('C', [ONE, ONE], dtype)
A = te.BufHandle('A', [1, 4], dtype)
B = te.BufHandle('B', [4, 1], dtype)
C = te.BufHandle('C', [1, 1], dtype)
s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
loopnest = te.LoopNest(s, [C])
loopnest.prepare_for_codegen()
codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]])
codegen = te.construct_codegen('ir_eval', s, [A, B, C])
tA = torch.ones(1, 4)
tB = torch.ones(4, 1)
@ -97,11 +94,8 @@ class TestTensorExprPyBind(JitTestCase):
test_with_shape(31)
def test_dtype_error(self):
one = te.ExprHandle.int(1)
te.Placeholder([one], torch.float32) # ok
te.Placeholder([one]) # ok
self.assertRaises(TypeError,
lambda: te.Placeholder([one], "float55"))
te.BufHandle('a', [1], torch.float32) # ok
self.assertRaises(TypeError, lambda: te.BufHandle('a', [1], "float55"))
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_tensor_inputs(self):
@ -124,7 +118,7 @@ graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
@ -151,7 +145,7 @@ graph(%a.1 : Float(requires_grad=0, device=cpu),
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
@ -173,7 +167,7 @@ graph(%a : Tensor, %b : Tensor):
exception_thrown = False
try:
kernel = torch._C._te.TensorExprKernel(graph)
kernel = te.TensorExprKernel(graph)
except RuntimeError:
# Graph doesn't have shape info for inputs => compilation should
# fail
@ -187,7 +181,7 @@ graph(%a : Tensor, %b : Tensor):
torch._C._jit_pass_propagate_shapes_on_graph(graph)
# Now compilation should pass
kernel = torch._C._te.TensorExprKernel(graph)
kernel = te.TensorExprKernel(graph)
res = kernel.run((x, y))
correct = torch.mul(x, y)
@ -205,7 +199,7 @@ graph(%a : Tensor, %b : Tensor):
# shape info.
exception_thrown = False
try:
kernel = torch._C._te.TensorExprKernel(graph)
kernel = te.TensorExprKernel(graph)
except RuntimeError:
exception_thrown = True
pass
@ -231,7 +225,7 @@ graph(%a : Tensor, %b : Tensor):
torch._C._jit_pass_propagate_shapes_on_graph(graph)
# Now compilation should pass
kernel = torch._C._te.TensorExprKernel(graph)
kernel = te.TensorExprKernel(graph)
device, size = 'cpu', (4, 4)
x = torch.rand(size, device=device)
@ -256,7 +250,7 @@ graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
@ -280,7 +274,7 @@ graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
@ -306,7 +300,7 @@ graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)):
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
@ -341,7 +335,7 @@ graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
return te.ifThenElse(te.ExprHandle.isnan(load), te.ExprHandle.float(0.), load)
return te.Compute2("custom_nan_to_num", get_dim_args(out_shape), compute)
kernel = torch._C._te.TensorExprKernel(graph, {'aten::nan_to_num' : my_custom_lowering})
kernel = te.TensorExprKernel(graph, {'aten::nan_to_num' : my_custom_lowering})
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
@ -367,7 +361,7 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
@ -376,18 +370,15 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_alloc_in_loop(self):
a, tmp, b = [
te.Placeholder(name, te.Dtype.Float, [te.ExprHandle.int(1)])
for name in ["a", "tmp", "b"]]
t0, t100 = [te.ExprHandle.int(n) for n in [0, 100]]
a, tmp, b = [te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"]]
body = te.Block([
tmp.store([t0], a.load([t0])),
b.store([t0], tmp.load([t0]))
tmp.store([0], a.load([0])),
b.store([0], tmp.load([0]))
])
for _ in range(4):
i = te.VarHandle("i", te.Dtype.Int)
body = te.For.make(i, t0, t100, body)
nest = te.LoopNest(body, [b.data()])
i = te.VarHandle("i", torch.int32)
body = te.For.make(i, 0, 100, body)
nest = te.LoopNest(body, [b])
nest.prepare_for_codegen()
f = te.construct_codegen("llvm", nest.simplify(), [a, b])
ta, tb = [torch.ones(1) for _ in range(2)]

View File

@ -75,6 +75,13 @@ void initTensorExprBindings(PyObject* module) {
auto expr_handle_class =
py::class_<ExprHandle>(te, "ExprHandle")
.def(
"__str__",
[](const ExprHandle& self) {
std::stringstream ss;
ss << self;
return ss.str();
})
.def(py::self + py::self)
.def(py::self * py::self)
.def(py::self - py::self)
@ -124,7 +131,23 @@ void initTensorExprBindings(PyObject* module) {
.def("trunc", [](const ExprHandle& self) { return trunc(self); })
.def("frac", [](const ExprHandle& self) { return frac(self); })
.def("lgamma", [](const ExprHandle& self) { return lgamma(self); })
.def("isnan", [](const ExprHandle& self) { return isnan(self); });
.def("isnan", [](const ExprHandle& self) { return isnan(self); })
.def(
"cast",
[](const ExprHandle& self, const Dtype& dt) {
return Cast::make(dt, self);
})
#define EXPRHANDLE_INIT(ctype, name) \
.def(py::init([](ctype val) { return name##Imm::make(val); }))
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_INIT)
#undef EXPRHANDLE_INIT
;
#define EXPRHANDLE_IMPL_CONV(ctype, name) \
py::implicitly_convertible<ctype, ExprHandle>();
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_IMPL_CONV)
#undef EXPRHANDLE_IMPL_CONV
te.def(
"ifThenElse",
[](const ExprHandle& c, const ExprHandle& t, const ExprHandle& f) {
@ -149,6 +172,13 @@ void initTensorExprBindings(PyObject* module) {
#undef EXPRHANDLE_CTOR
py::class_<VarHandle, ExprHandle>(te, "VarHandle")
.def(
"__str__",
[](const ExprHandle& self) {
std::stringstream ss;
ss << self;
return ss.str();
})
.def(py::init<Dtype>())
.def(py::init<const std::string&, Dtype>());
py::class_<BufHandle, ExprHandle>( // NOLINT
@ -163,9 +193,16 @@ void initTensorExprBindings(PyObject* module) {
[](BufHandle& self, const std::vector<ExprHandle>& v) {
return Load::make(self, v);
})
.def("load", [](BufHandle& self, const ExprHandle& v) {
return Load::make(self, {v});
});
.def(
"load",
[](BufHandle& self, const ExprHandle& v) {
return Load::make(self, {v});
})
.def(
"store",
[](BufHandle& self,
const std::vector<ExprHandle>& args,
const ExprHandle& val) { return Store::make(self, args, val); });
py::class_<Placeholder>(te, "Placeholder")
.def(py::init<
@ -196,12 +233,20 @@ void initTensorExprBindings(PyObject* module) {
.def("buf", [](Tensor& self) { return BufHandle(self.buf()); })
.def("stmt", &Tensor::stmt);
py::class_<Cast, std::shared_ptr<Cast>>(te, "Cast")
.def_static("make", &Cast::make);
.def_static("make", &Cast::make)
.def(
"src_value",
[](CastPtr& self) { return ExprHandle(self->src_value()); })
.def("set_src_value", [](CastPtr& self, const ExprHandle& value) {
self->set_src_value(value.node());
});
py::class_<DimArg>(te, "DimArg")
.def(py::init<const ExprHandle&>())
.def(py::init<const ExprHandle&, const std::string&>());
py::implicitly_convertible<ExprHandle, DimArg>();
py::implicitly_convertible<int32_t, DimArg>();
py::implicitly_convertible<int64_t, DimArg>();
te.def(
"Compute",
@ -584,7 +629,7 @@ void initTensorExprBindings(PyObject* module) {
py::return_value_policy::reference)
.def(
"flatten",
[](const std::vector<ForPtr>& loops) {
[](LoopNest& self, const std::vector<ForPtr>& loops) {
ForPtr flattened = nullptr;
LoopNest::flatten(loops, &flattened);
return flattened;