diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 64509816e09c..20fb3ff4d96d 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -46,6 +46,7 @@ import torch.package._mangling as package_mangling from torch._awaits import _Await from torch._C import _Await as CAwait, Future as CFuture from torch._sources import fake_range, get_source_lines_and_file, parse_def +from torch._utils_internal import log_torchscript_usage from torch.futures import Future IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9) @@ -582,6 +583,7 @@ def export(fn): # any compiled methods and wasn't decorated with `@torch.jit.export` m = torch.jit.script(MyModule()) """ + log_torchscript_usage("export") fn._torchscript_modifier = FunctionModifiers.EXPORT return fn @@ -623,6 +625,7 @@ def unused(fn): # exception raised m(torch.rand(100)) """ + log_torchscript_usage("unused") if isinstance(fn, property): prop = fn setattr( # noqa: B010 @@ -710,6 +713,7 @@ def ignore(drop=False, **kwargs): import os os.remove('m.pt') """ + log_torchscript_usage("ignore") if callable(drop): # used without any args, so drop is actually a function diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 1f85b3e7ce61..43d4cfee2b6d 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -95,6 +95,11 @@ def log_export_usage(**kwargs): pass +def log_torchscript_usage(api: str): + _ = api + return + + def justknobs_check(name: str) -> bool: """ This function can be used to killswitch functionality in FB prod, diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 5e29c43d4530..2d087bcdd593 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -19,6 +19,7 @@ import torch import torch._jit_internal as _jit_internal from torch._classes import classes from torch._jit_internal import _qualified_name +from torch._utils_internal import log_torchscript_usage from torch.jit._builtins import _register_builtin from torch.jit._fuser import _graph_for, _script_method_graph_for @@ -1287,6 +1288,8 @@ def script( if not _enabled: return obj + log_torchscript_usage("script") + if optimize is not None: warnings.warn( "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index 00b9254a263c..514f23cb76d3 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -7,9 +7,11 @@ This module contains functionality for serializing TorchScript modules, notably: This is not intended to be imported directly; please use the exposed functionalities in `torch.jit`. """ + import os import torch +from torch._utils_internal import log_torchscript_usage from torch.jit._recursive import wrap_cpp_module from torch.serialization import validate_cuda_device @@ -73,6 +75,7 @@ def save(m, f, _extra_files=None): extra_files = {'foo.txt': b'bar'} torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files) """ + log_torchscript_usage("save") if _extra_files is None: _extra_files = {} if isinstance(f, (str, os.PathLike)): @@ -143,6 +146,7 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False): import os os.remove("scriptmodule.pt") """ + log_torchscript_usage("load") if isinstance(f, (str, os.PathLike)): if not os.path.exists(f): # type: ignore[type-var] raise ValueError(f"The provided filename {f} does not exist") # type: ignore[str-bytes-safe] diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 23fe78201f10..d4651dab3655 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -7,6 +7,7 @@ This module contains functionality to support the JIT's tracing frontend, notabl This is not intended to be imported directly; please use the exposed functionalities in `torch.jit`. """ + import contextlib import copy @@ -25,6 +26,8 @@ from torch._jit_internal import ( get_callable_argument_names, is_scripting, ) + +from torch._utils_internal import log_torchscript_usage from torch.autograd import function from torch.jit._script import _CachedForward, script, ScriptModule @@ -803,6 +806,8 @@ def trace( "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" ) + log_torchscript_usage("trace") + if isinstance(func, torch.jit.ScriptModule): # it is hard to trace it because the forward method on ScriptModule is already defined, so it # would result in an error.