mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Partially fixes: #66328 This PR: - adds support for `ITensorList` to the dispatcher for: - computing the dispatch key - boxing and unboxing `ITensorList` - modified the codegen for structured kernels: - codegen APIs use `ITensorList` instead of `ArrayRef<Tensor>` **Changes summary:** - Signature changes due to the different APIs: - dispatcher API (e.g. `BatchingRegistrations.cpp`) - C++ API (e.g. `TensorShape.cpp`) - Miscelaneous functions used by codegen'd functions (e.g. `FunctionalTensorWrapper.*`) - Dispatcher changes for handling `ITensorList` correctly (e.g. `DispatchKeyExtractor.h`) - Signature changes of `at::cat` due to the need of `const` inside `TensorBody.h` - Forward declarations of `ITensorList` (e.g. `MethodOperators.h`) - Codegen changes, special casing structured kernels (e.g. `gen.py`) **Short description of structured kernels special casing:** I introduced, mainly, 5 types of changes to the codegen for generating code depending on whether the kernel is structured or not: 1. Added a `structured_type_override` flag to the `argument_type` function definition of the affected APIs (mainly the dispatcher and C++ APIs). - `api/cpp.py`, `api/dispatcher.py`, `api/native.py` 2. Added a `structured_type_override` member to the signature classes (e.g. `CppSignature`), since `FunctionSchema` doesn't really know whether the function is structured or not - `api/types.py` 3. Added a `part_of_structured_group` to `NativeFunction` class, which is just a convenient function to forward to `structured_type_override` wherever needed - `model.py` 4. Appropriately changed the rest of the codegen, whenever it used either the signature classes or the `arguments` function directly 5. Added a check for `const ITensorList&` type wherever there was a check for `TensorList` Pull Request resolved: https://github.com/pytorch/pytorch/pull/73350 Approved by: https://github.com/bdhirsh
57 lines
2.0 KiB
Python
57 lines
2.0 KiB
Python
import threading
|
|
from contextlib import contextmanager
|
|
from typing import Iterator, Optional
|
|
|
|
# Simple dynamic scoping implementation. The name "parametrize" comes
|
|
# from Racket.
|
|
#
|
|
# WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about
|
|
# why you need to add a toggle to the global behavior of code
|
|
# generation. The parameters here should really only be used
|
|
# for "temporary" situations, where we need to temporarily change
|
|
# the codegen in some cases because we cannot conveniently update
|
|
# all call sites, and are slated to be eliminated once all call
|
|
# sites are eliminated. If you don't have a plan for how to get there,
|
|
# DON'T add a new entry here.
|
|
|
|
|
|
class Locals(threading.local):
|
|
use_const_ref_for_mutable_tensors: Optional[bool] = None
|
|
use_ilistref_for_tensor_lists: Optional[bool] = None
|
|
|
|
|
|
_locals = Locals()
|
|
|
|
|
|
def use_const_ref_for_mutable_tensors() -> bool:
|
|
assert _locals.use_const_ref_for_mutable_tensors is not None, (
|
|
"need to initialize local.use_const_ref_for_mutable_tensors with "
|
|
"local.parametrize"
|
|
)
|
|
return _locals.use_const_ref_for_mutable_tensors
|
|
|
|
|
|
def use_ilistref_for_tensor_lists() -> bool:
|
|
assert _locals.use_ilistref_for_tensor_lists is not None, (
|
|
"need to initialize local.use_ilistref_for_tensor_lists with "
|
|
"local.parametrize"
|
|
)
|
|
return _locals.use_ilistref_for_tensor_lists
|
|
|
|
|
|
@contextmanager
|
|
def parametrize(
|
|
*, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool
|
|
) -> Iterator[None]:
|
|
old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
|
|
old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists
|
|
try:
|
|
_locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
|
|
_locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists
|
|
yield
|
|
finally:
|
|
_locals.use_const_ref_for_mutable_tensors = (
|
|
old_use_const_ref_for_mutable_tensors
|
|
)
|
|
_locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists
|