Compare commits

...

5 Commits

Author SHA1 Message Date
1e5f51ed2c Assorted fixes 2025-06-25 14:03:13 -07:00
2a738143b4 [DONT MERGE] Diffusion models benchmarking for compile time
ghstack-source-id: 1591c4dddcf9d828191ed7bb54ec98669e074816
Pull-Request: https://github.com/pytorch/pytorch/pull/155866
2025-06-20 22:12:55 -07:00
965f830bc9 [invoke_subgraph] Add config flag to control support of input aliasing
ghstack-source-id: 79d3bf9f22aecdaa5be6d95a2cb51d6b4d1a47a0
Pull-Request: https://github.com/pytorch/pytorch/pull/156450
2025-06-20 22:12:55 -07:00
a097a0d3b2 [dynamo] Guard eagerly on list objects to avoid guard on getitem index
ghstack-source-id: 2c5a7e61f7395361508f8f4ec1f3ab8b7449385a
Pull-Request: https://github.com/pytorch/pytorch/pull/156531
2025-06-20 22:12:54 -07:00
255c2b0d6c [compile] Release nested_compile_region API
ghstack-source-id: 913ef891c853836dfff75afb943359f6c9ad12db
Pull-Request: https://github.com/pytorch/pytorch/pull/156449
2025-06-20 22:12:54 -07:00
16 changed files with 877 additions and 122 deletions

View File

@ -0,0 +1,294 @@
import sys
import time
import diffusers
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
GGUFQuantizationConfig,
)
import torch
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
def compile_full_model(model):
model.compile(fullgraph=True, mode="reduce-overhead")
def compile_regions(model, nn_modules):
model.compile_repeated_blocks(fullgraph=True)
# for submod in model.modules():
# if isinstance(submod, nn_modules):
# print("Compiling", submod.__class__)
# submod.compile(fullgraph=True)
def compile_hierarchical(model, nn_modules):
for submod in model.modules():
if isinstance(submod, nn_modules):
submod.__class__.forward = mark_compile_region(submod.__class__.forward)
model.compile(fullgraph=True)
def auroflow_benchmark(mode):
transformer = AuraFlowTransformer2DModel.from_single_file(
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3",
torch_dtype=torch.bfloat16,
transformer=transformer,
).to("cuda")
if mode == "full":
compile_full_model(pipeline.transformer)
elif mode == "regional":
compile_regions(
pipeline.transformer,
(
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowSingleTransformerBlock,
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowJointTransformerBlock,
),
)
elif mode == "hierarchical":
compile_hierarchical(
pipeline.transformer,
(
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowSingleTransformerBlock,
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowJointTransformerBlock,
),
)
else:
assert mode == "eager"
pipeline("A cute pony", width=512, height=512, num_inference_steps=1)
t0 = time.perf_counter()
pipeline("A cute pony", width=512, height=512, num_inference_steps=50)
t1 = time.perf_counter()
print(t1 - t0)
def wan_benchmark(mode):
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import load_image
from transformers import CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(
model_id, subfolder="vae", torch_dtype=torch.float32
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
# replace this with pipe.to("cuda") if you have sufficient VRAM
# pipe.enable_model_cpu_offload()
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
if mode == "full":
compile_full_model(pipe.transformer)
elif mode == "regional":
compile_regions(
pipe.transformer,
(diffusers.models.transformers.transformer_wan.WanTransformerBlock,),
)
elif mode == "hierarchical":
compile_hierarchical(
pipe.transformer,
(diffusers.models.transformers.transformer_wan.WanTransformerBlock,),
)
else:
assert mode == "eager"
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=1,
guidance_scale=5.0,
).frames[0]
t0 = time.perf_counter()
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
t1 = time.perf_counter()
print(t1 - t0)
def ltx_benchmark(mode):
from diffusers import LTXConditionPipeline
import torch
pipe = LTXConditionPipeline.from_pretrained(
"Lightricks/LTX-Video-0.9.7-dev", torch_dtype=torch.bfloat16
)
pipe.to("cuda")
pipe.vae.enable_tiling()
def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % pipe.vae_spatial_compression_ratio)
width = width - (width % pipe.vae_spatial_compression_ratio)
return height, width
prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
expected_height, expected_width = 512, 704
downscale_factor = 2 / 3
num_frames = 121
# Part 1. Generate video at smaller resolution
downscaled_height, downscaled_width = (
int(expected_height * downscale_factor),
int(expected_width * downscale_factor),
)
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
downscaled_height, downscaled_width
)
if mode == "full":
compile_full_model(pipe.transformer)
elif mode == "regional":
compile_regions(
pipe.transformer,
(diffusers.models.transformers.transformer_ltx.LTXVideoTransformerBlock,),
)
elif mode == "hierarchical":
compile_hierarchical(
pipe.transformer,
(diffusers.models.transformers.transformer_ltx.LTXVideoTransformerBlock,),
)
latents = pipe(
conditions=None,
prompt=prompt,
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
num_inference_steps=1,
generator=torch.Generator().manual_seed(0),
output_type="latent",
).frames
import time
t0 = time.time()
latents = pipe(
conditions=None,
prompt=prompt,
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
num_inference_steps=50,
generator=torch.Generator().manual_seed(0),
output_type="latent",
).frames
t1 = time.time()
print(t1 - t0)
def flux_benchmark(mode):
import diffusers
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to("cuda")
prompt = "A cat holding a sign that says hello world"
if mode == "full":
compile_full_model(pipe.transformer)
elif mode == "regional":
compile_regions(
pipe.transformer,
(
diffusers.models.transformers.transformer_flux.FluxTransformerBlock,
diffusers.models.transformers.transformer_flux.FluxSingleTransformerBlock,
),
)
elif mode == "hierarchical":
compile_hierarchical(
pipe.transformer,
(
diffusers.models.transformers.transformer_flux.FluxTransformerBlock,
diffusers.models.transformers.transformer_flux.FluxSingleTransformerBlock,
),
)
pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=1,
max_sequence_length=512,
)
t0 = time.perf_counter()
pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
)
t1 = time.perf_counter()
print(t1 - t0)
model_name = sys.argv[1]
mode = sys.argv[2]
if model_name == "auroflow":
auroflow_benchmark(mode)
elif model_name == "wan":
wan_benchmark(mode)
elif model_name == "ltx":
ltx_benchmark(mode)
elif model_name == "flux":
flux_benchmark(mode)

