[JIT] Add a pass for annotating graph with input types derived from sample inputs. (#57076)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57076

This pass is intended to be used in conjunction with shape propagation
pass: first we use sample inputs to specify shape info for graph inputs
and then we run shape-prop to infer shapes of intermediate values in the
graph.

Differential Revision: D28048290

Test Plan: Imported from OSS

Reviewed By: astaff

Pulled By: ZolotukhinM

fbshipit-source-id: 778d772e873d59d77af9f669f45dc44b9ee5e443
This commit is contained in:
Mikhail Zolotukhin
2021-05-03 20:00:00 -07:00
committed by Facebook GitHub Bot
parent 74a4868d9a
commit 3ad3d8bd3f
4 changed files with 86 additions and 19 deletions

View File

@ -83,14 +83,12 @@ graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
"""
graph = torch._C.parse_ir(graph_str)
with kernel_arena_scope():
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_scalar_inputs(self):
@ -111,13 +109,50 @@ graph(%a.1 : Float(requires_grad=0, device=cpu),
"""
graph = torch._C.parse_ir(graph_str)
with kernel_arena_scope():
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_shape_prop(self):
device, size = 'cpu', (4, 4)
x = torch.rand(size, device=device)
y = torch.rand(size, device=device)
graph_str = """
graph(%a : Tensor, %b : Tensor):
%c : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%a, %b)
return (%c)
"""
graph = torch._C.parse_ir(graph_str)
exception_thrown = False
try:
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
except RuntimeError:
# Graph doesn't have shape info for inputs => compilation should
# fail
exception_thrown = True
pass
assert exception_thrown
# Inject shape info and try compiling again
example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
torch._C._te.annotate_input_shapes(graph, example_inputs)
# TODO: once we have shape propagation as well we should erase type
# info for %c from the input IR and run shape propagation here - it
# should be able to reconstruct that info
# Now compilation should pass
kernel = torch._C._te.TensorExprKernel(graph)
res = kernel.run((x, y))
correct = torch.mul(x, y)
np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
if __name__ == '__main__':
run_tests()

View File

@ -227,6 +227,18 @@ bool matmulIsSupported(const torch::jit::Node* node) {
return true;
}
void annotateInputShapes(
const std::shared_ptr<Graph>& graph,
const std::vector<c10::optional<at::Tensor>>& example_inputs) {
TORCH_INTERNAL_ASSERT(graph->inputs().size() == example_inputs.size());
for (size_t idx = 0; idx < example_inputs.size(); idx++) {
if (auto t = example_inputs[idx]) {
auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t);
graph->inputs().at(idx)->setType(concrete_tensor_type);
}
}
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch
@ -2063,8 +2075,9 @@ Tensor* tensorexpr::computeOperandValue(
return computeCat(inputs, outputShape);
}
default: {
throw std::runtime_error("Unhandled node kind");
return nullptr;
std::string msg =
std::string("Unhandled node kind: ") + op.toQualString();
throw malformed_input(msg);
}
}
}
@ -2193,7 +2206,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
}
default: {
throw std::runtime_error("Unhandled node kind");
std::string msg = std::string("Unhandled node kind: ") +
v->node()->kind().toQualString();
throw malformed_input(msg);
}
}
return nullptr;
@ -2463,6 +2478,11 @@ Tensor* TensorExprKernel::bindInput(const torch::jit::Value* input) {
switch (t->kind()) {
case TypeKind::TensorType: {
auto tt = input->type()->cast<TensorType>();
if (!input->isCompleteTensor()) {
std::string msg = std::string("Shapes for input '%") +
input->debugName() + "' are unknown";
throw malformed_input(msg);
}
Placeholder inBuffer(
"t" + input_name_map_[input],
ToDtype(static_cast<ScalarType>(*tt->scalarType())),
@ -2869,6 +2889,13 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
TORCH_INTERNAL_ASSERT(bufs_.count(v));
const Buf* buf = bufs_.at(v);
// No shape info is present in the graph
if (!tt->sizes().concrete_sizes()) {
std::string msg =
std::string("Shapes for output '%") + v->debugName() + "' are unknown";
throw malformed_input(msg);
}
TORCH_INTERNAL_ASSERT(tt->sizes().concrete_sizes());
const auto sizes = *tt->sizes().concrete_sizes();
std::vector<int64_t> default_strides = TensorType::contiguousStridesOf(sizes);
@ -2887,13 +2914,13 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
auto dims = dimsFromSizes(sizesForValue(v));
// We need to convert the output tensor so that its values are layed
// so that whene viewed from the output strides the values are correct.
// so that when viewed from the output strides the values are correct.
// A contiguous Tensor of size(2, 3) with values 0-5 is layed out as:
// [0] [1] [2] [3] [4] [5]
// The same valued tensor with strides (2, 1) would be layed out like
// [0] [3] [1] [4] [2] [5]
// When we are doing the re-ordering of values into the output tensor,
// we are iterating per-element of the input, ad we are fixed
// we are iterating per-element of the input, and we are fixed
// in indexing in to the output tensor at [i, j] = val
// `val` we want here is equal to the indices for the output
// tensor that would have given the same position as the output

View File

@ -217,6 +217,10 @@ TORCH_API bool& getCatWoConditionals();
TORCH_API c10::optional<at::Device> pickDeviceType(
const at::ArrayRef<torch::jit::Value*>& inputs);
TORCH_API void annotateInputShapes(
const std::shared_ptr<Graph>& graph,
const std::vector<c10::optional<at::Tensor>>& example_inputs);
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -667,6 +667,7 @@ void initTensorExprBindings(PyObject* module) {
}
return cg;
});
te.def("annotate_input_shapes", &tensorexpr::annotateInputShapes);
}
} // namespace jit
} // namespace torch