[Static Runtime] Fix aten::index_put list conversions (#85298)

Summary: Apparently static runtime's list construct return value is always a `GenericList`, so we cannot use the `toOptionalTensorList` method in the general case -- we must convert each item individually.

Test Plan: New unit test

Differential Revision: D39628979

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85298
Approved by: https://github.com/tenpercent
This commit is contained in:
Mike Iovine
2022-09-22 20:21:52 +00:00
committed by PyTorch MergeBot
parent bd854588fb
commit e4899764b2
4 changed files with 19 additions and 116 deletions

View File

@ -2523,6 +2523,15 @@ TEST(StaticRuntime, Index_Put) {
auto indices_b = c10::List<at::Tensor>{torch::tensor({0}, at::kLong)};
std::vector<IValue> args1{a, indices_b, values_a, false};
testStaticRuntime(index_put_non_optional_str, args1);
const auto index_put_list_construct = R"JIT(
def forward(self, a: Tensor, indices: Tensor, values: Tensor, accumulate: bool):
indices: List[Optional[Tensor]] = [indices]
return torch.index_put(a, indices, values, accumulate).clone()
)JIT";
std::vector<IValue> args2{a, torch::tensor({0}, at::kLong), values_a, false};
testStaticRuntime(index_put_list_construct, args2);
}
TEST(StaticRuntime, Item) {

View File

@ -73,11 +73,7 @@ Static runtime's memory planner does two things:
1) Coalesces internal allocations for tensor storage
2) Does static analysis to figure out how to efficiently re-use memory.
For (2), there are two algorithms used. Specify which algorithm with
the `memory_planner_algorithm` field in `StaticModuleOptions`. The
algorithms are briefly described below:
### Standard Resizing (default)
### Standard Resizing
Static runtime will record the space required for each intermediate managed tensor it sees
on the first inference iteration. An intermediate tensor is *managed* if two conditions
are satisfied:
@ -103,16 +99,6 @@ will occur. This is why dynamic shapes will degrade performance. With the standa
strategy, static runtime will record the new largest tensor size in each storage group at the
end of the iteration and allocate a buffer that is possibly bigger on the next iteration.
### Precomputed Offsets Memory Planner (experimental)
This algorithm is based on [arXiv:2001.03288](https://arxiv.org/pdf/2001.03288.pdf), section 5.2 "Greedy by Size for Offset Calculation".
The paper describes the algorithm in detail, but the key considerations are:
1) This algorithm will tend to be more efficient with respect to maximum memory usage
2) This algorithm will *not* resize the tensor buffer since recomputing offsets is a quadratic operation. Therefore,
to avoid performance degradation, the model should be warmed up with the largest possible inputs.
### Managed Output Tensors
`StaticRuntime` can optionally manage output tensors via the `manage_output_tensors` option in `StaticModuleOptions`.
@ -121,17 +107,7 @@ output tensors is separated from the one containing intermediate tensors. The fo
of the inference run, but the latter needs deallocated at the end of the run.
Under the hood, we store a refcounted pointer to the output arena in each returned `Tensor`. The arena is destroyed
only when all output tensors are destroyed.
```
auto output = runtime(args);
auto& elems = output.toTupleRef().elements();
auto tensor_1 = elems[0].toTensor();
auto tensor_2 = elems[1].toTensor();
tensor_1 = at::empty({0}); // Output buffer not deallocated yet!
tensor_2 = at::empty({0}); // This call deallocates the output buffer.
```
explicitly.
## Registering Ops
Static runtime has three op execution modes:
@ -255,68 +231,8 @@ upon `StaticModule` construction according to the out variant/native/JIT fallbac
* `prim::Loop` operations have a `BlockRunner` for the execution of the looping sub-block.
* `prim::fork` operations have `torch::jit::TaskLauncher` (`std::function<void(std::function<void()>)>`) responsible for forked graph execution.
### `Asynchronous Execution`
### Asynchronous Execution
`StaticRuntime::runAsync()` API allows execution of asynchronous operations on `TaskLauncher` passed as arguments.
`StaticRuntime::runAsync()` performs inline execution of parent graph on caller thread and asynchronous operations like `prim::fork` are executed
on the launcher passed in. In the case that no launcher is provided, the execution happens on `at::launch` inter-op thread pool.
### `prim::fork and aten::wait`
`prim::fork` takes the callable function/method/Module (say `fn`) and arguments to that callable `args` and `kwargs`. Since the execution of forked function `fn` happens asynchronously and fork returns immediately after creating the async task, the `fn` may not have been executed by the time the line of code after the `fork` call is reached. Thus, `aten::wait` is used to wait for the async `fn` task to be completed. `prim::fork` nodes contain the sub-graph for the forked parts of the network. Each parent graph creates a separate instance of `StaticModule` for the
forked sub-graph and `StaticRuntime` instances are created on the fly during runtime as the fork nodes are executed. The forked subgraph execution
happens asynchronously on the launcher provided during `StaticRuntime::runAsync()` or by `at::launch` executor by default. `aten::wait` operator
waits on the future returned by the corresponding `prim::fork` operation
#### Inter-op parallelism via fork/wait ops
Sample Model with independent operations can be parallelized by inserting fork/wait nodes in the graph.
```python
def CNNBlock(x):
out_1 = conv1(x)
out_1 = conv2(out_1)
out_1 = max_pool1(out_1)
out_2 = conv3(x)
out_2 = max_pool2(out_2)
out_merged = conv4(out_1 + out_2)
return out_merged
```
The two branches of (conv,conv,pool) operations can be parallelized by inserting fork nodes such that the execution of both the branches can
happen in parallel:
```python
def branch1(x):
out = conv1(x)
out = conv2(x)
return max_pool1(out)
def branch2(x):
out = conv3(x)
return max_pool2(out)
def CNNBlock(x):
fut_1 = torch.jit.fork(branch1, x)
fut_2 = torch.jit.fork(branch2, x)
out_merged = conv4(torch.jit.wait(fut_1) + torch.jit.wait(fut_2))
return out_merged
```
**Execution without fork/wait operations:**
```
<CALLER THREAD>: conv1 ─> conv2 ─> max_pool1 ─> conv3 ─> max_pool2 ─> conv4
```
**Execution with fork/wait operations:**
```
<CALLER THREAD> : fork1 ──> fork2 ──────────> wait(fut_1) ─> wait(fut_2) ─> conv4
| |
| |
<INTER-OP THREAD>: | conv3 ──────────────────> max_pool2 -> fut_2
|
<INTER-OP THREAD>: conv1 ─> conv2 ─> max_pool1 ──>fut_1
```
More examples for fork/wait operations and inter-op parallelism in PyTorch can be found at
[Dynamic Parallelism in TorchScript](https://pytorch.org/tutorials/advanced/torch-script-parallelism.html)
The `StaticRuntime::runAsync()` API allows the execution of asynchronous operations on the `TaskLauncher` passed as arguments.
`StaticRuntime::runAsync()` performs inline execution of the parent graph on the caller thread. Asynchronous operations like `prim::fork` are executed
on the launcher passed in. In the case that no launcher is provided, the execution happens via `at::launch`, i.e. on the inter-op thread pool.

View File

@ -109,12 +109,6 @@ bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
namespace {
// CustomClass extending torch::CustomClassHolder can be typecasted
// to IValue StaticRuntimeMetadata is created so that we can attach
// SR metadata to IR's prim::fork nodes. These CustomClass needs to be
// registered first in order to be used as IValue.below is an
// UNUSED VARIABLE but NEEDED to invoke the class_ constructor necessary
// for class registration.
auto sr_metadata_registerer = torch::class_<StaticRuntimeMetadata>(
"StaticRuntime",
"StaticRuntimeMetadata");

View File

@ -313,10 +313,13 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
"aten::index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor")) ||
n->matches(torch::schema(
"aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
const auto& indices = p_node->Input(1).toOptionalTensorList();
const auto& indices =
at::native::toListOfOptionalTensors(p_node->Input(1).toListRef());
const auto& values = p_node->Input(2).toTensor();
const auto accumulate = p_node->Input(3).toBool();
p_node->Output(0) =
@ -324,25 +327,6 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) ->
};
}
if (n->matches(torch::schema(
"aten::index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
const auto indices = p_node->Input(1).toTensorList();
c10::List<c10::optional<at::Tensor>> opt_list_indices;
opt_list_indices.reserve(indices.size());
for (const auto& ten : indices) {
opt_list_indices.push_back(ten);
}
const auto& values = p_node->Input(2).toTensor();
const auto accumulate = p_node->Input(3).toBool();
p_node->Output(0) =
at::native::index_put(self, opt_list_indices, values, accumulate);
};
}
LogAndDumpSchema(n);
return nullptr;
});