View File

@ -0,0 +1,50 @@
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-9b-it",
device_map="auto",
torch_dtype=torch.bfloat16,
)
messages = [
{"role": "user", "content": "Write me a poem about Machine Learning."},
]
input_ids = tokenizer.apply_chat_template(
messages, return_tensors="pt", return_dict=True
).to("cuda")
with torch.inference_mode():
outputs = model.generate(**input_ids, max_new_tokens=256)
t0 = time.perf_counter()
with torch.inference_mode():
outputs = model.generate(**input_ids, max_new_tokens=256)
t1 = time.perf_counter()
print(tokenizer.decode(outputs[0]))
print("Time :", t1 - t0)
# generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
# generation = generation[0][input_len:]
# t0 = time.perf_counter()
# with torch.inference_mode():
# generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
# generation = generation[0][input_len:]
# t1 = time.perf_counter()
# print("Time :", t1 - t0)
# decoded = processor.decode(generation, skip_special_tokens=True)
# print(decoded)
# # input_text = "Write me a poem about Machine Learning."
# # input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# # outputs = model.generate(**input_ids, max_new_tokens=32)
# # print(tokenizer.decode(outputs[0]))

View File

@ -0,0 +1,61 @@
# pip install accelerate
import time
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
import torch
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(model_id)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
},
{"type": "text", "text": "Describe this image in detail."},
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
t0 = time.perf_counter()
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
t1 = time.perf_counter()
print("Time :", t1 - t0)
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
# **Overall Impression:** The image is a close-up shot of a vibrant garden scene,
# focusing on a cluster of pink cosmos flowers and a busy bumblebee.
# It has a slightly soft, natural feel, likely captured in daylight.

View File

@ -0,0 +1,56 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "Qwen/Qwen3-8B"
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto"
)
model.generation_config.cache_implementation = "static"
# prepare the model input
prompt = "Give me a short introduction to large language model."
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True, # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# conduct text completion
with torch.inference_mode():
generated_ids = model.generate(**model_inputs, max_new_tokens=200)
import time
t0 = time.perf_counter()
with torch.inference_mode():
generated_ids = model.generate(**model_inputs, max_new_tokens=200)
t1 = time.perf_counter()
print("Time ", t1 - t0)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
# parsing thinking content
try:
# rindex finding 151668 (</think>)
index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
index = 0
thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(
"\n"
)
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
print("thinking content:", thinking_content)
print("content:", content)

View File

