consolidate fw and inference compile paths (#165457)

By design, fw compile and inference compile stages should share a bunch of code; just consolidating the duplication here.

Differential Revision: D84628978

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165457
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
This commit is contained in:
Avik Chaudhuri
2025-10-15 21:33:50 +00:00
committed by PyTorch MergeBot
parent dfc8a1c5dd
commit fa1539594b

View File

@ -322,83 +322,14 @@ def _aot_stage2b_inference_compile(
fw_metadata: ViewAndMutationMeta,
aot_config,
) -> Callable:
"""
Compile the inference graph. Returns the compiled inference function.
Mostly this is very similar to _aot_stage2b_fw_compile.
Before compiling, we run pre_compile for the following wrappers:
- FakifiedOutWrapper
- FunctionalizedRngRuntimeWrapper
After compiling, we run post_compile for the following wrappers:
- EffectTokensWrapper
- AOTDispatchSubclassWrapper
- FunctionalizedRngRuntimeWrapper
- FakifiedOutWrapper
"""
disable_amp = torch._C._is_any_autocast_enabled()
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context(), track_graph_compiling(aot_config, "inference"):
fakified_out_wrapper = FakifiedOutWrapper()
fakified_out_wrapper.pre_compile(
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
)
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper()
functionalized_rng_wrapper.pre_compile(
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
)
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.fw_metadata = _get_inner_meta(
maybe_subclass_meta,
fw_metadata,
)
with TracingContext.report_output_strides() as fwd_output_strides:
compiled_fw = aot_config.inference_compiler(fw_module, updated_flat_args)
# However, RuntimeWrapper does not expect the rng offsets in the
# output. So, we have to create another wrapper and take out the offset. As
# a result, we have to account for not boxed_call compilers as well.
if not getattr(compiled_fw, "_boxed_call", False):
compiled_fw = make_boxed_func(compiled_fw)
if fakified_out_wrapper.needs_post_compile:
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
compiled_fw = EffectTokensWrapper().post_compile(
compiled_fw,
aot_config,
runtime_metadata=fw_metadata,
)
# Why do we need to pass in num_fw_outs_saved_for_bw?
# See Note: [Partitioner handling for Subclasses, Part 2]
compiled_fw = AOTDispatchSubclassWrapper(
trace_joint=False,
# TODO: once we use pre_compile this will be flat_fn at the top of this function
fw_only=None,
maybe_subclass_meta=maybe_subclass_meta,
num_fw_outs_saved_for_bw=None,
).post_compile(
compiled_fw,
aot_config, # not used
runtime_metadata=fw_metadata,
)
# Create a wrapper to set up the rng functionalize and fakified out bits
compiled_fw = functionalized_rng_wrapper.post_compile(
compiled_fw, aot_config, runtime_metadata=fw_metadata
)
compiled_fw = fakified_out_wrapper.post_compile(
compiled_fw,
aot_config,
runtime_metadata=fw_metadata,
)
return compiled_fw
return _aot_stage2b_compile_forward_or_inference(
fw_module,
updated_flat_args, # type: ignore[arg-type]
maybe_subclass_meta,
fw_metadata,
aot_config,
is_inference=True,
)[1]
def aot_stage2_inference(
@ -1751,88 +1682,15 @@ def _aot_stage2b_fw_compile(
num_fw_outs_saved_for_bw: int,
aot_config: AOTConfig,
) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]:
"""
Compile the forward graph. Returns:
- the output strides of the forward graph
- the compiled forward function
Before compiling, we run pre_compile for the following wrappers:
- FakifiedOutWrapper
- FunctionalizedRngRuntimeWrapper
After compiling, we run post_compile for the following wrappers:
- EffectTokensWrapper
- AOTDispatchSubclassWrapper
- FunctionalizedRngRuntimeWrapper
- FakifiedOutWrapper
"""
with torch.no_grad():
# AMP is already traced out in joint graph. we do not wish to reapply it accidentally
# in the compiler.
with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast():
# flat_args at this point might still be subclasses-
# make sure to pass the unwrapped fake tensors into the compiler!
fakified_out_wrapper = FakifiedOutWrapper()
fakified_out_wrapper.pre_compile(
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
)
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper(
return_new_outs=False
)
if fw_metadata.num_graphsafe_rng_states > 0:
index = fw_metadata.graphsafe_rng_state_index
assert index is not None
rng_states = [
get_cuda_generator_meta_val(index)
for _ in range(fw_metadata.num_graphsafe_rng_states)
]
adjusted_flat_args.extend(rng_states) # type: ignore[arg-type]
functionalized_rng_wrapper.pre_compile(
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
)
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.fw_metadata = _get_inner_meta(
maybe_subclass_meta, fw_metadata
)
with TracingContext.report_output_strides() as fwd_output_strides:
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
if not getattr(compiled_fw_func, "_boxed_call", False):
compiled_fw_func = make_boxed_func(compiled_fw_func)
if fakified_out_wrapper.needs_post_compile:
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
compiled_fw_func = EffectTokensWrapper().post_compile(
compiled_fw_func,
aot_config,
runtime_metadata=fw_metadata,
)
compiled_fw_func = AOTDispatchSubclassWrapper(
fw_only=None,
trace_joint=False,
maybe_subclass_meta=maybe_subclass_meta,
num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw,
).post_compile(
compiled_fw_func,
aot_config, # not used
runtime_metadata=fw_metadata,
)
compiled_fw_func = functionalized_rng_wrapper.post_compile(
compiled_fw_func, aot_config, runtime_metadata=fw_metadata
)
compiled_fw_func = fakified_out_wrapper.post_compile(
compiled_fw_func,
aot_config,
runtime_metadata=fw_metadata,
)
return fwd_output_strides, compiled_fw_func
return _aot_stage2b_compile_forward_or_inference(
fw_module,
adjusted_flat_args,
maybe_subclass_meta,
fw_metadata,
aot_config,
is_inference=False,
num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw,
)
def _aot_stage2b_bw_compile(
@ -2150,3 +2008,132 @@ def aot_stage2_autograd(
runtime_metadata=fw_metadata,
)
return compiled_fn
def _aot_stage2b_compile_forward_or_inference(
fw_module: torch.fx.GraphModule,
adjusted_flat_args: list[Any],
maybe_subclass_meta: Optional[SubclassMeta],
fw_metadata: ViewAndMutationMeta,
aot_config: AOTConfig,
*,
is_inference: bool,
num_fw_outs_saved_for_bw: Optional[int] = None,
) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]:
"""
Compile the forward or inference graph. Returns:
- the output strides of the forward graph
- the compiled forward/inference function
Args:
fw_module: The forward graph module to compile
adjusted_flat_args: Flattened arguments after adjustments
maybe_subclass_meta: Metadata for tensor subclasses
fw_metadata: View and mutation metadata
aot_config: AOT configuration
is_inference: If True, compile for inference; if False, compile for forward (autograd)
num_fw_outs_saved_for_bw: Number of forward outputs saved for backward (required if not is_inference)
Before compiling, we run pre_compile for the following wrappers:
- FakifiedOutWrapper
- FunctionalizedRngRuntimeWrapper
After compiling, we run post_compile for the following wrappers:
- EffectTokensWrapper
- AOTDispatchSubclassWrapper
- FunctionalizedRngRuntimeWrapper
- FakifiedOutWrapper
"""
# Validation
if not is_inference and num_fw_outs_saved_for_bw is None:
raise ValueError(
"num_fw_outs_saved_for_bw must be provided when is_inference=False"
)
# Determine grad context, autocast context, tracking mode, compiler
if is_inference:
grad_ctx: Any = nullcontext
autocast_ctx: Any = (
torch._C._DisableAutocast
if torch._C._is_any_autocast_enabled()
else nullcontext
)
tracking_mode: str = "inference"
compiler: Any = aot_config.inference_compiler
else:
grad_ctx = torch.no_grad
autocast_ctx = torch._C._DisableAutocast
tracking_mode = "forward"
compiler = aot_config.fw_compiler
with grad_ctx(), autocast_ctx(), track_graph_compiling(aot_config, tracking_mode):
# Setup wrappers
fakified_out_wrapper = FakifiedOutWrapper()
fakified_out_wrapper.pre_compile(
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
)
# Initialize RNG wrapper based on mode
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper(
return_new_outs=is_inference
)
# Add RNG states for forward mode only
if not is_inference and fw_metadata.num_graphsafe_rng_states > 0:
index = fw_metadata.graphsafe_rng_state_index
assert index is not None
rng_states = [
get_cuda_generator_meta_val(index)
for _ in range(fw_metadata.num_graphsafe_rng_states)
]
adjusted_flat_args.extend(rng_states) # type: ignore[arg-type]
functionalized_rng_wrapper.pre_compile(
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
)
# Set tracing context
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.fw_metadata = _get_inner_meta(
maybe_subclass_meta, fw_metadata
)
with TracingContext.report_output_strides() as fwd_output_strides:
compiled_fw_func = compiler(fw_module, adjusted_flat_args)
# Make boxed if needed
if not getattr(compiled_fw_func, "_boxed_call", False):
compiled_fw_func = make_boxed_func(compiled_fw_func)
# Set forward output strides if needed
if fakified_out_wrapper.needs_post_compile:
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
# Apply post-compile wrappers
compiled_fw_func = EffectTokensWrapper().post_compile(
compiled_fw_func,
aot_config,
runtime_metadata=fw_metadata,
)
compiled_fw_func = AOTDispatchSubclassWrapper(
fw_only=None,
trace_joint=False,
maybe_subclass_meta=maybe_subclass_meta,
num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw,
).post_compile(
compiled_fw_func,
aot_config,
runtime_metadata=fw_metadata,
)
compiled_fw_func = functionalized_rng_wrapper.post_compile(
compiled_fw_func, aot_config, runtime_metadata=fw_metadata
)
compiled_fw_func = fakified_out_wrapper.post_compile(
compiled_fw_func,
aot_config,
runtime_metadata=fw_metadata,
)
return fwd_output_strides, compiled_fw_func