mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PYFMT lint grandfathered files 1 (#154261)
lint: - test/test_fake_tensor.py - test/test_flop_counter.py - torch/_export/verifier.py with same rules as other files, it was a night mare for me to update tests in one of the skipped files with not being able to lint them locally like other files with lintrunner -a. note that those file do have active dev and not old not touched files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154261 Approved by: https://github.com/angelayi, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
5677ab9aab
commit
43b2716e89
@ -1193,8 +1193,6 @@ exclude_patterns = [
|
||||
'test/quantization/fx/test_numeric_suite_fx.py',
|
||||
'test/quantization/fx/test_quantize_fx.py',
|
||||
'test/quantization/fx/test_subgraph_rewriter.py',
|
||||
'test/test_fake_tensor.py',
|
||||
'test/test_flop_counter.py',
|
||||
'test/test_function_schema.py',
|
||||
'test/test_functional_autograd_benchmark.py',
|
||||
'test/test_functional_optim.py',
|
||||
@ -1330,7 +1328,6 @@ exclude_patterns = [
|
||||
'torch/_export/serde/serialize.py',
|
||||
'torch/_export/serde/upgrade.py',
|
||||
'torch/_export/trace.py',
|
||||
'torch/_export/verifier.py',
|
||||
'torch/testing/_internal/__init__.py',
|
||||
'torch/testing/_internal/autocast_test_lists.py',
|
||||
'torch/testing/_internal/autograd_function_db.py',
|
||||
|
@ -3731,7 +3731,7 @@ class NcclProcessGroupWithDispatchedCollectivesTests(
|
||||
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
||||
def test_allgather_float8(self, float8_dtype):
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
if not sm_is_or_higher_than(device, 9, 0):
|
||||
if not sm_is_or_higher_than(device, 9, 0): # noqa: F821
|
||||
self.skipTest("FP8 reduction support begins with sm90 capable devices")
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
|
@ -5,23 +5,23 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import gc
|
||||
import inspect
|
||||
import io
|
||||
import itertools
|
||||
import pickle
|
||||
import unittest
|
||||
import weakref
|
||||
from unittest.mock import patch
|
||||
import io
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._functorch.config
|
||||
import torch._prims as prims
|
||||
import torch.testing._internal.optests as optests
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch import distributed as dist
|
||||
from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
@ -32,10 +32,10 @@ from torch._subclasses.fake_tensor import (
|
||||
_CacheKeyState,
|
||||
DynamicOutputShapeException,
|
||||
extract_tensor_metadata,
|
||||
MetadataMismatchError,
|
||||
FakeTensor,
|
||||
FakeTensorConverter,
|
||||
FakeTensorMode,
|
||||
MetadataMismatchError,
|
||||
unset_fake_temporarily,
|
||||
UnsupportedOperatorException,
|
||||
)
|
||||
@ -56,6 +56,7 @@ from torch.testing._internal.common_device_type import (
|
||||
OpDTypes,
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types_complex_float8_and
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -68,15 +69,14 @@ from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types_complex_float8_and
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.jit_utils import RUN_CUDA
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
torch._dynamo.config.fake_tensor_cache_enabled = True
|
||||
@ -977,10 +977,12 @@ class FakeTensorTest(TestCase):
|
||||
with mode:
|
||||
x = torch.empty(2, 2, device="cpu", dtype=torch.int32)
|
||||
from torch._subclasses.fake_impls import get_fast_op_impls
|
||||
|
||||
fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor]
|
||||
y = fast_div(mode, x, 2)
|
||||
self.assertEqual(y.dtype, torch.float32)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(FakeTensorTest)
|
||||
|
||||
|
||||
@ -1115,7 +1117,9 @@ class FakeTensorOpInfoTest(TestCase):
|
||||
make_propagate_real_tensors_cls(FakeTensorOpInfoTest)
|
||||
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda"))
|
||||
instantiate_device_type_tests(
|
||||
PropagateRealTensorsFakeTensorOpInfoTest, globals(), only_for=("cpu",) # noqa: F821
|
||||
PropagateRealTensorsFakeTensorOpInfoTest, # noqa: F821
|
||||
globals(),
|
||||
only_for=("cpu",),
|
||||
)
|
||||
|
||||
|
||||
@ -1415,13 +1419,11 @@ class FakeTensorOperatorInvariants(TestCase):
|
||||
self.assertTrue("output[0]" not in str(e))
|
||||
if self.__class__.__name__.startswith("PropagateRealTensors"):
|
||||
self.assertTrue(
|
||||
"Real tensor propagation found a metadata mismatch"
|
||||
in str(e)
|
||||
"Real tensor propagation found a metadata mismatch" in str(e)
|
||||
)
|
||||
else:
|
||||
self.assertTrue(
|
||||
"found mismatched tensor metadata for output"
|
||||
in str(e)
|
||||
"found mismatched tensor metadata for output" in str(e)
|
||||
)
|
||||
|
||||
# IMPORTANT!!! Always run even if CUDA is not available
|
||||
@ -1623,61 +1625,74 @@ class FakeTensorPropTest(TestCase):
|
||||
def test_torch_load_with_fake_mode(self):
|
||||
model = torch.nn.Linear(5, 10)
|
||||
sd = model.state_dict()
|
||||
sd['tt'] = TwoTensor(torch.randn(2), torch.randn(2))
|
||||
sd["tt"] = TwoTensor(torch.randn(2), torch.randn(2))
|
||||
|
||||
def _read_tensor_and_check(key, sd_loaded, all_bytes, device):
|
||||
dtype = torch.float32
|
||||
t = sd_loaded[key]
|
||||
self.assertEqual(t.device.type, device)
|
||||
if isinstance(t, TwoTensor):
|
||||
untyped_storage_a, untyped_storage_b = t.a.untyped_storage(), t.b.untyped_storage()
|
||||
offset_a, offset_b = untyped_storage_a._checkpoint_offset, untyped_storage_b._checkpoint_offset
|
||||
nbytes_a, nbytes_b = untyped_storage_a.nbytes() // 4, untyped_storage_b.nbytes() // 4
|
||||
result_a = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes_a, offset=offset_a).resize_(t.a.size())
|
||||
result_b = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes_b, offset=offset_b).resize_(t.b.size())
|
||||
untyped_storage_a, untyped_storage_b = (
|
||||
t.a.untyped_storage(),
|
||||
t.b.untyped_storage(),
|
||||
)
|
||||
offset_a, offset_b = (
|
||||
untyped_storage_a._checkpoint_offset,
|
||||
untyped_storage_b._checkpoint_offset,
|
||||
)
|
||||
nbytes_a, nbytes_b = (
|
||||
untyped_storage_a.nbytes() // 4,
|
||||
untyped_storage_b.nbytes() // 4,
|
||||
)
|
||||
result_a = torch.frombuffer(
|
||||
all_bytes, dtype=dtype, count=nbytes_a, offset=offset_a
|
||||
).resize_(t.a.size())
|
||||
result_b = torch.frombuffer(
|
||||
all_bytes, dtype=dtype, count=nbytes_b, offset=offset_b
|
||||
).resize_(t.b.size())
|
||||
self.assertEqual(TwoTensor(result_a, result_b), sd[key])
|
||||
else:
|
||||
untyped_storage = t.untyped_storage()
|
||||
offset = untyped_storage._checkpoint_offset
|
||||
nbytes = untyped_storage.nbytes() // 4
|
||||
result = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes, offset=offset).resize_(t.size())
|
||||
result = torch.frombuffer(
|
||||
all_bytes, dtype=dtype, count=nbytes, offset=offset
|
||||
).resize_(t.size())
|
||||
self.assertEqual(result, sd[key])
|
||||
|
||||
|
||||
with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]):
|
||||
# Create state_dict to be loaded later
|
||||
torch.save(sd, f)
|
||||
with open(f, 'rb') as g:
|
||||
with open(f, "rb") as g:
|
||||
all_bytes = g.read()
|
||||
|
||||
fake_mode = FakeTensorMode()
|
||||
with fake_mode:
|
||||
sd_loaded = torch.load(f)
|
||||
for k in sd:
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cpu')
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, "cpu")
|
||||
with fake_mode:
|
||||
sd_loaded = torch.load(f, map_location="cuda")
|
||||
for k in sd:
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cuda')
|
||||
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, "cuda")
|
||||
|
||||
for k in sd.keys():
|
||||
sd[k] = sd[k].to('cuda')
|
||||
sd[k] = sd[k].to("cuda")
|
||||
|
||||
with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]):
|
||||
torch.save(sd, f)
|
||||
with open(f, 'rb') as g:
|
||||
with open(f, "rb") as g:
|
||||
all_bytes = g.read()
|
||||
|
||||
fake_mode = FakeTensorMode()
|
||||
with fake_mode:
|
||||
sd_loaded = torch.load(f)
|
||||
for k in sd:
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cuda')
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, "cuda")
|
||||
with fake_mode:
|
||||
sd_loaded = torch.load(f, map_location="cpu")
|
||||
for k in sd:
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cpu')
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, "cpu")
|
||||
|
||||
|
||||
make_propagate_real_tensors_cls(FakeTensorPropTest)
|
||||
@ -1994,9 +2009,9 @@ class FakeTensorDispatchCache(TestCase):
|
||||
x = torch.randn(s0, s1, s2)
|
||||
out = torch.randn(s0, s3, s4)
|
||||
kwargs = {
|
||||
's': (s3, s4),
|
||||
'dim': (1, s5),
|
||||
'norm': 'ortho',
|
||||
"s": (s3, s4),
|
||||
"dim": (1, s5),
|
||||
"norm": "ortho",
|
||||
}
|
||||
r = torch._C._fft.fft_hfft2(x, **kwargs, out=out)
|
||||
self.assertEqual(r.shape, out.shape)
|
||||
@ -2074,8 +2089,12 @@ class FakeTensorDispatchCache(TestCase):
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
args = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, args)
|
||||
kwargs = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs)
|
||||
args = pytree.tree_map_only(
|
||||
DifferentDeviceTensor, lambda x: x.inner_tensor, args
|
||||
)
|
||||
kwargs = pytree.tree_map_only(
|
||||
DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs
|
||||
)
|
||||
# Returns unwrapped tensor
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -2098,7 +2117,7 @@ class FakeTensorDispatchCache(TestCase):
|
||||
return torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=[256, 256],
|
||||
mode='bilinear',
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
@ -2108,8 +2127,13 @@ class FakeTensorDispatchCache(TestCase):
|
||||
x = fake_m.from_tensor(
|
||||
torch.randn(1, 3, 2005, 1920, requires_grad=True),
|
||||
symbolic_context=StatelessSymbolicContext(
|
||||
dynamic_sizes=[DimDynamic.STATIC, DimDynamic.STATIC, DimDynamic.DYNAMIC, DimDynamic.DYNAMIC],
|
||||
constraint_sizes=[None, None, None, None]
|
||||
dynamic_sizes=[
|
||||
DimDynamic.STATIC,
|
||||
DimDynamic.STATIC,
|
||||
DimDynamic.DYNAMIC,
|
||||
DimDynamic.DYNAMIC,
|
||||
],
|
||||
constraint_sizes=[None, None, None, None],
|
||||
),
|
||||
)
|
||||
with fake_m, enable_python_dispatcher():
|
||||
@ -2126,14 +2150,14 @@ class FakeTensorDispatchCache(TestCase):
|
||||
|
||||
t = torch.ByteTensor(storage)
|
||||
self.assertTrue(isinstance(t, FakeTensor))
|
||||
self.assertEqual(t.device, torch.device('cpu'))
|
||||
self.assertEqual(t.device, torch.device("cpu"))
|
||||
|
||||
def test_meta_tensor_to_fake_cpu(self):
|
||||
x = torch.randn(4, 4, device='meta')
|
||||
x = torch.randn(4, 4, device="meta")
|
||||
with FakeTensorMode(allow_non_fake_inputs=True):
|
||||
x_cpu = x.to(device='cpu')
|
||||
x_cpu = x.to(device="cpu")
|
||||
self.assertTrue(isinstance(x_cpu, FakeTensor))
|
||||
self.assertEqual(x_cpu.device, torch.device('cpu'))
|
||||
self.assertEqual(x_cpu.device, torch.device("cpu"))
|
||||
|
||||
def test_cache_tuple_outputs(self):
|
||||
"""
|
||||
@ -2158,7 +2182,6 @@ class FakeTensorDispatchCache(TestCase):
|
||||
extract_tensor_metadata(b),
|
||||
)
|
||||
|
||||
|
||||
def test_cache_aten_index(self):
|
||||
with FakeTensorMode():
|
||||
x = torch.randn(4, 4, 4)
|
||||
@ -2178,10 +2201,16 @@ class FakeTensorDispatchCache(TestCase):
|
||||
with FakeTensorMode():
|
||||
x = torch.randn(4, 4, 4)
|
||||
idx_tensor1 = torch.tensor([True, True, False, True])
|
||||
self.assertRaises(DynamicOutputShapeException, lambda: torch.ops.aten.index(x, [None, idx_tensor1]))
|
||||
self.assertRaises(
|
||||
DynamicOutputShapeException,
|
||||
lambda: torch.ops.aten.index(x, [None, idx_tensor1]),
|
||||
)
|
||||
|
||||
idx_tensor1 = torch.tensor([1, -2, 3, -4], dtype=torch.int8)
|
||||
self.assertRaises(DynamicOutputShapeException, lambda: torch.ops.aten.index(x, [None, idx_tensor1]))
|
||||
self.assertRaises(
|
||||
DynamicOutputShapeException,
|
||||
lambda: torch.ops.aten.index(x, [None, idx_tensor1]),
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
|
||||
def test_invoke_subgraph(self):
|
||||
@ -2335,11 +2364,14 @@ class FakeTensorDispatchCache(TestCase):
|
||||
lengths = torch.tensor([0, 2, 3, 1, 4])
|
||||
indices = torch.tensor([2, 3, 4, 6, 7, 8, 9])
|
||||
offsets = torch.cumsum(lengths, 0)
|
||||
ep = torch.export.export(LengthsGather(), (input, lengths, indices, offsets), strict=False)
|
||||
ep = torch.export.export(
|
||||
LengthsGather(), (input, lengths, indices, offsets), strict=False
|
||||
)
|
||||
|
||||
FakeTensorMode.cache_clear()
|
||||
ep.run_decompositions({})
|
||||
self.assertBypasses("unrepresented symbol in output", 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -8,18 +8,19 @@ import torch.nn.functional as F
|
||||
import torch.utils.flop_counter
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase,
|
||||
skipIfRocm,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from torchvision import models as torchvision_models
|
||||
|
||||
@ -422,7 +423,12 @@ class TestFlopCounter(TestCase):
|
||||
run_uniform_flops(backend, with_backward=True)
|
||||
for backend in ["math", "flash", "mem_efficient", "cudnn"]
|
||||
]
|
||||
flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient, flops_fw_bw_cudnn = flops
|
||||
(
|
||||
flops_fw_bw_math,
|
||||
flops_fw_bw_flash,
|
||||
flops_fw_bw_efficient,
|
||||
flops_fw_bw_cudnn,
|
||||
) = flops
|
||||
self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
|
||||
self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
|
||||
self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
|
||||
@ -705,7 +711,9 @@ class TestFlopCounter(TestCase):
|
||||
False,
|
||||
)
|
||||
|
||||
dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2)
|
||||
dense_x = torch.randn(
|
||||
4, 40, 4, 16, dtype=torch.bfloat16, device="cuda"
|
||||
).transpose(1, 2)
|
||||
|
||||
with FlopCounterMode() as real_flop_counter_mode:
|
||||
torch.ops.aten._flash_attention_forward(
|
||||
@ -721,8 +729,10 @@ class TestFlopCounter(TestCase):
|
||||
False,
|
||||
)
|
||||
|
||||
self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode)))
|
||||
|
||||
self.assertEqual(
|
||||
int(get_total_flops(fake_flop_counter_mode)),
|
||||
int(get_total_flops(real_flop_counter_mode)),
|
||||
)
|
||||
|
||||
def test_addmm_out(self):
|
||||
def f(x):
|
||||
@ -795,7 +805,9 @@ class TestFlopCounter(TestCase):
|
||||
|
||||
called = 0
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "expected each target to be OpOverloadPacket"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "expected each target to be OpOverloadPacket"
|
||||
):
|
||||
register_flop_formula(torch.ops.mylib.foo.default)(lambda x: x)
|
||||
|
||||
@register_flop_formula(torch.ops.mylib.foo)
|
||||
@ -826,7 +838,9 @@ class TestFlopCounter(TestCase):
|
||||
with torch.inference_mode():
|
||||
mode_inference = get_flops(resnet18)
|
||||
|
||||
self.assertEqual(get_total_flops(mode_standard), get_total_flops(mode_inference))
|
||||
self.assertEqual(
|
||||
get_total_flops(mode_standard), get_total_flops(mode_inference)
|
||||
)
|
||||
|
||||
layer1_conv_flops_standard = mode_standard.flop_counts["ResNet.layer1"][
|
||||
torch.ops.aten.convolution
|
||||
@ -854,5 +868,6 @@ class TestFlopCounter(TestCase):
|
||||
|
||||
self.assertExpectedInline(get_total_flops(mode), """860160""")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -11,17 +11,19 @@ from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.export.graph_signature import (
|
||||
CustomObjArgument,
|
||||
InputKind,
|
||||
SymIntArgument,
|
||||
SymFloatArgument,
|
||||
SymBoolArgument,
|
||||
SymFloatArgument,
|
||||
SymIntArgument,
|
||||
TensorArgument,
|
||||
TokenArgument,
|
||||
)
|
||||
from torch.fx import GraphModule
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.export.exported_program import ExportedProgram
|
||||
|
||||
|
||||
class SpecViolationError(Exception):
|
||||
pass
|
||||
|
||||
@ -43,9 +45,13 @@ def _check_val(node: torch.fx.Node) -> None:
|
||||
return True
|
||||
elif isinstance(val, (int, bool, str, float)):
|
||||
return True
|
||||
elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)):
|
||||
elif isinstance(
|
||||
val, (torch.memory_format, torch.dtype, torch.device, torch.layout)
|
||||
):
|
||||
return True
|
||||
elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor.
|
||||
elif isinstance(
|
||||
val, (FakeTensor, torch.Tensor)
|
||||
): # TODO(zhxchen17) Remove Tensor.
|
||||
return True
|
||||
elif isinstance(val, (SymInt, SymFloat, SymBool)):
|
||||
return True
|
||||
@ -73,16 +79,21 @@ def _check_val(node: torch.fx.Node) -> None:
|
||||
def _check_torch_fn(node: torch.fx.Node) -> None:
|
||||
torch_fn = node.meta.get("torch_fn")
|
||||
if torch_fn is None:
|
||||
raise SpecViolationError(f"Unable to find torch_fn metadata for node {node.name}")
|
||||
raise SpecViolationError(
|
||||
f"Unable to find torch_fn metadata for node {node.name}"
|
||||
)
|
||||
if (
|
||||
not isinstance(torch_fn, tuple) and
|
||||
isinstance(torch_fn[0], str) and
|
||||
isinstance(torch_fn[1], str)
|
||||
not isinstance(torch_fn, tuple)
|
||||
and isinstance(torch_fn[0], str)
|
||||
and isinstance(torch_fn[1], str)
|
||||
):
|
||||
raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}")
|
||||
raise SpecViolationError(
|
||||
f"Node.meta {node.name} has invalid torch_fn field {torch_fn}"
|
||||
)
|
||||
|
||||
|
||||
class _VerifierMeta(type):
|
||||
_registry: dict[str, type['Verifier']] = {}
|
||||
_registry: dict[str, type["Verifier"]] = {}
|
||||
|
||||
def __new__(metacls, name, bases, attrs):
|
||||
if bases:
|
||||
@ -99,12 +110,15 @@ class _VerifierMeta(type):
|
||||
metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
|
||||
return ret
|
||||
|
||||
|
||||
def getattr_recursive(obj: Any, target: str) -> Any:
|
||||
target_atoms = target.split('.')
|
||||
target_atoms = target.split(".")
|
||||
attr_itr = obj
|
||||
for i, atom in enumerate(target_atoms):
|
||||
if not hasattr(attr_itr, atom):
|
||||
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
|
||||
raise RuntimeError(
|
||||
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
|
||||
)
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
return attr_itr
|
||||
|
||||
@ -153,7 +167,7 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
torch.fx.GraphModule,
|
||||
torch.nn.parameter.Parameter,
|
||||
torch.Tensor, # for buffer and constant tensor
|
||||
torch.utils._pytree.TreeSpec
|
||||
torch.utils._pytree.TreeSpec,
|
||||
)
|
||||
|
||||
def check_valid_op(self, op):
|
||||
@ -211,7 +225,10 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
)
|
||||
|
||||
if not isinstance(op, _allowed_op_types()):
|
||||
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
|
||||
if (
|
||||
op not in _allowed_builtin_ops()
|
||||
and op not in _allowed_torch_functions
|
||||
):
|
||||
raise SpecViolationError(
|
||||
f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n"
|
||||
f"Valid builtin ops: {_allowed_builtin_ops()}"
|
||||
@ -222,9 +239,7 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
# All ops functional
|
||||
# TODO (tmanlaibaatar) more proper way is needed here
|
||||
if self.dialect != "TRAINING" and not is_functional(op):
|
||||
raise SpecViolationError(
|
||||
f"operator '{op}' is not functional"
|
||||
)
|
||||
raise SpecViolationError(f"operator '{op}' is not functional")
|
||||
self.check_valid_op(op)
|
||||
|
||||
for mod in gm.modules():
|
||||
@ -254,13 +269,17 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
|
||||
attr = getattr_recursive(mod, node.target)
|
||||
if isinstance(attr, torch.nn.Module):
|
||||
|
||||
def _is_type(name, ty):
|
||||
return isinstance(getattr(attr, name, None), ty)
|
||||
|
||||
if type(attr).__name__ == "LoweredBackendModule":
|
||||
if _is_type("backend_id", str) \
|
||||
and _is_type("processed_bytes", bytes) \
|
||||
and _is_type("compile_specs", list) \
|
||||
and hasattr(attr, "original_module"):
|
||||
if (
|
||||
_is_type("backend_id", str)
|
||||
and _is_type("processed_bytes", bytes)
|
||||
and _is_type("compile_specs", list)
|
||||
and hasattr(attr, "original_module")
|
||||
):
|
||||
continue
|
||||
else:
|
||||
backend_id = getattr(attr, "backend_id", None)
|
||||
@ -285,7 +304,6 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
f"Valid get_attr types: {_allowed_getattr_types(is_toplevel_gm)}"
|
||||
)
|
||||
|
||||
|
||||
elif node.op == "placeholder":
|
||||
_check_val(node)
|
||||
# TODO(zhxchen17)
|
||||
@ -301,9 +319,7 @@ class TrainingIRVerifier(Verifier):
|
||||
|
||||
def _verify_exported_program_module_call_graph(exported_program) -> None:
|
||||
module_call_graph = exported_program.module_call_graph
|
||||
nodes = {
|
||||
node.name for node in exported_program.graph.nodes
|
||||
}
|
||||
nodes = {node.name for node in exported_program.graph.nodes}
|
||||
for entry in module_call_graph:
|
||||
if entry.signature is not None:
|
||||
for arg in entry.signature.inputs:
|
||||
@ -323,7 +339,9 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
gs = exported_program.graph_signature
|
||||
|
||||
# Check every node in the signature exists in the graph
|
||||
input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
|
||||
input_node_names = [
|
||||
node.name for node in exported_program.graph.nodes if node.op == "placeholder"
|
||||
]
|
||||
|
||||
if len(input_node_names) != len(gs.input_specs):
|
||||
raise SpecViolationError(
|
||||
@ -332,7 +350,10 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
)
|
||||
|
||||
for input_spec, node in zip(gs.input_specs, input_node_names):
|
||||
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument)):
|
||||
if isinstance(
|
||||
input_spec.arg,
|
||||
(TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument),
|
||||
):
|
||||
if input_spec.arg.name != node:
|
||||
raise SpecViolationError(
|
||||
f"Input spec name {input_spec.arg.name} does not match node name {node}"
|
||||
@ -353,9 +374,7 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
|
||||
param = input_spec.target
|
||||
if param not in exported_program.state_dict:
|
||||
raise SpecViolationError(
|
||||
f"Parameter {param} is not in the state dict."
|
||||
)
|
||||
raise SpecViolationError(f"Parameter {param} is not in the state dict.")
|
||||
|
||||
if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
|
||||
raise SpecViolationError(
|
||||
@ -378,10 +397,11 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
f"Buffer {buffer} is missing a persistence flag"
|
||||
)
|
||||
|
||||
if input_spec.persistent is True and buffer not in exported_program.state_dict:
|
||||
raise SpecViolationError(
|
||||
f"Buffer {buffer} is not in the state dict."
|
||||
)
|
||||
if (
|
||||
input_spec.persistent is True
|
||||
and buffer not in exported_program.state_dict
|
||||
):
|
||||
raise SpecViolationError(f"Buffer {buffer} is not in the state dict.")
|
||||
|
||||
if input_spec.persistent is False and buffer in exported_program.state_dict:
|
||||
raise SpecViolationError(
|
||||
@ -423,9 +443,7 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
||||
)
|
||||
else:
|
||||
raise SpecViolationError(
|
||||
f"Unknown InputKind {input_spec.kind}."
|
||||
)
|
||||
raise SpecViolationError(f"Unknown InputKind {input_spec.kind}.")
|
||||
|
||||
# Check outputs
|
||||
output_node = list(exported_program.graph.nodes)[-1]
|
||||
@ -446,7 +464,7 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
num_tokens = len(gs.output_tokens)
|
||||
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
|
||||
mutate_nodes: list[str] = output_nodes[num_tokens:end]
|
||||
user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
|
||||
user_output_nodes = output_nodes[end : end + len(gs.user_outputs)]
|
||||
|
||||
for mutation_node in mutate_nodes:
|
||||
if mutation_node in gs.buffers_to_mutate:
|
||||
@ -461,7 +479,8 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
raise SpecViolationError(
|
||||
f"User input output {mutation_node} does not point to a user input that exists. \n"
|
||||
f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
|
||||
f"User input nodes available: {gs.user_inputs} \n")
|
||||
f"User input nodes available: {gs.user_inputs} \n"
|
||||
)
|
||||
else:
|
||||
raise SpecViolationError(
|
||||
f"Mutation node {mutation_node} is neither a buffer nor a user input. "
|
||||
|
Reference in New Issue
Block a user