mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add scuba logging for TorchScript usage (#121936)
Summary: Infra to log live usage of TorchScript internally Test Plan: manually tested Differential Revision: D54923510 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121936 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
4819da60ab
commit
ba9a1d96a4
@ -46,6 +46,7 @@ import torch.package._mangling as package_mangling
|
||||
from torch._awaits import _Await
|
||||
from torch._C import _Await as CAwait, Future as CFuture
|
||||
from torch._sources import fake_range, get_source_lines_and_file, parse_def
|
||||
from torch._utils_internal import log_torchscript_usage
|
||||
from torch.futures import Future
|
||||
|
||||
IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
|
||||
@ -582,6 +583,7 @@ def export(fn):
|
||||
# any compiled methods and wasn't decorated with `@torch.jit.export`
|
||||
m = torch.jit.script(MyModule())
|
||||
"""
|
||||
log_torchscript_usage("export")
|
||||
fn._torchscript_modifier = FunctionModifiers.EXPORT
|
||||
return fn
|
||||
|
||||
@ -623,6 +625,7 @@ def unused(fn):
|
||||
# exception raised
|
||||
m(torch.rand(100))
|
||||
"""
|
||||
log_torchscript_usage("unused")
|
||||
if isinstance(fn, property):
|
||||
prop = fn
|
||||
setattr( # noqa: B010
|
||||
@ -710,6 +713,7 @@ def ignore(drop=False, **kwargs):
|
||||
import os
|
||||
os.remove('m.pt')
|
||||
"""
|
||||
log_torchscript_usage("ignore")
|
||||
|
||||
if callable(drop):
|
||||
# used without any args, so drop is actually a function
|
||||
|
@ -95,6 +95,11 @@ def log_export_usage(**kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def log_torchscript_usage(api: str):
|
||||
_ = api
|
||||
return
|
||||
|
||||
|
||||
def justknobs_check(name: str) -> bool:
|
||||
"""
|
||||
This function can be used to killswitch functionality in FB prod,
|
||||
|
@ -19,6 +19,7 @@ import torch
|
||||
import torch._jit_internal as _jit_internal
|
||||
from torch._classes import classes
|
||||
from torch._jit_internal import _qualified_name
|
||||
from torch._utils_internal import log_torchscript_usage
|
||||
from torch.jit._builtins import _register_builtin
|
||||
from torch.jit._fuser import _graph_for, _script_method_graph_for
|
||||
|
||||
@ -1287,6 +1288,8 @@ def script(
|
||||
if not _enabled:
|
||||
return obj
|
||||
|
||||
log_torchscript_usage("script")
|
||||
|
||||
if optimize is not None:
|
||||
warnings.warn(
|
||||
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
|
||||
|
@ -7,9 +7,11 @@ This module contains functionality for serializing TorchScript modules, notably:
|
||||
This is not intended to be imported directly; please use the exposed
|
||||
functionalities in `torch.jit`.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch._utils_internal import log_torchscript_usage
|
||||
from torch.jit._recursive import wrap_cpp_module
|
||||
from torch.serialization import validate_cuda_device
|
||||
|
||||
@ -73,6 +75,7 @@ def save(m, f, _extra_files=None):
|
||||
extra_files = {'foo.txt': b'bar'}
|
||||
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
|
||||
"""
|
||||
log_torchscript_usage("save")
|
||||
if _extra_files is None:
|
||||
_extra_files = {}
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
@ -143,6 +146,7 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
|
||||
import os
|
||||
os.remove("scriptmodule.pt")
|
||||
"""
|
||||
log_torchscript_usage("load")
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
if not os.path.exists(f): # type: ignore[type-var]
|
||||
raise ValueError(f"The provided filename {f} does not exist") # type: ignore[str-bytes-safe]
|
||||
|
@ -7,6 +7,7 @@ This module contains functionality to support the JIT's tracing frontend, notabl
|
||||
This is not intended to be imported directly; please use the exposed
|
||||
functionalities in `torch.jit`.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
|
||||
import copy
|
||||
@ -25,6 +26,8 @@ from torch._jit_internal import (
|
||||
get_callable_argument_names,
|
||||
is_scripting,
|
||||
)
|
||||
|
||||
from torch._utils_internal import log_torchscript_usage
|
||||
from torch.autograd import function
|
||||
from torch.jit._script import _CachedForward, script, ScriptModule
|
||||
|
||||
@ -803,6 +806,8 @@ def trace(
|
||||
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
|
||||
)
|
||||
|
||||
log_torchscript_usage("trace")
|
||||
|
||||
if isinstance(func, torch.jit.ScriptModule):
|
||||
# it is hard to trace it because the forward method on ScriptModule is already defined, so it
|
||||
# would result in an error.
|
||||
|
Reference in New Issue
Block a user