Timing cache for Tensort (#67214)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67214

This is draft for creating timing cache for tensorrt.

Reviewed By: yinghai, 842974287

Differential Revision: D31783757

fbshipit-source-id: 211ab68df0832120fa637304e4a7ece80d26f9b1
This commit is contained in:
Shirong Wu
2021-10-28 11:20:22 -07:00
committed by Facebook GitHub Bot
parent 0032fa7725
commit 55b7387e45
2 changed files with 19 additions and 5 deletions

View File

@ -46,8 +46,8 @@ def build_int8_trt(rn18):
[InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float,
shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)],
explicit_batch_dimension=True, explicit_precision=True, logger_level=trt.Logger.VERBOSE)
engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True)
trt_mod = TRTModule(engine, input_names, output_names)
interpreter_result = interp.run(fp16_mode=False, int8_mode=True)
trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
trt_res = trt_mod(data.cuda())
print("explicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
return trt_mod
@ -74,8 +74,8 @@ def build_int8_trt_implicit_quant(rn18):
shape_prop.ShapeProp(traced_rn18).propagate(data)
traced_rn18 = NormalizeArgs(traced_rn18).transform()
interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE)
engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
trt_mod = TRTModule(engine, input_names, output_names)
interpreter_result = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
trt_res = trt_mod(data.cuda())
print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
return trt_mod

View File

@ -6,12 +6,14 @@ import torch
import torch.fx
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata
import numpy
class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
output_names: Sequence[str]
serialized_cache: bytearray
# Borrowed from torch2trt
@ -420,6 +422,7 @@ class TRTInterpreter(torch.fx.Interpreter):
force_fp32_output=False,
strict_type_constraints=False,
algorithm_selector=None,
timing_cache=None,
) -> TRTInterpreterResult:
# For float outputs, we set their dtype to fp16 only if fp16_mode=True and
# force_fp32_output=False.
@ -437,6 +440,13 @@ class TRTInterpreter(torch.fx.Interpreter):
self.builder.max_batch_size = max_batch_size
builder_config = self.builder.create_builder_config()
builder_config.max_workspace_size = max_workspace_size
cache = None
if timing_cache:
cache_file = numpy.array(timing_cache)
cache = builder_config.create_timing_cache(cache_file.tobytes())
builder_config.set_timing_cache(cache, True)
if fp16_mode:
builder_config.set_flag(trt.BuilderFlag.FP16)
@ -456,7 +466,11 @@ class TRTInterpreter(torch.fx.Interpreter):
engine = self.builder.build_engine(self.network, builder_config)
assert engine
return TRTInterpreterResult(engine, self._input_names, self._output_names)
serialized_cache = bytearray(builder_config.get_timing_cache().serialize()) \
if builder_config.get_timing_cache() else bytearray()
return TRTInterpreterResult(engine, self._input_names, self._output_names, serialized_cache)
def run_node(self, n):
self._cur_node_name = str(n)