mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
38 lines
715 B
Python
38 lines
715 B
Python
from typing import Any
|
|
|
|
import torch
|
|
|
|
|
|
@torch.jit.script
|
|
class MyScriptClass:
|
|
"""Intended to be scripted."""
|
|
|
|
def __init__(self, x):
|
|
self.foo = x
|
|
|
|
def set_foo(self, x):
|
|
self.foo = x
|
|
|
|
|
|
@torch.jit.script
|
|
def uses_script_class(x):
|
|
"""Intended to be scripted."""
|
|
foo = MyScriptClass(x)
|
|
return foo.foo
|
|
|
|
|
|
class IdListFeature:
|
|
def __init__(self) -> None:
|
|
self.id_list = torch.ones(1, 1)
|
|
|
|
def returns_self(self) -> "IdListFeature":
|
|
return IdListFeature()
|
|
|
|
|
|
class UsesIdListFeature(torch.nn.Module):
|
|
def forward(self, feature: Any):
|
|
if isinstance(feature, IdListFeature):
|
|
return feature.id_list
|
|
else:
|
|
return feature
|