mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 13:34:57 +08:00
Compare commits
1 Commits
ciflow/tru
...
lucaskabel
| Author | SHA1 | Date | |
|---|---|---|---|
| 9cee5f7f08 |
297
benchmarks/dynamo/diffusers/auroflow.py
Normal file
297
benchmarks/dynamo/diffusers/auroflow.py
Normal file
@ -0,0 +1,297 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AuraFlowPipeline,
|
||||
AuraFlowTransformer2DModel,
|
||||
GGUFQuantizationConfig,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
|
||||
|
||||
|
||||
def compile_full_model(model):
|
||||
model.compile(fullgraph=True, mode="reduce-overhead")
|
||||
|
||||
|
||||
def compile_regions(model, nn_modules):
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
# for submod in model.modules():
|
||||
# if isinstance(submod, nn_modules):
|
||||
# print("Compiling", submod.__class__)
|
||||
# submod.compile(fullgraph=True)
|
||||
|
||||
|
||||
def compile_hierarchical(model, nn_modules):
|
||||
for submod in model.modules():
|
||||
if isinstance(submod, nn_modules):
|
||||
submod.__class__.forward = mark_compile_region(submod.__class__.forward)
|
||||
model.compile(fullgraph=True)
|
||||
|
||||
|
||||
def auroflow_benchmark(mode):
|
||||
transformer = AuraFlowTransformer2DModel.from_single_file(
|
||||
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipeline = AuraFlowPipeline.from_pretrained(
|
||||
"fal/AuraFlow-v0.3",
|
||||
torch_dtype=torch.bfloat16,
|
||||
transformer=transformer,
|
||||
).to("cuda")
|
||||
|
||||
if mode == "full":
|
||||
compile_full_model(pipeline.transformer)
|
||||
elif mode == "regional":
|
||||
compile_regions(
|
||||
pipeline.transformer,
|
||||
(
|
||||
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowSingleTransformerBlock,
|
||||
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowJointTransformerBlock,
|
||||
),
|
||||
)
|
||||
elif mode == "hierarchical":
|
||||
compile_hierarchical(
|
||||
pipeline.transformer,
|
||||
(
|
||||
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowSingleTransformerBlock,
|
||||
diffusers.models.transformers.auraflow_transformer_2d.AuraFlowJointTransformerBlock,
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert mode == "eager"
|
||||
|
||||
pipeline("A cute pony", width=512, height=512, num_inference_steps=1)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
pipeline("A cute pony", width=512, height=512, num_inference_steps=50)
|
||||
t1 = time.perf_counter()
|
||||
print(t1 - t0)
|
||||
|
||||
|
||||
def wan_benchmark(mode):
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||
from diffusers.utils import load_image
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
model_id, subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# replace this with pipe.to("cuda") if you have sufficient VRAM
|
||||
# pipe.enable_model_cpu_offload()
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 480 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
if mode == "full":
|
||||
compile_full_model(pipe.transformer)
|
||||
elif mode == "regional":
|
||||
compile_regions(
|
||||
pipe.transformer,
|
||||
(diffusers.models.transformers.transformer_wan.WanTransformerBlock,),
|
||||
)
|
||||
elif mode == "hierarchical":
|
||||
compile_hierarchical(
|
||||
pipe.transformer,
|
||||
(diffusers.models.transformers.transformer_wan.WanTransformerBlock,),
|
||||
)
|
||||
else:
|
||||
assert mode == "eager"
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=1,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
t0 = time.perf_counter()
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
t1 = time.perf_counter()
|
||||
print(t1 - t0)
|
||||
|
||||
|
||||
def ltx_benchmark(mode):
|
||||
from diffusers import LTXConditionPipeline
|
||||
|
||||
import torch
|
||||
|
||||
pipe = LTXConditionPipeline.from_pretrained(
|
||||
"Lightricks/LTX-Video-0.9.7-dev", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
def round_to_nearest_resolution_acceptable_by_vae(height, width):
|
||||
height = height - (height % pipe.vae_spatial_compression_ratio)
|
||||
width = width - (width % pipe.vae_spatial_compression_ratio)
|
||||
return height, width
|
||||
|
||||
prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
expected_height, expected_width = 512, 704
|
||||
downscale_factor = 2 / 3
|
||||
num_frames = 121
|
||||
|
||||
# Part 1. Generate video at smaller resolution
|
||||
downscaled_height, downscaled_width = (
|
||||
int(expected_height * downscale_factor),
|
||||
int(expected_width * downscale_factor),
|
||||
)
|
||||
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
|
||||
downscaled_height, downscaled_width
|
||||
)
|
||||
|
||||
if mode == "full":
|
||||
compile_full_model(pipe.transformer)
|
||||
elif mode == "regional":
|
||||
compile_regions(
|
||||
pipe.transformer,
|
||||
(diffusers.models.transformers.transformer_ltx.LTXVideoTransformerBlock,),
|
||||
)
|
||||
elif mode == "hierarchical":
|
||||
compile_hierarchical(
|
||||
pipe.transformer,
|
||||
(diffusers.models.transformers.transformer_ltx.LTXVideoTransformerBlock,),
|
||||
)
|
||||
|
||||
latents = pipe(
|
||||
conditions=None,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=downscaled_width,
|
||||
height=downscaled_height,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=1,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
output_type="latent",
|
||||
).frames
|
||||
|
||||
import time
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
latents = pipe(
|
||||
conditions=None,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=downscaled_width,
|
||||
height=downscaled_height,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=50,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
output_type="latent",
|
||||
).frames
|
||||
t1 = time.time()
|
||||
print(t1 - t0)
|
||||
|
||||
|
||||
def flux_benchmark(mode):
|
||||
import diffusers
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
|
||||
if mode == "full":
|
||||
compile_full_model(pipe.transformer)
|
||||
elif mode == "regional":
|
||||
compile_regions(
|
||||
pipe.transformer,
|
||||
(
|
||||
diffusers.models.transformers.transformer_flux.FluxTransformerBlock,
|
||||
diffusers.models.transformers.transformer_flux.FluxSingleTransformerBlock,
|
||||
),
|
||||
)
|
||||
elif mode == "hierarchical":
|
||||
compile_hierarchical(
|
||||
pipe.transformer,
|
||||
(
|
||||
diffusers.models.transformers.transformer_flux.FluxTransformerBlock,
|
||||
diffusers.models.transformers.transformer_flux.FluxSingleTransformerBlock,
|
||||
),
|
||||
)
|
||||
|
||||
t0_0 = time.perf_counter()
|
||||
pipe(
|
||||
prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
guidance_scale=3.5,
|
||||
num_inference_steps=1,
|
||||
max_sequence_length=512,
|
||||
)
|
||||
t1_0 = time.perf_counter()
|
||||
print(t1_0 - t0_0)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
pipe(
|
||||
prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
guidance_scale=3.5,
|
||||
num_inference_steps=50,
|
||||
max_sequence_length=512,
|
||||
)
|
||||
t1 = time.perf_counter()
|
||||
print(t1 - t0)
|
||||
|
||||
|
||||
model_name = sys.argv[1]
|
||||
mode = sys.argv[2]
|
||||
|
||||
if model_name == "auroflow":
|
||||
auroflow_benchmark(mode)
|
||||
elif model_name == "wan":
|
||||
wan_benchmark(mode)
|
||||
elif model_name == "ltx":
|
||||
ltx_benchmark(mode)
|
||||
elif model_name == "flux":
|
||||
flux_benchmark(mode)
|
||||
54
simple_benchmark.py
Normal file
54
simple_benchmark.py
Normal file
@ -0,0 +1,54 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
# 1 input image channel, 6 output channels, 5x5 square convolution
|
||||
# kernel
|
||||
self.conv1 = nn.Conv2d(1, 6, 5)
|
||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||
# an affine operation: y = Wx + b
|
||||
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
|
||||
self.fc2 = nn.Linear(120, 84)
|
||||
self.fc3 = nn.Linear(84, 10)
|
||||
|
||||
def forward(self, input):
|
||||
# Convolution layer C1: 1 input image channel, 6 output channels,
|
||||
# 5x5 square convolution, it uses RELU activation function, and
|
||||
# outputs a Tensor with size (N, 6, 28, 28), where N is the size of the batch
|
||||
c1 = F.relu(self.conv1(input))
|
||||
# Subsampling layer S2: 2x2 grid, purely functional,
|
||||
# this layer does not have any parameter, and outputs a (N, 6, 14, 14) Tensor
|
||||
s2 = F.max_pool2d(c1, (2, 2))
|
||||
# Convolution layer C3: 6 input channels, 16 output channels,
|
||||
# 5x5 square convolution, it uses RELU activation function, and
|
||||
# outputs a (N, 16, 10, 10) Tensor
|
||||
c3 = F.relu(self.conv2(s2))
|
||||
# Subsampling layer S4: 2x2 grid, purely functional,
|
||||
# this layer does not have any parameter, and outputs a (N, 16, 5, 5) Tensor
|
||||
s4 = F.max_pool2d(c3, 2)
|
||||
# Flatten operation: purely functional, outputs a (N, 400) Tensor
|
||||
s4 = torch.flatten(s4, 1)
|
||||
# Fully connected layer F5: (N, 400) Tensor input,
|
||||
# and outputs a (N, 120) Tensor, it uses RELU activation function
|
||||
f5 = F.relu(self.fc1(s4))
|
||||
# Fully connected layer F6: (N, 120) Tensor input,
|
||||
# and outputs a (N, 84) Tensor, it uses RELU activation function
|
||||
f6 = F.relu(self.fc2(f5))
|
||||
# Gaussian layer OUTPUT: (N, 84) Tensor input, and
|
||||
# outputs a (N, 10) Tensor
|
||||
output = self.fc3(f6)
|
||||
return output
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
input_ = torch.randn(1, 1, 32, 32)
|
||||
net = Net()
|
||||
compiled_net = torch.compile(net)
|
||||
print("Function compiled!")
|
||||
result = compiled_net(input_)
|
||||
print("Result:", result)
|
||||
144
test.py
Normal file
144
test.py
Normal file
@ -0,0 +1,144 @@
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._inductor.test_case
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs
|
||||
from torch.testing._internal.opinfo.core import SampleInput
|
||||
|
||||
|
||||
DEBUG = True
|
||||
|
||||
|
||||
class ViewAndMutationMetaFromDynamo(torch._dynamo.test_case.TestCase):
|
||||
# Note: These 4 tests are to evalaute the feasability for the four
|
||||
# fields in metadata analysis that were identifed as diffifuclt to import pdb; pdb.set_trace()
|
||||
# extract from dynamo
|
||||
|
||||
# We want to run each one with different backend to check the graph
|
||||
# To view the created artifact and verify we have sufficient info
|
||||
def test_output_alias_info_functional_tensor(self):
|
||||
def f(x):
|
||||
return x[1].view(-1)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
# backend = EagerAndRecordGraphs()
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
compiled_f = torch.compile(f, backend=backend, fullgraph=True)
|
||||
out = compiled_f(x)
|
||||
assert len(backend.graphs) == 1
|
||||
gm = backend.graphs[0]
|
||||
|
||||
def test_input_alias_info_mutations_hidden_from_autograd(self):
|
||||
# From https://github.com/pytorch/pytorch/blob/1258aac1c28f2e66f54ecacaf798a0e7a24206ef/test/functorch/test_aotdispatch.py#L1457
|
||||
def f(a):
|
||||
a_alias = a.view(-1)
|
||||
with torch.no_grad():
|
||||
a_alias.mul_(2)
|
||||
return a + 1
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
# backend = EagerAndRecordGraphs()
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
compiled_f = torch.compile(f, backend=backend, fullgraph=True)
|
||||
out = compiled_f(x)
|
||||
assert len(backend.graphs) == 1
|
||||
gm = backend.graphs[0]
|
||||
|
||||
# Currently fails because second collection pass doesn't pass GM
|
||||
# this is okay, as first pass has parity - just need to clena up code
|
||||
def test_traced_tangents(self):
|
||||
# From ttps://github.com/pytorch/pytorch/blob/1258aac1c28f2e66f54ecacaf798a0e7a24206ef/test/functorch/test_aotdispatch.py#L6541
|
||||
def fn(x):
|
||||
return x.clone()
|
||||
|
||||
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
|
||||
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
|
||||
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
|
||||
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
||||
# backend = EagerAndRecordGraphs()
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
out = torch.compile(fn, backend=backend, fullgraph=True)(nt)
|
||||
out_buffer = out.values()
|
||||
ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))
|
||||
|
||||
assert len(backend.graphs) == 1
|
||||
gm = backend.graphs[0]
|
||||
|
||||
def test_tokens(self):
|
||||
# Map of effect type (ex. _EffectType.ORDERED) to token
|
||||
# FunctionalTensorMode would have populated this, so we need to validate
|
||||
# that we can populate this from dynamo - should be similar to HOPs and Triton
|
||||
# kernels
|
||||
# From https://github.com/pytorch/pytorch/blob/1258aac1c28f2e66f54ecacaf798a0e7a24206ef/test/higher_order_ops/test_with_effects.py#L89
|
||||
def f(x):
|
||||
torch.ops.aten._print("moo")
|
||||
res = x + x
|
||||
torch.ops.aten._print("moo")
|
||||
return (res,)
|
||||
|
||||
inputs = (torch.randn(3),)
|
||||
# backend = EagerAndRecordGraphs()
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
out = torch.compile(f, backend=backend, fullgraph=True)(inputs)
|
||||
assert len(backend.graphs) == 1
|
||||
gm = backend.graphs[0]
|
||||
|
||||
def test_tp_transform_with_uncovered_op(self):
|
||||
class DummyModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = torch.nn.Linear(3, 5)
|
||||
self.bn = torch.nn.BatchNorm1d(5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(self.fc(x))
|
||||
|
||||
inputs = (torch.randn(7, 3, requires_grad=False),)
|
||||
model = DummyModel()
|
||||
res = model(*inputs)
|
||||
exported_program = torch.export.export(
|
||||
model, inputs, strict=True
|
||||
).run_decompositions()
|
||||
tp_res = exported_program.module()(*inputs)
|
||||
self.assertEqual(res, tp_res)
|
||||
# Expect all_gather to be inserted to distributed sharded fc resutls
|
||||
|
||||
def test_chebyshev_polynomial(self):
|
||||
out = torch.zeros((2,))
|
||||
|
||||
def f(samp: SampleInput):
|
||||
torch.special.chebyshev_polynomial_u(*samp.args, samp.input, out=out)
|
||||
return out
|
||||
|
||||
n = -5.115719318389893
|
||||
inp = (torch.tensor(1.0961),)
|
||||
samp = SampleInput(
|
||||
n,
|
||||
args=inp,
|
||||
kwargs={},
|
||||
)
|
||||
# backend = EagerAndRecordGraphs()
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
compiled_f = torch.compile(f, backend=backend, fullgraph=True)
|
||||
out = compiled_f(samp)
|
||||
plain = f(samp)
|
||||
print(f"Out: {out} vs Plain: {plain}")
|
||||
self.assertEqual(out, plain)
|
||||
|
||||
def test_baddbmm(self):
|
||||
aten = torch.ops.aten
|
||||
|
||||
def fn(a, b, c, beta):
|
||||
return aten.baddbmm(a, b, c, beta=beta)
|
||||
|
||||
compiled_fn = torch.compile(fn, dynamic=True)
|
||||
a = torch.randn(6, 1, 100)
|
||||
b = torch.randn(6, 128, 64)
|
||||
c = torch.randn(6, 64, 100)
|
||||
self.assertEqual(compiled_fn(a, b, c, 0.0), fn(a, b, c, 0.0))
|
||||
self.assertEqual(compiled_fn(a, b, c, 1.0), fn(a, b, c, 1.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -57,11 +57,11 @@ class ComptimeTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertExpectedInline(
|
||||
FILE.getvalue().strip(),
|
||||
"""\
|
||||
FakeTensor(..., size=(s77,))
|
||||
FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(s77,))))
|
||||
2
|
||||
[FakeTensor(..., size=(s77,)), 2]
|
||||
(FakeTensor(..., size=(s77,)), 2)
|
||||
{'foo': FakeTensor(..., size=(s77,))}
|
||||
[FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(s77,)))), 2]
|
||||
(FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(s77,)))), 2)
|
||||
{'foo': FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(s77,))))}
|
||||
range(1, 3, 1)
|
||||
Employee(name='foo', id=2)
|
||||
UserDefinedListVariable(mylist)
|
||||
@ -160,7 +160,7 @@ def forward(self, L_x_ : torch.Tensor):
|
||||
self.assertExpectedInline(
|
||||
FILE.getvalue(),
|
||||
"""\
|
||||
- FakeTensor(..., size=(2,))
|
||||
- FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(2,))))
|
||||
""",
|
||||
)
|
||||
|
||||
@ -186,8 +186,8 @@ def forward(self, L_x_ : torch.Tensor):
|
||||
self.assertExpectedInline(
|
||||
FILE.getvalue(),
|
||||
"""\
|
||||
x = FakeTensor(..., size=(2,))
|
||||
y = FakeTensor(..., size=(2,))
|
||||
x = FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(2,))))
|
||||
y = FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(2,))))
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@ -2496,6 +2496,9 @@ def forward(self, primals_1, primals_2):
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
]
|
||||
compiled_f = torch.compile(f)
|
||||
actual = compiled_f(*inp_grad)
|
||||
self.assertEqual(actual, f(*inp_grad))
|
||||
self.verify_aot_autograd(f, inp_grad, test_mutation=True)
|
||||
|
||||
def test_backward_mutation_metadata(self):
|
||||
|
||||
@ -6035,6 +6035,8 @@ def forward(self, x_1):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
# cpp(fake(rv))
|
||||
# Python(cpp(cpp(fake(rv))))
|
||||
try:
|
||||
func_args = pytree.tree_map(
|
||||
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
|
||||
|
||||
@ -280,11 +280,15 @@ class FakeTensorTest(TestCase):
|
||||
b = _add_batch_dim(x, 0, 0)
|
||||
mode = FakeTensorMode()
|
||||
fake_b = mode.from_tensor(b)
|
||||
print(fake_b)
|
||||
prims.utils.compare_tensor_meta(b, fake_b, check_strides=True)
|
||||
|
||||
b1 = _add_batch_dim(x, 1, 1)
|
||||
b2 = _add_batch_dim(b1, 0, 2)
|
||||
fake_b2 = mode.from_tensor(b2)
|
||||
print(fake_b2)
|
||||
print(is_batchedtensor(fake_b2))
|
||||
print(is_batchedtensor(fake_b))
|
||||
prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True)
|
||||
self.assertTrue(is_batchedtensor(fake_b2))
|
||||
fake_b1 = get_unwrapped(fake_b2)
|
||||
@ -700,7 +704,12 @@ class FakeTensorTest(TestCase):
|
||||
fake_x_view = mode.from_tensor(x_view)
|
||||
fake_x = mode.from_tensor(x)
|
||||
self.assertFalse(isinstance(fake_x_view, torch.nn.Parameter))
|
||||
self.assertTrue(isinstance(fake_x, torch.nn.Parameter))
|
||||
# fake_x = from_fun(fake_x)
|
||||
# from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
|
||||
# fake_x = FunctionalTensor.from_functional(fake_x)
|
||||
# print(f"Fake tensor is {fake_x}, is instance? {isinstance(fake_x, torch.nn.Parameter)}? type is: {type(fake_x)}")
|
||||
self.assertTrue(isinstance(fake_x, torch.nn.Parameter)) # <- Failing
|
||||
|
||||
def test_tolist(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
||||
@ -46,6 +46,7 @@ import torch.utils._pytree as pytree
|
||||
from torch import fx, Tensor
|
||||
from torch._C._dynamo import guards
|
||||
from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
|
||||
from torch._functorch._aot_autograd.functional_utils import from_fun
|
||||
from torch._guards import (
|
||||
CompileContext,
|
||||
CompileId,
|
||||
@ -55,6 +56,7 @@ from torch._guards import (
|
||||
TracingContext,
|
||||
)
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
@ -462,7 +464,12 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
allow_non_fake_inputs=True if self.export else False,
|
||||
export=self.export,
|
||||
)
|
||||
self.tracing_context: TracingContext = TracingContext(fake_mode)
|
||||
functional_mode = FunctionalTensorMode(
|
||||
export=self.export, _allow_token_discovery=True
|
||||
)
|
||||
self.tracing_context: TracingContext = TracingContext(
|
||||
fake_mode, functional_mode
|
||||
)
|
||||
self.tracing_context.traced_code.append(f_code)
|
||||
self.dynamo_compile_id: Optional[CompileId] = (
|
||||
CompileContext.current_compile_id()
|
||||
@ -697,6 +704,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
"""
|
||||
call fn(*args) before the graph runs and turn the result into a fake input.
|
||||
"""
|
||||
# NOTE: Where does this get used???
|
||||
example_value = fn(*args)
|
||||
varname = self.new_var()
|
||||
cg = PyCodegen(self.root_tx)
|
||||
@ -807,6 +815,10 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
def fake_mode(self):
|
||||
return self.tracing_context.fake_mode
|
||||
|
||||
@property
|
||||
def functional_mode(self):
|
||||
return self.tracing_context.functional_mode
|
||||
|
||||
@property
|
||||
def shape_env(self):
|
||||
return self.tracing_context.fake_mode.shape_env
|
||||
@ -923,6 +935,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
# than just nn module objects, fix that.
|
||||
self.nn_modules[attr_name] = attr_value
|
||||
proxy = self.create_proxy("get_attr", attr_name, (), {})
|
||||
# NOTE: Figure where this gets used
|
||||
set_example_value(proxy.node, attr_value)
|
||||
return proxy
|
||||
|
||||
@ -1691,18 +1704,18 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
)
|
||||
self.call_cleanup_hooks()
|
||||
old_fake_mode = self.tracing_context.fake_mode
|
||||
if not self.export:
|
||||
import torch._functorch.config as _config
|
||||
# if not self.export:
|
||||
# import torch._functorch.config as _config
|
||||
|
||||
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
||||
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
|
||||
backend_fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=old_fake_mode.shape_env,
|
||||
)
|
||||
# TODO(voz): Ostensibily, this should be scoped and
|
||||
# restore back to old_fake_mode, but doing so currently violates
|
||||
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
|
||||
self.tracing_context.fake_mode = backend_fake_mode
|
||||
# with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
||||
# # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
|
||||
# backend_fake_mode = torch._subclasses.FakeTensorMode(
|
||||
# shape_env=old_fake_mode.shape_env,
|
||||
# )
|
||||
# # TODO(voz): Ostensibily, this should be scoped and
|
||||
# # restore back to old_fake_mode, but doing so currently violates
|
||||
# # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
|
||||
# self.tracing_context.fake_mode = backend_fake_mode
|
||||
|
||||
with self.restore_global_state():
|
||||
compiled_fn = self.call_user_compiler(gm, self.example_inputs())
|
||||
@ -2093,6 +2106,10 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
|
||||
for node in self.graph.nodes:
|
||||
example_value = node.meta.get("example_value")
|
||||
# May need to unwrap functional tensor if used
|
||||
if isinstance(example_value, FunctionalTensor):
|
||||
example_value = from_fun(example_value)
|
||||
|
||||
if (
|
||||
isinstance(example_value, FakeTensor)
|
||||
and example_value.item_memo is not None
|
||||
@ -2103,6 +2120,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
example_value.item_memo.node._expr.name
|
||||
)
|
||||
):
|
||||
# Likely never hit but should be!
|
||||
for u in list(node.users):
|
||||
u.replace_all_uses_with(guard_scalar(example_value.item_memo))
|
||||
self.remove_node(u)
|
||||
@ -2476,7 +2494,6 @@ class SubgraphTracer(fx.Tracer):
|
||||
|
||||
trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
|
||||
self.prev_inst = cur_inst
|
||||
|
||||
# update reference to original meta if we're tracing a new code object
|
||||
is_retracing = False
|
||||
if tx.f_code is not self._cur_code:
|
||||
@ -2524,7 +2541,6 @@ class SubgraphTracer(fx.Tracer):
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
self._maybe_preserve_original_meta(tx, rv.node)
|
||||
|
||||
if not is_retracing:
|
||||
@ -2805,6 +2821,7 @@ class SubgraphTracer(fx.Tracer):
|
||||
def track_unbacked_symbols(
|
||||
self, example_value, e_proxy: Union[LazyProxy, torch.fx.Proxy]
|
||||
):
|
||||
# breakpoint()
|
||||
# When binding the symbols in an exmaple_value, we bind the symbols
|
||||
# to the proxy's associated Tracer instead of current tracer.
|
||||
# This is because:
|
||||
@ -2837,6 +2854,7 @@ class SubgraphTracer(fx.Tracer):
|
||||
return proxy
|
||||
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
example_value = from_fun(example_value)
|
||||
for i, s in enumerate(example_value.size()):
|
||||
if need_bind(s):
|
||||
log.debug(
|
||||
|
||||
@ -3164,6 +3164,10 @@ class InstructionTranslatorBase(
|
||||
def fake_mode(self):
|
||||
return self.output.tracing_context.fake_mode
|
||||
|
||||
@property
|
||||
def functional_mode(self):
|
||||
return self.output.tracing_context.functional_mode
|
||||
|
||||
@contextlib.contextmanager
|
||||
def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]):
|
||||
"""
|
||||
@ -4035,6 +4039,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
def fake_mode(self):
|
||||
return self.parent.fake_mode
|
||||
|
||||
@property
|
||||
def functional_mode(self):
|
||||
return self.parent.functional_mode
|
||||
|
||||
def run_ctx_mgr(self):
|
||||
return TracingContext.current_frame(self.parent.frame_summary())
|
||||
|
||||
|
||||
@ -2791,6 +2791,13 @@ def deepcopy_to_fake_tensor(obj, fake_mode):
|
||||
return wrap_fake_exception(lambda: copy.deepcopy(obj))
|
||||
|
||||
|
||||
def deepcopy_to_functional_tensor(obj, fake_mode, functional_mode):
|
||||
with torch._subclasses.functional_tensor.FunctionalCopyMode(
|
||||
fake_mode, functional_mode
|
||||
):
|
||||
return wrap_fake_exception(lambda: copy.deepcopy(obj))
|
||||
|
||||
|
||||
def rmse(ref, res):
|
||||
"""
|
||||
Calculate root mean squared error
|
||||
@ -3241,9 +3248,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
id_to_initial_version = {}
|
||||
|
||||
nnmodule = None
|
||||
# CURRENT ISSUE: deepcopy_to_fake_tensor doesn't work with FunctionalMode
|
||||
# without this mode though, we don't have parameters of proper tensor types
|
||||
if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):
|
||||
# If the first argument is nn.Module, should copy to fake mode.
|
||||
args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:])
|
||||
copied = deepcopy_to_fake_tensor(args[0], tx.fake_mode)
|
||||
args = (copied,) + tuple(args[1:])
|
||||
|
||||
if op == "call_module":
|
||||
nnmodule = tx.output.nn_modules[node.target]
|
||||
@ -3256,7 +3266,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
nnmodule._infer_parameters(nnmodule, args)
|
||||
|
||||
# no matter it's lazy module or not, we should copy to fake mode.
|
||||
nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)
|
||||
# Can make FunctionalMode like Fake one and handle
|
||||
# or wrap the resulting deep copy
|
||||
# just disable functional mode first
|
||||
nnmodule = deepcopy_to_functional_tensor(
|
||||
nnmodule, tx.fake_mode, tx.functional_mode
|
||||
)
|
||||
|
||||
if node.name in ["interpolate", "is_integer", "wrapped_gradient"] or any(
|
||||
isinstance(a, complex) for a in args
|
||||
@ -3270,7 +3285,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
)
|
||||
|
||||
try:
|
||||
with tx.fake_mode, enable_python_dispatcher():
|
||||
with tx.functional_mode, tx.fake_mode, enable_python_dispatcher():
|
||||
ret_val = wrap_fake_exception(
|
||||
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
|
||||
)
|
||||
@ -3426,7 +3441,6 @@ def run_node(tracer, node, args, kwargs, nnmodule):
|
||||
raise an AssertionError.
|
||||
"""
|
||||
op = node.op
|
||||
|
||||
with set_current_node(node):
|
||||
|
||||
def make_error_message(e):
|
||||
@ -3453,7 +3467,8 @@ def run_node(tracer, node, args, kwargs, nnmodule):
|
||||
return getattr(args[0], node.target)(*args[1:], **kwargs)
|
||||
elif op == "call_module":
|
||||
assert nnmodule is not None
|
||||
return nnmodule(*args, **kwargs)
|
||||
with tracer.functional_mode:
|
||||
return nnmodule(*args, **kwargs)
|
||||
elif op == "get_attr":
|
||||
return tracer.output_graph.get_submodule(node.target)
|
||||
elif op == "placeholder":
|
||||
|
||||
@ -51,6 +51,7 @@ from torch._dynamo.utils import (
|
||||
is_torch_sym,
|
||||
set_feature_use,
|
||||
)
|
||||
from torch._functorch._aot_autograd.functional_utils import to_fun
|
||||
from torch._guards import TracingContext
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._ops import HigherOrderOperator
|
||||
@ -450,7 +451,9 @@ class VariableBuilder:
|
||||
and value not in self.tx.output.side_effects
|
||||
and not is_wrapper_or_member_descriptor(value)
|
||||
):
|
||||
vt = self.tx.output.side_effects.track_object_existing(value, vt)
|
||||
vt = self.tx.output.side_effects.track_object_existing(
|
||||
value, vt
|
||||
) # NOTE: Shouldn't this track the mutation?
|
||||
|
||||
self.tx.output.variable_tracker_cache.add(value, self.source, vt)
|
||||
return vt
|
||||
@ -2085,11 +2088,21 @@ class VariableBuilder:
|
||||
# then the relevant SubgraphTracer will lift it to being an input of
|
||||
# the subgraph.
|
||||
# See NOTE [HigherOrderOperator tracing design] for more details.
|
||||
|
||||
# Need to handle weird edge case where value is already a functional tensor...
|
||||
# torch._disable_functionalization()
|
||||
# if torch._is_functional_tensor(value):
|
||||
# val = torch._from_functional_tensor(value)
|
||||
# else:
|
||||
# val = value
|
||||
# breakpoint()
|
||||
example_value = wrap_to_fake_tensor_and_record(
|
||||
value, tx=self.tx, is_tensor=True, source=source
|
||||
)
|
||||
|
||||
if not torch._is_functional_tensor(example_value):
|
||||
with self.tx.functional_mode:
|
||||
example_value = to_fun(example_value)
|
||||
# NOTE: figure out what reapply_views is from keyset
|
||||
# torch._enable_functionalization(reapply_views=False)
|
||||
tensor_proxy = self.tx.output.root_tracer.create_graph_input(
|
||||
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
|
||||
type(value),
|
||||
@ -2203,12 +2216,15 @@ class VariableBuilder:
|
||||
# that there's not another great way to do this atm.
|
||||
# This creates the right graphargs, as well as registration for guards in tensor names and shape env.
|
||||
LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value))
|
||||
example_value = wrap_to_fake_tensor_and_record(
|
||||
tensor_value,
|
||||
tx=self.tx,
|
||||
is_tensor=False,
|
||||
source=source,
|
||||
)
|
||||
with self.tx.functional_mode:
|
||||
example_value = to_fun(
|
||||
wrap_to_fake_tensor_and_record(
|
||||
tensor_value,
|
||||
tx=self.tx,
|
||||
is_tensor=False,
|
||||
source=source,
|
||||
)
|
||||
)
|
||||
proxy = self.tx.output.root_tracer.create_graph_input(
|
||||
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
|
||||
type(tensor_value),
|
||||
@ -2435,9 +2451,12 @@ class VariableBuilder:
|
||||
|
||||
# TODO: Maybe the tensor-ification should be built into the source,
|
||||
# rather than by special pattern match
|
||||
example_value = wrap_to_fake_tensor_and_record(
|
||||
wrapped_value, tx=self.tx, is_tensor=False, source=source
|
||||
)
|
||||
with self.tx.functional_mode:
|
||||
example_value = to_fun(
|
||||
wrap_to_fake_tensor_and_record(
|
||||
wrapped_value, tx=self.tx, is_tensor=False, source=source
|
||||
)
|
||||
)
|
||||
proxy = self.tx.output.root_tracer.create_graph_input(
|
||||
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
|
||||
type(wrapped_value),
|
||||
@ -2465,7 +2484,7 @@ class VariableBuilder:
|
||||
assert is_fake(example_value)
|
||||
|
||||
fake_tensor_value = example_value
|
||||
assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
|
||||
assert maybe_get_fake_mode(fake_tensor_value) is self.tx.fake_mode, (
|
||||
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
|
||||
"({self.tx.fake_mode}) from InstructionTranslator"
|
||||
)
|
||||
@ -2778,7 +2797,8 @@ def _wrap_fx_proxy(
|
||||
# with preserve_rng_state():
|
||||
# only allow_non_graph_fake in this instance because we handle the non-fake
|
||||
# cases properly below.
|
||||
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
|
||||
with tx.functional_mode:
|
||||
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
|
||||
|
||||
return handle_traced_output(
|
||||
example_value, tx, proxy, options, subclass_type, target_cls
|
||||
@ -2999,8 +3019,10 @@ def construct_tensor_variable(
|
||||
# TODO: not sure about this fake mode test
|
||||
if (
|
||||
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
|
||||
and example_value.fake_mode is tx.fake_mode
|
||||
):
|
||||
or isinstance(
|
||||
example_value, torch._subclasses.functional_tensor.FunctionalTensor
|
||||
)
|
||||
) and maybe_get_fake_mode(example_value) is tx.fake_mode:
|
||||
if subclass_type:
|
||||
tensor_type = subclass_type
|
||||
elif isinstance(example_value, torch.nn.Parameter):
|
||||
|
||||
@ -1564,7 +1564,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
from torch._higher_order_ops.utils import _maybe_fake_tracing
|
||||
from torch._inductor.utils import is_pointwise_use
|
||||
|
||||
with tx.fake_mode:
|
||||
with tx.fake_mode, tx.functional_mode:
|
||||
sub_args_fake = [
|
||||
leaf.node.meta["example_value"].clone()
|
||||
if hasattr(leaf.node.meta["example_value"], "clone")
|
||||
@ -1601,7 +1601,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
additional_inputs_proxy,
|
||||
)
|
||||
|
||||
with tx.fake_mode:
|
||||
with tx.fake_mode, tx.functional_mode:
|
||||
out_meta = tuple(
|
||||
inp_proxy.node.meta["example_value"].clone() for inp_proxy in xs_proxy
|
||||
)
|
||||
@ -1804,7 +1804,7 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
additional_inputs_proxy,
|
||||
)
|
||||
|
||||
with tx.fake_mode:
|
||||
with tx.fake_mode, tx.functional_mode:
|
||||
example_carry = [
|
||||
init_p.node.meta["example_value"].clone() for init_p in init_proxy
|
||||
]
|
||||
@ -3037,6 +3037,10 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
||||
"torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops",
|
||||
[],
|
||||
):
|
||||
# Could be a context manager to turn off counter increment
|
||||
# But we don't want to record the mutations inside speculate_subgraph
|
||||
# NOTE: May need special attention to Triton Hop (possibly not but maybe)
|
||||
# May need to also ignore effectful python ops
|
||||
(bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph(
|
||||
tx,
|
||||
bwd_fn,
|
||||
@ -3199,7 +3203,7 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
||||
# (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing.
|
||||
# Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it.
|
||||
with enable_python_dispatcher():
|
||||
with tx.output.fake_mode:
|
||||
with tx.output.fake_mode, tx.output.functional_mode:
|
||||
fake_args = (
|
||||
tx.output.nn_modules[fwd_node.node.name],
|
||||
tx.output.nn_modules[bwd_node.node.name],
|
||||
|
||||
@ -23,6 +23,7 @@ from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._functorch._aot_autograd.functional_utils import from_fun
|
||||
|
||||
from .. import graph_break_hints, polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
@ -149,6 +150,7 @@ class BaseListVariable(VariableTracker):
|
||||
assert not kwargs and len(args) == 1
|
||||
if isinstance(args[0], TensorVariable):
|
||||
value = get_fake_value(args[0].as_proxy().node, tx)
|
||||
value = from_fun(value)
|
||||
if value.constant is not None and value.constant.numel() == 1:
|
||||
value = variables.ConstantVariable.create(value.constant.item())
|
||||
else:
|
||||
|
||||
@ -694,6 +694,7 @@ class TensorVariable(VariableTracker):
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
# breakpoint() # NOTE: set_item gets in here
|
||||
result = handler_method(*args, **kwargs)
|
||||
if result:
|
||||
return result
|
||||
@ -975,12 +976,15 @@ class TensorVariable(VariableTracker):
|
||||
return wrap(tensor, sub_proxy)
|
||||
|
||||
if tensor.dim() == 1:
|
||||
return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)]
|
||||
|
||||
return [
|
||||
tolist(sub_tensor, sub_proxy=sub_proxy[i])
|
||||
for i, sub_tensor in enumerate(tensor)
|
||||
]
|
||||
# Example value may be functional tensor, in which case,
|
||||
# enumerate() calls, so we need to wrap in functional mode
|
||||
with tx.functional_mode:
|
||||
return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)]
|
||||
with tx.functional_mode:
|
||||
return [
|
||||
tolist(sub_tensor, sub_proxy=sub_proxy[i])
|
||||
for i, sub_tensor in enumerate(tensor)
|
||||
]
|
||||
|
||||
tensor = self.as_proxy().node.meta["example_value"]
|
||||
out = tolist(tensor, self.as_proxy())
|
||||
@ -1082,13 +1086,15 @@ class TensorVariable(VariableTracker):
|
||||
def method___setitem__(self, key, value):
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
|
||||
# breakpoint()
|
||||
tx = InstructionTranslator.current_tx()
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function",
|
||||
operator.setitem,
|
||||
*proxy_args_kwargs([self, key, value], {}),
|
||||
)
|
||||
|
||||
# Run func on fake tensor - needed to set state on functional
|
||||
get_fake_value(proxy.node, tx)
|
||||
if config.use_graph_deduplication or config.track_nodes_for_deduplication:
|
||||
tx.output.region_tracker.add_node_mutation(proxy.node, 0)
|
||||
|
||||
|
||||
@ -592,6 +592,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
# This simulates shallow-copying the tensor object.
|
||||
kwargs = dict(tensor_var.__dict__)
|
||||
input_tensor_type = kwargs.pop("class_type")
|
||||
# TODO: Figure out why this gets called with FunctionalTensor
|
||||
# and resolve!
|
||||
assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), (
|
||||
f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var"
|
||||
)
|
||||
|
||||
@ -152,10 +152,11 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
is_train: bool = False,
|
||||
# Note: this is guaranteed to be set when running under dynamo
|
||||
static_input_indices: Optional[list[int]] = None,
|
||||
pre_dispatch: bool = False,
|
||||
# is_export is technically only needed to avoid using functionalization V2
|
||||
# during analysis
|
||||
is_export: bool = False,
|
||||
gm: Optional[torch.fx.GraphModule] = None,
|
||||
functional_mode: Optional[FunctionalTensorMode] = None,
|
||||
) -> Callable[..., ViewAndMutationMeta]:
|
||||
memo: dict[Tensor, Tensor] = {}
|
||||
|
||||
@ -171,6 +172,21 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
|
||||
@wraps(f)
|
||||
def inner(*flat_args):
|
||||
nonlocal static_input_indices
|
||||
view_and_mutation_meta_no_gm = None
|
||||
# Short circut and don't run full thing if freezing, since the input placeholders won't be
|
||||
if gm is not None:
|
||||
# Hacky work around - we want to check against functionalized, so let us call
|
||||
# into this for now to verify correcntess
|
||||
to_call = run_functionalized_fw_and_collect_metadata(
|
||||
f,
|
||||
keep_input_mutations=keep_input_mutations,
|
||||
is_train=is_train,
|
||||
static_input_indices=static_input_indices,
|
||||
is_export=is_export,
|
||||
gm=None,
|
||||
)
|
||||
view_and_mutation_meta_no_gm = to_call(*flat_args)
|
||||
# This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
|
||||
assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args)
|
||||
|
||||
@ -189,20 +205,66 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
# only for figuring out metadata
|
||||
mode = FunctionalTensorMode(_allow_token_discovery=True, export=is_export)
|
||||
suppress_pending = contextlib.nullcontext()
|
||||
|
||||
fake_mode = detect_fake_mode()
|
||||
if fake_mode and (shape_env := fake_mode.shape_env):
|
||||
suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
|
||||
with disable_above, mode, suppress_pending:
|
||||
# precondition: The passed in function already handles unflattening inputs + flattening outputs
|
||||
flat_f_args = pytree.tree_map(_to_fun, flat_args)
|
||||
flat_f_outs = f(*flat_f_args)
|
||||
# We didn't do any tracing, so we don't need to process the
|
||||
# unbacked symbols, they will just disappear into the ether.
|
||||
# Also, prevent memoization from applying.
|
||||
if fake_mode:
|
||||
fake_mode.epoch += 1
|
||||
fake_mode.reset_nt_tensor_id_counter()
|
||||
if gm is None:
|
||||
# NOTE: If we don't have fx module (assume no dynamo),
|
||||
# need to get functional tensors and do call
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with disable_above, mode, suppress_pending:
|
||||
# precondition: The passed in function already handles unflattening inputs + flattening outputs
|
||||
flat_f_args = pytree.tree_map(_to_fun, flat_args)
|
||||
flat_f_outs = f(*flat_f_args)
|
||||
# We didn't do any tracing, so we don't need to process the
|
||||
# unbacked symbols, they will just disappear into the ether.
|
||||
# Also, prevent memoization from applying.
|
||||
if fake_mode:
|
||||
fake_mode.epoch += 1
|
||||
fake_mode.reset_nt_tensor_id_counter()
|
||||
|
||||
elapsed_time_ms = (time.time() - start_time) * 1000
|
||||
log.debug(
|
||||
f"Time taken for functionalized execution: {elapsed_time_ms:.2f} ms"
|
||||
)
|
||||
else:
|
||||
# NOTE: This makes a strong assumption that the order of f_args is analogous to args
|
||||
# this may not be true...
|
||||
|
||||
import time
|
||||
|
||||
if functional_mode:
|
||||
mode = functional_mode
|
||||
else:
|
||||
raise RuntimeError("Expected functional mode to be set")
|
||||
# All `example_value` tensors on the nodes are already wrapped in
|
||||
# functional tensor, so we don't need to wrap them again, only extract them
|
||||
start_time = time.time()
|
||||
flat_f_args = tuple(
|
||||
node.meta["example_value"]
|
||||
if hasattr(node, "meta") and "example_value" in node.meta
|
||||
else None
|
||||
for node in gm.graph.find_nodes(op="placeholder")
|
||||
)
|
||||
# breakpoint()
|
||||
# flat_f_args = flat_f_args_maybe_tensors
|
||||
# The output nodes are the last nodes (args) into the singular output node
|
||||
output_args = gm.graph.find_nodes(op="output")[0].args
|
||||
output_nodes = output_args[0] if len(output_args) > 0 else []
|
||||
flat_f_outs = tuple(
|
||||
node.meta["example_value"]
|
||||
if hasattr(node, "meta") and "example_value" in node.meta
|
||||
else None
|
||||
for node in output_nodes
|
||||
)
|
||||
elapsed_time_ms = (time.time() - start_time) * 1000
|
||||
log.debug(f"Time taken: {elapsed_time_ms:.2f} ms from graph")
|
||||
# print(flat_f_args)
|
||||
# print(flat_f_outs)
|
||||
if prior_autocast_states != _get_autocast_states():
|
||||
raise RuntimeError(
|
||||
"AOTAutograd does not support tracing graphs that mutate the autocast state. "
|
||||
@ -213,11 +275,13 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
|
||||
# Inspect the state of the input tensor functional wrapper to detect input mutation info
|
||||
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
|
||||
# import pdb; pdb.set_trace()s
|
||||
for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)):
|
||||
# NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
|
||||
# strides between the functionalized arg inner tensors and non-functionalized arg inner
|
||||
# tensors. This is a problem as the inner tensor stride change may not be reflected
|
||||
# correctly in the outer tensor, so disallow this for now.
|
||||
# breakpoint()
|
||||
mutates_data = has_data_mutation(f_arg)
|
||||
mutates_metadata = has_metadata_mutation(
|
||||
f_arg, arg, check_only_storage_mutation=False
|
||||
@ -738,7 +802,6 @@ from a multi-output view call"
|
||||
coerce_tangent_and_suggest_memory_format(tt)[0]
|
||||
for i, tt in enumerate(traced_tangents)
|
||||
]
|
||||
nonlocal static_input_indices
|
||||
static_input_indices = static_input_indices or []
|
||||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||
passed_indices = set(static_input_indices)
|
||||
@ -808,6 +871,17 @@ from a multi-output view call"
|
||||
static_input_indices=static_input_indices,
|
||||
tokens=mode._tokens,
|
||||
)
|
||||
if view_and_mutation_meta_no_gm is not None:
|
||||
for k, v in view_and_mutation_meta_no_gm.__dict__.items():
|
||||
try:
|
||||
assert (
|
||||
metadata.__dict__[k] == v
|
||||
), f"{k} mismatch: {metadata.__dict__[k]} vs {v}"
|
||||
except RuntimeError as e:
|
||||
# traced tangents results in multiple in nested case so call all
|
||||
log.debug(
|
||||
f"~~~~Got error on KEY: {k}: {e}, \nPlease inspect manually: {metadata.__dict__[k]} vs {v}"
|
||||
)
|
||||
return metadata
|
||||
|
||||
return inner
|
||||
|
||||
@ -39,6 +39,20 @@ def to_fun(t):
|
||||
out = transform_subclass(t, lambda _, inner_t: to_fun(inner_t))
|
||||
torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined]
|
||||
return out
|
||||
# NOTE: This may need to be expanded to other nontraceable subclasses too;
|
||||
# not sure how to do this outside of explicit adding
|
||||
# for instance, GradTrackingTensor
|
||||
# alternatively, can move this wrapping into meta_utils.py
|
||||
elif torch._C._functorch.is_batchedtensor(t):
|
||||
# Special case to get to the fake tensor and wrap properly
|
||||
unwrapped = torch._C._functorch.get_unwrapped(t)
|
||||
assert unwrapped is not None
|
||||
unwrapped = to_fun(unwrapped)
|
||||
return torch._C._functorch._add_batch_dim(
|
||||
unwrapped,
|
||||
torch._C._functorch.maybe_get_bdim(t),
|
||||
torch._C._functorch.maybe_get_level(t),
|
||||
)
|
||||
else:
|
||||
return FunctionalTensor.to_functional(t)
|
||||
else:
|
||||
|
||||
@ -956,6 +956,11 @@ class AOTConfig:
|
||||
# Used only by standalone_compile.
|
||||
ignore_shape_env: bool = False
|
||||
precompile_backend_id: Optional[str] = None
|
||||
# For testing gradual transition towards moving metadata colleciton
|
||||
# to Dynamo.
|
||||
gm: Optional[torch.fx.GraphModule] = None
|
||||
# For testing gradual transition towards moving metadata colleciton
|
||||
functional_mode: Optional[Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pre_dispatch:
|
||||
|
||||
@ -682,6 +682,14 @@ def _create_aot_dispatcher_function(
|
||||
dynamo_timed_ctx = dynamo_timed(
|
||||
"aot_collect_metadata", log_pt2_compile_event=True
|
||||
)
|
||||
if aot_config.gm is not None and aot_config.functional_mode is not None:
|
||||
# Need both to be non None else both should be None
|
||||
# this is because functional_mode contains pertinent info to the
|
||||
# FunctionalTensors in gm; if it is not preserved, we can't get this
|
||||
# info properly
|
||||
gm, func_mode = aot_config.gm, aot_config.functional_mode
|
||||
else:
|
||||
gm, func_mode = None, None
|
||||
|
||||
with dynamo_timed_ctx, ctx:
|
||||
fw_metadata = run_functionalized_fw_and_collect_metadata(
|
||||
@ -689,8 +697,9 @@ def _create_aot_dispatcher_function(
|
||||
static_input_indices=aot_config.static_input_indices,
|
||||
keep_input_mutations=aot_config.keep_inference_input_mutations,
|
||||
is_train=needs_autograd,
|
||||
pre_dispatch=aot_config.pre_dispatch,
|
||||
is_export=aot_config.is_export,
|
||||
gm=gm,
|
||||
functional_mode=func_mode,
|
||||
)(*_dup_fake_script_obj(fake_flat_args))
|
||||
|
||||
req_subclass_dispatch = requires_subclass_dispatch(
|
||||
@ -735,7 +744,6 @@ def _create_aot_dispatcher_function(
|
||||
flat_fn,
|
||||
keep_input_mutations=aot_config.keep_inference_input_mutations,
|
||||
is_train=False,
|
||||
pre_dispatch=aot_config.pre_dispatch,
|
||||
static_input_indices=aot_config.static_input_indices,
|
||||
)(*fake_flat_args)
|
||||
else:
|
||||
@ -1172,6 +1180,13 @@ def aot_module_simplified(
|
||||
cache_info=None,
|
||||
ignore_shape_env=ignore_shape_env,
|
||||
precompile_backend_id=getattr(mod, "_backend_id", None),
|
||||
gm=mod
|
||||
if (
|
||||
isinstance(mod, torch.fx.GraphModule)
|
||||
and not torch._inductor.config.freezing
|
||||
)
|
||||
else None,
|
||||
functional_mode=tracing_context.functional_mode if tracing_context else None,
|
||||
)
|
||||
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
|
||||
fake_flat_args = process_inputs(
|
||||
|
||||
@ -13,7 +13,7 @@ import unittest.mock
|
||||
import weakref
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
@ -823,13 +823,16 @@ class TracingContext:
|
||||
"TracingContext.get() must be called within an ongoing trace."
|
||||
)
|
||||
|
||||
def __init__(self, fake_mode):
|
||||
def __init__(self, fake_mode, functional_mode=None):
|
||||
self.guards_context = GuardsContext()
|
||||
self.module_context = ModuleContext()
|
||||
self.global_context = GlobalContext()
|
||||
self.previously_inlined_functions = dict()
|
||||
self.previously_cleaned_instructions = dict()
|
||||
self.fake_mode = fake_mode
|
||||
self.functional_mode = (
|
||||
functional_mode if functional_mode is not None else nullcontext()
|
||||
)
|
||||
self.frame_summary_stack = []
|
||||
# This is morally part of frame_summary_stack, but it is kept separate
|
||||
# for clarity. As we process a frame, this variable gets updated
|
||||
|
||||
@ -9,6 +9,7 @@ import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._functorch._aot_autograd.functional_utils import from_fun
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._higher_order_ops.schema import HopSchema
|
||||
from torch._ops import HigherOrderOperator, OperatorBase, OpOverload
|
||||
@ -362,21 +363,23 @@ def _collect_fake_inputs(inputs):
|
||||
if hasattr(inp, "meta"):
|
||||
val = inp.meta["example_value"]
|
||||
if isinstance(val, torch.Tensor):
|
||||
# Check for and unwrap common FakeTensor Wrappers in 'example_value'
|
||||
if torch._C._functorch.is_batchedtensor(
|
||||
val
|
||||
) or torch._C._functorch.is_functionaltensor(val):
|
||||
# This case is for batched or functional tensors
|
||||
# Unwrap the tensors
|
||||
while torch._C._functorch.is_batchedtensor(
|
||||
val
|
||||
) or torch._C._functorch.is_functionaltensor(val):
|
||||
val = torch._C._functorch.get_unwrapped(val)
|
||||
assert isinstance(val, FakeTensor)
|
||||
inputs_fake.append(val)
|
||||
else:
|
||||
# This is the standard case of a TensorVariable
|
||||
assert isinstance(val, FakeTensor)
|
||||
inputs_fake.append(val)
|
||||
elif isinstance(val, FunctionalTensor):
|
||||
# TODO: Figure out why these may be different
|
||||
# than the C subclass semantics laid out above
|
||||
# NOTE: The implementation for `is_functionaltensor`
|
||||
# checks the DispatchKeySet, and it is possible this has
|
||||
# been disabled as is the case for collect_view_and_mutation_metadata
|
||||
val = from_fun(val)
|
||||
assert isinstance(val, FakeTensor)
|
||||
inputs_fake.append(val)
|
||||
else:
|
||||
# This case is for SymInts and other non-Tensor elements
|
||||
assert not isinstance(val, torch.Tensor)
|
||||
|
||||
@ -3,14 +3,17 @@ import contextlib
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _functionalization_reapply_views_tls as _reapply_views
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.overrides import TorchFunctionMode
|
||||
from torch.utils._python_dispatch import (
|
||||
_detect_infra_mode,
|
||||
_disable_infra_mode,
|
||||
@ -354,6 +357,7 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
super().__exit__(a, b, c)
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
# breakpoint()
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
@ -568,6 +572,79 @@ def disable_functional_mode():
|
||||
return _disable_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
|
||||
|
||||
|
||||
# TODO: pull these from aot autograd
|
||||
def from_fun(t):
|
||||
if not isinstance(t, FunctionalTensor):
|
||||
# quick sanity assert
|
||||
if isinstance(t, torch.Tensor):
|
||||
assert not torch._is_functional_tensor(t)
|
||||
return t
|
||||
torch._sync(t)
|
||||
return torch._from_functional_tensor(t.elem)
|
||||
|
||||
|
||||
def to_fun(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return FunctionalTensor.to_functional(t)
|
||||
return t
|
||||
|
||||
|
||||
# Just for use to allow copying a module to Functional(Fake) tensors,
|
||||
# does not apply elsewhere
|
||||
class FunctionalCopyMode(TorchFunctionMode):
|
||||
def __init__(
|
||||
self, fake_mode: FakeTensorMode, functional_mode: FunctionalTensorMode
|
||||
) -> None:
|
||||
self.fake_mode = fake_mode
|
||||
self.functional_mode = functional_mode
|
||||
|
||||
def __torch_function__(
|
||||
self,
|
||||
func: torch._ops.OpOverload,
|
||||
types: Sequence[type],
|
||||
args: Sequence[object] = (),
|
||||
kwargs: Optional[Mapping[str, object]] = None,
|
||||
) -> FunctionalTensor:
|
||||
# Disable any outer functional modes, as these will try and intercept
|
||||
# the dispatch in `func`; we only want this done afterwards
|
||||
is_func_clone = func == torch._C.TensorBase.clone
|
||||
is_func_deepcopy = func == torch.Tensor.__deepcopy__
|
||||
kwargs = kwargs if kwargs else {}
|
||||
|
||||
if is_func_clone or is_func_deepcopy:
|
||||
# Need to run without functional for copy/clone
|
||||
with disable_functional_mode():
|
||||
# clone will get called in Parameter deepcopy
|
||||
if is_func_clone:
|
||||
assert isinstance(args[0], torch.Tensor)
|
||||
fake_clone = func(
|
||||
self.fake_mode.from_tensor(args[0], static_shapes=True),
|
||||
**kwargs,
|
||||
)
|
||||
elif is_func_deepcopy:
|
||||
assert len(args) == 2 and len(kwargs) == 0
|
||||
tensor = cast(torch.Tensor, args[0])
|
||||
memo = cast(dict[int, FunctionalTensor], args[1])
|
||||
if id(tensor) in memo:
|
||||
return memo[id(tensor)]
|
||||
fake_clone = self.fake_mode.from_tensor(tensor, static_shapes=True)
|
||||
elif func == torch._C.TensorBase.detach:
|
||||
# FunctionalTensor should run without being disabled
|
||||
assert isinstance(args[0], torch.Tensor)
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
with disable_functional_mode(), torch._C.DisableTorchFunctionSubclass():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with self.functional_mode:
|
||||
out = to_fun(fake_clone)
|
||||
if is_func_deepcopy:
|
||||
tensor = cast(torch.Tensor, args[0])
|
||||
memo = cast(dict[int, FunctionalTensor], args[1])
|
||||
memo[id(tensor)] = out
|
||||
return out
|
||||
|
||||
|
||||
# This is similar to torch.func.functionalize, but:
|
||||
# - It uses FunctionalTensorMode, and FunctionalTensor (a python subclass).
|
||||
# One important advantage to using this mode is that it will let us
|
||||
@ -577,21 +654,6 @@ def disable_functional_mode():
|
||||
# functorch transforms, since these transforms always run above __torch_dispatch__.
|
||||
# That's why this util lives here, and not in functorch.
|
||||
def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMode()):
|
||||
# TODO: pull these from aot autograd
|
||||
def to_fun(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return FunctionalTensor.to_functional(t)
|
||||
return t
|
||||
|
||||
def from_fun(t):
|
||||
if not isinstance(t, FunctionalTensor):
|
||||
# quick sanity assert
|
||||
if isinstance(t, torch.Tensor):
|
||||
assert not torch._is_functional_tensor(t)
|
||||
return t
|
||||
torch._sync(t)
|
||||
return torch._from_functional_tensor(t.elem)
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
disable_above = torch._C._ExcludeDispatchKeyGuard(
|
||||
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
|
||||
|
||||
@ -1148,7 +1148,13 @@ def _free_unbacked_symbols_with_path(
|
||||
r.update(go(sub, path + (InnerTensorKey(attr),)))
|
||||
elif isinstance(a, torch.Tensor):
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
|
||||
# Note, a may also be functional wrapping a Fake, so check and unwrap if needed
|
||||
if isinstance(a, FunctionalTensor):
|
||||
from torch._functorch._aot_autograd.functional_utils import from_fun
|
||||
|
||||
a = from_fun(a)
|
||||
assert isinstance(a, FakeTensor)
|
||||
r.update(
|
||||
go(
|
||||
|
||||
Reference in New Issue
Block a user