Compare commits

...

1 Commits

Author SHA1 Message Date
9cee5f7f08 Add simple testing file, evaluate feasability 2025-07-08 10:19:35 -07:00
23 changed files with 862 additions and 94 deletions

View File

@ -0,0 +1,297 @@
import sys
import time
import diffusers
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
GGUFQuantizationConfig,
)
import torch
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
def compile_full_model(model):
model.compile(fullgraph=True, mode="reduce-overhead")
def compile_regions(model, nn_modules):
model.compile_repeated_blocks(fullgraph=True)
# for submod in model.modules():
# if isinstance(submod, nn_modules):
# print("Compiling", submod.__class__)
# submod.compile(fullgraph=True)
def compile_hierarchical(model, nn_modules):
for submod in model.modules():
if isinstance(submod, nn_modules):
submod.__class__.forward = mark_compile_region(submod.__class__.forward)
model.compile(fullgraph=True)
def auroflow_benchmark(mode):
transformer = AuraFlowTransformer2DModel.from_single_file(
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3",
torch_dtype=torch.bfloat16,
transformer=transformer,
).to("cuda")
if mode == "full":
compile_full_model(pipeline.transformer)
elif mode == "regional":
compile_regions(
pipeline.transformer,
(
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowSingleTransformerBlock,
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowJointTransformerBlock,
),
)
elif mode == "hierarchical":
compile_hierarchical(
pipeline.transformer,
(
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowSingleTransformerBlock,
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowJointTransformerBlock,
),
)
else:
assert mode == "eager"
pipeline("A cute pony", width=512, height=512, num_inference_steps=1)
t0 = time.perf_counter()
pipeline("A cute pony", width=512, height=512, num_inference_steps=50)
t1 = time.perf_counter()
print(t1 - t0)
def wan_benchmark(mode):
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import load_image
from transformers import CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(
model_id, subfolder="vae", torch_dtype=torch.float32
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
# replace this with pipe.to("cuda") if you have sufficient VRAM
# pipe.enable_model_cpu_offload()
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
if mode == "full":
compile_full_model(pipe.transformer)
elif mode == "regional":
compile_regions(
pipe.transformer,
(diffusers.models.transformers.transformer_wan.WanTransformerBlock,),
)
elif mode == "hierarchical":
compile_hierarchical(
pipe.transformer,
(diffusers.models.transformers.transformer_wan.WanTransformerBlock,),
)
else:
assert mode == "eager"
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=1,
guidance_scale=5.0,
).frames[0]
t0 = time.perf_counter()
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
t1 = time.perf_counter()
print(t1 - t0)
def ltx_benchmark(mode):
from diffusers import LTXConditionPipeline
import torch
pipe = LTXConditionPipeline.from_pretrained(
"Lightricks/LTX-Video-0.9.7-dev", torch_dtype=torch.bfloat16
)
pipe.to("cuda")
pipe.vae.enable_tiling()
def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % pipe.vae_spatial_compression_ratio)
width = width - (width % pipe.vae_spatial_compression_ratio)
return height, width
prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
expected_height, expected_width = 512, 704
downscale_factor = 2 / 3
num_frames = 121
# Part 1. Generate video at smaller resolution
downscaled_height, downscaled_width = (
int(expected_height * downscale_factor),
int(expected_width * downscale_factor),
)
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
downscaled_height, downscaled_width
)
if mode == "full":
compile_full_model(pipe.transformer)
elif mode == "regional":
compile_regions(
pipe.transformer,
(diffusers.models.transformers.transformer_ltx.LTXVideoTransformerBlock,),
)
elif mode == "hierarchical":
compile_hierarchical(
pipe.transformer,
(diffusers.models.transformers.transformer_ltx.LTXVideoTransformerBlock,),
)
latents = pipe(
conditions=None,
prompt=prompt,
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
num_inference_steps=1,
generator=torch.Generator().manual_seed(0),
output_type="latent",
).frames
import time
t0 = time.time()
latents = pipe(
conditions=None,
prompt=prompt,
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
num_inference_steps=50,
generator=torch.Generator().manual_seed(0),
output_type="latent",
).frames
t1 = time.time()
print(t1 - t0)
def flux_benchmark(mode):
import diffusers
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to("cuda")
prompt = "A cat holding a sign that says hello world"
if mode == "full":
compile_full_model(pipe.transformer)
elif mode == "regional":
compile_regions(
pipe.transformer,
(
diffusers.models.transformers.transformer_flux.FluxTransformerBlock,
diffusers.models.transformers.transformer_flux.FluxSingleTransformerBlock,
),
)
elif mode == "hierarchical":
compile_hierarchical(
pipe.transformer,
(
diffusers.models.transformers.transformer_flux.FluxTransformerBlock,
diffusers.models.transformers.transformer_flux.FluxSingleTransformerBlock,
),
)
t0_0 = time.perf_counter()
pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=1,
max_sequence_length=512,
)
t1_0 = time.perf_counter()
print(t1_0 - t0_0)
t0 = time.perf_counter()
pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
)
t1 = time.perf_counter()
print(t1 - t0)
model_name = sys.argv[1]
mode = sys.argv[2]
if model_name == "auroflow":
auroflow_benchmark(mode)
elif model_name == "wan":
wan_benchmark(mode)
elif model_name == "ltx":
ltx_benchmark(mode)
elif model_name == "flux":
flux_benchmark(mode)

54
simple_benchmark.py Normal file
View File

