mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Solves https://github.com/pytorch/pytorch/issues/151925 Currently, AOTI only generate runtime asserts for unbacked symints. We should generate asserts for all `_assert_scalar` calls in the input graph. Also factored out the run time assertion logic to a separate function. We need to generate runtime asserts directly in Inductor instead of just re-using the asserts from input graphs becase we reuse the same ShapeEnv as before. In particular, on subsequent graph passes, we would immediately turn all of these assertions into noops, because when we evaluated their expressions, we would see that because we had a deferred runtime assert in the ShapeEnv, we know "oh, of course this expression is True" already. One example is below: ``` class Model(torch.nn.Module): def forward(self, a, b, c): nz = torch.nonzero(a) ones = a.new_ones([nz.size(0), b.size(0)]) torch._check(ones.size(0) >= 1) equals = torch.add(ones, c) return equals torch._dynamo.mark_dynamic(c, 0) ``` When we re-use the ShapeEnv in Inductor lowering, the check that checks a and nonzero have the same shape would be evaluted to True after we resolve unbacked bindings using the ShapeEnv. See test_unbacked_equals_input_size_runtime_assertion in test_aot_inductor. In addition to the Inductor generated runtime asserts, we also need the runtime asserts from the input graph, because some derived runtime asserts are not generated in Inductor. One example is below: ``` class Model(torch.nn.Module): def forward(self, x): y = x.reshape(100, -1).clone() y = y + 1 return y dynamic_shapes = { "x": {0: torch.export.Dim.DYNAMIC}, } x.shape[0] needs to be a multiple of 100. ``` See test_aoti_runtime_asserts_backed_symint in test_aot_inductor. Example: ``` def forward(self): arg0_1: "f32[s35]"; arg0_1, = fx_pytree.tree_flatten_spec([], self._in_spec) # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/73a672eb896e7996/scripts/shangdiy/__pt__/pt#link-tree/scripts/shangdiy/pt.py:11 in forward, code: y = x.reshape(100, -1).clone() sym_size_int: "Sym(s35)" = torch.ops.aten.sym_size.int(arg0_1, 0) # mod: "Sym(Mod(s35, 100))" = sym_size_int % 100; sym_size_int = None eq_2: "Sym(Eq(Mod(s35, 100), 0))" = mod == 0; mod = None _assert_scalar = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(Mod(s35, 100), 0) on node 'eq'"); eq_2 = _assert_scalar = None # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/73a672eb896e7996/scripts/shangdiy/__pt__/pt#link-tree/scripts/shangdiy/pt.py:11 in forward, code: y = x.reshape(100, -1).clone() view: "f32[100, (s35//100)]" = torch.ops.aten.reshape.default(arg0_1, [100, -1]); arg0_1 = None clone: "f32[100, (s35//100)]" = torch.ops.aten.clone.default(view); view = None # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/73a672eb896e7996/scripts/shangdiy/__pt__/pt#link-tree/scripts/shangdiy/pt.py:12 in forward, code: y = y + 1 add_6: "f32[100, 1]" = torch.ops.aten.add.Tensor(clone, 1); clone = None return (add_6,) ``` Generated cpp code: ``` auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 1); auto arg0_1 = std::move(inputs[0]); auto arg0_1_size = arg0_1.sizes(); int64_t s35 = arg0_1_size[0]; inputs.clear(); auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get()); if (!((s35 % 100L) == 0L)) { throw std::runtime_error("Expected Eq(Mod(s35, 100), 0) to be True but received " + std::to_string(s35)); } ``` Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r aoti_runtime_asserts_backed_symint ``` Differential Revision: D73596786 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152125 Approved by: https://github.com/henrylhtsang, https://github.com/jingsh
285 lines
8.5 KiB
Python
285 lines
8.5 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
import logging
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
from typing import Any, Callable, Optional, TypeVar
|
|
from typing_extensions import ParamSpec
|
|
|
|
import torch
|
|
from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
_P = ParamSpec("_P")
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
if os.environ.get("TORCH_COMPILE_STROBELIGHT", False):
|
|
import shutil
|
|
|
|
if not shutil.which("strobeclient"):
|
|
log.info(
|
|
"TORCH_COMPILE_STROBELIGHT is true, but seems like you are not on a FB machine."
|
|
)
|
|
else:
|
|
log.info("Strobelight profiler is enabled via environment variable")
|
|
StrobelightCompileTimeProfiler.enable()
|
|
|
|
# this arbitrary-looking assortment of functionality is provided here
|
|
# to have a central place for overrideable behavior. The motivating
|
|
# use is the FB build environment, where this source file is replaced
|
|
# by an equivalent.
|
|
|
|
if torch._running_with_deploy():
|
|
# __file__ is meaningless in the context of frozen torch used in torch deploy.
|
|
# setting empty torch_parent should allow below functions to operate without crashing,
|
|
# but it's unclear if there is a valid use case for them in the context of deploy.
|
|
torch_parent = ""
|
|
else:
|
|
if os.path.basename(os.path.dirname(__file__)) == "shared":
|
|
torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|
else:
|
|
torch_parent = os.path.dirname(os.path.dirname(__file__))
|
|
|
|
|
|
def get_file_path(*path_components: str) -> str:
|
|
return os.path.join(torch_parent, *path_components)
|
|
|
|
|
|
def get_file_path_2(*path_components: str) -> str:
|
|
return os.path.join(*path_components)
|
|
|
|
|
|
def get_writable_path(path: str) -> str:
|
|
if os.access(path, os.W_OK):
|
|
return path
|
|
return tempfile.mkdtemp(suffix=os.path.basename(path))
|
|
|
|
|
|
def prepare_multiprocessing_environment(path: str) -> None:
|
|
pass
|
|
|
|
|
|
def resolve_library_path(path: str) -> str:
|
|
return os.path.realpath(path)
|
|
|
|
|
|
def throw_abstract_impl_not_imported_error(opname, module, context):
|
|
if module in sys.modules:
|
|
raise NotImplementedError(
|
|
f"{opname}: We could not find the fake impl for this operator. "
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"{opname}: We could not find the fake impl for this operator. "
|
|
f"The operator specified that you may need to import the '{module}' "
|
|
f"Python module to load the fake impl. {context}"
|
|
)
|
|
|
|
|
|
# NB! This treats "skip" kwarg specially!!
|
|
def compile_time_strobelight_meta(
|
|
phase_name: str,
|
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
|
def compile_time_strobelight_meta_inner(
|
|
function: Callable[_P, _T],
|
|
) -> Callable[_P, _T]:
|
|
@functools.wraps(function)
|
|
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
|
if "skip" in kwargs and isinstance(skip := kwargs["skip"], int):
|
|
kwargs["skip"] = skip + 1
|
|
|
|
# This is not needed but we have it here to avoid having profile_compile_time
|
|
# in stack traces when profiling is not enabled.
|
|
if not StrobelightCompileTimeProfiler.enabled:
|
|
return function(*args, **kwargs)
|
|
|
|
return StrobelightCompileTimeProfiler.profile_compile_time(
|
|
function, phase_name, *args, **kwargs
|
|
)
|
|
|
|
return wrapper_function
|
|
|
|
return compile_time_strobelight_meta_inner
|
|
|
|
|
|
# Meta only, see
|
|
# https://www.internalfb.com/intern/wiki/ML_Workflow_Observability/User_Guides/Adding_instrumentation_to_your_code/
|
|
#
|
|
# This will cause an event to get logged to Scuba via the signposts API. You
|
|
# can view samples on the API at https://fburl.com/scuba/workflow_signpost/zh9wmpqs
|
|
# we log to subsystem "torch", and the category and name you provide here.
|
|
# Each of the arguments translate into a Scuba column. We're still figuring
|
|
# out local conventions in PyTorch, but category should be something like
|
|
# "dynamo" or "inductor", and name should be a specific string describing what
|
|
# kind of event happened.
|
|
#
|
|
# Killswitch is at
|
|
# https://www.internalfb.com/intern/justknobs/?name=pytorch%2Fsignpost#event
|
|
def signpost_event(category: str, name: str, parameters: dict[str, Any]):
|
|
log.info("%s %s: %r", category, name, parameters)
|
|
|
|
|
|
def log_compilation_event(metrics):
|
|
log.info("%s", metrics)
|
|
|
|
|
|
def upload_graph(graph):
|
|
pass
|
|
|
|
|
|
def set_pytorch_distributed_envs_from_justknobs():
|
|
pass
|
|
|
|
|
|
def log_export_usage(**kwargs):
|
|
pass
|
|
|
|
|
|
def log_trace_structured_event(*args, **kwargs) -> None:
|
|
pass
|
|
|
|
|
|
def log_cache_bypass(*args, **kwargs) -> None:
|
|
pass
|
|
|
|
|
|
def log_torchscript_usage(api: str, **kwargs):
|
|
_ = api
|
|
return
|
|
|
|
|
|
def check_if_torch_exportable():
|
|
return False
|
|
|
|
|
|
def export_training_ir_rollout_check() -> bool:
|
|
return True
|
|
|
|
|
|
def full_aoti_runtime_assert() -> bool:
|
|
return True
|
|
|
|
|
|
def log_torch_jit_trace_exportability(
|
|
api: str,
|
|
type_of_export: str,
|
|
export_outcome: str,
|
|
result: str,
|
|
):
|
|
_, _, _, _ = api, type_of_export, export_outcome, result
|
|
return
|
|
|
|
|
|
def justknobs_check(name: str, default: bool = True) -> bool:
|
|
"""
|
|
This function can be used to killswitch functionality in FB prod,
|
|
where you can toggle this value to False in JK without having to
|
|
do a code push. In OSS, we always have everything turned on all
|
|
the time, because downstream users can simply choose to not update
|
|
PyTorch. (If more fine-grained enable/disable is needed, we could
|
|
potentially have a map we lookup name in to toggle behavior. But
|
|
the point is that it's all tied to source code in OSS, since there's
|
|
no live server to query.)
|
|
|
|
This is the bare minimum functionality I needed to do some killswitches.
|
|
We have a more detailed plan at
|
|
https://docs.google.com/document/d/1Ukerh9_42SeGh89J-tGtecpHBPwGlkQ043pddkKb3PU/edit
|
|
In particular, in some circumstances it may be necessary to read in
|
|
a knob once at process start, and then use it consistently for the
|
|
rest of the process. Future functionality will codify these patterns
|
|
into a better high level API.
|
|
|
|
WARNING: Do NOT call this function at module import time, JK is not
|
|
fork safe and you will break anyone who forks the process and then
|
|
hits JK again.
|
|
"""
|
|
return default
|
|
|
|
|
|
def justknobs_getval_int(name: str) -> int:
|
|
"""
|
|
Read warning on justknobs_check
|
|
"""
|
|
return 0
|
|
|
|
|
|
def is_fb_unit_test() -> bool:
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def max_clock_rate():
|
|
if not torch.version.hip:
|
|
from triton.testing import nvsmi
|
|
|
|
return nvsmi(["clocks.max.sm"])[0]
|
|
else:
|
|
# Manually set max-clock speeds on ROCm until equivalent nvmsi
|
|
# functionality in triton.testing or via pyamdsmi enablement. Required
|
|
# for test_snode_runtime unit tests.
|
|
gcn_arch = str(torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0])
|
|
if "gfx94" in gcn_arch:
|
|
return 1700
|
|
elif "gfx90a" in gcn_arch:
|
|
return 1700
|
|
elif "gfx908" in gcn_arch:
|
|
return 1502
|
|
elif "gfx12" in gcn_arch:
|
|
return 1700
|
|
elif "gfx11" in gcn_arch:
|
|
return 1700
|
|
elif "gfx103" in gcn_arch:
|
|
return 1967
|
|
elif "gfx101" in gcn_arch:
|
|
return 1144
|
|
elif "gfx95" in gcn_arch:
|
|
return 1700 # TODO: placeholder, get actual value
|
|
else:
|
|
return 1100
|
|
|
|
|
|
def get_mast_job_name_version() -> Optional[tuple[str, int]]:
|
|
return None
|
|
|
|
|
|
TEST_MASTER_ADDR = "127.0.0.1"
|
|
TEST_MASTER_PORT = 29500
|
|
# USE_GLOBAL_DEPS controls whether __init__.py tries to load
|
|
# libtorch_global_deps, see Note [Global dependencies]
|
|
USE_GLOBAL_DEPS = True
|
|
# USE_RTLD_GLOBAL_WITH_LIBTORCH controls whether __init__.py tries to load
|
|
# _C.so with RTLD_GLOBAL during the call to dlopen.
|
|
USE_RTLD_GLOBAL_WITH_LIBTORCH = False
|
|
# If an op was defined in C++ and extended from Python using the
|
|
# torch.library.register_fake, returns if we require that there be a
|
|
# m.set_python_module("mylib.ops") call from C++ that associates
|
|
# the C++ op with a python module.
|
|
REQUIRES_SET_PYTHON_MODULE = False
|
|
|
|
|
|
def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]:
|
|
print("Uploading profile stats (fb-only otherwise no-op)")
|
|
return None
|
|
|
|
|
|
def log_chromium_event_internal(
|
|
event: dict[str, Any],
|
|
stack: list[str],
|
|
logger_uuid: str,
|
|
start_time_ns: int,
|
|
):
|
|
return None
|
|
|
|
|
|
def record_chromium_event_internal(
|
|
event: dict[str, Any],
|
|
):
|
|
return None
|
|
|
|
|
|
def profiler_allow_cudagraph_cupti_lazy_reinit_cuda12():
|
|
return True
|