mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
consolidate fw and inference compile paths (#165457)
By design, fw compile and inference compile stages should share a bunch of code; just consolidating the duplication here. Differential Revision: D84628978 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165457 Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
dfc8a1c5dd
commit
fa1539594b
@ -322,83 +322,14 @@ def _aot_stage2b_inference_compile(
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
aot_config,
|
||||
) -> Callable:
|
||||
"""
|
||||
Compile the inference graph. Returns the compiled inference function.
|
||||
|
||||
Mostly this is very similar to _aot_stage2b_fw_compile.
|
||||
|
||||
Before compiling, we run pre_compile for the following wrappers:
|
||||
- FakifiedOutWrapper
|
||||
- FunctionalizedRngRuntimeWrapper
|
||||
After compiling, we run post_compile for the following wrappers:
|
||||
- EffectTokensWrapper
|
||||
- AOTDispatchSubclassWrapper
|
||||
- FunctionalizedRngRuntimeWrapper
|
||||
- FakifiedOutWrapper
|
||||
"""
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
|
||||
with context(), track_graph_compiling(aot_config, "inference"):
|
||||
fakified_out_wrapper = FakifiedOutWrapper()
|
||||
fakified_out_wrapper.pre_compile(
|
||||
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper()
|
||||
functionalized_rng_wrapper.pre_compile(
|
||||
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
tracing_context.fw_metadata = _get_inner_meta(
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
)
|
||||
|
||||
with TracingContext.report_output_strides() as fwd_output_strides:
|
||||
compiled_fw = aot_config.inference_compiler(fw_module, updated_flat_args)
|
||||
|
||||
# However, RuntimeWrapper does not expect the rng offsets in the
|
||||
# output. So, we have to create another wrapper and take out the offset. As
|
||||
# a result, we have to account for not boxed_call compilers as well.
|
||||
if not getattr(compiled_fw, "_boxed_call", False):
|
||||
compiled_fw = make_boxed_func(compiled_fw)
|
||||
|
||||
if fakified_out_wrapper.needs_post_compile:
|
||||
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
|
||||
|
||||
compiled_fw = EffectTokensWrapper().post_compile(
|
||||
compiled_fw,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
# Why do we need to pass in num_fw_outs_saved_for_bw?
|
||||
# See Note: [Partitioner handling for Subclasses, Part 2]
|
||||
compiled_fw = AOTDispatchSubclassWrapper(
|
||||
trace_joint=False,
|
||||
# TODO: once we use pre_compile this will be flat_fn at the top of this function
|
||||
fw_only=None,
|
||||
maybe_subclass_meta=maybe_subclass_meta,
|
||||
num_fw_outs_saved_for_bw=None,
|
||||
).post_compile(
|
||||
compiled_fw,
|
||||
aot_config, # not used
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
# Create a wrapper to set up the rng functionalize and fakified out bits
|
||||
compiled_fw = functionalized_rng_wrapper.post_compile(
|
||||
compiled_fw, aot_config, runtime_metadata=fw_metadata
|
||||
)
|
||||
|
||||
compiled_fw = fakified_out_wrapper.post_compile(
|
||||
compiled_fw,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
return compiled_fw
|
||||
return _aot_stage2b_compile_forward_or_inference(
|
||||
fw_module,
|
||||
updated_flat_args, # type: ignore[arg-type]
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
aot_config,
|
||||
is_inference=True,
|
||||
)[1]
|
||||
|
||||
|
||||
def aot_stage2_inference(
|
||||
@ -1751,88 +1682,15 @@ def _aot_stage2b_fw_compile(
|
||||
num_fw_outs_saved_for_bw: int,
|
||||
aot_config: AOTConfig,
|
||||
) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]:
|
||||
"""
|
||||
Compile the forward graph. Returns:
|
||||
- the output strides of the forward graph
|
||||
- the compiled forward function
|
||||
|
||||
Before compiling, we run pre_compile for the following wrappers:
|
||||
- FakifiedOutWrapper
|
||||
- FunctionalizedRngRuntimeWrapper
|
||||
After compiling, we run post_compile for the following wrappers:
|
||||
- EffectTokensWrapper
|
||||
- AOTDispatchSubclassWrapper
|
||||
- FunctionalizedRngRuntimeWrapper
|
||||
- FakifiedOutWrapper
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# AMP is already traced out in joint graph. we do not wish to reapply it accidentally
|
||||
# in the compiler.
|
||||
with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast():
|
||||
# flat_args at this point might still be subclasses-
|
||||
# make sure to pass the unwrapped fake tensors into the compiler!
|
||||
fakified_out_wrapper = FakifiedOutWrapper()
|
||||
fakified_out_wrapper.pre_compile(
|
||||
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
|
||||
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper(
|
||||
return_new_outs=False
|
||||
)
|
||||
|
||||
if fw_metadata.num_graphsafe_rng_states > 0:
|
||||
index = fw_metadata.graphsafe_rng_state_index
|
||||
assert index is not None
|
||||
rng_states = [
|
||||
get_cuda_generator_meta_val(index)
|
||||
for _ in range(fw_metadata.num_graphsafe_rng_states)
|
||||
]
|
||||
adjusted_flat_args.extend(rng_states) # type: ignore[arg-type]
|
||||
|
||||
functionalized_rng_wrapper.pre_compile(
|
||||
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
tracing_context.fw_metadata = _get_inner_meta(
|
||||
maybe_subclass_meta, fw_metadata
|
||||
)
|
||||
|
||||
with TracingContext.report_output_strides() as fwd_output_strides:
|
||||
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
|
||||
|
||||
if not getattr(compiled_fw_func, "_boxed_call", False):
|
||||
compiled_fw_func = make_boxed_func(compiled_fw_func)
|
||||
|
||||
if fakified_out_wrapper.needs_post_compile:
|
||||
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
|
||||
|
||||
compiled_fw_func = EffectTokensWrapper().post_compile(
|
||||
compiled_fw_func,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
compiled_fw_func = AOTDispatchSubclassWrapper(
|
||||
fw_only=None,
|
||||
trace_joint=False,
|
||||
maybe_subclass_meta=maybe_subclass_meta,
|
||||
num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw,
|
||||
).post_compile(
|
||||
compiled_fw_func,
|
||||
aot_config, # not used
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
compiled_fw_func = functionalized_rng_wrapper.post_compile(
|
||||
compiled_fw_func, aot_config, runtime_metadata=fw_metadata
|
||||
)
|
||||
compiled_fw_func = fakified_out_wrapper.post_compile(
|
||||
compiled_fw_func,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
return fwd_output_strides, compiled_fw_func
|
||||
return _aot_stage2b_compile_forward_or_inference(
|
||||
fw_module,
|
||||
adjusted_flat_args,
|
||||
maybe_subclass_meta,
|
||||
fw_metadata,
|
||||
aot_config,
|
||||
is_inference=False,
|
||||
num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw,
|
||||
)
|
||||
|
||||
|
||||
def _aot_stage2b_bw_compile(
|
||||
@ -2150,3 +2008,132 @@ def aot_stage2_autograd(
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
return compiled_fn
|
||||
|
||||
|
||||
def _aot_stage2b_compile_forward_or_inference(
|
||||
fw_module: torch.fx.GraphModule,
|
||||
adjusted_flat_args: list[Any],
|
||||
maybe_subclass_meta: Optional[SubclassMeta],
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
is_inference: bool,
|
||||
num_fw_outs_saved_for_bw: Optional[int] = None,
|
||||
) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]:
|
||||
"""
|
||||
Compile the forward or inference graph. Returns:
|
||||
- the output strides of the forward graph
|
||||
- the compiled forward/inference function
|
||||
|
||||
Args:
|
||||
fw_module: The forward graph module to compile
|
||||
adjusted_flat_args: Flattened arguments after adjustments
|
||||
maybe_subclass_meta: Metadata for tensor subclasses
|
||||
fw_metadata: View and mutation metadata
|
||||
aot_config: AOT configuration
|
||||
is_inference: If True, compile for inference; if False, compile for forward (autograd)
|
||||
num_fw_outs_saved_for_bw: Number of forward outputs saved for backward (required if not is_inference)
|
||||
|
||||
Before compiling, we run pre_compile for the following wrappers:
|
||||
- FakifiedOutWrapper
|
||||
- FunctionalizedRngRuntimeWrapper
|
||||
After compiling, we run post_compile for the following wrappers:
|
||||
- EffectTokensWrapper
|
||||
- AOTDispatchSubclassWrapper
|
||||
- FunctionalizedRngRuntimeWrapper
|
||||
- FakifiedOutWrapper
|
||||
"""
|
||||
# Validation
|
||||
if not is_inference and num_fw_outs_saved_for_bw is None:
|
||||
raise ValueError(
|
||||
"num_fw_outs_saved_for_bw must be provided when is_inference=False"
|
||||
)
|
||||
|
||||
# Determine grad context, autocast context, tracking mode, compiler
|
||||
if is_inference:
|
||||
grad_ctx: Any = nullcontext
|
||||
autocast_ctx: Any = (
|
||||
torch._C._DisableAutocast
|
||||
if torch._C._is_any_autocast_enabled()
|
||||
else nullcontext
|
||||
)
|
||||
tracking_mode: str = "inference"
|
||||
compiler: Any = aot_config.inference_compiler
|
||||
else:
|
||||
grad_ctx = torch.no_grad
|
||||
autocast_ctx = torch._C._DisableAutocast
|
||||
tracking_mode = "forward"
|
||||
compiler = aot_config.fw_compiler
|
||||
|
||||
with grad_ctx(), autocast_ctx(), track_graph_compiling(aot_config, tracking_mode):
|
||||
# Setup wrappers
|
||||
fakified_out_wrapper = FakifiedOutWrapper()
|
||||
fakified_out_wrapper.pre_compile(
|
||||
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
|
||||
# Initialize RNG wrapper based on mode
|
||||
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper(
|
||||
return_new_outs=is_inference
|
||||
)
|
||||
|
||||
# Add RNG states for forward mode only
|
||||
if not is_inference and fw_metadata.num_graphsafe_rng_states > 0:
|
||||
index = fw_metadata.graphsafe_rng_state_index
|
||||
assert index is not None
|
||||
rng_states = [
|
||||
get_cuda_generator_meta_val(index)
|
||||
for _ in range(fw_metadata.num_graphsafe_rng_states)
|
||||
]
|
||||
adjusted_flat_args.extend(rng_states) # type: ignore[arg-type]
|
||||
|
||||
functionalized_rng_wrapper.pre_compile(
|
||||
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
|
||||
# Set tracing context
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
tracing_context.fw_metadata = _get_inner_meta(
|
||||
maybe_subclass_meta, fw_metadata
|
||||
)
|
||||
|
||||
with TracingContext.report_output_strides() as fwd_output_strides:
|
||||
compiled_fw_func = compiler(fw_module, adjusted_flat_args)
|
||||
|
||||
# Make boxed if needed
|
||||
if not getattr(compiled_fw_func, "_boxed_call", False):
|
||||
compiled_fw_func = make_boxed_func(compiled_fw_func)
|
||||
|
||||
# Set forward output strides if needed
|
||||
if fakified_out_wrapper.needs_post_compile:
|
||||
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
|
||||
|
||||
# Apply post-compile wrappers
|
||||
compiled_fw_func = EffectTokensWrapper().post_compile(
|
||||
compiled_fw_func,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
compiled_fw_func = AOTDispatchSubclassWrapper(
|
||||
fw_only=None,
|
||||
trace_joint=False,
|
||||
maybe_subclass_meta=maybe_subclass_meta,
|
||||
num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw,
|
||||
).post_compile(
|
||||
compiled_fw_func,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
compiled_fw_func = functionalized_rng_wrapper.post_compile(
|
||||
compiled_fw_func, aot_config, runtime_metadata=fw_metadata
|
||||
)
|
||||
|
||||
compiled_fw_func = fakified_out_wrapper.post_compile(
|
||||
compiled_fw_func,
|
||||
aot_config,
|
||||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
return fwd_output_strides, compiled_fw_func
|
||||
|
Reference in New Issue
Block a user