Add pyrefly suppressions (#164748)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the `project-excludes` field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

0 errors (4,263 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-07 17:31:18 +00:00
committed by PyTorch MergeBot
parent 5e47b4dd60
commit b13cd141b3
139 changed files with 504 additions and 30 deletions

View File

@ -25,10 +25,6 @@ project-excludes = [
"torch/nn/**", "torch/nn/**",
"torch/_dynamo/**", "torch/_dynamo/**",
"torch/utils/**", "torch/utils/**",
"torch/ao/**",
"torch/fx/**",
"torch/distributions/**",
"torch/onnx/**",
# formatting issues # formatting issues
"torch/linalg/__init__.py", "torch/linalg/__init__.py",
"torch/package/importer.py", "torch/package/importer.py",

View File

@ -470,7 +470,6 @@ def _check_input_constraints_for_graph(
) )
elif isinstance(node_val, torch.SymInt): elif isinstance(node_val, torch.SymInt):
_check_symint( _check_symint(
# pyrefly: ignore # bad-argument-type
node_val, node_val,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
arg, arg,

View File

@ -360,7 +360,6 @@ def trace_flex_attention(
"call_function", flex_attention, proxy_args, {} "call_function", flex_attention, proxy_args, {}
) )
return track_tensor_tree( return track_tensor_tree(
# pyrefly: ignore # bad-argument-type
example_out, example_out,
out_proxy, out_proxy,
constant=None, constant=None,
@ -1080,7 +1079,6 @@ def trace_flex_attention_backward(
name="flex_attention_backward", name="flex_attention_backward",
) )
return track_tensor_tree( return track_tensor_tree(
# pyrefly: ignore # bad-argument-type
example_out, example_out,
out_proxy, out_proxy,
constant=None, constant=None,

View File

@ -899,7 +899,6 @@ def analyze_kernel_mutations(
if op.name == "tt.call": if op.name == "tt.call":
assert op.fn_call_name in functions assert op.fn_call_name in functions
mutations = analyze_kernel_mutations( mutations = analyze_kernel_mutations(
# pyrefly: ignore # bad-argument-type
functions, functions,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
op.fn_call_name, op.fn_call_name,

View File

@ -3379,7 +3379,6 @@ def native_layer_norm(
torch._check( torch._check(
input.ndim >= normalized_ndim input.ndim >= normalized_ndim
and sym_eq( and sym_eq(
# pyrefly: ignore # bad-argument-type
input.shape[(input.ndim - normalized_ndim) :], input.shape[(input.ndim - normalized_ndim) :],
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
tuple(normalized_shape), tuple(normalized_shape),

View File

@ -620,6 +620,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
# pyrefly: ignore # bad-argument-type
padding_mode=padding_mode, padding_mode=padding_mode,
qconfig=qconfig, qconfig=qconfig,
) )
@ -820,6 +821,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
# pyrefly: ignore # bad-argument-type
padding_mode=padding_mode, padding_mode=padding_mode,
qconfig=qconfig, qconfig=qconfig,
) )
@ -1021,6 +1023,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
# pyrefly: ignore # bad-argument-type
padding_mode=padding_mode, padding_mode=padding_mode,
qconfig=qconfig, qconfig=qconfig,
) )

View File

@ -36,6 +36,7 @@ class LinearReLU(nnqat.Linear, _FusedModule):
torch.Size([128, 30]) torch.Size([128, 30])
""" """
# pyrefly: ignore # bad-override
_FLOAT_MODULE = nni.LinearReLU _FLOAT_MODULE = nni.LinearReLU
def __init__( def __init__(

View File

@ -30,6 +30,7 @@ class LinearReLU(nnqd.Linear):
torch.Size([128, 30]) torch.Size([128, 30])
""" """
# pyrefly: ignore # bad-override
_FLOAT_MODULE = nni.LinearReLU _FLOAT_MODULE = nni.LinearReLU
def __init__( def __init__(

View File

@ -54,6 +54,7 @@ class ConvReLU1d(nnq.Conv1d):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
# pyrefly: ignore # bad-argument-type
padding_mode=padding_mode, padding_mode=padding_mode,
device=device, device=device,
dtype=dtype, dtype=dtype,

View File

@ -114,6 +114,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
assert hasattr(cls, "_FLOAT_RELU_MODULE") assert hasattr(cls, "_FLOAT_RELU_MODULE")
relu = cls._FLOAT_RELU_MODULE() relu = cls._FLOAT_RELU_MODULE()
modules.append(relu) modules.append(relu)
# pyrefly: ignore # missing-attribute
fused = cls._FLOAT_MODULE(*modules) fused = cls._FLOAT_MODULE(*modules)
fused.train(self.training) fused.train(self.training)
return fused return fused

View File

@ -50,6 +50,7 @@ class Embedding(nn.Embedding):
scale_grad_by_freq, scale_grad_by_freq,
sparse, sparse,
_weight, _weight,
# pyrefly: ignore # bad-argument-type
**factory_kwargs, **factory_kwargs,
) )
assert qconfig, "qconfig must be provided for QAT module" assert qconfig, "qconfig must be provided for QAT module"

View File

@ -170,8 +170,11 @@ class MultiheadAttention(nn.MultiheadAttention):
observed.linear_K.weight = nn.Parameter(other.k_proj_weight) observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
observed.linear_V.weight = nn.Parameter(other.v_proj_weight) observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
if other.in_proj_bias is None: if other.in_proj_bias is None:
# pyrefly: ignore # bad-assignment
observed.linear_Q.bias = None observed.linear_Q.bias = None
# pyrefly: ignore # bad-assignment
observed.linear_K.bias = None observed.linear_K.bias = None
# pyrefly: ignore # bad-assignment
observed.linear_V.bias = None observed.linear_V.bias = None
else: else:
observed.linear_Q.bias = nn.Parameter( observed.linear_Q.bias = nn.Parameter(
@ -234,6 +237,7 @@ class MultiheadAttention(nn.MultiheadAttention):
_end = _start + fp.embed_dim _end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wQ fp.in_proj_weight[_start:_end, :] = wQ
if fp.in_proj_bias is not None: if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type
assert all(bQ == 0) assert all(bQ == 0)
fp.in_proj_bias[_start:_end] = bQ fp.in_proj_bias[_start:_end] = bQ
@ -241,12 +245,14 @@ class MultiheadAttention(nn.MultiheadAttention):
_end = _start + fp.embed_dim _end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wK fp.in_proj_weight[_start:_end, :] = wK
if fp.in_proj_bias is not None: if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type
assert all(bK == 0) assert all(bK == 0)
fp.in_proj_bias[_start:_end] = bK fp.in_proj_bias[_start:_end] = bK
_start = _end _start = _end
fp.in_proj_weight[_start:, :] = wV fp.in_proj_weight[_start:, :] = wV
if fp.in_proj_bias is not None: if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type
assert all(bV == 0) assert all(bV == 0)
fp.in_proj_bias[_start:] = bV fp.in_proj_bias[_start:] = bV
else: else:
@ -254,8 +260,11 @@ class MultiheadAttention(nn.MultiheadAttention):
fp.k_proj_weight = nn.Parameter(wK) fp.k_proj_weight = nn.Parameter(wK)
fp.v_proj_weight = nn.Parameter(wV) fp.v_proj_weight = nn.Parameter(wV)
if fp.in_proj_bias is None: if fp.in_proj_bias is None:
# pyrefly: ignore # bad-assignment
self.linear_Q.bias = None self.linear_Q.bias = None
# pyrefly: ignore # bad-assignment
self.linear_K.bias = None self.linear_K.bias = None
# pyrefly: ignore # bad-assignment
self.linear_V.bias = None self.linear_V.bias = None
else: else:
fp.in_proj_bias[0 : fp.embed_dim] = bQ fp.in_proj_bias[0 : fp.embed_dim] = bQ
@ -463,6 +472,7 @@ class MultiheadAttention(nn.MultiheadAttention):
assert static_v.size(2) == head_dim assert static_v.size(2) == head_dim
v = static_v v = static_v
# pyrefly: ignore # missing-attribute
src_len = k.size(1) src_len = k.size(1)
if key_padding_mask is not None: if key_padding_mask is not None:
@ -471,17 +481,35 @@ class MultiheadAttention(nn.MultiheadAttention):
if self.add_zero_attn: if self.add_zero_attn:
src_len += 1 src_len += 1
# pyrefly: ignore # missing-attribute
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
# pyrefly: ignore # missing-attribute
if k.is_quantized: if k.is_quantized:
k_zeros = torch.quantize_per_tensor( k_zeros = torch.quantize_per_tensor(
k_zeros, k.q_scale(), k.q_zero_point(), k.dtype k_zeros,
# pyrefly: ignore # missing-attribute
k.q_scale(),
# pyrefly: ignore # missing-attribute
k.q_zero_point(),
# pyrefly: ignore # missing-attribute
k.dtype,
) )
# pyrefly: ignore # no-matching-overload
k = torch.cat([k, k_zeros], dim=1) k = torch.cat([k, k_zeros], dim=1)
# pyrefly: ignore # missing-attribute
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
# pyrefly: ignore # missing-attribute
if v.is_quantized: if v.is_quantized:
v_zeros = torch.quantize_per_tensor( v_zeros = torch.quantize_per_tensor(
v_zeros, v.q_scale(), v.q_zero_point(), v.dtype v_zeros,
# pyrefly: ignore # missing-attribute
v.q_scale(),
# pyrefly: ignore # missing-attribute
v.q_zero_point(),
# pyrefly: ignore # missing-attribute
v.dtype,
) )
# pyrefly: ignore # no-matching-overload
v = torch.cat([v, v_zeros], dim=1) v = torch.cat([v, v_zeros], dim=1)
if attn_mask is not None: if attn_mask is not None:

View File

@ -376,6 +376,7 @@ class _LSTMLayer(torch.nn.Module):
bidirectional, bidirectional,
split_gates=split_gates, split_gates=split_gates,
) )
# pyrefly: ignore # bad-argument-type
layer.qconfig = getattr(other, "qconfig", qconfig) layer.qconfig = getattr(other, "qconfig", qconfig)
wi = getattr(other, f"weight_ih_l{layer_idx}") wi = getattr(other, f"weight_ih_l{layer_idx}")
wh = getattr(other, f"weight_hh_l{layer_idx}") wh = getattr(other, f"weight_hh_l{layer_idx}")
@ -454,6 +455,7 @@ class LSTM(torch.nn.Module):
if ( if (
not isinstance(dropout, numbers.Number) not isinstance(dropout, numbers.Number)
# pyrefly: ignore # unsupported-operation
or not 0 <= dropout <= 1 or not 0 <= dropout <= 1
or isinstance(dropout, bool) or isinstance(dropout, bool)
): ):
@ -462,6 +464,7 @@ class LSTM(torch.nn.Module):
"representing the probability of an element being " "representing the probability of an element being "
"zeroed" "zeroed"
) )
# pyrefly: ignore # unsupported-operation
if dropout > 0: if dropout > 0:
warnings.warn( warnings.warn(
"dropout option for quantizable LSTM is ignored. " "dropout option for quantizable LSTM is ignored. "
@ -573,6 +576,7 @@ class LSTM(torch.nn.Module):
other.bidirectional, other.bidirectional,
split_gates=split_gates, split_gates=split_gates,
) )
# pyrefly: ignore # bad-argument-type
observed.qconfig = getattr(other, "qconfig", qconfig) observed.qconfig = getattr(other, "qconfig", qconfig)
for idx in range(other.num_layers): for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float( observed.layers[idx] = _LSTMLayer.from_float(

View File

@ -73,6 +73,7 @@ class Conv1d(nnq.Conv1d):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size) kernel_size = _single(kernel_size)
stride = _single(stride) stride = _single(stride)
# pyrefly: ignore # bad-assignment
padding = padding if isinstance(padding, str) else _single(padding) padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation) dilation = _single(dilation)

View File

@ -119,7 +119,9 @@ class Linear(nnq.Linear):
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU: if type(mod) == nni.LinearReLU:
mod = mod[0] mod = mod[0]
# pyrefly: ignore # missing-attribute
if mod.qconfig is not None and mod.qconfig.weight is not None: if mod.qconfig is not None and mod.qconfig.weight is not None:
# pyrefly: ignore # not-callable
weight_observer = mod.qconfig.weight() weight_observer = mod.qconfig.weight()
else: else:
# We have the circular import issues if we import the qconfig in the beginning of this file: # We have the circular import issues if we import the qconfig in the beginning of this file:
@ -143,6 +145,7 @@ class Linear(nnq.Linear):
"Unsupported dtype specified for dynamic quantized Linear!" "Unsupported dtype specified for dynamic quantized Linear!"
) )
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
# pyrefly: ignore # bad-argument-type
qlinear.set_weight_bias(qweight, mod.bias) qlinear.set_weight_bias(qweight, mod.bias)
return qlinear return qlinear

View File

@ -521,6 +521,7 @@ class LSTM(RNNBase):
>>> output, (hn, cn) = rnn(input, (h0, c0)) >>> output, (hn, cn) = rnn(input, (h0, c0))
""" """
# pyrefly: ignore # bad-override
_FLOAT_MODULE = nn.LSTM _FLOAT_MODULE = nn.LSTM
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]} __overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
@ -806,6 +807,7 @@ class GRU(RNNBase):
>>> output, hn = rnn(input, h0) >>> output, hn = rnn(input, h0)
""" """
# pyrefly: ignore # bad-override
_FLOAT_MODULE = nn.GRU _FLOAT_MODULE = nn.GRU
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]} __overloads__ = {"forward": ["forward_packed", "forward_tensor"]}

View File

@ -67,7 +67,9 @@ class Hardswish(torch.nn.Hardswish):
def __init__(self, scale, zero_point, device=None, dtype=None): def __init__(self, scale, zero_point, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -138,7 +140,9 @@ class LeakyReLU(torch.nn.LeakyReLU):
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(negative_slope, inplace) super().__init__(negative_slope, inplace)
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -226,6 +230,7 @@ class Softmax(torch.nn.Softmax):
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention): class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
# pyrefly: ignore # bad-override
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
def _get_name(self): def _get_name(self):

View File

@ -12,7 +12,9 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs) super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs)) self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
@staticmethod @staticmethod

