mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements: - There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping `torch._dynamo.optimize("aot_autograd")` when you meant "aot_eager" - Deleted aot_print because it's annoying and anyway there's no uses of it - Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy. - The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later. New code is 130 lines shorter! Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/89736 Approved by: https://github.com/anjali411, https://github.com/albanD
787 lines
23 KiB
Python
787 lines
23 KiB
Python
import copy
|
|
import functools
|
|
import io
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import tempfile
|
|
|
|
from typing import Dict
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
from ..output_graph import CompilerFn
|
|
|
|
from ..utils import identity
|
|
from .subgraph import SubGraph
|
|
|
|
log = logging.getLogger(__name__)
|
|
BACKENDS: Dict[str, CompilerFn] = dict()
|
|
_NP_DTYPE = {
|
|
torch.float16: np.float16,
|
|
torch.float32: np.float32,
|
|
torch.float64: np.float64,
|
|
torch.uint8: np.uint8,
|
|
torch.int8: np.int8,
|
|
torch.int16: np.int16,
|
|
torch.int32: np.int32,
|
|
torch.int64: np.longlong,
|
|
torch.bool: np.bool_,
|
|
}
|
|
|
|
|
|
def register_backend(fn):
|
|
@functools.wraps(fn)
|
|
def inner(gm, example_inputs, **kwargs):
|
|
return fn(gm, example_inputs, **kwargs)
|
|
|
|
BACKENDS[fn.__name__] = inner
|
|
return inner
|
|
|
|
|
|
def create_backend(fn):
|
|
@functools.wraps(fn)
|
|
def inner(model, example_inputs=None, **kwargs):
|
|
if model is None:
|
|
return None
|
|
|
|
if not isinstance(model, SubGraph):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
return inner(SubGraph(model, example_inputs, tmp), **kwargs)
|
|
else:
|
|
assert example_inputs is None
|
|
|
|
try:
|
|
return fn(model, **kwargs)
|
|
except KeyboardInterrupt:
|
|
raise
|
|
|
|
BACKENDS[fn.__name__] = inner
|
|
return inner
|
|
|
|
|
|
@create_backend
|
|
def eager(subgraph):
|
|
return subgraph.model
|
|
|
|
|
|
@create_backend
|
|
def ts(subgraph):
|
|
return subgraph.scripted
|
|
|
|
|
|
def reload_jit_model(subgraph, opt_fn=identity):
|
|
tmp = io.BytesIO()
|
|
torch.jit.save(subgraph.scripted, tmp)
|
|
tmp.seek(0)
|
|
model = torch.jit.load(tmp)
|
|
model = opt_fn(model)
|
|
# populate cache
|
|
for _ in range(3):
|
|
model(*subgraph.example_inputs)
|
|
return model
|
|
|
|
|
|
def reload_jit_model_ofi(subgraph):
|
|
return reload_jit_model(subgraph, torch.jit.optimize_for_inference)
|
|
|
|
|
|
@create_backend
|
|
def nnc(subgraph):
|
|
with torch.jit.fuser("fuser1"):
|
|
return reload_jit_model(subgraph)
|
|
|
|
|
|
@create_backend
|
|
def nnc_ofi(subgraph):
|
|
with torch.jit.fuser("fuser1"):
|
|
return reload_jit_model_ofi(subgraph)
|
|
|
|
|
|
@create_backend
|
|
def ts_nvfuser(subgraph):
|
|
with torch.jit.fuser("fuser2"):
|
|
return reload_jit_model(subgraph)
|
|
|
|
|
|
@create_backend
|
|
def ts_nvfuser_ofi(subgraph):
|
|
with torch.jit.fuser("fuser2"):
|
|
return reload_jit_model_ofi(subgraph)
|
|
|
|
|
|
@create_backend
|
|
def onednn(subgraph):
|
|
with torch.jit.fuser("fuser3"):
|
|
return reload_jit_model(subgraph)
|
|
|
|
|
|
@create_backend
|
|
def ofi(subgraph):
|
|
return torch.jit.optimize_for_inference(subgraph.scripted)
|
|
|
|
|
|
@create_backend
|
|
def static_runtime(subgraph):
|
|
scripted = subgraph.scripted
|
|
if hasattr(scripted, "_c"):
|
|
static_module = torch._C._jit_to_static_module(scripted._c)
|
|
else:
|
|
static_module = torch._C._jit_to_static_module(scripted.graph)
|
|
return subgraph.wrap_returns(static_module)
|
|
|
|
|
|
def onnxrt_common(subgraph, provider, onnx_filename=None):
|
|
import onnxruntime # type: ignore[import]
|
|
|
|
assert provider in onnxruntime.get_available_providers()
|
|
session = onnxruntime.InferenceSession(
|
|
onnx_filename or subgraph.onnx_filename, providers=[provider]
|
|
)
|
|
input_names = subgraph.input_names
|
|
output_names = subgraph.output_names
|
|
create_outputs = subgraph.empty_outputs_factory()
|
|
is_cpu = subgraph.is_cpu
|
|
|
|
def _call(*initial_args):
|
|
binding = session.io_binding()
|
|
args = [a.contiguous() for a in initial_args]
|
|
for name, value in zip(input_names, args):
|
|
dev = value.device
|
|
binding.bind_input(
|
|
name,
|
|
dev.type,
|
|
dev.index or 0,
|
|
_NP_DTYPE[value.dtype],
|
|
value.size(),
|
|
value.data_ptr(),
|
|
)
|
|
outputs = create_outputs()
|
|
for name, value in zip(output_names, outputs):
|
|
dev = value.device
|
|
binding.bind_output(
|
|
name,
|
|
dev.type,
|
|
dev.index or 0,
|
|
_NP_DTYPE[value.dtype],
|
|
value.size(),
|
|
value.data_ptr(),
|
|
)
|
|
session.run_with_iobinding(binding)
|
|
if is_cpu:
|
|
binding.copy_outputs_to_cpu()
|
|
return outputs
|
|
|
|
return subgraph.wrap_returns(_call)
|
|
|
|
|
|
@create_backend
|
|
def onnxrt_cpu(subgraph):
|
|
return onnxrt_common(subgraph, provider="CPUExecutionProvider")
|
|
|
|
|
|
@create_backend
|
|
def onnxrt_cuda(subgraph):
|
|
return onnxrt_common(subgraph, provider="CUDAExecutionProvider")
|
|
|
|
|
|
@create_backend
|
|
def onnx2tensorrt(subgraph):
|
|
if subgraph.will_tensorrt_barf():
|
|
# TensorRT fails violently with an abort() on this
|
|
return None
|
|
|
|
return onnxrt_common(subgraph, provider="TensorrtExecutionProvider")
|
|
|
|
|
|
@create_backend
|
|
def onnxrt_cpu_numpy(subgraph, provider="CPUExecutionProvider"):
|
|
"""Alternate version that integrates via numpy"""
|
|
import onnxruntime
|
|
|
|
assert provider in onnxruntime.get_available_providers()
|
|
ort_session = onnxruntime.InferenceSession(
|
|
subgraph.onnx_filename, providers=[provider]
|
|
)
|
|
|
|
def to_numpy(x):
|
|
try:
|
|
return x.numpy()
|
|
except RuntimeError:
|
|
return x.detach().numpy()
|
|
|
|
def _call(*args):
|
|
res = ort_session.run(
|
|
None, {f"i{i}": to_numpy(arg) for i, arg in enumerate(args)}
|
|
)
|
|
res = [torch.from_numpy(x) for x in res]
|
|
return res
|
|
|
|
return subgraph.wrap_returns(_call)
|
|
|
|
|
|
@create_backend
|
|
def onnxrt(subgraph):
|
|
if subgraph.is_cuda:
|
|
return onnxrt_cuda(subgraph)
|
|
else:
|
|
return onnxrt_cpu(subgraph)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _init_tensorflow():
|
|
import tensorflow as tf # type: ignore[import]
|
|
|
|
# prevent tensorflow from eating all the GPU memory
|
|
gpus = tf.config.list_physical_devices("GPU")
|
|
for gpu in gpus:
|
|
tf.config.experimental.set_memory_growth(gpu, True)
|
|
return tf
|
|
|
|
|
|
@create_backend
|
|
def onnx2tf(subgraph):
|
|
import onnx # type: ignore[import]
|
|
from onnx_tf.backend import prepare # type: ignore[import]
|
|
|
|
tf = _init_tensorflow()
|
|
filename = subgraph.filename("tensorflow")
|
|
input_names = subgraph.input_names
|
|
output_names = subgraph.output_names
|
|
device = "/CPU:0" if subgraph.is_cpu else f"/GPU:{subgraph.device_index}"
|
|
with tf.device(device):
|
|
if not os.path.exists(filename):
|
|
prepare(onnx.load(subgraph.onnx_filename)).export_graph(filename)
|
|
tf_module = tf.saved_model.load(filename)
|
|
tf_module = tf.function(tf_module, jit_compile=True)
|
|
|
|
def run(*i_args):
|
|
args = [a.contiguous() for a in i_args]
|
|
with tf.device(device):
|
|
outs = tf_module(
|
|
**{
|
|
name: tf.experimental.dlpack.from_dlpack(
|
|
torch.utils.dlpack.to_dlpack(args[idx])
|
|
)
|
|
for idx, name in enumerate(input_names)
|
|
}
|
|
)
|
|
return [
|
|
torch.utils.dlpack.from_dlpack(
|
|
tf.experimental.dlpack.to_dlpack(outs[name])
|
|
)
|
|
for name in output_names
|
|
]
|
|
|
|
return subgraph.wrap_returns(run)
|
|
|
|
|
|
@create_backend
|
|
def taso(subgraph):
|
|
taso_filename = subgraph.filename("taso")
|
|
subprocess.check_call(
|
|
[
|
|
os.path.expanduser("~/conda/envs/taso/bin/python"),
|
|
"-c",
|
|
"import taso,onnx; onnx.save(taso.export_onnx(taso.optimize("
|
|
f"taso.load_onnx('{subgraph.onnx_filename}'))), '{taso_filename}')",
|
|
]
|
|
)
|
|
return onnxrt_common(
|
|
subgraph, provider="CUDAExecutionProvider", onnx_filename=taso_filename
|
|
)
|
|
|
|
|
|
@create_backend
|
|
def ipex(subgraph, **kwargs):
|
|
import intel_extension_for_pytorch as ipex # type: ignore[import]
|
|
|
|
inputs = subgraph.example_inputs
|
|
model = subgraph.model
|
|
with torch.no_grad():
|
|
model.eval()
|
|
if kwargs["datatype"] == "bf16":
|
|
model = ipex.optimize(model, dtype=torch.bfloat16)
|
|
else:
|
|
model = ipex.optimize(model, dtype=torch.float32)
|
|
try:
|
|
traced_model = torch.jit.trace(model, inputs).eval()
|
|
traced_model = torch.jit.freeze(traced_model)
|
|
return traced_model
|
|
except Exception:
|
|
log.warning("JIT trace failed during the 'ipex' optimize process.")
|
|
return model
|
|
|
|
|
|
def _raise_timeout(signum, frame):
|
|
raise TimeoutError()
|
|
|
|
|
|
@create_backend
|
|
def fx2trt(subgraph, **kwargs):
|
|
if subgraph.will_tensorrt_barf():
|
|
# TensorRT fails violently with an abort() on this
|
|
return None
|
|
|
|
from torch_tensorrt.fx.fx2trt import ( # type: ignore[import]
|
|
InputTensorSpec,
|
|
TRTInterpreter,
|
|
)
|
|
from torch_tensorrt.fx.passes.lower_basic_pass import ( # type: ignore[import]
|
|
transform_setitem,
|
|
)
|
|
from torch_tensorrt.fx.tools.trt_splitter import ( # type: ignore[import]
|
|
TRTSplitter,
|
|
TRTSplitterSetting,
|
|
)
|
|
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer # type: ignore[import]
|
|
from torch_tensorrt.fx.trt_module import TRTModule # type: ignore[import]
|
|
from torch_tensorrt.fx.utils import LowerPrecision # type: ignore[import]
|
|
|
|
from .normalize import normalize_ir
|
|
|
|
try:
|
|
model = subgraph.model
|
|
inputs = subgraph.example_inputs
|
|
# normalize
|
|
model = normalize_ir(model, inputs)
|
|
# pass rewrite
|
|
model = transform_setitem(model, inputs)
|
|
acc_model = acc_tracer.trace(model, inputs)
|
|
# Split out unsupported ops
|
|
splitter_setting = TRTSplitterSetting()
|
|
splitter_setting.use_implicit_batch_dim = False
|
|
splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
|
|
splitter.node_support_preview()
|
|
split_mod = splitter()
|
|
num_piece = 0
|
|
for name, _ in split_mod.named_children():
|
|
print(f"graph is split into {name}")
|
|
num_piece += 1
|
|
|
|
# if the graph module is split into pieces larger than 8, we consider its perf
|
|
# is not good and fall back to non-TRT
|
|
if num_piece > 8:
|
|
print(
|
|
f"The graph module is split into {num_piece} which is large than the \
|
|
threshold=8. Fall back to non-TRT module."
|
|
)
|
|
return None
|
|
|
|
if "fp16_mode" in kwargs and kwargs["fp16_mode"]:
|
|
precision = LowerPrecision.FP16
|
|
else:
|
|
precision = LowerPrecision.FP32
|
|
|
|
def get_submod_inputs(mod, submod, inputs):
|
|
acc_inputs = None
|
|
|
|
def get_input(self, inputs):
|
|
nonlocal acc_inputs
|
|
acc_inputs = inputs
|
|
|
|
handle = submod.register_forward_pre_hook(get_input)
|
|
mod(*inputs)
|
|
handle.remove()
|
|
return acc_inputs
|
|
|
|
for name, _ in split_mod.named_children():
|
|
if "_run_on_acc" in name:
|
|
submod = getattr(split_mod, name)
|
|
# print("acc=",submod.code)
|
|
# Get submodule inputs for fx2trt
|
|
acc_inputs = get_submod_inputs(split_mod, submod, inputs)
|
|
|
|
# fx2trt replacement
|
|
interp = TRTInterpreter(
|
|
submod,
|
|
InputTensorSpec.from_tensors(acc_inputs),
|
|
explicit_batch_dimension=True,
|
|
)
|
|
r = interp.run(
|
|
max_workspace_size=20 << 30,
|
|
lower_precision=precision,
|
|
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
|
|
)
|
|
# For profile
|
|
# from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module
|
|
# profile_trt_module("", trt_mod, acc_inputs)
|
|
trt_mod = TRTModule(*r)
|
|
|
|
setattr(split_mod, name, trt_mod)
|
|
else:
|
|
submod = getattr(split_mod, name)
|
|
# print("gpu=",submod.code)
|
|
return subgraph.wrap_returns(split_mod)
|
|
except Exception:
|
|
log.exception("FX2TRT conversion error")
|
|
return None
|
|
|
|
|
|
@create_backend
|
|
def torch2trt(subgraph):
|
|
if subgraph.will_tensorrt_barf():
|
|
# TensorRT fails violently with an abort() on this
|
|
return None
|
|
|
|
from torch2trt import torch2trt # type: ignore[import]
|
|
|
|
inputs = subgraph.example_inputs
|
|
trt_mod = torch2trt(
|
|
subgraph.model,
|
|
inputs,
|
|
max_batch_size=len(inputs[0]),
|
|
strict_type_constraints=True,
|
|
)
|
|
return subgraph.wrap_returns(trt_mod)
|
|
|
|
|
|
@create_backend
|
|
def tensorrt(subgraph):
|
|
if subgraph.will_tensorrt_barf():
|
|
# TensorRT fails violently with an abort() on this
|
|
return None
|
|
|
|
model = onnx2tensorrt(subgraph)
|
|
if model is None:
|
|
model = torch2trt(subgraph)
|
|
return model
|
|
|
|
|
|
@create_backend
|
|
def cudagraphs(subgraph):
|
|
model = subgraph.model
|
|
inputs = subgraph.example_inputs
|
|
assert subgraph.is_cuda
|
|
return subgraph.wrap_returns(cudagraphs_inner(model, inputs))
|
|
|
|
|
|
@create_backend
|
|
def cudagraphs_ts(subgraph):
|
|
assert subgraph.is_cuda
|
|
model = subgraph.scripted
|
|
inputs = subgraph.example_inputs
|
|
|
|
# warmup
|
|
for _ in range(3):
|
|
model(*inputs)
|
|
|
|
return subgraph.wrap_returns(cudagraphs_inner(model, inputs))
|
|
|
|
|
|
@create_backend
|
|
def cudagraphs_ts_ofi(subgraph):
|
|
assert subgraph.is_cuda
|
|
model = torch.jit.optimize_for_inference(torch.jit.freeze(subgraph.scripted))
|
|
inputs = subgraph.example_inputs
|
|
|
|
# warmup
|
|
for _ in range(3):
|
|
model(*inputs)
|
|
|
|
return subgraph.wrap_returns(cudagraphs_inner(model, inputs))
|
|
|
|
|
|
def cudagraphs_inner(model, inputs, copy_outputs=True):
|
|
assert isinstance(inputs, (list, tuple))
|
|
static_inputs = [torch.zeros_like(x) for x in inputs]
|
|
|
|
# warmup
|
|
torch.cuda.synchronize()
|
|
stream = torch.cuda.Stream()
|
|
stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(stream):
|
|
model(*inputs)
|
|
stream.synchronize()
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
torch.cuda.synchronize()
|
|
|
|
# record
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph, stream=stream):
|
|
static_outputs = model(*static_inputs)
|
|
if not isinstance(static_outputs, (list, tuple)):
|
|
static_outputs = (static_outputs,)
|
|
|
|
def run(*new_inputs):
|
|
assert len(static_inputs) == len(new_inputs)
|
|
for dst, src in zip(static_inputs, new_inputs):
|
|
dst.copy_(src)
|
|
graph.replay()
|
|
if copy_outputs:
|
|
return [x.clone() for x in static_outputs]
|
|
else:
|
|
return static_outputs
|
|
|
|
return run
|
|
|
|
|
|
def tvm_compile(jit_mod, example_inputs, log_file=None, **kwargs):
|
|
if jit_mod is None:
|
|
return None
|
|
try:
|
|
return tvm_compile_inner(jit_mod, example_inputs, None, log_file, **kwargs)
|
|
except Exception as e:
|
|
if log_file and os.path.exists(log_file):
|
|
os.unlink(log_file)
|
|
if isinstance(e, KeyboardInterrupt):
|
|
raise
|
|
log.exception("tvm error")
|
|
return None
|
|
|
|
|
|
@create_backend
|
|
def tvm(subgraph):
|
|
return subgraph.wrap_returns(
|
|
tvm_compile_inner(
|
|
subgraph.scripted,
|
|
subgraph.example_inputs,
|
|
tuning_option=None,
|
|
cuda=subgraph.is_cuda,
|
|
)
|
|
)
|
|
|
|
|
|
@create_backend
|
|
def ansor(subgraph):
|
|
"""
|
|
WARNING: this backend takes hours or days to train and
|
|
often produces a slower result than the default schedule.
|
|
"""
|
|
return subgraph.wrap_returns(
|
|
tvm_compile_inner(
|
|
subgraph.scripted,
|
|
subgraph.example_inputs,
|
|
tuning_option="auto_scheduler",
|
|
log_file=subgraph.filename("ansor"),
|
|
cuda=subgraph.is_cuda,
|
|
)
|
|
)
|
|
|
|
|
|
@create_backend
|
|
def tvm_meta_schedule(subgraph):
|
|
return subgraph.wrap_returns(
|
|
tvm_compile_inner(
|
|
subgraph.scripted,
|
|
subgraph.example_inputs,
|
|
tuning_option="meta_schedule",
|
|
trials=20000,
|
|
cuda=subgraph.is_cuda,
|
|
)
|
|
)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def llvm_target():
|
|
if "avx512" in open("/proc/cpuinfo").read():
|
|
return "llvm -mcpu=skylake-avx512"
|
|
return "llvm -mcpu=core-avx2"
|
|
|
|
|
|
def tvm_compile_inner(
|
|
jit_mod, example_inputs, tuning_option=None, log_file=None, trials=20000, cuda=False
|
|
):
|
|
try:
|
|
import tvm # type: ignore[import]
|
|
from tvm import relay # type: ignore[import]
|
|
from tvm.contrib import graph_executor # type: ignore[import]
|
|
|
|
shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
|
|
mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
|
|
if cuda:
|
|
dev = tvm.cuda(0)
|
|
target = tvm.target.cuda()
|
|
else:
|
|
dev = tvm.cpu(0)
|
|
target = tvm.target.Target(llvm_target())
|
|
|
|
if tuning_option == "auto_scheduler":
|
|
from tvm import auto_scheduler
|
|
|
|
if log_file is None:
|
|
log_file = tempfile.NamedTemporaryFile()
|
|
if not os.path.exists(log_file):
|
|
tasks, task_weights = auto_scheduler.extract_tasks(
|
|
mod["main"], params, target
|
|
)
|
|
for task in tasks:
|
|
print(task.compute_dag)
|
|
else:
|
|
print("No tasks")
|
|
if len(tasks) != 0:
|
|
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
|
|
if not os.path.exists(log_file):
|
|
assert trials > 0
|
|
tune_option = auto_scheduler.TuningOptions(
|
|
num_measure_trials=trials,
|
|
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
|
|
early_stopping=2000,
|
|
)
|
|
try:
|
|
tuner.tune(tune_option)
|
|
except Exception:
|
|
if os.path.exists(log_file):
|
|
os.unlink(log_file)
|
|
raise
|
|
|
|
with auto_scheduler.ApplyHistoryBest(log_file):
|
|
with tvm.transform.PassContext(
|
|
opt_level=3, config={"relay.backend.use_auto_scheduler": True}
|
|
):
|
|
lib = relay.build(mod, target=target, params=params)
|
|
elif tuning_option == "meta_schedule":
|
|
from os import path as osp
|
|
|
|
from tvm import meta_schedule as ms
|
|
|
|
with tempfile.TemporaryDirectory() as work_dir:
|
|
if log_file is not None:
|
|
assert osp.isdir(
|
|
log_file
|
|
), "TVM's meta_schedule requires a directory for storing log files."
|
|
work_dir = log_file
|
|
# TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch
|
|
# once USE_PT_TVMDSOOP is updated and turned on by default in TVM.
|
|
database = ms.relay_integration.tune_relay(
|
|
mod=mod,
|
|
target=target,
|
|
work_dir=work_dir,
|
|
max_trials_global=20000,
|
|
num_trials_per_iter=64,
|
|
params=params,
|
|
strategy="evolutionary",
|
|
)
|
|
lib = ms.relay_integration.compile_relay(
|
|
database=database,
|
|
mod=mod,
|
|
target=target,
|
|
params=params,
|
|
)
|
|
|
|
elif tuning_option is None:
|
|
# no autotuning (for debugging)
|
|
with tvm.transform.PassContext(opt_level=10):
|
|
lib = relay.build(mod, target=target, params=params)
|
|
else:
|
|
raise NotImplementedError(
|
|
"This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
|
|
"There are three available options including None, auto_scheduler and meta_schedule."
|
|
)
|
|
m = graph_executor.GraphModule(lib["default"](dev))
|
|
|
|
def to_torch_tensor(nd_tensor):
|
|
"""A helper function to transfer a NDArray to torch.tensor."""
|
|
if nd_tensor.dtype == "bool":
|
|
# DLPack does not support boolean so it can't be handled by
|
|
# torch.utils.dlpack.from_pack. Workaround by going through
|
|
# numpy, although this brings additional data copy overhead.
|
|
return torch.from_numpy(nd_tensor.numpy())
|
|
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
|
|
|
|
def exec_tvm(*i_args):
|
|
args = [a.contiguous() for a in i_args]
|
|
for idx, arg in enumerate(args, 0):
|
|
if arg.dim() != 0:
|
|
if arg.requires_grad:
|
|
arg = arg.detach()
|
|
m.set_input(
|
|
f"inp_{idx}",
|
|
tvm.nd.array(arg.numpy(), dev),
|
|
)
|
|
m.run()
|
|
return [
|
|
to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())
|
|
]
|
|
|
|
return exec_tvm
|
|
except Exception:
|
|
log.exception("tvm error")
|
|
return jit_mod # explicit fall back to eager
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _init_ltc():
|
|
try:
|
|
import torch._lazy.extract_compiled_graph
|
|
from torch._lazy.ts_backend import init as init_ts_backend
|
|
|
|
# hopefully changing this line to sth like _ltc_init_xla_backend in future
|
|
# will enable XLA
|
|
init_ts_backend()
|
|
|
|
return torch._lazy
|
|
except ModuleNotFoundError as e:
|
|
print(f"ltc backend fails. Can not import {e.name}")
|
|
raise
|
|
|
|
|
|
def ltc_reuse_graph(gm: torch.fx.GraphModule, example_inputs):
|
|
ltc = _init_ltc()
|
|
return ltc.extract_compiled_graph.extract_compiled_graph(gm, example_inputs)
|
|
|
|
|
|
def ltc_trivial(gm: torch.fx.GraphModule, example_inputs):
|
|
ltc = _init_ltc()
|
|
lazy_model = copy.deepcopy(gm).to(device="lazy")
|
|
ltc.extract_compiled_graph.force_lazy_device(lazy_model)
|
|
|
|
def ltc_model(*inputs):
|
|
orig_device = inputs[0].device if len(inputs) > 0 else "cuda"
|
|
lazy_inputs = tuple(inp.to(device="lazy") for inp in inputs)
|
|
|
|
lazy_out = lazy_model(*lazy_inputs)
|
|
out = tuple(out.to(device=orig_device) for out in lazy_out)
|
|
return out
|
|
|
|
return ltc_model
|
|
|
|
|
|
@create_backend
|
|
def torchxla_trivial(subgraph):
|
|
return subgraph.model
|
|
|
|
|
|
@create_backend
|
|
def torchxla_trace_once(subgraph):
|
|
import torch._dynamo.optimizations.torchxla_integration as integration
|
|
|
|
model = subgraph.model
|
|
example_inputs = subgraph.example_inputs
|
|
return integration.extract_compiled_graph(model, example_inputs)
|
|
|
|
|
|
def ipex_fp32(gm: torch.fx.GraphModule, example_inputs):
|
|
kwargs_ipex = {"datatype": "fp32"}
|
|
return BACKENDS["ipex"](gm, example_inputs, **kwargs_ipex)
|
|
|
|
|
|
def ipex_bf16(gm: torch.fx.GraphModule, example_inputs):
|
|
kwargs_ipex = {"datatype": "bf16"}
|
|
return BACKENDS["ipex"](gm, example_inputs, **kwargs_ipex)
|
|
|
|
|
|
def fx2trt_compiler_fp16(gm: torch.fx.GraphModule, example_inputs):
|
|
kwargs_fx2trt = {"fp16_mode": True}
|
|
trt_compiled = BACKENDS["fx2trt"](gm, example_inputs, **kwargs_fx2trt)
|
|
if trt_compiled is not None:
|
|
return trt_compiled
|
|
else:
|
|
print(
|
|
"FX2TRT conversion failed on the subgraph. Return GraphModule forward instead"
|
|
)
|
|
return gm.forward
|
|
|
|
|
|
def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs):
|
|
kwargs_fx2trt = {"fp16_mode": False}
|
|
trt_compiled = BACKENDS["fx2trt"](gm, example_inputs, **kwargs_fx2trt)
|
|
if trt_compiled is not None:
|
|
return trt_compiled
|
|
else:
|
|
print(
|
|
"FX2TRT conversion failed on the subgraph. Return GraphModule forward instead"
|
|
)
|
|
return gm.forward
|