diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index aac28cbabe61..4fc9d8c2e79d 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -322,83 +322,14 @@ def _aot_stage2b_inference_compile( fw_metadata: ViewAndMutationMeta, aot_config, ) -> Callable: - """ - Compile the inference graph. Returns the compiled inference function. - - Mostly this is very similar to _aot_stage2b_fw_compile. - - Before compiling, we run pre_compile for the following wrappers: - - FakifiedOutWrapper - - FunctionalizedRngRuntimeWrapper - After compiling, we run post_compile for the following wrappers: - - EffectTokensWrapper - - AOTDispatchSubclassWrapper - - FunctionalizedRngRuntimeWrapper - - FakifiedOutWrapper - """ - disable_amp = torch._C._is_any_autocast_enabled() - context = torch._C._DisableAutocast if disable_amp else nullcontext - - with context(), track_graph_compiling(aot_config, "inference"): - fakified_out_wrapper = FakifiedOutWrapper() - fakified_out_wrapper.pre_compile( - fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata - ) - functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper() - functionalized_rng_wrapper.pre_compile( - fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata - ) - - if tracing_context := torch._guards.TracingContext.try_get(): - tracing_context.fw_metadata = _get_inner_meta( - maybe_subclass_meta, - fw_metadata, - ) - - with TracingContext.report_output_strides() as fwd_output_strides: - compiled_fw = aot_config.inference_compiler(fw_module, updated_flat_args) - - # However, RuntimeWrapper does not expect the rng offsets in the - # output. So, we have to create another wrapper and take out the offset. As - # a result, we have to account for not boxed_call compilers as well. - if not getattr(compiled_fw, "_boxed_call", False): - compiled_fw = make_boxed_func(compiled_fw) - - if fakified_out_wrapper.needs_post_compile: - fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) - - compiled_fw = EffectTokensWrapper().post_compile( - compiled_fw, - aot_config, - runtime_metadata=fw_metadata, - ) - - # Why do we need to pass in num_fw_outs_saved_for_bw? - # See Note: [Partitioner handling for Subclasses, Part 2] - compiled_fw = AOTDispatchSubclassWrapper( - trace_joint=False, - # TODO: once we use pre_compile this will be flat_fn at the top of this function - fw_only=None, - maybe_subclass_meta=maybe_subclass_meta, - num_fw_outs_saved_for_bw=None, - ).post_compile( - compiled_fw, - aot_config, # not used - runtime_metadata=fw_metadata, - ) - - # Create a wrapper to set up the rng functionalize and fakified out bits - compiled_fw = functionalized_rng_wrapper.post_compile( - compiled_fw, aot_config, runtime_metadata=fw_metadata - ) - - compiled_fw = fakified_out_wrapper.post_compile( - compiled_fw, - aot_config, - runtime_metadata=fw_metadata, - ) - - return compiled_fw + return _aot_stage2b_compile_forward_or_inference( + fw_module, + updated_flat_args, # type: ignore[arg-type] + maybe_subclass_meta, + fw_metadata, + aot_config, + is_inference=True, + )[1] def aot_stage2_inference( @@ -1751,88 +1682,15 @@ def _aot_stage2b_fw_compile( num_fw_outs_saved_for_bw: int, aot_config: AOTConfig, ) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]: - """ - Compile the forward graph. Returns: - - the output strides of the forward graph - - the compiled forward function - - Before compiling, we run pre_compile for the following wrappers: - - FakifiedOutWrapper - - FunctionalizedRngRuntimeWrapper - After compiling, we run post_compile for the following wrappers: - - EffectTokensWrapper - - AOTDispatchSubclassWrapper - - FunctionalizedRngRuntimeWrapper - - FakifiedOutWrapper - """ - with torch.no_grad(): - # AMP is already traced out in joint graph. we do not wish to reapply it accidentally - # in the compiler. - with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast(): - # flat_args at this point might still be subclasses- - # make sure to pass the unwrapped fake tensors into the compiler! - fakified_out_wrapper = FakifiedOutWrapper() - fakified_out_wrapper.pre_compile( - fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata - ) - - functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper( - return_new_outs=False - ) - - if fw_metadata.num_graphsafe_rng_states > 0: - index = fw_metadata.graphsafe_rng_state_index - assert index is not None - rng_states = [ - get_cuda_generator_meta_val(index) - for _ in range(fw_metadata.num_graphsafe_rng_states) - ] - adjusted_flat_args.extend(rng_states) # type: ignore[arg-type] - - functionalized_rng_wrapper.pre_compile( - fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata - ) - if tracing_context := torch._guards.TracingContext.try_get(): - tracing_context.fw_metadata = _get_inner_meta( - maybe_subclass_meta, fw_metadata - ) - - with TracingContext.report_output_strides() as fwd_output_strides: - compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args) - - if not getattr(compiled_fw_func, "_boxed_call", False): - compiled_fw_func = make_boxed_func(compiled_fw_func) - - if fakified_out_wrapper.needs_post_compile: - fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) - - compiled_fw_func = EffectTokensWrapper().post_compile( - compiled_fw_func, - aot_config, - runtime_metadata=fw_metadata, - ) - - compiled_fw_func = AOTDispatchSubclassWrapper( - fw_only=None, - trace_joint=False, - maybe_subclass_meta=maybe_subclass_meta, - num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, - ).post_compile( - compiled_fw_func, - aot_config, # not used - runtime_metadata=fw_metadata, - ) - - compiled_fw_func = functionalized_rng_wrapper.post_compile( - compiled_fw_func, aot_config, runtime_metadata=fw_metadata - ) - compiled_fw_func = fakified_out_wrapper.post_compile( - compiled_fw_func, - aot_config, - runtime_metadata=fw_metadata, - ) - - return fwd_output_strides, compiled_fw_func + return _aot_stage2b_compile_forward_or_inference( + fw_module, + adjusted_flat_args, + maybe_subclass_meta, + fw_metadata, + aot_config, + is_inference=False, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + ) def _aot_stage2b_bw_compile( @@ -2150,3 +2008,132 @@ def aot_stage2_autograd( runtime_metadata=fw_metadata, ) return compiled_fn + + +def _aot_stage2b_compile_forward_or_inference( + fw_module: torch.fx.GraphModule, + adjusted_flat_args: list[Any], + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, + *, + is_inference: bool, + num_fw_outs_saved_for_bw: Optional[int] = None, +) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]: + """ + Compile the forward or inference graph. Returns: + - the output strides of the forward graph + - the compiled forward/inference function + + Args: + fw_module: The forward graph module to compile + adjusted_flat_args: Flattened arguments after adjustments + maybe_subclass_meta: Metadata for tensor subclasses + fw_metadata: View and mutation metadata + aot_config: AOT configuration + is_inference: If True, compile for inference; if False, compile for forward (autograd) + num_fw_outs_saved_for_bw: Number of forward outputs saved for backward (required if not is_inference) + + Before compiling, we run pre_compile for the following wrappers: + - FakifiedOutWrapper + - FunctionalizedRngRuntimeWrapper + After compiling, we run post_compile for the following wrappers: + - EffectTokensWrapper + - AOTDispatchSubclassWrapper + - FunctionalizedRngRuntimeWrapper + - FakifiedOutWrapper + """ + # Validation + if not is_inference and num_fw_outs_saved_for_bw is None: + raise ValueError( + "num_fw_outs_saved_for_bw must be provided when is_inference=False" + ) + + # Determine grad context, autocast context, tracking mode, compiler + if is_inference: + grad_ctx: Any = nullcontext + autocast_ctx: Any = ( + torch._C._DisableAutocast + if torch._C._is_any_autocast_enabled() + else nullcontext + ) + tracking_mode: str = "inference" + compiler: Any = aot_config.inference_compiler + else: + grad_ctx = torch.no_grad + autocast_ctx = torch._C._DisableAutocast + tracking_mode = "forward" + compiler = aot_config.fw_compiler + + with grad_ctx(), autocast_ctx(), track_graph_compiling(aot_config, tracking_mode): + # Setup wrappers + fakified_out_wrapper = FakifiedOutWrapper() + fakified_out_wrapper.pre_compile( + fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata + ) + + # Initialize RNG wrapper based on mode + functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper( + return_new_outs=is_inference + ) + + # Add RNG states for forward mode only + if not is_inference and fw_metadata.num_graphsafe_rng_states > 0: + index = fw_metadata.graphsafe_rng_state_index + assert index is not None + rng_states = [ + get_cuda_generator_meta_val(index) + for _ in range(fw_metadata.num_graphsafe_rng_states) + ] + adjusted_flat_args.extend(rng_states) # type: ignore[arg-type] + + functionalized_rng_wrapper.pre_compile( + fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata + ) + + # Set tracing context + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.fw_metadata = _get_inner_meta( + maybe_subclass_meta, fw_metadata + ) + + with TracingContext.report_output_strides() as fwd_output_strides: + compiled_fw_func = compiler(fw_module, adjusted_flat_args) + + # Make boxed if needed + if not getattr(compiled_fw_func, "_boxed_call", False): + compiled_fw_func = make_boxed_func(compiled_fw_func) + + # Set forward output strides if needed + if fakified_out_wrapper.needs_post_compile: + fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) + + # Apply post-compile wrappers + compiled_fw_func = EffectTokensWrapper().post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fw_func = AOTDispatchSubclassWrapper( + fw_only=None, + trace_joint=False, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + ).post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fw_func = functionalized_rng_wrapper.post_compile( + compiled_fw_func, aot_config, runtime_metadata=fw_metadata + ) + + compiled_fw_func = fakified_out_wrapper.post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + return fwd_output_strides, compiled_fw_func