Files
pytorch/test/load_torchscript_model.py
2024-11-04 18:30:29 +00:00

14 lines
311 B
Python

import sys
import torch
if __name__ == "__main__":
script_mod = torch.jit.load(sys.argv[1])
# weights_only=False as this is loading a sharded model
mod = torch.load(sys.argv[1] + ".orig", weights_only=False)
print(script_mod)
inp = torch.rand(2, 28 * 28)
_ = mod(inp)
sys.exit(0)