mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] Add a pass for annotating graph with input types derived from sample inputs. (#57076)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57076 This pass is intended to be used in conjunction with shape propagation pass: first we use sample inputs to specify shape info for graph inputs and then we run shape-prop to infer shapes of intermediate values in the graph. Differential Revision: D28048290 Test Plan: Imported from OSS Reviewed By: astaff Pulled By: ZolotukhinM fbshipit-source-id: 778d772e873d59d77af9f669f45dc44b9ee5e443
This commit is contained in:
committed by
Facebook GitHub Bot
parent
74a4868d9a
commit
3ad3d8bd3f
@ -83,14 +83,12 @@ graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
with kernel_arena_scope():
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x, y, z))
|
||||
res2 = kernel.fallback((x, y, z))
|
||||
correct = f(x, y, z)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x, y, z))
|
||||
res2 = kernel.fallback((x, y, z))
|
||||
correct = f(x, y, z)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_scalar_inputs(self):
|
||||
@ -111,13 +109,50 @@ graph(%a.1 : Float(requires_grad=0, device=cpu),
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
with kernel_arena_scope():
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x, y, z))
|
||||
res2 = kernel.fallback((x, y, z))
|
||||
correct = f(x, y, z)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_shape_prop(self):
|
||||
device, size = 'cpu', (4, 4)
|
||||
x = torch.rand(size, device=device)
|
||||
y = torch.rand(size, device=device)
|
||||
|
||||
graph_str = """
|
||||
graph(%a : Tensor, %b : Tensor):
|
||||
%c : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%a, %b)
|
||||
return (%c)
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
exception_thrown = False
|
||||
try:
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x, y, z))
|
||||
res2 = kernel.fallback((x, y, z))
|
||||
correct = f(x, y, z)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
except RuntimeError:
|
||||
# Graph doesn't have shape info for inputs => compilation should
|
||||
# fail
|
||||
exception_thrown = True
|
||||
pass
|
||||
assert exception_thrown
|
||||
|
||||
# Inject shape info and try compiling again
|
||||
example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
|
||||
torch._C._te.annotate_input_shapes(graph, example_inputs)
|
||||
|
||||
# TODO: once we have shape propagation as well we should erase type
|
||||
# info for %c from the input IR and run shape propagation here - it
|
||||
# should be able to reconstruct that info
|
||||
|
||||
# Now compilation should pass
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
|
||||
res = kernel.run((x, y))
|
||||
correct = torch.mul(x, y)
|
||||
np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -227,6 +227,18 @@ bool matmulIsSupported(const torch::jit::Node* node) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void annotateInputShapes(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::vector<c10::optional<at::Tensor>>& example_inputs) {
|
||||
TORCH_INTERNAL_ASSERT(graph->inputs().size() == example_inputs.size());
|
||||
for (size_t idx = 0; idx < example_inputs.size(); idx++) {
|
||||
if (auto t = example_inputs[idx]) {
|
||||
auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t);
|
||||
graph->inputs().at(idx)->setType(concrete_tensor_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
@ -2063,8 +2075,9 @@ Tensor* tensorexpr::computeOperandValue(
|
||||
return computeCat(inputs, outputShape);
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unhandled node kind");
|
||||
return nullptr;
|
||||
std::string msg =
|
||||
std::string("Unhandled node kind: ") + op.toQualString();
|
||||
throw malformed_input(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2193,7 +2206,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
||||
}
|
||||
|
||||
default: {
|
||||
throw std::runtime_error("Unhandled node kind");
|
||||
std::string msg = std::string("Unhandled node kind: ") +
|
||||
v->node()->kind().toQualString();
|
||||
throw malformed_input(msg);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
@ -2463,6 +2478,11 @@ Tensor* TensorExprKernel::bindInput(const torch::jit::Value* input) {
|
||||
switch (t->kind()) {
|
||||
case TypeKind::TensorType: {
|
||||
auto tt = input->type()->cast<TensorType>();
|
||||
if (!input->isCompleteTensor()) {
|
||||
std::string msg = std::string("Shapes for input '%") +
|
||||
input->debugName() + "' are unknown";
|
||||
throw malformed_input(msg);
|
||||
}
|
||||
Placeholder inBuffer(
|
||||
"t" + input_name_map_[input],
|
||||
ToDtype(static_cast<ScalarType>(*tt->scalarType())),
|
||||
@ -2869,6 +2889,13 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
|
||||
TORCH_INTERNAL_ASSERT(bufs_.count(v));
|
||||
const Buf* buf = bufs_.at(v);
|
||||
|
||||
// No shape info is present in the graph
|
||||
if (!tt->sizes().concrete_sizes()) {
|
||||
std::string msg =
|
||||
std::string("Shapes for output '%") + v->debugName() + "' are unknown";
|
||||
throw malformed_input(msg);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(tt->sizes().concrete_sizes());
|
||||
const auto sizes = *tt->sizes().concrete_sizes();
|
||||
std::vector<int64_t> default_strides = TensorType::contiguousStridesOf(sizes);
|
||||
@ -2887,13 +2914,13 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
|
||||
|
||||
auto dims = dimsFromSizes(sizesForValue(v));
|
||||
// We need to convert the output tensor so that its values are layed
|
||||
// so that whene viewed from the output strides the values are correct.
|
||||
// so that when viewed from the output strides the values are correct.
|
||||
// A contiguous Tensor of size(2, 3) with values 0-5 is layed out as:
|
||||
// [0] [1] [2] [3] [4] [5]
|
||||
// The same valued tensor with strides (2, 1) would be layed out like
|
||||
// [0] [3] [1] [4] [2] [5]
|
||||
// When we are doing the re-ordering of values into the output tensor,
|
||||
// we are iterating per-element of the input, ad we are fixed
|
||||
// we are iterating per-element of the input, and we are fixed
|
||||
// in indexing in to the output tensor at [i, j] = val
|
||||
// `val` we want here is equal to the indices for the output
|
||||
// tensor that would have given the same position as the output
|
||||
|
@ -217,6 +217,10 @@ TORCH_API bool& getCatWoConditionals();
|
||||
TORCH_API c10::optional<at::Device> pickDeviceType(
|
||||
const at::ArrayRef<torch::jit::Value*>& inputs);
|
||||
|
||||
TORCH_API void annotateInputShapes(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::vector<c10::optional<at::Tensor>>& example_inputs);
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -667,6 +667,7 @@ void initTensorExprBindings(PyObject* module) {
|
||||
}
|
||||
return cg;
|
||||
});
|
||||
te.def("annotate_input_shapes", &tensorexpr::annotateInputShapes);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user