@ -0,0 +1,54 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, input):
# Convolution layer C1: 1 input image channel, 6 output channels,
# 5x5 square convolution, it uses RELU activation function, and
# outputs a Tensor with size (N, 6, 28, 28), where N is the size of the batch
c1 = F.relu(self.conv1(input))
# Subsampling layer S2: 2x2 grid, purely functional,
# this layer does not have any parameter, and outputs a (N, 6, 14, 14) Tensor
s2 = F.max_pool2d(c1, (2, 2))
# Convolution layer C3: 6 input channels, 16 output channels,
# 5x5 square convolution, it uses RELU activation function, and
# outputs a (N, 16, 10, 10) Tensor
c3 = F.relu(self.conv2(s2))
# Subsampling layer S4: 2x2 grid, purely functional,
# this layer does not have any parameter, and outputs a (N, 16, 5, 5) Tensor
s4 = F.max_pool2d(c3, 2)
# Flatten operation: purely functional, outputs a (N, 400) Tensor
s4 = torch.flatten(s4, 1)
# Fully connected layer F5: (N, 400) Tensor input,
# and outputs a (N, 120) Tensor, it uses RELU activation function
f5 = F.relu(self.fc1(s4))
# Fully connected layer F6: (N, 120) Tensor input,
# and outputs a (N, 84) Tensor, it uses RELU activation function
f6 = F.relu(self.fc2(f5))
# Gaussian layer OUTPUT: (N, 84) Tensor input, and
# outputs a (N, 10) Tensor
output = self.fc3(f6)
return output
# Example usage
if __name__ == "__main__":
input_ = torch.randn(1, 1, 32, 32)
net = Net()
compiled_net = torch.compile(net)
print("Function compiled!")
result = compiled_net(input_)
print("Result:", result)

144
test.py Normal file
View File