View File

@ -408,6 +408,7 @@ class Conv1d(_ConvNd):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size) kernel_size = _single(kernel_size)
stride = _single(stride) stride = _single(stride)
# pyrefly: ignore # bad-assignment
padding = padding if isinstance(padding, str) else _single(padding) padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation) dilation = _single(dilation)

View File

@ -310,6 +310,7 @@ class Linear(WeightedQuantizedModule):
# the type mismatch in assignment. Also, mypy has an issue with # the type mismatch in assignment. Also, mypy has an issue with
# iterables not being implemented, so we are ignoring those too. # iterables not being implemented, so we are ignoring those too.
if not isinstance(cls._FLOAT_MODULE, Iterable): if not isinstance(cls._FLOAT_MODULE, Iterable):
# pyrefly: ignore # bad-assignment
cls._FLOAT_MODULE = [cls._FLOAT_MODULE] cls._FLOAT_MODULE = [cls._FLOAT_MODULE]
supported_modules = ", ".join( supported_modules = ", ".join(
[float_mod.__name__ for float_mod in cls._FLOAT_MODULE] [float_mod.__name__ for float_mod in cls._FLOAT_MODULE]

View File

@ -37,11 +37,14 @@ class LayerNorm(torch.nn.LayerNorm):
normalized_shape, normalized_shape,
eps=eps, eps=eps,
elementwise_affine=elementwise_affine, elementwise_affine=elementwise_affine,
# pyrefly: ignore # bad-argument-type
**factory_kwargs, **factory_kwargs,
) )
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -113,7 +116,9 @@ class GroupNorm(torch.nn.GroupNorm):
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs) super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -175,7 +180,9 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
) )
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -242,7 +249,9 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
) )
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -309,7 +318,9 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
) )
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
# pyrefly: ignore # bad-argument-type
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):

View File

@ -95,6 +95,7 @@ class Conv1d(_ConvNd, nn.Conv1d):
and the backend should be able to fuse the ops with `*` into a quantized conv1d and the backend should be able to fuse the ops with `*` into a quantized conv1d
""" """
weight_quant_dequant = self.get_weight() weight_quant_dequant = self.get_weight()
# pyrefly: ignore # no-matching-overload
result = F.conv1d( result = F.conv1d(
x, x,
weight_quant_dequant, weight_quant_dequant,
@ -140,6 +141,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
dilation, dilation,
groups, groups,
bias, bias,
# pyrefly: ignore # bad-argument-type
padding_mode, padding_mode,
device, device,
dtype, dtype,
@ -158,6 +160,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
and the backend should be able to fuse the ops with `*` into a quantized conv2d and the backend should be able to fuse the ops with `*` into a quantized conv2d
""" """
weight_quant_dequant = self.get_weight() weight_quant_dequant = self.get_weight()
# pyrefly: ignore # no-matching-overload
result = F.conv2d( result = F.conv2d(
x, x,
weight_quant_dequant, weight_quant_dequant,
@ -203,6 +206,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
dilation, dilation,
groups, groups,
bias, bias,
# pyrefly: ignore # bad-argument-type
padding_mode, padding_mode,
device, device,
dtype, dtype,
@ -221,6 +225,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
and the backend should be able to fuse the ops with `*` into a quantized conv3d and the backend should be able to fuse the ops with `*` into a quantized conv3d
""" """
weight_quant_dequant = self.get_weight() weight_quant_dequant = self.get_weight()
# pyrefly: ignore # no-matching-overload
result = F.conv3d( result = F.conv3d(
x, x,
weight_quant_dequant, weight_quant_dequant,
@ -378,6 +383,7 @@ class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
groups, groups,
bias, bias,
dilation, dilation,
# pyrefly: ignore # bad-argument-type
padding_mode, padding_mode,
device, device,
dtype, dtype,
@ -459,6 +465,7 @@ class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
groups, groups,
bias, bias,
dilation, dilation,
# pyrefly: ignore # bad-argument-type
padding_mode, padding_mode,
device, device,
dtype, dtype,

View File

@ -663,7 +663,11 @@ class LSTM(RNNBase):
# xxx: isinstance check needs to be in conditional for TorchScript to compile # xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence): if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence( output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices output,
# pyrefly: ignore # bad-argument-type
batch_sizes,
sorted_indices,
unsorted_indices,
) )
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
else: else:
@ -823,7 +827,11 @@ class GRU(RNNBase):
# xxx: isinstance check needs to be in conditional for TorchScript to compile # xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence): if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence( output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices output,
# pyrefly: ignore # bad-argument-type
batch_sizes,
sorted_indices,
unsorted_indices,
) )
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
else: else:

View File

@ -42,6 +42,7 @@ class Embedding(nn.Embedding, ReferenceQuantizedModule):
scale_grad_by_freq, scale_grad_by_freq,
sparse, sparse,
_weight, _weight,
# pyrefly: ignore # bad-argument-type
device, device,
dtype, dtype,
) )

View File

@ -18,6 +18,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
"scale": 1.0, "scale": 1.0,
"zero_point": 0, "zero_point": 0,
} }
# pyrefly: ignore # bad-assignment
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"] self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
self.weight_dtype = weight_qparams["dtype"] self.weight_dtype = weight_qparams["dtype"]
assert self.weight_qscheme in [ assert self.weight_qscheme in [
@ -80,13 +81,16 @@ class ReferenceQuantizedModule(torch.nn.Module):
self.register_buffer( self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device) "weight_axis", torch.tensor(0, dtype=torch.int, device=device)
) )
# pyrefly: ignore # bad-assignment
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False) self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
# for capturing `.item` operations # for capturing `.item` operations
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment] self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
# pyrefly: ignore # bad-assignment
self.weight_quant_min: typing.Optional[int] = weight_qparams.get( self.weight_quant_min: typing.Optional[int] = weight_qparams.get(
"quant_min", None "quant_min", None
) )
# pyrefly: ignore # bad-assignment
self.weight_quant_max: typing.Optional[int] = weight_qparams.get( self.weight_quant_max: typing.Optional[int] = weight_qparams.get(
"quant_max", None "quant_max", None
) )
@ -105,6 +109,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_and_dequantize_weight_decomposed( return _quantize_and_dequantize_weight_decomposed(
self.weight, # type: ignore[arg-type] self.weight, # type: ignore[arg-type]
self.weight_qscheme, self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
self.weight_dtype, self.weight_dtype,
self.weight_scale, self.weight_scale,
self.weight_zero_point, self.weight_zero_point,
@ -116,6 +121,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_and_dequantize_weight( return _quantize_and_dequantize_weight(
self.weight, # type: ignore[arg-type] self.weight, # type: ignore[arg-type]
self.weight_qscheme, self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
self.weight_dtype, self.weight_dtype,
self.weight_scale, self.weight_scale,
self.weight_zero_point, self.weight_zero_point,
@ -131,6 +137,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_weight_decomposed( return _quantize_weight_decomposed(
self.weight, # type: ignore[arg-type] self.weight, # type: ignore[arg-type]
self.weight_qscheme, self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
self.weight_dtype, self.weight_dtype,
self.weight_scale, self.weight_scale,
self.weight_zero_point, self.weight_zero_point,
@ -142,6 +149,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_weight( return _quantize_weight(
self.weight, # type: ignore[arg-type] self.weight, # type: ignore[arg-type]
self.weight_qscheme, self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
self.weight_dtype, self.weight_dtype,
self.weight_scale, self.weight_scale,
self.weight_zero_point, self.weight_zero_point,

View File

@ -151,7 +151,9 @@ class Linear(torch.nn.Module):
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU: if type(mod) == nni.LinearReLU:
mod = mod[0] mod = mod[0]
# pyrefly: ignore # missing-attribute
if mod.qconfig is not None and mod.qconfig.weight is not None: if mod.qconfig is not None and mod.qconfig.weight is not None:
# pyrefly: ignore # not-callable
weight_observer = mod.qconfig.weight() weight_observer = mod.qconfig.weight()
else: else:
# We have the circular import issues if we import the qconfig in the beginning of this file: # We have the circular import issues if we import the qconfig in the beginning of this file:
@ -185,5 +187,6 @@ class Linear(torch.nn.Module):
col_block_size, col_block_size,
dtype=dtype, dtype=dtype,
) )
# pyrefly: ignore # bad-argument-type
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size) qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
return qlinear return qlinear

View File

@ -84,6 +84,7 @@ class _NSGraphMatchableSubgraphsIterator:
if is_match: if is_match:
# navigate to the base node # navigate to the base node
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1): for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
# pyrefly: ignore # bad-argument-type
self.seen_nodes.add(cur_start_node) self.seen_nodes.add(cur_start_node)
# for now, assume that there are no other nodes # for now, assume that there are no other nodes
# which need to be added to the stack # which need to be added to the stack
@ -94,8 +95,10 @@ class _NSGraphMatchableSubgraphsIterator:
cur_base_op_node = cur_start_node cur_base_op_node = cur_start_node
break break
# pyrefly: ignore # bad-argument-type
self.seen_nodes.add(cur_start_node) self.seen_nodes.add(cur_start_node)
# add args of previous nodes to stack # add args of previous nodes to stack
# pyrefly: ignore # missing-attribute
for arg in cur_start_node.all_input_nodes: for arg in cur_start_node.all_input_nodes:
self._recursively_add_node_arg_to_stack(arg) self._recursively_add_node_arg_to_stack(arg)
@ -103,6 +106,7 @@ class _NSGraphMatchableSubgraphsIterator:
# note: this check is done on the start_node, i.e. # note: this check is done on the start_node, i.e.
# if we are matching linear-relu in reverse, this would do the matchable # if we are matching linear-relu in reverse, this would do the matchable
# check on the linear # check on the linear
# pyrefly: ignore # bad-argument-type
if not self._is_matchable(cur_base_op_node): if not self._is_matchable(cur_base_op_node):
continue continue
@ -116,8 +120,10 @@ class _NSGraphMatchableSubgraphsIterator:
continue continue
return NSSubgraph( return NSSubgraph(
# pyrefly: ignore # bad-argument-type
start_node=cur_start_node, start_node=cur_start_node,
end_node=cur_end_node, end_node=cur_end_node,
# pyrefly: ignore # bad-argument-type
base_op_node=cur_base_op_node, base_op_node=cur_base_op_node,
) )

View File

@ -415,6 +415,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
target2, target2,
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
new_connections.append((source, target1)) new_connections.append((source, target1))
# pyrefly: ignore # bad-argument-type
new_connections.append((source, target2)) new_connections.append((source, target2))
for source_to_target in ( for source_to_target in (
@ -423,6 +424,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
): ):
for source, target in source_to_target.items(): # type:ignore[assignment] for source, target in source_to_target.items(): # type:ignore[assignment]
# pyrefly: ignore # bad-argument-type
new_connections.append((source, target)) new_connections.append((source, target))
# #

View File

@ -95,6 +95,7 @@ class OutputProp:
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined] if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
node.traced_result = result node.traced_result = result
# pyrefly: ignore # unsupported-operation
env[node.name] = result env[node.name] = result
return None return None
@ -393,8 +394,10 @@ def create_submodule_from_subgraph(
cur_name_idx += 1 cur_name_idx += 1
setattr(gm, mod_name, new_arg) setattr(gm, mod_name, new_arg)
new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator] new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator]
# pyrefly: ignore # missing-attribute
cur_args_copy.append(new_arg_placeholder) cur_args_copy.append(new_arg_placeholder)
elif isinstance(arg, (float, int, torch.dtype)): elif isinstance(arg, (float, int, torch.dtype)):
# pyrefly: ignore # missing-attribute
cur_args_copy.append(arg) cur_args_copy.append(arg)
else: else:
raise AssertionError(f"arg of type {type(arg)} not handled yet") raise AssertionError(f"arg of type {type(arg)} not handled yet")
@ -801,6 +804,7 @@ def create_add_loggers_graph(
model, model,
cur_subgraph_idx, cur_subgraph_idx,
match_name, match_name,
# pyrefly: ignore # bad-argument-type
maybe_subgraph, maybe_subgraph,
[qconfig_mapping], [qconfig_mapping],
[node_name_to_qconfig], [node_name_to_qconfig],
@ -857,6 +861,7 @@ def create_add_loggers_graph(
cur_node_orig = first_node cur_node_orig = first_node
cur_node_copy = None cur_node_copy = None
first_node_copy = None first_node_copy = None
# pyrefly: ignore # bad-assignment
while cur_node_orig in subgraph_to_use: while cur_node_orig in subgraph_to_use:
# TODO(future PR): make this support all possible args/kwargs # TODO(future PR): make this support all possible args/kwargs
if cur_node_orig is first_node: if cur_node_orig is first_node:

View File

@ -404,6 +404,7 @@ def maybe_add_missing_fqns(results: NSResultsType) -> None:
for model_name, model_results in model_name_to_results.items(): for model_name, model_results in model_name_to_results.items():
if model_name == model_name_with_fqns: if model_name == model_name_with_fqns:
continue continue
# pyrefly: ignore # bad-assignment
for i in range(len(model_results)): for i in range(len(model_results)):
fqn = ref_model_results[i]["fqn"] fqn = ref_model_results[i]["fqn"]
model_results[i]["fqn"] = fqn model_results[i]["fqn"] = fqn
@ -467,6 +468,7 @@ def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso
Return: Return:
float or tuple of floats float or tuple of floats
""" """
# pyrefly: ignore # unsupported-operation
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum()) return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())

View File

@ -23,12 +23,17 @@ class SparseDLRM(DLRM_Net):
super().__init__(**args) super().__init__(**args)
def forward(self, dense_x, lS_o, lS_i): def forward(self, dense_x, lS_o, lS_i):
# pyrefly: ignore # missing-attribute
x = self.apply_mlp(dense_x, self.bot_l) # dense features x = self.apply_mlp(dense_x, self.bot_l) # dense features
# pyrefly: ignore # missing-attribute
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag
# pyrefly: ignore # missing-attribute
z = self.interact_features(x, ly) z = self.interact_features(x, ly)
z = z.to_sparse_coo() z = z.to_sparse_coo()
# pyrefly: ignore # missing-attribute
z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias) z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias)
# pyrefly: ignore # missing-attribute
for layer in self.top_l[1:]: for layer in self.top_l[1:]:
z = layer(z) z = layer(z)

