Add scuba logging for TorchScript usage (#121936)

Summary: Infra to log live usage of TorchScript internally

Test Plan: manually tested

Differential Revision: D54923510

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121936
Approved by: https://github.com/zhxchen17
This commit is contained in:
Yanan Cao (PyTorch)
2024-03-19 17:38:27 +00:00
committed by PyTorch MergeBot
parent 4819da60ab
commit ba9a1d96a4
5 changed files with 21 additions and 0 deletions

View File

@ -46,6 +46,7 @@ import torch.package._mangling as package_mangling
from torch._awaits import _Await
from torch._C import _Await as CAwait, Future as CFuture
from torch._sources import fake_range, get_source_lines_and_file, parse_def
from torch._utils_internal import log_torchscript_usage
from torch.futures import Future
IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
@ -582,6 +583,7 @@ def export(fn):
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())
"""
log_torchscript_usage("export")
fn._torchscript_modifier = FunctionModifiers.EXPORT
return fn
@ -623,6 +625,7 @@ def unused(fn):
# exception raised
m(torch.rand(100))
"""
log_torchscript_usage("unused")
if isinstance(fn, property):
prop = fn
setattr( # noqa: B010
@ -710,6 +713,7 @@ def ignore(drop=False, **kwargs):
import os
os.remove('m.pt')
"""
log_torchscript_usage("ignore")
if callable(drop):
# used without any args, so drop is actually a function

View File

@ -95,6 +95,11 @@ def log_export_usage(**kwargs):
pass
def log_torchscript_usage(api: str):
_ = api
return
def justknobs_check(name: str) -> bool:
"""
This function can be used to killswitch functionality in FB prod,

View File

@ -19,6 +19,7 @@ import torch
import torch._jit_internal as _jit_internal
from torch._classes import classes
from torch._jit_internal import _qualified_name
from torch._utils_internal import log_torchscript_usage
from torch.jit._builtins import _register_builtin
from torch.jit._fuser import _graph_for, _script_method_graph_for
@ -1287,6 +1288,8 @@ def script(
if not _enabled:
return obj
log_torchscript_usage("script")
if optimize is not None:
warnings.warn(
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"

View File

@ -7,9 +7,11 @@ This module contains functionality for serializing TorchScript modules, notably:
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import os
import torch
from torch._utils_internal import log_torchscript_usage
from torch.jit._recursive import wrap_cpp_module
from torch.serialization import validate_cuda_device
@ -73,6 +75,7 @@ def save(m, f, _extra_files=None):
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
"""
log_torchscript_usage("save")
if _extra_files is None:
_extra_files = {}
if isinstance(f, (str, os.PathLike)):
@ -143,6 +146,7 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
import os
os.remove("scriptmodule.pt")
"""
log_torchscript_usage("load")
if isinstance(f, (str, os.PathLike)):
if not os.path.exists(f): # type: ignore[type-var]
raise ValueError(f"The provided filename {f} does not exist") # type: ignore[str-bytes-safe]

View File

@ -7,6 +7,7 @@ This module contains functionality to support the JIT's tracing frontend, notabl
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import contextlib
import copy
@ -25,6 +26,8 @@ from torch._jit_internal import (
get_callable_argument_names,
is_scripting,
)
from torch._utils_internal import log_torchscript_usage
from torch.autograd import function
from torch.jit._script import _CachedForward, script, ScriptModule
@ -803,6 +806,8 @@ def trace(
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
)
log_torchscript_usage("trace")
if isinstance(func, torch.jit.ScriptModule):
# it is hard to trace it because the forward method on ScriptModule is already defined, so it
# would result in an error.