Add pyrefly suppressions (#164748)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the `project-excludes` field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

0 errors (4,263 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-07 17:31:18 +00:00
committed by PyTorch MergeBot
parent 5e47b4dd60
commit b13cd141b3
139 changed files with 504 additions and 30 deletions

View File

@ -25,10 +25,6 @@ project-excludes = [
"torch/nn/**",
"torch/_dynamo/**",
"torch/utils/**",
"torch/ao/**",
"torch/fx/**",
"torch/distributions/**",
"torch/onnx/**",
# formatting issues
"torch/linalg/__init__.py",
"torch/package/importer.py",

View File

@ -470,7 +470,6 @@ def _check_input_constraints_for_graph(
)
elif isinstance(node_val, torch.SymInt):
_check_symint(
# pyrefly: ignore # bad-argument-type
node_val,
# pyrefly: ignore # bad-argument-type
arg,

View File

@ -360,7 +360,6 @@ def trace_flex_attention(
"call_function", flex_attention, proxy_args, {}
)
return track_tensor_tree(
# pyrefly: ignore # bad-argument-type
example_out,
out_proxy,
constant=None,
@ -1080,7 +1079,6 @@ def trace_flex_attention_backward(
name="flex_attention_backward",
)
return track_tensor_tree(
# pyrefly: ignore # bad-argument-type
example_out,
out_proxy,
constant=None,

View File

@ -899,7 +899,6 @@ def analyze_kernel_mutations(
if op.name == "tt.call":
assert op.fn_call_name in functions
mutations = analyze_kernel_mutations(
# pyrefly: ignore # bad-argument-type
functions,
# pyrefly: ignore # bad-argument-type
op.fn_call_name,

View File

@ -3379,7 +3379,6 @@ def native_layer_norm(
torch._check(
input.ndim >= normalized_ndim
and sym_eq(
# pyrefly: ignore # bad-argument-type
input.shape[(input.ndim - normalized_ndim) :],
# pyrefly: ignore # bad-argument-type
tuple(normalized_shape),

View File

@ -620,6 +620,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
dilation=dilation,
groups=groups,
bias=bias,
# pyrefly: ignore # bad-argument-type
padding_mode=padding_mode,
qconfig=qconfig,
)
@ -820,6 +821,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
dilation=dilation,
groups=groups,
bias=bias,
# pyrefly: ignore # bad-argument-type
padding_mode=padding_mode,
qconfig=qconfig,
)
@ -1021,6 +1023,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
dilation=dilation,
groups=groups,
bias=bias,
# pyrefly: ignore # bad-argument-type
padding_mode=padding_mode,
qconfig=qconfig,
)

View File

@ -36,6 +36,7 @@ class LinearReLU(nnqat.Linear, _FusedModule):
torch.Size([128, 30])
"""
# pyrefly: ignore # bad-override
_FLOAT_MODULE = nni.LinearReLU
def __init__(

View File

@ -30,6 +30,7 @@ class LinearReLU(nnqd.Linear):
torch.Size([128, 30])
"""
# pyrefly: ignore # bad-override
_FLOAT_MODULE = nni.LinearReLU
def __init__(

View File

@ -54,6 +54,7 @@ class ConvReLU1d(nnq.Conv1d):
dilation=dilation,
groups=groups,
bias=bias,
# pyrefly: ignore # bad-argument-type
padding_mode=padding_mode,
device=device,
dtype=dtype,

View File

@ -114,6 +114,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
assert hasattr(cls, "_FLOAT_RELU_MODULE")
relu = cls._FLOAT_RELU_MODULE()
modules.append(relu)
# pyrefly: ignore # missing-attribute
fused = cls._FLOAT_MODULE(*modules)
fused.train(self.training)
return fused

View File

@ -50,6 +50,7 @@ class Embedding(nn.Embedding):
scale_grad_by_freq,
sparse,
_weight,
# pyrefly: ignore # bad-argument-type
**factory_kwargs,
)
assert qconfig, "qconfig must be provided for QAT module"

View File

@ -170,8 +170,11 @@ class MultiheadAttention(nn.MultiheadAttention):
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
if other.in_proj_bias is None:
# pyrefly: ignore # bad-assignment
observed.linear_Q.bias = None
# pyrefly: ignore # bad-assignment
observed.linear_K.bias = None
# pyrefly: ignore # bad-assignment
observed.linear_V.bias = None
else:
observed.linear_Q.bias = nn.Parameter(
@ -234,6 +237,7 @@ class MultiheadAttention(nn.MultiheadAttention):
_end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wQ
if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type
assert all(bQ == 0)
fp.in_proj_bias[_start:_end] = bQ
@ -241,12 +245,14 @@ class MultiheadAttention(nn.MultiheadAttention):
_end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wK
if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type
assert all(bK == 0)
fp.in_proj_bias[_start:_end] = bK
_start = _end
fp.in_proj_weight[_start:, :] = wV
if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type
assert all(bV == 0)
fp.in_proj_bias[_start:] = bV
else:
@ -254,8 +260,11 @@ class MultiheadAttention(nn.MultiheadAttention):
fp.k_proj_weight = nn.Parameter(wK)
fp.v_proj_weight = nn.Parameter(wV)
if fp.in_proj_bias is None:
# pyrefly: ignore # bad-assignment
self.linear_Q.bias = None
# pyrefly: ignore # bad-assignment
self.linear_K.bias = None
# pyrefly: ignore # bad-assignment
self.linear_V.bias = None
else:
fp.in_proj_bias[0 : fp.embed_dim] = bQ
@ -463,6 +472,7 @@ class MultiheadAttention(nn.MultiheadAttention):
assert static_v.size(2) == head_dim
v = static_v
# pyrefly: ignore # missing-attribute
src_len = k.size(1)
if key_padding_mask is not None:
@ -471,17 +481,35 @@ class MultiheadAttention(nn.MultiheadAttention):
if self.add_zero_attn:
src_len += 1
# pyrefly: ignore # missing-attribute
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
# pyrefly: ignore # missing-attribute
if k.is_quantized:
k_zeros = torch.quantize_per_tensor(
k_zeros, k.q_scale(), k.q_zero_point(), k.dtype
k_zeros,
# pyrefly: ignore # missing-attribute
k.q_scale(),
# pyrefly: ignore # missing-attribute
k.q_zero_point(),
# pyrefly: ignore # missing-attribute
k.dtype,
)
# pyrefly: ignore # no-matching-overload
k = torch.cat([k, k_zeros], dim=1)
# pyrefly: ignore # missing-attribute
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
# pyrefly: ignore # missing-attribute
if v.is_quantized:
v_zeros = torch.quantize_per_tensor(
v_zeros, v.q_scale(), v.q_zero_point(), v.dtype
v_zeros,
# pyrefly: ignore # missing-attribute
v.q_scale(),
# pyrefly: ignore # missing-attribute
v.q_zero_point(),
# pyrefly: ignore # missing-attribute
v.dtype,
)
# pyrefly: ignore # no-matching-overload
v = torch.cat([v, v_zeros], dim=1)
if attn_mask is not None:

View File

@ -376,6 +376,7 @@ class _LSTMLayer(torch.nn.Module):
bidirectional,
split_gates=split_gates,
)
# pyrefly: ignore # bad-argument-type
layer.qconfig = getattr(other, "qconfig", qconfig)
wi = getattr(other, f"weight_ih_l{layer_idx}")
wh = getattr(other, f"weight_hh_l{layer_idx}")
@ -454,6 +455,7 @@ class LSTM(torch.nn.Module):
if (
not isinstance(dropout, numbers.Number)
# pyrefly: ignore # unsupported-operation
or not 0 <= dropout <= 1
or isinstance(dropout, bool)
):
@ -462,6 +464,7 @@ class LSTM(torch.nn.Module):
"representing the probability of an element being "
"zeroed"
)
# pyrefly: ignore # unsupported-operation
if dropout > 0:
warnings.warn(
"dropout option for quantizable LSTM is ignored. "
@ -573,6 +576,7 @@ class LSTM(torch.nn.Module):
other.bidirectional,
split_gates=split_gates,
)
# pyrefly: ignore # bad-argument-type
observed.qconfig = getattr(other, "qconfig", qconfig)
for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float(

View File

@ -73,6 +73,7 @@ class Conv1d(nnq.Conv1d):
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size)
stride = _single(stride)
# pyrefly: ignore # bad-assignment
padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation)

View File

@ -119,7 +119,9 @@ class Linear(nnq.Linear):
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU:
mod = mod[0]
# pyrefly: ignore # missing-attribute
if mod.qconfig is not None and mod.qconfig.weight is not None:
# pyrefly: ignore # not-callable
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:
@ -143,6 +145,7 @@ class Linear(nnq.Linear):
"Unsupported dtype specified for dynamic quantized Linear!"
)
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
# pyrefly: ignore # bad-argument-type
qlinear.set_weight_bias(qweight, mod.bias)
return qlinear

View File

@ -521,6 +521,7 @@ class LSTM(RNNBase):
>>> output, (hn, cn) = rnn(input, (h0, c0))
"""
# pyrefly: ignore # bad-override
_FLOAT_MODULE = nn.LSTM
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
@ -806,6 +807,7 @@ class GRU(RNNBase):
>>> output, hn = rnn(input, h0)
"""
# pyrefly: ignore # bad-override
_FLOAT_MODULE = nn.GRU
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}

View File

@ -67,7 +67,9 @@ class Hardswish(torch.nn.Hardswish):
def __init__(self, scale, zero_point, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -138,7 +140,9 @@ class LeakyReLU(torch.nn.LeakyReLU):
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(negative_slope, inplace)
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -226,6 +230,7 @@ class Softmax(torch.nn.Softmax):
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
# pyrefly: ignore # bad-override
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
def _get_name(self):

View File

@ -12,7 +12,9 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
@staticmethod

View File

@ -408,6 +408,7 @@ class Conv1d(_ConvNd):
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size)
stride = _single(stride)
# pyrefly: ignore # bad-assignment
padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation)

View File

@ -310,6 +310,7 @@ class Linear(WeightedQuantizedModule):
# the type mismatch in assignment. Also, mypy has an issue with
# iterables not being implemented, so we are ignoring those too.
if not isinstance(cls._FLOAT_MODULE, Iterable):
# pyrefly: ignore # bad-assignment
cls._FLOAT_MODULE = [cls._FLOAT_MODULE]
supported_modules = ", ".join(
[float_mod.__name__ for float_mod in cls._FLOAT_MODULE]

View File

@ -37,11 +37,14 @@ class LayerNorm(torch.nn.LayerNorm):
normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
# pyrefly: ignore # bad-argument-type
**factory_kwargs,
)
self.weight = weight
self.bias = bias
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -113,7 +116,9 @@ class GroupNorm(torch.nn.GroupNorm):
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
self.weight = weight
self.bias = bias
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -175,7 +180,9 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
)
self.weight = weight
self.bias = bias
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -242,7 +249,9 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
)
self.weight = weight
self.bias = bias
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -309,7 +318,9 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
)
self.weight = weight
self.bias = bias
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):