View File

@ -72,6 +72,7 @@ class FPGMPruner(BaseStructuredSparsifier):
dist_matrix = self.dist_fn(t_flatten) dist_matrix = self.dist_fn(t_flatten)
# more similar with other filter indicates large in the sum of row # more similar with other filter indicates large in the sum of row
# pyrefly: ignore # bad-argument-type
distance = torch.sum(torch.abs(dist_matrix), 1) distance = torch.sum(torch.abs(dist_matrix), 1)
return distance return distance

View File

@ -260,6 +260,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
module.register_parameter( module.register_parameter(
"_bias", nn.Parameter(module.bias.detach()) "_bias", nn.Parameter(module.bias.detach())
) )
# pyrefly: ignore # bad-assignment
module.bias = None module.bias = None
module.prune_bias = prune_bias module.prune_bias = prune_bias

View File

@ -97,6 +97,7 @@ def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]:
if module.bias is not None: if module.bias is not None:
module.bias = nn.Parameter(cast(Tensor, module.bias)[mask]) module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
elif getattr(module, "_bias", None) is not None: elif getattr(module, "_bias", None) is not None:
# pyrefly: ignore # bad-assignment
module.bias = nn.Parameter(cast(Tensor, module._bias)[mask]) module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
# get pruned biases to propagate to subsequent layer # get pruned biases to propagate to subsequent layer

View File

