mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/83154 Approved by: https://github.com/ezyang
51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
import torch
|
|
import torch._C as _C
|
|
from torch.utils._pytree import tree_flatten
|
|
|
|
"""
|
|
This is a dispatcher (in Python)
|
|
- You can define new operations (in Python) without schemas
|
|
- It interfaces with the PyTorch dispatcher
|
|
"""
|
|
|
|
class PyDispatcher:
|
|
# operator is a PyOperator
|
|
@staticmethod
|
|
def call(operator, *args, **kwargs):
|
|
dispatch_key_set = compute_keyset(args, kwargs)
|
|
kernel = operator.lookup(dispatch_key_set)
|
|
return kernel(*args, **kwargs)
|
|
|
|
# operator is a PyOperator
|
|
@staticmethod
|
|
def redispatch(operator, dispatch_key_set, *args, **kwargs):
|
|
kernel = operator.lookup(dispatch_key_set)
|
|
return kernel(*args, **kwargs)
|
|
|
|
|
|
def compute_keyset(args, kwargs):
|
|
tensors = get_tensors(args, kwargs)
|
|
return key_extractor(tensors)
|
|
|
|
|
|
# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
|
|
# at ATen/core/dispatch/DispatchKeyExtractor.h
|
|
def key_extractor(tensors):
|
|
key_set = _C._dispatch_tls_local_include_set() # type: ignore[attr-defined]
|
|
for tensor in tensors:
|
|
key_set = key_set | _C._dispatch_keys(tensor) # type: ignore[attr-defined]
|
|
key_set = key_set - _C._dispatch_tls_local_exclude_set() # type: ignore[attr-defined]
|
|
return key_set
|
|
|
|
|
|
def to_flat_tuple(args, kwargs):
|
|
flat_args, _ = tree_flatten(args)
|
|
flat_kwargs, _ = tree_flatten(kwargs)
|
|
flat_all = flat_args + flat_kwargs
|
|
return flat_all
|
|
|
|
def get_tensors(args, kwargs):
|
|
flat_all = to_flat_tuple(args, kwargs)
|
|
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
|
|
return tuple(tensor_args)
|