View File

@ -95,6 +95,7 @@ class Conv1d(_ConvNd, nn.Conv1d):
and the backend should be able to fuse the ops with `*` into a quantized conv1d
"""
weight_quant_dequant = self.get_weight()
# pyrefly: ignore # no-matching-overload
result = F.conv1d(
x,
weight_quant_dequant,
@ -140,6 +141,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
dilation,
groups,
bias,
# pyrefly: ignore # bad-argument-type
padding_mode,
device,
dtype,
@ -158,6 +160,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
and the backend should be able to fuse the ops with `*` into a quantized conv2d
"""
weight_quant_dequant = self.get_weight()
# pyrefly: ignore # no-matching-overload
result = F.conv2d(
x,
weight_quant_dequant,
@ -203,6 +206,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
dilation,
groups,
bias,
# pyrefly: ignore # bad-argument-type
padding_mode,
device,
dtype,
@ -221,6 +225,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
and the backend should be able to fuse the ops with `*` into a quantized conv3d
"""
weight_quant_dequant = self.get_weight()
# pyrefly: ignore # no-matching-overload
result = F.conv3d(
x,
weight_quant_dequant,
@ -378,6 +383,7 @@ class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
groups,
bias,
dilation,
# pyrefly: ignore # bad-argument-type
padding_mode,
device,
dtype,
@ -459,6 +465,7 @@ class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
groups,
bias,
dilation,
# pyrefly: ignore # bad-argument-type
padding_mode,
device,
dtype,

View File

@ -663,7 +663,11 @@ class LSTM(RNNBase):
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices
output,
# pyrefly: ignore # bad-argument-type
batch_sizes,
sorted_indices,
unsorted_indices,
)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
@ -823,7 +827,11 @@ class GRU(RNNBase):
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices
output,
# pyrefly: ignore # bad-argument-type
batch_sizes,
sorted_indices,
unsorted_indices,
)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:

View File

@ -42,6 +42,7 @@ class Embedding(nn.Embedding, ReferenceQuantizedModule):
scale_grad_by_freq,
sparse,
_weight,
# pyrefly: ignore # bad-argument-type
device,
dtype,
)

View File

@ -18,6 +18,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
"scale": 1.0,
"zero_point": 0,
}
# pyrefly: ignore # bad-assignment
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
self.weight_dtype = weight_qparams["dtype"]
assert self.weight_qscheme in [
@ -80,13 +81,16 @@ class ReferenceQuantizedModule(torch.nn.Module):
self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device)
)
# pyrefly: ignore # bad-assignment
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
# for capturing `.item` operations
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
# pyrefly: ignore # bad-assignment
self.weight_quant_min: typing.Optional[int] = weight_qparams.get(
"quant_min", None
)
# pyrefly: ignore # bad-assignment
self.weight_quant_max: typing.Optional[int] = weight_qparams.get(
"quant_max", None
)
@ -105,6 +109,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_and_dequantize_weight_decomposed(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
@ -116,6 +121,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_and_dequantize_weight(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
@ -131,6 +137,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_weight_decomposed(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
@ -142,6 +149,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_weight(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,

View File

@ -151,7 +151,9 @@ class Linear(torch.nn.Module):
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU:
mod = mod[0]
# pyrefly: ignore # missing-attribute
if mod.qconfig is not None and mod.qconfig.weight is not None:
# pyrefly: ignore # not-callable
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:
@ -185,5 +187,6 @@ class Linear(torch.nn.Module):
col_block_size,
dtype=dtype,
)
# pyrefly: ignore # bad-argument-type
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
return qlinear

View File

@ -84,6 +84,7 @@ class _NSGraphMatchableSubgraphsIterator:
if is_match:
# navigate to the base node
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
# pyrefly: ignore # bad-argument-type
self.seen_nodes.add(cur_start_node)
# for now, assume that there are no other nodes
# which need to be added to the stack
@ -94,8 +95,10 @@ class _NSGraphMatchableSubgraphsIterator:
cur_base_op_node = cur_start_node
break
# pyrefly: ignore # bad-argument-type
self.seen_nodes.add(cur_start_node)
# add args of previous nodes to stack
# pyrefly: ignore # missing-attribute
for arg in cur_start_node.all_input_nodes:
self._recursively_add_node_arg_to_stack(arg)
@ -103,6 +106,7 @@ class _NSGraphMatchableSubgraphsIterator:
# note: this check is done on the start_node, i.e.
# if we are matching linear-relu in reverse, this would do the matchable
# check on the linear
# pyrefly: ignore # bad-argument-type
if not self._is_matchable(cur_base_op_node):
continue
@ -116,8 +120,10 @@ class _NSGraphMatchableSubgraphsIterator:
continue
return NSSubgraph(
# pyrefly: ignore # bad-argument-type
start_node=cur_start_node,
end_node=cur_end_node,
# pyrefly: ignore # bad-argument-type
base_op_node=cur_base_op_node,
)

View File

@ -415,6 +415,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
target2,
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
new_connections.append((source, target1))
# pyrefly: ignore # bad-argument-type
new_connections.append((source, target2))
for source_to_target in (
@ -423,6 +424,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
):
for source, target in source_to_target.items(): # type:ignore[assignment]
# pyrefly: ignore # bad-argument-type
new_connections.append((source, target))
#

View File

@ -95,6 +95,7 @@ class OutputProp:
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
node.traced_result = result
# pyrefly: ignore # unsupported-operation
env[node.name] = result
return None
@ -393,8 +394,10 @@ def create_submodule_from_subgraph(
cur_name_idx += 1
setattr(gm, mod_name, new_arg)
new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator]
# pyrefly: ignore # missing-attribute
cur_args_copy.append(new_arg_placeholder)
elif isinstance(arg, (float, int, torch.dtype)):
# pyrefly: ignore # missing-attribute
cur_args_copy.append(arg)
else:
raise AssertionError(f"arg of type {type(arg)} not handled yet")
@ -801,6 +804,7 @@ def create_add_loggers_graph(
model,
cur_subgraph_idx,
match_name,
# pyrefly: ignore # bad-argument-type
maybe_subgraph,
[qconfig_mapping],
[node_name_to_qconfig],
@ -857,6 +861,7 @@ def create_add_loggers_graph(
cur_node_orig = first_node
cur_node_copy = None
first_node_copy = None
# pyrefly: ignore # bad-assignment
while cur_node_orig in subgraph_to_use:
# TODO(future PR): make this support all possible args/kwargs
if cur_node_orig is first_node:

View File

@ -404,6 +404,7 @@ def maybe_add_missing_fqns(results: NSResultsType) -> None:
for model_name, model_results in model_name_to_results.items():
if model_name == model_name_with_fqns:
continue
# pyrefly: ignore # bad-assignment
for i in range(len(model_results)):
fqn = ref_model_results[i]["fqn"]
model_results[i]["fqn"] = fqn
@ -467,6 +468,7 @@ def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso
Return:
float or tuple of floats
"""
# pyrefly: ignore # unsupported-operation
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())