@ -170,6 +170,7 @@ class BaseSparsifier(abc.ABC):
self.make_config_from_model(model) self.make_config_from_model(model)
# TODO: Remove the configuration by reference ('module') # TODO: Remove the configuration by reference ('module')
# pyrefly: ignore # not-iterable
for module_config in self.config: for module_config in self.config:
assert isinstance(module_config, dict), ( assert isinstance(module_config, dict), (
"config elements should be dicts not modules i.e.:" "config elements should be dicts not modules i.e.:"

View File

@ -51,6 +51,7 @@ def swap_module(
new_mod.register_forward_hook(hook_fn) new_mod.register_forward_hook(hook_fn)
# respect device affinity when swapping modules # respect device affinity when swapping modules
# pyrefly: ignore # bad-argument-type
devices = {p.device for p in chain(mod.parameters(), mod.buffers())} devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
assert len(devices) <= 1, ( assert len(devices) <= 1, (
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"

View File

@ -235,6 +235,7 @@ class WeightNormSparsifier(BaseSparsifier):
ww = self.norm_fn(getattr(module, tensor_name)) ww = self.norm_fn(getattr(module, tensor_name))
tensor_mask = self._make_tensor_mask( tensor_mask = self._make_tensor_mask(
data=ww, data=ww,
# pyrefly: ignore # missing-attribute
input_shape=ww.shape, input_shape=ww.shape,
sparsity_level=sparsity_level, sparsity_level=sparsity_level,
sparse_block_shape=sparse_block_shape, sparse_block_shape=sparse_block_shape,

View File

@ -24,6 +24,8 @@ from .pt2e.export_utils import (
_move_exported_model_to_eval as move_exported_model_to_eval, _move_exported_model_to_eval as move_exported_model_to_eval,
_move_exported_model_to_train as move_exported_model_to_train, _move_exported_model_to_train as move_exported_model_to_train,
) )
# pyrefly: ignore # deprecated
from .qconfig import * # noqa: F403 from .qconfig import * # noqa: F403
from .qconfig_mapping import * # noqa: F403 from .qconfig_mapping import * # noqa: F403
from .quant_type import * # noqa: F403 from .quant_type import * # noqa: F403

View File

@ -127,6 +127,7 @@ class AdaptiveRoundingOptimizer:
@torch.no_grad() @torch.no_grad()
def feed_forward(self, x, weight, module): def feed_forward(self, x, weight, module):
if isinstance(module, torch.nn.Conv1d): if isinstance(module, torch.nn.Conv1d):
# pyrefly: ignore # no-matching-overload
out = torch.nn.functional.conv1d( out = torch.nn.functional.conv1d(
x, x,
weight, weight,

View File

@ -185,7 +185,9 @@ class FakeQuantize(FakeQuantizeBase):
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get( dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
"dtype", dtype "dtype", dtype
) )
# pyrefly: ignore # bad-argument-type
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound" assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
# pyrefly: ignore # bad-argument-type
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound" assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max}) observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
observer_kwargs["is_dynamic"] = is_dynamic observer_kwargs["is_dynamic"] = is_dynamic

View File

@ -1149,6 +1149,7 @@ quantized_decomposed_lib.define(
class FakeQuantPerChannel(torch.autograd.Function): class FakeQuantPerChannel(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max): def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
if scales.dtype != torch.float32: if scales.dtype != torch.float32:
scales = scales.to(torch.float32) scales = scales.to(torch.float32)
@ -1171,6 +1172,7 @@ class FakeQuantPerChannel(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, gy): def backward(ctx, gy):
(mask,) = ctx.saved_tensors (mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None return gy * mask, None, None, None, None, None

View File

@ -246,6 +246,7 @@ def calculate_equalization_scale(
class EqualizationQConfig( class EqualizationQConfig(
# pyrefly: ignore # invalid-inheritance
namedtuple("EqualizationQConfig", ["input_activation", "weight"]) namedtuple("EqualizationQConfig", ["input_activation", "weight"])
): ):
""" """
@ -460,6 +461,7 @@ def maybe_get_next_equalization_scale(
In this case, the node given is linear1 and we want to locate the InputEqObs. In this case, the node given is linear1 and we want to locate the InputEqObs.
""" """
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules) next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
# pyrefly: ignore # invalid-argument
if next_inp_eq_obs: if next_inp_eq_obs:
if ( if (
next_inp_eq_obs.equalization_scale.nelement() == 1 next_inp_eq_obs.equalization_scale.nelement() == 1
@ -821,13 +823,18 @@ def convert_eq_obs(
# Scale the weight nodes # Scale the weight nodes
if node.op == "call_module": if node.op == "call_module":
scale_weight_node( scale_weight_node(
node, modules, equalization_scale, maybe_next_equalization_scale node,
modules,
# pyrefly: ignore # bad-argument-type
equalization_scale,
maybe_next_equalization_scale,
) )
elif node.op == "call_function": elif node.op == "call_function":
scale_weight_functional( scale_weight_functional(
node, node,
model, model,
modules, modules,
# pyrefly: ignore # bad-argument-type
equalization_scale, equalization_scale,
maybe_next_equalization_scale, maybe_next_equalization_scale,
) )

View File

@ -223,6 +223,7 @@ class ModelReportVisualizer:
feature_val = feature_val.item() feature_val = feature_val.item()
# we add to our list of values # we add to our list of values
# pyrefly: ignore # bad-argument-type
tensor_table_row.append(feature_val) tensor_table_row.append(feature_val)
tensor_table.append(tensor_table_row) tensor_table.append(tensor_table_row)
@ -283,6 +284,7 @@ class ModelReportVisualizer:
feature_val = feature_val.item() feature_val = feature_val.item()
# add value to channel specific row # add value to channel specific row
# pyrefly: ignore # bad-argument-type
new_channel_row.append(feature_val) new_channel_row.append(feature_val)
# add to table and increment row index counter # add to table and increment row index counter

View File

@ -166,6 +166,7 @@ def _create_obs_or_fq_from_qspec(
} }
edge_or_nodes = quantization_spec.derived_from edge_or_nodes = quantization_spec.derived_from
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes] obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
# pyrefly: ignore # unsupported-operation
kwargs["obs_or_fqs"] = obs_or_fqs kwargs["obs_or_fqs"] = obs_or_fqs
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)() return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec): elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
@ -2085,8 +2086,11 @@ def prepare(
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
# pyrefly: ignore # bad-argument-type
_update_qconfig_for_fusion(model, qconfig_mapping) _update_qconfig_for_fusion(model, qconfig_mapping)
# pyrefly: ignore # bad-argument-type
_update_qconfig_for_fusion(model, _equalization_config) _update_qconfig_for_fusion(model, _equalization_config)
# pyrefly: ignore # bad-argument-type
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping) flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
# TODO: support regex as well # TODO: support regex as well
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
@ -2094,6 +2098,7 @@ def prepare(
if is_qat: if is_qat:
module_to_qat_module = get_module_to_qat_module(backend_config) module_to_qat_module = get_module_to_qat_module(backend_config)
_qat_swap_modules(model, module_to_qat_module) _qat_swap_modules(model, module_to_qat_module)
# pyrefly: ignore # bad-argument-type
_update_qconfig_for_qat(qconfig_mapping, backend_config) _update_qconfig_for_qat(qconfig_mapping, backend_config)
# mapping from fully qualified module name to module instance # mapping from fully qualified module name to module instance
@ -2107,10 +2112,20 @@ def prepare(
# fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
equalization_node_name_to_qconfig = _generate_node_name_to_qconfig( equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
model, named_modules, model.graph, _equalization_config, node_name_to_scope model,
named_modules,
model.graph,
# pyrefly: ignore # bad-argument-type
_equalization_config,
node_name_to_scope,
) )
node_name_to_qconfig = _generate_node_name_to_qconfig( node_name_to_qconfig = _generate_node_name_to_qconfig(
model, named_modules, model.graph, qconfig_mapping, node_name_to_scope model,
named_modules,
model.graph,
# pyrefly: ignore # bad-argument-type
qconfig_mapping,
node_name_to_scope,
) )
# match the patterns that will get quantized # match the patterns that will get quantized
@ -2170,6 +2185,7 @@ def prepare(
node_name_to_scope, node_name_to_scope,
prepare_custom_config, prepare_custom_config,
equalization_node_name_to_qconfig, equalization_node_name_to_qconfig,
# pyrefly: ignore # bad-argument-type
qconfig_mapping, qconfig_mapping,
is_qat, is_qat,
observed_node_names, observed_node_names,

View File

@ -720,6 +720,7 @@ def _maybe_get_custom_module_lstm_from_node_arg(
a = a.args[0][0] # type: ignore[assignment,index] a = a.args[0][0] # type: ignore[assignment,index]
else: else:
a = a.args[0] # type: ignore[assignment] a = a.args[0] # type: ignore[assignment]
# pyrefly: ignore # bad-return
return a return a
all_match_patterns = [ all_match_patterns = [

View File

@ -280,9 +280,12 @@ class UniformQuantizationObserverBase(ObserverBase):
) )
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
if self.has_customized_qrange: if self.has_customized_qrange:
# pyrefly: ignore # bad-argument-type
validate_qmin_qmax(quant_min, quant_max) validate_qmin_qmax(quant_min, quant_max)
self.quant_min, self.quant_max = calculate_qmin_qmax( self.quant_min, self.quant_max = calculate_qmin_qmax(
# pyrefly: ignore # bad-argument-type
quant_min, quant_min,
# pyrefly: ignore # bad-argument-type
quant_max, quant_max,
self.has_customized_qrange, self.has_customized_qrange,
self.dtype, self.dtype,

View File

@ -72,6 +72,7 @@ def _find_q_dq_node_for_user(
dq_node = n dq_node = n
break break
if dq_node is None: if dq_node is None:
# pyrefly: ignore # bad-assignment
for n in user.kwargs: for n in user.kwargs:
if ( if (
isinstance(n, torch.fx.Node) isinstance(n, torch.fx.Node)

View File

@ -83,6 +83,7 @@ __all__ = [
] ]
# pyrefly: ignore # invalid-inheritance
class QConfig(namedtuple("QConfig", ["activation", "weight"])): class QConfig(namedtuple("QConfig", ["activation", "weight"])):
""" """
Describes how to quantize a layer or a part of the network by providing Describes how to quantize a layer or a part of the network by providing
@ -120,6 +121,7 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
"`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", "`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead",
category=FutureWarning, category=FutureWarning,
) )
# pyrefly: ignore # invalid-inheritance
class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
""" """
Describes how to dynamically quantize a layer or a part of the network by providing Describes how to dynamically quantize a layer or a part of the network by providing

View File

@ -417,6 +417,7 @@ class X86InductorQuantizer(Quantizer):
# As we use `_need_skip_config` to skip all invalid configurations, # As we use `_need_skip_config` to skip all invalid configurations,
# we can safely assume that the all existing non-None configurations # we can safely assume that the all existing non-None configurations
# have the same quantization mode. # have the same quantization mode.
# pyrefly: ignore # bad-assignment
for qconfig in ( for qconfig in (
list(self.module_name_qconfig.values()) list(self.module_name_qconfig.values())
+ list(self.operator_type_qconfig.values()) + list(self.operator_type_qconfig.values())
@ -808,6 +809,7 @@ class X86InductorQuantizer(Quantizer):
) )
binary_node.meta[QUANT_ANNOTATION_KEY] = ( binary_node.meta[QUANT_ANNOTATION_KEY] = (
_X86InductorQuantizationAnnotation( _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
input_qspec_map=binary_node_input_qspec_map, input_qspec_map=binary_node_input_qspec_map,
_annotated=True, _annotated=True,
) )
@ -878,6 +880,7 @@ class X86InductorQuantizer(Quantizer):
) )
binary_node.meta[QUANT_ANNOTATION_KEY] = ( binary_node.meta[QUANT_ANNOTATION_KEY] = (
_X86InductorQuantizationAnnotation( _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
input_qspec_map=binary_node_input_qspec_map, input_qspec_map=binary_node_input_qspec_map,
# TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher. # TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
@ -1085,6 +1088,7 @@ class X86InductorQuantizer(Quantizer):
quantization_config quantization_config
) )
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
input_qspec_map=binary_node_input_qspec_map, input_qspec_map=binary_node_input_qspec_map,
_annotated=True, _annotated=True,
) )
@ -1139,6 +1143,7 @@ class X86InductorQuantizer(Quantizer):
quantization_config quantization_config
) )
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
input_qspec_map=binary_node_input_qspec_map, input_qspec_map=binary_node_input_qspec_map,
_annotated=True, _annotated=True,
_is_output_of_quantized_pattern=True, _is_output_of_quantized_pattern=True,
@ -1499,6 +1504,7 @@ class X86InductorQuantizer(Quantizer):
has_unary = unary_op is not None has_unary = unary_op is not None
seq_partition = [torch.nn.Linear, binary_op] seq_partition = [torch.nn.Linear, binary_op]
if has_unary: if has_unary:
# pyrefly: ignore # bad-argument-type
seq_partition.append(unary_op) seq_partition.append(unary_op)
fused_partitions = find_sequential_partitions(gm, seq_partition) fused_partitions = find_sequential_partitions(gm, seq_partition)
for fused_partition in fused_partitions: for fused_partition in fused_partitions:

View File

@ -376,9 +376,11 @@ def _do_annotate_conv_relu(
input_qspec_map[bias] = get_bias_qspec(quantization_config) input_qspec_map[bias] = get_bias_qspec(quantization_config)
partition.append(bias) partition.append(bias)
# pyrefly: ignore # bad-argument-type
if _is_annotated(partition): if _is_annotated(partition):
continue continue
# pyrefly: ignore # bad-argument-type
if filter_fn and any(not filter_fn(n) for n in partition): if filter_fn and any(not filter_fn(n) for n in partition):
continue continue
@ -389,6 +391,7 @@ def _do_annotate_conv_relu(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True, _annotated=True,
) )
# pyrefly: ignore # bad-argument-type
_mark_nodes_as_annotated(partition) _mark_nodes_as_annotated(partition)
annotated_partitions.append(partition) annotated_partitions.append(partition)
return annotated_partitions return annotated_partitions

View File

@ -39,6 +39,7 @@ class Bernoulli(ExponentialFamily):
validate_args (bool, optional): whether to validate arguments, None by default validate_args (bool, optional): whether to validate arguments, None by default
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.boolean support = constraints.boolean
has_enumerate_support = True has_enumerate_support = True
@ -56,10 +57,12 @@ class Bernoulli(ExponentialFamily):
) )
if probs is not None: if probs is not None:
is_scalar = isinstance(probs, _Number) is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only
(self.probs,) = broadcast_all(probs) (self.probs,) = broadcast_all(probs)
else: else:
assert logits is not None # helps mypy assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number) is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only
(self.logits,) = broadcast_all(logits) (self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits self._param = self.probs if probs is not None else self.logits
if is_scalar: if is_scalar:
@ -137,5 +140,6 @@ class Bernoulli(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]: def _natural_params(self) -> tuple[Tensor]:
return (torch.logit(self.probs),) return (torch.logit(self.probs),)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x): def _log_normalizer(self, x):
return torch.log1p(torch.exp(x)) return torch.log1p(torch.exp(x))

View File

@ -31,6 +31,7 @@ class Beta(ExponentialFamily):
(often referred to as beta) (often referred to as beta)
""" """
# pyrefly: ignore # bad-override
arg_constraints = { arg_constraints = {
"concentration1": constraints.positive, "concentration1": constraints.positive,
"concentration0": constraints.positive, "concentration0": constraints.positive,
@ -113,5 +114,6 @@ class Beta(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]: def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.concentration1, self.concentration0) return (self.concentration1, self.concentration0)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x, y): def _log_normalizer(self, x, y):
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)

View File

@ -45,6 +45,7 @@ class Binomial(Distribution):
logits (Tensor): Event log-odds logits (Tensor): Event log-odds
""" """
# pyrefly: ignore # bad-override
arg_constraints = { arg_constraints = {
"total_count": constraints.nonnegative_integer, "total_count": constraints.nonnegative_integer,
"probs": constraints.unit_interval, "probs": constraints.unit_interval,
@ -66,6 +67,7 @@ class Binomial(Distribution):
if probs is not None: if probs is not None:
( (
self.total_count, self.total_count,
# pyrefly: ignore # read-only
self.probs, self.probs,
) = broadcast_all(total_count, probs) ) = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs) self.total_count = self.total_count.type_as(self.probs)
@ -73,6 +75,7 @@ class Binomial(Distribution):
assert logits is not None # helps mypy assert logits is not None # helps mypy
( (
self.total_count, self.total_count,
# pyrefly: ignore # read-only
self.logits, self.logits,
) = broadcast_all(total_count, logits) ) = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits) self.total_count = self.total_count.type_as(self.logits)
@ -99,6 +102,7 @@ class Binomial(Distribution):
return self._param.new(*args, **kwargs) return self._param.new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=0) @constraints.dependent_property(is_discrete=True, event_dim=0)
# pyrefly: ignore # bad-override
def support(self): def support(self):
return constraints.integer_interval(0, self.total_count) return constraints.integer_interval(0, self.total_count)

View File

@ -50,6 +50,7 @@ class Categorical(Distribution):
logits (Tensor): event log probabilities (unnormalized) logits (Tensor): event log probabilities (unnormalized)
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
has_enumerate_support = True has_enumerate_support = True
@ -66,12 +67,14 @@ class Categorical(Distribution):
if probs is not None: if probs is not None:
if probs.dim() < 1: if probs.dim() < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.") raise ValueError("`probs` parameter must be at least one-dimensional.")
# pyrefly: ignore # read-only
self.probs = probs / probs.sum(-1, keepdim=True) self.probs = probs / probs.sum(-1, keepdim=True)
else: else:
assert logits is not None # helps mypy assert logits is not None # helps mypy
if logits.dim() < 1: if logits.dim() < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.") raise ValueError("`logits` parameter must be at least one-dimensional.")
# Normalize # Normalize
# pyrefly: ignore # read-only
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True) self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
self._param = self.probs if probs is not None else self.logits self._param = self.probs if probs is not None else self.logits
self._num_events = self._param.size()[-1] self._num_events = self._param.size()[-1]
@ -99,6 +102,7 @@ class Categorical(Distribution):
return self._param.new(*args, **kwargs) return self._param.new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=0) @constraints.dependent_property(is_discrete=True, event_dim=0)
# pyrefly: ignore # bad-override
def support(self): def support(self):
return constraints.integer_interval(0, self._num_events - 1) return constraints.integer_interval(0, self._num_events - 1)

View File

@ -31,6 +31,7 @@ class Cauchy(Distribution):
scale (float or Tensor): half width at half maximum. scale (float or Tensor): half width at half maximum.
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real support = constraints.real
has_rsample = True has_rsample = True

View File

@ -47,6 +47,7 @@ class ContinuousBernoulli(ExponentialFamily):
https://arxiv.org/abs/1907.06845 https://arxiv.org/abs/1907.06845
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.unit_interval support = constraints.unit_interval
_mean_carrier_measure = 0 _mean_carrier_measure = 0
@ -65,16 +66,19 @@ class ContinuousBernoulli(ExponentialFamily):
) )
if probs is not None: if probs is not None:
is_scalar = isinstance(probs, _Number) is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only
(self.probs,) = broadcast_all(probs) (self.probs,) = broadcast_all(probs)
# validate 'probs' here if necessary as it is later clamped for numerical stability # validate 'probs' here if necessary as it is later clamped for numerical stability
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
if validate_args is not None: if validate_args is not None:
if not self.arg_constraints["probs"].check(self.probs).all(): if not self.arg_constraints["probs"].check(self.probs).all():
raise ValueError("The parameter probs has invalid values") raise ValueError("The parameter probs has invalid values")
# pyrefly: ignore # read-only
self.probs = clamp_probs(self.probs) self.probs = clamp_probs(self.probs)
else: else:
assert logits is not None # helps mypy assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number) is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only
(self.logits,) = broadcast_all(logits) (self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits self._param = self.probs if probs is not None else self.logits
if is_scalar: if is_scalar:
@ -230,6 +234,7 @@ class ContinuousBernoulli(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]: def _natural_params(self) -> tuple[Tensor]:
return (self.logits,) return (self.logits,)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x): def _log_normalizer(self, x):
"""computes the log normalizing constant as a function of the natural parameter""" """computes the log normalizing constant as a function of the natural parameter"""
out_unst_reg = torch.max( out_unst_reg = torch.max(

View File

@ -22,6 +22,7 @@ def _Dirichlet_backward(x, concentration, grad_output):
class _Dirichlet(Function): class _Dirichlet(Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, concentration): def forward(ctx, concentration):
x = torch._sample_dirichlet(concentration) x = torch._sample_dirichlet(concentration)
ctx.save_for_backward(x, concentration) ctx.save_for_backward(x, concentration)
@ -29,6 +30,7 @@ class _Dirichlet(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
# pyrefly: ignore # bad-override
def backward(ctx, grad_output): def backward(ctx, grad_output):
x, concentration = ctx.saved_tensors x, concentration = ctx.saved_tensors
return _Dirichlet_backward(x, concentration, grad_output) return _Dirichlet_backward(x, concentration, grad_output)
@ -50,6 +52,7 @@ class Dirichlet(ExponentialFamily):
(often referred to as alpha) (often referred to as alpha)
""" """
# pyrefly: ignore # bad-override
arg_constraints = { arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1) "concentration": constraints.independent(constraints.positive, 1)
} }
@ -130,5 +133,6 @@ class Dirichlet(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]: def _natural_params(self) -> tuple[Tensor]:
return (self.concentration,) return (self.concentration,)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x): def _log_normalizer(self, x):
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1)) return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))