@ -0,0 +1,144 @@
import torch
import torch._dynamo.test_case
import torch._inductor.test_case
from torch._dynamo.testing import AotEagerAndRecordGraphs
from torch.testing._internal.opinfo.core import SampleInput
DEBUG = True
class ViewAndMutationMetaFromDynamo(torch._dynamo.test_case.TestCase):
# Note: These 4 tests are to evalaute the feasability for the four
# fields in metadata analysis that were identifed as diffifuclt to import pdb; pdb.set_trace()
# extract from dynamo
# We want to run each one with different backend to check the graph
# To view the created artifact and verify we have sufficient info
def test_output_alias_info_functional_tensor(self):
def f(x):
return x[1].view(-1)
x = torch.randn(4, 4, requires_grad=True)
# backend = EagerAndRecordGraphs()
backend = AotEagerAndRecordGraphs()
compiled_f = torch.compile(f, backend=backend, fullgraph=True)
out = compiled_f(x)
assert len(backend.graphs) == 1
gm = backend.graphs[0]
def test_input_alias_info_mutations_hidden_from_autograd(self):
# From https://github.com/pytorch/pytorch/blob/1258aac1c28f2e66f54ecacaf798a0e7a24206ef/test/functorch/test_aotdispatch.py#L1457
def f(a):
a_alias = a.view(-1)
with torch.no_grad():
a_alias.mul_(2)
return a + 1
x = torch.randn(4, 4, requires_grad=True)
# backend = EagerAndRecordGraphs()
backend = AotEagerAndRecordGraphs()
compiled_f = torch.compile(f, backend=backend, fullgraph=True)
out = compiled_f(x)
assert len(backend.graphs) == 1
gm = backend.graphs[0]
# Currently fails because second collection pass doesn't pass GM
# this is okay, as first pass has parity - just need to clena up code
def test_traced_tangents(self):
# From ttps://github.com/pytorch/pytorch/blob/1258aac1c28f2e66f54ecacaf798a0e7a24206ef/test/functorch/test_aotdispatch.py#L6541
def fn(x):
return x.clone()
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
# backend = EagerAndRecordGraphs()
backend = AotEagerAndRecordGraphs()
out = torch.compile(fn, backend=backend, fullgraph=True)(nt)
out_buffer = out.values()
ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))
assert len(backend.graphs) == 1
gm = backend.graphs[0]
def test_tokens(self):
# Map of effect type (ex. _EffectType.ORDERED) to token
# FunctionalTensorMode would have populated this, so we need to validate
# that we can populate this from dynamo - should be similar to HOPs and Triton
# kernels
# From https://github.com/pytorch/pytorch/blob/1258aac1c28f2e66f54ecacaf798a0e7a24206ef/test/higher_order_ops/test_with_effects.py#L89
def f(x):
torch.ops.aten._print("moo")
res = x + x
torch.ops.aten._print("moo")
return (res,)
inputs = (torch.randn(3),)
# backend = EagerAndRecordGraphs()
backend = AotEagerAndRecordGraphs()
out = torch.compile(f, backend=backend, fullgraph=True)(inputs)
assert len(backend.graphs) == 1
gm = backend.graphs[0]
def test_tp_transform_with_uncovered_op(self):
class DummyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = torch.nn.Linear(3, 5)
self.bn = torch.nn.BatchNorm1d(5)
def forward(self, x):
return self.bn(self.fc(x))
inputs = (torch.randn(7, 3, requires_grad=False),)
model = DummyModel()
res = model(*inputs)
exported_program = torch.export.export(
model, inputs, strict=True
).run_decompositions()
tp_res = exported_program.module()(*inputs)
self.assertEqual(res, tp_res)
# Expect all_gather to be inserted to distributed sharded fc resutls
def test_chebyshev_polynomial(self):
out = torch.zeros((2,))
def f(samp: SampleInput):
torch.special.chebyshev_polynomial_u(*samp.args, samp.input, out=out)
return out
n = -5.115719318389893
inp = (torch.tensor(1.0961),)
samp = SampleInput(
n,
args=inp,
kwargs={},
)
# backend = EagerAndRecordGraphs()
backend = AotEagerAndRecordGraphs()
compiled_f = torch.compile(f, backend=backend, fullgraph=True)
out = compiled_f(samp)
plain = f(samp)
print(f"Out: {out} vs Plain: {plain}")
self.assertEqual(out, plain)
def test_baddbmm(self):
aten = torch.ops.aten
def fn(a, b, c, beta):
return aten.baddbmm(a, b, c, beta=beta)
compiled_fn = torch.compile(fn, dynamic=True)
a = torch.randn(6, 1, 100)
b = torch.randn(6, 128, 64)
c = torch.randn(6, 64, 100)
self.assertEqual(compiled_fn(a, b, c, 0.0), fn(a, b, c, 0.0))
self.assertEqual(compiled_fn(a, b, c, 1.0), fn(a, b, c, 1.0))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -57,11 +57,11 @@ class ComptimeTests(torch._dynamo.test_case.TestCase):
self.assertExpectedInline(
FILE.getvalue().strip(),
"""\
FakeTensor(..., size=(s77,))
FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(s77,))))
2
[FakeTensor(..., size=(s77,)), 2]
(FakeTensor(..., size=(s77,)), 2)
{'foo': FakeTensor(..., size=(s77,))}
[FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(s77,)))), 2]
(FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(s77,)))), 2)
{'foo': FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(s77,))))}
range(1, 3, 1)
Employee(name='foo', id=2)
UserDefinedListVariable(mylist)
@ -160,7 +160,7 @@ def forward(self, L_x_ : torch.Tensor):
self.assertExpectedInline(
FILE.getvalue(),
"""\
- FakeTensor(..., size=(2,))
- FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(2,))))
""",
)
@ -186,8 +186,8 @@ def forward(self, L_x_ : torch.Tensor):
self.assertExpectedInline(
FILE.getvalue(),
"""\
x = FakeTensor(..., size=(2,))
y = FakeTensor(..., size=(2,))
x = FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(2,))))
y = FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(2,))))
""",
)

View File

@ -2496,6 +2496,9 @@ def forward(self, primals_1, primals_2):
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=True),
]
compiled_f = torch.compile(f)
actual = compiled_f(*inp_grad)
self.assertEqual(actual, f(*inp_grad))
self.verify_aot_autograd(f, inp_grad, test_mutation=True)
def test_backward_mutation_metadata(self):

View File

@ -6035,6 +6035,8 @@ def forward(self, x_1):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
# cpp(fake(rv))
# Python(cpp(cpp(fake(rv))))
try:
func_args = pytree.tree_map(
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,

View File

@ -280,11 +280,15 @@ class FakeTensorTest(TestCase):
b = _add_batch_dim(x, 0, 0)
mode = FakeTensorMode()
fake_b = mode.from_tensor(b)
print(fake_b)
prims.utils.compare_tensor_meta(b, fake_b, check_strides=True)
b1 = _add_batch_dim(x, 1, 1)
b2 = _add_batch_dim(b1, 0, 2)
fake_b2 = mode.from_tensor(b2)
print(fake_b2)
print(is_batchedtensor(fake_b2))
print(is_batchedtensor(fake_b))
prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True)
self.assertTrue(is_batchedtensor(fake_b2))
fake_b1 = get_unwrapped(fake_b2)
@ -700,7 +704,12 @@ class FakeTensorTest(TestCase):
fake_x_view = mode.from_tensor(x_view)
fake_x = mode.from_tensor(x)
self.assertFalse(isinstance(fake_x_view, torch.nn.Parameter))
self.assertTrue(isinstance(fake_x, torch.nn.Parameter))
# fake_x = from_fun(fake_x)
# from torch._subclasses.functional_tensor import FunctionalTensor
# fake_x = FunctionalTensor.from_functional(fake_x)
# print(f"Fake tensor is {fake_x}, is instance? {isinstance(fake_x, torch.nn.Parameter)}? type is: {type(fake_x)}")
self.assertTrue(isinstance(fake_x, torch.nn.Parameter)) # <- Failing
def test_tolist(self):
shape_env = ShapeEnv()

View File

@ -46,6 +46,7 @@ import torch.utils._pytree as pytree
from torch import fx, Tensor
from torch._C._dynamo import guards
from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
from torch._functorch._aot_autograd.functional_utils import from_fun
from torch._guards import (
CompileContext,
CompileId,
@ -55,6 +56,7 @@ from torch._guards import (
TracingContext,
)
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch._utils_internal import signpost_event
from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
from torch.fx.experimental._backward_state import BackwardState
@ -462,7 +464,12 @@ class OutputGraph(OutputGraphGuardsState):
allow_non_fake_inputs=True if self.export else False,
export=self.export,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)
functional_mode = FunctionalTensorMode(
export=self.export, _allow_token_discovery=True
)
self.tracing_context: TracingContext = TracingContext(
fake_mode, functional_mode
)
self.tracing_context.traced_code.append(f_code)
self.dynamo_compile_id: Optional[CompileId] = (
CompileContext.current_compile_id()
@ -697,6 +704,7 @@ class OutputGraph(OutputGraphGuardsState):
"""
call fn(*args) before the graph runs and turn the result into a fake input.
"""
# NOTE: Where does this get used???
example_value = fn(*args)
varname = self.new_var()
cg = PyCodegen(self.root_tx)
@ -807,6 +815,10 @@ class OutputGraph(OutputGraphGuardsState):
def fake_mode(self):
return self.tracing_context.fake_mode
@property
def functional_mode(self):
return self.tracing_context.functional_mode
@property
def shape_env(self):
return self.tracing_context.fake_mode.shape_env
@ -923,6 +935,7 @@ class OutputGraph(OutputGraphGuardsState):
# than just nn module objects, fix that.
self.nn_modules[attr_name] = attr_value
proxy = self.create_proxy("get_attr", attr_name, (), {})
# NOTE: Figure where this gets used
set_example_value(proxy.node, attr_value)
return proxy
@ -1691,18 +1704,18 @@ class OutputGraph(OutputGraphGuardsState):
)
self.call_cleanup_hooks()
old_fake_mode = self.tracing_context.fake_mode
if not self.export:
import torch._functorch.config as _config
# if not self.export:
# import torch._functorch.config as _config
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
backend_fake_mode = torch._subclasses.FakeTensorMode(
shape_env=old_fake_mode.shape_env,
)
# TODO(voz): Ostensibily, this should be scoped and
# restore back to old_fake_mode, but doing so currently violates
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
self.tracing_context.fake_mode = backend_fake_mode
# with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
# # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
# backend_fake_mode = torch._subclasses.FakeTensorMode(
# shape_env=old_fake_mode.shape_env,
# )
# # TODO(voz): Ostensibily, this should be scoped and
# # restore back to old_fake_mode, but doing so currently violates
# # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
# self.tracing_context.fake_mode = backend_fake_mode
with self.restore_global_state():
compiled_fn = self.call_user_compiler(gm, self.example_inputs())
@ -2093,6 +2106,10 @@ class OutputGraph(OutputGraphGuardsState):
for node in self.graph.nodes:
example_value = node.meta.get("example_value")
# May need to unwrap functional tensor if used
if isinstance(example_value, FunctionalTensor):
example_value = from_fun(example_value)
if (
isinstance(example_value, FakeTensor)
and example_value.item_memo is not None
@ -2103,6 +2120,7 @@ class OutputGraph(OutputGraphGuardsState):
example_value.item_memo.node._expr.name
)
):
# Likely never hit but should be!
for u in list(node.users):
u.replace_all_uses_with(guard_scalar(example_value.item_memo))
self.remove_node(u)
@ -2476,7 +2494,6 @@ class SubgraphTracer(fx.Tracer):
trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
self.prev_inst = cur_inst
# update reference to original meta if we're tracing a new code object
is_retracing = False
if tx.f_code is not self._cur_code:
@ -2524,7 +2541,6 @@ class SubgraphTracer(fx.Tracer):
),
)
]
self._maybe_preserve_original_meta(tx, rv.node)
if not is_retracing:
@ -2805,6 +2821,7 @@ class SubgraphTracer(fx.Tracer):
def track_unbacked_symbols(
self, example_value, e_proxy: Union[LazyProxy, torch.fx.Proxy]
):
# breakpoint()
# When binding the symbols in an exmaple_value, we bind the symbols
# to the proxy's associated Tracer instead of current tracer.
# This is because:
@ -2837,6 +2854,7 @@ class SubgraphTracer(fx.Tracer):
return proxy
if isinstance(example_value, torch.Tensor):
example_value = from_fun(example_value)
for i, s in enumerate(example_value.size()):
if need_bind(s):
log.debug(

View File

@ -3164,6 +3164,10 @@ class InstructionTranslatorBase(
def fake_mode(self):
return self.output.tracing_context.fake_mode
@property
def functional_mode(self):
return self.output.tracing_context.functional_mode
@contextlib.contextmanager
def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]):
"""
@ -4035,6 +4039,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
def fake_mode(self):
return self.parent.fake_mode
@property
def functional_mode(self):
return self.parent.functional_mode
def run_ctx_mgr(self):
return TracingContext.current_frame(self.parent.frame_summary())