View File

@ -23,12 +23,17 @@ class SparseDLRM(DLRM_Net):
super().__init__(**args)
def forward(self, dense_x, lS_o, lS_i):
# pyrefly: ignore # missing-attribute
x = self.apply_mlp(dense_x, self.bot_l) # dense features
# pyrefly: ignore # missing-attribute
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag
# pyrefly: ignore # missing-attribute
z = self.interact_features(x, ly)
z = z.to_sparse_coo()
# pyrefly: ignore # missing-attribute
z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias)
# pyrefly: ignore # missing-attribute
for layer in self.top_l[1:]:
z = layer(z)

View File

@ -72,6 +72,7 @@ class FPGMPruner(BaseStructuredSparsifier):
dist_matrix = self.dist_fn(t_flatten)
# more similar with other filter indicates large in the sum of row
# pyrefly: ignore # bad-argument-type
distance = torch.sum(torch.abs(dist_matrix), 1)
return distance

View File

@ -260,6 +260,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
module.register_parameter(
"_bias", nn.Parameter(module.bias.detach())
)
# pyrefly: ignore # bad-assignment
module.bias = None
module.prune_bias = prune_bias

View File

@ -97,6 +97,7 @@ def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]:
if module.bias is not None:
module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
elif getattr(module, "_bias", None) is not None:
# pyrefly: ignore # bad-assignment
module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
# get pruned biases to propagate to subsequent layer

View File

