mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
I want to use StaticModule in another (internal) test, so splitting it out. Differential Revision: [D54384817](https://our.internmc.facebook.com/intern/diff/D54384817/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121028 Approved by: https://github.com/suo
27 lines
866 B
Python
27 lines
866 B
Python
# Owner(s): ["module: unknown"]
|
|
|
|
import torch
|
|
|
|
|
|
class StaticModule:
|
|
def __init__(self, scripted):
|
|
# this is an nn.Module
|
|
if hasattr(scripted, "_c"):
|
|
self.static_module = torch._C._jit_to_static_module(scripted._c)
|
|
else:
|
|
self.static_module = torch._C._jit_to_static_module(scripted.graph)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.static_module(*args, **kwargs)
|
|
|
|
def benchmark(self, args, kwargs, warmup_runs, main_runs):
|
|
self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
|
|
|
|
def runAsync(self, args, kwargs):
|
|
return self.static_module.runAsync(args, kwargs)
|
|
|
|
def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
|
|
return self.static_module.benchmark_individual_ops(
|
|
args, kwargs, warmup_runs, main_runs
|
|
)
|