mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 11:14:56 +08:00
Compare commits
5 Commits
ciflow/tru
...
mlazos/tup
| Author | SHA1 | Date | |
|---|---|---|---|
| 1e5f51ed2c | |||
| 2a738143b4 | |||
| 965f830bc9 | |||
| a097a0d3b2 | |||
| 255c2b0d6c |
294
benchmarks/dynamo/diffusers/auroflow.py
Normal file
294
benchmarks/dynamo/diffusers/auroflow.py
Normal file
@ -0,0 +1,294 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AuraFlowPipeline,
|
||||
AuraFlowTransformer2DModel,
|
||||
GGUFQuantizationConfig,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
|
||||
|
||||
|
||||
def compile_full_model(model):
|
||||
model.compile(fullgraph=True, mode="reduce-overhead")
|
||||
|
||||
|
||||
def compile_regions(model, nn_modules):
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
# for submod in model.modules():
|
||||
# if isinstance(submod, nn_modules):
|
||||
# print("Compiling", submod.__class__)
|
||||
# submod.compile(fullgraph=True)
|
||||
|
||||
|
||||
def compile_hierarchical(model, nn_modules):
|
||||
for submod in model.modules():
|
||||
if isinstance(submod, nn_modules):
|
||||
submod.__class__.forward = mark_compile_region(submod.__class__.forward)
|
||||
model.compile(fullgraph=True)
|
||||
|
||||
|
||||
def auroflow_benchmark(mode):
|
||||
transformer = AuraFlowTransformer2DModel.from_single_file(
|
||||
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipeline = AuraFlowPipeline.from_pretrained(
|
||||
"fal/AuraFlow-v0.3",
|
||||
torch_dtype=torch.bfloat16,
|
||||
transformer=transformer,
|
||||
).to("cuda")
|
||||
|
||||
if mode == "full":
|
||||
compile_full_model(pipeline.transformer)
|
||||
elif mode == "regional":
|
||||
compile_regions(
|
||||
pipeline.transformer,
|
||||
(
|
||||
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowSingleTransformerBlock,
|
||||
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowJointTransformerBlock,
|
||||
),
|
||||
)
|
||||
elif mode == "hierarchical":
|
||||
compile_hierarchical(
|
||||
pipeline.transformer,
|
||||
(
|
||||
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowSingleTransformerBlock,
|
||||
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowJointTransformerBlock,
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert mode == "eager"
|
||||
|
||||
pipeline("A cute pony", width=512, height=512, num_inference_steps=1)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
pipeline("A cute pony", width=512, height=512, num_inference_steps=50)
|
||||
t1 = time.perf_counter()
|
||||
print(t1 - t0)
|
||||
|
||||
|
||||
def wan_benchmark(mode):
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||
from diffusers.utils import load_image
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
model_id, subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# replace this with pipe.to("cuda") if you have sufficient VRAM
|
||||
# pipe.enable_model_cpu_offload()
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 480 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
if mode == "full":
|
||||
compile_full_model(pipe.transformer)
|
||||
elif mode == "regional":
|
||||
compile_regions(
|
||||
pipe.transformer,
|
||||
(diffusers.models.transformers.transformer_wan.WanTransformerBlock,),
|
||||
)
|
||||
elif mode == "hierarchical":
|
||||
compile_hierarchical(
|
||||
pipe.transformer,
|
||||
(diffusers.models.transformers.transformer_wan.WanTransformerBlock,),
|
||||
)
|
||||
else:
|
||||
assert mode == "eager"
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=1,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
t0 = time.perf_counter()
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
t1 = time.perf_counter()
|
||||
print(t1 - t0)
|
||||
|
||||
|
||||
def ltx_benchmark(mode):
|
||||
from diffusers import LTXConditionPipeline
|
||||
|
||||
import torch
|
||||
|
||||
pipe = LTXConditionPipeline.from_pretrained(
|
||||
"Lightricks/LTX-Video-0.9.7-dev", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
def round_to_nearest_resolution_acceptable_by_vae(height, width):
|
||||
height = height - (height % pipe.vae_spatial_compression_ratio)
|
||||
width = width - (width % pipe.vae_spatial_compression_ratio)
|
||||
return height, width
|
||||
|
||||
prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
expected_height, expected_width = 512, 704
|
||||
downscale_factor = 2 / 3
|
||||
num_frames = 121
|
||||
|
||||
# Part 1. Generate video at smaller resolution
|
||||
downscaled_height, downscaled_width = (
|
||||
int(expected_height * downscale_factor),
|
||||
int(expected_width * downscale_factor),
|
||||
)
|
||||
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
|
||||
downscaled_height, downscaled_width
|
||||
)
|
||||
|
||||
if mode == "full":
|
||||
compile_full_model(pipe.transformer)
|
||||
elif mode == "regional":
|
||||
compile_regions(
|
||||
pipe.transformer,
|
||||
(diffusers.models.transformers.transformer_ltx.LTXVideoTransformerBlock,),
|
||||
)
|
||||
elif mode == "hierarchical":
|
||||
compile_hierarchical(
|
||||
pipe.transformer,
|
||||
(diffusers.models.transformers.transformer_ltx.LTXVideoTransformerBlock,),
|
||||
)
|
||||
|
||||
latents = pipe(
|
||||
conditions=None,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=downscaled_width,
|
||||
height=downscaled_height,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=1,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
output_type="latent",
|
||||
).frames
|
||||
|
||||
import time
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
latents = pipe(
|
||||
conditions=None,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=downscaled_width,
|
||||
height=downscaled_height,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=50,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
output_type="latent",
|
||||
).frames
|
||||
t1 = time.time()
|
||||
print(t1 - t0)
|
||||
|
||||
|
||||
def flux_benchmark(mode):
|
||||
import diffusers
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
|
||||
if mode == "full":
|
||||
compile_full_model(pipe.transformer)
|
||||
elif mode == "regional":
|
||||
compile_regions(
|
||||
pipe.transformer,
|
||||
(
|
||||
diffusers.models.transformers.transformer_flux.FluxTransformerBlock,
|
||||
diffusers.models.transformers.transformer_flux.FluxSingleTransformerBlock,
|
||||
),
|
||||
)
|
||||
elif mode == "hierarchical":
|
||||
compile_hierarchical(
|
||||
pipe.transformer,
|
||||
(
|
||||
diffusers.models.transformers.transformer_flux.FluxTransformerBlock,
|
||||
diffusers.models.transformers.transformer_flux.FluxSingleTransformerBlock,
|
||||
),
|
||||
)
|
||||
|
||||
pipe(
|
||||
prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
guidance_scale=3.5,
|
||||
num_inference_steps=1,
|
||||
max_sequence_length=512,
|
||||
)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
pipe(
|
||||
prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
guidance_scale=3.5,
|
||||
num_inference_steps=50,
|
||||
max_sequence_length=512,
|
||||
)
|
||||
t1 = time.perf_counter()
|
||||
print(t1 - t0)
|
||||
|
||||
|
||||
model_name = sys.argv[1]
|
||||
mode = sys.argv[2]
|
||||
|
||||
if model_name == "auroflow":
|
||||
auroflow_benchmark(mode)
|
||||
elif model_name == "wan":
|
||||
wan_benchmark(mode)
|
||||
elif model_name == "ltx":
|
||||
ltx_benchmark(mode)
|
||||
elif model_name == "flux":
|
||||
flux_benchmark(mode)
|
||||
50
benchmarks/dynamo/transformers/gemma2_tx.py
Normal file
50
benchmarks/dynamo/transformers/gemma2_tx.py
Normal file
@ -0,0 +1,50 @@
|
||||
import time
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2-9b-it",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Write me a poem about Machine Learning."},
|
||||
]
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
messages, return_tensors="pt", return_dict=True
|
||||
).to("cuda")
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model.generate(**input_ids, max_new_tokens=256)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
outputs = model.generate(**input_ids, max_new_tokens=256)
|
||||
t1 = time.perf_counter()
|
||||
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
print("Time :", t1 - t0)
|
||||
|
||||
|
||||
# generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
# generation = generation[0][input_len:]
|
||||
|
||||
# t0 = time.perf_counter()
|
||||
# with torch.inference_mode():
|
||||
# generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
# generation = generation[0][input_len:]
|
||||
# t1 = time.perf_counter()
|
||||
# print("Time :", t1 - t0)
|
||||
# decoded = processor.decode(generation, skip_special_tokens=True)
|
||||
# print(decoded)
|
||||
|
||||
# # input_text = "Write me a poem about Machine Learning."
|
||||
# # input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# # outputs = model.generate(**input_ids, max_new_tokens=32)
|
||||
# # print(tokenizer.decode(outputs[0]))
|
||||
61
benchmarks/dynamo/transformers/gemma3_multi.py
Normal file
61
benchmarks/dynamo/transformers/gemma3_multi.py
Normal file
@ -0,0 +1,61 @@
|
||||
# pip install accelerate
|
||||
|
||||
import time
|
||||
|
||||
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, device_map="auto"
|
||||
).eval()
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image in detail."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
generation = generation[0][input_len:]
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
generation = generation[0][input_len:]
|
||||
t1 = time.perf_counter()
|
||||
print("Time :", t1 - t0)
|
||||
decoded = processor.decode(generation, skip_special_tokens=True)
|
||||
print(decoded)
|
||||
|
||||
# **Overall Impression:** The image is a close-up shot of a vibrant garden scene,
|
||||
# focusing on a cluster of pink cosmos flowers and a busy bumblebee.
|
||||
# It has a slightly soft, natural feel, likely captured in daylight.
|
||||
56
benchmarks/dynamo/transformers/qwen3.py
Normal file
56
benchmarks/dynamo/transformers/qwen3.py
Normal file
@ -0,0 +1,56 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
model_name = "Qwen/Qwen3-8B"
|
||||
|
||||
# load the tokenizer and the model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
|
||||
|
||||
model.generation_config.cache_implementation = "static"
|
||||
|
||||
# prepare the model input
|
||||
prompt = "Give me a short introduction to large language model."
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True, # Switches between thinking and non-thinking modes. Default is True.
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
# conduct text completion
|
||||
with torch.inference_mode():
|
||||
generated_ids = model.generate(**model_inputs, max_new_tokens=200)
|
||||
|
||||
import time
|
||||
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
generated_ids = model.generate(**model_inputs, max_new_tokens=200)
|
||||
t1 = time.perf_counter()
|
||||
print("Time ", t1 - t0)
|
||||
|
||||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
|
||||
|
||||
# parsing thinking content
|
||||
try:
|
||||
# rindex finding 151668 (</think>)
|
||||
index = len(output_ids) - output_ids[::-1].index(151668)
|
||||
except ValueError:
|
||||
index = 0
|
||||
|
||||
thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(
|
||||
"\n"
|
||||
)
|
||||
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
||||
|
||||
print("thinking content:", thinking_content)
|
||||
print("content:", content)
|
||||
@ -29,4 +29,5 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
|
||||
skip_guard_on_all_nn_modules_unsafe
|
||||
keep_tensor_guards_unsafe
|
||||
skip_guard_on_globals_unsafe
|
||||
nested_compile_region
|
||||
```
|
||||
|
||||
@ -20,7 +20,6 @@ from torch._dynamo.testing import (
|
||||
InductorAndRecordGraphs,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
|
||||
from torch._higher_order_ops.schema import find_hop_schema
|
||||
from torch._inductor.pattern_matcher import (
|
||||
CallFunctionVarArgs,
|
||||
@ -37,6 +36,8 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
|
||||
|
||||
|
||||
nested_compile_region = torch.compiler.nested_compile_region
|
||||
|
||||
if HAS_GPU:
|
||||
import triton
|
||||
|
||||
@ -48,7 +49,7 @@ class TestInvokeSubgraph(TestCase):
|
||||
return torch.mul(x, y)
|
||||
|
||||
def fn(x, y):
|
||||
return mark_compile_region(gn)(x, y)
|
||||
return nested_compile_region(gn)(x, y)
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
@ -71,7 +72,7 @@ class TestInvokeSubgraph(TestCase):
|
||||
return torch.mul(x, y)
|
||||
|
||||
def fn(x, y):
|
||||
return mark_compile_region(gn)(x, y)
|
||||
return nested_compile_region(gn)(x, y)
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
@ -91,11 +92,11 @@ class TestInvokeSubgraph(TestCase):
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
def test_multiple(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def cos(x):
|
||||
return torch.cos(x)
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def sin(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@ -122,7 +123,7 @@ class TestInvokeSubgraphCompile(TestCase):
|
||||
self.assertEqual(len(subgraph_attr_names), expected)
|
||||
|
||||
def test_simple(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
return torch.mul(x, y)
|
||||
|
||||
@ -151,7 +152,7 @@ class TestInvokeSubgraphCompile(TestCase):
|
||||
super().__init__()
|
||||
self.c = 5
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def forward(self, x, y):
|
||||
return torch.mul(x, y).sin() + self.c
|
||||
|
||||
@ -182,7 +183,7 @@ class TestInvokeSubgraphCompile(TestCase):
|
||||
super().__init__()
|
||||
self.c = 5
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def forward(self, x, y):
|
||||
return torch.mul(x, y).sin() + self.c
|
||||
|
||||
@ -232,7 +233,7 @@ class TestInvokeSubgraphCompile(TestCase):
|
||||
self.c = 5
|
||||
self.register_buffer("buf", torch.ones(8, requires_grad=False))
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def forward(self, x, y):
|
||||
self.buf.add_(1)
|
||||
return torch.mul(x, y).sin() + self.c + self.buf
|
||||
@ -311,7 +312,7 @@ class GraphModule(torch.nn.Module):
|
||||
self.c = 5
|
||||
self.register_buffer("buf", torch.ones(8, requires_grad=False))
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def forward(self, x, y):
|
||||
return torch.mul(x, y).sin() * self.c * self.buf
|
||||
|
||||
@ -412,7 +413,7 @@ class GraphModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("buf", torch.ones(8, requires_grad=False))
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def forward(self, x, y):
|
||||
self.buf.add_(1)
|
||||
return torch.mul(x, y).sin() * self.buf
|
||||
@ -467,7 +468,7 @@ class GraphModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("buf", torch.ones(8, requires_grad=False))
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def forward(self, x, y):
|
||||
self.buf.add_(1)
|
||||
return torch.mul(x, y).sin() * self.buf
|
||||
@ -490,7 +491,7 @@ class GraphModule(torch.nn.Module):
|
||||
torch.compile(fn, backend="inductor", fullgraph=True)(mod, x, y)
|
||||
|
||||
def test_list(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
return [torch.mul(x, y), torch.add(x, y)]
|
||||
|
||||
@ -516,7 +517,7 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
def test_tuple_of_tuple(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
return ((torch.mul(x, y),), torch.add(x, y))
|
||||
|
||||
@ -552,7 +553,7 @@ class GraphModule(torch.nn.Module):
|
||||
a = grad_out.view(12, 5)
|
||||
return torch.cos(torch.reshape(a, (3, 4, 5)))
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return CustomOp.apply(x)
|
||||
|
||||
@ -579,7 +580,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
@requires_cuda
|
||||
def test_sdpa(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(q, k, v):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
|
||||
@ -611,7 +612,7 @@ class GraphModule(torch.nn.Module):
|
||||
res.sum().backward()
|
||||
|
||||
def test_symint_from_fwd_to_bwd(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
a = torch.sum(x, (1,), keepdim=True).view(y.shape[1], y.shape[0])
|
||||
return torch.matmul(a, y)
|
||||
@ -647,11 +648,11 @@ class GraphModule(torch.nn.Module):
|
||||
# graph passes. Without running joint graph passes, we would get an
|
||||
# error like AssertionError: should have been handled in
|
||||
# replace_random.py
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.nn.functional.dropout(torch.sin(x), p=0.5)
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def hn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@ -714,7 +715,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
def test_dropout_checks_joint_graph_inference(self):
|
||||
# Checks that joint graph results in inductor seeds for just the inference graph
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.nn.functional.dropout(torch.sin(x), p=0.5)
|
||||
|
||||
@ -753,7 +754,7 @@ class <lambda>(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_dedupe(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
return torch.mul(x, y)
|
||||
|
||||
@ -839,7 +840,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_dce(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
x = torch.sin(x)
|
||||
# should be dce'd
|
||||
@ -875,7 +876,7 @@ class <lambda>(torch.nn.Module):
|
||||
def test_nonlocal_update(self):
|
||||
counter = 2
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
nonlocal counter
|
||||
return (torch.mul(x, y) * counter,)
|
||||
@ -940,7 +941,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_view_to_reshape(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
x = torch.sin(x)
|
||||
x = x.view(1, 8)
|
||||
@ -980,7 +981,7 @@ class <lambda>(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_normalize_gm(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
# Different graph give different names to intermediate nodes
|
||||
for _ in range(5):
|
||||
@ -1038,7 +1039,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_input_mutation(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
x.add_(1)
|
||||
return torch.mul(x, y)
|
||||
@ -1053,7 +1054,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
|
||||
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
|
||||
) as cm:
|
||||
opt_fn(x, y)
|
||||
|
||||
@ -1063,12 +1064,29 @@ class GraphModule(torch.nn.Module):
|
||||
"Encountered input mutation during higher order op tracing" in str(cause)
|
||||
)
|
||||
|
||||
def test_input_mutation_inference_mode(self):
|
||||
@torch._dynamo.config.patch(does_invoke_subgraph_support_input_mutation=True)
|
||||
def test_input_mutation_with_config_flag(self):
|
||||
@mark_compile_region
|
||||
def gn(x, y):
|
||||
x.add_(1)
|
||||
return torch.mul(x, y)
|
||||
|
||||
def fn(x, y):
|
||||
return gn(x, y)
|
||||
|
||||
x1 = torch.ones(8, requires_grad=False)
|
||||
x2 = torch.ones(8, requires_grad=False)
|
||||
y = torch.randn(8, requires_grad=False)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
self.assertEqual(fn(x1, y), opt_fn(x2, y))
|
||||
|
||||
def test_input_mutation_inference_mode(self):
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
x.add_(1)
|
||||
return torch.mul(x, y)
|
||||
|
||||
def fn(x, y):
|
||||
z = torch.cos(x)
|
||||
with torch.inference_mode():
|
||||
@ -1080,7 +1098,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
|
||||
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
|
||||
) as cm:
|
||||
opt_fn(x, y)
|
||||
|
||||
@ -1093,7 +1111,7 @@ class GraphModule(torch.nn.Module):
|
||||
def test_simple_module(self):
|
||||
mod = torch.nn.Linear(8, 8)
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.cos(x), mod(x)
|
||||
|
||||
@ -1134,7 +1152,7 @@ class GraphModule(torch.nn.Module):
|
||||
opt_fn(x)
|
||||
|
||||
def test_input_output_aliasing(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
return (x, torch.mul(x, y))
|
||||
|
||||
@ -1149,7 +1167,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
|
||||
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
|
||||
) as cm:
|
||||
opt_fn(x, y)
|
||||
|
||||
@ -1160,7 +1178,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_input_input_aliasing(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
return torch.mul(x, y)
|
||||
|
||||
@ -1173,7 +1191,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
|
||||
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
|
||||
) as cm:
|
||||
opt_fn(x)
|
||||
|
||||
@ -1184,7 +1202,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_output_output_aliasing(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
z = torch.cos(x)
|
||||
return z, z.view(1, 8)
|
||||
@ -1198,7 +1216,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
|
||||
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
|
||||
) as cm:
|
||||
opt_fn(x)
|
||||
|
||||
@ -1218,7 +1236,7 @@ class GraphModule(torch.nn.Module):
|
||||
self.a.add_(1)
|
||||
return torch.mul(x, self.a)
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return mod(x)
|
||||
|
||||
@ -1235,7 +1253,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
|
||||
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
|
||||
) as cm:
|
||||
opt_fn(x, y)
|
||||
|
||||
@ -1246,8 +1264,8 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_redundant_compile_region(self):
|
||||
@mark_compile_region
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@ -1289,7 +1307,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_kwargs_only(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, *, y):
|
||||
return x * y
|
||||
|
||||
@ -1310,7 +1328,7 @@ class GraphModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(8, 8)
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def helper(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
@ -1367,7 +1385,7 @@ class GraphModule(torch.nn.Module):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.submod = mark_compile_region(SubMod())
|
||||
self.submod = nested_compile_region(SubMod())
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.submod(x) * self.submod(x) + x
|
||||
@ -1418,7 +1436,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
ones = torch.ones(1000, device="cuda:0", dtype=torch.float32)
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def fn(x, train):
|
||||
return F.dropout(x * weight, 0.33, train)
|
||||
|
||||
@ -1431,7 +1449,7 @@ class GraphModule(torch.nn.Module):
|
||||
weight.grad.clone()
|
||||
|
||||
def test_return_none_from_fwd(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return x * 2, None, x * 3
|
||||
|
||||
@ -1535,7 +1553,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_dynamic(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@ -1549,9 +1567,27 @@ class GraphModule(torch.nn.Module):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_return_size(self):
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
y = x + 1
|
||||
z = x.shape
|
||||
return y, z
|
||||
|
||||
def fn(x):
|
||||
z0 = gn(x)
|
||||
z1 = gn(x)
|
||||
return z0[0] + z1[0], z0[1]
|
||||
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
ref = fn(x)
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_complex(self):
|
||||
# Observed in Wan2.1
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@ -1566,7 +1602,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_pending_unbacked(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
u = x[0].item()
|
||||
return x * u
|
||||
@ -1585,7 +1621,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_unbacked(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
b = x.item()
|
||||
torch._check_is_size(b)
|
||||
@ -1605,7 +1641,7 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_bwd_partitioning(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
z = torch.matmul(x, y)
|
||||
return torch.sin(z)
|
||||
@ -1688,7 +1724,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_const_tensor(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.tensor(64, dtype=torch.float32) * x
|
||||
|
||||
@ -1707,14 +1743,14 @@ class GraphModule(torch.nn.Module):
|
||||
def fn1(x):
|
||||
return torch.cos(x)
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def fn1_checkpoint(x):
|
||||
return torch.utils.checkpoint.checkpoint(fn1, x, use_reentrant=False)
|
||||
|
||||
def fn2(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def fn2_checkpoint(x):
|
||||
return torch.utils.checkpoint.checkpoint(fn2, x, use_reentrant=False)
|
||||
|
||||
@ -1753,7 +1789,7 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_fake_tensor_checking(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@ -1807,7 +1843,7 @@ class GraphModule(torch.nn.Module):
|
||||
Tests check that the same subgraph called with different symints use different graphs
|
||||
"""
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@ -1882,7 +1918,7 @@ class GraphModule(torch.nn.Module):
|
||||
(x,) = ctx.saved_tensors
|
||||
return x * torch.cos(grad_out)
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return CustomOp.apply(x)
|
||||
|
||||
@ -1982,7 +2018,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
o = torch.zeros_like(x)
|
||||
call_triton_add(x, y, o, 0)
|
||||
@ -2052,7 +2088,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
def test_unbacked_symbol(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.sin(torch.nonzero(x))
|
||||
|
||||
@ -2069,7 +2105,7 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_different_strides_in_backward(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return torch.cos(x)
|
||||
|
||||
@ -2216,7 +2252,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_div(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
div = torch.div(1024, 256, rounding_mode="trunc")
|
||||
return div * torch.ones(64, div) * x
|
||||
@ -2292,7 +2328,7 @@ class GraphModule(torch.nn.Module):
|
||||
lib.impl("add_op", impl, "CompositeExplicitAutograd")
|
||||
lib.impl("add_op", meta, "Meta")
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(y, z):
|
||||
return torch.ops.mylib.add_op.default(y, z)
|
||||
|
||||
@ -2374,7 +2410,7 @@ class GraphModule(torch.nn.Module):
|
||||
lib.impl("add_op", impl, "CompositeExplicitAutograd")
|
||||
lib.impl("add_op", meta, "Meta")
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, other):
|
||||
y = x.transpose(2, 3).contiguous().transpose(2, 3)
|
||||
z = y.sin().transpose(2, 3)
|
||||
@ -2402,7 +2438,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
class TestInvokeSubgraphExport(TestCase):
|
||||
def test_simple_func(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
return torch.mul(x, y)
|
||||
|
||||
@ -2441,7 +2477,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_unbacked(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
b = x.item()
|
||||
torch._check_is_size(b)
|
||||
@ -2466,7 +2502,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
def test_pending_unbacked(self):
|
||||
class M(torch.nn.Module):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(self, x):
|
||||
u = x[0].item()
|
||||
return x * u
|
||||
@ -2498,7 +2534,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
def test_simple_method(self):
|
||||
class M(torch.nn.Module):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(self, x, y):
|
||||
return torch.mul(x, y)
|
||||
|
||||
@ -2522,7 +2558,7 @@ class GraphModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("buf", b)
|
||||
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def forward(self, x, y):
|
||||
return x * y + self.buf
|
||||
|
||||
@ -2546,7 +2582,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
class NegativeTesting(TestCase):
|
||||
def test_graph_break(self):
|
||||
@mark_compile_region
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
torch._dynamo.graph_break()
|
||||
return torch.cos(x)
|
||||
@ -2558,7 +2594,7 @@ class NegativeTesting(TestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
|
||||
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
|
||||
):
|
||||
torch.compile(fn, backend="eager")(x)
|
||||
|
||||
|
||||
@ -107,7 +107,7 @@ assume_static_by_default = True
|
||||
# with assume_static_by_default=True.
|
||||
# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail
|
||||
# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic.
|
||||
automatic_dynamic_shapes = True
|
||||
automatic_dynamic_shapes = False
|
||||
|
||||
# Valid options: "dynamic", "unbacked"
|
||||
automatic_dynamic_shapes_mark_as: Literal["dynamic", "unbacked"] = "dynamic"
|
||||
@ -623,6 +623,10 @@ wrap_top_frame = False
|
||||
# record pre-graph bytecode in profile traces
|
||||
record_pre_graph_bytecode_in_traces = True
|
||||
|
||||
# This is a short-lived flag to control whether invoke_subgraph supports
|
||||
# input mutation. It will be turned to True by default.
|
||||
does_invoke_subgraph_support_input_mutation = False
|
||||
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
|
||||
@ -9,9 +9,9 @@ structures across different parts of the network.
|
||||
|
||||
import logging
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
@ -74,12 +74,14 @@ when they are created in output_graph.
|
||||
|
||||
sub_gms: dict[str, torch.fx.GraphModule] = {}
|
||||
|
||||
for region_group in duplicated_region_groups:
|
||||
for region_group in duplicated_region_groups: # 0:1
|
||||
inds_with_external_users = _get_all_output_indices(region_group)
|
||||
region = region_group[0]
|
||||
(
|
||||
subgraph,
|
||||
external_node_usages,
|
||||
node_usage_to_tuple_elems,
|
||||
ind_to_tuple_spec,
|
||||
) = _create_subgraph(region, inds_with_external_users)
|
||||
|
||||
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
|
||||
@ -100,6 +102,8 @@ when they are created in output_graph.
|
||||
region,
|
||||
get_subgraph_node,
|
||||
external_node_usages,
|
||||
node_usage_to_tuple_elems,
|
||||
ind_to_tuple_spec,
|
||||
inds_with_external_users,
|
||||
subgraph_name,
|
||||
node_to_additional_deps,
|
||||
@ -122,14 +126,18 @@ def _replace_region_with_subgraph(
|
||||
region: Region,
|
||||
get_subgraph_node: Node,
|
||||
external_node_usages: Iterable[OrderedSet[UsageIndex]],
|
||||
node_usage_to_tuple_elems,
|
||||
ind_to_tuple_spec,
|
||||
inds_with_external_users: list[int],
|
||||
subgraph_name: str,
|
||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
||||
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
||||
) -> None:
|
||||
sub_args = []
|
||||
flattened_getitem_nodes = OrderedSet()
|
||||
for usages in external_node_usages:
|
||||
node_ind, usage_ind = next(iter(usages))
|
||||
usage = next(iter(usages))
|
||||
node_ind, usage_ind = usage
|
||||
node = region[node_ind]
|
||||
flattened_args_kwargs = _get_flat_args(node, {})
|
||||
for user_ind, node_usage_ind in usages:
|
||||
@ -140,12 +148,20 @@ def _replace_region_with_subgraph(
|
||||
"NYI: Failed to substitute region %s due to mutation", region
|
||||
)
|
||||
return
|
||||
sub_args.append(flattened_args_kwargs[usage_ind])
|
||||
|
||||
if usage in node_usage_to_tuple_elems:
|
||||
tuple_elems = [region[i] for i in node_usage_to_tuple_elems[usage]]
|
||||
flattened_getitem_nodes.update(tuple_elems)
|
||||
sub_args.extend(tuple_elems)
|
||||
else:
|
||||
sub_args.append(flattened_args_kwargs[usage_ind])
|
||||
|
||||
# Input/Output aliasing not supported in HOPs today
|
||||
# Note: we should use the nodes in the original graph (the region here)
|
||||
# because we use the original traced example values for this check
|
||||
if _has_aliasing(region, sub_args, inds_with_external_users):
|
||||
if _has_aliasing(
|
||||
region, sub_args, inds_with_external_users, flattened_getitem_nodes
|
||||
):
|
||||
return
|
||||
|
||||
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
|
||||
@ -156,16 +172,35 @@ def _replace_region_with_subgraph(
|
||||
invoke_args, # type: ignore[arg-type]
|
||||
{},
|
||||
)
|
||||
for ind, external_user_ind in enumerate(inds_with_external_users):
|
||||
|
||||
ind = 0
|
||||
flattened_output_nodes = OrderedSet()
|
||||
for external_user_ind in inds_with_external_users:
|
||||
node = region[external_user_ind]
|
||||
subgraph_output = graph.create_node(
|
||||
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
|
||||
)
|
||||
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
|
||||
if _is_tuple_node(node):
|
||||
tuple_spec = ind_to_tuple_spec[external_user_ind]
|
||||
flattened_output_nodes.update(
|
||||
_replace_tuple_outputs(
|
||||
node, ind, tuple_spec, invoke_subgraph_node, graph
|
||||
)
|
||||
)
|
||||
ind += len(tuple_spec)
|
||||
else:
|
||||
subgraph_output = graph.create_node(
|
||||
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
|
||||
)
|
||||
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
|
||||
ind += 1
|
||||
|
||||
# Erase in reverse topological order
|
||||
for node in reversed(region):
|
||||
graph.erase_node(node)
|
||||
if node in flattened_getitem_nodes:
|
||||
# Don't erase these, since they will still be used
|
||||
continue
|
||||
|
||||
if node not in flattened_output_nodes:
|
||||
graph.erase_node(node)
|
||||
|
||||
# Remove any nodes with additional deps
|
||||
# This is safe; we've guaranteed that there is
|
||||
# no input mutation, so all additional deps
|
||||
@ -220,15 +255,38 @@ def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None
|
||||
inds_unique.add(ind)
|
||||
|
||||
|
||||
def _copy_nodes_and_remap_inputs(
|
||||
subgraph: torch.fx.Graph, region: Region
|
||||
) -> list[OrderedSet[UsageIndex]]:
|
||||
def _create_subgraph(
|
||||
region: Region,
|
||||
inds_with_external_users: list[int],
|
||||
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
|
||||
subgraph: torch.fx.Graph = torch.fx.Graph()
|
||||
external_input_to_usages = _get_external_inputs(region)
|
||||
external_node_usages = list[OrderedSet[UsageIndex]]()
|
||||
region_to_subgraph_node = {}
|
||||
flattened_getitem_nodes = OrderedSet()
|
||||
node_usage_to_tuple_elems = {}
|
||||
|
||||
for node, usage_indices in external_input_to_usages.items():
|
||||
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
||||
region_to_subgraph_node[node] = placeholder
|
||||
# We don't handle tuples as inputs today
|
||||
if _is_tuple_node(node):
|
||||
# If a node is a tuple we will possibly create multiple placeholders for them
|
||||
# and track which nodes we won't copy into the subgraph because they are flattened away
|
||||
# Later, when replacing each region with this subgraph, we will create a getitem node
|
||||
# externally which will perform the flattening on the outer nodes.
|
||||
flattened_node_indices = _get_flattened_node_indices(node, region)
|
||||
for ind in flattened_node_indices:
|
||||
placeholder = subgraph.placeholder(
|
||||
f"supgraph_input_{node.name}_flattened_{ind}"
|
||||
)
|
||||
region_to_subgraph_node[region[ind]] = placeholder
|
||||
flattened_getitem_nodes.add(region[ind])
|
||||
node_usage_to_tuple_elems[next(iter(usage_indices))] = (
|
||||
flattened_node_indices
|
||||
)
|
||||
else:
|
||||
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
||||
region_to_subgraph_node[node] = placeholder
|
||||
|
||||
external_node_usages.append(usage_indices)
|
||||
|
||||
def map_arg(node: Node) -> Node:
|
||||
@ -237,29 +295,29 @@ def _copy_nodes_and_remap_inputs(
|
||||
else:
|
||||
return node
|
||||
|
||||
for node in region:
|
||||
def copy_to_subgraph(node: Node) -> Node:
|
||||
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
|
||||
region_to_subgraph_node[node] = subgraph_node
|
||||
return subgraph_node
|
||||
|
||||
return external_node_usages
|
||||
output_list = []
|
||||
ind_to_tuple_spec = {}
|
||||
for ind, node in enumerate(region):
|
||||
if node not in flattened_getitem_nodes:
|
||||
subgraph_node = copy_to_subgraph(node)
|
||||
if ind in inds_with_external_users:
|
||||
# flatten tuple outputs by generating a getitem node tree
|
||||
if _is_tuple_node(node):
|
||||
getitem_nodes, ind_to_tuple_spec[ind] = _create_getitem_nodes(
|
||||
node, subgraph_node, subgraph
|
||||
)
|
||||
output_list.extend(getitem_nodes)
|
||||
else:
|
||||
output_list.append(subgraph_node)
|
||||
|
||||
subgraph.output(tuple(output_list))
|
||||
|
||||
def _create_subgraph_outputs(
|
||||
subgraph: torch.fx.Graph, inds_to_output: list[int]
|
||||
) -> None:
|
||||
node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")]
|
||||
out_tup = tuple(node_list[ind] for ind in inds_to_output)
|
||||
subgraph.output(out_tup)
|
||||
|
||||
|
||||
def _create_subgraph(
|
||||
region: Region,
|
||||
inds_with_external_users: list[int],
|
||||
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
|
||||
subgraph: torch.fx.Graph = torch.fx.Graph()
|
||||
external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region)
|
||||
_create_subgraph_outputs(subgraph, inds_with_external_users)
|
||||
return subgraph, external_node_usages
|
||||
return subgraph, external_node_usages, node_usage_to_tuple_elems, ind_to_tuple_spec
|
||||
|
||||
|
||||
def _stable_topological_sort(
|
||||
@ -312,7 +370,9 @@ def _stable_topological_sort(
|
||||
pending.extend(reversed(waiting.pop(node, ())))
|
||||
|
||||
ready.update(outputs)
|
||||
assert not waiting and len(ready) == len(graph.nodes)
|
||||
if not (not waiting and len(ready) == len(graph.nodes)):
|
||||
breakpoint()
|
||||
assert not waiting and len(ready) == len(graph.nodes)
|
||||
|
||||
|
||||
def _populate_additional_deps(
|
||||
@ -384,11 +444,17 @@ def _add_mutation_dependencies(
|
||||
|
||||
|
||||
def _has_aliasing(
|
||||
region: Region, inputs: list[Node], inds_with_external_users: list[int]
|
||||
region: Region,
|
||||
inputs: list[Node],
|
||||
inds_with_external_users: list[int],
|
||||
flattened_getitem_nodes,
|
||||
) -> bool:
|
||||
input_storages: dict[StorageWeakRef, Node] = dict()
|
||||
|
||||
for node in inputs:
|
||||
if node in flattened_getitem_nodes:
|
||||
continue
|
||||
|
||||
example_value = node.meta["example_value"]
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
storage = StorageWeakRef(example_value._typed_storage())
|
||||
@ -406,6 +472,9 @@ def _has_aliasing(
|
||||
output_storages: dict[StorageWeakRef, Node] = dict()
|
||||
for i in inds_with_external_users:
|
||||
out_node = region[i]
|
||||
if out_node in flattened_getitem_nodes:
|
||||
continue
|
||||
|
||||
if out_node:
|
||||
example_value = out_node.meta["example_value"]
|
||||
assert not isinstance(example_value, list)
|
||||
@ -437,3 +506,90 @@ def _has_aliasing(
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_tuple_node(node: Node) -> bool:
|
||||
return isinstance(node.meta["example_value"], tuple)
|
||||
|
||||
|
||||
def _get_children_getitems(node: Node) -> Generator[Node, None, None]:
|
||||
for user in node.users:
|
||||
if user.target == operator.getitem and isinstance(user.args[1], int):
|
||||
yield user
|
||||
|
||||
|
||||
def _get_flattened_node_indices(node: Node, region: Region) -> OrderedSet[int]:
|
||||
"""Returns an ordered set of indices, each reprenting a node in the region which will be flattened"""
|
||||
|
||||
flattened_node_to_ind = {n: i for i, n in enumerate(region)}
|
||||
node_indices: OrderedSet[int] = OrderedSet()
|
||||
|
||||
queue = deque(_get_children_getitems(node))
|
||||
|
||||
while queue:
|
||||
cur_node = queue.popleft()
|
||||
|
||||
if any(user in region for user in cur_node.users):
|
||||
node_indices.add(flattened_node_to_ind[cur_node])
|
||||
|
||||
for child in _get_children_getitems(cur_node):
|
||||
queue.append(child)
|
||||
|
||||
return node_indices
|
||||
|
||||
|
||||
def _create_getitem_nodes(
|
||||
node: Node, subgraph_tuple_node, subgraph: torch.fx.Graph
|
||||
) -> tuple[list[Node], tuple[Any]]:
|
||||
tup = node.meta["example_value"]
|
||||
assert isinstance(tup, tuple), "_get_getitem_children expects tuple"
|
||||
|
||||
getitem_nodes = []
|
||||
queue = deque((e, (i,), subgraph_tuple_node) for i, e in enumerate(tup))
|
||||
path_to_output_index = {}
|
||||
|
||||
while queue:
|
||||
cur_elem, path, parent = queue.popleft()
|
||||
|
||||
new_getitem_node = subgraph.create_node(
|
||||
"call_function", operator.getitem, (parent, path[-1]), {}
|
||||
)
|
||||
new_getitem_node.meta["example_value"] = cur_elem
|
||||
|
||||
path_to_output_index[path] = len(getitem_nodes)
|
||||
getitem_nodes.append(new_getitem_node)
|
||||
|
||||
if isinstance(cur_elem, tuple):
|
||||
queue.extend(
|
||||
(e, path + (i,), new_getitem_node) for i, e in enumerate(cur_elem)
|
||||
)
|
||||
|
||||
return getitem_nodes, path_to_output_index
|
||||
|
||||
|
||||
def _replace_tuple_outputs(
|
||||
node, output_index, tuple_spec, invoke_subgraph_node, graph
|
||||
) -> OrderedSet[Node]:
|
||||
assert _is_tuple_node(node), "_replace_tuple_outputs expects a tuple node"
|
||||
|
||||
queue = deque((c, (c.args[1],)) for c in _get_children_getitems(node))
|
||||
erased_nodes = OrderedSet()
|
||||
while queue:
|
||||
cur_node, path = queue.pop()
|
||||
|
||||
for c in _get_children_getitems(cur_node):
|
||||
queue.append((c, path + (c.args[1],)))
|
||||
|
||||
subgraph_output = graph.create_node(
|
||||
"call_function",
|
||||
operator.getitem,
|
||||
(invoke_subgraph_node, output_index + tuple_spec[path]),
|
||||
{},
|
||||
)
|
||||
cur_node.replace_all_uses_with(subgraph_output, propagate_meta=True)
|
||||
graph.erase_node(cur_node)
|
||||
erased_nodes.add(cur_node)
|
||||
|
||||
graph.erase_node(node)
|
||||
erased_nodes.add(node)
|
||||
return erased_nodes
|
||||
|
||||
@ -327,8 +327,17 @@ class GraphRegionTracker:
|
||||
self._is_identical,
|
||||
)
|
||||
# sort topologically
|
||||
# we need to handle edge cases where some nodes have no dependencies
|
||||
# so first we map each node to its ranking,
|
||||
ref_region = region_group[0]
|
||||
index_to_rank = {
|
||||
index: topological_ranking[n] for index, n in enumerate(ref_region)
|
||||
}
|
||||
sorted_indices = sorted(
|
||||
range(len(ref_region)), key=lambda i: index_to_rank[i]
|
||||
)
|
||||
for region in region_group:
|
||||
region.sort(key=lambda n: topological_ranking[n])
|
||||
region[:] = [region[i] for i in sorted_indices]
|
||||
|
||||
return [
|
||||
region_group for region_group in region_groups if len(region_group[0]) > 1
|
||||
|
||||
@ -1030,13 +1030,19 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
|
||||
name = OutputGraph.module_key_name(*names)
|
||||
name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
|
||||
|
||||
self.nn_modules[name] = target
|
||||
if isinstance(target, torch.nn.Module):
|
||||
|
||||
def register_leaf_name(leaf_name):
|
||||
assert self.param_name_to_source is not None
|
||||
new_source = ParamBufferSource(source, leaf_name)
|
||||
new_name = f"{name}.{leaf_name}"
|
||||
# If source is None we are installing a subgraph and
|
||||
# propagating existing parameters to the new nn module
|
||||
if source:
|
||||
new_source = ParamBufferSource(source, leaf_name)
|
||||
else:
|
||||
new_source = self.param_name_to_source[leaf_name]
|
||||
self.param_name_to_source[new_name] = new_source
|
||||
if isinstance(source, LocalSource):
|
||||
self.dynamo_flat_name_to_original_fqn[
|
||||
@ -1849,9 +1855,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
raise
|
||||
except exceptions_allowed_to_be_fallback as e:
|
||||
if self.has_user_defined_allowed_in_graph:
|
||||
raise BackendCompilerFailed(
|
||||
self.compiler_fn, e, inspect.currentframe()
|
||||
).with_traceback(e.__traceback__) from None
|
||||
raise BackendCompilerFailed(self.compiler_fn, e, inspect.currentframe())
|
||||
unimplemented_v2_with_warning(
|
||||
e,
|
||||
self.root_tx.f_code,
|
||||
@ -1867,9 +1871,8 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
# aborting execution.
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise BackendCompilerFailed(
|
||||
self.compiler_fn, e, inspect.currentframe()
|
||||
).with_traceback(e.__traceback__) from None
|
||||
breakpoint()
|
||||
raise BackendCompilerFailed(self.compiler_fn, e, inspect.currentframe())
|
||||
|
||||
signpost_event(
|
||||
"dynamo",
|
||||
|
||||
@ -954,7 +954,7 @@ class VariableBuilder:
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph",
|
||||
context="",
|
||||
explanation="Directly using invoke_subgraph is not supported. Use mark_compile_region",
|
||||
explanation="Directly using invoke_subgraph is not supported. Use nested_compile_region",
|
||||
hints=[],
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
|
||||
|
||||
@ -818,9 +818,7 @@ class BuiltinVariable(VariableTracker):
|
||||
handlers: list[_HandlerCallback] = []
|
||||
|
||||
if any(issubclass(t, LazyVariableTracker) for t in arg_types):
|
||||
return lambda tx, args, kwargs: obj.call_function(
|
||||
tx, [v.realize() for v in args], kwargs
|
||||
)
|
||||
return lambda tx, args, kwargs: obj.call_function(tx, args, kwargs)
|
||||
|
||||
if inspect.isclass(fn) and (
|
||||
issubclass(fn, Exception)
|
||||
@ -1175,6 +1173,30 @@ class BuiltinVariable(VariableTracker):
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
key: tuple[object, ...]
|
||||
from .lazy import LazyVariableTracker
|
||||
|
||||
if (
|
||||
self.fn is operator.getitem
|
||||
and len(args) == 2
|
||||
and isinstance(args[0], LazyVariableTracker)
|
||||
and isinstance(args[1], LazyVariableTracker)
|
||||
):
|
||||
new_arg0 = args[0].realize()
|
||||
if (
|
||||
isinstance(new_arg0, ListVariable)
|
||||
and new_arg0.are_items_same
|
||||
and isinstance(args[1].peek_value(), int)
|
||||
):
|
||||
# This translates to a user code that has list[index], where
|
||||
# both have sources, and we are about to insert a EQUALS_MATCH
|
||||
# guard on the index. This will likely have large number of
|
||||
# recompilations on the `index`, so one way to avoid this is to
|
||||
# guard recursively on the list but not guard on the index.
|
||||
LazyVariableTracker.realize_all(new_arg0)
|
||||
new_arg1 = ConstantVariable(args[1].peek_value())
|
||||
args = [new_arg0, new_arg1]
|
||||
|
||||
args = [v.realize() for v in args]
|
||||
if kwargs:
|
||||
kwargs = {k: v.realize() for k, v in kwargs.items()}
|
||||
key = (self.fn, *(type(x) for x in args), True)
|
||||
|
||||
@ -823,6 +823,7 @@ def speculate_subgraph(
|
||||
context=context,
|
||||
explanation=f"Higher order ops do not support input mutation. Found in {source_target.name()}",
|
||||
hints=[
|
||||
"Set experimental flag `torch._dynamo.config.does_invoke_subgraph_support_input_mutation=True`.",
|
||||
"Consider using the debug context to change user code to avoid mutation.",
|
||||
"Please open an issue.",
|
||||
],
|
||||
@ -3397,7 +3398,7 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
return body_name
|
||||
|
||||
@raise_hard_error_if_graph_break(
|
||||
reason="torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
|
||||
reason="torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
|
||||
)
|
||||
def call_function(
|
||||
self,
|
||||
@ -3405,6 +3406,10 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
self.supports_input_mutation = (
|
||||
torch._dynamo.config.does_invoke_subgraph_support_input_mutation
|
||||
)
|
||||
|
||||
# This flattens the kwargs into lifted args
|
||||
(
|
||||
p_args,
|
||||
|
||||
@ -17,6 +17,7 @@ variable tracking system.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
import operator
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
@ -85,6 +86,28 @@ class BaseListVariable(VariableTracker):
|
||||
def modified(self, items, **kwargs):
|
||||
return type(self)(items, **kwargs)
|
||||
|
||||
@functools.cached_property
|
||||
def are_items_same(self):
|
||||
items = self.items
|
||||
if len(items) >= 2 and all(
|
||||
isinstance(x, variables.LazyVariableTracker) for x in items
|
||||
):
|
||||
values = [x.peek_value() for x in items]
|
||||
if all(variables.ConstantVariable.is_literal(v) for v in values):
|
||||
return True
|
||||
if all(isinstance(v, torch.Tensor) for v in values):
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
|
||||
metadata0 = _extract_tensor_metadata(values[0])
|
||||
same_metedata = True
|
||||
for v in values[1:]:
|
||||
if _extract_tensor_metadata(v) != metadata0:
|
||||
return False
|
||||
|
||||
if same_metedata:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.as_python_constant()
|
||||
|
||||
@ -1342,7 +1342,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# nested subgraphs can have singleton outputs
|
||||
result = (result,)
|
||||
assert isinstance(result, (tuple, list)), type(result)
|
||||
assert all(
|
||||
if not all(
|
||||
isinstance(
|
||||
x,
|
||||
(
|
||||
@ -1358,7 +1358,8 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
),
|
||||
)
|
||||
for x in result
|
||||
), result
|
||||
):
|
||||
breakpoint()
|
||||
|
||||
fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type]
|
||||
if not isinstance(fx_node_args, (tuple, list)):
|
||||
|
||||
@ -32,6 +32,7 @@ __all__ = [
|
||||
"skip_guard_on_all_nn_modules_unsafe",
|
||||
"keep_tensor_guards_unsafe",
|
||||
"skip_guard_on_globals_unsafe",
|
||||
"nested_compile_region",
|
||||
]
|
||||
|
||||
|
||||
@ -566,3 +567,36 @@ def skip_guard_on_globals_unsafe(guard_entries):
|
||||
"""
|
||||
|
||||
return [not entry.is_global for entry in guard_entries]
|
||||
|
||||
|
||||
def nested_compile_region(fn=None):
|
||||
"""
|
||||
Tells **``torch.compile``** that the marked set of operations forms a nested
|
||||
compile region (which is often repeated in the full model) whose code can be
|
||||
compiled once and safely reused. ``nested_compile_region`` can also be used
|
||||
as a decorator.
|
||||
|
||||
During **``torch.compile``** tracing, the compiler applies *hierarchical
|
||||
compilation* with ``nested_compile_region``: it emits optimized code for the
|
||||
marked region the first time it is encountered and re-emits (or “stamps
|
||||
out”) the previously compiled code on every subsequent invocation. This can
|
||||
substantially reduce overall compile time for deeply-stacked,
|
||||
structurally-identical components such as the transformer layers of a
|
||||
large-language-model (LLM).
|
||||
|
||||
Outside a ``torch.compile`` context—i.e., in standard eager execution—the
|
||||
call is a no-op, so existing workflows remain unaffected.
|
||||
|
||||
Note that ``nested_compile_region`` **does not** promise that a region will
|
||||
be compiled exactly once. If the compiler detects that new input conditions
|
||||
(shape, dtype, device, stride, globals etc.) make the cached version invalid
|
||||
to reuse, it will transparently re-compile the region. Using it is
|
||||
therefore *safe*: correctness is always preserved, and you pay the extra
|
||||
compilation cost only when required.
|
||||
"""
|
||||
|
||||
from torch._higher_order_ops.invoke_subgraph import (
|
||||
mark_compile_region as _mark_compile_region,
|
||||
)
|
||||
|
||||
return _mark_compile_region(fn)
|
||||
|
||||
Reference in New Issue
Block a user