@ -170,6 +170,7 @@ class BaseSparsifier(abc.ABC):
self.make_config_from_model(model)
# TODO: Remove the configuration by reference ('module')
# pyrefly: ignore # not-iterable
for module_config in self.config:
assert isinstance(module_config, dict), (
"config elements should be dicts not modules i.e.:"

View File

@ -51,6 +51,7 @@ def swap_module(
new_mod.register_forward_hook(hook_fn)
# respect device affinity when swapping modules
# pyrefly: ignore # bad-argument-type
devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
assert len(devices) <= 1, (
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"

View File

@ -235,6 +235,7 @@ class WeightNormSparsifier(BaseSparsifier):
ww = self.norm_fn(getattr(module, tensor_name))
tensor_mask = self._make_tensor_mask(
data=ww,
# pyrefly: ignore # missing-attribute
input_shape=ww.shape,
sparsity_level=sparsity_level,
sparse_block_shape=sparse_block_shape,

View File

@ -24,6 +24,8 @@ from .pt2e.export_utils import (
_move_exported_model_to_eval as move_exported_model_to_eval,
_move_exported_model_to_train as move_exported_model_to_train,
)
# pyrefly: ignore # deprecated
from .qconfig import * # noqa: F403
from .qconfig_mapping import * # noqa: F403
from .quant_type import * # noqa: F403

View File

@ -127,6 +127,7 @@ class AdaptiveRoundingOptimizer:
@torch.no_grad()
def feed_forward(self, x, weight, module):
if isinstance(module, torch.nn.Conv1d):
# pyrefly: ignore # no-matching-overload
out = torch.nn.functional.conv1d(
x,
weight,

View File

@ -185,7 +185,9 @@ class FakeQuantize(FakeQuantizeBase):
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
"dtype", dtype
)
# pyrefly: ignore # bad-argument-type
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
# pyrefly: ignore # bad-argument-type
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
observer_kwargs["is_dynamic"] = is_dynamic

View File

@ -1149,6 +1149,7 @@ quantized_decomposed_lib.define(
class FakeQuantPerChannel(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
if scales.dtype != torch.float32:
scales = scales.to(torch.float32)
@ -1171,6 +1172,7 @@ class FakeQuantPerChannel(torch.autograd.Function):
return out
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, gy):
(mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None

View File

@ -246,6 +246,7 @@ def calculate_equalization_scale(
class EqualizationQConfig(
# pyrefly: ignore # invalid-inheritance
namedtuple("EqualizationQConfig", ["input_activation", "weight"])
):
"""
@ -460,6 +461,7 @@ def maybe_get_next_equalization_scale(
In this case, the node given is linear1 and we want to locate the InputEqObs.
"""
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
# pyrefly: ignore # invalid-argument
if next_inp_eq_obs:
if (
next_inp_eq_obs.equalization_scale.nelement() == 1
@ -821,13 +823,18 @@ def convert_eq_obs(
# Scale the weight nodes
if node.op == "call_module":
scale_weight_node(
node, modules, equalization_scale, maybe_next_equalization_scale
node,
modules,
# pyrefly: ignore # bad-argument-type
equalization_scale,
maybe_next_equalization_scale,
)
elif node.op == "call_function":
scale_weight_functional(
node,
model,
modules,
# pyrefly: ignore # bad-argument-type
equalization_scale,
maybe_next_equalization_scale,
)

View File

@ -223,6 +223,7 @@ class ModelReportVisualizer:
feature_val = feature_val.item()
# we add to our list of values
# pyrefly: ignore # bad-argument-type
tensor_table_row.append(feature_val)
tensor_table.append(tensor_table_row)
@ -283,6 +284,7 @@ class ModelReportVisualizer:
feature_val = feature_val.item()
# add value to channel specific row
# pyrefly: ignore # bad-argument-type
new_channel_row.append(feature_val)
# add to table and increment row index counter

View File

@ -166,6 +166,7 @@ def _create_obs_or_fq_from_qspec(
}
edge_or_nodes = quantization_spec.derived_from
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
# pyrefly: ignore # unsupported-operation
kwargs["obs_or_fqs"] = obs_or_fqs
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
@ -2085,8 +2086,11 @@ def prepare(
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
# pyrefly: ignore # bad-argument-type
_update_qconfig_for_fusion(model, qconfig_mapping)
# pyrefly: ignore # bad-argument-type
_update_qconfig_for_fusion(model, _equalization_config)
# pyrefly: ignore # bad-argument-type
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
# TODO: support regex as well
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
@ -2094,6 +2098,7 @@ def prepare(
if is_qat:
module_to_qat_module = get_module_to_qat_module(backend_config)
_qat_swap_modules(model, module_to_qat_module)
# pyrefly: ignore # bad-argument-type
_update_qconfig_for_qat(qconfig_mapping, backend_config)
# mapping from fully qualified module name to module instance
@ -2107,10 +2112,20 @@ def prepare(
# fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
model, named_modules, model.graph, _equalization_config, node_name_to_scope
model,
named_modules,
model.graph,
# pyrefly: ignore # bad-argument-type
_equalization_config,
node_name_to_scope,
)
node_name_to_qconfig = _generate_node_name_to_qconfig(
model, named_modules, model.graph, qconfig_mapping, node_name_to_scope
model,
named_modules,
model.graph,
# pyrefly: ignore # bad-argument-type
qconfig_mapping,
node_name_to_scope,
)
# match the patterns that will get quantized
@ -2170,6 +2185,7 @@ def prepare(
node_name_to_scope,
prepare_custom_config,
equalization_node_name_to_qconfig,
# pyrefly: ignore # bad-argument-type
qconfig_mapping,
is_qat,
observed_node_names,

View File

@ -720,6 +720,7 @@ def _maybe_get_custom_module_lstm_from_node_arg(
a = a.args[0][0] # type: ignore[assignment,index]
else:
a = a.args[0] # type: ignore[assignment]
# pyrefly: ignore # bad-return
return a
all_match_patterns = [

View File

@ -280,9 +280,12 @@ class UniformQuantizationObserverBase(ObserverBase):
)
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
if self.has_customized_qrange:
# pyrefly: ignore # bad-argument-type
validate_qmin_qmax(quant_min, quant_max)
self.quant_min, self.quant_max = calculate_qmin_qmax(
# pyrefly: ignore # bad-argument-type
quant_min,
# pyrefly: ignore # bad-argument-type
quant_max,
self.has_customized_qrange,
self.dtype,

View File

@ -72,6 +72,7 @@ def _find_q_dq_node_for_user(
dq_node = n
break
if dq_node is None:
# pyrefly: ignore # bad-assignment
for n in user.kwargs:
if (
isinstance(n, torch.fx.Node)

View File

@ -83,6 +83,7 @@ __all__ = [
]
# pyrefly: ignore # invalid-inheritance
class QConfig(namedtuple("QConfig", ["activation", "weight"])):
"""
Describes how to quantize a layer or a part of the network by providing
@ -120,6 +121,7 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
"`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead",
category=FutureWarning,
)
# pyrefly: ignore # invalid-inheritance
class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
"""
Describes how to dynamically quantize a layer or a part of the network by providing

View File

@ -417,6 +417,7 @@ class X86InductorQuantizer(Quantizer):
# As we use `_need_skip_config` to skip all invalid configurations,
# we can safely assume that the all existing non-None configurations
# have the same quantization mode.
# pyrefly: ignore # bad-assignment
for qconfig in (
list(self.module_name_qconfig.values())
+ list(self.operator_type_qconfig.values())
@ -808,6 +809,7 @@ class X86InductorQuantizer(Quantizer):
)
binary_node.meta[QUANT_ANNOTATION_KEY] = (
_X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
input_qspec_map=binary_node_input_qspec_map,
_annotated=True,
)
@ -878,6 +880,7 @@ class X86InductorQuantizer(Quantizer):
)
binary_node.meta[QUANT_ANNOTATION_KEY] = (
_X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
input_qspec_map=binary_node_input_qspec_map,
# TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
@ -1085,6 +1088,7 @@ class X86InductorQuantizer(Quantizer):
quantization_config
)
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
input_qspec_map=binary_node_input_qspec_map,
_annotated=True,
)
@ -1139,6 +1143,7 @@ class X86InductorQuantizer(Quantizer):
quantization_config
)
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
input_qspec_map=binary_node_input_qspec_map,
_annotated=True,
_is_output_of_quantized_pattern=True,
@ -1499,6 +1504,7 @@ class X86InductorQuantizer(Quantizer):
has_unary = unary_op is not None
seq_partition = [torch.nn.Linear, binary_op]
if has_unary:
# pyrefly: ignore # bad-argument-type
seq_partition.append(unary_op)
fused_partitions = find_sequential_partitions(gm, seq_partition)
for fused_partition in fused_partitions:

View File

@ -376,9 +376,11 @@ def _do_annotate_conv_relu(
input_qspec_map[bias] = get_bias_qspec(quantization_config)
partition.append(bias)
# pyrefly: ignore # bad-argument-type
if _is_annotated(partition):
continue
# pyrefly: ignore # bad-argument-type
if filter_fn and any(not filter_fn(n) for n in partition):
continue
@ -389,6 +391,7 @@ def _do_annotate_conv_relu(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
# pyrefly: ignore # bad-argument-type
_mark_nodes_as_annotated(partition)
annotated_partitions.append(partition)
return annotated_partitions

View File

@ -39,6 +39,7 @@ class Bernoulli(ExponentialFamily):
validate_args (bool, optional): whether to validate arguments, None by default
"""
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.boolean
has_enumerate_support = True
@ -56,10 +57,12 @@ class Bernoulli(ExponentialFamily):
)
if probs is not None:
is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
@ -137,5 +140,6 @@ class Bernoulli(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]:
return (torch.logit(self.probs),)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x):
return torch.log1p(torch.exp(x))

View File

@ -31,6 +31,7 @@ class Beta(ExponentialFamily):
(often referred to as beta)
"""
# pyrefly: ignore # bad-override
arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
@ -113,5 +114,6 @@ class Beta(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.concentration1, self.concentration0)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x, y):
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)

View File

@ -45,6 +45,7 @@ class Binomial(Distribution):
logits (Tensor): Event log-odds
"""
# pyrefly: ignore # bad-override
arg_constraints = {
"total_count": constraints.nonnegative_integer,
"probs": constraints.unit_interval,
@ -66,6 +67,7 @@ class Binomial(Distribution):
if probs is not None:
(
self.total_count,
# pyrefly: ignore # read-only
self.probs,
) = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs)
@ -73,6 +75,7 @@ class Binomial(Distribution):
assert logits is not None # helps mypy
(
self.total_count,
# pyrefly: ignore # read-only
self.logits,
) = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits)
@ -99,6 +102,7 @@ class Binomial(Distribution):
return self._param.new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=0)
# pyrefly: ignore # bad-override
def support(self):
return constraints.integer_interval(0, self.total_count)

View File

@ -50,6 +50,7 @@ class Categorical(Distribution):
logits (Tensor): event log probabilities (unnormalized)
"""
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
has_enumerate_support = True
@ -66,12 +67,14 @@ class Categorical(Distribution):
if probs is not None:
if probs.dim() < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
# pyrefly: ignore # read-only
self.probs = probs / probs.sum(-1, keepdim=True)
else:
assert logits is not None # helps mypy
if logits.dim() < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
# Normalize
# pyrefly: ignore # read-only
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
self._param = self.probs if probs is not None else self.logits
self._num_events = self._param.size()[-1]
@ -99,6 +102,7 @@ class Categorical(Distribution):
return self._param.new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=0)
# pyrefly: ignore # bad-override
def support(self):
return constraints.integer_interval(0, self._num_events - 1)

View File

@ -31,6 +31,7 @@ class Cauchy(Distribution):
scale (float or Tensor): half width at half maximum.
"""
# pyrefly: ignore # bad-override
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
has_rsample = True

View File

@ -47,6 +47,7 @@ class ContinuousBernoulli(ExponentialFamily):
https://arxiv.org/abs/1907.06845
"""
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.unit_interval
_mean_carrier_measure = 0
@ -65,16 +66,19 @@ class ContinuousBernoulli(ExponentialFamily):
)
if probs is not None:
is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only
(self.probs,) = broadcast_all(probs)
# validate 'probs' here if necessary as it is later clamped for numerical stability
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
if validate_args is not None:
if not self.arg_constraints["probs"].check(self.probs).all():
raise ValueError("The parameter probs has invalid values")
# pyrefly: ignore # read-only
self.probs = clamp_probs(self.probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
@ -230,6 +234,7 @@ class ContinuousBernoulli(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]:
return (self.logits,)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x):
"""computes the log normalizing constant as a function of the natural parameter"""
out_unst_reg = torch.max(

View File

@ -22,6 +22,7 @@ def _Dirichlet_backward(x, concentration, grad_output):
class _Dirichlet(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, concentration):
x = torch._sample_dirichlet(concentration)
ctx.save_for_backward(x, concentration)
@ -29,6 +30,7 @@ class _Dirichlet(Function):
@staticmethod
@once_differentiable
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
x, concentration = ctx.saved_tensors
return _Dirichlet_backward(x, concentration, grad_output)
@ -50,6 +52,7 @@ class Dirichlet(ExponentialFamily):
(often referred to as alpha)
"""
# pyrefly: ignore # bad-override
arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1)
}
@ -130,5 +133,6 @@ class Dirichlet(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]:
return (self.concentration,)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x):
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))

View File

@ -27,6 +27,7 @@ class Exponential(ExponentialFamily):
rate (float or Tensor): rate = 1 / scale of the distribution
"""
# pyrefly: ignore # bad-override
arg_constraints = {"rate": constraints.positive}
support = constraints.nonnegative
has_rsample = True
@ -89,5 +90,6 @@ class Exponential(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]:
return (-self.rate,)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x):
return -torch.log(-x)

View File

@ -29,6 +29,7 @@ class FisherSnedecor(Distribution):
df2 (float or Tensor): degrees of freedom parameter 2
"""
# pyrefly: ignore # bad-override
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
support = constraints.positive
has_rsample = True

View File

@ -34,6 +34,7 @@ class Gamma(ExponentialFamily):
(often referred to as beta), rate = 1 / scale
"""
# pyrefly: ignore # bad-override
arg_constraints = {
"concentration": constraints.positive,
"rate": constraints.positive,
@ -109,6 +110,7 @@ class Gamma(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.concentration - 1, -self.rate)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x, y):
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())

View File

@ -35,6 +35,7 @@ class GeneralizedPareto(Distribution):
concentration (float or Tensor): Concentration parameter of the distribution
"""
# pyrefly: ignore # bad-override
arg_constraints = {
"loc": constraints.real,
"scale": constraints.positive,
@ -130,6 +131,7 @@ class GeneralizedPareto(Distribution):
concentration = self.concentration
valid = concentration < 0.5
safe_conc = torch.where(valid, concentration, 0.25)
# pyrefly: ignore # unsupported-operation
result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
return torch.where(valid, result, nan)
@ -142,6 +144,7 @@ class GeneralizedPareto(Distribution):
return self.loc
@constraints.dependent_property(is_discrete=False, event_dim=0)
# pyrefly: ignore # bad-override
def support(self):
lower = self.loc
upper = torch.where(

View File

@ -44,6 +44,7 @@ class Geometric(Distribution):
logits (Number, Tensor): the log-odds of sampling `1`.
"""
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.nonnegative_integer
@ -58,9 +59,11 @@ class Geometric(Distribution):
"Either `probs` or `logits` must be specified, but not both."
)
if probs is not None:
# pyrefly: ignore # read-only
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
# pyrefly: ignore # read-only
(self.logits,) = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, _Number):

View File

@ -32,6 +32,7 @@ class Gumbel(TransformedDistribution):
"""
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override
support = constraints.real
def __init__(

View File

@ -32,8 +32,10 @@ class HalfCauchy(TransformedDistribution):
"""
arg_constraints = {"scale": constraints.positive}
# pyrefly: ignore # bad-override
support = constraints.nonnegative
has_rsample = True
# pyrefly: ignore # bad-override
base_dist: Cauchy
def __init__(

View File

@ -32,8 +32,10 @@ class HalfNormal(TransformedDistribution):
"""
arg_constraints = {"scale": constraints.positive}
# pyrefly: ignore # bad-override
support = constraints.nonnegative
has_rsample = True
# pyrefly: ignore # bad-override
base_dist: Normal
def __init__(

View File

@ -91,6 +91,7 @@ class Independent(Distribution, Generic[D]):
return self.base_dist.has_enumerate_support
@constraints.dependent_property
# pyrefly: ignore # bad-override
def support(self):
result = self.base_dist.support
if self.reinterpreted_batch_ndims:

View File

@ -38,8 +38,10 @@ class InverseGamma(TransformedDistribution):
"concentration": constraints.positive,
"rate": constraints.positive,
}
# pyrefly: ignore # bad-override
support = constraints.positive
has_rsample = True
# pyrefly: ignore # bad-override
base_dist: Gamma
def __init__(

View File

@ -44,6 +44,7 @@ class Kumaraswamy(TransformedDistribution):
"concentration1": constraints.positive,
"concentration0": constraints.positive,
}
# pyrefly: ignore # bad-override
support = constraints.unit_interval
has_rsample = True
@ -66,6 +67,7 @@ class Kumaraswamy(TransformedDistribution):
AffineTransform(loc=1.0, scale=-1.0),
PowerTransform(exponent=self.concentration1.reciprocal()),
]
# pyrefly: ignore # bad-argument-type
super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):

View File

@ -28,6 +28,7 @@ class Laplace(Distribution):
scale (float or Tensor): scale of the distribution
"""
# pyrefly: ignore # bad-override
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
has_rsample = True

View File

@ -60,6 +60,7 @@ class LKJCholesky(Distribution):
Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
"""
# pyrefly: ignore # bad-override
arg_constraints = {"concentration": constraints.positive}
support = constraints.corr_cholesky

View File

@ -32,8 +32,10 @@ class LogNormal(TransformedDistribution):
"""
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override
support = constraints.positive
has_rsample = True
# pyrefly: ignore # bad-override
base_dist: Normal
def __init__(

View File

@ -36,8 +36,10 @@ class LogisticNormal(TransformedDistribution):
"""
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override
support = constraints.simplex
has_rsample = True
# pyrefly: ignore # bad-override
base_dist: Independent[Normal]
def __init__(

View File

@ -86,6 +86,7 @@ class LowRankMultivariateNormal(Distribution):
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
"""
# pyrefly: ignore # bad-override
arg_constraints = {
"loc": constraints.real_vector,
"cov_factor": constraints.independent(constraints.real, 2),

View File

@ -124,6 +124,7 @@ class MixtureSameFamily(Distribution):
return new
@constraints.dependent_property
# pyrefly: ignore # bad-override
def support(self):
return MixtureSameFamilyConstraint(self._component_distribution.support)

View File

@ -50,6 +50,7 @@ class Multinomial(Distribution):
logits (Tensor): event log probabilities (unnormalized)
"""
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
total_count: int
@ -92,6 +93,7 @@ class Multinomial(Distribution):
return self._categorical._new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=1)
# pyrefly: ignore # bad-override
def support(self):
return constraints.multinomial(self.total_count)

View File

@ -123,6 +123,7 @@ class MultivariateNormal(Distribution):
the corresponding lower triangular matrices using a Cholesky decomposition.
"""
# pyrefly: ignore # bad-override
arg_constraints = {
"loc": constraints.real_vector,
"covariance_matrix": constraints.positive_definite,
@ -156,6 +157,7 @@ class MultivariateNormal(Distribution):
"with optional leading batch dimensions"
)
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
# pyrefly: ignore # read-only
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None:
if covariance_matrix.dim() < 2:
@ -166,6 +168,7 @@ class MultivariateNormal(Distribution):
batch_shape = torch.broadcast_shapes(
covariance_matrix.shape[:-2], loc.shape[:-1]
)
# pyrefly: ignore # read-only
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
else:
assert precision_matrix is not None # helps mypy
@ -177,6 +180,7 @@ class MultivariateNormal(Distribution):
batch_shape = torch.broadcast_shapes(
precision_matrix.shape[:-2], loc.shape[:-1]
)
# pyrefly: ignore # read-only
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
self.loc = loc.expand(batch_shape + (-1,))

View File

@ -33,6 +33,7 @@ class NegativeBinomial(Distribution):
logits (Tensor): Event log-odds for probabilities of success
"""
# pyrefly: ignore # bad-override
arg_constraints = {
"total_count": constraints.greater_than_eq(0),
"probs": constraints.half_open_interval(0.0, 1.0),
@ -54,6 +55,7 @@ class NegativeBinomial(Distribution):
if probs is not None:
(
self.total_count,
# pyrefly: ignore # read-only
self.probs,
) = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs)
@ -61,6 +63,7 @@ class NegativeBinomial(Distribution):
assert logits is not None # helps mypy
(
self.total_count,
# pyrefly: ignore # read-only
self.logits,
) = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits)

View File

@ -31,6 +31,7 @@ class Normal(ExponentialFamily):
(often referred to as sigma)
"""
# pyrefly: ignore # bad-override
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
has_rsample = True
@ -88,6 +89,7 @@ class Normal(ExponentialFamily):
if self._validate_args:
self._validate_sample(value)
# compute the variance
# pyrefly: ignore # unsupported-operation
var = self.scale**2
log_scale = (
math.log(self.scale)
@ -117,5 +119,6 @@ class Normal(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
# pyrefly: ignore # bad-override
def _log_normalizer(self, x, y):
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)

View File

@ -42,6 +42,7 @@ class OneHotCategorical(Distribution):
logits (Tensor): event log probabilities (unnormalized)
"""
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.one_hot
has_enumerate_support = True

View File

@ -39,6 +39,7 @@ class Pareto(TransformedDistribution):
self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
# pyrefly: ignore # bad-argument-type
super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(

View File

@ -32,6 +32,7 @@ class Poisson(ExponentialFamily):
rate (Number, Tensor): the rate parameter
"""
# pyrefly: ignore # bad-override
arg_constraints = {"rate": constraints.nonnegative}
support = constraints.nonnegative_integer
@ -82,5 +83,6 @@ class Poisson(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]:
return (torch.log(self.rate),)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x):
return torch.exp(x)

View File

@ -40,6 +40,7 @@ class LogitRelaxedBernoulli(Distribution):
(Jang et al., 2017)
"""
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.real
@ -57,10 +58,12 @@ class LogitRelaxedBernoulli(Distribution):
)
if probs is not None:
is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
@ -138,8 +141,10 @@ class RelaxedBernoulli(TransformedDistribution):
"""
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
# pyrefly: ignore # bad-override
support = constraints.unit_interval
has_rsample = True
# pyrefly: ignore # bad-override
base_dist: LogitRelaxedBernoulli
def __init__(

View File

@ -38,6 +38,7 @@ class ExpRelaxedCategorical(Distribution):
(Jang et al., 2017)
"""
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = (
constraints.real_vector
@ -127,8 +128,10 @@ class RelaxedOneHotCategorical(TransformedDistribution):
"""
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
# pyrefly: ignore # bad-override
support = constraints.simplex
has_rsample = True
# pyrefly: ignore # bad-override
base_dist: ExpRelaxedCategorical
def __init__(

View File

@ -31,6 +31,7 @@ class StudentT(Distribution):
scale (float or Tensor): scale of the distribution
"""
# pyrefly: ignore # bad-override
arg_constraints = {
"df": constraints.positive,
"loc": constraints.real,

View File

@ -123,6 +123,7 @@ class TransformedDistribution(Distribution):
return new
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def support(self):
if not self.transforms:
return self.base_dist.support

View File

@ -226,11 +226,13 @@ class _InverseTransform(Transform):
self._inv: Transform = transform # type: ignore[assignment]
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def domain(self):
assert self._inv is not None
return self._inv.codomain
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def codomain(self):
assert self._inv is not None
return self._inv.domain
@ -300,6 +302,7 @@ class ComposeTransform(Transform):
return self.parts == other.parts
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def domain(self):
if not self.parts:
return constraints.real
@ -315,6 +318,7 @@ class ComposeTransform(Transform):
return domain
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def codomain(self):
if not self.parts:
return constraints.real
@ -434,12 +438,14 @@ class IndependentTransform(Transform):
)
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def domain(self):
return constraints.independent(
self.base_transform.domain, self.reinterpreted_batch_ndims
)
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def codomain(self):
return constraints.independent(
self.base_transform.codomain, self.reinterpreted_batch_ndims
@ -507,10 +513,12 @@ class ReshapeTransform(Transform):
super().__init__(cache_size=cache_size)
@constraints.dependent_property
# pyrefly: ignore # bad-override
def domain(self):
return constraints.independent(constraints.real, len(self.in_shape))
@constraints.dependent_property
# pyrefly: ignore # bad-override
def codomain(self):
return constraints.independent(constraints.real, len(self.out_shape))
@ -764,12 +772,14 @@ class AffineTransform(Transform):
return self._event_dim
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def domain(self):
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def codomain(self):
if self.event_dim == 0:
return constraints.real
@ -867,6 +877,7 @@ class CorrCholeskyTransform(Transform):
# apply stick-breaking on the squared values
# Note that y = sign(r) * sqrt(z * z1m_cumprod)
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
# pyrefly: ignore # unsupported-operation
z = r**2
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
# Diagonal elements must be 1.
@ -1155,12 +1166,14 @@ class CatTransform(Transform):
return all(t.bijective for t in self.transforms)
@constraints.dependent_property
# pyrefly: ignore # bad-override
def domain(self):
return constraints.cat(
[t.domain for t in self.transforms], self.dim, self.lengths
)
@constraints.dependent_property
# pyrefly: ignore # bad-override
def codomain(self):
return constraints.cat(
[t.codomain for t in self.transforms], self.dim, self.lengths
@ -1233,10 +1246,12 @@ class StackTransform(Transform):
return all(t.bijective for t in self.transforms)
@constraints.dependent_property
# pyrefly: ignore # bad-override
def domain(self):
return constraints.stack([t.domain for t in self.transforms], self.dim)
@constraints.dependent_property
# pyrefly: ignore # bad-override
def codomain(self):
return constraints.stack([t.codomain for t in self.transforms], self.dim)

View File

@ -79,6 +79,7 @@ class Uniform(Distribution):
return new
@constraints.dependent_property(is_discrete=False, event_dim=0)
# pyrefly: ignore # bad-override
def support(self):
return constraints.interval(self.low, self.high)

View File

@ -92,6 +92,7 @@ def _log_modified_bessel_fn(x, order=0):
@torch.jit.script_if_tracing
def _rejection_sample(loc, concentration, proposal_r, x):
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
# pyrefly: ignore # bad-assignment
while not done.all():
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
u1, u2, u3 = u.unbind()
@ -100,6 +101,7 @@ def _rejection_sample(loc, concentration, proposal_r, x):
c = concentration * (proposal_r - f)
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
if accept.any():
# pyrefly: ignore # no-matching-overload
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
done = done | accept
return (x + math.pi + loc) % (2 * math.pi) - math.pi
@ -123,6 +125,7 @@ class VonMises(Distribution):
:param torch.Tensor concentration: concentration parameter
"""
# pyrefly: ignore # bad-override
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
support = constraints.real
has_rsample = False
@ -160,8 +163,10 @@ class VonMises(Distribution):
@lazy_property
def _proposal_r(self) -> Tensor:
kappa = self._concentration
# pyrefly: ignore # unsupported-operation
tau = 1 + (1 + 4 * kappa**2).sqrt()
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
# pyrefly: ignore # unsupported-operation
_proposal_r = (1 + rho**2) / (2 * rho)
# second order Taylor expansion around 0 for small kappa
_proposal_r_taylor = 1 / kappa + kappa

View File

@ -35,6 +35,7 @@ class Weibull(TransformedDistribution):
"scale": constraints.positive,
"concentration": constraints.positive,
}
# pyrefly: ignore # bad-override
support = constraints.positive
def __init__(
@ -52,6 +53,7 @@ class Weibull(TransformedDistribution):
PowerTransform(exponent=self.concentration_reciprocal),
AffineTransform(loc=0, scale=self.scale),
]
# pyrefly: ignore # bad-argument-type
super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):

View File

@ -116,10 +116,13 @@ class Wishart(ExponentialFamily):
)
if scale_tril is not None:
# pyrefly: ignore # read-only
self.scale_tril = param.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None:
# pyrefly: ignore # read-only
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
elif precision_matrix is not None:
# pyrefly: ignore # read-only
self.precision_matrix = param.expand(batch_shape + (-1, -1))
if self.df.lt(event_shape[-1]).any():
@ -335,6 +338,7 @@ class Wishart(ExponentialFamily):
p = self._event_shape[-1] # has singleton shape
return -self.precision_matrix / 2, (nu - p - 1) / 2
# pyrefly: ignore # bad-override
def _log_normalizer(self, x, y):
p = self._event_shape[-1]
return (y + (p + 1) / 2) * (

View File

@ -24,11 +24,15 @@ class TensorProperties:
if not self.is_fake:
# only get the storage pointer for real tensors
# pyrefly: ignore # bad-assignment
self.storage_ptr = tensor.untyped_storage().data_ptr()
if self.is_contiguous:
# only get storage size and start/end pointers for contiguous tensors
# pyrefly: ignore # bad-assignment
self.storage_size = tensor.untyped_storage().nbytes()
# pyrefly: ignore # bad-assignment
self.start = tensor.data_ptr()
# pyrefly: ignore # bad-assignment
self.end = _end_ptr(tensor)
# info to recover tensor

View File

@ -65,6 +65,7 @@ class GraphPickler(pickle.Pickler):
self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
@override
# pyrefly: ignore # bad-override
def reducer_override(
self, obj: object
) -> tuple[Callable[..., Any], tuple[Any, ...]]:
@ -201,6 +202,7 @@ class _SymNodePickleData:
]:
args = (cls(obj.node), pickler._unpickle_state)
if isinstance(obj, torch.SymInt):
# pyrefly: ignore # bad-return
return _SymNodePickleData.unpickle_sym_int, args
else:
raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
@ -277,6 +279,7 @@ class _TensorPickleData:
return FakeTensor(
unpickle_state.fake_mode,
make_meta_t(),
# pyrefly: ignore # bad-argument-type
device,
)

View File

@ -603,6 +603,7 @@ class Tracer(TracerBase):
in inspect.signature(self.create_proxy).parameters
):
kwargs["proxy_factory_fn"] = (
# pyrefly: ignore # unsupported-operation
None
if not self.param_shapes_constant
else lambda node: ParameterProxy(

View File

@ -657,7 +657,10 @@ class Partitioner:
# Mark bfs level
get_bfs_level_partition(self.partitions)
find_combination, partitions = find_partition_to_combine_based_on_size(
sorted_partitions, available_mem_bytes, partitions
sorted_partitions,
available_mem_bytes,
# pyrefly: ignore # bad-argument-type
partitions,
)
return
@ -702,6 +705,7 @@ class Partitioner:
non_embedding_partitions.append(partition)
if new_partition:
partition = self.create_partition()
# pyrefly: ignore # missing-attribute
partition.left_mem_bytes = available_mem_bytes
return partition
return None
@ -997,6 +1001,7 @@ class Partitioner:
node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
)
if cost < min_cost:
# pyrefly: ignore # bad-assignment
node_pair = [node, n1]
min_cost = cost
return cost, node_pair # type: ignore[possibly-undefined]

View File

@ -30,6 +30,7 @@ def split_result_tensors(
else:
splits = [x.shape[0] for x in inputs]
# pyrefly: ignore # bad-argument-type
return torch.split(result, splits)

View File

@ -171,7 +171,14 @@ class MetaTracer(torch.fx.Tracer):
proxy_factory_fn=None,
):
rv = super().create_proxy(
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
kind,
target,
args,
kwargs,
name,
type_expr,
# pyrefly: ignore # bad-argument-type
proxy_factory_fn,
)
if kind == "placeholder" and target in self.meta_args:
@ -193,6 +200,7 @@ class MetaTracer(torch.fx.Tracer):
if kind == "call_function":
meta_target = manual_meta_overrides.get(target, target)
# pyrefly: ignore # not-callable
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_method":
meta_target = getattr(args_metas[0], target) # type: ignore[index]

View File

@ -528,9 +528,11 @@ def view_inference_rule(n: Node, symbols, constraints, counter):
if t == -1:
var, counter = gen_dvar(counter)
t2_type.append(var)
# pyrefly: ignore # bad-argument-type
num_constraints.append(BinConstraintD(var, Dyn, op_neq))
else:
# pyrefly: ignore # bad-argument-type
num_constraints.append(BinConstraintD(t, Dyn, op_neq))
t2_type.append(t) # type: ignore[arg-type]
@ -1475,6 +1477,7 @@ class ConstraintGenerator:
all_constraints = []
# pyrefly: ignore # missing-attribute
for n in graph.nodes:
(constraints, counter) = self.generate_constraints_node(n, counter)
all_constraints += constraints

View File

@ -193,6 +193,7 @@ def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
assert isinstance(node.target, str)
cur_module = modules[node.target]
if type(cur_module) in mkldnn_map:
# pyrefly: ignore # index-error
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
assert isinstance(new_module, nn.Module)
old_modules[new_module] = copy.deepcopy(cur_module)
@ -263,7 +264,10 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
)
reset_modules(
submodule.graph.nodes, dict(submodule.named_modules()), old_modules
submodule.graph.nodes,
dict(submodule.named_modules()),
# pyrefly: ignore # bad-argument-type
old_modules,
)
no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
return mkl_time < no_mkl_time

View File

@ -124,6 +124,7 @@ pytree.register_pytree_node(
torch.Size,
lambda xs: (list(xs), None),
lambda xs, _: tuple(xs),
# pyrefly: ignore # bad-argument-type
flatten_with_keys_fn=lambda xs: (
[(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
None,
@ -306,6 +307,7 @@ def set_proxy_slot( # type: ignore[no-redef]
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
assert isinstance(obj, (Tensor, SymNode)), type(obj)
# pyrefly: ignore # no-matching-overload
return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
@ -402,6 +404,7 @@ def get_proxy_slot(
assert isinstance(obj, py_sym_types), type(obj)
tracker = tracer.symnode_tracker
# pyrefly: ignore # unsupported-operation
if obj not in tracker:
# Last ditch
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
@ -413,6 +416,7 @@ def get_proxy_slot(
)
return default
else:
# pyrefly: ignore # index-error
value = tracker[obj]
res = transform(value)
return res
@ -788,6 +792,7 @@ def fetch_object_proxy(
def fetch_object_proxy(
tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType]
) -> object:
# pyrefly: ignore # no-matching-overload
return get_proxy_slot(t, tracer, t)
@ -836,6 +841,7 @@ def _fetch_proxies_and_all_constant_flag(
"""
f_flat_args_kwargs = [
(
# pyrefly: ignore # no-matching-overload
fetch_object_proxy(tracer, x)
if isinstance(x, (Tensor, _AnyScriptObject))
else x
@ -1410,6 +1416,7 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
kwargs: Optional[dict[str, object]] = None,
) -> object:
kwargs = kwargs or {}
# pyrefly: ignore # bad-assignment
self.tracer.torch_fn_metadata = func
self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1
return func(*args, **kwargs)
@ -1459,6 +1466,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
# For autocast, the python APIs run so we don't have to run them again
# here.
if func is torch._C._set_grad_enabled:
# pyrefly: ignore # bad-argument-type
func(*args, **kwargs)
return node
@ -1672,6 +1680,7 @@ class DecompositionInterpreter(fx.Interpreter):
self.decomposition_table = decomposition_table or {}
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
# pyrefly: ignore # bad-override
def placeholder(
self,
target: str, # type: ignore[override]
@ -1684,6 +1693,7 @@ class DecompositionInterpreter(fx.Interpreter):
# TODO handle case where the first character of target is '*'
return out
# pyrefly: ignore # bad-override
def get_attr(
self,
target: str, # type: ignore[override]
@ -1697,6 +1707,7 @@ class DecompositionInterpreter(fx.Interpreter):
# call_function, call_method, call_module get traced automatically by the outer mode.
# pyrefly: ignore # bad-override
def output(
self,
target: str, # type: ignore[override]
@ -1782,14 +1793,17 @@ class _ModuleStackTracer(PythonKeyTracer):
self.enable_attr_proxy = False
self.submodule_paths = {}
for name, m in self.scope_root.named_modules(remove_duplicate=False):
# pyrefly: ignore # unsupported-operation
if m in self.submodule_paths:
log.info(
"Shared module found between %s and %s, AttrProxy is enabled.",
# pyrefly: ignore # unsupported-operation
self.submodule_paths[m],
name,
)
self.enable_attr_proxy = True
else:
# pyrefly: ignore # unsupported-operation
self.submodule_paths[m] = name
self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary()
@ -1815,6 +1829,7 @@ class _ModuleStackTracer(PythonKeyTracer):
# Class is modified to be a subclass of torch.nn.Module
# Warning: We blow away our own attributes here to mimic the base class
# - so don't expect `self.x` to do anything useful.
# pyrefly: ignore # no-matching-overload
self.__class__ = type(
base.__class__.__name__,
(self.__class__, base.__class__),
@ -1837,6 +1852,7 @@ class _ModuleStackTracer(PythonKeyTracer):
if not isinstance(attr_val, Module):
return attr_val
# pyrefly: ignore # index-error
return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name)
def get_base(self) -> Module:
@ -1849,10 +1865,12 @@ class _ModuleStackTracer(PythonKeyTracer):
res = torch.nn.Sequential(
OrderedDict(list(self._modules.items())[idx])
)
# pyrefly: ignore # index-error
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
elif isinstance(self, torch.nn.ModuleList):
# Copied from nn/modules/container.py
res = torch.nn.ModuleList(list(self._modules.values())[idx])
# pyrefly: ignore # index-error
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
return super().__getitem__(idx) # type: ignore[misc]

View File

@ -839,6 +839,7 @@ def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr:
factor = functools.reduce(math.gcd, map(integer_coefficient, atoms))
if factor == 1:
return expr
# pyrefly: ignore # bad-argument-type
atoms = [div_by_factor(x, factor) for x in atoms]
return _sympy_from_args(
sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative
@ -1234,6 +1235,7 @@ def _free_unbacked_symbols_with_path(
else _symint_wrap(coeff)
)
# TODO: DivideByKey needs to test divisibility at runtime!
# pyrefly: ignore # unsupported-operation
r[unbacked] = path + (DivideByKey(divisor),)
if real is not None:
assert isinstance(real, int)
@ -1256,6 +1258,7 @@ def _free_unbacked_symbols_with_path(
and s.rhs == 1
and s.lhs in pending
):
# pyrefly: ignore # unsupported-operation
r[s.lhs] = path + (ConvertIntKey(),)
if real is not None:
assert type(real) is bool
@ -2172,6 +2175,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext):
def __post_init__(self) -> None:
super().__post_init__()
if self.inner_contexts is None:
# pyrefly: ignore # bad-assignment
self.inner_contexts = {}
@ -2260,9 +2264,12 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
# only re-create the objects if any of the args changed to avoid expensive
# checks when re-creating objects.
new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type]
# pyrefly: ignore # missing-attribute
if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)):
# pyrefly: ignore # missing-attribute
return _fast_expand(expr.func(*new_args))
# pyrefly: ignore # missing-attribute
if expr.is_Pow:
base: sympy.Expr
exp: sympy.Expr
@ -2272,9 +2279,11 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
return sympy.expand_multinomial(expr, deep=False)
elif exp < 0:
return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
# pyrefly: ignore # missing-attribute
elif expr.is_Mul:
num: list[sympy.Expr] = []
den: list[sympy.Expr] = []
# pyrefly: ignore # missing-attribute
for arg in expr.args:
if arg.is_Pow and arg.args[1] == -1:
den.append(S.One / arg) # type: ignore[operator, arg-type]
@ -2396,6 +2405,7 @@ def _maybe_evaluate_static_worker(
# TODO: remove this try catch (esp for unbacked_only)
try:
# pyrefly: ignore # missing-attribute
new_expr = expr.xreplace(new_shape_env)
except RecursionError:
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
@ -2933,13 +2943,19 @@ class DimConstraints:
# is_integer tests though haha
return (base - mod_reduced) / divisor
# pyrefly: ignore # missing-attribute
if expr.has(Mod):
# pyrefly: ignore # missing-attribute
expr = expr.replace(Mod, mod_handler)
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
# arguments should be OK.
# pyrefly: ignore # missing-attribute
if expr.has(PythonMod):
# pyrefly: ignore # missing-attribute
expr = expr.replace(PythonMod, mod_handler)
# pyrefly: ignore # missing-attribute
if expr.has(FloorDiv):
# pyrefly: ignore # missing-attribute
expr = expr.replace(FloorDiv, floor_div_handler)
return expr
@ -5057,6 +5073,7 @@ class ShapeEnv:
if duck:
# Make sure to reuse this symbol for subsequent duck shaping
# pyrefly: ignore # unsupported-operation
self.val_to_var[val] = sympy_expr
if isinstance(val, int):
@ -5288,15 +5305,19 @@ class ShapeEnv:
# Expand optional inputs, or verify invariants are upheld
if input_contexts is None:
# pyrefly: ignore # bad-assignment
input_contexts = [
# pyrefly: ignore # bad-argument-type
_create_no_constraints_context(t) if isinstance(t, Tensorlike) else None
for t in placeholders
]
else:
assert len(input_contexts) == len(placeholders)
# pyrefly: ignore # bad-assignment
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
if isinstance(t, Tensorlike):
if context is None:
# pyrefly: ignore # bad-argument-type
input_contexts[i] = _create_no_constraints_context(t)
else:
assert isinstance(t, (SymInt, int, SymFloat, float))
@ -5582,6 +5603,7 @@ class ShapeEnv:
s = sympy.Float(val)
input_guards.append((source, s))
# pyrefly: ignore # no-matching-overload
for t, source, context in zip(placeholders, sources, input_contexts):
if isinstance(source, str):
from torch._dynamo.source import LocalSource
@ -5641,11 +5663,13 @@ class ShapeEnv:
)
track_symint(property_source, ss, constraint_size[i])
else:
# pyrefly: ignore # missing-attribute
for i, ss in enumerate(curr_t.size()):
property_source = TensorPropertySource(
src, TensorProperty.SIZE, i
)
track_symint(property_source, ss, constraint_size[i])
# pyrefly: ignore # missing-attribute
for i, ss in enumerate(curr_t.stride()):
property_source = TensorPropertySource(
src, TensorProperty.STRIDE, i
@ -5653,6 +5677,7 @@ class ShapeEnv:
track_symint(property_source, ss, constraint_stride[i])
track_symint(
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
# pyrefly: ignore # missing-attribute
curr_t.storage_offset(),
)
@ -5698,6 +5723,7 @@ class ShapeEnv:
continue
if is_dim(source):
# pyrefly: ignore # missing-attribute
self.dim_constraints.add_equality(source, expr)
for exprs, printer, lang in zip(all_exprs, printers, langs):
@ -5851,6 +5877,7 @@ class ShapeEnv:
continue
expr = self.simplify(ra.expr)
# pyrefly: ignore # missing-attribute
self.dim_constraints.add(expr)
# 3. Every symbol must be within its value range (this handles 0/1
@ -5867,6 +5894,7 @@ class ShapeEnv:
verbose_expr = ""
if r.lower not in (-sympy.oo, -int_oo):
if any(is_dim(source) for source in sources):
# pyrefly: ignore # missing-attribute
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
# Only print lower bound in simplified mode if it is not the
# default
@ -5875,6 +5903,7 @@ class ShapeEnv:
verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}"
if r.upper not in (sympy.oo, int_oo):
if any(is_dim(source) for source in sources):
# pyrefly: ignore # missing-attribute
self.dim_constraints.add(sympy.Le(symbol, r.upper))
# nontrivial upper bound is always interesting
bounds.append(sympy.Le(symbol, r.upper, evaluate=False))
@ -5943,6 +5972,7 @@ class ShapeEnv:
else:
str_msg = f" - {msg_cb()}"
error_msgs.append(str_msg)
# pyrefly: ignore # bad-argument-type
debug_names.add(debug_name)
if len(error_msgs) > 0:
debug_names_str = ", ".join(sorted(debug_names))
@ -6076,6 +6106,7 @@ class ShapeEnv:
Get a list of guards, but pruned so it only provides guards that
reference symints from the passed in input
"""
# pyrefly: ignore # bad-assignment
symints = {
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
}
@ -6121,6 +6152,7 @@ class ShapeEnv:
else:
bindings[-s] = -arg
# pyrefly: ignore # bad-assignment
for t, arg in zip(placeholders, args):
if t is None:
continue
@ -6338,6 +6370,7 @@ class ShapeEnv:
Apply symbol replacements to any symbols in the given expression.
"""
replacements = {}
# pyrefly: ignore # missing-attribute
for s in expr.free_symbols:
r = self._find(s)
@ -6347,6 +6380,7 @@ class ShapeEnv:
if not r.is_Symbol or r != s:
replacements[s] = r
if replacements:
# pyrefly: ignore # missing-attribute
return safe_expand(expr.xreplace(replacements))
else:
return expr
@ -7121,6 +7155,7 @@ class ShapeEnv:
instructions = list(dis.Bytecode(frame.f_code))
co_lines, offset = inspect.getsourcelines(frame.f_code)
start, end, cur = None, None, None
# pyrefly: ignore # bad-assignment
for i, instr in enumerate(instructions):
if instr.starts_line is not None:
cur = instr.starts_line
@ -8000,6 +8035,7 @@ def _suggest_fixes_for_data_dependent_error_non_strict(
if isinstance(leaf, torch.SymInt):
src_map[str(leaf.node.expr)].append(name)
elif isinstance(leaf, torch.Tensor):
# pyrefly: ignore # bad-assignment
for i, dim in enumerate(leaf.shape):
if isinstance(dim, torch.SymInt):
src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]")

Some files were not shown because too many files have changed in this diff Show More