@ -29,4 +29,5 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
skip_guard_on_all_nn_modules_unsafe
keep_tensor_guards_unsafe
skip_guard_on_globals_unsafe
nested_compile_region
```

View File

@ -20,7 +20,6 @@ from torch._dynamo.testing import (
InductorAndRecordGraphs,
normalize_gm,
)
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
from torch._higher_order_ops.schema import find_hop_schema
from torch._inductor.pattern_matcher import (
CallFunctionVarArgs,
@ -37,6 +36,8 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
nested_compile_region = torch.compiler.nested_compile_region
if HAS_GPU:
import triton
@ -48,7 +49,7 @@ class TestInvokeSubgraph(TestCase):
return torch.mul(x, y)
def fn(x, y):
return mark_compile_region(gn)(x, y)
return nested_compile_region(gn)(x, y)
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
@ -71,7 +72,7 @@ class TestInvokeSubgraph(TestCase):
return torch.mul(x, y)
def fn(x, y):
return mark_compile_region(gn)(x, y)
return nested_compile_region(gn)(x, y)
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
@ -91,11 +92,11 @@ class TestInvokeSubgraph(TestCase):
self.assertEqual(y.grad, y_clone.grad)
def test_multiple(self):
@mark_compile_region
@nested_compile_region
def cos(x):
return torch.cos(x)
@mark_compile_region
@nested_compile_region
def sin(x):
return torch.sin(x)
@ -122,7 +123,7 @@ class TestInvokeSubgraphCompile(TestCase):
self.assertEqual(len(subgraph_attr_names), expected)
def test_simple(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
return torch.mul(x, y)
@ -151,7 +152,7 @@ class TestInvokeSubgraphCompile(TestCase):
super().__init__()
self.c = 5
@mark_compile_region
@nested_compile_region
def forward(self, x, y):
return torch.mul(x, y).sin() + self.c
@ -182,7 +183,7 @@ class TestInvokeSubgraphCompile(TestCase):
super().__init__()
self.c = 5
@mark_compile_region
@nested_compile_region
def forward(self, x, y):
return torch.mul(x, y).sin() + self.c
@ -232,7 +233,7 @@ class TestInvokeSubgraphCompile(TestCase):
self.c = 5
self.register_buffer("buf", torch.ones(8, requires_grad=False))
@mark_compile_region
@nested_compile_region
def forward(self, x, y):
self.buf.add_(1)
return torch.mul(x, y).sin() + self.c + self.buf
@ -311,7 +312,7 @@ class GraphModule(torch.nn.Module):
self.c = 5
self.register_buffer("buf", torch.ones(8, requires_grad=False))
@mark_compile_region
@nested_compile_region
def forward(self, x, y):
return torch.mul(x, y).sin() * self.c * self.buf
@ -412,7 +413,7 @@ class GraphModule(torch.nn.Module):
super().__init__()
self.register_buffer("buf", torch.ones(8, requires_grad=False))
@mark_compile_region
@nested_compile_region
def forward(self, x, y):
self.buf.add_(1)
return torch.mul(x, y).sin() * self.buf
@ -467,7 +468,7 @@ class GraphModule(torch.nn.Module):
super().__init__()
self.register_buffer("buf", torch.ones(8, requires_grad=False))
@mark_compile_region
@nested_compile_region
def forward(self, x, y):
self.buf.add_(1)
return torch.mul(x, y).sin() * self.buf
@ -490,7 +491,7 @@ class GraphModule(torch.nn.Module):
torch.compile(fn, backend="inductor", fullgraph=True)(mod, x, y)
def test_list(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
return [torch.mul(x, y), torch.add(x, y)]
@ -516,7 +517,7 @@ class GraphModule(torch.nn.Module):
self.assertEqual(y.grad, y_clone.grad)
def test_tuple_of_tuple(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
return ((torch.mul(x, y),), torch.add(x, y))
@ -552,7 +553,7 @@ class GraphModule(torch.nn.Module):
a = grad_out.view(12, 5)
return torch.cos(torch.reshape(a, (3, 4, 5)))
@mark_compile_region
@nested_compile_region
def gn(x):
return CustomOp.apply(x)
@ -579,7 +580,7 @@ class GraphModule(torch.nn.Module):
@requires_cuda
def test_sdpa(self):
@mark_compile_region
@nested_compile_region
def gn(q, k, v):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
@ -611,7 +612,7 @@ class GraphModule(torch.nn.Module):
res.sum().backward()
def test_symint_from_fwd_to_bwd(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
a = torch.sum(x, (1,), keepdim=True).view(y.shape[1], y.shape[0])
return torch.matmul(a, y)
@ -647,11 +648,11 @@ class GraphModule(torch.nn.Module):
# graph passes. Without running joint graph passes, we would get an
# error like AssertionError: should have been handled in
# replace_random.py
@mark_compile_region
@nested_compile_region
def gn(x):
return torch.nn.functional.dropout(torch.sin(x), p=0.5)
@mark_compile_region
@nested_compile_region
def hn(x):
return torch.sin(x)
@ -714,7 +715,7 @@ class GraphModule(torch.nn.Module):
def test_dropout_checks_joint_graph_inference(self):
# Checks that joint graph results in inductor seeds for just the inference graph
@mark_compile_region
@nested_compile_region
def gn(x):
return torch.nn.functional.dropout(torch.sin(x), p=0.5)
@ -753,7 +754,7 @@ class <lambda>(torch.nn.Module):
)
def test_dedupe(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
return torch.mul(x, y)
@ -839,7 +840,7 @@ class GraphModule(torch.nn.Module):
)
def test_dce(self):
@mark_compile_region
@nested_compile_region
def gn(x):
x = torch.sin(x)
# should be dce'd
@ -875,7 +876,7 @@ class <lambda>(torch.nn.Module):
def test_nonlocal_update(self):
counter = 2
@mark_compile_region
@nested_compile_region
def gn(x, y):
nonlocal counter
return (torch.mul(x, y) * counter,)
@ -940,7 +941,7 @@ class GraphModule(torch.nn.Module):
)
def test_view_to_reshape(self):
@mark_compile_region
@nested_compile_region
def gn(x):
x = torch.sin(x)
x = x.view(1, 8)
@ -980,7 +981,7 @@ class <lambda>(torch.nn.Module):
)
def test_normalize_gm(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
# Different graph give different names to intermediate nodes
for _ in range(5):
@ -1038,7 +1039,7 @@ class GraphModule(torch.nn.Module):
)
def test_input_mutation(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
x.add_(1)
return torch.mul(x, y)
@ -1053,7 +1054,7 @@ class GraphModule(torch.nn.Module):
with self.assertRaisesRegex(
RuntimeError,
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
) as cm:
opt_fn(x, y)
@ -1063,12 +1064,29 @@ class GraphModule(torch.nn.Module):
"Encountered input mutation during higher order op tracing" in str(cause)
)
def test_input_mutation_inference_mode(self):
@torch._dynamo.config.patch(does_invoke_subgraph_support_input_mutation=True)
def test_input_mutation_with_config_flag(self):
@mark_compile_region
def gn(x, y):
x.add_(1)
return torch.mul(x, y)
def fn(x, y):
return gn(x, y)
x1 = torch.ones(8, requires_grad=False)
x2 = torch.ones(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
self.assertEqual(fn(x1, y), opt_fn(x2, y))
def test_input_mutation_inference_mode(self):
@nested_compile_region
def gn(x, y):
x.add_(1)
return torch.mul(x, y)
def fn(x, y):
z = torch.cos(x)
with torch.inference_mode():
@ -1080,7 +1098,7 @@ class GraphModule(torch.nn.Module):
with self.assertRaisesRegex(
RuntimeError,
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
) as cm:
opt_fn(x, y)
@ -1093,7 +1111,7 @@ class GraphModule(torch.nn.Module):
def test_simple_module(self):
mod = torch.nn.Linear(8, 8)
@mark_compile_region
@nested_compile_region
def gn(x):
return torch.cos(x), mod(x)
@ -1134,7 +1152,7 @@ class GraphModule(torch.nn.Module):
opt_fn(x)
def test_input_output_aliasing(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
return (x, torch.mul(x, y))
@ -1149,7 +1167,7 @@ class GraphModule(torch.nn.Module):
with self.assertRaisesRegex(
RuntimeError,
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
) as cm:
opt_fn(x, y)
@ -1160,7 +1178,7 @@ class GraphModule(torch.nn.Module):
)
def test_input_input_aliasing(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
return torch.mul(x, y)
@ -1173,7 +1191,7 @@ class GraphModule(torch.nn.Module):
with self.assertRaisesRegex(
RuntimeError,
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
) as cm:
opt_fn(x)
@ -1184,7 +1202,7 @@ class GraphModule(torch.nn.Module):
)
def test_output_output_aliasing(self):
@mark_compile_region
@nested_compile_region
def gn(x):
z = torch.cos(x)
return z, z.view(1, 8)
@ -1198,7 +1216,7 @@ class GraphModule(torch.nn.Module):
with self.assertRaisesRegex(
RuntimeError,
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
) as cm:
opt_fn(x)
@ -1218,7 +1236,7 @@ class GraphModule(torch.nn.Module):
self.a.add_(1)
return torch.mul(x, self.a)
@mark_compile_region
@nested_compile_region
def gn(x):
return mod(x)
@ -1235,7 +1253,7 @@ class GraphModule(torch.nn.Module):
with self.assertRaisesRegex(
RuntimeError,
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
) as cm:
opt_fn(x, y)
@ -1246,8 +1264,8 @@ class GraphModule(torch.nn.Module):
)
def test_redundant_compile_region(self):
@mark_compile_region
@mark_compile_region
@nested_compile_region
@nested_compile_region
def gn(x):
return torch.sin(x)
@ -1289,7 +1307,7 @@ class GraphModule(torch.nn.Module):
)
def test_kwargs_only(self):
@mark_compile_region
@nested_compile_region
def gn(x, *, y):
return x * y
@ -1310,7 +1328,7 @@ class GraphModule(torch.nn.Module):
super().__init__()
self.linear = torch.nn.Linear(8, 8)
@mark_compile_region
@nested_compile_region
def helper(self, x):
return self.linear(x)
@ -1367,7 +1385,7 @@ class GraphModule(torch.nn.Module):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = mark_compile_region(SubMod())
self.submod = nested_compile_region(SubMod())
def forward(self, x):
return x + self.submod(x) * self.submod(x) + x
@ -1418,7 +1436,7 @@ class GraphModule(torch.nn.Module):
)
ones = torch.ones(1000, device="cuda:0", dtype=torch.float32)
@mark_compile_region
@nested_compile_region
def fn(x, train):
return F.dropout(x * weight, 0.33, train)
@ -1431,7 +1449,7 @@ class GraphModule(torch.nn.Module):
weight.grad.clone()
def test_return_none_from_fwd(self):
@mark_compile_region
@nested_compile_region
def gn(x):
return x * 2, None, x * 3
@ -1535,7 +1553,7 @@ class GraphModule(torch.nn.Module):
)
def test_dynamic(self):
@mark_compile_region
@nested_compile_region
def gn(x):
return torch.sin(x)
@ -1549,9 +1567,27 @@ class GraphModule(torch.nn.Module):
res = opt_fn(x)
self.assertEqual(ref, res)
def test_return_size(self):
@nested_compile_region
def gn(x):
y = x + 1
z = x.shape
return y, z
def fn(x):
z0 = gn(x)
z1 = gn(x)
return z0[0] + z1[0], z0[1]
x = torch.randn(8, 8, requires_grad=True)
ref = fn(x)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_complex(self):
# Observed in Wan2.1
@mark_compile_region
@nested_compile_region
def gn(x):
return torch.sin(x)
@ -1566,7 +1602,7 @@ class GraphModule(torch.nn.Module):
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_pending_unbacked(self):
@mark_compile_region
@nested_compile_region
def gn(x):
u = x[0].item()
return x * u
@ -1585,7 +1621,7 @@ class GraphModule(torch.nn.Module):
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_unbacked(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
b = x.item()
torch._check_is_size(b)
@ -1605,7 +1641,7 @@ class GraphModule(torch.nn.Module):
self.assertEqual(ref, res)
def test_bwd_partitioning(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
z = torch.matmul(x, y)
return torch.sin(z)
@ -1688,7 +1724,7 @@ class GraphModule(torch.nn.Module):
)
def test_const_tensor(self):
@mark_compile_region
@nested_compile_region
def gn(x):
return torch.tensor(64, dtype=torch.float32) * x
@ -1707,14 +1743,14 @@ class GraphModule(torch.nn.Module):
def fn1(x):
return torch.cos(x)
@mark_compile_region
@nested_compile_region
def fn1_checkpoint(x):
return torch.utils.checkpoint.checkpoint(fn1, x, use_reentrant=False)
def fn2(x):
return torch.sin(x)
@mark_compile_region
@nested_compile_region
def fn2_checkpoint(x):
return torch.utils.checkpoint.checkpoint(fn2, x, use_reentrant=False)
@ -1753,7 +1789,7 @@ class GraphModule(torch.nn.Module):
self.assertEqual(ref, res)
def test_fake_tensor_checking(self):
@mark_compile_region
@nested_compile_region
def gn(x):
return torch.sin(x)
@ -1807,7 +1843,7 @@ class GraphModule(torch.nn.Module):
Tests check that the same subgraph called with different symints use different graphs
"""
@mark_compile_region
@nested_compile_region
def gn(x):
return torch.sin(x)
@ -1882,7 +1918,7 @@ class GraphModule(torch.nn.Module):
(x,) = ctx.saved_tensors
return x * torch.cos(grad_out)
@mark_compile_region
@nested_compile_region
def gn(x):
return CustomOp.apply(x)
@ -1982,7 +2018,7 @@ class GraphModule(torch.nn.Module):
return output
@mark_compile_region
@nested_compile_region
def gn(x, y):
o = torch.zeros_like(x)
call_triton_add(x, y, o, 0)
@ -2052,7 +2088,7 @@ class GraphModule(torch.nn.Module):
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
def test_unbacked_symbol(self):
@mark_compile_region
@nested_compile_region
def gn(x):
return torch.sin(torch.nonzero(x))
@ -2069,7 +2105,7 @@ class GraphModule(torch.nn.Module):
self.assertEqual(ref, res)
def test_different_strides_in_backward(self):
@mark_compile_region
@nested_compile_region
def gn(x):
return torch.cos(x)
@ -2216,7 +2252,7 @@ class GraphModule(torch.nn.Module):
)
def test_div(self):
@mark_compile_region
@nested_compile_region
def gn(x):
div = torch.div(1024, 256, rounding_mode="trunc")
return div * torch.ones(64, div) * x
@ -2292,7 +2328,7 @@ class GraphModule(torch.nn.Module):
lib.impl("add_op", impl, "CompositeExplicitAutograd")
lib.impl("add_op", meta, "Meta")
@mark_compile_region
@nested_compile_region
def gn(y, z):
return torch.ops.mylib.add_op.default(y, z)
@ -2374,7 +2410,7 @@ class GraphModule(torch.nn.Module):
lib.impl("add_op", impl, "CompositeExplicitAutograd")
lib.impl("add_op", meta, "Meta")
@mark_compile_region
@nested_compile_region
def gn(x, other):
y = x.transpose(2, 3).contiguous().transpose(2, 3)
z = y.sin().transpose(2, 3)
@ -2402,7 +2438,7 @@ class GraphModule(torch.nn.Module):
)
class TestInvokeSubgraphExport(TestCase):
def test_simple_func(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
return torch.mul(x, y)
@ -2441,7 +2477,7 @@ class GraphModule(torch.nn.Module):
)
def test_unbacked(self):
@mark_compile_region
@nested_compile_region
def gn(x, y):
b = x.item()
torch._check_is_size(b)
@ -2466,7 +2502,7 @@ class GraphModule(torch.nn.Module):
def test_pending_unbacked(self):
class M(torch.nn.Module):
@mark_compile_region
@nested_compile_region
def gn(self, x):
u = x[0].item()
return x * u
@ -2498,7 +2534,7 @@ class GraphModule(torch.nn.Module):
def test_simple_method(self):
class M(torch.nn.Module):
@mark_compile_region
@nested_compile_region
def gn(self, x, y):
return torch.mul(x, y)
@ -2522,7 +2558,7 @@ class GraphModule(torch.nn.Module):
super().__init__()
self.register_buffer("buf", b)
@mark_compile_region
@nested_compile_region
def forward(self, x, y):
return x * y + self.buf
@ -2546,7 +2582,7 @@ class GraphModule(torch.nn.Module):
class NegativeTesting(TestCase):
def test_graph_break(self):
@mark_compile_region
@nested_compile_region
def gn(x):
torch._dynamo.graph_break()
return torch.cos(x)
@ -2558,7 +2594,7 @@ class NegativeTesting(TestCase):
with self.assertRaisesRegex(
RuntimeError,
"torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
):
torch.compile(fn, backend="eager")(x)

View File

@ -107,7 +107,7 @@ assume_static_by_default = True
# with assume_static_by_default=True.
# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail
# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic.
automatic_dynamic_shapes = True
automatic_dynamic_shapes = False
# Valid options: "dynamic", "unbacked"
automatic_dynamic_shapes_mark_as: Literal["dynamic", "unbacked"] = "dynamic"
@ -623,6 +623,10 @@ wrap_top_frame = False
# record pre-graph bytecode in profile traces
record_pre_graph_bytecode_in_traces = True
# This is a short-lived flag to control whether invoke_subgraph supports
# input mutation. It will be turned to True by default.
does_invoke_subgraph_support_input_mutation = False
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None

View File

@ -9,9 +9,9 @@ structures across different parts of the network.
import logging
import operator
from collections import defaultdict
from collections import defaultdict, deque
from collections.abc import Generator, Iterable
from typing import Optional
from typing import Any, Optional
import torch
import torch.fx
@ -74,12 +74,14 @@ when they are created in output_graph.
sub_gms: dict[str, torch.fx.GraphModule] = {}
for region_group in duplicated_region_groups:
for region_group in duplicated_region_groups: # 0:1
inds_with_external_users = _get_all_output_indices(region_group)
region = region_group[0]
(
subgraph,
external_node_usages,
node_usage_to_tuple_elems,
ind_to_tuple_spec,
) = _create_subgraph(region, inds_with_external_users)
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
@ -100,6 +102,8 @@ when they are created in output_graph.
region,
get_subgraph_node,
external_node_usages,
node_usage_to_tuple_elems,
ind_to_tuple_spec,
inds_with_external_users,
subgraph_name,
node_to_additional_deps,
@ -122,14 +126,18 @@ def _replace_region_with_subgraph(
region: Region,
get_subgraph_node: Node,
external_node_usages: Iterable[OrderedSet[UsageIndex]],
node_usage_to_tuple_elems,
ind_to_tuple_spec,
inds_with_external_users: list[int],
subgraph_name: str,
node_to_additional_deps: dict[Node, OrderedSet[Node]],
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
) -> None:
sub_args = []
flattened_getitem_nodes = OrderedSet()
for usages in external_node_usages:
node_ind, usage_ind = next(iter(usages))
usage = next(iter(usages))
node_ind, usage_ind = usage
node = region[node_ind]
flattened_args_kwargs = _get_flat_args(node, {})
for user_ind, node_usage_ind in usages:
@ -140,12 +148,20 @@ def _replace_region_with_subgraph(
"NYI: Failed to substitute region %s due to mutation", region
)
return
sub_args.append(flattened_args_kwargs[usage_ind])
if usage in node_usage_to_tuple_elems:
tuple_elems = [region[i] for i in node_usage_to_tuple_elems[usage]]
flattened_getitem_nodes.update(tuple_elems)
sub_args.extend(tuple_elems)
else:
sub_args.append(flattened_args_kwargs[usage_ind])
# Input/Output aliasing not supported in HOPs today
# Note: we should use the nodes in the original graph (the region here)
# because we use the original traced example values for this check
if _has_aliasing(region, sub_args, inds_with_external_users):
if _has_aliasing(
region, sub_args, inds_with_external_users, flattened_getitem_nodes
):
return
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
@ -156,16 +172,35 @@ def _replace_region_with_subgraph(
invoke_args, # type: ignore[arg-type]
{},
)
for ind, external_user_ind in enumerate(inds_with_external_users):
ind = 0
flattened_output_nodes = OrderedSet()
for external_user_ind in inds_with_external_users:
node = region[external_user_ind]
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
if _is_tuple_node(node):
tuple_spec = ind_to_tuple_spec[external_user_ind]
flattened_output_nodes.update(
_replace_tuple_outputs(
node, ind, tuple_spec, invoke_subgraph_node, graph
)
)
ind += len(tuple_spec)
else:
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
ind += 1
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
if node in flattened_getitem_nodes:
# Don't erase these, since they will still be used
continue
if node not in flattened_output_nodes:
graph.erase_node(node)
# Remove any nodes with additional deps
# This is safe; we've guaranteed that there is
# no input mutation, so all additional deps
@ -220,15 +255,38 @@ def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None
inds_unique.add(ind)
def _copy_nodes_and_remap_inputs(
subgraph: torch.fx.Graph, region: Region
) -> list[OrderedSet[UsageIndex]]:
def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
external_input_to_usages = _get_external_inputs(region)
external_node_usages = list[OrderedSet[UsageIndex]]()
region_to_subgraph_node = {}
flattened_getitem_nodes = OrderedSet()
node_usage_to_tuple_elems = {}
for node, usage_indices in external_input_to_usages.items():
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
# We don't handle tuples as inputs today
if _is_tuple_node(node):
# If a node is a tuple we will possibly create multiple placeholders for them
# and track which nodes we won't copy into the subgraph because they are flattened away
# Later, when replacing each region with this subgraph, we will create a getitem node
# externally which will perform the flattening on the outer nodes.
flattened_node_indices = _get_flattened_node_indices(node, region)
for ind in flattened_node_indices:
placeholder = subgraph.placeholder(
f"supgraph_input_{node.name}_flattened_{ind}"
)
region_to_subgraph_node[region[ind]] = placeholder
flattened_getitem_nodes.add(region[ind])
node_usage_to_tuple_elems[next(iter(usage_indices))] = (
flattened_node_indices
)
else:
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
external_node_usages.append(usage_indices)
def map_arg(node: Node) -> Node:
@ -237,29 +295,29 @@ def _copy_nodes_and_remap_inputs(
else:
return node
for node in region:
def copy_to_subgraph(node: Node) -> Node:
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
region_to_subgraph_node[node] = subgraph_node
return subgraph_node
return external_node_usages
output_list = []
ind_to_tuple_spec = {}
for ind, node in enumerate(region):
if node not in flattened_getitem_nodes:
subgraph_node = copy_to_subgraph(node)
if ind in inds_with_external_users:
# flatten tuple outputs by generating a getitem node tree
if _is_tuple_node(node):
getitem_nodes, ind_to_tuple_spec[ind] = _create_getitem_nodes(
node, subgraph_node, subgraph
)
output_list.extend(getitem_nodes)
else:
output_list.append(subgraph_node)
subgraph.output(tuple(output_list))
def _create_subgraph_outputs(
subgraph: torch.fx.Graph, inds_to_output: list[int]
) -> None:
node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")]
out_tup = tuple(node_list[ind] for ind in inds_to_output)
subgraph.output(out_tup)
def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region)
_create_subgraph_outputs(subgraph, inds_with_external_users)
return subgraph, external_node_usages
return subgraph, external_node_usages, node_usage_to_tuple_elems, ind_to_tuple_spec
def _stable_topological_sort(
@ -312,7 +370,9 @@ def _stable_topological_sort(
pending.extend(reversed(waiting.pop(node, ())))
ready.update(outputs)
assert not waiting and len(ready) == len(graph.nodes)
if not (not waiting and len(ready) == len(graph.nodes)):
breakpoint()
assert not waiting and len(ready) == len(graph.nodes)
def _populate_additional_deps(
@ -384,11 +444,17 @@ def _add_mutation_dependencies(
def _has_aliasing(
region: Region, inputs: list[Node], inds_with_external_users: list[int]
region: Region,
inputs: list[Node],
inds_with_external_users: list[int],
flattened_getitem_nodes,
) -> bool:
input_storages: dict[StorageWeakRef, Node] = dict()
for node in inputs:
if node in flattened_getitem_nodes:
continue
example_value = node.meta["example_value"]
if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage())
@ -406,6 +472,9 @@ def _has_aliasing(
output_storages: dict[StorageWeakRef, Node] = dict()
for i in inds_with_external_users:
out_node = region[i]
if out_node in flattened_getitem_nodes:
continue
if out_node:
example_value = out_node.meta["example_value"]
assert not isinstance(example_value, list)
@ -437,3 +506,90 @@ def _has_aliasing(
return True
return False
def _is_tuple_node(node: Node) -> bool:
return isinstance(node.meta["example_value"], tuple)
def _get_children_getitems(node: Node) -> Generator[Node, None, None]:
for user in node.users:
if user.target == operator.getitem and isinstance(user.args[1], int):
yield user
def _get_flattened_node_indices(node: Node, region: Region) -> OrderedSet[int]:
"""Returns an ordered set of indices, each reprenting a node in the region which will be flattened"""
flattened_node_to_ind = {n: i for i, n in enumerate(region)}
node_indices: OrderedSet[int] = OrderedSet()
queue = deque(_get_children_getitems(node))
while queue:
cur_node = queue.popleft()
if any(user in region for user in cur_node.users):
node_indices.add(flattened_node_to_ind[cur_node])
for child in _get_children_getitems(cur_node):
queue.append(child)
return node_indices
def _create_getitem_nodes(
node: Node, subgraph_tuple_node, subgraph: torch.fx.Graph
) -> tuple[list[Node], tuple[Any]]:
tup = node.meta["example_value"]
assert isinstance(tup, tuple), "_get_getitem_children expects tuple"
getitem_nodes = []
queue = deque((e, (i,), subgraph_tuple_node) for i, e in enumerate(tup))
path_to_output_index = {}
while queue:
cur_elem, path, parent = queue.popleft()
new_getitem_node = subgraph.create_node(
"call_function", operator.getitem, (parent, path[-1]), {}
)
new_getitem_node.meta["example_value"] = cur_elem
path_to_output_index[path] = len(getitem_nodes)
getitem_nodes.append(new_getitem_node)
if isinstance(cur_elem, tuple):
queue.extend(
(e, path + (i,), new_getitem_node) for i, e in enumerate(cur_elem)
)
return getitem_nodes, path_to_output_index
def _replace_tuple_outputs(
node, output_index, tuple_spec, invoke_subgraph_node, graph
) -> OrderedSet[Node]:
assert _is_tuple_node(node), "_replace_tuple_outputs expects a tuple node"
queue = deque((c, (c.args[1],)) for c in _get_children_getitems(node))
erased_nodes = OrderedSet()
while queue:
cur_node, path = queue.pop()
for c in _get_children_getitems(cur_node):
queue.append((c, path + (c.args[1],)))
subgraph_output = graph.create_node(
"call_function",
operator.getitem,
(invoke_subgraph_node, output_index + tuple_spec[path]),
{},
)
cur_node.replace_all_uses_with(subgraph_output, propagate_meta=True)
graph.erase_node(cur_node)
erased_nodes.add(cur_node)
graph.erase_node(node)
erased_nodes.add(node)
return erased_nodes

View File

@ -327,8 +327,17 @@ class GraphRegionTracker:
self._is_identical,
)
# sort topologically
# we need to handle edge cases where some nodes have no dependencies
# so first we map each node to its ranking,
ref_region = region_group[0]
index_to_rank = {
index: topological_ranking[n] for index, n in enumerate(ref_region)
}
sorted_indices = sorted(
range(len(ref_region)), key=lambda i: index_to_rank[i]
)
for region in region_group:
region.sort(key=lambda n: topological_ranking[n])
region[:] = [region[i] for i in sorted_indices]
return [
region_group for region_group in region_groups if len(region_group[0]) > 1

View File

@ -1030,13 +1030,19 @@ class OutputGraph(OutputGraphGuardsState):
name = OutputGraph.module_key_name(*names)
name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
self.nn_modules[name] = target
if isinstance(target, torch.nn.Module):
def register_leaf_name(leaf_name):
assert self.param_name_to_source is not None
new_source = ParamBufferSource(source, leaf_name)
new_name = f"{name}.{leaf_name}"
# If source is None we are installing a subgraph and
# propagating existing parameters to the new nn module
if source:
new_source = ParamBufferSource(source, leaf_name)
else:
new_source = self.param_name_to_source[leaf_name]
self.param_name_to_source[new_name] = new_source
if isinstance(source, LocalSource):
self.dynamo_flat_name_to_original_fqn[
@ -1849,9 +1855,7 @@ class OutputGraph(OutputGraphGuardsState):
raise
except exceptions_allowed_to_be_fallback as e:
if self.has_user_defined_allowed_in_graph:
raise BackendCompilerFailed(
self.compiler_fn, e, inspect.currentframe()
).with_traceback(e.__traceback__) from None
raise BackendCompilerFailed(self.compiler_fn, e, inspect.currentframe())
unimplemented_v2_with_warning(
e,
self.root_tx.f_code,
@ -1867,9 +1871,8 @@ class OutputGraph(OutputGraphGuardsState):
# aborting execution.
raise e
except Exception as e:
raise BackendCompilerFailed(
self.compiler_fn, e, inspect.currentframe()
).with_traceback(e.__traceback__) from None
breakpoint()
raise BackendCompilerFailed(self.compiler_fn, e, inspect.currentframe())
signpost_event(
"dynamo",

View File

@ -954,7 +954,7 @@ class VariableBuilder:
unimplemented_v2(
gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph",
context="",
explanation="Directly using invoke_subgraph is not supported. Use mark_compile_region",
explanation="Directly using invoke_subgraph is not supported. Use nested_compile_region",
hints=[],
)
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)

View File

@ -818,9 +818,7 @@ class BuiltinVariable(VariableTracker):
handlers: list[_HandlerCallback] = []
if any(issubclass(t, LazyVariableTracker) for t in arg_types):
return lambda tx, args, kwargs: obj.call_function(
tx, [v.realize() for v in args], kwargs
)
return lambda tx, args, kwargs: obj.call_function(tx, args, kwargs)
if inspect.isclass(fn) and (
issubclass(fn, Exception)
@ -1175,6 +1173,30 @@ class BuiltinVariable(VariableTracker):
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
key: tuple[object, ...]
from .lazy import LazyVariableTracker
if (
self.fn is operator.getitem
and len(args) == 2
and isinstance(args[0], LazyVariableTracker)
and isinstance(args[1], LazyVariableTracker)
):
new_arg0 = args[0].realize()
if (
isinstance(new_arg0, ListVariable)
and new_arg0.are_items_same
and isinstance(args[1].peek_value(), int)
):
# This translates to a user code that has list[index], where
# both have sources, and we are about to insert a EQUALS_MATCH
# guard on the index. This will likely have large number of
# recompilations on the `index`, so one way to avoid this is to
# guard recursively on the list but not guard on the index.
LazyVariableTracker.realize_all(new_arg0)
new_arg1 = ConstantVariable(args[1].peek_value())
args = [new_arg0, new_arg1]
args = [v.realize() for v in args]
if kwargs:
kwargs = {k: v.realize() for k, v in kwargs.items()}
key = (self.fn, *(type(x) for x in args), True)

View File

@ -823,6 +823,7 @@ def speculate_subgraph(
context=context,
explanation=f"Higher order ops do not support input mutation. Found in {source_target.name()}",
hints=[
"Set experimental flag `torch._dynamo.config.does_invoke_subgraph_support_input_mutation=True`.",
"Consider using the debug context to change user code to avoid mutation.",
"Please open an issue.",
],
@ -3397,7 +3398,7 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
return body_name
@raise_hard_error_if_graph_break(
reason="torch.compile requires the `mark_compile_region` decorated function to be capturable into a single graph",
reason="torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
)
def call_function(
self,
@ -3405,6 +3406,10 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
self.supports_input_mutation = (
torch._dynamo.config.does_invoke_subgraph_support_input_mutation
)
# This flattens the kwargs into lifted args
(
p_args,

View File

@ -17,6 +17,7 @@ variable tracking system.
"""
import collections
import functools
import inspect
import operator
from typing import Optional, TYPE_CHECKING
@ -85,6 +86,28 @@ class BaseListVariable(VariableTracker):
def modified(self, items, **kwargs):
return type(self)(items, **kwargs)
@functools.cached_property
def are_items_same(self):
items = self.items
if len(items) >= 2 and all(
isinstance(x, variables.LazyVariableTracker) for x in items
):
values = [x.peek_value() for x in items]
if all(variables.ConstantVariable.is_literal(v) for v in values):
return True
if all(isinstance(v, torch.Tensor) for v in values):
from torch.fx.passes.shape_prop import _extract_tensor_metadata
metadata0 = _extract_tensor_metadata(values[0])
same_metedata = True
for v in values[1:]:
if _extract_tensor_metadata(v) != metadata0:
return False
if same_metedata:
return True
return False
@property
def value(self):
return self.as_python_constant()

View File

@ -1342,7 +1342,7 @@ class GraphLowering(torch.fx.Interpreter):
# nested subgraphs can have singleton outputs
result = (result,)
assert isinstance(result, (tuple, list)), type(result)
assert all(
if not all(
isinstance(
x,
(
@ -1358,7 +1358,8 @@ class GraphLowering(torch.fx.Interpreter):
),
)
for x in result
), result
):
breakpoint()
fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type]
if not isinstance(fx_node_args, (tuple, list)):

View File

@ -32,6 +32,7 @@ __all__ = [
"skip_guard_on_all_nn_modules_unsafe",
"keep_tensor_guards_unsafe",
"skip_guard_on_globals_unsafe",
"nested_compile_region",
]
@ -566,3 +567,36 @@ def skip_guard_on_globals_unsafe(guard_entries):
"""
return [not entry.is_global for entry in guard_entries]
def nested_compile_region(fn=None):
"""
Tells **``torch.compile``** that the marked set of operations forms a nested
compile region (which is often repeated in the full model) whose code can be
compiled once and safely reused. ``nested_compile_region`` can also be used
as a decorator.
During **``torch.compile``** tracing, the compiler applies *hierarchical
compilation* with ``nested_compile_region``: it emits optimized code for the
marked region the first time it is encountered and re-emits (or “stamps
out”) the previously compiled code on every subsequent invocation. This can
substantially reduce overall compile time for deeply-stacked,
structurally-identical components such as the transformer layers of a
large-language-model (LLM).
Outside a ``torch.compile`` context—i.e., in standard eager execution—the
call is a no-op, so existing workflows remain unaffected.
Note that ``nested_compile_region`` **does not** promise that a region will
be compiled exactly once. If the compiler detects that new input conditions
(shape, dtype, device, stride, globals etc.) make the cached version invalid
to reuse, it will transparently re-compile the region. Using it is
therefore *safe*: correctness is always preserved, and you pay the extra
compilation cost only when required.
"""
from torch._higher_order_ops.invoke_subgraph import (
mark_compile_region as _mark_compile_region,
)
return _mark_compile_region(fn)