mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 10:01:39 +08:00
Per title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76626 Approved by: https://github.com/ngimel
111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
from typing import Callable
|
|
|
|
import torch
|
|
|
|
from torch.fx import GraphModule
|
|
from torch._prims.utils import TensorMeta, getnvFuserDtype
|
|
from torch._prims.context import PrimContext
|
|
|
|
if torch.cuda.is_available():
|
|
from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import]
|
|
|
|
|
|
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
|
"""
|
|
Prototype ATen executor.
|
|
|
|
Just executes the context's graph.
|
|
"""
|
|
|
|
if executor == "aten":
|
|
gm = GraphModule({}, ctx.graph)
|
|
return gm.forward(*args, **kwargs)
|
|
elif executor == "nvfuser":
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError(
|
|
"Attempting to use nvFuser trace executor but CUDA is not available!"
|
|
)
|
|
|
|
# PROTOTYPE nvfuser executor
|
|
# Only accepts tensor inputs and single tensor outputs
|
|
# Does not handle kwargs
|
|
# Does not support reusing the same ctx to execute!
|
|
assert len(kwargs) == 0
|
|
# TODO: make this a proper trace -> trace transform that
|
|
# doesn't mutate the context
|
|
graph_fd = ctx.graph.placeholder("fd")
|
|
ctx.graph._root.append(graph_fd)
|
|
|
|
fusion = Fusion()
|
|
with FusionDefinition(fusion) as fd:
|
|
# Transforms graph to call nvfuser lowerings
|
|
nv_args = [fd]
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor):
|
|
x = fd.define_tensor(arg.ndim, getnvFuserDtype(arg.dtype))
|
|
fd.add_input(x)
|
|
nv_args.append(x)
|
|
else:
|
|
nv_args.append(x)
|
|
|
|
for x in ctx.graph.nodes:
|
|
if x.op == "call_function":
|
|
x.target = x.target.impl_nvfuser
|
|
x.args = (graph_fd,) + x.args
|
|
|
|
gm = GraphModule({}, ctx.graph)
|
|
out = gm.forward(*nv_args)
|
|
fd.add_output(out)
|
|
|
|
return fusion.execute(
|
|
tuple(arg for arg in args if isinstance(arg, torch.Tensor))
|
|
)[0]
|
|
|
|
msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
|
|
executor
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
|
|
def make_traced(fn: Callable):
|
|
"""
|
|
Returns a function that, when called, will
|
|
trace its torch operations to prims and then
|
|
execute those prims on the requested trace executor
|
|
(possibly lowering them to that trace executor first).
|
|
|
|
Only supports the torch operations defined in _torch_to_reference_map
|
|
in context.py and operations with positional args. All args must
|
|
be tensors and the function must return a single tensor. In the
|
|
near future all these restrictions will be lifted.
|
|
|
|
Example usage:
|
|
|
|
def foo(a, b):
|
|
return torch.add(a, b)
|
|
|
|
traced_foo = make_traced(foo)
|
|
|
|
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
|
|
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
|
|
result = traced_foo(a, b, executor='nvfuser')
|
|
|
|
Executor may be either 'aten' or 'nvfuser'.
|
|
"""
|
|
|
|
def _traced(*args, executor="aten"):
|
|
ctx = PrimContext()
|
|
with ctx:
|
|
placeholders = []
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor):
|
|
placeholders.append(ctx.placeholder(TensorMeta(arg)))
|
|
else:
|
|
placeholders.append(ctx.placeholder(arg))
|
|
|
|
result = fn(*placeholders)
|
|
ctx.output(result)
|
|
return execute(ctx, *args, executor=executor)
|
|
|
|
return _traced
|