mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] Split StaticModule out of test_static_runtime (#121028)
I want to use StaticModule in another (internal) test, so splitting it out. Differential Revision: [D54384817](https://our.internmc.facebook.com/intern/diff/D54384817/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121028 Approved by: https://github.com/suo
This commit is contained in:
committed by
PyTorch MergeBot
parent
f5391dad82
commit
cac36e232e
@ -7,30 +7,9 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.static_module import StaticModule
|
||||
from typing import List
|
||||
|
||||
class StaticModule:
|
||||
def __init__(self, scripted):
|
||||
# this is an nn.Module
|
||||
if hasattr(scripted, "_c"):
|
||||
self.static_module = torch._C._jit_to_static_module(scripted._c)
|
||||
else:
|
||||
self.static_module = torch._C._jit_to_static_module(scripted.graph)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.static_module(*args, **kwargs)
|
||||
|
||||
def benchmark(self, args, kwargs, warmup_runs, main_runs):
|
||||
self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
|
||||
|
||||
def runAsync(self, args, kwargs):
|
||||
return self.static_module.runAsync(args, kwargs)
|
||||
|
||||
def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
|
||||
return self.static_module.benchmark_individual_ops(
|
||||
args, kwargs, warmup_runs, main_runs
|
||||
)
|
||||
|
||||
|
||||
def linear_shim(
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
|
Reference in New Issue
Block a user