mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Validate input types for torch.nn.Linear
and torch.nn.Bilinear
(#135596)
Adding validation checks to check the input types and display better error messages for the same. Fixes https://github.com/pytorch/pytorch/issues/135463 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135596 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
b897ab0540
commit
e157ce3ebb
@ -98,6 +98,13 @@ class Linear(Module):
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
# validation checks for input types
|
||||
if not isinstance(in_features, int):
|
||||
raise TypeError(f"Expected int for in_features but got {type(in_features)}")
|
||||
if not isinstance(out_features, int):
|
||||
raise TypeError(
|
||||
f"Expected int for out_features but got {type(out_features)}"
|
||||
)
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
@ -200,6 +207,19 @@ class Bilinear(Module):
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
# validation checks for input types
|
||||
if not isinstance(in1_features, int):
|
||||
raise TypeError(
|
||||
f"Expected int for in1_features but got {type(in1_features)}"
|
||||
)
|
||||
if not isinstance(in2_features, int):
|
||||
raise TypeError(
|
||||
f"Expected int for in2_features but got {type(in2_features)}"
|
||||
)
|
||||
if not isinstance(out_features, int):
|
||||
raise TypeError(
|
||||
f"Expected int for out_features but got {type(out_features)}"
|
||||
)
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.in1_features = in1_features
|
||||
|
@ -3180,6 +3180,66 @@ rnn_gru_lstm_module_info_decorators = (
|
||||
|
||||
# Start of module error inputs functions.
|
||||
|
||||
def module_error_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
samples = [
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput("10", 20),
|
||||
forward_input=FunctionInput(make_input(3, 10)),
|
||||
),
|
||||
error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
|
||||
error_type=TypeError,
|
||||
error_regex=r"Expected int for in_features but got <class 'str'>"
|
||||
),
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(10, 20.7),
|
||||
forward_input=FunctionInput(make_input(3, 10)),
|
||||
),
|
||||
error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
|
||||
error_type=TypeError,
|
||||
error_regex=r"Expected int for out_features but got <class 'float'>"
|
||||
),
|
||||
]
|
||||
return samples
|
||||
|
||||
|
||||
def module_error_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
samples = [
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput("10", 20, 30),
|
||||
forward_input=FunctionInput(make_input(3, 10), make_input(3, 20)),
|
||||
),
|
||||
error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
|
||||
error_type=TypeError,
|
||||
error_regex=r"Expected int for in1_features but got <class 'str'>"
|
||||
),
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(10, 20.7, 30),
|
||||
forward_input=FunctionInput(make_input(3, 10), make_input(3, 20)),
|
||||
),
|
||||
error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
|
||||
error_type=TypeError,
|
||||
error_regex=r"Expected int for in2_features but got <class 'float'>"
|
||||
),
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(10, 20, "30"),
|
||||
forward_input=FunctionInput(make_input(3, 10), make_input(3, 20)),
|
||||
),
|
||||
error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
|
||||
error_type=TypeError,
|
||||
error_regex=r"Expected int for out_features but got <class 'str'>"
|
||||
),
|
||||
]
|
||||
return samples
|
||||
|
||||
|
||||
|
||||
def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
samples = [
|
||||
@ -3832,12 +3892,14 @@ module_db: List[ModuleInfo] = [
|
||||
)),
|
||||
ModuleInfo(torch.nn.Linear,
|
||||
module_inputs_func=module_inputs_torch_nn_Linear,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Linear,
|
||||
skips=(
|
||||
# No channels_last support for Linear currently.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
|
||||
),
|
||||
ModuleInfo(torch.nn.Bilinear,
|
||||
module_inputs_func=module_inputs_torch_nn_Bilinear,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Bilinear,
|
||||
decorators=[
|
||||
DecorateInfo(
|
||||
toleranceOverride({
|
||||
|
Reference in New Issue
Block a user