mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] More python binding cleanups (#60058)
Summary: A few more quality of life improvements for NNC's python bindings: - Use standard `torch.dtype`s (rather than `te.Dtype`) - Make names optional (they don't seem to matter) - Make shapes optional - A few implicit conversions to make code cleaner Followup to https://github.com/pytorch/pytorch/issues/59920 Pull Request resolved: https://github.com/pytorch/pytorch/pull/60058 Reviewed By: bertmaher Differential Revision: D29151953 Pulled By: jansel fbshipit-source-id: c8286e329eb4ee3921ca0786e17248cf6a898bd8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c01939a9b1
commit
85517a2b70
@ -49,7 +49,7 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
def test_call_raw(self):
|
||||
with kernel_arena_scope():
|
||||
n = 16
|
||||
cg = construct_adder(n, dtype=te.Dtype.Double)
|
||||
cg = construct_adder(n, dtype=torch.float64)
|
||||
|
||||
tA = torch.randn(n, dtype=torch.float64)
|
||||
tB = torch.randn(n, dtype=torch.float64)
|
||||
@ -59,7 +59,7 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
|
||||
def test_external_calls(self):
|
||||
with kernel_arena_scope():
|
||||
dtype = te.Dtype.Float
|
||||
dtype = torch.float32
|
||||
|
||||
ONE = te.ExprHandle.int(1)
|
||||
FOUR = te.ExprHandle.int(4)
|
||||
@ -81,22 +81,21 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
|
||||
def test_dynamic_shape(self):
|
||||
with kernel_arena_scope():
|
||||
dN = te.VarHandle("n", te.Dtype.Int)
|
||||
A = te.Placeholder('A', te.Dtype.Double, [dN])
|
||||
B = te.Placeholder('B', te.Dtype.Double, [dN])
|
||||
dN = te.VarHandle(torch.int32)
|
||||
A = te.BufHandle(torch.float64)
|
||||
B = te.BufHandle(torch.float64)
|
||||
|
||||
def compute(i):
|
||||
return A.load([i]) - B.load([i])
|
||||
return A.load(i) - B.load(i)
|
||||
|
||||
C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
|
||||
C = te.Compute('C', [dN], compute)
|
||||
|
||||
loopnest = te.LoopNest([C])
|
||||
loopnest.prepare_for_codegen()
|
||||
stmt = te.simplify(loopnest.root_stmt())
|
||||
|
||||
cg = te.construct_codegen(
|
||||
'ir_eval',
|
||||
stmt,
|
||||
loopnest.simplify(),
|
||||
[A, B, C, dN])
|
||||
|
||||
def test_with_shape(n):
|
||||
@ -109,6 +108,14 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
test_with_shape(8)
|
||||
test_with_shape(31)
|
||||
|
||||
def test_dtype_error(self):
|
||||
with kernel_arena_scope():
|
||||
one = te.ExprHandle.int(1)
|
||||
te.Placeholder([one], torch.float32) # ok
|
||||
te.Placeholder([one]) # ok
|
||||
self.assertRaises(TypeError,
|
||||
lambda: te.Placeholder([one], "float55"))
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_tensor_inputs(self):
|
||||
def f(a, b, c):
|
||||
|
@ -247,6 +247,11 @@ class TORCH_API BufHandle : public ExprHandle {
|
||||
Dtype dtype)
|
||||
: ExprHandle(Buf::make(name_hint, dims, dtype)) {}
|
||||
|
||||
BufHandle(const std::vector<ExprHandle>& dims, Dtype dtype)
|
||||
: ExprHandle(Buf::make("_", dims, dtype)) {}
|
||||
|
||||
explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {}
|
||||
|
||||
explicit BufHandle(const Buf* node) : ExprHandle(node) {}
|
||||
const Buf* node() const {
|
||||
return static_cast<const Buf*>(ExprHandle::node());
|
||||
|
@ -82,6 +82,12 @@ class Placeholder {
|
||||
const std::vector<ExprHandle>& dims)
|
||||
: Placeholder(BufHandle(name, dims, dtype)) {}
|
||||
|
||||
Placeholder(const std::vector<ExprHandle>& dims, const Dtype& dtype)
|
||||
: Placeholder(BufHandle("_", dims, dtype)) {}
|
||||
|
||||
explicit Placeholder(const std::vector<ExprHandle>& dims)
|
||||
: Placeholder(BufHandle("_", dims, kFloat)) {}
|
||||
|
||||
const Buf* data() const {
|
||||
return data_;
|
||||
}
|
||||
|
@ -44,7 +44,15 @@ ArgValue convertPyToArgValue(py::handle inp) {
|
||||
throw std::runtime_error("vector conversion failed");
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("nyi");
|
||||
throw std::runtime_error("conversion not yet implemented");
|
||||
}
|
||||
}
|
||||
|
||||
Dtype parsePythonDtype(py::handle obj) {
|
||||
if (py::isinstance(obj, py::module_::import("torch").attr("dtype"))) {
|
||||
return Dtype(reinterpret_cast<THPDtype*>(obj.ptr())->scalar_type);
|
||||
} else {
|
||||
throw std::runtime_error("expected a torch.dtype instance");
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,7 +63,9 @@ void initTensorExprBindings(PyObject* module) {
|
||||
auto te = m.def_submodule("_te");
|
||||
py::class_<KernelScope>(te, "KernelScope").def(py::init<>());
|
||||
|
||||
auto dtype_class = py::class_<Dtype>(te, "Dtype");
|
||||
auto dtype_class =
|
||||
py::class_<Dtype>(te, "Dtype").def(py::init(&parsePythonDtype));
|
||||
py::implicitly_convertible<py::object, Dtype>();
|
||||
|
||||
#define DTYPE_SINGLETON_ACCESSOR(ctype, name) \
|
||||
dtype_class.def_property_readonly_static( \
|
||||
@ -139,21 +149,31 @@ void initTensorExprBindings(PyObject* module) {
|
||||
#undef EXPRHANDLE_CTOR
|
||||
|
||||
py::class_<VarHandle, ExprHandle>(te, "VarHandle")
|
||||
.def(py::init<Dtype>())
|
||||
.def(py::init<const std::string&, Dtype>());
|
||||
py::class_<BufHandle, ExprHandle>( // NOLINT
|
||||
te,
|
||||
"BufHandle")
|
||||
.def(
|
||||
py::init<const std::string&, const std::vector<ExprHandle>&, Dtype>())
|
||||
.def("load", [](BufHandle& self, const std::vector<ExprHandle>& v) {
|
||||
return Load::make(self, v);
|
||||
.def(py::init<const std::vector<ExprHandle>&, Dtype>())
|
||||
.def(py::init<Dtype>())
|
||||
.def(
|
||||
"load",
|
||||
[](BufHandle& self, const std::vector<ExprHandle>& v) {
|
||||
return Load::make(self, v);
|
||||
})
|
||||
.def("load", [](BufHandle& self, const ExprHandle& v) {
|
||||
return Load::make(self, {v});
|
||||
});
|
||||
|
||||
py::class_<Placeholder>(te, "Placeholder")
|
||||
.def(py::init<
|
||||
const std::string&,
|
||||
const Dtype&,
|
||||
std::vector<ExprHandle>&>())
|
||||
const std::vector<ExprHandle>&>())
|
||||
.def(py::init<const std::vector<ExprHandle>&, const Dtype&>())
|
||||
.def(py::init<const std::vector<ExprHandle>&>())
|
||||
.def(
|
||||
"load",
|
||||
[](Placeholder& self, const std::vector<ExprHandle>& v) {
|
||||
@ -183,6 +203,7 @@ void initTensorExprBindings(PyObject* module) {
|
||||
py::class_<DimArg>(te, "DimArg")
|
||||
.def(py::init<const ExprHandle&>())
|
||||
.def(py::init<const ExprHandle&, const std::string&>());
|
||||
py::implicitly_convertible<ExprHandle, DimArg>();
|
||||
|
||||
te.def(
|
||||
"Compute",
|
||||
|
Reference in New Issue
Block a user