mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
This PR adds `executor_parameters` keyword argument to `torch._prims.executor.execute`. For now there are two knobs: * `use_python_fusion_cache: bool = True` whether to use lru_cache when constructing fusion object or not. * `allow_single_op_fusion: bool = True` whether to allow fusions with single callable Behavior can be controlled by passing dict with custom specified values as `executor_parameters` argument. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84482 Approved by: https://github.com/jjsjann123, https://github.com/ngimel
72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
from typing import Callable, Optional
|
|
|
|
from torch._prims.context import NvfuserPrimsMode, TorchRefsMode
|
|
from torch._prims.nvfuser_executor import nvfuser_execute, nvfuser_execute_partitioned
|
|
|
|
from torch.fx import GraphModule
|
|
from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
|
|
|
|
|
|
def execute(
|
|
gm: GraphModule,
|
|
*args,
|
|
executor: str = "aten",
|
|
executor_parameters: Optional[dict] = None,
|
|
):
|
|
"""
|
|
Prototype ATen executor.
|
|
|
|
Just executes the context's graph.
|
|
"""
|
|
|
|
if executor == "aten":
|
|
return gm.forward(*args)
|
|
elif executor == "nvfuser":
|
|
return nvfuser_execute_partitioned(
|
|
gm, *args, executor_parameters=executor_parameters
|
|
)
|
|
elif executor == "strictly_nvfuser":
|
|
return nvfuser_execute(gm, *args, executor_parameters=executor_parameters)
|
|
|
|
msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
|
|
executor
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
|
|
def make_traced(fn: Callable):
|
|
"""
|
|
Returns a function that, when called, will
|
|
trace its torch operations to prims and then
|
|
execute those prims on the requested trace executor
|
|
(possibly lowering them to that trace executor first).
|
|
|
|
Only supports the torch operations defined in _torch_to_reference_map
|
|
in context.py and operations with positional args. All args must
|
|
be tensors.
|
|
In the near future all these restrictions will be lifted.
|
|
|
|
Example usage:
|
|
|
|
def foo(a, b):
|
|
return torch.add(a, b)
|
|
|
|
traced_foo = make_traced(foo)
|
|
|
|
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
|
|
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
|
|
result = traced_foo(a, b, executor='nvfuser')
|
|
|
|
Executor may be either 'aten' or 'nvfuser'.
|
|
"""
|
|
|
|
def _traced(*args, executor="aten", **kwargs):
|
|
# TODO: caching
|
|
wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
|
|
|
|
with NvfuserPrimsMode(), TorchRefsMode():
|
|
gm = make_fx(wrapped)(all_args)
|
|
return execute(gm, all_args, executor=executor)
|
|
|
|
return _traced
|