[TensorExpr] More python binding cleanups (#60058)

Summary:
A few more quality of life improvements for NNC's python bindings:
- Use standard `torch.dtype`s (rather than `te.Dtype`)
- Make names optional (they don't seem to matter)
- Make shapes optional
- A few implicit conversions to make code cleaner

Followup to https://github.com/pytorch/pytorch/issues/59920

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60058

Reviewed By: bertmaher

Differential Revision: D29151953

Pulled By: jansel

fbshipit-source-id: c8286e329eb4ee3921ca0786e17248cf6a898bd8
This commit is contained in:
Jason Ansel
2021-06-16 19:59:42 -07:00
committed by Facebook GitHub Bot
parent c01939a9b1
commit 85517a2b70
4 changed files with 53 additions and 14 deletions

View File

@ -49,7 +49,7 @@ class TestTensorExprPyBind(JitTestCase):
def test_call_raw(self):
with kernel_arena_scope():
n = 16
cg = construct_adder(n, dtype=te.Dtype.Double)
cg = construct_adder(n, dtype=torch.float64)
tA = torch.randn(n, dtype=torch.float64)
tB = torch.randn(n, dtype=torch.float64)
@ -59,7 +59,7 @@ class TestTensorExprPyBind(JitTestCase):
def test_external_calls(self):
with kernel_arena_scope():
dtype = te.Dtype.Float
dtype = torch.float32
ONE = te.ExprHandle.int(1)
FOUR = te.ExprHandle.int(4)
@ -81,22 +81,21 @@ class TestTensorExprPyBind(JitTestCase):
def test_dynamic_shape(self):
with kernel_arena_scope():
dN = te.VarHandle("n", te.Dtype.Int)
A = te.Placeholder('A', te.Dtype.Double, [dN])
B = te.Placeholder('B', te.Dtype.Double, [dN])
dN = te.VarHandle(torch.int32)
A = te.BufHandle(torch.float64)
B = te.BufHandle(torch.float64)
def compute(i):
return A.load([i]) - B.load([i])
return A.load(i) - B.load(i)
C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
C = te.Compute('C', [dN], compute)
loopnest = te.LoopNest([C])
loopnest.prepare_for_codegen()
stmt = te.simplify(loopnest.root_stmt())
cg = te.construct_codegen(
'ir_eval',
stmt,
loopnest.simplify(),
[A, B, C, dN])
def test_with_shape(n):
@ -109,6 +108,14 @@ class TestTensorExprPyBind(JitTestCase):
test_with_shape(8)
test_with_shape(31)
def test_dtype_error(self):
with kernel_arena_scope():
one = te.ExprHandle.int(1)
te.Placeholder([one], torch.float32) # ok
te.Placeholder([one]) # ok
self.assertRaises(TypeError,
lambda: te.Placeholder([one], "float55"))
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_tensor_inputs(self):
def f(a, b, c):

View File

@ -247,6 +247,11 @@ class TORCH_API BufHandle : public ExprHandle {
Dtype dtype)
: ExprHandle(Buf::make(name_hint, dims, dtype)) {}
BufHandle(const std::vector<ExprHandle>& dims, Dtype dtype)
: ExprHandle(Buf::make("_", dims, dtype)) {}
explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {}
explicit BufHandle(const Buf* node) : ExprHandle(node) {}
const Buf* node() const {
return static_cast<const Buf*>(ExprHandle::node());

View File

@ -82,6 +82,12 @@ class Placeholder {
const std::vector<ExprHandle>& dims)
: Placeholder(BufHandle(name, dims, dtype)) {}
Placeholder(const std::vector<ExprHandle>& dims, const Dtype& dtype)
: Placeholder(BufHandle("_", dims, dtype)) {}
explicit Placeholder(const std::vector<ExprHandle>& dims)
: Placeholder(BufHandle("_", dims, kFloat)) {}
const Buf* data() const {
return data_;
}

View File

@ -44,7 +44,15 @@ ArgValue convertPyToArgValue(py::handle inp) {
throw std::runtime_error("vector conversion failed");
}
} else {
throw std::runtime_error("nyi");
throw std::runtime_error("conversion not yet implemented");
}
}
Dtype parsePythonDtype(py::handle obj) {
if (py::isinstance(obj, py::module_::import("torch").attr("dtype"))) {
return Dtype(reinterpret_cast<THPDtype*>(obj.ptr())->scalar_type);
} else {
throw std::runtime_error("expected a torch.dtype instance");
}
}
@ -55,7 +63,9 @@ void initTensorExprBindings(PyObject* module) {
auto te = m.def_submodule("_te");
py::class_<KernelScope>(te, "KernelScope").def(py::init<>());
auto dtype_class = py::class_<Dtype>(te, "Dtype");
auto dtype_class =
py::class_<Dtype>(te, "Dtype").def(py::init(&parsePythonDtype));
py::implicitly_convertible<py::object, Dtype>();
#define DTYPE_SINGLETON_ACCESSOR(ctype, name) \
dtype_class.def_property_readonly_static( \
@ -139,21 +149,31 @@ void initTensorExprBindings(PyObject* module) {
#undef EXPRHANDLE_CTOR
py::class_<VarHandle, ExprHandle>(te, "VarHandle")
.def(py::init<Dtype>())
.def(py::init<const std::string&, Dtype>());
py::class_<BufHandle, ExprHandle>( // NOLINT
te,
"BufHandle")
.def(
py::init<const std::string&, const std::vector<ExprHandle>&, Dtype>())
.def("load", [](BufHandle& self, const std::vector<ExprHandle>& v) {
.def(py::init<const std::vector<ExprHandle>&, Dtype>())
.def(py::init<Dtype>())
.def(
"load",
[](BufHandle& self, const std::vector<ExprHandle>& v) {
return Load::make(self, v);
})
.def("load", [](BufHandle& self, const ExprHandle& v) {
return Load::make(self, {v});
});
py::class_<Placeholder>(te, "Placeholder")
.def(py::init<
const std::string&,
const Dtype&,
std::vector<ExprHandle>&>())
const std::vector<ExprHandle>&>())
.def(py::init<const std::vector<ExprHandle>&, const Dtype&>())
.def(py::init<const std::vector<ExprHandle>&>())
.def(
"load",
[](Placeholder& self, const std::vector<ExprHandle>& v) {
@ -183,6 +203,7 @@ void initTensorExprBindings(PyObject* module) {
py::class_<DimArg>(te, "DimArg")
.def(py::init<const ExprHandle&>())
.def(py::init<const ExprHandle&, const std::string&>());
py::implicitly_convertible<ExprHandle, DimArg>();
te.def(
"Compute",