View File

@ -2791,6 +2791,13 @@ def deepcopy_to_fake_tensor(obj, fake_mode):
return wrap_fake_exception(lambda: copy.deepcopy(obj))
def deepcopy_to_functional_tensor(obj, fake_mode, functional_mode):
with torch._subclasses.functional_tensor.FunctionalCopyMode(
fake_mode, functional_mode
):
return wrap_fake_exception(lambda: copy.deepcopy(obj))
def rmse(ref, res):
"""
Calculate root mean squared error
@ -3241,9 +3248,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
id_to_initial_version = {}
nnmodule = None
# CURRENT ISSUE: deepcopy_to_fake_tensor doesn't work with FunctionalMode
# without this mode though, we don't have parameters of proper tensor types
if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):
# If the first argument is nn.Module, should copy to fake mode.
args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:])
copied = deepcopy_to_fake_tensor(args[0], tx.fake_mode)
args = (copied,) + tuple(args[1:])
if op == "call_module":
nnmodule = tx.output.nn_modules[node.target]
@ -3256,7 +3266,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
nnmodule._infer_parameters(nnmodule, args)
# no matter it's lazy module or not, we should copy to fake mode.
nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)
# Can make FunctionalMode like Fake one and handle
# or wrap the resulting deep copy
# just disable functional mode first
nnmodule = deepcopy_to_functional_tensor(
nnmodule, tx.fake_mode, tx.functional_mode
)
if node.name in ["interpolate", "is_integer", "wrapped_gradient"] or any(
isinstance(a, complex) for a in args
@ -3270,7 +3285,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
)
try:
with tx.fake_mode, enable_python_dispatcher():
with tx.functional_mode, tx.fake_mode, enable_python_dispatcher():
ret_val = wrap_fake_exception(
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
)
@ -3426,7 +3441,6 @@ def run_node(tracer, node, args, kwargs, nnmodule):
raise an AssertionError.
"""
op = node.op
with set_current_node(node):
def make_error_message(e):
@ -3453,7 +3467,8 @@ def run_node(tracer, node, args, kwargs, nnmodule):
return getattr(args[0], node.target)(*args[1:], **kwargs)
elif op == "call_module":
assert nnmodule is not None
return nnmodule(*args, **kwargs)
with tracer.functional_mode:
return nnmodule(*args, **kwargs)
elif op == "get_attr":
return tracer.output_graph.get_submodule(node.target)
elif op == "placeholder":

View File