View File

@ -27,6 +27,7 @@ class Exponential(ExponentialFamily):
rate (float or Tensor): rate = 1 / scale of the distribution rate (float or Tensor): rate = 1 / scale of the distribution
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"rate": constraints.positive} arg_constraints = {"rate": constraints.positive}
support = constraints.nonnegative support = constraints.nonnegative
has_rsample = True has_rsample = True
@ -89,5 +90,6 @@ class Exponential(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]: def _natural_params(self) -> tuple[Tensor]:
return (-self.rate,) return (-self.rate,)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x): def _log_normalizer(self, x):
return -torch.log(-x) return -torch.log(-x)

View File

@ -29,6 +29,7 @@ class FisherSnedecor(Distribution):
df2 (float or Tensor): degrees of freedom parameter 2 df2 (float or Tensor): degrees of freedom parameter 2
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive} arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
support = constraints.positive support = constraints.positive
has_rsample = True has_rsample = True

View File

@ -34,6 +34,7 @@ class Gamma(ExponentialFamily):
(often referred to as beta), rate = 1 / scale (often referred to as beta), rate = 1 / scale
""" """
# pyrefly: ignore # bad-override
arg_constraints = { arg_constraints = {
"concentration": constraints.positive, "concentration": constraints.positive,
"rate": constraints.positive, "rate": constraints.positive,
@ -109,6 +110,7 @@ class Gamma(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]: def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.concentration - 1, -self.rate) return (self.concentration - 1, -self.rate)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x, y): def _log_normalizer(self, x, y):
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal()) return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())

View File

@ -35,6 +35,7 @@ class GeneralizedPareto(Distribution):
concentration (float or Tensor): Concentration parameter of the distribution concentration (float or Tensor): Concentration parameter of the distribution
""" """
# pyrefly: ignore # bad-override
arg_constraints = { arg_constraints = {
"loc": constraints.real, "loc": constraints.real,
"scale": constraints.positive, "scale": constraints.positive,
@ -130,6 +131,7 @@ class GeneralizedPareto(Distribution):
concentration = self.concentration concentration = self.concentration
valid = concentration < 0.5 valid = concentration < 0.5
safe_conc = torch.where(valid, concentration, 0.25) safe_conc = torch.where(valid, concentration, 0.25)
# pyrefly: ignore # unsupported-operation
result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc)) result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
return torch.where(valid, result, nan) return torch.where(valid, result, nan)
@ -142,6 +144,7 @@ class GeneralizedPareto(Distribution):
return self.loc return self.loc
@constraints.dependent_property(is_discrete=False, event_dim=0) @constraints.dependent_property(is_discrete=False, event_dim=0)
# pyrefly: ignore # bad-override
def support(self): def support(self):
lower = self.loc lower = self.loc
upper = torch.where( upper = torch.where(

View File

@ -44,6 +44,7 @@ class Geometric(Distribution):
logits (Number, Tensor): the log-odds of sampling `1`. logits (Number, Tensor): the log-odds of sampling `1`.
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.nonnegative_integer support = constraints.nonnegative_integer
@ -58,9 +59,11 @@ class Geometric(Distribution):
"Either `probs` or `logits` must be specified, but not both." "Either `probs` or `logits` must be specified, but not both."
) )
if probs is not None: if probs is not None:
# pyrefly: ignore # read-only
(self.probs,) = broadcast_all(probs) (self.probs,) = broadcast_all(probs)
else: else:
assert logits is not None # helps mypy assert logits is not None # helps mypy
# pyrefly: ignore # read-only
(self.logits,) = broadcast_all(logits) (self.logits,) = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, _Number): if isinstance(probs_or_logits, _Number):

View File

