mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
[JIT][NNC] Add handling of strides to dynamic shape support. (#70464)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70464 Add handling of strided input tensors to dynamic fusion. This is done with the same set of input striding specializations as https://github.com/pytorch/pytorch/pull/60684/: ``` S_ONE, // STRIDE_ONE: packed S_CONT, // STRIDE_CONTIGUOUS: stride[i + 1] * sizes[i + 1] S_TRAN_CONT, // STRIDE_TRANSPOSED_CONTIGUOUS: stride[i-1] * sizes[i-1] S_AS_ARG, // STRIDE_AS_ARG: stride passed in as runtime value ``` and then two additional specializations for a) contiguous tensor and b) channels-last tensor. channels-last is a common case and we should optimize for it. additionally, tensors natively store whether they are contiguous/channels-last contiguous, which makes it faster to check if tensors follow this pattern. Output striding will be done in a follow up. The striding is stored on both the TensorGroup node and on the guard node. The striding descriptors are stored as a vector of strings on the node for debugability and to make use of storing ivalues as attributes on nodes. As an example: ``` %8 : Double(10, 11, 12, 13, strides=[1716, 1, 143, 11], requires_grad=0, device=cpu) = prim::TensorExprGroup_0[symbolic_shape_inputs=[-37, -36, -35, -34], striding_inputs_desc=[["TENSOR_CONT_CHANNELS_LAST"]](%x, %24, %23, %22, %21)``` ``` Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D33458649 Pulled By: eellison fbshipit-source-id: c42616d3c683d70f6258180d23d3841a31a6030d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
975e7d246e
commit
39be20f259
@ -826,6 +826,12 @@ void initJITBindings(PyObject* module) {
|
||||
.def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed)
|
||||
.def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
|
||||
.def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
|
||||
.def(
|
||||
"_jit_set_texpr_dynamic_shape_enabled",
|
||||
&setTensorExprDynamicShapeFusionEnabled)
|
||||
.def(
|
||||
"_jit_texpr_dynamic_shape_enabled",
|
||||
&tensorExprDynamicShapeFusionEnabled)
|
||||
.def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
|
||||
.def(
|
||||
"_jit_set_te_generate_block_code",
|
||||
|
||||
Reference in New Issue
Block a user