mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/105928 Approved by: https://github.com/albanD
48 lines
1.9 KiB
Python
48 lines
1.9 KiB
Python
import torch
|
|
|
|
|
|
class WrapperModule:
|
|
"""Wraps the instance of wrapped_type.
|
|
For graph_mode traces the instance of wrapped_type.
|
|
Randomaly initializes num_params tensors with single float element.
|
|
Args:
|
|
wrapped_type:
|
|
- Object type to be wrapped.
|
|
Expects the wrapped_type to:
|
|
- be constructed with pt_fn specified in module_config.
|
|
- provide forward method that takes module_config.num_params args.
|
|
module_config:
|
|
- Specified pt_fn to construct wrapped_type with, whether graph_mode
|
|
is enabled, and number of parameters wrapped_type's forward method
|
|
takes.
|
|
debug:
|
|
- Whether debug mode is enabled.
|
|
save:
|
|
- In graph mode, whether graph is to be saved.
|
|
"""
|
|
|
|
def __init__(self, wrapped_type, module_config, debug, save=False):
|
|
pt_fn = module_config.pt_fn
|
|
self.module = wrapped_type(pt_fn)
|
|
self.tensor_inputs = []
|
|
self.module_name = wrapped_type.__name__
|
|
for _ in range(module_config.num_params):
|
|
self.tensor_inputs.append(torch.randn(1))
|
|
if module_config.graph_mode:
|
|
self.module = torch.jit.trace(self.module, self.tensor_inputs)
|
|
if save:
|
|
file_name = self.module_name + "_" + pt_fn.__name__ + ".pt"
|
|
torch.jit.save(self.module, file_name)
|
|
print(f"Generated graph is saved in {file_name}")
|
|
print(
|
|
f"Benchmarking module {self.module_name} with fn {pt_fn.__name__}: Graph mode:{module_config.graph_mode}"
|
|
)
|
|
if debug and isinstance(self.module, torch.jit.ScriptModule):
|
|
print(self.module.graph)
|
|
print(self.module.code)
|
|
|
|
def forward(self, niters):
|
|
with torch.no_grad():
|
|
for _ in range(niters):
|
|
self.module.forward(*self.tensor_inputs)
|