PYFMT lint grandfathered files 1 (#154261)

lint:
-  test/test_fake_tensor.py
-  test/test_flop_counter.py
- torch/_export/verifier.py

with same rules as other files, it was a night mare for me to update tests in one of the skipped files
with not being able to lint them locally like other files with lintrunner -a.
note that those file do have active dev and not old not touched files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154261
Approved by: https://github.com/angelayi, https://github.com/Skylion007
This commit is contained in:
Laith Sakka
2025-05-23 15:06:44 -07:00
committed by PyTorch MergeBot
parent 5677ab9aab
commit 43b2716e89
5 changed files with 157 additions and 94 deletions

View File

@ -1193,8 +1193,6 @@ exclude_patterns = [
'test/quantization/fx/test_numeric_suite_fx.py',
'test/quantization/fx/test_quantize_fx.py',
'test/quantization/fx/test_subgraph_rewriter.py',
'test/test_fake_tensor.py',
'test/test_flop_counter.py',
'test/test_function_schema.py',
'test/test_functional_autograd_benchmark.py',
'test/test_functional_optim.py',
@ -1330,7 +1328,6 @@ exclude_patterns = [
'torch/_export/serde/serialize.py',
'torch/_export/serde/upgrade.py',
'torch/_export/trace.py',
'torch/_export/verifier.py',
'torch/testing/_internal/__init__.py',
'torch/testing/_internal/autocast_test_lists.py',
'torch/testing/_internal/autograd_function_db.py',

View File

@ -3731,7 +3731,7 @@ class NcclProcessGroupWithDispatchedCollectivesTests(
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
def test_allgather_float8(self, float8_dtype):
device = torch.device(f"cuda:{self.rank:d}")
if not sm_is_or_higher_than(device, 9, 0):
if not sm_is_or_higher_than(device, 9, 0): # noqa: F821
self.skipTest("FP8 reduction support begins with sm90 capable devices")
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(

View File

@ -5,23 +5,23 @@
import contextlib
import copy
import dataclasses
import gc
import inspect
import io
import itertools
import pickle
import unittest
import weakref
from unittest.mock import patch
import io
import gc
import numpy as np
import torch
import torch._dynamo
import torch._functorch.config
import torch._prims as prims
import torch.testing._internal.optests as optests
import torch.utils._pytree as pytree
from torch import distributed as dist
from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor
from torch._dispatch.python import enable_python_dispatcher
@ -32,10 +32,10 @@ from torch._subclasses.fake_tensor import (
_CacheKeyState,
DynamicOutputShapeException,
extract_tensor_metadata,
MetadataMismatchError,
FakeTensor,
FakeTensorConverter,
FakeTensorMode,
MetadataMismatchError,
unset_fake_temporarily,
UnsupportedOperatorException,
)
@ -56,6 +56,7 @@ from torch.testing._internal.common_device_type import (
OpDTypes,
ops,
)
from torch.testing._internal.common_dtype import all_types_complex_float8_and
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -68,15 +69,14 @@ from torch.testing._internal.common_utils import (
TestCase,
xfailIfTorchDynamo,
)
from torch.testing._internal.common_dtype import all_types_complex_float8_and
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.jit_utils import RUN_CUDA
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
aten = torch.ops.aten
torch._dynamo.config.fake_tensor_cache_enabled = True
@ -977,10 +977,12 @@ class FakeTensorTest(TestCase):
with mode:
x = torch.empty(2, 2, device="cpu", dtype=torch.int32)
from torch._subclasses.fake_impls import get_fast_op_impls
fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor]
y = fast_div(mode, x, 2)
self.assertEqual(y.dtype, torch.float32)
instantiate_parametrized_tests(FakeTensorTest)
@ -1115,7 +1117,9 @@ class FakeTensorOpInfoTest(TestCase):
make_propagate_real_tensors_cls(FakeTensorOpInfoTest)
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda"))
instantiate_device_type_tests(
PropagateRealTensorsFakeTensorOpInfoTest, globals(), only_for=("cpu",) # noqa: F821
PropagateRealTensorsFakeTensorOpInfoTest, # noqa: F821
globals(),
only_for=("cpu",),
)
@ -1415,13 +1419,11 @@ class FakeTensorOperatorInvariants(TestCase):
self.assertTrue("output[0]" not in str(e))
if self.__class__.__name__.startswith("PropagateRealTensors"):
self.assertTrue(
"Real tensor propagation found a metadata mismatch"
in str(e)
"Real tensor propagation found a metadata mismatch" in str(e)
)
else:
self.assertTrue(
"found mismatched tensor metadata for output"
in str(e)
"found mismatched tensor metadata for output" in str(e)
)
# IMPORTANT!!! Always run even if CUDA is not available
@ -1623,61 +1625,74 @@ class FakeTensorPropTest(TestCase):
def test_torch_load_with_fake_mode(self):
model = torch.nn.Linear(5, 10)
sd = model.state_dict()
sd['tt'] = TwoTensor(torch.randn(2), torch.randn(2))
sd["tt"] = TwoTensor(torch.randn(2), torch.randn(2))
def _read_tensor_and_check(key, sd_loaded, all_bytes, device):
dtype = torch.float32
t = sd_loaded[key]
self.assertEqual(t.device.type, device)
if isinstance(t, TwoTensor):
untyped_storage_a, untyped_storage_b = t.a.untyped_storage(), t.b.untyped_storage()
offset_a, offset_b = untyped_storage_a._checkpoint_offset, untyped_storage_b._checkpoint_offset
nbytes_a, nbytes_b = untyped_storage_a.nbytes() // 4, untyped_storage_b.nbytes() // 4
result_a = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes_a, offset=offset_a).resize_(t.a.size())
result_b = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes_b, offset=offset_b).resize_(t.b.size())
untyped_storage_a, untyped_storage_b = (
t.a.untyped_storage(),
t.b.untyped_storage(),
)
offset_a, offset_b = (
untyped_storage_a._checkpoint_offset,
untyped_storage_b._checkpoint_offset,
)
nbytes_a, nbytes_b = (
untyped_storage_a.nbytes() // 4,
untyped_storage_b.nbytes() // 4,
)
result_a = torch.frombuffer(
all_bytes, dtype=dtype, count=nbytes_a, offset=offset_a
).resize_(t.a.size())
result_b = torch.frombuffer(
all_bytes, dtype=dtype, count=nbytes_b, offset=offset_b
).resize_(t.b.size())
self.assertEqual(TwoTensor(result_a, result_b), sd[key])
else:
untyped_storage = t.untyped_storage()
offset = untyped_storage._checkpoint_offset
nbytes = untyped_storage.nbytes() // 4
result = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes, offset=offset).resize_(t.size())
result = torch.frombuffer(
all_bytes, dtype=dtype, count=nbytes, offset=offset
).resize_(t.size())
self.assertEqual(result, sd[key])
with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]):
# Create state_dict to be loaded later
torch.save(sd, f)
with open(f, 'rb') as g:
with open(f, "rb") as g:
all_bytes = g.read()
fake_mode = FakeTensorMode()
with fake_mode:
sd_loaded = torch.load(f)
for k in sd:
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cpu')
_read_tensor_and_check(k, sd_loaded, all_bytes, "cpu")
with fake_mode:
sd_loaded = torch.load(f, map_location="cuda")
for k in sd:
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cuda')
_read_tensor_and_check(k, sd_loaded, all_bytes, "cuda")
for k in sd.keys():
sd[k] = sd[k].to('cuda')
sd[k] = sd[k].to("cuda")
with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]):
torch.save(sd, f)
with open(f, 'rb') as g:
with open(f, "rb") as g:
all_bytes = g.read()
fake_mode = FakeTensorMode()
with fake_mode:
sd_loaded = torch.load(f)
for k in sd:
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cuda')
_read_tensor_and_check(k, sd_loaded, all_bytes, "cuda")
with fake_mode:
sd_loaded = torch.load(f, map_location="cpu")
for k in sd:
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cpu')
_read_tensor_and_check(k, sd_loaded, all_bytes, "cpu")
make_propagate_real_tensors_cls(FakeTensorPropTest)
@ -1994,9 +2009,9 @@ class FakeTensorDispatchCache(TestCase):
x = torch.randn(s0, s1, s2)
out = torch.randn(s0, s3, s4)
kwargs = {
's': (s3, s4),
'dim': (1, s5),
'norm': 'ortho',
"s": (s3, s4),
"dim": (1, s5),
"norm": "ortho",
}
r = torch._C._fft.fft_hfft2(x, **kwargs, out=out)
self.assertEqual(r.shape, out.shape)
@ -2074,8 +2089,12 @@ class FakeTensorDispatchCache(TestCase):
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, args)
kwargs = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs)
args = pytree.tree_map_only(
DifferentDeviceTensor, lambda x: x.inner_tensor, args
)
kwargs = pytree.tree_map_only(
DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs
)
# Returns unwrapped tensor
return func(*args, **kwargs)
@ -2098,7 +2117,7 @@ class FakeTensorDispatchCache(TestCase):
return torch.nn.functional.interpolate(
x,
size=[256, 256],
mode='bilinear',
mode="bilinear",
align_corners=False,
antialias=True,
)
@ -2108,8 +2127,13 @@ class FakeTensorDispatchCache(TestCase):
x = fake_m.from_tensor(
torch.randn(1, 3, 2005, 1920, requires_grad=True),
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[DimDynamic.STATIC, DimDynamic.STATIC, DimDynamic.DYNAMIC, DimDynamic.DYNAMIC],
constraint_sizes=[None, None, None, None]
dynamic_sizes=[
DimDynamic.STATIC,
DimDynamic.STATIC,
DimDynamic.DYNAMIC,
DimDynamic.DYNAMIC,
],
constraint_sizes=[None, None, None, None],
),
)
with fake_m, enable_python_dispatcher():
@ -2126,14 +2150,14 @@ class FakeTensorDispatchCache(TestCase):
t = torch.ByteTensor(storage)
self.assertTrue(isinstance(t, FakeTensor))
self.assertEqual(t.device, torch.device('cpu'))
self.assertEqual(t.device, torch.device("cpu"))
def test_meta_tensor_to_fake_cpu(self):
x = torch.randn(4, 4, device='meta')
x = torch.randn(4, 4, device="meta")
with FakeTensorMode(allow_non_fake_inputs=True):
x_cpu = x.to(device='cpu')
x_cpu = x.to(device="cpu")
self.assertTrue(isinstance(x_cpu, FakeTensor))
self.assertEqual(x_cpu.device, torch.device('cpu'))
self.assertEqual(x_cpu.device, torch.device("cpu"))
def test_cache_tuple_outputs(self):
"""
@ -2158,7 +2182,6 @@ class FakeTensorDispatchCache(TestCase):
extract_tensor_metadata(b),
)
def test_cache_aten_index(self):
with FakeTensorMode():
x = torch.randn(4, 4, 4)
@ -2178,10 +2201,16 @@ class FakeTensorDispatchCache(TestCase):
with FakeTensorMode():
x = torch.randn(4, 4, 4)
idx_tensor1 = torch.tensor([True, True, False, True])
self.assertRaises(DynamicOutputShapeException, lambda: torch.ops.aten.index(x, [None, idx_tensor1]))
self.assertRaises(
DynamicOutputShapeException,
lambda: torch.ops.aten.index(x, [None, idx_tensor1]),
)
idx_tensor1 = torch.tensor([1, -2, 3, -4], dtype=torch.int8)
self.assertRaises(DynamicOutputShapeException, lambda: torch.ops.aten.index(x, [None, idx_tensor1]))
self.assertRaises(
DynamicOutputShapeException,
lambda: torch.ops.aten.index(x, [None, idx_tensor1]),
)
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
def test_invoke_subgraph(self):
@ -2335,11 +2364,14 @@ class FakeTensorDispatchCache(TestCase):
lengths = torch.tensor([0, 2, 3, 1, 4])
indices = torch.tensor([2, 3, 4, 6, 7, 8, 9])
offsets = torch.cumsum(lengths, 0)
ep = torch.export.export(LengthsGather(), (input, lengths, indices, offsets), strict=False)
ep = torch.export.export(
LengthsGather(), (input, lengths, indices, offsets), strict=False
)
FakeTensorMode.cache_clear()
ep.run_decompositions({})
self.assertBypasses("unrepresented symbol in output", 2)
if __name__ == "__main__":
run_tests()

View File

@ -8,18 +8,19 @@ import torch.nn.functional as F
import torch.utils.flop_counter
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
PLATFORM_SUPPORTS_CUDNN_ATTENTION
)
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocm,
TEST_WITH_TORCHDYNAMO,
TestCase,
skipIfRocm,
)
try:
from torchvision import models as torchvision_models
@ -422,7 +423,12 @@ class TestFlopCounter(TestCase):
run_uniform_flops(backend, with_backward=True)
for backend in ["math", "flash", "mem_efficient", "cudnn"]
]
flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient, flops_fw_bw_cudnn = flops
(
flops_fw_bw_math,
flops_fw_bw_flash,
flops_fw_bw_efficient,
flops_fw_bw_cudnn,
) = flops
self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
@ -705,7 +711,9 @@ class TestFlopCounter(TestCase):
False,
)
dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2)
dense_x = torch.randn(
4, 40, 4, 16, dtype=torch.bfloat16, device="cuda"
).transpose(1, 2)
with FlopCounterMode() as real_flop_counter_mode:
torch.ops.aten._flash_attention_forward(
@ -721,8 +729,10 @@ class TestFlopCounter(TestCase):
False,
)
self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode)))
self.assertEqual(
int(get_total_flops(fake_flop_counter_mode)),
int(get_total_flops(real_flop_counter_mode)),
)
def test_addmm_out(self):
def f(x):
@ -795,7 +805,9 @@ class TestFlopCounter(TestCase):
called = 0
with self.assertRaisesRegex(ValueError, "expected each target to be OpOverloadPacket"):
with self.assertRaisesRegex(
ValueError, "expected each target to be OpOverloadPacket"
):
register_flop_formula(torch.ops.mylib.foo.default)(lambda x: x)
@register_flop_formula(torch.ops.mylib.foo)
@ -826,7 +838,9 @@ class TestFlopCounter(TestCase):
with torch.inference_mode():
mode_inference = get_flops(resnet18)
self.assertEqual(get_total_flops(mode_standard), get_total_flops(mode_inference))
self.assertEqual(
get_total_flops(mode_standard), get_total_flops(mode_inference)
)
layer1_conv_flops_standard = mode_standard.flop_counts["ResNet.layer1"][
torch.ops.aten.convolution
@ -854,5 +868,6 @@ class TestFlopCounter(TestCase):
self.assertExpectedInline(get_total_flops(mode), """860160""")
if __name__ == "__main__":
run_tests()

View File

@ -11,17 +11,19 @@ from torch._subclasses.fake_tensor import FakeTensor
from torch.export.graph_signature import (
CustomObjArgument,
InputKind,
SymIntArgument,
SymFloatArgument,
SymBoolArgument,
SymFloatArgument,
SymIntArgument,
TensorArgument,
TokenArgument,
)
from torch.fx import GraphModule
if TYPE_CHECKING:
from torch.export.exported_program import ExportedProgram
class SpecViolationError(Exception):
pass
@ -43,9 +45,13 @@ def _check_val(node: torch.fx.Node) -> None:
return True
elif isinstance(val, (int, bool, str, float)):
return True
elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)):
elif isinstance(
val, (torch.memory_format, torch.dtype, torch.device, torch.layout)
):
return True
elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor.
elif isinstance(
val, (FakeTensor, torch.Tensor)
): # TODO(zhxchen17) Remove Tensor.
return True
elif isinstance(val, (SymInt, SymFloat, SymBool)):
return True
@ -73,16 +79,21 @@ def _check_val(node: torch.fx.Node) -> None:
def _check_torch_fn(node: torch.fx.Node) -> None:
torch_fn = node.meta.get("torch_fn")
if torch_fn is None:
raise SpecViolationError(f"Unable to find torch_fn metadata for node {node.name}")
raise SpecViolationError(
f"Unable to find torch_fn metadata for node {node.name}"
)
if (
not isinstance(torch_fn, tuple) and
isinstance(torch_fn[0], str) and
isinstance(torch_fn[1], str)
not isinstance(torch_fn, tuple)
and isinstance(torch_fn[0], str)
and isinstance(torch_fn[1], str)
):
raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}")
raise SpecViolationError(
f"Node.meta {node.name} has invalid torch_fn field {torch_fn}"
)
class _VerifierMeta(type):
_registry: dict[str, type['Verifier']] = {}
_registry: dict[str, type["Verifier"]] = {}
def __new__(metacls, name, bases, attrs):
if bases:
@ -99,12 +110,15 @@ class _VerifierMeta(type):
metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
return ret
def getattr_recursive(obj: Any, target: str) -> Any:
target_atoms = target.split('.')
target_atoms = target.split(".")
attr_itr = obj
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr
@ -153,7 +167,7 @@ class Verifier(metaclass=_VerifierMeta):
torch.fx.GraphModule,
torch.nn.parameter.Parameter,
torch.Tensor, # for buffer and constant tensor
torch.utils._pytree.TreeSpec
torch.utils._pytree.TreeSpec,
)
def check_valid_op(self, op):
@ -211,7 +225,10 @@ class Verifier(metaclass=_VerifierMeta):
)
if not isinstance(op, _allowed_op_types()):
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
if (
op not in _allowed_builtin_ops()
and op not in _allowed_torch_functions
):
raise SpecViolationError(
f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n"
f"Valid builtin ops: {_allowed_builtin_ops()}"
@ -222,9 +239,7 @@ class Verifier(metaclass=_VerifierMeta):
# All ops functional
# TODO (tmanlaibaatar) more proper way is needed here
if self.dialect != "TRAINING" and not is_functional(op):
raise SpecViolationError(
f"operator '{op}' is not functional"
)
raise SpecViolationError(f"operator '{op}' is not functional")
self.check_valid_op(op)
for mod in gm.modules():
@ -254,13 +269,17 @@ class Verifier(metaclass=_VerifierMeta):
attr = getattr_recursive(mod, node.target)
if isinstance(attr, torch.nn.Module):
def _is_type(name, ty):
return isinstance(getattr(attr, name, None), ty)
if type(attr).__name__ == "LoweredBackendModule":
if _is_type("backend_id", str) \
and _is_type("processed_bytes", bytes) \
and _is_type("compile_specs", list) \
and hasattr(attr, "original_module"):
if (
_is_type("backend_id", str)
and _is_type("processed_bytes", bytes)
and _is_type("compile_specs", list)
and hasattr(attr, "original_module")
):
continue
else:
backend_id = getattr(attr, "backend_id", None)
@ -285,7 +304,6 @@ class Verifier(metaclass=_VerifierMeta):
f"Valid get_attr types: {_allowed_getattr_types(is_toplevel_gm)}"
)
elif node.op == "placeholder":
_check_val(node)
# TODO(zhxchen17)
@ -301,9 +319,7 @@ class TrainingIRVerifier(Verifier):
def _verify_exported_program_module_call_graph(exported_program) -> None:
module_call_graph = exported_program.module_call_graph
nodes = {
node.name for node in exported_program.graph.nodes
}
nodes = {node.name for node in exported_program.graph.nodes}
for entry in module_call_graph:
if entry.signature is not None:
for arg in entry.signature.inputs:
@ -323,7 +339,9 @@ def _verify_exported_program_signature(exported_program) -> None:
gs = exported_program.graph_signature
# Check every node in the signature exists in the graph
input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
input_node_names = [
node.name for node in exported_program.graph.nodes if node.op == "placeholder"
]
if len(input_node_names) != len(gs.input_specs):
raise SpecViolationError(
@ -332,7 +350,10 @@ def _verify_exported_program_signature(exported_program) -> None:
)
for input_spec, node in zip(gs.input_specs, input_node_names):
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument)):
if isinstance(
input_spec.arg,
(TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument),
):
if input_spec.arg.name != node:
raise SpecViolationError(
f"Input spec name {input_spec.arg.name} does not match node name {node}"
@ -353,9 +374,7 @@ def _verify_exported_program_signature(exported_program) -> None:
param = input_spec.target
if param not in exported_program.state_dict:
raise SpecViolationError(
f"Parameter {param} is not in the state dict."
)
raise SpecViolationError(f"Parameter {param} is not in the state dict.")
if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
raise SpecViolationError(
@ -378,10 +397,11 @@ def _verify_exported_program_signature(exported_program) -> None:
f"Buffer {buffer} is missing a persistence flag"
)
if input_spec.persistent is True and buffer not in exported_program.state_dict:
raise SpecViolationError(
f"Buffer {buffer} is not in the state dict."
)
if (
input_spec.persistent is True
and buffer not in exported_program.state_dict
):
raise SpecViolationError(f"Buffer {buffer} is not in the state dict.")
if input_spec.persistent is False and buffer in exported_program.state_dict:
raise SpecViolationError(
@ -423,9 +443,7 @@ def _verify_exported_program_signature(exported_program) -> None:
f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
)
else:
raise SpecViolationError(
f"Unknown InputKind {input_spec.kind}."
)
raise SpecViolationError(f"Unknown InputKind {input_spec.kind}.")
# Check outputs
output_node = list(exported_program.graph.nodes)[-1]
@ -446,7 +464,7 @@ def _verify_exported_program_signature(exported_program) -> None:
num_tokens = len(gs.output_tokens)
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
mutate_nodes: list[str] = output_nodes[num_tokens:end]
user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
user_output_nodes = output_nodes[end : end + len(gs.user_outputs)]
for mutation_node in mutate_nodes:
if mutation_node in gs.buffers_to_mutate:
@ -461,7 +479,8 @@ def _verify_exported_program_signature(exported_program) -> None:
raise SpecViolationError(
f"User input output {mutation_node} does not point to a user input that exists. \n"
f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
f"User input nodes available: {gs.user_inputs} \n")
f"User input nodes available: {gs.user_inputs} \n"
)
else:
raise SpecViolationError(
f"Mutation node {mutation_node} is neither a buffer nor a user input. "