mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Partiall addresses #123062 UFMT formatting on - test/distributions - test/error_messages, test/forward_backward_compatability Pull Request resolved: https://github.com/pytorch/pytorch/pull/123527 Approved by: https://github.com/huydhn
77 lines
1.6 KiB
Python
77 lines
1.6 KiB
Python
import torch
|
|
|
|
|
|
def check_error(desc, fn, *required_substrings):
|
|
try:
|
|
fn()
|
|
except Exception as e:
|
|
error_message = e.args[0]
|
|
print("=" * 80)
|
|
print(desc)
|
|
print("-" * 80)
|
|
print(error_message)
|
|
print("")
|
|
for sub in required_substrings:
|
|
assert sub in error_message
|
|
return
|
|
raise AssertionError(f"given function ({desc}) didn't raise an error")
|
|
|
|
|
|
check_error("Wrong argument types", lambda: torch.FloatStorage(object()), "object")
|
|
|
|
check_error(
|
|
"Unknown keyword argument", lambda: torch.FloatStorage(content=1234.0), "keyword"
|
|
)
|
|
|
|
check_error(
|
|
"Invalid types inside a sequence",
|
|
lambda: torch.FloatStorage(["a", "b"]),
|
|
"list",
|
|
"str",
|
|
)
|
|
|
|
check_error("Invalid size type", lambda: torch.FloatStorage(1.5), "float")
|
|
|
|
check_error(
|
|
"Invalid offset", lambda: torch.FloatStorage(torch.FloatStorage(2), 4), "2", "4"
|
|
)
|
|
|
|
check_error(
|
|
"Negative offset", lambda: torch.FloatStorage(torch.FloatStorage(2), -1), "2", "-1"
|
|
)
|
|
|
|
check_error(
|
|
"Invalid size",
|
|
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
|
|
"2",
|
|
"1",
|
|
"5",
|
|
)
|
|
|
|
check_error(
|
|
"Negative size",
|
|
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
|
|
"2",
|
|
"1",
|
|
"-5",
|
|
)
|
|
|
|
check_error("Invalid index type", lambda: torch.FloatStorage(10)["first item"], "str")
|
|
|
|
|
|
def assign():
|
|
torch.FloatStorage(10)[1:-1] = "1"
|
|
|
|
|
|
check_error("Invalid value type", assign, "str")
|
|
|
|
check_error(
|
|
"resize_ with invalid type", lambda: torch.FloatStorage(10).resize_(1.5), "float"
|
|
)
|
|
|
|
check_error(
|
|
"fill_ with invalid type", lambda: torch.IntStorage(10).fill_("asdf"), "str"
|
|
)
|
|
|
|
# TODO: frombuffer
|