@ -51,6 +51,7 @@ from torch._dynamo.utils import (
is_torch_sym,
set_feature_use,
)
from torch._functorch._aot_autograd.functional_utils import to_fun
from torch._guards import TracingContext
from torch._higher_order_ops.torchbind import call_torchbind
from torch._ops import HigherOrderOperator
@ -450,7 +451,9 @@ class VariableBuilder:
and value not in self.tx.output.side_effects
and not is_wrapper_or_member_descriptor(value)
):
vt = self.tx.output.side_effects.track_object_existing(value, vt)
vt = self.tx.output.side_effects.track_object_existing(
value, vt
) # NOTE: Shouldn't this track the mutation?
self.tx.output.variable_tracker_cache.add(value, self.source, vt)
return vt
@ -2085,11 +2088,21 @@ class VariableBuilder:
# then the relevant SubgraphTracer will lift it to being an input of
# the subgraph.
# See NOTE [HigherOrderOperator tracing design] for more details.
# Need to handle weird edge case where value is already a functional tensor...
# torch._disable_functionalization()
# if torch._is_functional_tensor(value):
# val = torch._from_functional_tensor(value)
# else:
# val = value
# breakpoint()
example_value = wrap_to_fake_tensor_and_record(
value, tx=self.tx, is_tensor=True, source=source
)
if not torch._is_functional_tensor(example_value):
with self.tx.functional_mode:
example_value = to_fun(example_value)
# NOTE: figure out what reapply_views is from keyset
# torch._enable_functionalization(reapply_views=False)
tensor_proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(value),
@ -2203,12 +2216,15 @@ class VariableBuilder:
# that there's not another great way to do this atm.
# This creates the right graphargs, as well as registration for guards in tensor names and shape env.
LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value))
example_value = wrap_to_fake_tensor_and_record(
tensor_value,
tx=self.tx,
is_tensor=False,
source=source,
)
with self.tx.functional_mode:
example_value = to_fun(
wrap_to_fake_tensor_and_record(
tensor_value,
tx=self.tx,
is_tensor=False,
source=source,
)
)
proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(tensor_value),
@ -2435,9 +2451,12 @@ class VariableBuilder:
# TODO: Maybe the tensor-ification should be built into the source,
# rather than by special pattern match
example_value = wrap_to_fake_tensor_and_record(
wrapped_value, tx=self.tx, is_tensor=False, source=source
)
with self.tx.functional_mode:
example_value = to_fun(
wrap_to_fake_tensor_and_record(
wrapped_value, tx=self.tx, is_tensor=False, source=source
)
)
proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(wrapped_value),
@ -2465,7 +2484,7 @@ class VariableBuilder:
assert is_fake(example_value)
fake_tensor_value = example_value
assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
assert maybe_get_fake_mode(fake_tensor_value) is self.tx.fake_mode, (
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
"({self.tx.fake_mode}) from InstructionTranslator"
)
@ -2778,7 +2797,8 @@ def _wrap_fx_proxy(
# with preserve_rng_state():
# only allow_non_graph_fake in this instance because we handle the non-fake
# cases properly below.
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
with tx.functional_mode:
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
return handle_traced_output(
example_value, tx, proxy, options, subclass_type, target_cls
@ -2999,8 +3019,10 @@ def construct_tensor_variable(
# TODO: not sure about this fake mode test
if (
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
and example_value.fake_mode is tx.fake_mode
):
or isinstance(
example_value, torch._subclasses.functional_tensor.FunctionalTensor
)
) and maybe_get_fake_mode(example_value) is tx.fake_mode:
if subclass_type:
tensor_type = subclass_type
elif isinstance(example_value, torch.nn.Parameter):

View File

@ -1564,7 +1564,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
from torch._higher_order_ops.utils import _maybe_fake_tracing
from torch._inductor.utils import is_pointwise_use
with tx.fake_mode:
with tx.fake_mode, tx.functional_mode:
sub_args_fake = [
leaf.node.meta["example_value"].clone()
if hasattr(leaf.node.meta["example_value"], "clone")
@ -1601,7 +1601,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
additional_inputs_proxy,
)
with tx.fake_mode:
with tx.fake_mode, tx.functional_mode:
out_meta = tuple(
inp_proxy.node.meta["example_value"].clone() for inp_proxy in xs_proxy
)
@ -1804,7 +1804,7 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
additional_inputs_proxy,
)
with tx.fake_mode:
with tx.fake_mode, tx.functional_mode:
example_carry = [
init_p.node.meta["example_value"].clone() for init_p in init_proxy
]
@ -3037,6 +3037,10 @@ class AutogradFunctionApplyVariable(VariableTracker):
"torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops",
[],
):
# Could be a context manager to turn off counter increment
# But we don't want to record the mutations inside speculate_subgraph
# NOTE: May need special attention to Triton Hop (possibly not but maybe)
# May need to also ignore effectful python ops
(bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph(
tx,
bwd_fn,
@ -3199,7 +3203,7 @@ class AutogradFunctionApplyVariable(VariableTracker):
# (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing.
# Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it.
with enable_python_dispatcher():
with tx.output.fake_mode:
with tx.output.fake_mode, tx.output.functional_mode:
fake_args = (
tx.output.nn_modules[fwd_node.node.name],
tx.output.nn_modules[bwd_node.node.name],

View File

@ -23,6 +23,7 @@ from typing import Optional, TYPE_CHECKING
import torch
import torch.fx
from torch._functorch._aot_autograd.functional_utils import from_fun
from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
@ -149,6 +150,7 @@ class BaseListVariable(VariableTracker):
assert not kwargs and len(args) == 1
if isinstance(args[0], TensorVariable):
value = get_fake_value(args[0].as_proxy().node, tx)
value = from_fun(value)
if value.constant is not None and value.constant.numel() == 1:
value = variables.ConstantVariable.create(value.constant.item())
else:

View File

@ -694,6 +694,7 @@ class TensorVariable(VariableTracker):
pass
else:
try:
# breakpoint() # NOTE: set_item gets in here
result = handler_method(*args, **kwargs)
if result:
return result
@ -975,12 +976,15 @@ class TensorVariable(VariableTracker):
return wrap(tensor, sub_proxy)
if tensor.dim() == 1:
return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)]
return [
tolist(sub_tensor, sub_proxy=sub_proxy[i])
for i, sub_tensor in enumerate(tensor)
]
# Example value may be functional tensor, in which case,
# enumerate() calls, so we need to wrap in functional mode
with tx.functional_mode:
return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)]
with tx.functional_mode:
return [
tolist(sub_tensor, sub_proxy=sub_proxy[i])
for i, sub_tensor in enumerate(tensor)
]
tensor = self.as_proxy().node.meta["example_value"]
out = tolist(tensor, self.as_proxy())
@ -1082,13 +1086,15 @@ class TensorVariable(VariableTracker):
def method___setitem__(self, key, value):
from ..symbolic_convert import InstructionTranslator
# breakpoint()
tx = InstructionTranslator.current_tx()
proxy = tx.output.create_proxy(
"call_function",
operator.setitem,
*proxy_args_kwargs([self, key, value], {}),
)
# Run func on fake tensor - needed to set state on functional
get_fake_value(proxy.node, tx)
if config.use_graph_deduplication or config.track_nodes_for_deduplication:
tx.output.region_tracker.add_node_mutation(proxy.node, 0)