@ -32,6 +32,7 @@ class Gumbel(TransformedDistribution):
""" """
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override
support = constraints.real support = constraints.real
def __init__( def __init__(

View File

@ -32,8 +32,10 @@ class HalfCauchy(TransformedDistribution):
""" """
arg_constraints = {"scale": constraints.positive} arg_constraints = {"scale": constraints.positive}
# pyrefly: ignore # bad-override
support = constraints.nonnegative support = constraints.nonnegative
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override
base_dist: Cauchy base_dist: Cauchy
def __init__( def __init__(

View File

@ -32,8 +32,10 @@ class HalfNormal(TransformedDistribution):
""" """
arg_constraints = {"scale": constraints.positive} arg_constraints = {"scale": constraints.positive}
# pyrefly: ignore # bad-override
support = constraints.nonnegative support = constraints.nonnegative
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override
base_dist: Normal base_dist: Normal
def __init__( def __init__(

View File

@ -91,6 +91,7 @@ class Independent(Distribution, Generic[D]):
return self.base_dist.has_enumerate_support return self.base_dist.has_enumerate_support
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override
def support(self): def support(self):
result = self.base_dist.support result = self.base_dist.support
if self.reinterpreted_batch_ndims: if self.reinterpreted_batch_ndims:

View File

@ -38,8 +38,10 @@ class InverseGamma(TransformedDistribution):
"concentration": constraints.positive, "concentration": constraints.positive,
"rate": constraints.positive, "rate": constraints.positive,
} }
# pyrefly: ignore # bad-override
support = constraints.positive support = constraints.positive
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override
base_dist: Gamma base_dist: Gamma
def __init__( def __init__(

View File

@ -44,6 +44,7 @@ class Kumaraswamy(TransformedDistribution):
"concentration1": constraints.positive, "concentration1": constraints.positive,
"concentration0": constraints.positive, "concentration0": constraints.positive,
} }
# pyrefly: ignore # bad-override
support = constraints.unit_interval support = constraints.unit_interval
has_rsample = True has_rsample = True
@ -66,6 +67,7 @@ class Kumaraswamy(TransformedDistribution):
AffineTransform(loc=1.0, scale=-1.0), AffineTransform(loc=1.0, scale=-1.0),
PowerTransform(exponent=self.concentration1.reciprocal()), PowerTransform(exponent=self.concentration1.reciprocal()),
] ]
# pyrefly: ignore # bad-argument-type
super().__init__(base_dist, transforms, validate_args=validate_args) super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(self, batch_shape, _instance=None): def expand(self, batch_shape, _instance=None):

View File

@ -28,6 +28,7 @@ class Laplace(Distribution):
scale (float or Tensor): scale of the distribution scale (float or Tensor): scale of the distribution
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real support = constraints.real
has_rsample = True has_rsample = True

View File

@ -60,6 +60,7 @@ class LKJCholesky(Distribution):
Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008 Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"concentration": constraints.positive} arg_constraints = {"concentration": constraints.positive}
support = constraints.corr_cholesky support = constraints.corr_cholesky

View File

@ -32,8 +32,10 @@ class LogNormal(TransformedDistribution):
""" """
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override
support = constraints.positive support = constraints.positive
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override
base_dist: Normal base_dist: Normal
def __init__( def __init__(

View File

@ -36,8 +36,10 @@ class LogisticNormal(TransformedDistribution):
""" """
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override
support = constraints.simplex support = constraints.simplex
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override
base_dist: Independent[Normal] base_dist: Independent[Normal]
def __init__( def __init__(

View File

@ -86,6 +86,7 @@ class LowRankMultivariateNormal(Distribution):
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
""" """
# pyrefly: ignore # bad-override
arg_constraints = { arg_constraints = {
"loc": constraints.real_vector, "loc": constraints.real_vector,
"cov_factor": constraints.independent(constraints.real, 2), "cov_factor": constraints.independent(constraints.real, 2),

View File

@ -124,6 +124,7 @@ class MixtureSameFamily(Distribution):
return new return new
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override
def support(self): def support(self):
return MixtureSameFamilyConstraint(self._component_distribution.support) return MixtureSameFamilyConstraint(self._component_distribution.support)

View File

@ -50,6 +50,7 @@ class Multinomial(Distribution):
logits (Tensor): event log probabilities (unnormalized) logits (Tensor): event log probabilities (unnormalized)
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
total_count: int total_count: int
@ -92,6 +93,7 @@ class Multinomial(Distribution):
return self._categorical._new(*args, **kwargs) return self._categorical._new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=1) @constraints.dependent_property(is_discrete=True, event_dim=1)
# pyrefly: ignore # bad-override
def support(self): def support(self):
return constraints.multinomial(self.total_count) return constraints.multinomial(self.total_count)

View File

@ -123,6 +123,7 @@ class MultivariateNormal(Distribution):
the corresponding lower triangular matrices using a Cholesky decomposition. the corresponding lower triangular matrices using a Cholesky decomposition.
""" """
# pyrefly: ignore # bad-override
arg_constraints = { arg_constraints = {
"loc": constraints.real_vector, "loc": constraints.real_vector,
"covariance_matrix": constraints.positive_definite, "covariance_matrix": constraints.positive_definite,
@ -156,6 +157,7 @@ class MultivariateNormal(Distribution):
"with optional leading batch dimensions" "with optional leading batch dimensions"
) )
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
# pyrefly: ignore # read-only
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None: elif covariance_matrix is not None:
if covariance_matrix.dim() < 2: if covariance_matrix.dim() < 2:
@ -166,6 +168,7 @@ class MultivariateNormal(Distribution):
batch_shape = torch.broadcast_shapes( batch_shape = torch.broadcast_shapes(
covariance_matrix.shape[:-2], loc.shape[:-1] covariance_matrix.shape[:-2], loc.shape[:-1]
) )
# pyrefly: ignore # read-only
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
else: else:
assert precision_matrix is not None # helps mypy assert precision_matrix is not None # helps mypy
@ -177,6 +180,7 @@ class MultivariateNormal(Distribution):
batch_shape = torch.broadcast_shapes( batch_shape = torch.broadcast_shapes(
precision_matrix.shape[:-2], loc.shape[:-1] precision_matrix.shape[:-2], loc.shape[:-1]
) )
# pyrefly: ignore # read-only
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
self.loc = loc.expand(batch_shape + (-1,)) self.loc = loc.expand(batch_shape + (-1,))

View File

@ -33,6 +33,7 @@ class NegativeBinomial(Distribution):
logits (Tensor): Event log-odds for probabilities of success logits (Tensor): Event log-odds for probabilities of success
""" """
# pyrefly: ignore # bad-override
arg_constraints = { arg_constraints = {
"total_count": constraints.greater_than_eq(0), "total_count": constraints.greater_than_eq(0),
"probs": constraints.half_open_interval(0.0, 1.0), "probs": constraints.half_open_interval(0.0, 1.0),
@ -54,6 +55,7 @@ class NegativeBinomial(Distribution):
if probs is not None: if probs is not None:
( (
self.total_count, self.total_count,
# pyrefly: ignore # read-only
self.probs, self.probs,
) = broadcast_all(total_count, probs) ) = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs) self.total_count = self.total_count.type_as(self.probs)
@ -61,6 +63,7 @@ class NegativeBinomial(Distribution):
assert logits is not None # helps mypy assert logits is not None # helps mypy
( (
self.total_count, self.total_count,
# pyrefly: ignore # read-only
self.logits, self.logits,
) = broadcast_all(total_count, logits) ) = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits) self.total_count = self.total_count.type_as(self.logits)

View File

@ -31,6 +31,7 @@ class Normal(ExponentialFamily):
(often referred to as sigma) (often referred to as sigma)
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real support = constraints.real
has_rsample = True has_rsample = True
@ -88,6 +89,7 @@ class Normal(ExponentialFamily):
if self._validate_args: if self._validate_args:
self._validate_sample(value) self._validate_sample(value)
# compute the variance # compute the variance
# pyrefly: ignore # unsupported-operation
var = self.scale**2 var = self.scale**2
log_scale = ( log_scale = (
math.log(self.scale) math.log(self.scale)
@ -117,5 +119,6 @@ class Normal(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]: def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal()) return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
# pyrefly: ignore # bad-override
def _log_normalizer(self, x, y): def _log_normalizer(self, x, y):
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y) return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)

View File

@ -42,6 +42,7 @@ class OneHotCategorical(Distribution):
logits (Tensor): event log probabilities (unnormalized) logits (Tensor): event log probabilities (unnormalized)
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.one_hot support = constraints.one_hot
has_enumerate_support = True has_enumerate_support = True

View File

@ -39,6 +39,7 @@ class Pareto(TransformedDistribution):
self.scale, self.alpha = broadcast_all(scale, alpha) self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args) base_dist = Exponential(self.alpha, validate_args=validate_args)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)] transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
# pyrefly: ignore # bad-argument-type
super().__init__(base_dist, transforms, validate_args=validate_args) super().__init__(base_dist, transforms, validate_args=validate_args)
def expand( def expand(

View File

@ -32,6 +32,7 @@ class Poisson(ExponentialFamily):
rate (Number, Tensor): the rate parameter rate (Number, Tensor): the rate parameter
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"rate": constraints.nonnegative} arg_constraints = {"rate": constraints.nonnegative}
support = constraints.nonnegative_integer support = constraints.nonnegative_integer
@ -82,5 +83,6 @@ class Poisson(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]: def _natural_params(self) -> tuple[Tensor]:
return (torch.log(self.rate),) return (torch.log(self.rate),)
# pyrefly: ignore # bad-override
def _log_normalizer(self, x): def _log_normalizer(self, x):
return torch.exp(x) return torch.exp(x)

View File

@ -40,6 +40,7 @@ class LogitRelaxedBernoulli(Distribution):
(Jang et al., 2017) (Jang et al., 2017)
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.real support = constraints.real
@ -57,10 +58,12 @@ class LogitRelaxedBernoulli(Distribution):
) )
if probs is not None: if probs is not None:
is_scalar = isinstance(probs, _Number) is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only
(self.probs,) = broadcast_all(probs) (self.probs,) = broadcast_all(probs)
else: else:
assert logits is not None # helps mypy assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number) is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only
(self.logits,) = broadcast_all(logits) (self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits self._param = self.probs if probs is not None else self.logits
if is_scalar: if is_scalar:
@ -138,8 +141,10 @@ class RelaxedBernoulli(TransformedDistribution):
""" """
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
# pyrefly: ignore # bad-override
support = constraints.unit_interval support = constraints.unit_interval
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override
base_dist: LogitRelaxedBernoulli base_dist: LogitRelaxedBernoulli
def __init__( def __init__(

View File

@ -38,6 +38,7 @@ class ExpRelaxedCategorical(Distribution):
(Jang et al., 2017) (Jang et al., 2017)
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = ( support = (
constraints.real_vector constraints.real_vector
@ -127,8 +128,10 @@ class RelaxedOneHotCategorical(TransformedDistribution):
""" """
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
# pyrefly: ignore # bad-override
support = constraints.simplex support = constraints.simplex
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override
base_dist: ExpRelaxedCategorical base_dist: ExpRelaxedCategorical
def __init__( def __init__(

View File

@ -31,6 +31,7 @@ class StudentT(Distribution):
scale (float or Tensor): scale of the distribution scale (float or Tensor): scale of the distribution
""" """
# pyrefly: ignore # bad-override
arg_constraints = { arg_constraints = {
"df": constraints.positive, "df": constraints.positive,
"loc": constraints.real, "loc": constraints.real,

View File

@ -123,6 +123,7 @@ class TransformedDistribution(Distribution):
return new return new
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def support(self): def support(self):
if not self.transforms: if not self.transforms:
return self.base_dist.support return self.base_dist.support

View File

@ -226,11 +226,13 @@ class _InverseTransform(Transform):
self._inv: Transform = transform # type: ignore[assignment] self._inv: Transform = transform # type: ignore[assignment]
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def domain(self): def domain(self):
assert self._inv is not None assert self._inv is not None
return self._inv.codomain return self._inv.codomain
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def codomain(self): def codomain(self):
assert self._inv is not None assert self._inv is not None
return self._inv.domain return self._inv.domain
@ -300,6 +302,7 @@ class ComposeTransform(Transform):
return self.parts == other.parts return self.parts == other.parts
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def domain(self): def domain(self):
if not self.parts: if not self.parts:
return constraints.real return constraints.real
@ -315,6 +318,7 @@ class ComposeTransform(Transform):
return domain return domain
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def codomain(self): def codomain(self):
if not self.parts: if not self.parts:
return constraints.real return constraints.real
@ -434,12 +438,14 @@ class IndependentTransform(Transform):
) )
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def domain(self): def domain(self):
return constraints.independent( return constraints.independent(
self.base_transform.domain, self.reinterpreted_batch_ndims self.base_transform.domain, self.reinterpreted_batch_ndims
) )
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def codomain(self): def codomain(self):
return constraints.independent( return constraints.independent(
self.base_transform.codomain, self.reinterpreted_batch_ndims self.base_transform.codomain, self.reinterpreted_batch_ndims
@ -507,10 +513,12 @@ class ReshapeTransform(Transform):
super().__init__(cache_size=cache_size) super().__init__(cache_size=cache_size)
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override
def domain(self): def domain(self):
return constraints.independent(constraints.real, len(self.in_shape)) return constraints.independent(constraints.real, len(self.in_shape))
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override
def codomain(self): def codomain(self):
return constraints.independent(constraints.real, len(self.out_shape)) return constraints.independent(constraints.real, len(self.out_shape))
@ -764,12 +772,14 @@ class AffineTransform(Transform):
return self._event_dim return self._event_dim
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def domain(self): def domain(self):
if self.event_dim == 0: if self.event_dim == 0:
return constraints.real return constraints.real
return constraints.independent(constraints.real, self.event_dim) return constraints.independent(constraints.real, self.event_dim)
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
def codomain(self): def codomain(self):
if self.event_dim == 0: if self.event_dim == 0:
return constraints.real return constraints.real
@ -867,6 +877,7 @@ class CorrCholeskyTransform(Transform):
# apply stick-breaking on the squared values # apply stick-breaking on the squared values
# Note that y = sign(r) * sqrt(z * z1m_cumprod) # Note that y = sign(r) * sqrt(z * z1m_cumprod)
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
# pyrefly: ignore # unsupported-operation
z = r**2 z = r**2
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
# Diagonal elements must be 1. # Diagonal elements must be 1.
@ -1155,12 +1166,14 @@ class CatTransform(Transform):
return all(t.bijective for t in self.transforms) return all(t.bijective for t in self.transforms)
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override
def domain(self): def domain(self):
return constraints.cat( return constraints.cat(
[t.domain for t in self.transforms], self.dim, self.lengths [t.domain for t in self.transforms], self.dim, self.lengths
) )
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override
def codomain(self): def codomain(self):
return constraints.cat( return constraints.cat(
[t.codomain for t in self.transforms], self.dim, self.lengths [t.codomain for t in self.transforms], self.dim, self.lengths
@ -1233,10 +1246,12 @@ class StackTransform(Transform):
return all(t.bijective for t in self.transforms) return all(t.bijective for t in self.transforms)
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override
def domain(self): def domain(self):
return constraints.stack([t.domain for t in self.transforms], self.dim) return constraints.stack([t.domain for t in self.transforms], self.dim)
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override
def codomain(self): def codomain(self):
return constraints.stack([t.codomain for t in self.transforms], self.dim) return constraints.stack([t.codomain for t in self.transforms], self.dim)

View File

@ -79,6 +79,7 @@ class Uniform(Distribution):
return new return new
@constraints.dependent_property(is_discrete=False, event_dim=0) @constraints.dependent_property(is_discrete=False, event_dim=0)
# pyrefly: ignore # bad-override
def support(self): def support(self):
return constraints.interval(self.low, self.high) return constraints.interval(self.low, self.high)

View File

@ -92,6 +92,7 @@ def _log_modified_bessel_fn(x, order=0):
@torch.jit.script_if_tracing @torch.jit.script_if_tracing
def _rejection_sample(loc, concentration, proposal_r, x): def _rejection_sample(loc, concentration, proposal_r, x):
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device) done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
# pyrefly: ignore # bad-assignment
while not done.all(): while not done.all():
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device) u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
u1, u2, u3 = u.unbind() u1, u2, u3 = u.unbind()
@ -100,6 +101,7 @@ def _rejection_sample(loc, concentration, proposal_r, x):
c = concentration * (proposal_r - f) c = concentration * (proposal_r - f)
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0) accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
if accept.any(): if accept.any():
# pyrefly: ignore # no-matching-overload
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x) x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
done = done | accept done = done | accept
return (x + math.pi + loc) % (2 * math.pi) - math.pi return (x + math.pi + loc) % (2 * math.pi) - math.pi
@ -123,6 +125,7 @@ class VonMises(Distribution):
:param torch.Tensor concentration: concentration parameter :param torch.Tensor concentration: concentration parameter
""" """
# pyrefly: ignore # bad-override
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive} arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
support = constraints.real support = constraints.real
has_rsample = False has_rsample = False
@ -160,8 +163,10 @@ class VonMises(Distribution):
@lazy_property @lazy_property
def _proposal_r(self) -> Tensor: def _proposal_r(self) -> Tensor:
kappa = self._concentration kappa = self._concentration
# pyrefly: ignore # unsupported-operation
tau = 1 + (1 + 4 * kappa**2).sqrt() tau = 1 + (1 + 4 * kappa**2).sqrt()
rho = (tau - (2 * tau).sqrt()) / (2 * kappa) rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
# pyrefly: ignore # unsupported-operation
_proposal_r = (1 + rho**2) / (2 * rho) _proposal_r = (1 + rho**2) / (2 * rho)
# second order Taylor expansion around 0 for small kappa # second order Taylor expansion around 0 for small kappa
_proposal_r_taylor = 1 / kappa + kappa _proposal_r_taylor = 1 / kappa + kappa

View File

@ -35,6 +35,7 @@ class Weibull(TransformedDistribution):
"scale": constraints.positive, "scale": constraints.positive,
"concentration": constraints.positive, "concentration": constraints.positive,
} }
# pyrefly: ignore # bad-override
support = constraints.positive support = constraints.positive
def __init__( def __init__(
@ -52,6 +53,7 @@ class Weibull(TransformedDistribution):
PowerTransform(exponent=self.concentration_reciprocal), PowerTransform(exponent=self.concentration_reciprocal),
AffineTransform(loc=0, scale=self.scale), AffineTransform(loc=0, scale=self.scale),
] ]
# pyrefly: ignore # bad-argument-type
super().__init__(base_dist, transforms, validate_args=validate_args) super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(self, batch_shape, _instance=None): def expand(self, batch_shape, _instance=None):

View File

@ -116,10 +116,13 @@ class Wishart(ExponentialFamily):
) )
if scale_tril is not None: if scale_tril is not None:
# pyrefly: ignore # read-only
self.scale_tril = param.expand(batch_shape + (-1, -1)) self.scale_tril = param.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None: elif covariance_matrix is not None:
# pyrefly: ignore # read-only
self.covariance_matrix = param.expand(batch_shape + (-1, -1)) self.covariance_matrix = param.expand(batch_shape + (-1, -1))
elif precision_matrix is not None: elif precision_matrix is not None:
# pyrefly: ignore # read-only
self.precision_matrix = param.expand(batch_shape + (-1, -1)) self.precision_matrix = param.expand(batch_shape + (-1, -1))
if self.df.lt(event_shape[-1]).any(): if self.df.lt(event_shape[-1]).any():
@ -335,6 +338,7 @@ class Wishart(ExponentialFamily):
p = self._event_shape[-1] # has singleton shape p = self._event_shape[-1] # has singleton shape
return -self.precision_matrix / 2, (nu - p - 1) / 2 return -self.precision_matrix / 2, (nu - p - 1) / 2
# pyrefly: ignore # bad-override
def _log_normalizer(self, x, y): def _log_normalizer(self, x, y):
p = self._event_shape[-1] p = self._event_shape[-1]
return (y + (p + 1) / 2) * ( return (y + (p + 1) / 2) * (

View File

@ -24,11 +24,15 @@ class TensorProperties:
if not self.is_fake: if not self.is_fake:
# only get the storage pointer for real tensors # only get the storage pointer for real tensors
# pyrefly: ignore # bad-assignment
self.storage_ptr = tensor.untyped_storage().data_ptr() self.storage_ptr = tensor.untyped_storage().data_ptr()
if self.is_contiguous: if self.is_contiguous:
# only get storage size and start/end pointers for contiguous tensors # only get storage size and start/end pointers for contiguous tensors
# pyrefly: ignore # bad-assignment
self.storage_size = tensor.untyped_storage().nbytes() self.storage_size = tensor.untyped_storage().nbytes()
# pyrefly: ignore # bad-assignment
self.start = tensor.data_ptr() self.start = tensor.data_ptr()
# pyrefly: ignore # bad-assignment
self.end = _end_ptr(tensor) self.end = _end_ptr(tensor)
# info to recover tensor # info to recover tensor

View File

@ -65,6 +65,7 @@ class GraphPickler(pickle.Pickler):
self._meta_tensor_describer = MetaTensorDescriber(copy_data=False) self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
@override @override
# pyrefly: ignore # bad-override
def reducer_override( def reducer_override(
self, obj: object self, obj: object
) -> tuple[Callable[..., Any], tuple[Any, ...]]: ) -> tuple[Callable[..., Any], tuple[Any, ...]]:
@ -201,6 +202,7 @@ class _SymNodePickleData:
]: ]:
args = (cls(obj.node), pickler._unpickle_state) args = (cls(obj.node), pickler._unpickle_state)
if isinstance(obj, torch.SymInt): if isinstance(obj, torch.SymInt):
# pyrefly: ignore # bad-return
return _SymNodePickleData.unpickle_sym_int, args return _SymNodePickleData.unpickle_sym_int, args
else: else:
raise NotImplementedError(f"Unhandled SymNode type {type(obj)}") raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
@ -277,6 +279,7 @@ class _TensorPickleData:
return FakeTensor( return FakeTensor(
unpickle_state.fake_mode, unpickle_state.fake_mode,
make_meta_t(), make_meta_t(),
# pyrefly: ignore # bad-argument-type
device, device,
) )

View File

@ -603,6 +603,7 @@ class Tracer(TracerBase):
in inspect.signature(self.create_proxy).parameters in inspect.signature(self.create_proxy).parameters
): ):
kwargs["proxy_factory_fn"] = ( kwargs["proxy_factory_fn"] = (
# pyrefly: ignore # unsupported-operation
None None
if not self.param_shapes_constant if not self.param_shapes_constant
else lambda node: ParameterProxy( else lambda node: ParameterProxy(

View File

@ -657,7 +657,10 @@ class Partitioner:
# Mark bfs level # Mark bfs level
get_bfs_level_partition(self.partitions) get_bfs_level_partition(self.partitions)
find_combination, partitions = find_partition_to_combine_based_on_size( find_combination, partitions = find_partition_to_combine_based_on_size(
sorted_partitions, available_mem_bytes, partitions sorted_partitions,
available_mem_bytes,
# pyrefly: ignore # bad-argument-type
partitions,
) )
return return
@ -702,6 +705,7 @@ class Partitioner:
non_embedding_partitions.append(partition) non_embedding_partitions.append(partition)
if new_partition: if new_partition:
partition = self.create_partition() partition = self.create_partition()
# pyrefly: ignore # missing-attribute
partition.left_mem_bytes = available_mem_bytes partition.left_mem_bytes = available_mem_bytes
return partition return partition
return None return None
@ -997,6 +1001,7 @@ class Partitioner:
node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
) )
if cost < min_cost: if cost < min_cost:
# pyrefly: ignore # bad-assignment
node_pair = [node, n1] node_pair = [node, n1]
min_cost = cost min_cost = cost
return cost, node_pair # type: ignore[possibly-undefined] return cost, node_pair # type: ignore[possibly-undefined]

View File

@ -30,6 +30,7 @@ def split_result_tensors(
else: else:
splits = [x.shape[0] for x in inputs] splits = [x.shape[0] for x in inputs]
# pyrefly: ignore # bad-argument-type
return torch.split(result, splits) return torch.split(result, splits)

View File

@ -171,7 +171,14 @@ class MetaTracer(torch.fx.Tracer):
proxy_factory_fn=None, proxy_factory_fn=None,
): ):
rv = super().create_proxy( rv = super().create_proxy(
kind, target, args, kwargs, name, type_expr, proxy_factory_fn kind,
target,
args,
kwargs,
name,
type_expr,
# pyrefly: ignore # bad-argument-type
proxy_factory_fn,
) )
if kind == "placeholder" and target in self.meta_args: if kind == "placeholder" and target in self.meta_args:
@ -193,6 +200,7 @@ class MetaTracer(torch.fx.Tracer):
if kind == "call_function": if kind == "call_function":
meta_target = manual_meta_overrides.get(target, target) meta_target = manual_meta_overrides.get(target, target)
# pyrefly: ignore # not-callable
meta_out = meta_target(*args_metas, **kwargs_metas) meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_method": elif kind == "call_method":
meta_target = getattr(args_metas[0], target) # type: ignore[index] meta_target = getattr(args_metas[0], target) # type: ignore[index]

View File

@ -528,9 +528,11 @@ def view_inference_rule(n: Node, symbols, constraints, counter):
if t == -1: if t == -1:
var, counter = gen_dvar(counter) var, counter = gen_dvar(counter)
t2_type.append(var) t2_type.append(var)
# pyrefly: ignore # bad-argument-type
num_constraints.append(BinConstraintD(var, Dyn, op_neq)) num_constraints.append(BinConstraintD(var, Dyn, op_neq))
else: else:
# pyrefly: ignore # bad-argument-type
num_constraints.append(BinConstraintD(t, Dyn, op_neq)) num_constraints.append(BinConstraintD(t, Dyn, op_neq))
t2_type.append(t) # type: ignore[arg-type] t2_type.append(t) # type: ignore[arg-type]
@ -1475,6 +1477,7 @@ class ConstraintGenerator:
all_constraints = [] all_constraints = []
# pyrefly: ignore # missing-attribute
for n in graph.nodes: for n in graph.nodes:
(constraints, counter) = self.generate_constraints_node(n, counter) (constraints, counter) = self.generate_constraints_node(n, counter)
all_constraints += constraints all_constraints += constraints

View File

@ -193,6 +193,7 @@ def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
assert isinstance(node.target, str) assert isinstance(node.target, str)
cur_module = modules[node.target] cur_module = modules[node.target]
if type(cur_module) in mkldnn_map: if type(cur_module) in mkldnn_map:
# pyrefly: ignore # index-error
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
assert isinstance(new_module, nn.Module) assert isinstance(new_module, nn.Module)
old_modules[new_module] = copy.deepcopy(cur_module) old_modules[new_module] = copy.deepcopy(cur_module)
@ -263,7 +264,10 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
) )
reset_modules( reset_modules(
submodule.graph.nodes, dict(submodule.named_modules()), old_modules submodule.graph.nodes,
dict(submodule.named_modules()),
# pyrefly: ignore # bad-argument-type
old_modules,
) )
no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
return mkl_time < no_mkl_time return mkl_time < no_mkl_time

