mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Timing cache for Tensort (#67214)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67214 This is draft for creating timing cache for tensorrt. Reviewed By: yinghai, 842974287 Differential Revision: D31783757 fbshipit-source-id: 211ab68df0832120fa637304e4a7ece80d26f9b1
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0032fa7725
commit
55b7387e45
@ -46,8 +46,8 @@ def build_int8_trt(rn18):
|
||||
[InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float,
|
||||
shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)],
|
||||
explicit_batch_dimension=True, explicit_precision=True, logger_level=trt.Logger.VERBOSE)
|
||||
engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True)
|
||||
trt_mod = TRTModule(engine, input_names, output_names)
|
||||
interpreter_result = interp.run(fp16_mode=False, int8_mode=True)
|
||||
trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
|
||||
trt_res = trt_mod(data.cuda())
|
||||
print("explicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
|
||||
return trt_mod
|
||||
@ -74,8 +74,8 @@ def build_int8_trt_implicit_quant(rn18):
|
||||
shape_prop.ShapeProp(traced_rn18).propagate(data)
|
||||
traced_rn18 = NormalizeArgs(traced_rn18).transform()
|
||||
interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE)
|
||||
engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
|
||||
trt_mod = TRTModule(engine, input_names, output_names)
|
||||
interpreter_result = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
|
||||
trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
|
||||
trt_res = trt_mod(data.cuda())
|
||||
print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
|
||||
return trt_mod
|
||||
|
@ -6,12 +6,14 @@ import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import _get_qualified_name
|
||||
from torch.fx.passes.shape_prop import TensorMetadata
|
||||
import numpy
|
||||
|
||||
|
||||
class TRTInterpreterResult(NamedTuple):
|
||||
engine: Any
|
||||
input_names: Sequence[str]
|
||||
output_names: Sequence[str]
|
||||
serialized_cache: bytearray
|
||||
|
||||
|
||||
# Borrowed from torch2trt
|
||||
@ -420,6 +422,7 @@ class TRTInterpreter(torch.fx.Interpreter):
|
||||
force_fp32_output=False,
|
||||
strict_type_constraints=False,
|
||||
algorithm_selector=None,
|
||||
timing_cache=None,
|
||||
) -> TRTInterpreterResult:
|
||||
# For float outputs, we set their dtype to fp16 only if fp16_mode=True and
|
||||
# force_fp32_output=False.
|
||||
@ -437,6 +440,13 @@ class TRTInterpreter(torch.fx.Interpreter):
|
||||
self.builder.max_batch_size = max_batch_size
|
||||
builder_config = self.builder.create_builder_config()
|
||||
builder_config.max_workspace_size = max_workspace_size
|
||||
|
||||
cache = None
|
||||
if timing_cache:
|
||||
cache_file = numpy.array(timing_cache)
|
||||
cache = builder_config.create_timing_cache(cache_file.tobytes())
|
||||
builder_config.set_timing_cache(cache, True)
|
||||
|
||||
if fp16_mode:
|
||||
builder_config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
@ -456,7 +466,11 @@ class TRTInterpreter(torch.fx.Interpreter):
|
||||
|
||||
engine = self.builder.build_engine(self.network, builder_config)
|
||||
assert engine
|
||||
return TRTInterpreterResult(engine, self._input_names, self._output_names)
|
||||
|
||||
serialized_cache = bytearray(builder_config.get_timing_cache().serialize()) \
|
||||
if builder_config.get_timing_cache() else bytearray()
|
||||
|
||||
return TRTInterpreterResult(engine, self._input_names, self._output_names, serialized_cache)
|
||||
|
||||
def run_node(self, n):
|
||||
self._cur_node_name = str(n)
|
||||
|
Reference in New Issue
Block a user