mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][PYFMT] migrate PYFMT for test/[i-z]*/
to ruff format
(#144556)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144556 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
19ce1beb05
commit
775788f93b
@ -802,7 +802,8 @@ class AddedAttributesTest(JitBackendTestCase):
|
||||
# Attach bundled inputs which adds several attributes and functions to the model
|
||||
self.lowered_module = (
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
lowered_module, input # noqa: F821
|
||||
lowered_module, # noqa: F821
|
||||
input,
|
||||
)
|
||||
)
|
||||
post_bundled = self.lowered_module(
|
||||
|
@ -279,23 +279,23 @@ class TestModuleContainers(JitTestCase):
|
||||
self.moduledict = CustomModuleDict({"submod": self.submod})
|
||||
|
||||
def forward(self, inputs):
|
||||
assert (
|
||||
self.modulelist[0] is self.submod
|
||||
), "__getitem__ failing for ModuleList"
|
||||
assert self.modulelist[0] is self.submod, (
|
||||
"__getitem__ failing for ModuleList"
|
||||
)
|
||||
assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
|
||||
for module in self.modulelist:
|
||||
assert module is self.submod, "__iter__ failing for ModuleList"
|
||||
|
||||
assert (
|
||||
self.sequential[0] is self.submod
|
||||
), "__getitem__ failing for Sequential"
|
||||
assert self.sequential[0] is self.submod, (
|
||||
"__getitem__ failing for Sequential"
|
||||
)
|
||||
assert len(self.sequential) == 1, "__len__ failing for Sequential"
|
||||
for module in self.sequential:
|
||||
assert module is self.submod, "__iter__ failing for Sequential"
|
||||
|
||||
assert (
|
||||
self.moduledict["submod"] is self.submod
|
||||
), "__getitem__ failing for ModuleDict"
|
||||
assert self.moduledict["submod"] is self.submod, (
|
||||
"__getitem__ failing for ModuleDict"
|
||||
)
|
||||
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
|
||||
|
||||
# note: unable to index moduledict with a string variable currently
|
||||
@ -439,9 +439,9 @@ class TestModuleContainers(JitTestCase):
|
||||
self.moduledict = CustomModuleDict()
|
||||
|
||||
def forward(self, inputs):
|
||||
assert (
|
||||
"submod" not in self.moduledict
|
||||
), "__contains__ fails for ModuleDict"
|
||||
assert "submod" not in self.moduledict, (
|
||||
"__contains__ fails for ModuleDict"
|
||||
)
|
||||
return inputs
|
||||
|
||||
m = MyModule()
|
||||
|
@ -405,9 +405,7 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
def f():
|
||||
x = torch.ones(10, 9, 8, 7, 6)
|
||||
return x{indices}.shape
|
||||
""".format(
|
||||
indices=indices
|
||||
)
|
||||
""".format(indices=indices)
|
||||
)
|
||||
test_str = test_str.replace(r"'", r"")
|
||||
scope = {}
|
||||
|
@ -139,9 +139,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -160,9 +158,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -181,9 +177,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -202,9 +196,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -223,9 +215,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -244,9 +234,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -265,9 +253,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -286,9 +272,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -307,9 +291,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -328,9 +310,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -351,9 +331,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
"doesn't support instance-level annotations on empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
|
@ -960,8 +960,9 @@ class TestTracer(JitTestCase):
|
||||
V = Variable
|
||||
a, b = V(torch.rand(1)), V(torch.rand(1))
|
||||
ge = torch.jit.trace(foo, (a, b))
|
||||
a, b = V(torch.rand(1), requires_grad=True), V(
|
||||
torch.rand(1), requires_grad=True
|
||||
a, b = (
|
||||
V(torch.rand(1), requires_grad=True),
|
||||
V(torch.rand(1), requires_grad=True),
|
||||
)
|
||||
(r,) = ge(a, b)
|
||||
da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
|
||||
|
@ -396,9 +396,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"only int, float, "
|
||||
"complex, Tensor, device and string keys "
|
||||
"are supported",
|
||||
"only int, float, complex, Tensor, device and string keys are supported",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
@ -602,9 +600,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"y is set to type str"
|
||||
" in the true branch and type int "
|
||||
"in the false branch",
|
||||
"y is set to type str in the true branch and type int in the false branch",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
@ -622,9 +618,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"previously had type "
|
||||
"str but is now being assigned to a"
|
||||
" value of type int",
|
||||
"previously had type str but is now being assigned to a value of type int",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
@ -729,8 +723,7 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_literal_empty"],
|
||||
"there are multiple possible List type "
|
||||
"candidates in the Union annotation",
|
||||
"there are multiple possible List type candidates in the Union annotation",
|
||||
)
|
||||
|
||||
self._assert_passes(
|
||||
@ -902,8 +895,7 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_literal_of_mixed"],
|
||||
"none of those dict types can hold the "
|
||||
"types of the given keys and values",
|
||||
"none of those dict types can hold the types of the given keys and values",
|
||||
)
|
||||
|
||||
# TODO: String frontend does not support tuple unpacking
|
||||
|
@ -406,9 +406,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"only int, float, "
|
||||
"complex, Tensor, device and string keys "
|
||||
"are supported",
|
||||
"only int, float, complex, Tensor, device and string keys are supported",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
@ -612,9 +610,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"y is set to type str"
|
||||
" in the true branch and type int "
|
||||
"in the false branch",
|
||||
"y is set to type str in the true branch and type int in the false branch",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
@ -632,9 +628,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"previously had type "
|
||||
"str but is now being assigned to a"
|
||||
" value of type int",
|
||||
"previously had type str but is now being assigned to a value of type int",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
@ -739,8 +733,7 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"List[str] | List[torch.Tensor]",
|
||||
lhs["list_literal_empty"],
|
||||
"there are multiple possible List type "
|
||||
"candidates in the Union annotation",
|
||||
"there are multiple possible List type candidates in the Union annotation",
|
||||
)
|
||||
|
||||
self._assert_passes(
|
||||
@ -906,8 +899,7 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"Dict[str, torch.Tensor] | Dict[str, int]",
|
||||
lhs["dict_literal_of_mixed"],
|
||||
"none of those dict types can hold the "
|
||||
"types of the given keys and values",
|
||||
"none of those dict types can hold the types of the given keys and values",
|
||||
)
|
||||
|
||||
# TODO: String frontend does not support tuple unpacking
|
||||
|
@ -135,12 +135,14 @@ class TestWarn(JitTestCase):
|
||||
bar()
|
||||
|
||||
FileCheck().check_count(
|
||||
str="UserWarning: I am warning you from foo", count=1, exactly=True
|
||||
str="UserWarning: I am warning you from foo",
|
||||
count=1,
|
||||
exactly=True,
|
||||
).check_count(
|
||||
str="UserWarning: I am warning you from bar", count=1, exactly=True
|
||||
).run(
|
||||
f.getvalue()
|
||||
)
|
||||
str="UserWarning: I am warning you from bar",
|
||||
count=1,
|
||||
exactly=True,
|
||||
).run(f.getvalue())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -42,12 +42,12 @@ class LazyGeneratorTest(TestCase):
|
||||
|
||||
torch._lazy.mark_step()
|
||||
|
||||
assert torch.allclose(
|
||||
cpu_t1, lazy_t1.to("cpu")
|
||||
), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
|
||||
assert torch.allclose(
|
||||
cpu_t2, lazy_t2.to("cpu")
|
||||
), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
|
||||
assert torch.allclose(cpu_t1, lazy_t1.to("cpu")), (
|
||||
f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
|
||||
)
|
||||
assert torch.allclose(cpu_t2, lazy_t2.to("cpu")), (
|
||||
f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
|
||||
def test_generator_causes_multiple_compiles(self):
|
||||
@ -69,29 +69,29 @@ class LazyGeneratorTest(TestCase):
|
||||
torch._lazy.mark_step()
|
||||
|
||||
uncached_compile = metrics.counter_value("UncachedCompile")
|
||||
assert (
|
||||
uncached_compile == 1
|
||||
), f"Expected 1 uncached compiles, got {uncached_compile}"
|
||||
assert uncached_compile == 1, (
|
||||
f"Expected 1 uncached compiles, got {uncached_compile}"
|
||||
)
|
||||
|
||||
t = generate_tensor(2)
|
||||
torch._lazy.mark_step()
|
||||
|
||||
uncached_compile = metrics.counter_value("UncachedCompile")
|
||||
assert (
|
||||
uncached_compile == 2
|
||||
), f"Expected 2 uncached compiles, got {uncached_compile}"
|
||||
assert uncached_compile == 2, (
|
||||
f"Expected 2 uncached compiles, got {uncached_compile}"
|
||||
)
|
||||
|
||||
t = generate_tensor(1) # noqa: F841
|
||||
torch._lazy.mark_step()
|
||||
|
||||
uncached_compile = metrics.counter_value("UncachedCompile")
|
||||
assert (
|
||||
uncached_compile == 2
|
||||
), f"Expected 2 uncached compiles, got {uncached_compile}"
|
||||
assert uncached_compile == 2, (
|
||||
f"Expected 2 uncached compiles, got {uncached_compile}"
|
||||
)
|
||||
cached_compile = metrics.counter_value("CachedCompile")
|
||||
assert (
|
||||
cached_compile == 1
|
||||
), f"Expected 1 cached compile, got {cached_compile}"
|
||||
assert cached_compile == 1, (
|
||||
f"Expected 1 cached compile, got {cached_compile}"
|
||||
)
|
||||
|
||||
metrics.reset()
|
||||
|
||||
|
@ -486,17 +486,9 @@ class TestLiteScriptModule(TestCase):
|
||||
"Traceback of TorchScript"
|
||||
).check("self.b.forwardError").check_next(
|
||||
"~~~~~~~~~~~~~~~~~~~ <--- HERE"
|
||||
).check(
|
||||
"return self.call"
|
||||
).check_next(
|
||||
"~~~~~~~~~ <--- HERE"
|
||||
).check(
|
||||
).check("return self.call").check_next("~~~~~~~~~ <--- HERE").check(
|
||||
"return torch.ones"
|
||||
).check_next(
|
||||
"~~~~~~~~~~ <--- HERE"
|
||||
).run(
|
||||
str(exp)
|
||||
)
|
||||
).check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
|
||||
|
||||
|
||||
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
|
||||
|
@ -25,7 +25,7 @@ class TestNnModule(torch.nn.Module):
|
||||
torch.nn.ReLU(True),
|
||||
# state size. (ngf) x 32 x 32
|
||||
torch.nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
||||
torch.nn.Tanh()
|
||||
torch.nn.Tanh(),
|
||||
# state size. (nc) x 64 x 64
|
||||
)
|
||||
|
||||
|
@ -721,9 +721,9 @@ class TestExecutionTrace(TestCase):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
fp_name = os.path.join(temp_dir, "test.et.json")
|
||||
|
||||
os.environ[
|
||||
"ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA"
|
||||
] = "aten::gather"
|
||||
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA"] = (
|
||||
"aten::gather"
|
||||
)
|
||||
et = ExecutionTraceObserver()
|
||||
et.register_callback(fp_name)
|
||||
et.set_extra_resource_collection(True)
|
||||
|
@ -1468,7 +1468,7 @@ class TestProfiler(TestCase):
|
||||
cats = {e.get("cat", None) for e in j["traceEvents"]}
|
||||
self.assertTrue(
|
||||
"cuda_sync" in cats,
|
||||
"Expected to find cuda_sync event" f" found = {cats}",
|
||||
f"Expected to find cuda_sync event found = {cats}",
|
||||
)
|
||||
|
||||
print("Testing enable_cuda_sync_events in _ExperimentalConfig")
|
||||
|
@ -47,6 +47,6 @@ class AOMigrationTestCase(TestCase):
|
||||
new_dict = getattr(new_location, dict_name)
|
||||
assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
|
||||
for key in new_dict.keys():
|
||||
assert (
|
||||
old_dict[key] == new_dict[key]
|
||||
), f"Dicts don't match: {dict_name} for key {key}"
|
||||
assert old_dict[key] == new_dict[key], (
|
||||
f"Dicts don't match: {dict_name} for key {key}"
|
||||
)
|
||||
|
@ -426,7 +426,6 @@ instantiate_device_type_tests(TestFloat4Dtype, globals())
|
||||
|
||||
|
||||
class TestFloat8DtypeCPUOnly(TestCase):
|
||||
|
||||
"""
|
||||
Test of mul implementation
|
||||
|
||||
|
@ -248,9 +248,9 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||
+ cls._FLOAT_MODULE.__name__
|
||||
)
|
||||
if not qconfig:
|
||||
assert hasattr(
|
||||
mod, "qconfig"
|
||||
), "Input float module must have qconfig defined"
|
||||
assert hasattr(mod, "qconfig"), (
|
||||
"Input float module must have qconfig defined"
|
||||
)
|
||||
assert mod.qconfig, "Input float module must have a valid qconfig"
|
||||
qconfig = mod.qconfig
|
||||
conv, bn = mod[0], mod[1]
|
||||
|
@ -99,17 +99,17 @@ class OnDevicePTQUtils:
|
||||
):
|
||||
raise ValueError("Quantized weight must be produced.")
|
||||
fp_weight = weight.inputsAt(0).node()
|
||||
assert (
|
||||
fp_weight.kind() == "prim::GetAttr"
|
||||
), "Weight must be an attribute of the module."
|
||||
assert fp_weight.kind() == "prim::GetAttr", (
|
||||
"Weight must be an attribute of the module."
|
||||
)
|
||||
fp_weight_name = fp_weight.s("name")
|
||||
return fp_weight_name
|
||||
|
||||
@staticmethod
|
||||
def is_per_channel_quantized_packed_param(node):
|
||||
assert (
|
||||
node.kind() == "quantized::linear_prepack"
|
||||
), "Node must corresponds to linear_prepack."
|
||||
assert node.kind() == "quantized::linear_prepack", (
|
||||
"Node must corresponds to linear_prepack."
|
||||
)
|
||||
weight = node.inputsAt(0).node()
|
||||
assert (
|
||||
weight.kind() != "aten::quantize_per_tensor"
|
||||
|
@ -124,11 +124,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
"aten::dequantize"
|
||||
).check_not("aten::quantize_per_channel").check("aten::dequantize").check_next(
|
||||
"aten::conv2d"
|
||||
).check_next(
|
||||
"aten::quantize_per_tensor"
|
||||
).check_next(
|
||||
"aten::dequantize"
|
||||
).run(
|
||||
).check_next("aten::quantize_per_tensor").check_next("aten::dequantize").run(
|
||||
freezed.graph
|
||||
)
|
||||
|
||||
@ -670,9 +666,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
}
|
||||
assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype"
|
||||
assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype"
|
||||
assert next(iter(activation_dtypes)) != next(
|
||||
iter(weight_dtypes)
|
||||
), "Expected activation dtype to "
|
||||
assert next(iter(activation_dtypes)) != next(iter(weight_dtypes)), (
|
||||
"Expected activation dtype to "
|
||||
)
|
||||
" be different from wegiht dtype"
|
||||
|
||||
def test_insert_observers_for_reused_weight(self):
|
||||
@ -706,9 +702,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
conv2_observers = attrs_with_prefix(m.conv2, "_observer_")
|
||||
assert len(conv1_observers) == 1, "Expected to have 1 observer submodules"
|
||||
assert len(conv2_observers) == 1, "Expected to have 1 observer submodules"
|
||||
assert (
|
||||
conv1_observers == conv2_observers
|
||||
), "Expect conv1 and conv2 to have same observers since the class type is shared"
|
||||
assert conv1_observers == conv2_observers, (
|
||||
"Expect conv1 and conv2 to have same observers since the class type is shared"
|
||||
)
|
||||
|
||||
def test_insert_observers_for_general_ops(self):
|
||||
"""Make sure we skip observers for ops that doesn't require
|
||||
@ -734,13 +730,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
'prim::GetAttr[name="conv"]'
|
||||
).check("prim::CallMethod").check(
|
||||
'Observer = prim::GetAttr[name="_observer_'
|
||||
).check(
|
||||
"aten::flatten"
|
||||
).check_not(
|
||||
).check("aten::flatten").check_not(
|
||||
'Observer = prim::GetAttr[name="_observer_'
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).run(m.graph)
|
||||
|
||||
# TODO: this is too long, split this to test_insert_observers.py and remove
|
||||
# insrt_observers prefix
|
||||
@ -770,17 +762,11 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
'prim::GetAttr[name="conv1"]'
|
||||
).check("prim::CallMethod").check(
|
||||
'Observer = prim::GetAttr[name="_observer_'
|
||||
).check(
|
||||
"aten::flatten"
|
||||
).check_not(
|
||||
).check("aten::flatten").check_not(
|
||||
'Observer = prim::GetAttr[name="_observer_'
|
||||
).check(
|
||||
'prim::GetAttr[name="conv2"]'
|
||||
).check(
|
||||
).check('prim::GetAttr[name="conv2"]').check(
|
||||
'Observer = prim::GetAttr[name="_observer_'
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).run(m.graph)
|
||||
|
||||
def test_insert_observers_propagate_observed_in_submodule(self):
|
||||
"""Make sure we propagate observed property through general ops"""
|
||||
@ -809,17 +795,11 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
'prim::GetAttr[name="conv1"]'
|
||||
).check("prim::CallMethod").check(
|
||||
'Observer = prim::GetAttr[name="_observer_'
|
||||
).check(
|
||||
"prim::CallMethod"
|
||||
).check_not(
|
||||
).check("prim::CallMethod").check_not(
|
||||
'Observer = prim::GetAttr[name="_observer_'
|
||||
).check(
|
||||
'prim::GetAttr[name="conv2"]'
|
||||
).check(
|
||||
).check('prim::GetAttr[name="conv2"]').check(
|
||||
'Observer = prim::GetAttr[name="_observer_'
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).run(m.graph)
|
||||
|
||||
def test_insert_observers_propagate_observed_for_function(self):
|
||||
def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
|
||||
@ -1055,9 +1035,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
|
||||
m(data)
|
||||
m = convert_jit(m, debug=True)
|
||||
assert (
|
||||
len(m._modules._c.items()) == 1
|
||||
), "Expected to have single submodule of conv"
|
||||
assert len(m._modules._c.items()) == 1, (
|
||||
"Expected to have single submodule of conv"
|
||||
)
|
||||
# make sure the quantized model is executable
|
||||
m(data)
|
||||
quant_func = (
|
||||
@ -1088,17 +1068,17 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
qconfig_dict = {"": qconfig}
|
||||
m = prepare_jit(m, qconfig_dict)
|
||||
# observers for input, output and value between conv1/conv2
|
||||
assert (
|
||||
len(attrs_with_prefix(m, "_observer_")) == 3
|
||||
), "Expected to have 3 obervers"
|
||||
assert len(attrs_with_prefix(m, "_observer_")) == 3, (
|
||||
"Expected to have 3 obervers"
|
||||
)
|
||||
# observer for weight
|
||||
assert (
|
||||
len(attrs_with_prefix(m.conv1, "_observer_")) == 1
|
||||
), "Expected to have 1 obervers"
|
||||
assert len(attrs_with_prefix(m.conv1, "_observer_")) == 1, (
|
||||
"Expected to have 1 obervers"
|
||||
)
|
||||
# observer for weight
|
||||
assert (
|
||||
len(attrs_with_prefix(m.conv2, "_observer_")) == 1
|
||||
), "Expected to have 1 obervers"
|
||||
assert len(attrs_with_prefix(m.conv2, "_observer_")) == 1, (
|
||||
"Expected to have 1 obervers"
|
||||
)
|
||||
|
||||
data = torch.randn(1, 3, 10, 10, dtype=torch.float)
|
||||
m(data)
|
||||
@ -1107,15 +1087,15 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
assert m.conv1._c._type() == m.conv2._c._type()
|
||||
|
||||
# check all observers have been removed
|
||||
assert (
|
||||
len(attrs_with_prefix(m, "_observer_")) == 0
|
||||
), "Expected to have 0 obervers"
|
||||
assert (
|
||||
len(attrs_with_prefix(m.conv1, "_observer_")) == 0
|
||||
), "Expected to have 0 obervers"
|
||||
assert (
|
||||
len(attrs_with_prefix(m.conv2, "_observer_")) == 0
|
||||
), "Expected to have 0 obervers"
|
||||
assert len(attrs_with_prefix(m, "_observer_")) == 0, (
|
||||
"Expected to have 0 obervers"
|
||||
)
|
||||
assert len(attrs_with_prefix(m.conv1, "_observer_")) == 0, (
|
||||
"Expected to have 0 obervers"
|
||||
)
|
||||
assert len(attrs_with_prefix(m.conv2, "_observer_")) == 0, (
|
||||
"Expected to have 0 obervers"
|
||||
)
|
||||
|
||||
quant_func = (
|
||||
"aten::quantize_per_channel"
|
||||
@ -1334,11 +1314,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
"aten::avg_pool2d"
|
||||
).check("aten::q_scale").check_next("aten::q_zero_point").check_next(
|
||||
"prim::dtype"
|
||||
).check_next(
|
||||
"aten::quantize_per_tensor"
|
||||
).check(
|
||||
"aten::dequantize"
|
||||
).run(
|
||||
).check_next("aten::quantize_per_tensor").check("aten::dequantize").run(
|
||||
model.graph
|
||||
)
|
||||
|
||||
@ -1757,9 +1733,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||
"aten::relu"
|
||||
).check_not(f"quantized::conv{dim}d(").check_not(
|
||||
"quantized::relu("
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).run(m.graph)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_quantized_add_alpha(self):
|
||||
@ -1910,9 +1884,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||
"aten::relu("
|
||||
).check_not("aten::relu_(").check_not("quantized::add(").check_not(
|
||||
"quantized::relu("
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).run(m.graph)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_quantized_add(self):
|
||||
@ -2119,9 +2091,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||
"aten::relu("
|
||||
).check_not("aten::relu_(").check_not("quantized::add(").check_not(
|
||||
"quantized::relu("
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).run(m.graph)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_quantized_add_scalar_relu(self):
|
||||
@ -2205,11 +2175,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||
"aten::relu("
|
||||
).check_not("aten::relu_(").check_not(
|
||||
"quantized::add_scalar("
|
||||
).check_not(
|
||||
"quantized::relu("
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).check_not("quantized::relu(").run(m.graph)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_quantized_cat(self):
|
||||
@ -2544,9 +2510,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||
"aten::relu("
|
||||
).check_not("aten::relu_(").check_not("quantized::mul(").check_not(
|
||||
"quantized::relu("
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).run(m.graph)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_quantized_mul_scalar_relu(self):
|
||||
@ -2629,11 +2593,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||
"aten::relu("
|
||||
).check_not("aten::relu_(").check_not(
|
||||
"quantized::mul_scalar("
|
||||
).check_not(
|
||||
"quantized::relu("
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).check_not("quantized::relu(").run(m.graph)
|
||||
|
||||
@override_qengines
|
||||
def test_hardswish(self):
|
||||
@ -3103,9 +3063,7 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
|
||||
'Observer = prim::GetAttr[name="_observer_'
|
||||
).check("prim::CallMethod").check_not(
|
||||
'Observer = prim::GetAttr[name="_observer_'
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).run(m.graph)
|
||||
|
||||
def test_insert_quant_dequant_linear_dynamic(self):
|
||||
class M(torch.nn.Module):
|
||||
@ -3126,9 +3084,9 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
|
||||
else default_dynamic_qconfig
|
||||
)
|
||||
m = quantize_dynamic_jit(m, {"": qconfig}, debug=True)
|
||||
assert (
|
||||
len(m._modules._c.items()) == 2
|
||||
), "Expected to have two submodule of linear"
|
||||
assert len(m._modules._c.items()) == 2, (
|
||||
"Expected to have two submodule of linear"
|
||||
)
|
||||
|
||||
wt_quant_func = (
|
||||
"aten::quantize_per_channel"
|
||||
@ -3141,21 +3099,11 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
|
||||
act_quant_func
|
||||
).check_next("aten::dequantize").check(
|
||||
"aten::_choose_qparams_per_tensor"
|
||||
).check_next(
|
||||
act_quant_func
|
||||
).check_next(
|
||||
"aten::dequantize"
|
||||
).check(
|
||||
).check_next(act_quant_func).check_next("aten::dequantize").check(
|
||||
wt_quant_func
|
||||
).check_next(
|
||||
"aten::dequantize"
|
||||
).check_not(
|
||||
wt_quant_func
|
||||
).check(
|
||||
).check_next("aten::dequantize").check_not(wt_quant_func).check(
|
||||
"return"
|
||||
).run(
|
||||
m.graph
|
||||
)
|
||||
).run(m.graph)
|
||||
|
||||
@override_qengines
|
||||
def test_dynamic_multi_op(self):
|
||||
|
@ -254,16 +254,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
maxpool_node = node
|
||||
input_act = maxpool_node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
maxpool_node.meta[
|
||||
"quantization_annotation"
|
||||
] = QuantizationAnnotation(
|
||||
input_qspec_map={
|
||||
input_act: act_qspec,
|
||||
},
|
||||
output_qspec=SharedQuantizationSpec(
|
||||
(input_act, maxpool_node)
|
||||
),
|
||||
_annotated=True,
|
||||
maxpool_node.meta["quantization_annotation"] = (
|
||||
QuantizationAnnotation(
|
||||
input_qspec_map={
|
||||
input_act: act_qspec,
|
||||
},
|
||||
output_qspec=SharedQuantizationSpec(
|
||||
(input_act, maxpool_node)
|
||||
),
|
||||
_annotated=True,
|
||||
)
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
@ -339,9 +339,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
def derive_qparams_fn(
|
||||
obs_or_fqs: list[ObserverOrFakeQuantize],
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
assert (
|
||||
len(obs_or_fqs) == 2
|
||||
), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
|
||||
assert len(obs_or_fqs) == 2, (
|
||||
f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
|
||||
)
|
||||
act_obs_or_fq = obs_or_fqs[0]
|
||||
weight_obs_or_fq = obs_or_fqs[1]
|
||||
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
|
||||
@ -442,9 +442,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
def derive_qparams_fn(
|
||||
obs_or_fqs: list[ObserverOrFakeQuantize],
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
assert (
|
||||
len(obs_or_fqs) == 1
|
||||
), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
|
||||
assert len(obs_or_fqs) == 1, (
|
||||
f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
|
||||
)
|
||||
weight_obs_or_fq = obs_or_fqs[0]
|
||||
(
|
||||
weight_scale,
|
||||
@ -748,16 +748,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
(first_input_node, cat_node)
|
||||
)
|
||||
for input_node in input_nodes[1:]:
|
||||
input_qspec_map[
|
||||
input_node
|
||||
] = share_qparams_with_input_act0_qspec
|
||||
input_qspec_map[input_node] = (
|
||||
share_qparams_with_input_act0_qspec
|
||||
)
|
||||
|
||||
cat_node.meta[
|
||||
"quantization_annotation"
|
||||
] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=share_qparams_with_input_act0_qspec,
|
||||
_annotated=True,
|
||||
cat_node.meta["quantization_annotation"] = (
|
||||
QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=share_qparams_with_input_act0_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
@ -783,9 +783,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
obs_ins0 = getattr(m, input0.target)
|
||||
obs_ins1 = getattr(m, input1.target)
|
||||
assert obs_ins0 == obs_ins1
|
||||
assert (
|
||||
len(conv_output_obs) == 2
|
||||
), "expecting two observer that follows conv2d ops"
|
||||
assert len(conv_output_obs) == 2, (
|
||||
"expecting two observer that follows conv2d ops"
|
||||
)
|
||||
# checking that the output observers for the two convs are shared as well
|
||||
assert conv_output_obs[0] == conv_output_obs[1]
|
||||
|
||||
@ -850,9 +850,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
obs_ins2 = getattr(m, output_obs.target)
|
||||
assert obs_ins0 == obs_ins2, "input observer does not match output"
|
||||
|
||||
assert (
|
||||
len(conv_output_obs) == 2
|
||||
), "expecting two observer that follows conv2d ops"
|
||||
assert len(conv_output_obs) == 2, (
|
||||
"expecting two observer that follows conv2d ops"
|
||||
)
|
||||
# checking that the output observers for the two convs are shared as well
|
||||
assert conv_output_obs[0] == conv_output_obs[1]
|
||||
|
||||
@ -967,16 +967,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
(first_input_node, cat_node)
|
||||
)
|
||||
for input_node in input_nodes[1:]:
|
||||
input_qspec_map[
|
||||
input_node
|
||||
] = share_qparams_with_input_act0_qspec
|
||||
input_qspec_map[input_node] = (
|
||||
share_qparams_with_input_act0_qspec
|
||||
)
|
||||
|
||||
cat_node.meta[
|
||||
"quantization_annotation"
|
||||
] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=share_qparams_with_input_act0_qspec,
|
||||
_annotated=True,
|
||||
cat_node.meta["quantization_annotation"] = (
|
||||
QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=share_qparams_with_input_act0_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
@ -1063,16 +1063,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
|
||||
(second_input_node, cat_node)
|
||||
)
|
||||
input_qspec_map[
|
||||
first_input_node
|
||||
] = share_qparams_with_input_act1_qspec
|
||||
input_qspec_map[first_input_node] = (
|
||||
share_qparams_with_input_act1_qspec
|
||||
)
|
||||
|
||||
cat_node.meta[
|
||||
"quantization_annotation"
|
||||
] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=share_qparams_with_input_act1_qspec,
|
||||
_annotated=True,
|
||||
cat_node.meta["quantization_annotation"] = (
|
||||
QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=share_qparams_with_input_act1_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
@ -1121,17 +1121,17 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
|
||||
(second_input_node, add_node)
|
||||
)
|
||||
input_qspec_map[
|
||||
first_input_node
|
||||
] = share_qparams_with_input_act1_qspec
|
||||
input_qspec_map[first_input_node] = (
|
||||
share_qparams_with_input_act1_qspec
|
||||
)
|
||||
|
||||
add_node.meta[
|
||||
"quantization_annotation"
|
||||
] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=share_qparams_with_input_act1_qspec,
|
||||
allow_implicit_sharing=False,
|
||||
_annotated=True,
|
||||
add_node.meta["quantization_annotation"] = (
|
||||
QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=share_qparams_with_input_act1_qspec,
|
||||
allow_implicit_sharing=False,
|
||||
_annotated=True,
|
||||
)
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
|
@ -277,9 +277,9 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
|
||||
# Verify: conv literal args
|
||||
if expected_conv_literal_args is not None:
|
||||
assert (
|
||||
len(expected_conv_literal_args) == 6
|
||||
), "wrong num conv args, bad test setup"
|
||||
assert len(expected_conv_literal_args) == 6, (
|
||||
"wrong num conv args, bad test setup"
|
||||
)
|
||||
for i in range(6):
|
||||
if i + 3 < len(conv_node.args):
|
||||
self.assertEqual(
|
||||
|
@ -157,9 +157,9 @@ async def run1(coroutine_id):
|
||||
gpuid = coroutine_id % GPUS
|
||||
else:
|
||||
gpu_assignments = args.gpus.split(":")
|
||||
assert args.nproc == len(
|
||||
gpu_assignments
|
||||
), "Please specify GPU assignment for each process, separated by :"
|
||||
assert args.nproc == len(gpu_assignments), (
|
||||
"Please specify GPU assignment for each process, separated by :"
|
||||
)
|
||||
gpuid = gpu_assignments[coroutine_id]
|
||||
|
||||
while progress < len(ALL_TESTS):
|
||||
|
@ -1,8 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
""" Test functions for limits module.
|
||||
"""Test functions for limits module."""
|
||||
|
||||
"""
|
||||
import functools
|
||||
import warnings
|
||||
from unittest import expectedFailure as xfail, skipIf
|
||||
|
@ -4104,6 +4104,7 @@ class TestIO(TestCase):
|
||||
def test_decimal_period_separator():
|
||||
pass
|
||||
|
||||
|
||||
def test_decimal_comma_separator():
|
||||
with CommaDecimalPointLocale():
|
||||
pass
|
||||
@ -6786,7 +6787,10 @@ class TestWritebackIfCopy(TestCase):
|
||||
class TestArange(TestCase):
|
||||
def test_infinite(self):
|
||||
assert_raises(
|
||||
(RuntimeError, ValueError), np.arange, 0, np.inf # "unsupported range",
|
||||
(RuntimeError, ValueError),
|
||||
np.arange,
|
||||
0,
|
||||
np.inf, # "unsupported range",
|
||||
)
|
||||
|
||||
def test_nan_step(self):
|
||||
|
@ -2733,10 +2733,18 @@ class TestMoveaxis(TestCase):
|
||||
assert_raises(np.AxisError, np.moveaxis, x, 3, 0) # 'source.*out of bounds',
|
||||
assert_raises(np.AxisError, np.moveaxis, x, -4, 0) # 'source.*out of bounds',
|
||||
assert_raises(
|
||||
np.AxisError, np.moveaxis, x, 0, 5 # 'destination.*out of bounds',
|
||||
np.AxisError,
|
||||
np.moveaxis,
|
||||
x,
|
||||
0,
|
||||
5, # 'destination.*out of bounds',
|
||||
)
|
||||
assert_raises(
|
||||
ValueError, np.moveaxis, x, [0, 0], [0, 1] # 'repeated axis in `source`',
|
||||
ValueError,
|
||||
np.moveaxis,
|
||||
x,
|
||||
[0, 0],
|
||||
[0, 1], # 'repeated axis in `source`',
|
||||
)
|
||||
assert_raises(
|
||||
ValueError, # 'repeated axis in `destination`',
|
||||
|
@ -3,6 +3,7 @@
|
||||
"""
|
||||
Test the scalar constructors, which also do type-coercion
|
||||
"""
|
||||
|
||||
import functools
|
||||
from unittest import skipIf as skipif
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
"""
|
||||
Test the scalar constructors, which also do type-coercion
|
||||
"""
|
||||
|
||||
import fractions
|
||||
import functools
|
||||
import types
|
||||
|
@ -1,8 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
""" Test printing of scalar types.
|
||||
"""Test printing of scalar types."""
|
||||
|
||||
"""
|
||||
import functools
|
||||
from unittest import skipIf as skipif
|
||||
|
||||
|
@ -811,7 +811,10 @@ class TestBlock(TestCase):
|
||||
assert_raises_regex(ValueError, msg, block, [[1], 2])
|
||||
assert_raises_regex(ValueError, msg, block, [[], 2])
|
||||
assert_raises_regex(
|
||||
ValueError, msg, block, [[[1], [2]], [[3, 4]], [5]] # missing brackets
|
||||
ValueError,
|
||||
msg,
|
||||
block,
|
||||
[[[1], [2]], [[3, 4]], [5]], # missing brackets
|
||||
)
|
||||
|
||||
def test_empty_lists(self, block):
|
||||
|
@ -5,6 +5,7 @@
|
||||
Copied from fftpack.helper by Pearu Peterson, October 2005
|
||||
|
||||
"""
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
|
@ -372,7 +372,7 @@ class TestFFTThreadSafe(TestCase):
|
||||
assert_allclose(
|
||||
q.get(timeout=5),
|
||||
expected,
|
||||
atol=2e-14
|
||||
atol=2e-14,
|
||||
# msg="Function returned wrong value in multithreaded context",
|
||||
)
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
"""Test functions for 1D array set operations.
|
||||
"""Test functions for 1D array set operations."""
|
||||
|
||||
"""
|
||||
from unittest import expectedFailure as xfail, skipIf
|
||||
|
||||
import numpy
|
||||
|
@ -2881,7 +2881,8 @@ class TestPercentile(TestCase):
|
||||
np.testing.assert_equal(res.dtype, arr.dtype)
|
||||
|
||||
H_F_TYPE_CODES = [
|
||||
(int_type, np.float64) for int_type in "Bbhil" # np.typecodes["AllInteger"]
|
||||
(int_type, np.float64)
|
||||
for int_type in "Bbhil" # np.typecodes["AllInteger"]
|
||||
] + [
|
||||
(np.float16, np.float16),
|
||||
(np.float32, np.float32),
|
||||
|
@ -505,8 +505,7 @@ class TestHistogramOptimBinNums(TestCase):
|
||||
assert_equal(
|
||||
len(a),
|
||||
numbins,
|
||||
err_msg=f"For the {estimator} estimator "
|
||||
f"with datasize of {testlen}",
|
||||
err_msg=f"For the {estimator} estimator with datasize of {testlen}",
|
||||
)
|
||||
|
||||
def test_small(self):
|
||||
@ -552,8 +551,7 @@ class TestHistogramOptimBinNums(TestCase):
|
||||
assert_equal(
|
||||
len(a),
|
||||
expbins,
|
||||
err_msg=f"For the {estimator} estimator "
|
||||
f"with datasize of {testlen}",
|
||||
err_msg=f"For the {estimator} estimator with datasize of {testlen}",
|
||||
)
|
||||
|
||||
def test_incorrect_methods(self):
|
||||
|
@ -1,8 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
"""Test functions for matrix module
|
||||
"""Test functions for matrix module"""
|
||||
|
||||
"""
|
||||
import functools
|
||||
from unittest import expectedFailure as xfail, skipIf as skipif
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
""" Test functions for linalg module
|
||||
"""Test functions for linalg module"""
|
||||
|
||||
"""
|
||||
import functools
|
||||
import itertools
|
||||
import os
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
"""Light smoke test switching between numpy to pytorch random streams.
|
||||
"""
|
||||
"""Light smoke test switching between numpy to pytorch random streams."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
|
@ -9,6 +9,7 @@ The goal is to validate on numpy, and tests should work when replacing
|
||||
by
|
||||
>>> import torch._numpy as np
|
||||
"""
|
||||
|
||||
import operator
|
||||
from unittest import skipIf as skip, SkipTest
|
||||
|
||||
|
@ -39,12 +39,9 @@ USE_BLACK_FILELIST = re.compile(
|
||||
# test/**
|
||||
# test/[a-h]*/**
|
||||
# test/[i-j]*/**
|
||||
"test/j*/**",
|
||||
# test/[k-m]*/**
|
||||
"test/[k-m]*/**",
|
||||
# test/optim/**
|
||||
# "test/[p-z]*/**",
|
||||
"test/[p-z]*/**",
|
||||
# test/[p-z]*/**,
|
||||
# torch/**
|
||||
# torch/_[a-c]*/**
|
||||
# torch/_[e-h]*/**
|
||||
|
Reference in New Issue
Block a user