mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Static Runtime] Fix aten::index_put list conversions (#85298)
Summary: Apparently static runtime's list construct return value is always a `GenericList`, so we cannot use the `toOptionalTensorList` method in the general case -- we must convert each item individually. Test Plan: New unit test Differential Revision: D39628979 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85298 Approved by: https://github.com/tenpercent
This commit is contained in:
committed by
PyTorch MergeBot
parent
bd854588fb
commit
e4899764b2
@ -2523,6 +2523,15 @@ TEST(StaticRuntime, Index_Put) {
|
||||
auto indices_b = c10::List<at::Tensor>{torch::tensor({0}, at::kLong)};
|
||||
std::vector<IValue> args1{a, indices_b, values_a, false};
|
||||
testStaticRuntime(index_put_non_optional_str, args1);
|
||||
|
||||
const auto index_put_list_construct = R"JIT(
|
||||
def forward(self, a: Tensor, indices: Tensor, values: Tensor, accumulate: bool):
|
||||
indices: List[Optional[Tensor]] = [indices]
|
||||
return torch.index_put(a, indices, values, accumulate).clone()
|
||||
)JIT";
|
||||
|
||||
std::vector<IValue> args2{a, torch::tensor({0}, at::kLong), values_a, false};
|
||||
testStaticRuntime(index_put_list_construct, args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Item) {
|
||||
|
@ -73,11 +73,7 @@ Static runtime's memory planner does two things:
|
||||
1) Coalesces internal allocations for tensor storage
|
||||
2) Does static analysis to figure out how to efficiently re-use memory.
|
||||
|
||||
For (2), there are two algorithms used. Specify which algorithm with
|
||||
the `memory_planner_algorithm` field in `StaticModuleOptions`. The
|
||||
algorithms are briefly described below:
|
||||
|
||||
### Standard Resizing (default)
|
||||
### Standard Resizing
|
||||
Static runtime will record the space required for each intermediate managed tensor it sees
|
||||
on the first inference iteration. An intermediate tensor is *managed* if two conditions
|
||||
are satisfied:
|
||||
@ -103,16 +99,6 @@ will occur. This is why dynamic shapes will degrade performance. With the standa
|
||||
strategy, static runtime will record the new largest tensor size in each storage group at the
|
||||
end of the iteration and allocate a buffer that is possibly bigger on the next iteration.
|
||||
|
||||
### Precomputed Offsets Memory Planner (experimental)
|
||||
This algorithm is based on [arXiv:2001.03288](https://arxiv.org/pdf/2001.03288.pdf), section 5.2 "Greedy by Size for Offset Calculation".
|
||||
|
||||
The paper describes the algorithm in detail, but the key considerations are:
|
||||
|
||||
1) This algorithm will tend to be more efficient with respect to maximum memory usage
|
||||
2) This algorithm will *not* resize the tensor buffer since recomputing offsets is a quadratic operation. Therefore,
|
||||
to avoid performance degradation, the model should be warmed up with the largest possible inputs.
|
||||
|
||||
|
||||
### Managed Output Tensors
|
||||
|
||||
`StaticRuntime` can optionally manage output tensors via the `manage_output_tensors` option in `StaticModuleOptions`.
|
||||
@ -121,17 +107,7 @@ output tensors is separated from the one containing intermediate tensors. The fo
|
||||
of the inference run, but the latter needs deallocated at the end of the run.
|
||||
|
||||
Under the hood, we store a refcounted pointer to the output arena in each returned `Tensor`. The arena is destroyed
|
||||
only when all output tensors are destroyed.
|
||||
|
||||
```
|
||||
auto output = runtime(args);
|
||||
auto& elems = output.toTupleRef().elements();
|
||||
auto tensor_1 = elems[0].toTensor();
|
||||
auto tensor_2 = elems[1].toTensor();
|
||||
|
||||
tensor_1 = at::empty({0}); // Output buffer not deallocated yet!
|
||||
tensor_2 = at::empty({0}); // This call deallocates the output buffer.
|
||||
```
|
||||
explicitly.
|
||||
|
||||
## Registering Ops
|
||||
Static runtime has three op execution modes:
|
||||
@ -255,68 +231,8 @@ upon `StaticModule` construction according to the out variant/native/JIT fallbac
|
||||
* `prim::Loop` operations have a `BlockRunner` for the execution of the looping sub-block.
|
||||
* `prim::fork` operations have `torch::jit::TaskLauncher` (`std::function<void(std::function<void()>)>`) responsible for forked graph execution.
|
||||
|
||||
### `Asynchronous Execution`
|
||||
### Asynchronous Execution
|
||||
|
||||
`StaticRuntime::runAsync()` API allows execution of asynchronous operations on `TaskLauncher` passed as arguments.
|
||||
`StaticRuntime::runAsync()` performs inline execution of parent graph on caller thread and asynchronous operations like `prim::fork` are executed
|
||||
on the launcher passed in. In the case that no launcher is provided, the execution happens on `at::launch` inter-op thread pool.
|
||||
|
||||
### `prim::fork and aten::wait`
|
||||
|
||||
`prim::fork` takes the callable function/method/Module (say `fn`) and arguments to that callable `args` and `kwargs`. Since the execution of forked function `fn` happens asynchronously and fork returns immediately after creating the async task, the `fn` may not have been executed by the time the line of code after the `fork` call is reached. Thus, `aten::wait` is used to wait for the async `fn` task to be completed. `prim::fork` nodes contain the sub-graph for the forked parts of the network. Each parent graph creates a separate instance of `StaticModule` for the
|
||||
forked sub-graph and `StaticRuntime` instances are created on the fly during runtime as the fork nodes are executed. The forked subgraph execution
|
||||
happens asynchronously on the launcher provided during `StaticRuntime::runAsync()` or by `at::launch` executor by default. `aten::wait` operator
|
||||
waits on the future returned by the corresponding `prim::fork` operation
|
||||
|
||||
#### Inter-op parallelism via fork/wait ops
|
||||
|
||||
Sample Model with independent operations can be parallelized by inserting fork/wait nodes in the graph.
|
||||
|
||||
```python
|
||||
def CNNBlock(x):
|
||||
out_1 = conv1(x)
|
||||
out_1 = conv2(out_1)
|
||||
out_1 = max_pool1(out_1)
|
||||
|
||||
out_2 = conv3(x)
|
||||
out_2 = max_pool2(out_2)
|
||||
|
||||
out_merged = conv4(out_1 + out_2)
|
||||
return out_merged
|
||||
```
|
||||
The two branches of (conv,conv,pool) operations can be parallelized by inserting fork nodes such that the execution of both the branches can
|
||||
happen in parallel:
|
||||
|
||||
```python
|
||||
def branch1(x):
|
||||
out = conv1(x)
|
||||
out = conv2(x)
|
||||
return max_pool1(out)
|
||||
|
||||
def branch2(x):
|
||||
out = conv3(x)
|
||||
return max_pool2(out)
|
||||
|
||||
def CNNBlock(x):
|
||||
fut_1 = torch.jit.fork(branch1, x)
|
||||
fut_2 = torch.jit.fork(branch2, x)
|
||||
|
||||
out_merged = conv4(torch.jit.wait(fut_1) + torch.jit.wait(fut_2))
|
||||
return out_merged
|
||||
```
|
||||
**Execution without fork/wait operations:**
|
||||
```
|
||||
<CALLER THREAD>: conv1 ─> conv2 ─> max_pool1 ─> conv3 ─> max_pool2 ─> conv4
|
||||
```
|
||||
|
||||
**Execution with fork/wait operations:**
|
||||
```
|
||||
<CALLER THREAD> : fork1 ──> fork2 ──────────> wait(fut_1) ─> wait(fut_2) ─> conv4
|
||||
| |
|
||||
| |
|
||||
<INTER-OP THREAD>: | conv3 ──────────────────> max_pool2 -> fut_2
|
||||
|
|
||||
<INTER-OP THREAD>: conv1 ─> conv2 ─> max_pool1 ──>fut_1
|
||||
```
|
||||
More examples for fork/wait operations and inter-op parallelism in PyTorch can be found at
|
||||
[Dynamic Parallelism in TorchScript](https://pytorch.org/tutorials/advanced/torch-script-parallelism.html)
|
||||
The `StaticRuntime::runAsync()` API allows the execution of asynchronous operations on the `TaskLauncher` passed as arguments.
|
||||
`StaticRuntime::runAsync()` performs inline execution of the parent graph on the caller thread. Asynchronous operations like `prim::fork` are executed
|
||||
on the launcher passed in. In the case that no launcher is provided, the execution happens via `at::launch`, i.e. on the inter-op thread pool.
|
||||
|
@ -109,12 +109,6 @@ bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
|
||||
namespace {
|
||||
|
||||
// CustomClass extending torch::CustomClassHolder can be typecasted
|
||||
// to IValue StaticRuntimeMetadata is created so that we can attach
|
||||
// SR metadata to IR's prim::fork nodes. These CustomClass needs to be
|
||||
// registered first in order to be used as IValue.below is an
|
||||
// UNUSED VARIABLE but NEEDED to invoke the class_ constructor necessary
|
||||
// for class registration.
|
||||
auto sr_metadata_registerer = torch::class_<StaticRuntimeMetadata>(
|
||||
"StaticRuntime",
|
||||
"StaticRuntimeMetadata");
|
||||
|
@ -313,10 +313,13 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) -> SROperator {
|
||||
if (n->matches(torch::schema(
|
||||
"aten::index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor")) ||
|
||||
n->matches(torch::schema(
|
||||
"aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
const auto& indices = p_node->Input(1).toOptionalTensorList();
|
||||
const auto& indices =
|
||||
at::native::toListOfOptionalTensors(p_node->Input(1).toListRef());
|
||||
const auto& values = p_node->Input(2).toTensor();
|
||||
const auto accumulate = p_node->Input(3).toBool();
|
||||
p_node->Output(0) =
|
||||
@ -324,25 +327,6 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) ->
|
||||
};
|
||||
}
|
||||
|
||||
if (n->matches(torch::schema(
|
||||
"aten::index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
const auto indices = p_node->Input(1).toTensorList();
|
||||
|
||||
c10::List<c10::optional<at::Tensor>> opt_list_indices;
|
||||
opt_list_indices.reserve(indices.size());
|
||||
for (const auto& ten : indices) {
|
||||
opt_list_indices.push_back(ten);
|
||||
}
|
||||
|
||||
const auto& values = p_node->Input(2).toTensor();
|
||||
const auto accumulate = p_node->Input(3).toBool();
|
||||
p_node->Output(0) =
|
||||
at::native::index_put(self, opt_list_indices, values, accumulate);
|
||||
};
|
||||
}
|
||||
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
});
|
||||
|
Reference in New Issue
Block a user