From 85517a2b700a5abc0b38f53ce8c99404cd67db79 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 16 Jun 2021 19:59:42 -0700 Subject: [PATCH] [TensorExpr] More python binding cleanups (#60058) Summary: A few more quality of life improvements for NNC's python bindings: - Use standard `torch.dtype`s (rather than `te.Dtype`) - Make names optional (they don't seem to matter) - Make shapes optional - A few implicit conversions to make code cleaner Followup to https://github.com/pytorch/pytorch/issues/59920 Pull Request resolved: https://github.com/pytorch/pytorch/pull/60058 Reviewed By: bertmaher Differential Revision: D29151953 Pulled By: jansel fbshipit-source-id: c8286e329eb4ee3921ca0786e17248cf6a898bd8 --- test/test_tensorexpr_pybind.py | 25 +++++++++------ torch/csrc/jit/tensorexpr/expr.h | 5 +++ torch/csrc/jit/tensorexpr/tensor.h | 6 ++++ torch/csrc/jit/tensorexpr/tensorexpr_init.cpp | 31 ++++++++++++++++--- 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/test/test_tensorexpr_pybind.py b/test/test_tensorexpr_pybind.py index a3efb8416a37..cc4551515bb4 100644 --- a/test/test_tensorexpr_pybind.py +++ b/test/test_tensorexpr_pybind.py @@ -49,7 +49,7 @@ class TestTensorExprPyBind(JitTestCase): def test_call_raw(self): with kernel_arena_scope(): n = 16 - cg = construct_adder(n, dtype=te.Dtype.Double) + cg = construct_adder(n, dtype=torch.float64) tA = torch.randn(n, dtype=torch.float64) tB = torch.randn(n, dtype=torch.float64) @@ -59,7 +59,7 @@ class TestTensorExprPyBind(JitTestCase): def test_external_calls(self): with kernel_arena_scope(): - dtype = te.Dtype.Float + dtype = torch.float32 ONE = te.ExprHandle.int(1) FOUR = te.ExprHandle.int(4) @@ -81,22 +81,21 @@ class TestTensorExprPyBind(JitTestCase): def test_dynamic_shape(self): with kernel_arena_scope(): - dN = te.VarHandle("n", te.Dtype.Int) - A = te.Placeholder('A', te.Dtype.Double, [dN]) - B = te.Placeholder('B', te.Dtype.Double, [dN]) + dN = te.VarHandle(torch.int32) + A = te.BufHandle(torch.float64) + B = te.BufHandle(torch.float64) def compute(i): - return A.load([i]) - B.load([i]) + return A.load(i) - B.load(i) - C = te.Compute('C', [te.DimArg(dN, 'i')], compute) + C = te.Compute('C', [dN], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() - stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen( 'ir_eval', - stmt, + loopnest.simplify(), [A, B, C, dN]) def test_with_shape(n): @@ -109,6 +108,14 @@ class TestTensorExprPyBind(JitTestCase): test_with_shape(8) test_with_shape(31) + def test_dtype_error(self): + with kernel_arena_scope(): + one = te.ExprHandle.int(1) + te.Placeholder([one], torch.float32) # ok + te.Placeholder([one]) # ok + self.assertRaises(TypeError, + lambda: te.Placeholder([one], "float55")) + @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_with_tensor_inputs(self): def f(a, b, c): diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index 6e36be55fe71..2f96d8c4de0a 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -247,6 +247,11 @@ class TORCH_API BufHandle : public ExprHandle { Dtype dtype) : ExprHandle(Buf::make(name_hint, dims, dtype)) {} + BufHandle(const std::vector& dims, Dtype dtype) + : ExprHandle(Buf::make("_", dims, dtype)) {} + + explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {} + explicit BufHandle(const Buf* node) : ExprHandle(node) {} const Buf* node() const { return static_cast(ExprHandle::node()); diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h index 436f3052db9e..95c98af0bdce 100644 --- a/torch/csrc/jit/tensorexpr/tensor.h +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -82,6 +82,12 @@ class Placeholder { const std::vector& dims) : Placeholder(BufHandle(name, dims, dtype)) {} + Placeholder(const std::vector& dims, const Dtype& dtype) + : Placeholder(BufHandle("_", dims, dtype)) {} + + explicit Placeholder(const std::vector& dims) + : Placeholder(BufHandle("_", dims, kFloat)) {} + const Buf* data() const { return data_; } diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp index c247ecb9b720..c1e5fc6aa4f0 100644 --- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp @@ -44,7 +44,15 @@ ArgValue convertPyToArgValue(py::handle inp) { throw std::runtime_error("vector conversion failed"); } } else { - throw std::runtime_error("nyi"); + throw std::runtime_error("conversion not yet implemented"); + } +} + +Dtype parsePythonDtype(py::handle obj) { + if (py::isinstance(obj, py::module_::import("torch").attr("dtype"))) { + return Dtype(reinterpret_cast(obj.ptr())->scalar_type); + } else { + throw std::runtime_error("expected a torch.dtype instance"); } } @@ -55,7 +63,9 @@ void initTensorExprBindings(PyObject* module) { auto te = m.def_submodule("_te"); py::class_(te, "KernelScope").def(py::init<>()); - auto dtype_class = py::class_(te, "Dtype"); + auto dtype_class = + py::class_(te, "Dtype").def(py::init(&parsePythonDtype)); + py::implicitly_convertible(); #define DTYPE_SINGLETON_ACCESSOR(ctype, name) \ dtype_class.def_property_readonly_static( \ @@ -139,21 +149,31 @@ void initTensorExprBindings(PyObject* module) { #undef EXPRHANDLE_CTOR py::class_(te, "VarHandle") + .def(py::init()) .def(py::init()); py::class_( // NOLINT te, "BufHandle") .def( py::init&, Dtype>()) - .def("load", [](BufHandle& self, const std::vector& v) { - return Load::make(self, v); + .def(py::init&, Dtype>()) + .def(py::init()) + .def( + "load", + [](BufHandle& self, const std::vector& v) { + return Load::make(self, v); + }) + .def("load", [](BufHandle& self, const ExprHandle& v) { + return Load::make(self, {v}); }); py::class_(te, "Placeholder") .def(py::init< const std::string&, const Dtype&, - std::vector&>()) + const std::vector&>()) + .def(py::init&, const Dtype&>()) + .def(py::init&>()) .def( "load", [](Placeholder& self, const std::vector& v) { @@ -183,6 +203,7 @@ void initTensorExprBindings(PyObject* module) { py::class_(te, "DimArg") .def(py::init()) .def(py::init()); + py::implicitly_convertible(); te.def( "Compute",