diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index bd2fcee5e51c..29671bc99931 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -177,14 +177,18 @@ blocklist = [ "copy_", ] -binary_ops = ( +shift_ops = ( + "lshift", + "rshift", + "ilshift", + "irshift", # inplace ops +) +arithmetic_ops = ( "add", "sub", "mul", "div", "pow", - "lshift", - "rshift", "mod", "truediv", "matmul", @@ -195,24 +199,26 @@ binary_ops = ( "rtruediv", "rfloordiv", "rpow", # reverse arithmetic + "iadd", + "idiv", + "imul", + "isub", + "ifloordiv", + "imod", # inplace ops +) +logic_ops = ( "and", "or", "xor", "rand", "ror", - "rxor", # logic - "iadd", + "rxor", # reverse logic "iand", - "idiv", - "ilshift", - "imul", "ior", - "irshift", - "isub", - "ixor", - "ifloordiv", - "imod", # inplace ops + "ixor", # inplace ops ) +binary_ops = shift_ops + arithmetic_ops + logic_ops + symmetric_comparison_ops = ("eq", "ne") asymmetric_comparison_ops = ("ge", "gt", "lt", "le") comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops @@ -232,14 +238,28 @@ def sig_for_ops(opname: str) -> list[str]: assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}" name = opname[2:-2] - if name in binary_ops: - return [f"def {opname}(self, other: Any) -> Tensor: ..."] - elif name in comparison_ops: - sig = f"def {opname}(self, other: Any) -> Tensor: ..." - if name in symmetric_comparison_ops: + if name == "rpow": + return [ # somehow required to make mypy ci happy? + f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ... # type: ignore[has-type]" + ] + elif name in arithmetic_ops: + return [ + f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ..." + ] + elif name in logic_ops: + return [f"def {opname}(self, other: Union[Tensor, _bool]) -> Tensor: ..."] + elif name in shift_ops: + return [f"def {opname}(self, other: Union[Tensor, _int]) -> Tensor: ..."] + elif name in symmetric_comparison_ops: + return [ # unsafe override https://github.com/python/mypy/issues/5704 - sig += " # type: ignore[override]" - return [sig] + f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ... # type: ignore[override]", + f"def {opname}(self, other: Any) -> _bool: ...", + ] + elif name in asymmetric_comparison_ops: + return [ + f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ..." + ] elif name in unary_ops: return [f"def {opname}(self) -> Tensor: ..."] elif name in to_py_type_ops: diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 8822a3840aac..359cc3c06a07 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2291,7 +2291,8 @@ def native_batch_norm_backward( mean = save_mean_cast invstd = save_invstd_cast if train: - assert save_mean_cast is not None and save_invstd_cast is not None + assert mean is not None and invstd is not None + else: assert running_mean_cast is not None and running_var_cast is not None mean = running_mean_cast diff --git a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py index 4845142caab3..bc6ebbcd5cef 100644 --- a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +++ b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -33,6 +33,7 @@ def efficient_conv_bn_eval( """ assert bn.running_var is not None + assert bn.running_mean is not None # These lines of code are designed to deal with various cases # like bn without affine transform, and conv without bias diff --git a/torch/ao/quantization/_equalize.py b/torch/ao/quantization/_equalize.py index 08316f755552..57a4cfdead29 100644 --- a/torch/ao/quantization/_equalize.py +++ b/torch/ao/quantization/_equalize.py @@ -128,8 +128,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): "module type not supported:", type(module1), " ", type(module2) ) - conv1_has_bias = has_bias(module1) - bias = None + bias = get_module_bias(module1) if has_bias(module1) else None weight1 = get_module_weight(module1) weight2 = get_module_weight(module2) @@ -140,9 +139,6 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): number input channels of second arg" ) - if conv1_has_bias: - bias = get_module_bias(module1) - weight1_range = channel_range(weight1, output_axis) weight2_range = channel_range(weight2, input_axis) @@ -151,7 +147,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): scaling_factors = torch.sqrt(weight1_range / weight2_range) inverse_scaling_factors = torch.reciprocal(scaling_factors) - if conv1_has_bias: + if bias is not None: bias = bias * inverse_scaling_factors # formatting the scaling (1D) tensors to be applied on the given argument tensors @@ -168,7 +164,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): weight2 = weight2 * scaling_factors set_module_weight(module1, weight1) - if conv1_has_bias: + if bias is not None: set_module_bias(module1, bias) set_module_weight(module2, weight2)