View File

@ -592,6 +592,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
# This simulates shallow-copying the tensor object.
kwargs = dict(tensor_var.__dict__)
input_tensor_type = kwargs.pop("class_type")
# TODO: Figure out why this gets called with FunctionalTensor
# and resolve!
assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), (
f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var"
)

View File

@ -152,10 +152,11 @@ def run_functionalized_fw_and_collect_metadata(
is_train: bool = False,
# Note: this is guaranteed to be set when running under dynamo
static_input_indices: Optional[list[int]] = None,
pre_dispatch: bool = False,
# is_export is technically only needed to avoid using functionalization V2
# during analysis
is_export: bool = False,
gm: Optional[torch.fx.GraphModule] = None,
functional_mode: Optional[FunctionalTensorMode] = None,
) -> Callable[..., ViewAndMutationMeta]:
memo: dict[Tensor, Tensor] = {}
@ -171,6 +172,21 @@ def run_functionalized_fw_and_collect_metadata(
@wraps(f)
def inner(*flat_args):
nonlocal static_input_indices
view_and_mutation_meta_no_gm = None
# Short circut and don't run full thing if freezing, since the input placeholders won't be
if gm is not None:
# Hacky work around - we want to check against functionalized, so let us call
# into this for now to verify correcntess
to_call = run_functionalized_fw_and_collect_metadata(
f,
keep_input_mutations=keep_input_mutations,
is_train=is_train,
static_input_indices=static_input_indices,
is_export=is_export,
gm=None,
)
view_and_mutation_meta_no_gm = to_call(*flat_args)
# This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args)
@ -189,20 +205,66 @@ def run_functionalized_fw_and_collect_metadata(
# only for figuring out metadata
mode = FunctionalTensorMode(_allow_token_discovery=True, export=is_export)
suppress_pending = contextlib.nullcontext()
fake_mode = detect_fake_mode()
if fake_mode and (shape_env := fake_mode.shape_env):
suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
with disable_above, mode, suppress_pending:
# precondition: The passed in function already handles unflattening inputs + flattening outputs
flat_f_args = pytree.tree_map(_to_fun, flat_args)
flat_f_outs = f(*flat_f_args)
# We didn't do any tracing, so we don't need to process the
# unbacked symbols, they will just disappear into the ether.
# Also, prevent memoization from applying.
if fake_mode:
fake_mode.epoch += 1
fake_mode.reset_nt_tensor_id_counter()
if gm is None:
# NOTE: If we don't have fx module (assume no dynamo),
# need to get functional tensors and do call
import time
start_time = time.time()
with disable_above, mode, suppress_pending:
# precondition: The passed in function already handles unflattening inputs + flattening outputs
flat_f_args = pytree.tree_map(_to_fun, flat_args)
flat_f_outs = f(*flat_f_args)
# We didn't do any tracing, so we don't need to process the
# unbacked symbols, they will just disappear into the ether.
# Also, prevent memoization from applying.
if fake_mode:
fake_mode.epoch += 1
fake_mode.reset_nt_tensor_id_counter()
elapsed_time_ms = (time.time() - start_time) * 1000
log.debug(
f"Time taken for functionalized execution: {elapsed_time_ms:.2f} ms"
)
else:
# NOTE: This makes a strong assumption that the order of f_args is analogous to args
# this may not be true...
import time
if functional_mode:
mode = functional_mode
else:
raise RuntimeError("Expected functional mode to be set")
# All `example_value` tensors on the nodes are already wrapped in
# functional tensor, so we don't need to wrap them again, only extract them
start_time = time.time()
flat_f_args = tuple(
node.meta["example_value"]
if hasattr(node, "meta") and "example_value" in node.meta
else None
for node in gm.graph.find_nodes(op="placeholder")
)
# breakpoint()
# flat_f_args = flat_f_args_maybe_tensors
# The output nodes are the last nodes (args) into the singular output node
output_args = gm.graph.find_nodes(op="output")[0].args
output_nodes = output_args[0] if len(output_args) > 0 else []
flat_f_outs = tuple(
node.meta["example_value"]
if hasattr(node, "meta") and "example_value" in node.meta
else None
for node in output_nodes
)
elapsed_time_ms = (time.time() - start_time) * 1000
log.debug(f"Time taken: {elapsed_time_ms:.2f} ms from graph")
# print(flat_f_args)
# print(flat_f_outs)
if prior_autocast_states != _get_autocast_states():
raise RuntimeError(
"AOTAutograd does not support tracing graphs that mutate the autocast state. "
@ -213,11 +275,13 @@ def run_functionalized_fw_and_collect_metadata(
# Inspect the state of the input tensor functional wrapper to detect input mutation info
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
# import pdb; pdb.set_trace()s
for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)):
# NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
# strides between the functionalized arg inner tensors and non-functionalized arg inner
# tensors. This is a problem as the inner tensor stride change may not be reflected
# correctly in the outer tensor, so disallow this for now.
# breakpoint()
mutates_data = has_data_mutation(f_arg)
mutates_metadata = has_metadata_mutation(
f_arg, arg, check_only_storage_mutation=False
@ -738,7 +802,6 @@ from a multi-output view call"
coerce_tangent_and_suggest_memory_format(tt)[0]
for i, tt in enumerate(traced_tangents)
]
nonlocal static_input_indices
static_input_indices = static_input_indices or []
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
passed_indices = set(static_input_indices)
@ -808,6 +871,17 @@ from a multi-output view call"
static_input_indices=static_input_indices,
tokens=mode._tokens,
)
if view_and_mutation_meta_no_gm is not None:
for k, v in view_and_mutation_meta_no_gm.__dict__.items():
try:
assert (
metadata.__dict__[k] == v
), f"{k} mismatch: {metadata.__dict__[k]} vs {v}"
except RuntimeError as e:
# traced tangents results in multiple in nested case so call all
log.debug(
f"~~~~Got error on KEY: {k}: {e}, \nPlease inspect manually: {metadata.__dict__[k]} vs {v}"
)
return metadata
return inner

View File

@ -39,6 +39,20 @@ def to_fun(t):
out = transform_subclass(t, lambda _, inner_t: to_fun(inner_t))
torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined]
return out
# NOTE: This may need to be expanded to other nontraceable subclasses too;
# not sure how to do this outside of explicit adding
# for instance, GradTrackingTensor
# alternatively, can move this wrapping into meta_utils.py
elif torch._C._functorch.is_batchedtensor(t):
# Special case to get to the fake tensor and wrap properly
unwrapped = torch._C._functorch.get_unwrapped(t)
assert unwrapped is not None
unwrapped = to_fun(unwrapped)
return torch._C._functorch._add_batch_dim(
unwrapped,
torch._C._functorch.maybe_get_bdim(t),
torch._C._functorch.maybe_get_level(t),
)
else:
return FunctionalTensor.to_functional(t)
else:

View File

@ -956,6 +956,11 @@ class AOTConfig:
# Used only by standalone_compile.
ignore_shape_env: bool = False
precompile_backend_id: Optional[str] = None
# For testing gradual transition towards moving metadata colleciton
# to Dynamo.
gm: Optional[torch.fx.GraphModule] = None
# For testing gradual transition towards moving metadata colleciton
functional_mode: Optional[Any] = None
def __post_init__(self):
if self.pre_dispatch:

View File

@ -682,6 +682,14 @@ def _create_aot_dispatcher_function(
dynamo_timed_ctx = dynamo_timed(
"aot_collect_metadata", log_pt2_compile_event=True
)
if aot_config.gm is not None and aot_config.functional_mode is not None:
# Need both to be non None else both should be None
# this is because functional_mode contains pertinent info to the
# FunctionalTensors in gm; if it is not preserved, we can't get this
# info properly
gm, func_mode = aot_config.gm, aot_config.functional_mode
else:
gm, func_mode = None, None
with dynamo_timed_ctx, ctx:
fw_metadata = run_functionalized_fw_and_collect_metadata(
@ -689,8 +697,9 @@ def _create_aot_dispatcher_function(
static_input_indices=aot_config.static_input_indices,
keep_input_mutations=aot_config.keep_inference_input_mutations,
is_train=needs_autograd,
pre_dispatch=aot_config.pre_dispatch,
is_export=aot_config.is_export,
gm=gm,
functional_mode=func_mode,
)(*_dup_fake_script_obj(fake_flat_args))
req_subclass_dispatch = requires_subclass_dispatch(
@ -735,7 +744,6 @@ def _create_aot_dispatcher_function(
flat_fn,
keep_input_mutations=aot_config.keep_inference_input_mutations,
is_train=False,
pre_dispatch=aot_config.pre_dispatch,
static_input_indices=aot_config.static_input_indices,
)(*fake_flat_args)
else:
@ -1172,6 +1180,13 @@ def aot_module_simplified(
cache_info=None,
ignore_shape_env=ignore_shape_env,
precompile_backend_id=getattr(mod, "_backend_id", None),
gm=mod
if (
isinstance(mod, torch.fx.GraphModule)
and not torch._inductor.config.freezing
)
else None,
functional_mode=tracing_context.functional_mode if tracing_context else None,
)
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
fake_flat_args = process_inputs(

View File

@ -13,7 +13,7 @@ import unittest.mock
import weakref
from abc import abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import (
Any,
@ -823,13 +823,16 @@ class TracingContext:
"TracingContext.get() must be called within an ongoing trace."
)
def __init__(self, fake_mode):
def __init__(self, fake_mode, functional_mode=None):
self.guards_context = GuardsContext()
self.module_context = ModuleContext()
self.global_context = GlobalContext()
self.previously_inlined_functions = dict()
self.previously_cleaned_instructions = dict()
self.fake_mode = fake_mode
self.functional_mode = (
functional_mode if functional_mode is not None else nullcontext()
)
self.frame_summary_stack = []
# This is morally part of frame_summary_stack, but it is kept separate
# for clarity. As we process a frame, this variable gets updated

View File

@ -9,6 +9,7 @@ import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._dispatch.python import suspend_functionalization
from torch._functorch._aot_autograd.functional_utils import from_fun
from torch._guards import detect_fake_mode
from torch._higher_order_ops.schema import HopSchema
from torch._ops import HigherOrderOperator, OperatorBase, OpOverload
@ -362,21 +363,23 @@ def _collect_fake_inputs(inputs):
if hasattr(inp, "meta"):
val = inp.meta["example_value"]
if isinstance(val, torch.Tensor):
# Check for and unwrap common FakeTensor Wrappers in 'example_value'
if torch._C._functorch.is_batchedtensor(
val
) or torch._C._functorch.is_functionaltensor(val):
# This case is for batched or functional tensors
# Unwrap the tensors
while torch._C._functorch.is_batchedtensor(
val
) or torch._C._functorch.is_functionaltensor(val):
val = torch._C._functorch.get_unwrapped(val)
assert isinstance(val, FakeTensor)
inputs_fake.append(val)
else:
# This is the standard case of a TensorVariable
assert isinstance(val, FakeTensor)
inputs_fake.append(val)
elif isinstance(val, FunctionalTensor):
# TODO: Figure out why these may be different
# than the C subclass semantics laid out above
# NOTE: The implementation for `is_functionaltensor`
# checks the DispatchKeySet, and it is possible this has
# been disabled as is the case for collect_view_and_mutation_metadata
val = from_fun(val)
assert isinstance(val, FakeTensor)
inputs_fake.append(val)
else:
# This case is for SymInts and other non-Tensor elements
assert not isinstance(val, torch.Tensor)

View File

@ -3,14 +3,17 @@ import contextlib
import warnings
import weakref
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from contextlib import AbstractContextManager
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, cast, Optional, Union
import torch
import torch.utils._pytree as pytree
from torch._C import _functionalization_reapply_views_tls as _reapply_views
from torch._ops import _get_dispatch_mode_pre_dispatch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.meta_utils import is_sparse_any
from torch.overrides import TorchFunctionMode
from torch.utils._python_dispatch import (
_detect_infra_mode,
_disable_infra_mode,
@ -354,6 +357,7 @@ class FunctionalTensorMode(TorchDispatchMode):
super().__exit__(a, b, c)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
# breakpoint()
if kwargs is None:
kwargs = {}
@ -568,6 +572,79 @@ def disable_functional_mode():
return _disable_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
# TODO: pull these from aot autograd
def from_fun(t):
if not isinstance(t, FunctionalTensor):
# quick sanity assert
if isinstance(t, torch.Tensor):
assert not torch._is_functional_tensor(t)
return t
torch._sync(t)
return torch._from_functional_tensor(t.elem)
def to_fun(t):
if isinstance(t, torch.Tensor):
return FunctionalTensor.to_functional(t)
return t
# Just for use to allow copying a module to Functional(Fake) tensors,
# does not apply elsewhere
class FunctionalCopyMode(TorchFunctionMode):
def __init__(
self, fake_mode: FakeTensorMode, functional_mode: FunctionalTensorMode
) -> None:
self.fake_mode = fake_mode
self.functional_mode = functional_mode
def __torch_function__(
self,
func: torch._ops.OpOverload,
types: Sequence[type],
args: Sequence[object] = (),
kwargs: Optional[Mapping[str, object]] = None,
) -> FunctionalTensor:
# Disable any outer functional modes, as these will try and intercept
# the dispatch in `func`; we only want this done afterwards
is_func_clone = func == torch._C.TensorBase.clone
is_func_deepcopy = func == torch.Tensor.__deepcopy__
kwargs = kwargs if kwargs else {}
if is_func_clone or is_func_deepcopy:
# Need to run without functional for copy/clone
with disable_functional_mode():
# clone will get called in Parameter deepcopy
if is_func_clone:
assert isinstance(args[0], torch.Tensor)
fake_clone = func(
self.fake_mode.from_tensor(args[0], static_shapes=True),
**kwargs,
)
elif is_func_deepcopy:
assert len(args) == 2 and len(kwargs) == 0
tensor = cast(torch.Tensor, args[0])
memo = cast(dict[int, FunctionalTensor], args[1])
if id(tensor) in memo:
return memo[id(tensor)]
fake_clone = self.fake_mode.from_tensor(tensor, static_shapes=True)
elif func == torch._C.TensorBase.detach:
# FunctionalTensor should run without being disabled
assert isinstance(args[0], torch.Tensor)
return func(*args, **kwargs)
else:
with disable_functional_mode(), torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
with self.functional_mode:
out = to_fun(fake_clone)
if is_func_deepcopy:
tensor = cast(torch.Tensor, args[0])
memo = cast(dict[int, FunctionalTensor], args[1])
memo[id(tensor)] = out
return out
# This is similar to torch.func.functionalize, but:
# - It uses FunctionalTensorMode, and FunctionalTensor (a python subclass).
# One important advantage to using this mode is that it will let us
@ -577,21 +654,6 @@ def disable_functional_mode():
# functorch transforms, since these transforms always run above __torch_dispatch__.
# That's why this util lives here, and not in functorch.
def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMode()):
# TODO: pull these from aot autograd
def to_fun(t):
if isinstance(t, torch.Tensor):
return FunctionalTensor.to_functional(t)
return t
def from_fun(t):
if not isinstance(t, FunctionalTensor):
# quick sanity assert
if isinstance(t, torch.Tensor):
assert not torch._is_functional_tensor(t)
return t
torch._sync(t)
return torch._from_functional_tensor(t.elem)
def inner(*args, **kwargs):
disable_above = torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)

View File

@ -1148,7 +1148,13 @@ def _free_unbacked_symbols_with_path(
r.update(go(sub, path + (InnerTensorKey(attr),)))
elif isinstance(a, torch.Tensor):
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
# Note, a may also be functional wrapping a Fake, so check and unwrap if needed
if isinstance(a, FunctionalTensor):
from torch._functorch._aot_autograd.functional_utils import from_fun
a = from_fun(a)
assert isinstance(a, FakeTensor)
r.update(
go(