diff --git a/pyrefly.toml b/pyrefly.toml index 8abdef644671..d4146bf88d4a 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -25,10 +25,6 @@ project-excludes = [ "torch/nn/**", "torch/_dynamo/**", "torch/utils/**", - "torch/ao/**", - "torch/fx/**", - "torch/distributions/**", - "torch/onnx/**", # formatting issues "torch/linalg/__init__.py", "torch/package/importer.py", diff --git a/torch/_export/utils.py b/torch/_export/utils.py index f8f247b56a52..939160a48815 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -470,7 +470,6 @@ def _check_input_constraints_for_graph( ) elif isinstance(node_val, torch.SymInt): _check_symint( - # pyrefly: ignore # bad-argument-type node_val, # pyrefly: ignore # bad-argument-type arg, diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 34218efe3a31..bdbf56e357b4 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -360,7 +360,6 @@ def trace_flex_attention( "call_function", flex_attention, proxy_args, {} ) return track_tensor_tree( - # pyrefly: ignore # bad-argument-type example_out, out_proxy, constant=None, @@ -1080,7 +1079,6 @@ def trace_flex_attention_backward( name="flex_attention_backward", ) return track_tensor_tree( - # pyrefly: ignore # bad-argument-type example_out, out_proxy, constant=None, diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 5a5a04c27c31..27be6dca3e50 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -899,7 +899,6 @@ def analyze_kernel_mutations( if op.name == "tt.call": assert op.fn_call_name in functions mutations = analyze_kernel_mutations( - # pyrefly: ignore # bad-argument-type functions, # pyrefly: ignore # bad-argument-type op.fn_call_name, diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index e40218dd6a58..0b6869622303 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -3379,7 +3379,6 @@ def native_layer_norm( torch._check( input.ndim >= normalized_ndim and sym_eq( - # pyrefly: ignore # bad-argument-type input.shape[(input.ndim - normalized_ndim) :], # pyrefly: ignore # bad-argument-type tuple(normalized_shape), diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 6671e317b6b0..8fb64a101ff6 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -620,6 +620,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): dilation=dilation, groups=groups, bias=bias, + # pyrefly: ignore # bad-argument-type padding_mode=padding_mode, qconfig=qconfig, ) @@ -820,6 +821,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): dilation=dilation, groups=groups, bias=bias, + # pyrefly: ignore # bad-argument-type padding_mode=padding_mode, qconfig=qconfig, ) @@ -1021,6 +1023,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): dilation=dilation, groups=groups, bias=bias, + # pyrefly: ignore # bad-argument-type padding_mode=padding_mode, qconfig=qconfig, ) diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py index 8446468dddcf..075e1411a45f 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py @@ -36,6 +36,7 @@ class LinearReLU(nnqat.Linear, _FusedModule): torch.Size([128, 30]) """ + # pyrefly: ignore # bad-override _FLOAT_MODULE = nni.LinearReLU def __init__( diff --git a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py index a9566b268f08..2458c32ae4f3 100644 --- a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -30,6 +30,7 @@ class LinearReLU(nnqd.Linear): torch.Size([128, 30]) """ + # pyrefly: ignore # bad-override _FLOAT_MODULE = nni.LinearReLU def __init__( diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py index 8172004d95fc..c8024c5b4c58 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py @@ -54,6 +54,7 @@ class ConvReLU1d(nnq.Conv1d): dilation=dilation, groups=groups, bias=bias, + # pyrefly: ignore # bad-argument-type padding_mode=padding_mode, device=device, dtype=dtype, diff --git a/torch/ao/nn/qat/modules/conv.py b/torch/ao/nn/qat/modules/conv.py index 4a193fa6763c..846ba2b4fd26 100644 --- a/torch/ao/nn/qat/modules/conv.py +++ b/torch/ao/nn/qat/modules/conv.py @@ -114,6 +114,7 @@ class _ConvNd(nn.modules.conv._ConvNd): assert hasattr(cls, "_FLOAT_RELU_MODULE") relu = cls._FLOAT_RELU_MODULE() modules.append(relu) + # pyrefly: ignore # missing-attribute fused = cls._FLOAT_MODULE(*modules) fused.train(self.training) return fused diff --git a/torch/ao/nn/qat/modules/embedding_ops.py b/torch/ao/nn/qat/modules/embedding_ops.py index 13fd7a5983fb..d199e6d46f74 100644 --- a/torch/ao/nn/qat/modules/embedding_ops.py +++ b/torch/ao/nn/qat/modules/embedding_ops.py @@ -50,6 +50,7 @@ class Embedding(nn.Embedding): scale_grad_by_freq, sparse, _weight, + # pyrefly: ignore # bad-argument-type **factory_kwargs, ) assert qconfig, "qconfig must be provided for QAT module" diff --git a/torch/ao/nn/quantizable/modules/activation.py b/torch/ao/nn/quantizable/modules/activation.py index d9f5e4ff4c86..23ef0f6a8d37 100644 --- a/torch/ao/nn/quantizable/modules/activation.py +++ b/torch/ao/nn/quantizable/modules/activation.py @@ -170,8 +170,11 @@ class MultiheadAttention(nn.MultiheadAttention): observed.linear_K.weight = nn.Parameter(other.k_proj_weight) observed.linear_V.weight = nn.Parameter(other.v_proj_weight) if other.in_proj_bias is None: + # pyrefly: ignore # bad-assignment observed.linear_Q.bias = None + # pyrefly: ignore # bad-assignment observed.linear_K.bias = None + # pyrefly: ignore # bad-assignment observed.linear_V.bias = None else: observed.linear_Q.bias = nn.Parameter( @@ -234,6 +237,7 @@ class MultiheadAttention(nn.MultiheadAttention): _end = _start + fp.embed_dim fp.in_proj_weight[_start:_end, :] = wQ if fp.in_proj_bias is not None: + # pyrefly: ignore # bad-argument-type assert all(bQ == 0) fp.in_proj_bias[_start:_end] = bQ @@ -241,12 +245,14 @@ class MultiheadAttention(nn.MultiheadAttention): _end = _start + fp.embed_dim fp.in_proj_weight[_start:_end, :] = wK if fp.in_proj_bias is not None: + # pyrefly: ignore # bad-argument-type assert all(bK == 0) fp.in_proj_bias[_start:_end] = bK _start = _end fp.in_proj_weight[_start:, :] = wV if fp.in_proj_bias is not None: + # pyrefly: ignore # bad-argument-type assert all(bV == 0) fp.in_proj_bias[_start:] = bV else: @@ -254,8 +260,11 @@ class MultiheadAttention(nn.MultiheadAttention): fp.k_proj_weight = nn.Parameter(wK) fp.v_proj_weight = nn.Parameter(wV) if fp.in_proj_bias is None: + # pyrefly: ignore # bad-assignment self.linear_Q.bias = None + # pyrefly: ignore # bad-assignment self.linear_K.bias = None + # pyrefly: ignore # bad-assignment self.linear_V.bias = None else: fp.in_proj_bias[0 : fp.embed_dim] = bQ @@ -463,6 +472,7 @@ class MultiheadAttention(nn.MultiheadAttention): assert static_v.size(2) == head_dim v = static_v + # pyrefly: ignore # missing-attribute src_len = k.size(1) if key_padding_mask is not None: @@ -471,17 +481,35 @@ class MultiheadAttention(nn.MultiheadAttention): if self.add_zero_attn: src_len += 1 + # pyrefly: ignore # missing-attribute k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) + # pyrefly: ignore # missing-attribute if k.is_quantized: k_zeros = torch.quantize_per_tensor( - k_zeros, k.q_scale(), k.q_zero_point(), k.dtype + k_zeros, + # pyrefly: ignore # missing-attribute + k.q_scale(), + # pyrefly: ignore # missing-attribute + k.q_zero_point(), + # pyrefly: ignore # missing-attribute + k.dtype, ) + # pyrefly: ignore # no-matching-overload k = torch.cat([k, k_zeros], dim=1) + # pyrefly: ignore # missing-attribute v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) + # pyrefly: ignore # missing-attribute if v.is_quantized: v_zeros = torch.quantize_per_tensor( - v_zeros, v.q_scale(), v.q_zero_point(), v.dtype + v_zeros, + # pyrefly: ignore # missing-attribute + v.q_scale(), + # pyrefly: ignore # missing-attribute + v.q_zero_point(), + # pyrefly: ignore # missing-attribute + v.dtype, ) + # pyrefly: ignore # no-matching-overload v = torch.cat([v, v_zeros], dim=1) if attn_mask is not None: diff --git a/torch/ao/nn/quantizable/modules/rnn.py b/torch/ao/nn/quantizable/modules/rnn.py index ad32cf174c62..7588d83e746a 100644 --- a/torch/ao/nn/quantizable/modules/rnn.py +++ b/torch/ao/nn/quantizable/modules/rnn.py @@ -376,6 +376,7 @@ class _LSTMLayer(torch.nn.Module): bidirectional, split_gates=split_gates, ) + # pyrefly: ignore # bad-argument-type layer.qconfig = getattr(other, "qconfig", qconfig) wi = getattr(other, f"weight_ih_l{layer_idx}") wh = getattr(other, f"weight_hh_l{layer_idx}") @@ -454,6 +455,7 @@ class LSTM(torch.nn.Module): if ( not isinstance(dropout, numbers.Number) + # pyrefly: ignore # unsupported-operation or not 0 <= dropout <= 1 or isinstance(dropout, bool) ): @@ -462,6 +464,7 @@ class LSTM(torch.nn.Module): "representing the probability of an element being " "zeroed" ) + # pyrefly: ignore # unsupported-operation if dropout > 0: warnings.warn( "dropout option for quantizable LSTM is ignored. " @@ -573,6 +576,7 @@ class LSTM(torch.nn.Module): other.bidirectional, split_gates=split_gates, ) + # pyrefly: ignore # bad-argument-type observed.qconfig = getattr(other, "qconfig", qconfig) for idx in range(other.num_layers): observed.layers[idx] = _LSTMLayer.from_float( diff --git a/torch/ao/nn/quantized/dynamic/modules/conv.py b/torch/ao/nn/quantized/dynamic/modules/conv.py index a079f31f62e4..1f8a65fe9d66 100644 --- a/torch/ao/nn/quantized/dynamic/modules/conv.py +++ b/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -73,6 +73,7 @@ class Conv1d(nnq.Conv1d): factory_kwargs = {"device": device, "dtype": dtype} kernel_size = _single(kernel_size) stride = _single(stride) + # pyrefly: ignore # bad-assignment padding = padding if isinstance(padding, str) else _single(padding) dilation = _single(dilation) diff --git a/torch/ao/nn/quantized/dynamic/modules/linear.py b/torch/ao/nn/quantized/dynamic/modules/linear.py index 0faaf62cedb5..6fa5ee65f3b5 100644 --- a/torch/ao/nn/quantized/dynamic/modules/linear.py +++ b/torch/ao/nn/quantized/dynamic/modules/linear.py @@ -119,7 +119,9 @@ class Linear(nnq.Linear): assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" if type(mod) == nni.LinearReLU: mod = mod[0] + # pyrefly: ignore # missing-attribute if mod.qconfig is not None and mod.qconfig.weight is not None: + # pyrefly: ignore # not-callable weight_observer = mod.qconfig.weight() else: # We have the circular import issues if we import the qconfig in the beginning of this file: @@ -143,6 +145,7 @@ class Linear(nnq.Linear): "Unsupported dtype specified for dynamic quantized Linear!" ) qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) + # pyrefly: ignore # bad-argument-type qlinear.set_weight_bias(qweight, mod.bias) return qlinear diff --git a/torch/ao/nn/quantized/dynamic/modules/rnn.py b/torch/ao/nn/quantized/dynamic/modules/rnn.py index 10db59aafbf7..cdfecc95d723 100644 --- a/torch/ao/nn/quantized/dynamic/modules/rnn.py +++ b/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -521,6 +521,7 @@ class LSTM(RNNBase): >>> output, (hn, cn) = rnn(input, (h0, c0)) """ + # pyrefly: ignore # bad-override _FLOAT_MODULE = nn.LSTM __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} @@ -806,6 +807,7 @@ class GRU(RNNBase): >>> output, hn = rnn(input, h0) """ + # pyrefly: ignore # bad-override _FLOAT_MODULE = nn.GRU __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} diff --git a/torch/ao/nn/quantized/modules/activation.py b/torch/ao/nn/quantized/modules/activation.py index 15b4d36e8b44..67b69eb7390c 100644 --- a/torch/ao/nn/quantized/modules/activation.py +++ b/torch/ao/nn/quantized/modules/activation.py @@ -67,7 +67,9 @@ class Hardswish(torch.nn.Hardswish): def __init__(self, scale, zero_point, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() + # pyrefly: ignore # bad-argument-type self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore # bad-argument-type self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): @@ -138,7 +140,9 @@ class LeakyReLU(torch.nn.LeakyReLU): ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(negative_slope, inplace) + # pyrefly: ignore # bad-argument-type self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore # bad-argument-type self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): @@ -226,6 +230,7 @@ class Softmax(torch.nn.Softmax): class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention): + # pyrefly: ignore # bad-override _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention def _get_name(self): diff --git a/torch/ao/nn/quantized/modules/batchnorm.py b/torch/ao/nn/quantized/modules/batchnorm.py index 069db116a064..bd426038657c 100644 --- a/torch/ao/nn/quantized/modules/batchnorm.py +++ b/torch/ao/nn/quantized/modules/batchnorm.py @@ -12,7 +12,9 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm): ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(num_features, eps, momentum, True, True, **factory_kwargs) + # pyrefly: ignore # bad-argument-type self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs)) + # pyrefly: ignore # bad-argument-type self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs)) @staticmethod diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index 592c5893d113..1bec74975f8a 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -408,6 +408,7 @@ class Conv1d(_ConvNd): factory_kwargs = {"device": device, "dtype": dtype} kernel_size = _single(kernel_size) stride = _single(stride) + # pyrefly: ignore # bad-assignment padding = padding if isinstance(padding, str) else _single(padding) dilation = _single(dilation) diff --git a/torch/ao/nn/quantized/modules/linear.py b/torch/ao/nn/quantized/modules/linear.py index 9042833f5e30..54d3adf19e00 100644 --- a/torch/ao/nn/quantized/modules/linear.py +++ b/torch/ao/nn/quantized/modules/linear.py @@ -310,6 +310,7 @@ class Linear(WeightedQuantizedModule): # the type mismatch in assignment. Also, mypy has an issue with # iterables not being implemented, so we are ignoring those too. if not isinstance(cls._FLOAT_MODULE, Iterable): + # pyrefly: ignore # bad-assignment cls._FLOAT_MODULE = [cls._FLOAT_MODULE] supported_modules = ", ".join( [float_mod.__name__ for float_mod in cls._FLOAT_MODULE] diff --git a/torch/ao/nn/quantized/modules/normalization.py b/torch/ao/nn/quantized/modules/normalization.py index 4db2ac6e928f..d5c6c4d41c5c 100644 --- a/torch/ao/nn/quantized/modules/normalization.py +++ b/torch/ao/nn/quantized/modules/normalization.py @@ -37,11 +37,14 @@ class LayerNorm(torch.nn.LayerNorm): normalized_shape, eps=eps, elementwise_affine=elementwise_affine, + # pyrefly: ignore # bad-argument-type **factory_kwargs, ) self.weight = weight self.bias = bias + # pyrefly: ignore # bad-argument-type self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore # bad-argument-type self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): @@ -113,7 +116,9 @@ class GroupNorm(torch.nn.GroupNorm): super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs) self.weight = weight self.bias = bias + # pyrefly: ignore # bad-argument-type self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore # bad-argument-type self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): @@ -175,7 +180,9 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d): ) self.weight = weight self.bias = bias + # pyrefly: ignore # bad-argument-type self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore # bad-argument-type self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): @@ -242,7 +249,9 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d): ) self.weight = weight self.bias = bias + # pyrefly: ignore # bad-argument-type self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore # bad-argument-type self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): @@ -309,7 +318,9 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d): ) self.weight = weight self.bias = bias + # pyrefly: ignore # bad-argument-type self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore # bad-argument-type self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): diff --git a/torch/ao/nn/quantized/reference/modules/conv.py b/torch/ao/nn/quantized/reference/modules/conv.py index de2ea9c6da8d..1e9cbceb7c12 100644 --- a/torch/ao/nn/quantized/reference/modules/conv.py +++ b/torch/ao/nn/quantized/reference/modules/conv.py @@ -95,6 +95,7 @@ class Conv1d(_ConvNd, nn.Conv1d): and the backend should be able to fuse the ops with `*` into a quantized conv1d """ weight_quant_dequant = self.get_weight() + # pyrefly: ignore # no-matching-overload result = F.conv1d( x, weight_quant_dequant, @@ -140,6 +141,7 @@ class Conv2d(_ConvNd, nn.Conv2d): dilation, groups, bias, + # pyrefly: ignore # bad-argument-type padding_mode, device, dtype, @@ -158,6 +160,7 @@ class Conv2d(_ConvNd, nn.Conv2d): and the backend should be able to fuse the ops with `*` into a quantized conv2d """ weight_quant_dequant = self.get_weight() + # pyrefly: ignore # no-matching-overload result = F.conv2d( x, weight_quant_dequant, @@ -203,6 +206,7 @@ class Conv3d(_ConvNd, nn.Conv3d): dilation, groups, bias, + # pyrefly: ignore # bad-argument-type padding_mode, device, dtype, @@ -221,6 +225,7 @@ class Conv3d(_ConvNd, nn.Conv3d): and the backend should be able to fuse the ops with `*` into a quantized conv3d """ weight_quant_dequant = self.get_weight() + # pyrefly: ignore # no-matching-overload result = F.conv3d( x, weight_quant_dequant, @@ -378,6 +383,7 @@ class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d): groups, bias, dilation, + # pyrefly: ignore # bad-argument-type padding_mode, device, dtype, @@ -459,6 +465,7 @@ class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d): groups, bias, dilation, + # pyrefly: ignore # bad-argument-type padding_mode, device, dtype, diff --git a/torch/ao/nn/quantized/reference/modules/rnn.py b/torch/ao/nn/quantized/reference/modules/rnn.py index adb1356cb3d3..f7b9c447e9aa 100644 --- a/torch/ao/nn/quantized/reference/modules/rnn.py +++ b/torch/ao/nn/quantized/reference/modules/rnn.py @@ -663,7 +663,11 @@ class LSTM(RNNBase): # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): output_packed = PackedSequence( - output, batch_sizes, sorted_indices, unsorted_indices + output, + # pyrefly: ignore # bad-argument-type + batch_sizes, + sorted_indices, + unsorted_indices, ) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: @@ -823,7 +827,11 @@ class GRU(RNNBase): # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): output_packed = PackedSequence( - output, batch_sizes, sorted_indices, unsorted_indices + output, + # pyrefly: ignore # bad-argument-type + batch_sizes, + sorted_indices, + unsorted_indices, ) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: diff --git a/torch/ao/nn/quantized/reference/modules/sparse.py b/torch/ao/nn/quantized/reference/modules/sparse.py index 7e4bdb9b02c7..23d85ac46d09 100644 --- a/torch/ao/nn/quantized/reference/modules/sparse.py +++ b/torch/ao/nn/quantized/reference/modules/sparse.py @@ -42,6 +42,7 @@ class Embedding(nn.Embedding, ReferenceQuantizedModule): scale_grad_by_freq, sparse, _weight, + # pyrefly: ignore # bad-argument-type device, dtype, ) diff --git a/torch/ao/nn/quantized/reference/modules/utils.py b/torch/ao/nn/quantized/reference/modules/utils.py index 653e688c4d17..aaa13274678b 100644 --- a/torch/ao/nn/quantized/reference/modules/utils.py +++ b/torch/ao/nn/quantized/reference/modules/utils.py @@ -18,6 +18,7 @@ class ReferenceQuantizedModule(torch.nn.Module): "scale": 1.0, "zero_point": 0, } + # pyrefly: ignore # bad-assignment self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"] self.weight_dtype = weight_qparams["dtype"] assert self.weight_qscheme in [ @@ -80,13 +81,16 @@ class ReferenceQuantizedModule(torch.nn.Module): self.register_buffer( "weight_axis", torch.tensor(0, dtype=torch.int, device=device) ) + # pyrefly: ignore # bad-assignment self.is_decomposed: bool = weight_qparams.get("is_decomposed", False) # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export # for capturing `.item` operations self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment] + # pyrefly: ignore # bad-assignment self.weight_quant_min: typing.Optional[int] = weight_qparams.get( "quant_min", None ) + # pyrefly: ignore # bad-assignment self.weight_quant_max: typing.Optional[int] = weight_qparams.get( "quant_max", None ) @@ -105,6 +109,7 @@ class ReferenceQuantizedModule(torch.nn.Module): return _quantize_and_dequantize_weight_decomposed( self.weight, # type: ignore[arg-type] self.weight_qscheme, + # pyrefly: ignore # bad-argument-type self.weight_dtype, self.weight_scale, self.weight_zero_point, @@ -116,6 +121,7 @@ class ReferenceQuantizedModule(torch.nn.Module): return _quantize_and_dequantize_weight( self.weight, # type: ignore[arg-type] self.weight_qscheme, + # pyrefly: ignore # bad-argument-type self.weight_dtype, self.weight_scale, self.weight_zero_point, @@ -131,6 +137,7 @@ class ReferenceQuantizedModule(torch.nn.Module): return _quantize_weight_decomposed( self.weight, # type: ignore[arg-type] self.weight_qscheme, + # pyrefly: ignore # bad-argument-type self.weight_dtype, self.weight_scale, self.weight_zero_point, @@ -142,6 +149,7 @@ class ReferenceQuantizedModule(torch.nn.Module): return _quantize_weight( self.weight, # type: ignore[arg-type] self.weight_qscheme, + # pyrefly: ignore # bad-argument-type self.weight_dtype, self.weight_scale, self.weight_zero_point, diff --git a/torch/ao/nn/sparse/quantized/dynamic/linear.py b/torch/ao/nn/sparse/quantized/dynamic/linear.py index 6da18e151012..5ae9a9227dbb 100644 --- a/torch/ao/nn/sparse/quantized/dynamic/linear.py +++ b/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -151,7 +151,9 @@ class Linear(torch.nn.Module): assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" if type(mod) == nni.LinearReLU: mod = mod[0] + # pyrefly: ignore # missing-attribute if mod.qconfig is not None and mod.qconfig.weight is not None: + # pyrefly: ignore # not-callable weight_observer = mod.qconfig.weight() else: # We have the circular import issues if we import the qconfig in the beginning of this file: @@ -185,5 +187,6 @@ class Linear(torch.nn.Module): col_block_size, dtype=dtype, ) + # pyrefly: ignore # bad-argument-type qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size) return qlinear diff --git a/torch/ao/ns/fx/graph_matcher.py b/torch/ao/ns/fx/graph_matcher.py index 1f9c873971a3..ea65dff4eb54 100644 --- a/torch/ao/ns/fx/graph_matcher.py +++ b/torch/ao/ns/fx/graph_matcher.py @@ -84,6 +84,7 @@ class _NSGraphMatchableSubgraphsIterator: if is_match: # navigate to the base node for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1): + # pyrefly: ignore # bad-argument-type self.seen_nodes.add(cur_start_node) # for now, assume that there are no other nodes # which need to be added to the stack @@ -94,8 +95,10 @@ class _NSGraphMatchableSubgraphsIterator: cur_base_op_node = cur_start_node break + # pyrefly: ignore # bad-argument-type self.seen_nodes.add(cur_start_node) # add args of previous nodes to stack + # pyrefly: ignore # missing-attribute for arg in cur_start_node.all_input_nodes: self._recursively_add_node_arg_to_stack(arg) @@ -103,6 +106,7 @@ class _NSGraphMatchableSubgraphsIterator: # note: this check is done on the start_node, i.e. # if we are matching linear-relu in reverse, this would do the matchable # check on the linear + # pyrefly: ignore # bad-argument-type if not self._is_matchable(cur_base_op_node): continue @@ -116,8 +120,10 @@ class _NSGraphMatchableSubgraphsIterator: continue return NSSubgraph( + # pyrefly: ignore # bad-argument-type start_node=cur_start_node, end_node=cur_end_node, + # pyrefly: ignore # bad-argument-type base_op_node=cur_base_op_node, ) diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py index a8ca955d22fa..49aa0a2c995a 100644 --- a/torch/ao/ns/fx/mappings.py +++ b/torch/ao/ns/fx/mappings.py @@ -415,6 +415,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]: target2, ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): new_connections.append((source, target1)) + # pyrefly: ignore # bad-argument-type new_connections.append((source, target2)) for source_to_target in ( @@ -423,6 +424,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]: quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, ): for source, target in source_to_target.items(): # type:ignore[assignment] + # pyrefly: ignore # bad-argument-type new_connections.append((source, target)) # diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index 5d8b569036ff..3b2453d8cc28 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -95,6 +95,7 @@ class OutputProp: if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined] node.traced_result = result + # pyrefly: ignore # unsupported-operation env[node.name] = result return None @@ -393,8 +394,10 @@ def create_submodule_from_subgraph( cur_name_idx += 1 setattr(gm, mod_name, new_arg) new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator] + # pyrefly: ignore # missing-attribute cur_args_copy.append(new_arg_placeholder) elif isinstance(arg, (float, int, torch.dtype)): + # pyrefly: ignore # missing-attribute cur_args_copy.append(arg) else: raise AssertionError(f"arg of type {type(arg)} not handled yet") @@ -801,6 +804,7 @@ def create_add_loggers_graph( model, cur_subgraph_idx, match_name, + # pyrefly: ignore # bad-argument-type maybe_subgraph, [qconfig_mapping], [node_name_to_qconfig], @@ -857,6 +861,7 @@ def create_add_loggers_graph( cur_node_orig = first_node cur_node_copy = None first_node_copy = None + # pyrefly: ignore # bad-assignment while cur_node_orig in subgraph_to_use: # TODO(future PR): make this support all possible args/kwargs if cur_node_orig is first_node: diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py index b6357120dc14..d1c5f062a6c1 100644 --- a/torch/ao/ns/fx/utils.py +++ b/torch/ao/ns/fx/utils.py @@ -404,6 +404,7 @@ def maybe_add_missing_fqns(results: NSResultsType) -> None: for model_name, model_results in model_name_to_results.items(): if model_name == model_name_with_fqns: continue + # pyrefly: ignore # bad-assignment for i in range(len(model_results)): fqn = ref_model_results[i]["fqn"] model_results[i]["fqn"] = fqn @@ -467,6 +468,7 @@ def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso Return: float or tuple of floats """ + # pyrefly: ignore # unsupported-operation return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum()) diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py index 2bfaac1cef49..1c0dfd502bea 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py @@ -23,12 +23,17 @@ class SparseDLRM(DLRM_Net): super().__init__(**args) def forward(self, dense_x, lS_o, lS_i): + # pyrefly: ignore # missing-attribute x = self.apply_mlp(dense_x, self.bot_l) # dense features + # pyrefly: ignore # missing-attribute ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag + # pyrefly: ignore # missing-attribute z = self.interact_features(x, ly) z = z.to_sparse_coo() + # pyrefly: ignore # missing-attribute z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias) + # pyrefly: ignore # missing-attribute for layer in self.top_l[1:]: z = layer(z) diff --git a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py index a716d91bbb8e..b2e44a5ed249 100644 --- a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -72,6 +72,7 @@ class FPGMPruner(BaseStructuredSparsifier): dist_matrix = self.dist_fn(t_flatten) # more similar with other filter indicates large in the sum of row + # pyrefly: ignore # bad-argument-type distance = torch.sum(torch.abs(dist_matrix), 1) return distance diff --git a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py index 826ad95bf63b..33ecf08b79ed 100644 --- a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py +++ b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -260,6 +260,7 @@ class BaseStructuredSparsifier(BaseSparsifier): module.register_parameter( "_bias", nn.Parameter(module.bias.detach()) ) + # pyrefly: ignore # bad-assignment module.bias = None module.prune_bias = prune_bias diff --git a/torch/ao/pruning/_experimental/pruner/prune_functions.py b/torch/ao/pruning/_experimental/pruner/prune_functions.py index d5e4b7823dc4..143a1f844ba6 100644 --- a/torch/ao/pruning/_experimental/pruner/prune_functions.py +++ b/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -97,6 +97,7 @@ def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]: if module.bias is not None: module.bias = nn.Parameter(cast(Tensor, module.bias)[mask]) elif getattr(module, "_bias", None) is not None: + # pyrefly: ignore # bad-assignment module.bias = nn.Parameter(cast(Tensor, module._bias)[mask]) # get pruned biases to propagate to subsequent layer diff --git a/torch/ao/pruning/sparsifier/base_sparsifier.py b/torch/ao/pruning/sparsifier/base_sparsifier.py index 73d4c283da63..a4ffac985631 100644 --- a/torch/ao/pruning/sparsifier/base_sparsifier.py +++ b/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -170,6 +170,7 @@ class BaseSparsifier(abc.ABC): self.make_config_from_model(model) # TODO: Remove the configuration by reference ('module') + # pyrefly: ignore # not-iterable for module_config in self.config: assert isinstance(module_config, dict), ( "config elements should be dicts not modules i.e.:" diff --git a/torch/ao/pruning/sparsifier/utils.py b/torch/ao/pruning/sparsifier/utils.py index 47185aeea527..de7e400757bc 100644 --- a/torch/ao/pruning/sparsifier/utils.py +++ b/torch/ao/pruning/sparsifier/utils.py @@ -51,6 +51,7 @@ def swap_module( new_mod.register_forward_hook(hook_fn) # respect device affinity when swapping modules + # pyrefly: ignore # bad-argument-type devices = {p.device for p in chain(mod.parameters(), mod.buffers())} assert len(devices) <= 1, ( f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index c3541ac83ca3..a78659094ac9 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -235,6 +235,7 @@ class WeightNormSparsifier(BaseSparsifier): ww = self.norm_fn(getattr(module, tensor_name)) tensor_mask = self._make_tensor_mask( data=ww, + # pyrefly: ignore # missing-attribute input_shape=ww.shape, sparsity_level=sparsity_level, sparse_block_shape=sparse_block_shape, diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 3bf5d82f1909..be6de56c717f 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -24,6 +24,8 @@ from .pt2e.export_utils import ( _move_exported_model_to_eval as move_exported_model_to_eval, _move_exported_model_to_train as move_exported_model_to_train, ) + +# pyrefly: ignore # deprecated from .qconfig import * # noqa: F403 from .qconfig_mapping import * # noqa: F403 from .quant_type import * # noqa: F403 diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py index 9082d6c0f99c..7ce4047ca1ba 100644 --- a/torch/ao/quantization/experimental/adaround_optimization.py +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -127,6 +127,7 @@ class AdaptiveRoundingOptimizer: @torch.no_grad() def feed_forward(self, x, weight, module): if isinstance(module, torch.nn.Conv1d): + # pyrefly: ignore # no-matching-overload out = torch.nn.functional.conv1d( x, weight, diff --git a/torch/ao/quantization/fake_quantize.py b/torch/ao/quantization/fake_quantize.py index 7c7ef597df07..c47629428d61 100644 --- a/torch/ao/quantization/fake_quantize.py +++ b/torch/ao/quantization/fake_quantize.py @@ -185,7 +185,9 @@ class FakeQuantize(FakeQuantizeBase): dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get( "dtype", dtype ) + # pyrefly: ignore # bad-argument-type assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound" + # pyrefly: ignore # bad-argument-type assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound" observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max}) observer_kwargs["is_dynamic"] = is_dynamic diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index 1c4517b93c7f..160e9aa3afef 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -1149,6 +1149,7 @@ quantized_decomposed_lib.define( class FakeQuantPerChannel(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max): if scales.dtype != torch.float32: scales = scales.to(torch.float32) @@ -1171,6 +1172,7 @@ class FakeQuantPerChannel(torch.autograd.Function): return out @staticmethod + # pyrefly: ignore # bad-override def backward(ctx, gy): (mask,) = ctx.saved_tensors return gy * mask, None, None, None, None, None diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index 822d261ffc32..71563c236aab 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -246,6 +246,7 @@ def calculate_equalization_scale( class EqualizationQConfig( + # pyrefly: ignore # invalid-inheritance namedtuple("EqualizationQConfig", ["input_activation", "weight"]) ): """ @@ -460,6 +461,7 @@ def maybe_get_next_equalization_scale( In this case, the node given is linear1 and we want to locate the InputEqObs. """ next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules) + # pyrefly: ignore # invalid-argument if next_inp_eq_obs: if ( next_inp_eq_obs.equalization_scale.nelement() == 1 @@ -821,13 +823,18 @@ def convert_eq_obs( # Scale the weight nodes if node.op == "call_module": scale_weight_node( - node, modules, equalization_scale, maybe_next_equalization_scale + node, + modules, + # pyrefly: ignore # bad-argument-type + equalization_scale, + maybe_next_equalization_scale, ) elif node.op == "call_function": scale_weight_functional( node, model, modules, + # pyrefly: ignore # bad-argument-type equalization_scale, maybe_next_equalization_scale, ) diff --git a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py index 60bc65cd1500..6d446bac7fe8 100644 --- a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py +++ b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -223,6 +223,7 @@ class ModelReportVisualizer: feature_val = feature_val.item() # we add to our list of values + # pyrefly: ignore # bad-argument-type tensor_table_row.append(feature_val) tensor_table.append(tensor_table_row) @@ -283,6 +284,7 @@ class ModelReportVisualizer: feature_val = feature_val.item() # add value to channel specific row + # pyrefly: ignore # bad-argument-type new_channel_row.append(feature_val) # add to table and increment row index counter diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index ff85e4505f03..acac09bd3258 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -166,6 +166,7 @@ def _create_obs_or_fq_from_qspec( } edge_or_nodes = quantization_spec.derived_from obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes] + # pyrefly: ignore # unsupported-operation kwargs["obs_or_fqs"] = obs_or_fqs return _DerivedObserverOrFakeQuantize.with_args(**kwargs)() elif isinstance(quantization_spec, FixedQParamsQuantizationSpec): @@ -2085,8 +2086,11 @@ def prepare( root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) + # pyrefly: ignore # bad-argument-type _update_qconfig_for_fusion(model, qconfig_mapping) + # pyrefly: ignore # bad-argument-type _update_qconfig_for_fusion(model, _equalization_config) + # pyrefly: ignore # bad-argument-type flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) @@ -2094,6 +2098,7 @@ def prepare( if is_qat: module_to_qat_module = get_module_to_qat_module(backend_config) _qat_swap_modules(model, module_to_qat_module) + # pyrefly: ignore # bad-argument-type _update_qconfig_for_qat(qconfig_mapping, backend_config) # mapping from fully qualified module name to module instance @@ -2107,10 +2112,20 @@ def prepare( # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches equalization_node_name_to_qconfig = _generate_node_name_to_qconfig( - model, named_modules, model.graph, _equalization_config, node_name_to_scope + model, + named_modules, + model.graph, + # pyrefly: ignore # bad-argument-type + _equalization_config, + node_name_to_scope, ) node_name_to_qconfig = _generate_node_name_to_qconfig( - model, named_modules, model.graph, qconfig_mapping, node_name_to_scope + model, + named_modules, + model.graph, + # pyrefly: ignore # bad-argument-type + qconfig_mapping, + node_name_to_scope, ) # match the patterns that will get quantized @@ -2170,6 +2185,7 @@ def prepare( node_name_to_scope, prepare_custom_config, equalization_node_name_to_qconfig, + # pyrefly: ignore # bad-argument-type qconfig_mapping, is_qat, observed_node_names, diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 6e68bfd4648e..4010312a0b4d 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -720,6 +720,7 @@ def _maybe_get_custom_module_lstm_from_node_arg( a = a.args[0][0] # type: ignore[assignment,index] else: a = a.args[0] # type: ignore[assignment] + # pyrefly: ignore # bad-return return a all_match_patterns = [ diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index d8975cc3571d..2f404dcd1a42 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -280,9 +280,12 @@ class UniformQuantizationObserverBase(ObserverBase): ) self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) if self.has_customized_qrange: + # pyrefly: ignore # bad-argument-type validate_qmin_qmax(quant_min, quant_max) self.quant_min, self.quant_max = calculate_qmin_qmax( + # pyrefly: ignore # bad-argument-type quant_min, + # pyrefly: ignore # bad-argument-type quant_max, self.has_customized_qrange, self.dtype, diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 3a90bb953f17..63cd49cb6983 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -72,6 +72,7 @@ def _find_q_dq_node_for_user( dq_node = n break if dq_node is None: + # pyrefly: ignore # bad-assignment for n in user.kwargs: if ( isinstance(n, torch.fx.Node) diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 94dfdb8c7626..623fd12434ab 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -83,6 +83,7 @@ __all__ = [ ] +# pyrefly: ignore # invalid-inheritance class QConfig(namedtuple("QConfig", ["activation", "weight"])): """ Describes how to quantize a layer or a part of the network by providing @@ -120,6 +121,7 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])): "`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", category=FutureWarning, ) +# pyrefly: ignore # invalid-inheritance class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): """ Describes how to dynamically quantize a layer or a part of the network by providing diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index a8637e1668c1..8ca3e91af97e 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -417,6 +417,7 @@ class X86InductorQuantizer(Quantizer): # As we use `_need_skip_config` to skip all invalid configurations, # we can safely assume that the all existing non-None configurations # have the same quantization mode. + # pyrefly: ignore # bad-assignment for qconfig in ( list(self.module_name_qconfig.values()) + list(self.operator_type_qconfig.values()) @@ -808,6 +809,7 @@ class X86InductorQuantizer(Quantizer): ) binary_node.meta[QUANT_ANNOTATION_KEY] = ( _X86InductorQuantizationAnnotation( + # pyrefly: ignore # bad-argument-type input_qspec_map=binary_node_input_qspec_map, _annotated=True, ) @@ -878,6 +880,7 @@ class X86InductorQuantizer(Quantizer): ) binary_node.meta[QUANT_ANNOTATION_KEY] = ( _X86InductorQuantizationAnnotation( + # pyrefly: ignore # bad-argument-type input_qspec_map=binary_node_input_qspec_map, # TODO Remove the annotate of output in QAT when qat util support pattern matcher. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] @@ -1085,6 +1088,7 @@ class X86InductorQuantizer(Quantizer): quantization_config ) binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + # pyrefly: ignore # bad-argument-type input_qspec_map=binary_node_input_qspec_map, _annotated=True, ) @@ -1139,6 +1143,7 @@ class X86InductorQuantizer(Quantizer): quantization_config ) binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + # pyrefly: ignore # bad-argument-type input_qspec_map=binary_node_input_qspec_map, _annotated=True, _is_output_of_quantized_pattern=True, @@ -1499,6 +1504,7 @@ class X86InductorQuantizer(Quantizer): has_unary = unary_op is not None seq_partition = [torch.nn.Linear, binary_op] if has_unary: + # pyrefly: ignore # bad-argument-type seq_partition.append(unary_op) fused_partitions = find_sequential_partitions(gm, seq_partition) for fused_partition in fused_partitions: diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index dec59bb02df3..d6cd477ea047 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -376,9 +376,11 @@ def _do_annotate_conv_relu( input_qspec_map[bias] = get_bias_qspec(quantization_config) partition.append(bias) + # pyrefly: ignore # bad-argument-type if _is_annotated(partition): continue + # pyrefly: ignore # bad-argument-type if filter_fn and any(not filter_fn(n) for n in partition): continue @@ -389,6 +391,7 @@ def _do_annotate_conv_relu( output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, ) + # pyrefly: ignore # bad-argument-type _mark_nodes_as_annotated(partition) annotated_partitions.append(partition) return annotated_partitions diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index 74fedc929783..44446b4ffda8 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -39,6 +39,7 @@ class Bernoulli(ExponentialFamily): validate_args (bool, optional): whether to validate arguments, None by default """ + # pyrefly: ignore # bad-override arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.boolean has_enumerate_support = True @@ -56,10 +57,12 @@ class Bernoulli(ExponentialFamily): ) if probs is not None: is_scalar = isinstance(probs, _Number) + # pyrefly: ignore # read-only (self.probs,) = broadcast_all(probs) else: assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) + # pyrefly: ignore # read-only (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits if is_scalar: @@ -137,5 +140,6 @@ class Bernoulli(ExponentialFamily): def _natural_params(self) -> tuple[Tensor]: return (torch.logit(self.probs),) + # pyrefly: ignore # bad-override def _log_normalizer(self, x): return torch.log1p(torch.exp(x)) diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index e06a28ca5aa4..0cab53613079 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -31,6 +31,7 @@ class Beta(ExponentialFamily): (often referred to as beta) """ + # pyrefly: ignore # bad-override arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, @@ -113,5 +114,6 @@ class Beta(ExponentialFamily): def _natural_params(self) -> tuple[Tensor, Tensor]: return (self.concentration1, self.concentration0) + # pyrefly: ignore # bad-override def _log_normalizer(self, x, y): return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 90461784c06d..b400b9861407 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -45,6 +45,7 @@ class Binomial(Distribution): logits (Tensor): Event log-odds """ + # pyrefly: ignore # bad-override arg_constraints = { "total_count": constraints.nonnegative_integer, "probs": constraints.unit_interval, @@ -66,6 +67,7 @@ class Binomial(Distribution): if probs is not None: ( self.total_count, + # pyrefly: ignore # read-only self.probs, ) = broadcast_all(total_count, probs) self.total_count = self.total_count.type_as(self.probs) @@ -73,6 +75,7 @@ class Binomial(Distribution): assert logits is not None # helps mypy ( self.total_count, + # pyrefly: ignore # read-only self.logits, ) = broadcast_all(total_count, logits) self.total_count = self.total_count.type_as(self.logits) @@ -99,6 +102,7 @@ class Binomial(Distribution): return self._param.new(*args, **kwargs) @constraints.dependent_property(is_discrete=True, event_dim=0) + # pyrefly: ignore # bad-override def support(self): return constraints.integer_interval(0, self.total_count) diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 1c8fed2636ad..7e083f802206 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -50,6 +50,7 @@ class Categorical(Distribution): logits (Tensor): event log probabilities (unnormalized) """ + # pyrefly: ignore # bad-override arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} has_enumerate_support = True @@ -66,12 +67,14 @@ class Categorical(Distribution): if probs is not None: if probs.dim() < 1: raise ValueError("`probs` parameter must be at least one-dimensional.") + # pyrefly: ignore # read-only self.probs = probs / probs.sum(-1, keepdim=True) else: assert logits is not None # helps mypy if logits.dim() < 1: raise ValueError("`logits` parameter must be at least one-dimensional.") # Normalize + # pyrefly: ignore # read-only self.logits = logits - logits.logsumexp(dim=-1, keepdim=True) self._param = self.probs if probs is not None else self.logits self._num_events = self._param.size()[-1] @@ -99,6 +102,7 @@ class Categorical(Distribution): return self._param.new(*args, **kwargs) @constraints.dependent_property(is_discrete=True, event_dim=0) + # pyrefly: ignore # bad-override def support(self): return constraints.integer_interval(0, self._num_events - 1) diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 84c1d34bda79..39b9885b237c 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -31,6 +31,7 @@ class Cauchy(Distribution): scale (float or Tensor): half width at half maximum. """ + # pyrefly: ignore # bad-override arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real has_rsample = True diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py index 14d0d6a9c177..d949a19d3f77 100644 --- a/torch/distributions/continuous_bernoulli.py +++ b/torch/distributions/continuous_bernoulli.py @@ -47,6 +47,7 @@ class ContinuousBernoulli(ExponentialFamily): https://arxiv.org/abs/1907.06845 """ + # pyrefly: ignore # bad-override arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.unit_interval _mean_carrier_measure = 0 @@ -65,16 +66,19 @@ class ContinuousBernoulli(ExponentialFamily): ) if probs is not None: is_scalar = isinstance(probs, _Number) + # pyrefly: ignore # read-only (self.probs,) = broadcast_all(probs) # validate 'probs' here if necessary as it is later clamped for numerical stability # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass if validate_args is not None: if not self.arg_constraints["probs"].check(self.probs).all(): raise ValueError("The parameter probs has invalid values") + # pyrefly: ignore # read-only self.probs = clamp_probs(self.probs) else: assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) + # pyrefly: ignore # read-only (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits if is_scalar: @@ -230,6 +234,7 @@ class ContinuousBernoulli(ExponentialFamily): def _natural_params(self) -> tuple[Tensor]: return (self.logits,) + # pyrefly: ignore # bad-override def _log_normalizer(self, x): """computes the log normalizing constant as a function of the natural parameter""" out_unst_reg = torch.max( diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index 414ad6efe47e..0f2a656ac21d 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -22,6 +22,7 @@ def _Dirichlet_backward(x, concentration, grad_output): class _Dirichlet(Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx, concentration): x = torch._sample_dirichlet(concentration) ctx.save_for_backward(x, concentration) @@ -29,6 +30,7 @@ class _Dirichlet(Function): @staticmethod @once_differentiable + # pyrefly: ignore # bad-override def backward(ctx, grad_output): x, concentration = ctx.saved_tensors return _Dirichlet_backward(x, concentration, grad_output) @@ -50,6 +52,7 @@ class Dirichlet(ExponentialFamily): (often referred to as alpha) """ + # pyrefly: ignore # bad-override arg_constraints = { "concentration": constraints.independent(constraints.positive, 1) } @@ -130,5 +133,6 @@ class Dirichlet(ExponentialFamily): def _natural_params(self) -> tuple[Tensor]: return (self.concentration,) + # pyrefly: ignore # bad-override def _log_normalizer(self, x): return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1)) diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index d15cb1f7a258..3630d4158b37 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -27,6 +27,7 @@ class Exponential(ExponentialFamily): rate (float or Tensor): rate = 1 / scale of the distribution """ + # pyrefly: ignore # bad-override arg_constraints = {"rate": constraints.positive} support = constraints.nonnegative has_rsample = True @@ -89,5 +90,6 @@ class Exponential(ExponentialFamily): def _natural_params(self) -> tuple[Tensor]: return (-self.rate,) + # pyrefly: ignore # bad-override def _log_normalizer(self, x): return -torch.log(-x) diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py index 4755bd0d8bde..a329d68c61e6 100644 --- a/torch/distributions/fishersnedecor.py +++ b/torch/distributions/fishersnedecor.py @@ -29,6 +29,7 @@ class FisherSnedecor(Distribution): df2 (float or Tensor): degrees of freedom parameter 2 """ + # pyrefly: ignore # bad-override arg_constraints = {"df1": constraints.positive, "df2": constraints.positive} support = constraints.positive has_rsample = True diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py index 9df91ebee640..67086674714c 100644 --- a/torch/distributions/gamma.py +++ b/torch/distributions/gamma.py @@ -34,6 +34,7 @@ class Gamma(ExponentialFamily): (often referred to as beta), rate = 1 / scale """ + # pyrefly: ignore # bad-override arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, @@ -109,6 +110,7 @@ class Gamma(ExponentialFamily): def _natural_params(self) -> tuple[Tensor, Tensor]: return (self.concentration - 1, -self.rate) + # pyrefly: ignore # bad-override def _log_normalizer(self, x, y): return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal()) diff --git a/torch/distributions/generalized_pareto.py b/torch/distributions/generalized_pareto.py index 4ee0a54b608f..218faacfdb60 100644 --- a/torch/distributions/generalized_pareto.py +++ b/torch/distributions/generalized_pareto.py @@ -35,6 +35,7 @@ class GeneralizedPareto(Distribution): concentration (float or Tensor): Concentration parameter of the distribution """ + # pyrefly: ignore # bad-override arg_constraints = { "loc": constraints.real, "scale": constraints.positive, @@ -130,6 +131,7 @@ class GeneralizedPareto(Distribution): concentration = self.concentration valid = concentration < 0.5 safe_conc = torch.where(valid, concentration, 0.25) + # pyrefly: ignore # unsupported-operation result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc)) return torch.where(valid, result, nan) @@ -142,6 +144,7 @@ class GeneralizedPareto(Distribution): return self.loc @constraints.dependent_property(is_discrete=False, event_dim=0) + # pyrefly: ignore # bad-override def support(self): lower = self.loc upper = torch.where( diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index b5ceac39e94e..4fb6b534b1bf 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -44,6 +44,7 @@ class Geometric(Distribution): logits (Number, Tensor): the log-odds of sampling `1`. """ + # pyrefly: ignore # bad-override arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.nonnegative_integer @@ -58,9 +59,11 @@ class Geometric(Distribution): "Either `probs` or `logits` must be specified, but not both." ) if probs is not None: + # pyrefly: ignore # read-only (self.probs,) = broadcast_all(probs) else: assert logits is not None # helps mypy + # pyrefly: ignore # read-only (self.logits,) = broadcast_all(logits) probs_or_logits = probs if probs is not None else logits if isinstance(probs_or_logits, _Number): diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index 6d097c9324e2..8057d9718de6 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -32,6 +32,7 @@ class Gumbel(TransformedDistribution): """ arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + # pyrefly: ignore # bad-override support = constraints.real def __init__( diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py index 572ae080ac3e..a2848f3d0cdd 100644 --- a/torch/distributions/half_cauchy.py +++ b/torch/distributions/half_cauchy.py @@ -32,8 +32,10 @@ class HalfCauchy(TransformedDistribution): """ arg_constraints = {"scale": constraints.positive} + # pyrefly: ignore # bad-override support = constraints.nonnegative has_rsample = True + # pyrefly: ignore # bad-override base_dist: Cauchy def __init__( diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py index 21e1b9d2c506..0aac8852e6e1 100644 --- a/torch/distributions/half_normal.py +++ b/torch/distributions/half_normal.py @@ -32,8 +32,10 @@ class HalfNormal(TransformedDistribution): """ arg_constraints = {"scale": constraints.positive} + # pyrefly: ignore # bad-override support = constraints.nonnegative has_rsample = True + # pyrefly: ignore # bad-override base_dist: Normal def __init__( diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index b66406681bb8..b901a7caab58 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -91,6 +91,7 @@ class Independent(Distribution, Generic[D]): return self.base_dist.has_enumerate_support @constraints.dependent_property + # pyrefly: ignore # bad-override def support(self): result = self.base_dist.support if self.reinterpreted_batch_ndims: diff --git a/torch/distributions/inverse_gamma.py b/torch/distributions/inverse_gamma.py index de432a34434e..1be089e5331e 100644 --- a/torch/distributions/inverse_gamma.py +++ b/torch/distributions/inverse_gamma.py @@ -38,8 +38,10 @@ class InverseGamma(TransformedDistribution): "concentration": constraints.positive, "rate": constraints.positive, } + # pyrefly: ignore # bad-override support = constraints.positive has_rsample = True + # pyrefly: ignore # bad-override base_dist: Gamma def __init__( diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py index 53c09ab9870d..03fe9d6e3712 100644 --- a/torch/distributions/kumaraswamy.py +++ b/torch/distributions/kumaraswamy.py @@ -44,6 +44,7 @@ class Kumaraswamy(TransformedDistribution): "concentration1": constraints.positive, "concentration0": constraints.positive, } + # pyrefly: ignore # bad-override support = constraints.unit_interval has_rsample = True @@ -66,6 +67,7 @@ class Kumaraswamy(TransformedDistribution): AffineTransform(loc=1.0, scale=-1.0), PowerTransform(exponent=self.concentration1.reciprocal()), ] + # pyrefly: ignore # bad-argument-type super().__init__(base_dist, transforms, validate_args=validate_args) def expand(self, batch_shape, _instance=None): diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index 0d50712fb26f..01f51edc0546 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -28,6 +28,7 @@ class Laplace(Distribution): scale (float or Tensor): scale of the distribution """ + # pyrefly: ignore # bad-override arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real has_rsample = True diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py index f3fc4b20751e..3f1e6b98b6fe 100644 --- a/torch/distributions/lkj_cholesky.py +++ b/torch/distributions/lkj_cholesky.py @@ -60,6 +60,7 @@ class LKJCholesky(Distribution): Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008 """ + # pyrefly: ignore # bad-override arg_constraints = {"concentration": constraints.positive} support = constraints.corr_cholesky diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py index 2c6dbc6bf55c..675c58ab2e64 100644 --- a/torch/distributions/log_normal.py +++ b/torch/distributions/log_normal.py @@ -32,8 +32,10 @@ class LogNormal(TransformedDistribution): """ arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + # pyrefly: ignore # bad-override support = constraints.positive has_rsample = True + # pyrefly: ignore # bad-override base_dist: Normal def __init__( diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py index 729e3a67419f..14ef668a72cf 100644 --- a/torch/distributions/logistic_normal.py +++ b/torch/distributions/logistic_normal.py @@ -36,8 +36,10 @@ class LogisticNormal(TransformedDistribution): """ arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + # pyrefly: ignore # bad-override support = constraints.simplex has_rsample = True + # pyrefly: ignore # bad-override base_dist: Independent[Normal] def __init__( diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index 968e4634ba62..27270ee59cec 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -86,6 +86,7 @@ class LowRankMultivariateNormal(Distribution): capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor """ + # pyrefly: ignore # bad-override arg_constraints = { "loc": constraints.real_vector, "cov_factor": constraints.independent(constraints.real, 2), diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index 3fe47a4b4c6b..3ab2be132e9b 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -124,6 +124,7 @@ class MixtureSameFamily(Distribution): return new @constraints.dependent_property + # pyrefly: ignore # bad-override def support(self): return MixtureSameFamilyConstraint(self._component_distribution.support) diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 41d8ded53fd6..58e7f6734b0d 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -50,6 +50,7 @@ class Multinomial(Distribution): logits (Tensor): event log probabilities (unnormalized) """ + # pyrefly: ignore # bad-override arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} total_count: int @@ -92,6 +93,7 @@ class Multinomial(Distribution): return self._categorical._new(*args, **kwargs) @constraints.dependent_property(is_discrete=True, event_dim=1) + # pyrefly: ignore # bad-override def support(self): return constraints.multinomial(self.total_count) diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index c15a84815b06..1cf701901819 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -123,6 +123,7 @@ class MultivariateNormal(Distribution): the corresponding lower triangular matrices using a Cholesky decomposition. """ + # pyrefly: ignore # bad-override arg_constraints = { "loc": constraints.real_vector, "covariance_matrix": constraints.positive_definite, @@ -156,6 +157,7 @@ class MultivariateNormal(Distribution): "with optional leading batch dimensions" ) batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) + # pyrefly: ignore # read-only self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) elif covariance_matrix is not None: if covariance_matrix.dim() < 2: @@ -166,6 +168,7 @@ class MultivariateNormal(Distribution): batch_shape = torch.broadcast_shapes( covariance_matrix.shape[:-2], loc.shape[:-1] ) + # pyrefly: ignore # read-only self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) else: assert precision_matrix is not None # helps mypy @@ -177,6 +180,7 @@ class MultivariateNormal(Distribution): batch_shape = torch.broadcast_shapes( precision_matrix.shape[:-2], loc.shape[:-1] ) + # pyrefly: ignore # read-only self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) self.loc = loc.expand(batch_shape + (-1,)) diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index f28222f92f78..a743c318f419 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -33,6 +33,7 @@ class NegativeBinomial(Distribution): logits (Tensor): Event log-odds for probabilities of success """ + # pyrefly: ignore # bad-override arg_constraints = { "total_count": constraints.greater_than_eq(0), "probs": constraints.half_open_interval(0.0, 1.0), @@ -54,6 +55,7 @@ class NegativeBinomial(Distribution): if probs is not None: ( self.total_count, + # pyrefly: ignore # read-only self.probs, ) = broadcast_all(total_count, probs) self.total_count = self.total_count.type_as(self.probs) @@ -61,6 +63,7 @@ class NegativeBinomial(Distribution): assert logits is not None # helps mypy ( self.total_count, + # pyrefly: ignore # read-only self.logits, ) = broadcast_all(total_count, logits) self.total_count = self.total_count.type_as(self.logits) diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 626358d14795..cc391f4afacc 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -31,6 +31,7 @@ class Normal(ExponentialFamily): (often referred to as sigma) """ + # pyrefly: ignore # bad-override arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real has_rsample = True @@ -88,6 +89,7 @@ class Normal(ExponentialFamily): if self._validate_args: self._validate_sample(value) # compute the variance + # pyrefly: ignore # unsupported-operation var = self.scale**2 log_scale = ( math.log(self.scale) @@ -117,5 +119,6 @@ class Normal(ExponentialFamily): def _natural_params(self) -> tuple[Tensor, Tensor]: return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal()) + # pyrefly: ignore # bad-override def _log_normalizer(self, x, y): return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y) diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 8edb6da0b8dd..aec55b95f786 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -42,6 +42,7 @@ class OneHotCategorical(Distribution): logits (Tensor): event log probabilities (unnormalized) """ + # pyrefly: ignore # bad-override arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.one_hot has_enumerate_support = True diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index bbca7e0cba35..8b1df0dbb9bb 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -39,6 +39,7 @@ class Pareto(TransformedDistribution): self.scale, self.alpha = broadcast_all(scale, alpha) base_dist = Exponential(self.alpha, validate_args=validate_args) transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)] + # pyrefly: ignore # bad-argument-type super().__init__(base_dist, transforms, validate_args=validate_args) def expand( diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index d3fb4446baf4..04524ec56d93 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -32,6 +32,7 @@ class Poisson(ExponentialFamily): rate (Number, Tensor): the rate parameter """ + # pyrefly: ignore # bad-override arg_constraints = {"rate": constraints.nonnegative} support = constraints.nonnegative_integer @@ -82,5 +83,6 @@ class Poisson(ExponentialFamily): def _natural_params(self) -> tuple[Tensor]: return (torch.log(self.rate),) + # pyrefly: ignore # bad-override def _log_normalizer(self, x): return torch.exp(x) diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index 16ad4219627e..fd6e4226603c 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -40,6 +40,7 @@ class LogitRelaxedBernoulli(Distribution): (Jang et al., 2017) """ + # pyrefly: ignore # bad-override arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.real @@ -57,10 +58,12 @@ class LogitRelaxedBernoulli(Distribution): ) if probs is not None: is_scalar = isinstance(probs, _Number) + # pyrefly: ignore # read-only (self.probs,) = broadcast_all(probs) else: assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) + # pyrefly: ignore # read-only (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits if is_scalar: @@ -138,8 +141,10 @@ class RelaxedBernoulli(TransformedDistribution): """ arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + # pyrefly: ignore # bad-override support = constraints.unit_interval has_rsample = True + # pyrefly: ignore # bad-override base_dist: LogitRelaxedBernoulli def __init__( diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index 47314be9e44a..c5492d69b706 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -38,6 +38,7 @@ class ExpRelaxedCategorical(Distribution): (Jang et al., 2017) """ + # pyrefly: ignore # bad-override arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = ( constraints.real_vector @@ -127,8 +128,10 @@ class RelaxedOneHotCategorical(TransformedDistribution): """ arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + # pyrefly: ignore # bad-override support = constraints.simplex has_rsample = True + # pyrefly: ignore # bad-override base_dist: ExpRelaxedCategorical def __init__( diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py index d98554f413c0..ef84b5bdc879 100644 --- a/torch/distributions/studentT.py +++ b/torch/distributions/studentT.py @@ -31,6 +31,7 @@ class StudentT(Distribution): scale (float or Tensor): scale of the distribution """ + # pyrefly: ignore # bad-override arg_constraints = { "df": constraints.positive, "loc": constraints.real, diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 1724b586b5a7..1cc427f069f5 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -123,6 +123,7 @@ class TransformedDistribution(Distribution): return new @constraints.dependent_property(is_discrete=False) + # pyrefly: ignore # bad-override def support(self): if not self.transforms: return self.base_dist.support diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 9584bb0b342d..9fdf6911c10c 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -226,11 +226,13 @@ class _InverseTransform(Transform): self._inv: Transform = transform # type: ignore[assignment] @constraints.dependent_property(is_discrete=False) + # pyrefly: ignore # bad-override def domain(self): assert self._inv is not None return self._inv.codomain @constraints.dependent_property(is_discrete=False) + # pyrefly: ignore # bad-override def codomain(self): assert self._inv is not None return self._inv.domain @@ -300,6 +302,7 @@ class ComposeTransform(Transform): return self.parts == other.parts @constraints.dependent_property(is_discrete=False) + # pyrefly: ignore # bad-override def domain(self): if not self.parts: return constraints.real @@ -315,6 +318,7 @@ class ComposeTransform(Transform): return domain @constraints.dependent_property(is_discrete=False) + # pyrefly: ignore # bad-override def codomain(self): if not self.parts: return constraints.real @@ -434,12 +438,14 @@ class IndependentTransform(Transform): ) @constraints.dependent_property(is_discrete=False) + # pyrefly: ignore # bad-override def domain(self): return constraints.independent( self.base_transform.domain, self.reinterpreted_batch_ndims ) @constraints.dependent_property(is_discrete=False) + # pyrefly: ignore # bad-override def codomain(self): return constraints.independent( self.base_transform.codomain, self.reinterpreted_batch_ndims @@ -507,10 +513,12 @@ class ReshapeTransform(Transform): super().__init__(cache_size=cache_size) @constraints.dependent_property + # pyrefly: ignore # bad-override def domain(self): return constraints.independent(constraints.real, len(self.in_shape)) @constraints.dependent_property + # pyrefly: ignore # bad-override def codomain(self): return constraints.independent(constraints.real, len(self.out_shape)) @@ -764,12 +772,14 @@ class AffineTransform(Transform): return self._event_dim @constraints.dependent_property(is_discrete=False) + # pyrefly: ignore # bad-override def domain(self): if self.event_dim == 0: return constraints.real return constraints.independent(constraints.real, self.event_dim) @constraints.dependent_property(is_discrete=False) + # pyrefly: ignore # bad-override def codomain(self): if self.event_dim == 0: return constraints.real @@ -867,6 +877,7 @@ class CorrCholeskyTransform(Transform): # apply stick-breaking on the squared values # Note that y = sign(r) * sqrt(z * z1m_cumprod) # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) + # pyrefly: ignore # unsupported-operation z = r**2 z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) # Diagonal elements must be 1. @@ -1155,12 +1166,14 @@ class CatTransform(Transform): return all(t.bijective for t in self.transforms) @constraints.dependent_property + # pyrefly: ignore # bad-override def domain(self): return constraints.cat( [t.domain for t in self.transforms], self.dim, self.lengths ) @constraints.dependent_property + # pyrefly: ignore # bad-override def codomain(self): return constraints.cat( [t.codomain for t in self.transforms], self.dim, self.lengths @@ -1233,10 +1246,12 @@ class StackTransform(Transform): return all(t.bijective for t in self.transforms) @constraints.dependent_property + # pyrefly: ignore # bad-override def domain(self): return constraints.stack([t.domain for t in self.transforms], self.dim) @constraints.dependent_property + # pyrefly: ignore # bad-override def codomain(self): return constraints.stack([t.codomain for t in self.transforms], self.dim) diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index b6e7c2640cfc..fc3cac86770a 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -79,6 +79,7 @@ class Uniform(Distribution): return new @constraints.dependent_property(is_discrete=False, event_dim=0) + # pyrefly: ignore # bad-override def support(self): return constraints.interval(self.low, self.high) diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index 4f96a23cf55b..9112d9f5be3c 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -92,6 +92,7 @@ def _log_modified_bessel_fn(x, order=0): @torch.jit.script_if_tracing def _rejection_sample(loc, concentration, proposal_r, x): done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device) + # pyrefly: ignore # bad-assignment while not done.all(): u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device) u1, u2, u3 = u.unbind() @@ -100,6 +101,7 @@ def _rejection_sample(loc, concentration, proposal_r, x): c = concentration * (proposal_r - f) accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0) if accept.any(): + # pyrefly: ignore # no-matching-overload x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x) done = done | accept return (x + math.pi + loc) % (2 * math.pi) - math.pi @@ -123,6 +125,7 @@ class VonMises(Distribution): :param torch.Tensor concentration: concentration parameter """ + # pyrefly: ignore # bad-override arg_constraints = {"loc": constraints.real, "concentration": constraints.positive} support = constraints.real has_rsample = False @@ -160,8 +163,10 @@ class VonMises(Distribution): @lazy_property def _proposal_r(self) -> Tensor: kappa = self._concentration + # pyrefly: ignore # unsupported-operation tau = 1 + (1 + 4 * kappa**2).sqrt() rho = (tau - (2 * tau).sqrt()) / (2 * kappa) + # pyrefly: ignore # unsupported-operation _proposal_r = (1 + rho**2) / (2 * rho) # second order Taylor expansion around 0 for small kappa _proposal_r_taylor = 1 / kappa + kappa diff --git a/torch/distributions/weibull.py b/torch/distributions/weibull.py index aec5e6b8cd1c..0c7c3762b774 100644 --- a/torch/distributions/weibull.py +++ b/torch/distributions/weibull.py @@ -35,6 +35,7 @@ class Weibull(TransformedDistribution): "scale": constraints.positive, "concentration": constraints.positive, } + # pyrefly: ignore # bad-override support = constraints.positive def __init__( @@ -52,6 +53,7 @@ class Weibull(TransformedDistribution): PowerTransform(exponent=self.concentration_reciprocal), AffineTransform(loc=0, scale=self.scale), ] + # pyrefly: ignore # bad-argument-type super().__init__(base_dist, transforms, validate_args=validate_args) def expand(self, batch_shape, _instance=None): diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index c5865b6b43c4..5aaa3ddc9d09 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -116,10 +116,13 @@ class Wishart(ExponentialFamily): ) if scale_tril is not None: + # pyrefly: ignore # read-only self.scale_tril = param.expand(batch_shape + (-1, -1)) elif covariance_matrix is not None: + # pyrefly: ignore # read-only self.covariance_matrix = param.expand(batch_shape + (-1, -1)) elif precision_matrix is not None: + # pyrefly: ignore # read-only self.precision_matrix = param.expand(batch_shape + (-1, -1)) if self.df.lt(event_shape[-1]).any(): @@ -335,6 +338,7 @@ class Wishart(ExponentialFamily): p = self._event_shape[-1] # has singleton shape return -self.precision_matrix / 2, (nu - p - 1) / 2 + # pyrefly: ignore # bad-override def _log_normalizer(self, x, y): p = self._event_shape[-1] return (y + (p + 1) / 2) * ( diff --git a/torch/export/pt2_archive/_package_weights.py b/torch/export/pt2_archive/_package_weights.py index 05643aa929b2..f9b369abc88e 100644 --- a/torch/export/pt2_archive/_package_weights.py +++ b/torch/export/pt2_archive/_package_weights.py @@ -24,11 +24,15 @@ class TensorProperties: if not self.is_fake: # only get the storage pointer for real tensors + # pyrefly: ignore # bad-assignment self.storage_ptr = tensor.untyped_storage().data_ptr() if self.is_contiguous: # only get storage size and start/end pointers for contiguous tensors + # pyrefly: ignore # bad-assignment self.storage_size = tensor.untyped_storage().nbytes() + # pyrefly: ignore # bad-assignment self.start = tensor.data_ptr() + # pyrefly: ignore # bad-assignment self.end = _end_ptr(tensor) # info to recover tensor diff --git a/torch/fx/_graph_pickler.py b/torch/fx/_graph_pickler.py index f913a24bc736..62d6fc0a36f8 100644 --- a/torch/fx/_graph_pickler.py +++ b/torch/fx/_graph_pickler.py @@ -65,6 +65,7 @@ class GraphPickler(pickle.Pickler): self._meta_tensor_describer = MetaTensorDescriber(copy_data=False) @override + # pyrefly: ignore # bad-override def reducer_override( self, obj: object ) -> tuple[Callable[..., Any], tuple[Any, ...]]: @@ -201,6 +202,7 @@ class _SymNodePickleData: ]: args = (cls(obj.node), pickler._unpickle_state) if isinstance(obj, torch.SymInt): + # pyrefly: ignore # bad-return return _SymNodePickleData.unpickle_sym_int, args else: raise NotImplementedError(f"Unhandled SymNode type {type(obj)}") @@ -277,6 +279,7 @@ class _TensorPickleData: return FakeTensor( unpickle_state.fake_mode, make_meta_t(), + # pyrefly: ignore # bad-argument-type device, ) diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index b7109832cf77..6e67fa56d168 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -603,6 +603,7 @@ class Tracer(TracerBase): in inspect.signature(self.create_proxy).parameters ): kwargs["proxy_factory_fn"] = ( + # pyrefly: ignore # unsupported-operation None if not self.param_shapes_constant else lambda node: ParameterProxy( diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index c29d05f511a7..153e54400ee1 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -657,7 +657,10 @@ class Partitioner: # Mark bfs level get_bfs_level_partition(self.partitions) find_combination, partitions = find_partition_to_combine_based_on_size( - sorted_partitions, available_mem_bytes, partitions + sorted_partitions, + available_mem_bytes, + # pyrefly: ignore # bad-argument-type + partitions, ) return @@ -702,6 +705,7 @@ class Partitioner: non_embedding_partitions.append(partition) if new_partition: partition = self.create_partition() + # pyrefly: ignore # missing-attribute partition.left_mem_bytes = available_mem_bytes return partition return None @@ -997,6 +1001,7 @@ class Partitioner: node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec ) if cost < min_cost: + # pyrefly: ignore # bad-assignment node_pair = [node, n1] min_cost = cost return cost, node_pair # type: ignore[possibly-undefined] diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py index c6a51918f930..80912ec87e7a 100644 --- a/torch/fx/experimental/merge_matmul.py +++ b/torch/fx/experimental/merge_matmul.py @@ -30,6 +30,7 @@ def split_result_tensors( else: splits = [x.shape[0] for x in inputs] + # pyrefly: ignore # bad-argument-type return torch.split(result, splits) diff --git a/torch/fx/experimental/meta_tracer.py b/torch/fx/experimental/meta_tracer.py index bc00be5ee7ae..5f437cc0a686 100644 --- a/torch/fx/experimental/meta_tracer.py +++ b/torch/fx/experimental/meta_tracer.py @@ -171,7 +171,14 @@ class MetaTracer(torch.fx.Tracer): proxy_factory_fn=None, ): rv = super().create_proxy( - kind, target, args, kwargs, name, type_expr, proxy_factory_fn + kind, + target, + args, + kwargs, + name, + type_expr, + # pyrefly: ignore # bad-argument-type + proxy_factory_fn, ) if kind == "placeholder" and target in self.meta_args: @@ -193,6 +200,7 @@ class MetaTracer(torch.fx.Tracer): if kind == "call_function": meta_target = manual_meta_overrides.get(target, target) + # pyrefly: ignore # not-callable meta_out = meta_target(*args_metas, **kwargs_metas) elif kind == "call_method": meta_target = getattr(args_metas[0], target) # type: ignore[index] diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index e4951aab15cb..9e0f8f98768c 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -528,9 +528,11 @@ def view_inference_rule(n: Node, symbols, constraints, counter): if t == -1: var, counter = gen_dvar(counter) t2_type.append(var) + # pyrefly: ignore # bad-argument-type num_constraints.append(BinConstraintD(var, Dyn, op_neq)) else: + # pyrefly: ignore # bad-argument-type num_constraints.append(BinConstraintD(t, Dyn, op_neq)) t2_type.append(t) # type: ignore[arg-type] @@ -1475,6 +1477,7 @@ class ConstraintGenerator: all_constraints = [] + # pyrefly: ignore # missing-attribute for n in graph.nodes: (constraints, counter) = self.generate_constraints_node(n, counter) all_constraints += constraints diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index 3e406b57a96d..73a7805f0478 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -193,6 +193,7 @@ def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]): assert isinstance(node.target, str) cur_module = modules[node.target] if type(cur_module) in mkldnn_map: + # pyrefly: ignore # index-error new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) assert isinstance(new_module, nn.Module) old_modules[new_module] = copy.deepcopy(cur_module) @@ -263,7 +264,10 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): ) reset_modules( - submodule.graph.nodes, dict(submodule.named_modules()), old_modules + submodule.graph.nodes, + dict(submodule.named_modules()), + # pyrefly: ignore # bad-argument-type + old_modules, ) no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) return mkl_time < no_mkl_time diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index b4daffa46291..5f9e8aec4bff 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -124,6 +124,7 @@ pytree.register_pytree_node( torch.Size, lambda xs: (list(xs), None), lambda xs, _: tuple(xs), + # pyrefly: ignore # bad-argument-type flatten_with_keys_fn=lambda xs: ( [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)], None, @@ -306,6 +307,7 @@ def set_proxy_slot( # type: ignore[no-redef] def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool: assert isinstance(obj, (Tensor, SymNode)), type(obj) + # pyrefly: ignore # no-matching-overload return bool(get_proxy_slot(obj, tracer, False, lambda _: True)) @@ -402,6 +404,7 @@ def get_proxy_slot( assert isinstance(obj, py_sym_types), type(obj) tracker = tracer.symnode_tracker + # pyrefly: ignore # unsupported-operation if obj not in tracker: # Last ditch if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker: @@ -413,6 +416,7 @@ def get_proxy_slot( ) return default else: + # pyrefly: ignore # index-error value = tracker[obj] res = transform(value) return res @@ -788,6 +792,7 @@ def fetch_object_proxy( def fetch_object_proxy( tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType] ) -> object: + # pyrefly: ignore # no-matching-overload return get_proxy_slot(t, tracer, t) @@ -836,6 +841,7 @@ def _fetch_proxies_and_all_constant_flag( """ f_flat_args_kwargs = [ ( + # pyrefly: ignore # no-matching-overload fetch_object_proxy(tracer, x) if isinstance(x, (Tensor, _AnyScriptObject)) else x @@ -1410,6 +1416,7 @@ class TorchFunctionMetadataMode(TorchFunctionMode): kwargs: Optional[dict[str, object]] = None, ) -> object: kwargs = kwargs or {} + # pyrefly: ignore # bad-assignment self.tracer.torch_fn_metadata = func self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1 return func(*args, **kwargs) @@ -1459,6 +1466,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode): # For autocast, the python APIs run so we don't have to run them again # here. if func is torch._C._set_grad_enabled: + # pyrefly: ignore # bad-argument-type func(*args, **kwargs) return node @@ -1672,6 +1680,7 @@ class DecompositionInterpreter(fx.Interpreter): self.decomposition_table = decomposition_table or {} self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") + # pyrefly: ignore # bad-override def placeholder( self, target: str, # type: ignore[override] @@ -1684,6 +1693,7 @@ class DecompositionInterpreter(fx.Interpreter): # TODO handle case where the first character of target is '*' return out + # pyrefly: ignore # bad-override def get_attr( self, target: str, # type: ignore[override] @@ -1697,6 +1707,7 @@ class DecompositionInterpreter(fx.Interpreter): # call_function, call_method, call_module get traced automatically by the outer mode. + # pyrefly: ignore # bad-override def output( self, target: str, # type: ignore[override] @@ -1782,14 +1793,17 @@ class _ModuleStackTracer(PythonKeyTracer): self.enable_attr_proxy = False self.submodule_paths = {} for name, m in self.scope_root.named_modules(remove_duplicate=False): + # pyrefly: ignore # unsupported-operation if m in self.submodule_paths: log.info( "Shared module found between %s and %s, AttrProxy is enabled.", + # pyrefly: ignore # unsupported-operation self.submodule_paths[m], name, ) self.enable_attr_proxy = True else: + # pyrefly: ignore # unsupported-operation self.submodule_paths[m] = name self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary() @@ -1815,6 +1829,7 @@ class _ModuleStackTracer(PythonKeyTracer): # Class is modified to be a subclass of torch.nn.Module # Warning: We blow away our own attributes here to mimic the base class # - so don't expect `self.x` to do anything useful. + # pyrefly: ignore # no-matching-overload self.__class__ = type( base.__class__.__name__, (self.__class__, base.__class__), @@ -1837,6 +1852,7 @@ class _ModuleStackTracer(PythonKeyTracer): if not isinstance(attr_val, Module): return attr_val + # pyrefly: ignore # index-error return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name) def get_base(self) -> Module: @@ -1849,10 +1865,12 @@ class _ModuleStackTracer(PythonKeyTracer): res = torch.nn.Sequential( OrderedDict(list(self._modules.items())[idx]) ) + # pyrefly: ignore # index-error return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}") elif isinstance(self, torch.nn.ModuleList): # Copied from nn/modules/container.py res = torch.nn.ModuleList(list(self._modules.values())[idx]) + # pyrefly: ignore # index-error return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}") return super().__getitem__(idx) # type: ignore[misc] diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 34ab1d74c307..77c4c482ae91 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -839,6 +839,7 @@ def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) if factor == 1: return expr + # pyrefly: ignore # bad-argument-type atoms = [div_by_factor(x, factor) for x in atoms] return _sympy_from_args( sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative @@ -1234,6 +1235,7 @@ def _free_unbacked_symbols_with_path( else _symint_wrap(coeff) ) # TODO: DivideByKey needs to test divisibility at runtime! + # pyrefly: ignore # unsupported-operation r[unbacked] = path + (DivideByKey(divisor),) if real is not None: assert isinstance(real, int) @@ -1256,6 +1258,7 @@ def _free_unbacked_symbols_with_path( and s.rhs == 1 and s.lhs in pending ): + # pyrefly: ignore # unsupported-operation r[s.lhs] = path + (ConvertIntKey(),) if real is not None: assert type(real) is bool @@ -2172,6 +2175,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext): def __post_init__(self) -> None: super().__post_init__() if self.inner_contexts is None: + # pyrefly: ignore # bad-assignment self.inner_contexts = {} @@ -2260,9 +2264,12 @@ def _fast_expand(expr: _SympyT) -> _SympyT: # only re-create the objects if any of the args changed to avoid expensive # checks when re-creating objects. new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type] + # pyrefly: ignore # missing-attribute if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): + # pyrefly: ignore # missing-attribute return _fast_expand(expr.func(*new_args)) + # pyrefly: ignore # missing-attribute if expr.is_Pow: base: sympy.Expr exp: sympy.Expr @@ -2272,9 +2279,11 @@ def _fast_expand(expr: _SympyT) -> _SympyT: return sympy.expand_multinomial(expr, deep=False) elif exp < 0: return S.One / sympy.expand_multinomial(S.One / expr, deep=False) + # pyrefly: ignore # missing-attribute elif expr.is_Mul: num: list[sympy.Expr] = [] den: list[sympy.Expr] = [] + # pyrefly: ignore # missing-attribute for arg in expr.args: if arg.is_Pow and arg.args[1] == -1: den.append(S.One / arg) # type: ignore[operator, arg-type] @@ -2396,6 +2405,7 @@ def _maybe_evaluate_static_worker( # TODO: remove this try catch (esp for unbacked_only) try: + # pyrefly: ignore # missing-attribute new_expr = expr.xreplace(new_shape_env) except RecursionError: log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env) @@ -2933,13 +2943,19 @@ class DimConstraints: # is_integer tests though haha return (base - mod_reduced) / divisor + # pyrefly: ignore # missing-attribute if expr.has(Mod): + # pyrefly: ignore # missing-attribute expr = expr.replace(Mod, mod_handler) # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative # arguments should be OK. + # pyrefly: ignore # missing-attribute if expr.has(PythonMod): + # pyrefly: ignore # missing-attribute expr = expr.replace(PythonMod, mod_handler) + # pyrefly: ignore # missing-attribute if expr.has(FloorDiv): + # pyrefly: ignore # missing-attribute expr = expr.replace(FloorDiv, floor_div_handler) return expr @@ -5057,6 +5073,7 @@ class ShapeEnv: if duck: # Make sure to reuse this symbol for subsequent duck shaping + # pyrefly: ignore # unsupported-operation self.val_to_var[val] = sympy_expr if isinstance(val, int): @@ -5288,15 +5305,19 @@ class ShapeEnv: # Expand optional inputs, or verify invariants are upheld if input_contexts is None: + # pyrefly: ignore # bad-assignment input_contexts = [ + # pyrefly: ignore # bad-argument-type _create_no_constraints_context(t) if isinstance(t, Tensorlike) else None for t in placeholders ] else: assert len(input_contexts) == len(placeholders) + # pyrefly: ignore # bad-assignment for i, (t, context) in enumerate(zip(placeholders, input_contexts)): if isinstance(t, Tensorlike): if context is None: + # pyrefly: ignore # bad-argument-type input_contexts[i] = _create_no_constraints_context(t) else: assert isinstance(t, (SymInt, int, SymFloat, float)) @@ -5582,6 +5603,7 @@ class ShapeEnv: s = sympy.Float(val) input_guards.append((source, s)) + # pyrefly: ignore # no-matching-overload for t, source, context in zip(placeholders, sources, input_contexts): if isinstance(source, str): from torch._dynamo.source import LocalSource @@ -5641,11 +5663,13 @@ class ShapeEnv: ) track_symint(property_source, ss, constraint_size[i]) else: + # pyrefly: ignore # missing-attribute for i, ss in enumerate(curr_t.size()): property_source = TensorPropertySource( src, TensorProperty.SIZE, i ) track_symint(property_source, ss, constraint_size[i]) + # pyrefly: ignore # missing-attribute for i, ss in enumerate(curr_t.stride()): property_source = TensorPropertySource( src, TensorProperty.STRIDE, i @@ -5653,6 +5677,7 @@ class ShapeEnv: track_symint(property_source, ss, constraint_stride[i]) track_symint( TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), + # pyrefly: ignore # missing-attribute curr_t.storage_offset(), ) @@ -5698,6 +5723,7 @@ class ShapeEnv: continue if is_dim(source): + # pyrefly: ignore # missing-attribute self.dim_constraints.add_equality(source, expr) for exprs, printer, lang in zip(all_exprs, printers, langs): @@ -5851,6 +5877,7 @@ class ShapeEnv: continue expr = self.simplify(ra.expr) + # pyrefly: ignore # missing-attribute self.dim_constraints.add(expr) # 3. Every symbol must be within its value range (this handles 0/1 @@ -5867,6 +5894,7 @@ class ShapeEnv: verbose_expr = "" if r.lower not in (-sympy.oo, -int_oo): if any(is_dim(source) for source in sources): + # pyrefly: ignore # missing-attribute self.dim_constraints.add(sympy.Ge(symbol, r.lower)) # Only print lower bound in simplified mode if it is not the # default @@ -5875,6 +5903,7 @@ class ShapeEnv: verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}" if r.upper not in (sympy.oo, int_oo): if any(is_dim(source) for source in sources): + # pyrefly: ignore # missing-attribute self.dim_constraints.add(sympy.Le(symbol, r.upper)) # nontrivial upper bound is always interesting bounds.append(sympy.Le(symbol, r.upper, evaluate=False)) @@ -5943,6 +5972,7 @@ class ShapeEnv: else: str_msg = f" - {msg_cb()}" error_msgs.append(str_msg) + # pyrefly: ignore # bad-argument-type debug_names.add(debug_name) if len(error_msgs) > 0: debug_names_str = ", ".join(sorted(debug_names)) @@ -6076,6 +6106,7 @@ class ShapeEnv: Get a list of guards, but pruned so it only provides guards that reference symints from the passed in input """ + # pyrefly: ignore # bad-assignment symints = { s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol) } @@ -6121,6 +6152,7 @@ class ShapeEnv: else: bindings[-s] = -arg + # pyrefly: ignore # bad-assignment for t, arg in zip(placeholders, args): if t is None: continue @@ -6338,6 +6370,7 @@ class ShapeEnv: Apply symbol replacements to any symbols in the given expression. """ replacements = {} + # pyrefly: ignore # missing-attribute for s in expr.free_symbols: r = self._find(s) @@ -6347,6 +6380,7 @@ class ShapeEnv: if not r.is_Symbol or r != s: replacements[s] = r if replacements: + # pyrefly: ignore # missing-attribute return safe_expand(expr.xreplace(replacements)) else: return expr @@ -7121,6 +7155,7 @@ class ShapeEnv: instructions = list(dis.Bytecode(frame.f_code)) co_lines, offset = inspect.getsourcelines(frame.f_code) start, end, cur = None, None, None + # pyrefly: ignore # bad-assignment for i, instr in enumerate(instructions): if instr.starts_line is not None: cur = instr.starts_line @@ -8000,6 +8035,7 @@ def _suggest_fixes_for_data_dependent_error_non_strict( if isinstance(leaf, torch.SymInt): src_map[str(leaf.node.expr)].append(name) elif isinstance(leaf, torch.Tensor): + # pyrefly: ignore # bad-assignment for i, dim in enumerate(leaf.shape): if isinstance(dim, torch.SymInt): src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]") diff --git a/torch/fx/experimental/unification/multipledispatch/__init__.py b/torch/fx/experimental/unification/multipledispatch/__init__.py index bb7304069243..b7d633ac1cee 100644 --- a/torch/fx/experimental/unification/multipledispatch/__init__.py +++ b/torch/fx/experimental/unification/multipledispatch/__init__.py @@ -1,5 +1,5 @@ from .core import dispatch -from .dispatcher import ( +from .dispatcher import ( # pyrefly: ignore # deprecated Dispatcher, halt_ordering, MDNotImplementedError, diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index 28eac85b0180..f1b229291887 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -238,6 +238,7 @@ class Dispatcher: "To use a variadic union type place the desired types " "inside of a tuple, e.g., [(int, str)]" ) + # pyrefly: ignore # bad-specialization new_signature.append(Variadic[typ[0]]) else: new_signature.append(typ) diff --git a/torch/fx/experimental/unification/unification_tools.py b/torch/fx/experimental/unification/unification_tools.py index a47d900273f5..f29bc8b52550 100644 --- a/torch/fx/experimental/unification/unification_tools.py +++ b/torch/fx/experimental/unification/unification_tools.py @@ -298,6 +298,7 @@ def update_in(d, keys, func, default=None, factory=dict): rv = inner = factory() rv.update(d) + # pyrefly: ignore # not-iterable for key in ks: if k in d: d = d[k] diff --git a/torch/fx/graph.py b/torch/fx/graph.py index fa2052f60ec7..49eaadb483a9 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1836,6 +1836,7 @@ class Graph: "a str is expected" ) if node.op in ["get_attr", "call_module"]: + # pyrefly: ignore # missing-attribute target_atoms = node.target.split(".") m_itr = self.owning_module for i, atom in enumerate(target_atoms): diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index f4496338fffc..3f5c15375528 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -533,6 +533,7 @@ class GraphModule(torch.nn.Module): self.graph._tracer_cls and "" not in self.graph._tracer_cls.__qualname__ ): + # pyrefly: ignore # bad-assignment self._tracer_cls = self.graph._tracer_cls self._tracer_extras = {} diff --git a/torch/fx/node.py b/torch/fx/node.py index dbd6ed93ef26..321cbfbf2f3b 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -59,6 +59,7 @@ Argument = Optional[ BaseArgumentTypes, ] ] +# pyrefly: ignore # invalid-annotation ArgumentT = TypeVar("ArgumentT", bound=Argument) _P = ParamSpec("_P") _R = TypeVar("_R") diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 618a0fa8b413..81ca1402a6c7 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -120,6 +120,7 @@ def _torchscript_schema_to_signature_impl( # which makes it hard to do type annotation kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment] # This renders all previous arguments to positional only + # pyrefly: ignore # bad-assignment for idx, p in enumerate(parameters): assert p.kind == Parameter.POSITIONAL_OR_KEYWORD parameters[idx] = Parameter( @@ -128,6 +129,7 @@ def _torchscript_schema_to_signature_impl( default=p.default, annotation=p.annotation, ) + # pyrefly: ignore # missing-attribute parameters.append( Parameter(name=name, kind=kind, default=default, annotation=arg_type) ) @@ -141,6 +143,7 @@ def _torchscript_schema_to_signature_impl( else: return_type = tuple(return_types) + # pyrefly: ignore # bad-argument-type return inspect.Signature(parameters, return_annotation=return_type) diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index dd8edb50e161..5d80d47ea2ba 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -164,9 +164,13 @@ def tensorify_python_scalars( c = float(expr) node = graph.call_function( - torch.ops.aten.scalar_tensor.default, (c,), {"dtype": dtype} + torch.ops.aten.scalar_tensor.default, + # pyrefly: ignore # unbound-name + (c,), + {"dtype": dtype}, ) with fake_mode: + # pyrefly: ignore # unbound-name node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype) expr_to_tensor_proxy[expr] = MetaProxy( node, @@ -219,19 +223,25 @@ def tensorify_python_scalars( expr_to_sym_proxy[s] = MetaProxy( node, tracer=tracer, fake_mode=fake_mode ) + # pyrefly: ignore # bad-argument-type elif (sym_expr := _get_sym_val(node)) is not None: if sym_expr not in expr_to_sym_proxy and not isinstance( sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) ): expr_to_sym_proxy[sym_expr] = MetaProxy( - node, tracer=tracer, fake_mode=fake_mode + # pyrefly: ignore # bad-argument-type + node, + tracer=tracer, + fake_mode=fake_mode, ) # Specialize all dimensions that contain symfloats. Here's # an example test that requires this: # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950 + # pyrefly: ignore # missing-attribute val = node.meta.get("val") if isinstance(val, FakeTensor): + # pyrefly: ignore # bad-assignment for dim in val.shape: if isinstance(dim, torch.SymInt): for s in dim.node.expr.free_symbols: @@ -248,13 +258,17 @@ def tensorify_python_scalars( should_restart = True # Look for functions to convert + # pyrefly: ignore # missing-attribute if node.op == "call_function" and ( + # pyrefly: ignore # missing-attribute replacement_op := SUPPORTED_OPS.get(node.target) ): args: list[Any] = [] transform = False + # pyrefly: ignore # missing-attribute compute_dtype = get_computation_dtype(node.meta["val"].dtype) + # pyrefly: ignore # missing-attribute for a in node.args: if ( isinstance(a, fx.Node) @@ -291,6 +305,7 @@ def tensorify_python_scalars( if transform: replacement_proxy = replacement_op(*args) + # pyrefly: ignore # missing-attribute if compute_dtype != node.meta["val"].dtype: replacement_proxy = ( torch.ops.prims.convert_element_type.default( @@ -299,7 +314,9 @@ def tensorify_python_scalars( ) ) + # pyrefly: ignore # missing-attribute node.replace_all_uses_with(replacement_proxy.node) + # pyrefly: ignore # bad-argument-type graph.erase_node(node) metrics_context = get_metrics_context() @@ -308,13 +325,16 @@ def tensorify_python_scalars( "tensorify_float_success", True, overwrite=True ) else: + # pyrefly: ignore # missing-attribute for a in node.args: if ( isinstance(a, fx.Node) and "val" in a.meta and isinstance(zf := a.meta["val"], torch.SymFloat) ): + # pyrefly: ignore # missing-attribute failed_tensorify_ops.update(str(node.target)) + # pyrefly: ignore # missing-attribute log.info("Failed to tensorify %s", str(node.target)) # Now do one more pass that specializes all symfloats we didn't manage diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index a5445a6851fa..313766d51028 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -437,11 +437,14 @@ if HAS_PYDOT: ) current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment] + # pyrefly: ignore # missing-attribute current_graph.add_node(dot_node) def get_module_params_or_buffers(): for pname, ptensor in chain( - leaf_module.named_parameters(), leaf_module.named_buffers() + leaf_module.named_parameters(), + # pyrefly: ignore # bad-argument-type + leaf_module.named_buffers(), ): pname1 = node.name + "." + pname label1 = ( diff --git a/torch/fx/passes/infra/pass_base.py b/torch/fx/passes/infra/pass_base.py index 957b8145f995..ef8e79e57869 100644 --- a/torch/fx/passes/infra/pass_base.py +++ b/torch/fx/passes/infra/pass_base.py @@ -11,6 +11,7 @@ __all__ = ["PassResult", "PassBase"] @compatibility(is_backward_compatible=False) +# pyrefly: ignore # invalid-inheritance class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): """ Result of a pass: diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index 4077e74360f5..826e998f5c9c 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -31,6 +31,7 @@ def pass_result_wrapper(fn: Callable) -> Callable: wrapped_fn (Callable[Module, PassResult]) """ if fn is None: + # pyrefly: ignore # bad-return return None @wraps(fn) @@ -273,6 +274,7 @@ class PassManager: logger.debug("Running pass '%s'", fn_name) try: + # pyrefly: ignore # not-callable res = fn(module) if not isinstance(res, PassResult) and not hasattr( diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 19e101a5c120..f05982f1adea 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -358,6 +358,7 @@ def insert_deferred_runtime_asserts( ): # this guards against deleting calls like item() that produce new untracked symbols def has_new_untracked_symbols(): + # pyrefly: ignore # missing-attribute for symbol in sym_expr.free_symbols: if symbol not in expr_to_proxy: return True @@ -373,6 +374,7 @@ def insert_deferred_runtime_asserts( assert resolved_unbacked_bindings is not None def has_new_unbacked_bindings(): + # pyrefly: ignore # missing-attribute for key in resolved_unbacked_bindings.keys(): if key not in expr_to_proxy: return True diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 9c301c462c6c..52fbbaeaa1be 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -350,7 +350,9 @@ def split_module( assert all(v is not None for v in autocast_exits.values()), "autocast must exit" + # pyrefly: ignore # bad-assignment autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()} + # pyrefly: ignore # bad-assignment grad_regions = {k: sorted(v) for k, v in grad_regions.items()} if _LOGGER.isEnabledFor(logging.DEBUG): @@ -415,7 +417,9 @@ def split_module( for regions_mapping in [autocast_regions, grad_regions]: for node, regions in regions_mapping.items(): assert len(regions) > 0 + # pyrefly: ignore # index-error partitions[str(regions[0])].environment[node] = node + # pyrefly: ignore # index-error for r in regions[1:]: partition = partitions[str(r)] new_node = partition.graph.create_node( @@ -515,6 +519,7 @@ def split_module( for node in reversed(regions_mapping): regions = regions_mapping[node] assert len(regions) > 0 + # pyrefly: ignore # index-error for r in regions[:-1]: partition = partitions[str(r)] exit_node = autocast_exits[node] diff --git a/torch/fx/passes/utils/common.py b/torch/fx/passes/utils/common.py index 17362c9eec12..3924a93d22cf 100644 --- a/torch/fx/passes/utils/common.py +++ b/torch/fx/passes/utils/common.py @@ -64,6 +64,7 @@ def lift_subgraph_as_module( for name in target_name_parts[:-1]: if not hasattr(curr, name): + # pyrefly: ignore # missing-attribute curr.add_module(name, HolderModule({})) curr = getattr(curr, name) diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index a8d520ea877d..c719a01708c9 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -166,7 +166,6 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False): cu = torch._C.CompilationUnit() if isinstance(f, (str, os.PathLike)): cpp_module = torch._C.import_ir_module( - # pyrefly: ignore # no-matching-overload, bad-argument-count cu, # pyrefly: ignore # no-matching-overload os.fspath(f), @@ -177,7 +176,6 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False): ) # type: ignore[call-arg] else: cpp_module = torch._C.import_ir_module_from_buffer( - # pyrefly: ignore # missing-attribute, bad-argument-count cu, # pyrefly: ignore # missing-attribute f.read(), diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index ee83d0f346bd..df0bf2cd1a22 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -37,7 +37,7 @@ from ._internal.torchscript_exporter._type_utils import ( ) from ._internal.torchscript_exporter.utils import ( # Deprecated members that are excluded from __all__ register_custom_op_symbolic, - select_model_mode_for_export, + select_model_mode_for_export, # pyrefly: ignore # deprecated unregister_custom_op_symbolic, ) from .errors import OnnxExporterError diff --git a/torch/onnx/_internal/exporter/_analysis.py b/torch/onnx/_internal/exporter/_analysis.py index 53860413526e..1ff8506283bd 100644 --- a/torch/onnx/_internal/exporter/_analysis.py +++ b/torch/onnx/_internal/exporter/_analysis.py @@ -122,6 +122,7 @@ def _format_model_info(model_info: ModelInfo) -> str: target_to_nodes = defaultdict(list) for node, _ in model_info.dispatch_failures: + # pyrefly: ignore # index-error target_to_nodes[str(node.target)].append(node) target_to_messages = {} diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index 64319ac427fe..dbe38f81680c 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -267,6 +267,7 @@ def _get_or_create_constant( # float representation of complex numbers if isinstance(arg, complex): # Convert the complex number to a float + # pyrefly: ignore # bad-assignment arg = (arg.real, arg.imag) if isinstance(arg, list): diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index b7f9016ae6d0..dc2f39990fec 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -47,6 +47,7 @@ def _patch_dynamo_unsupported_functions(): # Replace torch.jit.isinstance with isinstance jit_isinstance = torch.jit.isinstance + # pyrefly: ignore # bad-assignment torch.jit.isinstance = isinstance logger.info("Replaced torch.jit.isinstance with isinstance to allow dynamo tracing") try: diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 7e7f206c80fd..f46601eed261 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -132,8 +132,10 @@ class TorchTensor(ir.Tensor): # view the tensor as that dtype so that it is convertible to NumPy, # and then view it back to the proper dtype (using ml_dtypes obtained by # calling dtype.numpy()). + # pyrefly: ignore # missing-attribute if self.dtype == ir.DataType.BFLOAT16: return ( + # pyrefly: ignore # missing-attribute self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) ) if self.dtype in { @@ -142,9 +144,11 @@ class TorchTensor(ir.Tensor): ir.DataType.FLOAT8E5M2, ir.DataType.FLOAT8E5M2FNUZ, }: + # pyrefly: ignore # missing-attribute return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) if self.dtype == ir.DataType.FLOAT4E2M1: return _type_casting.unpack_float4x2_as_uint8(self.raw).view( + # pyrefly: ignore # missing-attribute self.dtype.numpy() ) @@ -168,6 +172,7 @@ class TorchTensor(ir.Tensor): if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): raise TypeError( + # pyrefly: ignore # missing-attribute f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " "with a tensor backed by real data using ONNXProgram.apply_weights() " "or save the model without initializers by setting include_initializers=False." @@ -238,6 +243,7 @@ def _set_shape_type( if isinstance(dim, int): dims.append(dim) else: + # pyrefly: ignore # bad-argument-type dims.append(str(dim.node)) # If the dtype is set already (e.g. by the onnx_symbolic ops), @@ -1212,6 +1218,7 @@ def _exported_program_to_onnx_program( # so we need to get them from the name_* apis. for name, torch_tensor in itertools.chain( exported_program.named_parameters(), + # pyrefly: ignore # bad-argument-type exported_program.named_buffers(), exported_program.constants.items(), ): diff --git a/torch/onnx/_internal/exporter/_isolated.py b/torch/onnx/_internal/exporter/_isolated.py index 461590ec9eb4..141c8ad754cf 100644 --- a/torch/onnx/_internal/exporter/_isolated.py +++ b/torch/onnx/_internal/exporter/_isolated.py @@ -26,6 +26,7 @@ def _call_function_and_return_exception( """Call function and return a exception if there is one.""" try: + # pyrefly: ignore # bad-argument-type return func(*args, **kwargs) except Exception as e: return e diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py index 17f646c93375..0afd81ce4e23 100644 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -157,6 +157,7 @@ def _to_ort_value(input: torch.Tensor | int | float | str | bool) -> ort.OrtValu int: np.int64, float: np.float32, } + # pyrefly: ignore # no-matching-overload dtype = dtype_mapping.get(type(input), None) return ort.OrtValue.ortvalue_from_numpy(np.array(input, dtype=dtype)) @@ -252,6 +253,7 @@ ONNXProgram( run_options = ort.RunOptions() run_options.log_severity_level = 3 # 3: Error logger.debug("Running the inference session with %s arguments.", len(ort_input)) + # pyrefly: ignore # missing-attribute outputs = self._inference_session.run_with_ort_values( None, ort_input, run_options=run_options ) diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py index 0dd23819af11..f4c7cfbf5127 100644 --- a/torch/onnx/_internal/exporter/_registration.py +++ b/torch/onnx/_internal/exporter/_registration.py @@ -64,8 +64,11 @@ class OnnxDecompMeta: if isinstance(self.onnx_function, onnxscript.OnnxFunction): signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] self.onnx_function, + # pyrefly: ignore # missing-attribute self.onnx_function.function_ir.domain, + # pyrefly: ignore # missing-attribute self.onnx_function.name, + # pyrefly: ignore # missing-attribute opset_version=self.onnx_function.opset.version, ) else: diff --git a/torch/onnx/_internal/exporter/_schemas.py b/torch/onnx/_internal/exporter/_schemas.py index 3aa8b0e0c7e2..9ee51980cf5d 100644 --- a/torch/onnx/_internal/exporter/_schemas.py +++ b/torch/onnx/_internal/exporter/_schemas.py @@ -541,6 +541,7 @@ class OpSignature: if ( return_param_name := _get_type_constraint_name(return_type_i) ) in type_constraints: + # pyrefly: ignore # index-error type_constraint = type_constraints[return_param_name] else: return_param_name = f"TReturn{i}" @@ -553,6 +554,7 @@ class OpSignature: type_constraints[return_param_name] = type_constraint outputs.append( Parameter( + # pyrefly: ignore # bad-argument-type name=return_param_name, type_constraint=type_constraint, required=True, diff --git a/torch/onnx/_internal/exporter/_tensors.py b/torch/onnx/_internal/exporter/_tensors.py index 2fdafacbe06f..6a40893e6b6c 100644 --- a/torch/onnx/_internal/exporter/_tensors.py +++ b/torch/onnx/_internal/exporter/_tensors.py @@ -30,13 +30,16 @@ class SymbolicTensor(ir.Value): @property def rank(self) -> int | None: + # pyrefly: ignore # missing-attribute if self.shape is None: return None + # pyrefly: ignore # bad-argument-type return len(self.shape) # TODO: Implement indexing def __mod__(self, other): + # pyrefly: ignore # missing-attribute if self.dtype in { ir.DataType.FLOAT, ir.DataType.DOUBLE, diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py index 1ea9a4161f43..9be57d88a635 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -50,8 +50,10 @@ def aten_group_norm( c = op21.Shape(input, start=1, end=2) if weight is None: + # pyrefly: ignore # missing-attribute weight = op21.ConstantOfShape(c, value=ir.tensor(1.0, dtype=input.dtype)) if bias is None: + # pyrefly: ignore # missing-attribute bias = op21.ConstantOfShape(c, value=ir.tensor(0.0, dtype=input.dtype)) return op21.GroupNormalization( input, weight, bias, epsilon=eps, num_groups=num_groups @@ -80,6 +82,7 @@ def aten_rms_norm( # Create weight tensor if not provided if weight is None: + # pyrefly: ignore # missing-attribute weight = op23.Constant(value=ir.tensor(1.0, dtype=input.dtype)) return op23.RMSNormalization(input, weight, axis=axis, epsilon=eps) @@ -128,6 +131,7 @@ def aten_scaled_dot_product_attention_23( assert (not is_causal) or (is_causal and attn_mask is None), ( "is_causal and attn_mask cannot be set at the same time" ) + # pyrefly: ignore # missing-attribute assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( "only 4D query, key, and value are supported" ) @@ -136,12 +140,15 @@ def aten_scaled_dot_product_attention_23( if dropout_p == 0: if enable_gqa: assert ( + # pyrefly: ignore # index-error query.shape[1] > key.shape[1] == value.shape[1] + # pyrefly: ignore # index-error and query.shape[1] % key.shape[1] == 0 ), ( "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" ) else: + # pyrefly: ignore # index-error assert query.shape[1] == key.shape[1] == value.shape[1], ( "SDPA (MHA) requires q_num_heads = kv_num_heads" ) @@ -202,7 +209,9 @@ def _attention_repeat_kv_for_group_query( """ assert ( + # pyrefly: ignore # missing-attribute query.shape[1] > key.shape[1] == value.shape[1] + # pyrefly: ignore # missing-attribute and query.shape[1] % key.shape[1] == 0 ), ( "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" diff --git a/torch/onnx/_internal/torchscript_exporter/_type_utils.py b/torch/onnx/_internal/torchscript_exporter/_type_utils.py index 81bcaeef1107..d4c1382d2931 100644 --- a/torch/onnx/_internal/torchscript_exporter/_type_utils.py +++ b/torch/onnx/_internal/torchscript_exporter/_type_utils.py @@ -153,6 +153,7 @@ class JitScalarType(enum.IntEnum): """ if dtype not in _DTYPE_TO_SCALAR_TYPE: raise errors.OnnxExporterError(f"Unknown dtype: {dtype}") + # pyrefly: ignore # index-error return _DTYPE_TO_SCALAR_TYPE[dtype] @classmethod diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py b/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py index cd7763bf41ec..bcd36a6ac41b 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py @@ -364,6 +364,7 @@ def parse_args( fn_name = None args = [ _parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign] + # pyrefly: ignore # no-matching-overload for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names) ] # only support _outputs in kwargs @@ -1800,22 +1801,27 @@ def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kw if require_cast: for input in inputs: + # pyrefly: ignore # missing-attribute if input.isCompleteTensor(): input_scalar_type = _type_utils.JitScalarType.from_value(input) if input_scalar_type != dtype_0: raise errors.SymbolicValueError( f"Inputs of {op_name} must have same dtype." f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}", + # pyrefly: ignore # bad-argument-type input, ) for i, input in enumerate(inputs): + # pyrefly: ignore # missing-attribute if input.isCompleteTensor() and not _is_fp(input): inputs[i] = g.op( "Cast", + # pyrefly: ignore # bad-argument-type input, to_i=target_float_t.onnx_type(), ) + # pyrefly: ignore # bad-argument-type self = g.op(op_name, *inputs, **kwargs) if require_cast: diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset10.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset10.py index 6b36396250b4..6bb09ef3ec2a 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset10.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset10.py @@ -205,6 +205,7 @@ def _adjust_attributes_of_max_pool( else: strides = stride # type: ignore[assignment] + # pyrefly: ignore # bad-return return (kernel_shape, strides, pads, dilation) @@ -381,6 +382,7 @@ def _adjust_attributes_of_avg_pool( else: strides = stride # type: ignore[assignment] + # pyrefly: ignore # bad-return return (kernel_shape, strides, pads) @@ -709,6 +711,7 @@ def fake_quantize_per_tensor_affine( "Non-constant scale not supported", inputs, ) + # pyrefly: ignore # missing-attribute scale = scale.float().data # Avoid exporter generating double type if quant_min == 0: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py index f437e2670768..858e81766446 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py @@ -219,6 +219,7 @@ def index_put( if len(indices_list) > 1: for idx_ in range(len(indices_list)): if symbolic_helper._is_bool(indices_list[idx_]): + # pyrefly: ignore # unsupported-operation indices_list[idx_] = g.op("NonZero", indices_list[idx_]) index = indices_list[0] @@ -819,6 +820,7 @@ def arange(g: jit_utils.GraphContext, *args): "Constant", value_t=torch.tensor(1, dtype=type_.dtype()), ) + # pyrefly: ignore # bad-argument-type return g.op("Range", start_default, end, delta_default) elif len(args) == 4 or len(args) == 7: if len(args) == 4: @@ -830,6 +832,7 @@ def arange(g: jit_utils.GraphContext, *args): _, end, start, step = symbolic_helper._arange_cast_helper( g, start=args[0], end=args[1], step=args[2], dtype=dtype ) + # pyrefly: ignore # bad-argument-type return g.op("Range", start, end, step) elif len(args) == 6: # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) @@ -841,6 +844,7 @@ def arange(g: jit_utils.GraphContext, *args): "Constant", value_t=torch.tensor(1, dtype=type_.dtype()), ) + # pyrefly: ignore # bad-argument-type return g.op("Range", start, end, delta_default) else: return symbolic_helper._unimplemented( diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset12.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset12.py index 431660409717..822e14556768 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset12.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset12.py @@ -363,7 +363,10 @@ def unfold(g: jit_utils.GraphContext, input, dimension, size, step): concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) cond_out = loop_context.op( - "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL + "Cast", + loop_condition, + # pyrefly: ignore # bad-argument-type + _C_onnx.TensorProtoDataType.BOOL, ) utils._add_output_to_block(loop_block, cond_out) utils._add_output_to_block(loop_block, concat) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset13.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset13.py index e9da6a426f7f..9deb479a7ceb 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset13.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset13.py @@ -96,6 +96,7 @@ def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=No split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") if split_val.dim() > 0: + # pyrefly: ignore # bad-argument-type return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") @@ -112,6 +113,7 @@ def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=No if leftover: splits.append(leftover) splits = g.op("Constant", value_t=torch.tensor(splits)) + # pyrefly: ignore # bad-argument-type return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) @@ -186,6 +188,7 @@ def tensor_split( splits = g.op( "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long) ) + # pyrefly: ignore # bad-argument-type return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) if ( @@ -311,6 +314,7 @@ def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs= return symbolic_helper._unbind_helper( g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs ) + # pyrefly: ignore # bad-argument-type return g.op("Where", condition, self, other) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py index 5675f362893e..3e6752506bd4 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py @@ -167,6 +167,7 @@ def scaled_dot_product_attention( # NOTE: onnx-script has different logic here, because the attribute perms in # transpose needs list of ints key_shape_builtin = symbolic_helper._get_tensor_rank(key) + # pyrefly: ignore # no-matching-overload key_transposed_axes = list(range(key_shape_builtin)) key_transposed_axes[-1], key_transposed_axes[-2] = ( key_transposed_axes[-2], @@ -176,7 +177,9 @@ def scaled_dot_product_attention( # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + # pyrefly: ignore # bad-argument-type query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) + # pyrefly: ignore # bad-argument-type key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) @@ -190,6 +193,7 @@ def scaled_dot_product_attention( # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) const_zero = g.op("Constant", value_t=torch.tensor([0.0])) const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) + # pyrefly: ignore # bad-argument-type attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) mul_qk_add = g.op("Add", mul_qk, attn_mask) attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) @@ -203,6 +207,7 @@ def scaled_dot_product_attention( _type_utils.JitScalarType.HALF, _type_utils.JitScalarType.BFLOAT16, ): + # pyrefly: ignore # bad-argument-type mul_qk_add = g.op("Add", mul_qk, attn_mask) attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) else: diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset17.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset17.py index e8ea41e64306..42acd954520a 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset17.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset17.py @@ -176,6 +176,7 @@ def stft( ) # Get window and make sure it's the same size as `win_length` or `n_fft` + # pyrefly: ignore # bad-argument-type n_win = symbolic_helper._get_tensor_dim_size(window, dim=0) if n_win is not None: win_length_default = win_length if win_length else n_fft @@ -189,6 +190,7 @@ def stft( left, right = _compute_edge_sizes(n_fft, n_win) left_win = g.op("Constant", value_t=torch.zeros(left)) right_win = g.op("Constant", value_t=torch.zeros(right)) + # pyrefly: ignore # bad-argument-type window = g.op("Concat", left_win, window, right_win, axis_i=0) # Create window, if needed @@ -212,7 +214,10 @@ def stft( assert torch_window.shape[0] == n_fft window = g.op("Constant", value_t=torch_window) window = g.op( - "Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type() + "Cast", + # pyrefly: ignore # bad-argument-type + window, + to_i=_type_utils.JitScalarType.from_value(signal).onnx_type(), ) # Run STFT diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset18.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset18.py index 6a5ac408fb1b..f8ff787df9c0 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset18.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset18.py @@ -151,6 +151,7 @@ def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): @_onnx_symbolic("aten::maximum") @symbolic_helper.quantized_args(True, True) def maximum(g: jit_utils.GraphContext, input, other): + # pyrefly: ignore # no-matching-overload return max(g, input, dim_or_y=other) @@ -163,6 +164,7 @@ def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): @_onnx_symbolic("aten::minimum") @symbolic_helper.quantized_args(True, True) def minimum(g: jit_utils.GraphContext, input, other): + # pyrefly: ignore # no-matching-overload return min(g, input, dim_or_y=other) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py index 65657f6a91c2..9b7aba64ef31 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py @@ -1051,6 +1051,7 @@ def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=No leftover = size % split_size if leftover: splits.append(leftover) + # pyrefly: ignore # bad-argument-type return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) @@ -1068,6 +1069,7 @@ def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs return symbolic_helper._onnx_opset_unsupported_detailed( "split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self ) + # pyrefly: ignore # bad-argument-type return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) @@ -1705,6 +1707,7 @@ def _adaptive_pool(name, type, tuple_fn, fn=None): k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] # call max_poolxd_with_indices to get indices in the output if type == "MaxPool": + # pyrefly: ignore # not-callable return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) return output @@ -1759,6 +1762,7 @@ def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): ) padding = _convert_padding_node(padding) + # pyrefly: ignore # bad-argument-type paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) return symbolic_helper._op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 @@ -1812,6 +1816,7 @@ def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): def reflection_pad(g: jit_utils.GraphContext, input, padding): mode = "reflect" padding = _convert_padding_node(padding) + # pyrefly: ignore # bad-argument-type paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) return symbolic_helper._op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 @@ -1824,6 +1829,7 @@ def reflection_pad(g: jit_utils.GraphContext, input, padding): def replication_pad(g: jit_utils.GraphContext, input, padding): mode = "edge" padding = _convert_padding_node(padding) + # pyrefly: ignore # bad-argument-type paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) return symbolic_helper._op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 @@ -2204,6 +2210,7 @@ def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs= return symbolic_helper._unbind_helper( g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs ) + # pyrefly: ignore # bad-argument-type return g.op("Where", condition, self, other) @@ -2379,6 +2386,7 @@ def _convolution_mode( "group_i": groups, } + # pyrefly: ignore # bad-argument-type n = g.op("Conv", *args, **kwargs) if ( @@ -2723,10 +2731,12 @@ def native_layer_norm( # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula if g.opset < 18: + # pyrefly: ignore # no-matching-overload variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) else: variance = g.op( "ReduceMean", + # pyrefly: ignore # no-matching-overload pow(g, numerator, two_cst), g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), ) @@ -3065,10 +3075,12 @@ def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim ) summation = symbolic_helper._reducesum_helper( g, + # pyrefly: ignore # no-matching-overload pow(g, sub(g, input1, input2), p), axes_i=[-1], keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), ) + # pyrefly: ignore # no-matching-overload return pow(g, summation, inv_p) @@ -3178,6 +3190,7 @@ def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): @_onnx_symbolic("aten::maximum") @symbolic_helper.quantized_args(True, True) def maximum(g: jit_utils.GraphContext, input, other): + # pyrefly: ignore # no-matching-overload return max(g, input, dim_or_y=other) @@ -3190,6 +3203,7 @@ def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): @_onnx_symbolic("aten::minimum") @symbolic_helper.quantized_args(True, True) def minimum(g: jit_utils.GraphContext, input, other): + # pyrefly: ignore # no-matching-overload return min(g, input, dim_or_y=other) @@ -3486,6 +3500,7 @@ def zeros_like( input, _type_utils.JitScalarType.FLOAT ) else: + # pyrefly: ignore # bad-argument-type scalar_type = _type_utils.JitScalarType(dtype) return g.op( "ConstantOfShape", @@ -3545,6 +3560,7 @@ def ones_like( input, _type_utils.JitScalarType.FLOAT ) else: + # pyrefly: ignore # bad-argument-type scalar_type = _type_utils.JitScalarType(dtype) return g.op( "ConstantOfShape", @@ -5534,6 +5550,7 @@ def linalg_matrix_norm( g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim ) if ord_value > 0: + # pyrefly: ignore # no-matching-overload result, _indices = max( g, sum, @@ -5541,6 +5558,7 @@ def linalg_matrix_norm( keepdim=keepdim, ) else: + # pyrefly: ignore # no-matching-overload result, _indices = min( g, sum, @@ -5904,7 +5922,9 @@ def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): else: ind = g.op("Add", ind, tmp_ind) if offset: + # pyrefly: ignore # bad-argument-type ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) + # pyrefly: ignore # bad-argument-type return g.op("Gather", self_1d, ind) @@ -6187,6 +6207,7 @@ def _euclidean_dist(g: jit_utils.GraphContext, x1, x2): assert rank is not None x1_norm = symbolic_helper._reducesum_helper( g, + # pyrefly: ignore # no-matching-overload pow(g, x1, symbolic_helper._generate_wrapped_number(g, 2.0)), axes_i=[-1], keepdims_i=True, @@ -6194,6 +6215,7 @@ def _euclidean_dist(g: jit_utils.GraphContext, x1, x2): x1_pad = ones_like(g, x1_norm) x2_norm = symbolic_helper._reducesum_helper( g, + # pyrefly: ignore # no-matching-overload pow(g, x2, symbolic_helper._generate_wrapped_number(g, 2.0)), axes_i=[-1], keepdims_i=True, diff --git a/torch/onnx/_internal/torchscript_exporter/verification.py b/torch/onnx/_internal/torchscript_exporter/verification.py index 3bf8cba1c8d6..f8e2d37ba737 100644 --- a/torch/onnx/_internal/torchscript_exporter/verification.py +++ b/torch/onnx/_internal/torchscript_exporter/verification.py @@ -239,7 +239,7 @@ def _compare_onnx_pytorch_outputs_in_np( if acceptable_error_percentage: error_percentage = 1 - np.sum( np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) - ) / np.prod(ort_out.shape) + ) / np.prod(ort_out.shape) # pyrefly: ignore # missing-attribute if error_percentage <= acceptable_error_percentage: warnings.warn( f"Suppressed AssertionError:\n{e}.\n" @@ -247,8 +247,10 @@ def _compare_onnx_pytorch_outputs_in_np( f"within acceptable range {acceptable_error_percentage}." ) continue + # pyrefly: ignore # missing-attribute if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: warnings.warn("ONNX output is quantized") + # pyrefly: ignore # missing-attribute if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: warnings.warn("PyTorch output is quantized") raise diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 6b1d752bb04e..c50676eda781 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -5,4 +5,5 @@ from __future__ import annotations __all__: list[str] = [] +# pyrefly: ignore # deprecated from torch.onnx._internal.torchscript_exporter.utils import * # noqa: F401,F403 diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 6f12315e78c0..2425a253a914 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -958,7 +958,6 @@ class Optimizer: r"""Make a deep copy of value, casting all tensors to device of param.""" if isinstance(value, torch.Tensor): return Optimizer._process_value_according_to_param_policy( - # pyrefly: ignore # bad-argument-type param, value, # pyrefly: ignore # bad-argument-type