View File

@ -124,6 +124,7 @@ pytree.register_pytree_node(
torch.Size, torch.Size,
lambda xs: (list(xs), None), lambda xs: (list(xs), None),
lambda xs, _: tuple(xs), lambda xs, _: tuple(xs),
# pyrefly: ignore # bad-argument-type
flatten_with_keys_fn=lambda xs: ( flatten_with_keys_fn=lambda xs: (
[(pytree.SequenceKey(i), x) for i, x in enumerate(xs)], [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
None, None,
@ -306,6 +307,7 @@ def set_proxy_slot( # type: ignore[no-redef]
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool: def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
assert isinstance(obj, (Tensor, SymNode)), type(obj) assert isinstance(obj, (Tensor, SymNode)), type(obj)
# pyrefly: ignore # no-matching-overload
return bool(get_proxy_slot(obj, tracer, False, lambda _: True)) return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
@ -402,6 +404,7 @@ def get_proxy_slot(
assert isinstance(obj, py_sym_types), type(obj) assert isinstance(obj, py_sym_types), type(obj)
tracker = tracer.symnode_tracker tracker = tracer.symnode_tracker
# pyrefly: ignore # unsupported-operation
if obj not in tracker: if obj not in tracker:
# Last ditch # Last ditch
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker: if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
@ -413,6 +416,7 @@ def get_proxy_slot(
) )
return default return default
else: else:
# pyrefly: ignore # index-error
value = tracker[obj] value = tracker[obj]
res = transform(value) res = transform(value)
return res return res
@ -788,6 +792,7 @@ def fetch_object_proxy(
def fetch_object_proxy( def fetch_object_proxy(
tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType] tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType]
) -> object: ) -> object:
# pyrefly: ignore # no-matching-overload
return get_proxy_slot(t, tracer, t) return get_proxy_slot(t, tracer, t)
@ -836,6 +841,7 @@ def _fetch_proxies_and_all_constant_flag(
""" """
f_flat_args_kwargs = [ f_flat_args_kwargs = [
( (
# pyrefly: ignore # no-matching-overload
fetch_object_proxy(tracer, x) fetch_object_proxy(tracer, x)
if isinstance(x, (Tensor, _AnyScriptObject)) if isinstance(x, (Tensor, _AnyScriptObject))
else x else x
@ -1410,6 +1416,7 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
kwargs: Optional[dict[str, object]] = None, kwargs: Optional[dict[str, object]] = None,
) -> object: ) -> object:
kwargs = kwargs or {} kwargs = kwargs or {}
# pyrefly: ignore # bad-assignment
self.tracer.torch_fn_metadata = func self.tracer.torch_fn_metadata = func
self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1 self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1
return func(*args, **kwargs) return func(*args, **kwargs)
@ -1459,6 +1466,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
# For autocast, the python APIs run so we don't have to run them again # For autocast, the python APIs run so we don't have to run them again
# here. # here.
if func is torch._C._set_grad_enabled: if func is torch._C._set_grad_enabled:
# pyrefly: ignore # bad-argument-type
func(*args, **kwargs) func(*args, **kwargs)
return node return node
@ -1672,6 +1680,7 @@ class DecompositionInterpreter(fx.Interpreter):
self.decomposition_table = decomposition_table or {} self.decomposition_table = decomposition_table or {}
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
# pyrefly: ignore # bad-override
def placeholder( def placeholder(
self, self,
target: str, # type: ignore[override] target: str, # type: ignore[override]
@ -1684,6 +1693,7 @@ class DecompositionInterpreter(fx.Interpreter):
# TODO handle case where the first character of target is '*' # TODO handle case where the first character of target is '*'
return out return out
# pyrefly: ignore # bad-override
def get_attr( def get_attr(
self, self,
target: str, # type: ignore[override] target: str, # type: ignore[override]
@ -1697,6 +1707,7 @@ class DecompositionInterpreter(fx.Interpreter):
# call_function, call_method, call_module get traced automatically by the outer mode. # call_function, call_method, call_module get traced automatically by the outer mode.
# pyrefly: ignore # bad-override
def output( def output(
self, self,
target: str, # type: ignore[override] target: str, # type: ignore[override]
@ -1782,14 +1793,17 @@ class _ModuleStackTracer(PythonKeyTracer):
self.enable_attr_proxy = False self.enable_attr_proxy = False
self.submodule_paths = {} self.submodule_paths = {}
for name, m in self.scope_root.named_modules(remove_duplicate=False): for name, m in self.scope_root.named_modules(remove_duplicate=False):
# pyrefly: ignore # unsupported-operation
if m in self.submodule_paths: if m in self.submodule_paths:
log.info( log.info(
"Shared module found between %s and %s, AttrProxy is enabled.", "Shared module found between %s and %s, AttrProxy is enabled.",
# pyrefly: ignore # unsupported-operation
self.submodule_paths[m], self.submodule_paths[m],
name, name,
) )
self.enable_attr_proxy = True self.enable_attr_proxy = True
else: else:
# pyrefly: ignore # unsupported-operation
self.submodule_paths[m] = name self.submodule_paths[m] = name
self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary() self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary()
@ -1815,6 +1829,7 @@ class _ModuleStackTracer(PythonKeyTracer):
# Class is modified to be a subclass of torch.nn.Module # Class is modified to be a subclass of torch.nn.Module
# Warning: We blow away our own attributes here to mimic the base class # Warning: We blow away our own attributes here to mimic the base class
# - so don't expect `self.x` to do anything useful. # - so don't expect `self.x` to do anything useful.
# pyrefly: ignore # no-matching-overload
self.__class__ = type( self.__class__ = type(
base.__class__.__name__, base.__class__.__name__,
(self.__class__, base.__class__), (self.__class__, base.__class__),
@ -1837,6 +1852,7 @@ class _ModuleStackTracer(PythonKeyTracer):
if not isinstance(attr_val, Module): if not isinstance(attr_val, Module):
return attr_val return attr_val
# pyrefly: ignore # index-error
return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name) return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name)
def get_base(self) -> Module: def get_base(self) -> Module:
@ -1849,10 +1865,12 @@ class _ModuleStackTracer(PythonKeyTracer):
res = torch.nn.Sequential( res = torch.nn.Sequential(
OrderedDict(list(self._modules.items())[idx]) OrderedDict(list(self._modules.items())[idx])
) )
# pyrefly: ignore # index-error
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}") return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
elif isinstance(self, torch.nn.ModuleList): elif isinstance(self, torch.nn.ModuleList):
# Copied from nn/modules/container.py # Copied from nn/modules/container.py
res = torch.nn.ModuleList(list(self._modules.values())[idx]) res = torch.nn.ModuleList(list(self._modules.values())[idx])
# pyrefly: ignore # index-error
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}") return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
return super().__getitem__(idx) # type: ignore[misc] return super().__getitem__(idx) # type: ignore[misc]

View File

@ -839,6 +839,7 @@ def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr:
factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) factor = functools.reduce(math.gcd, map(integer_coefficient, atoms))
if factor == 1: if factor == 1:
return expr return expr
# pyrefly: ignore # bad-argument-type
atoms = [div_by_factor(x, factor) for x in atoms] atoms = [div_by_factor(x, factor) for x in atoms]
return _sympy_from_args( return _sympy_from_args(
sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative
@ -1234,6 +1235,7 @@ def _free_unbacked_symbols_with_path(
else _symint_wrap(coeff) else _symint_wrap(coeff)
) )
# TODO: DivideByKey needs to test divisibility at runtime! # TODO: DivideByKey needs to test divisibility at runtime!
# pyrefly: ignore # unsupported-operation
r[unbacked] = path + (DivideByKey(divisor),) r[unbacked] = path + (DivideByKey(divisor),)
if real is not None: if real is not None:
assert isinstance(real, int) assert isinstance(real, int)
@ -1256,6 +1258,7 @@ def _free_unbacked_symbols_with_path(
and s.rhs == 1 and s.rhs == 1
and s.lhs in pending and s.lhs in pending
): ):
# pyrefly: ignore # unsupported-operation
r[s.lhs] = path + (ConvertIntKey(),) r[s.lhs] = path + (ConvertIntKey(),)
if real is not None: if real is not None:
assert type(real) is bool assert type(real) is bool
@ -2172,6 +2175,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext):
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()
if self.inner_contexts is None: if self.inner_contexts is None:
# pyrefly: ignore # bad-assignment
self.inner_contexts = {} self.inner_contexts = {}
@ -2260,9 +2264,12 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
# only re-create the objects if any of the args changed to avoid expensive # only re-create the objects if any of the args changed to avoid expensive
# checks when re-creating objects. # checks when re-creating objects.
new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type] new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type]
# pyrefly: ignore # missing-attribute
if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)):
# pyrefly: ignore # missing-attribute
return _fast_expand(expr.func(*new_args)) return _fast_expand(expr.func(*new_args))
# pyrefly: ignore # missing-attribute
if expr.is_Pow: if expr.is_Pow:
base: sympy.Expr base: sympy.Expr
exp: sympy.Expr exp: sympy.Expr
@ -2272,9 +2279,11 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
return sympy.expand_multinomial(expr, deep=False) return sympy.expand_multinomial(expr, deep=False)
elif exp < 0: elif exp < 0:
return S.One / sympy.expand_multinomial(S.One / expr, deep=False) return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
# pyrefly: ignore # missing-attribute
elif expr.is_Mul: elif expr.is_Mul:
num: list[sympy.Expr] = [] num: list[sympy.Expr] = []
den: list[sympy.Expr] = [] den: list[sympy.Expr] = []
# pyrefly: ignore # missing-attribute
for arg in expr.args: for arg in expr.args:
if arg.is_Pow and arg.args[1] == -1: if arg.is_Pow and arg.args[1] == -1:
den.append(S.One / arg) # type: ignore[operator, arg-type] den.append(S.One / arg) # type: ignore[operator, arg-type]
@ -2396,6 +2405,7 @@ def _maybe_evaluate_static_worker(
# TODO: remove this try catch (esp for unbacked_only) # TODO: remove this try catch (esp for unbacked_only)
try: try:
# pyrefly: ignore # missing-attribute
new_expr = expr.xreplace(new_shape_env) new_expr = expr.xreplace(new_shape_env)
except RecursionError: except RecursionError:
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env) log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
@ -2933,13 +2943,19 @@ class DimConstraints:
# is_integer tests though haha # is_integer tests though haha
return (base - mod_reduced) / divisor return (base - mod_reduced) / divisor
# pyrefly: ignore # missing-attribute
if expr.has(Mod): if expr.has(Mod):
# pyrefly: ignore # missing-attribute
expr = expr.replace(Mod, mod_handler) expr = expr.replace(Mod, mod_handler)
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
# arguments should be OK. # arguments should be OK.
# pyrefly: ignore # missing-attribute
if expr.has(PythonMod): if expr.has(PythonMod):
# pyrefly: ignore # missing-attribute
expr = expr.replace(PythonMod, mod_handler) expr = expr.replace(PythonMod, mod_handler)
# pyrefly: ignore # missing-attribute
if expr.has(FloorDiv): if expr.has(FloorDiv):
# pyrefly: ignore # missing-attribute
expr = expr.replace(FloorDiv, floor_div_handler) expr = expr.replace(FloorDiv, floor_div_handler)
return expr return expr
@ -5057,6 +5073,7 @@ class ShapeEnv:
if duck: if duck:
# Make sure to reuse this symbol for subsequent duck shaping # Make sure to reuse this symbol for subsequent duck shaping
# pyrefly: ignore # unsupported-operation
self.val_to_var[val] = sympy_expr self.val_to_var[val] = sympy_expr
if isinstance(val, int): if isinstance(val, int):
@ -5288,15 +5305,19 @@ class ShapeEnv:
# Expand optional inputs, or verify invariants are upheld # Expand optional inputs, or verify invariants are upheld
if input_contexts is None: if input_contexts is None:
# pyrefly: ignore # bad-assignment
input_contexts = [ input_contexts = [
# pyrefly: ignore # bad-argument-type
_create_no_constraints_context(t) if isinstance(t, Tensorlike) else None _create_no_constraints_context(t) if isinstance(t, Tensorlike) else None
for t in placeholders for t in placeholders
] ]
else: else:
assert len(input_contexts) == len(placeholders) assert len(input_contexts) == len(placeholders)
# pyrefly: ignore # bad-assignment
for i, (t, context) in enumerate(zip(placeholders, input_contexts)): for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
if isinstance(t, Tensorlike): if isinstance(t, Tensorlike):
if context is None: if context is None:
# pyrefly: ignore # bad-argument-type
input_contexts[i] = _create_no_constraints_context(t) input_contexts[i] = _create_no_constraints_context(t)
else: else:
assert isinstance(t, (SymInt, int, SymFloat, float)) assert isinstance(t, (SymInt, int, SymFloat, float))
@ -5582,6 +5603,7 @@ class ShapeEnv:
s = sympy.Float(val) s = sympy.Float(val)
input_guards.append((source, s)) input_guards.append((source, s))
# pyrefly: ignore # no-matching-overload
for t, source, context in zip(placeholders, sources, input_contexts): for t, source, context in zip(placeholders, sources, input_contexts):
if isinstance(source, str): if isinstance(source, str):
from torch._dynamo.source import LocalSource from torch._dynamo.source import LocalSource
@ -5641,11 +5663,13 @@ class ShapeEnv:
) )
track_symint(property_source, ss, constraint_size[i]) track_symint(property_source, ss, constraint_size[i])
else: else:
# pyrefly: ignore # missing-attribute
for i, ss in enumerate(curr_t.size()): for i, ss in enumerate(curr_t.size()):
property_source = TensorPropertySource( property_source = TensorPropertySource(
src, TensorProperty.SIZE, i src, TensorProperty.SIZE, i
) )
track_symint(property_source, ss, constraint_size[i]) track_symint(property_source, ss, constraint_size[i])
# pyrefly: ignore # missing-attribute
for i, ss in enumerate(curr_t.stride()): for i, ss in enumerate(curr_t.stride()):
property_source = TensorPropertySource( property_source = TensorPropertySource(
src, TensorProperty.STRIDE, i src, TensorProperty.STRIDE, i
@ -5653,6 +5677,7 @@ class ShapeEnv:
track_symint(property_source, ss, constraint_stride[i]) track_symint(property_source, ss, constraint_stride[i])
track_symint( track_symint(
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
# pyrefly: ignore # missing-attribute
curr_t.storage_offset(), curr_t.storage_offset(),
) )
@ -5698,6 +5723,7 @@ class ShapeEnv:
continue continue
if is_dim(source): if is_dim(source):
# pyrefly: ignore # missing-attribute
self.dim_constraints.add_equality(source, expr) self.dim_constraints.add_equality(source, expr)
for exprs, printer, lang in zip(all_exprs, printers, langs): for exprs, printer, lang in zip(all_exprs, printers, langs):
@ -5851,6 +5877,7 @@ class ShapeEnv:
continue continue
expr = self.simplify(ra.expr) expr = self.simplify(ra.expr)
# pyrefly: ignore # missing-attribute
self.dim_constraints.add(expr) self.dim_constraints.add(expr)
# 3. Every symbol must be within its value range (this handles 0/1 # 3. Every symbol must be within its value range (this handles 0/1
@ -5867,6 +5894,7 @@ class ShapeEnv:
verbose_expr = "" verbose_expr = ""
if r.lower not in (-sympy.oo, -int_oo): if r.lower not in (-sympy.oo, -int_oo):
if any(is_dim(source) for source in sources): if any(is_dim(source) for source in sources):
# pyrefly: ignore # missing-attribute
self.dim_constraints.add(sympy.Ge(symbol, r.lower)) self.dim_constraints.add(sympy.Ge(symbol, r.lower))
# Only print lower bound in simplified mode if it is not the # Only print lower bound in simplified mode if it is not the
# default # default
@ -5875,6 +5903,7 @@ class ShapeEnv:
verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}" verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}"
if r.upper not in (sympy.oo, int_oo): if r.upper not in (sympy.oo, int_oo):
if any(is_dim(source) for source in sources): if any(is_dim(source) for source in sources):
# pyrefly: ignore # missing-attribute
self.dim_constraints.add(sympy.Le(symbol, r.upper)) self.dim_constraints.add(sympy.Le(symbol, r.upper))
# nontrivial upper bound is always interesting # nontrivial upper bound is always interesting
bounds.append(sympy.Le(symbol, r.upper, evaluate=False)) bounds.append(sympy.Le(symbol, r.upper, evaluate=False))
@ -5943,6 +5972,7 @@ class ShapeEnv:
else: else:
str_msg = f" - {msg_cb()}" str_msg = f" - {msg_cb()}"
error_msgs.append(str_msg) error_msgs.append(str_msg)
# pyrefly: ignore # bad-argument-type
debug_names.add(debug_name) debug_names.add(debug_name)
if len(error_msgs) > 0: if len(error_msgs) > 0:
debug_names_str = ", ".join(sorted(debug_names)) debug_names_str = ", ".join(sorted(debug_names))
@ -6076,6 +6106,7 @@ class ShapeEnv:
Get a list of guards, but pruned so it only provides guards that Get a list of guards, but pruned so it only provides guards that
reference symints from the passed in input reference symints from the passed in input
""" """
# pyrefly: ignore # bad-assignment
symints = { symints = {
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol) s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
} }
@ -6121,6 +6152,7 @@ class ShapeEnv:
else: else:
bindings[-s] = -arg bindings[-s] = -arg
# pyrefly: ignore # bad-assignment
for t, arg in zip(placeholders, args): for t, arg in zip(placeholders, args):
if t is None: if t is None:
continue continue
@ -6338,6 +6370,7 @@ class ShapeEnv:
Apply symbol replacements to any symbols in the given expression. Apply symbol replacements to any symbols in the given expression.
""" """
replacements = {} replacements = {}
# pyrefly: ignore # missing-attribute
for s in expr.free_symbols: for s in expr.free_symbols:
r = self._find(s) r = self._find(s)
@ -6347,6 +6380,7 @@ class ShapeEnv:
if not r.is_Symbol or r != s: if not r.is_Symbol or r != s:
replacements[s] = r replacements[s] = r
if replacements: if replacements:
# pyrefly: ignore # missing-attribute
return safe_expand(expr.xreplace(replacements)) return safe_expand(expr.xreplace(replacements))
else: else:
return expr return expr
@ -7121,6 +7155,7 @@ class ShapeEnv:
instructions = list(dis.Bytecode(frame.f_code)) instructions = list(dis.Bytecode(frame.f_code))
co_lines, offset = inspect.getsourcelines(frame.f_code) co_lines, offset = inspect.getsourcelines(frame.f_code)
start, end, cur = None, None, None start, end, cur = None, None, None
# pyrefly: ignore # bad-assignment
for i, instr in enumerate(instructions): for i, instr in enumerate(instructions):
if instr.starts_line is not None: if instr.starts_line is not None:
cur = instr.starts_line cur = instr.starts_line
@ -8000,6 +8035,7 @@ def _suggest_fixes_for_data_dependent_error_non_strict(
if isinstance(leaf, torch.SymInt): if isinstance(leaf, torch.SymInt):
src_map[str(leaf.node.expr)].append(name) src_map[str(leaf.node.expr)].append(name)
elif isinstance(leaf, torch.Tensor): elif isinstance(leaf, torch.Tensor):
# pyrefly: ignore # bad-assignment
for i, dim in enumerate(leaf.shape): for i, dim in enumerate(leaf.shape):
if isinstance(dim, torch.SymInt): if isinstance(dim, torch.SymInt):
src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]") src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]")

Some files were not shown because too many files have changed in this diff Show More