mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: att Test Plan: ci Rollback Plan: Differential Revision: D81731425 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162353 Approved by: https://github.com/yiming0416
34 lines
925 B
Python
34 lines
925 B
Python
from typing import Optional
|
|
|
|
import torch
|
|
from torch.export import ExportedProgram
|
|
|
|
|
|
class LoweredBackendModule(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
original_exported_program: ExportedProgram,
|
|
backend_id: str,
|
|
*,
|
|
module_name: Optional[str] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self._backend_id = backend_id
|
|
self._module_name = module_name
|
|
self._original_exported_program = original_exported_program
|
|
|
|
@property
|
|
def backend_id(self) -> str:
|
|
return self._backend_id
|
|
|
|
@property
|
|
def module_name(self) -> Optional[str]:
|
|
return self._module_name
|
|
|
|
@property
|
|
def original_module(self) -> ExportedProgram:
|
|
return self._original_exported_program
|
|
|
|
def forward(self, *args, **kwargs): # type: ignore[no-untyped-def]
|
|
return torch._higher_order_ops.executorch_call_delegate(self, *args, **kwargs)
|