mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add pyrefly suppressions (#164748)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the `project-excludes` field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: 0 errors (4,263 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
5e47b4dd60
commit
b13cd141b3
@ -25,10 +25,6 @@ project-excludes = [
|
|||||||
"torch/nn/**",
|
"torch/nn/**",
|
||||||
"torch/_dynamo/**",
|
"torch/_dynamo/**",
|
||||||
"torch/utils/**",
|
"torch/utils/**",
|
||||||
"torch/ao/**",
|
|
||||||
"torch/fx/**",
|
|
||||||
"torch/distributions/**",
|
|
||||||
"torch/onnx/**",
|
|
||||||
# formatting issues
|
# formatting issues
|
||||||
"torch/linalg/__init__.py",
|
"torch/linalg/__init__.py",
|
||||||
"torch/package/importer.py",
|
"torch/package/importer.py",
|
||||||
|
@ -470,7 +470,6 @@ def _check_input_constraints_for_graph(
|
|||||||
)
|
)
|
||||||
elif isinstance(node_val, torch.SymInt):
|
elif isinstance(node_val, torch.SymInt):
|
||||||
_check_symint(
|
_check_symint(
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
node_val,
|
node_val,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
arg,
|
arg,
|
||||||
|
@ -360,7 +360,6 @@ def trace_flex_attention(
|
|||||||
"call_function", flex_attention, proxy_args, {}
|
"call_function", flex_attention, proxy_args, {}
|
||||||
)
|
)
|
||||||
return track_tensor_tree(
|
return track_tensor_tree(
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
example_out,
|
example_out,
|
||||||
out_proxy,
|
out_proxy,
|
||||||
constant=None,
|
constant=None,
|
||||||
@ -1080,7 +1079,6 @@ def trace_flex_attention_backward(
|
|||||||
name="flex_attention_backward",
|
name="flex_attention_backward",
|
||||||
)
|
)
|
||||||
return track_tensor_tree(
|
return track_tensor_tree(
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
example_out,
|
example_out,
|
||||||
out_proxy,
|
out_proxy,
|
||||||
constant=None,
|
constant=None,
|
||||||
|
@ -899,7 +899,6 @@ def analyze_kernel_mutations(
|
|||||||
if op.name == "tt.call":
|
if op.name == "tt.call":
|
||||||
assert op.fn_call_name in functions
|
assert op.fn_call_name in functions
|
||||||
mutations = analyze_kernel_mutations(
|
mutations = analyze_kernel_mutations(
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
functions,
|
functions,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
op.fn_call_name,
|
op.fn_call_name,
|
||||||
|
@ -3379,7 +3379,6 @@ def native_layer_norm(
|
|||||||
torch._check(
|
torch._check(
|
||||||
input.ndim >= normalized_ndim
|
input.ndim >= normalized_ndim
|
||||||
and sym_eq(
|
and sym_eq(
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
input.shape[(input.ndim - normalized_ndim) :],
|
input.shape[(input.ndim - normalized_ndim) :],
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
tuple(normalized_shape),
|
tuple(normalized_shape),
|
||||||
|
@ -620,6 +620,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
qconfig=qconfig,
|
qconfig=qconfig,
|
||||||
)
|
)
|
||||||
@ -820,6 +821,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
qconfig=qconfig,
|
qconfig=qconfig,
|
||||||
)
|
)
|
||||||
@ -1021,6 +1023,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
qconfig=qconfig,
|
qconfig=qconfig,
|
||||||
)
|
)
|
||||||
|
@ -36,6 +36,7 @@ class LinearReLU(nnqat.Linear, _FusedModule):
|
|||||||
torch.Size([128, 30])
|
torch.Size([128, 30])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
_FLOAT_MODULE = nni.LinearReLU
|
_FLOAT_MODULE = nni.LinearReLU
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -30,6 +30,7 @@ class LinearReLU(nnqd.Linear):
|
|||||||
torch.Size([128, 30])
|
torch.Size([128, 30])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
_FLOAT_MODULE = nni.LinearReLU
|
_FLOAT_MODULE = nni.LinearReLU
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -54,6 +54,7 @@ class ConvReLU1d(nnq.Conv1d):
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -114,6 +114,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
|
|||||||
assert hasattr(cls, "_FLOAT_RELU_MODULE")
|
assert hasattr(cls, "_FLOAT_RELU_MODULE")
|
||||||
relu = cls._FLOAT_RELU_MODULE()
|
relu = cls._FLOAT_RELU_MODULE()
|
||||||
modules.append(relu)
|
modules.append(relu)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
fused = cls._FLOAT_MODULE(*modules)
|
fused = cls._FLOAT_MODULE(*modules)
|
||||||
fused.train(self.training)
|
fused.train(self.training)
|
||||||
return fused
|
return fused
|
||||||
|
@ -50,6 +50,7 @@ class Embedding(nn.Embedding):
|
|||||||
scale_grad_by_freq,
|
scale_grad_by_freq,
|
||||||
sparse,
|
sparse,
|
||||||
_weight,
|
_weight,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
assert qconfig, "qconfig must be provided for QAT module"
|
assert qconfig, "qconfig must be provided for QAT module"
|
||||||
|
@ -170,8 +170,11 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||||||
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
|
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
|
||||||
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
|
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
|
||||||
if other.in_proj_bias is None:
|
if other.in_proj_bias is None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
observed.linear_Q.bias = None
|
observed.linear_Q.bias = None
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
observed.linear_K.bias = None
|
observed.linear_K.bias = None
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
observed.linear_V.bias = None
|
observed.linear_V.bias = None
|
||||||
else:
|
else:
|
||||||
observed.linear_Q.bias = nn.Parameter(
|
observed.linear_Q.bias = nn.Parameter(
|
||||||
@ -234,6 +237,7 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||||||
_end = _start + fp.embed_dim
|
_end = _start + fp.embed_dim
|
||||||
fp.in_proj_weight[_start:_end, :] = wQ
|
fp.in_proj_weight[_start:_end, :] = wQ
|
||||||
if fp.in_proj_bias is not None:
|
if fp.in_proj_bias is not None:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
assert all(bQ == 0)
|
assert all(bQ == 0)
|
||||||
fp.in_proj_bias[_start:_end] = bQ
|
fp.in_proj_bias[_start:_end] = bQ
|
||||||
|
|
||||||
@ -241,12 +245,14 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||||||
_end = _start + fp.embed_dim
|
_end = _start + fp.embed_dim
|
||||||
fp.in_proj_weight[_start:_end, :] = wK
|
fp.in_proj_weight[_start:_end, :] = wK
|
||||||
if fp.in_proj_bias is not None:
|
if fp.in_proj_bias is not None:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
assert all(bK == 0)
|
assert all(bK == 0)
|
||||||
fp.in_proj_bias[_start:_end] = bK
|
fp.in_proj_bias[_start:_end] = bK
|
||||||
|
|
||||||
_start = _end
|
_start = _end
|
||||||
fp.in_proj_weight[_start:, :] = wV
|
fp.in_proj_weight[_start:, :] = wV
|
||||||
if fp.in_proj_bias is not None:
|
if fp.in_proj_bias is not None:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
assert all(bV == 0)
|
assert all(bV == 0)
|
||||||
fp.in_proj_bias[_start:] = bV
|
fp.in_proj_bias[_start:] = bV
|
||||||
else:
|
else:
|
||||||
@ -254,8 +260,11 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||||||
fp.k_proj_weight = nn.Parameter(wK)
|
fp.k_proj_weight = nn.Parameter(wK)
|
||||||
fp.v_proj_weight = nn.Parameter(wV)
|
fp.v_proj_weight = nn.Parameter(wV)
|
||||||
if fp.in_proj_bias is None:
|
if fp.in_proj_bias is None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.linear_Q.bias = None
|
self.linear_Q.bias = None
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.linear_K.bias = None
|
self.linear_K.bias = None
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.linear_V.bias = None
|
self.linear_V.bias = None
|
||||||
else:
|
else:
|
||||||
fp.in_proj_bias[0 : fp.embed_dim] = bQ
|
fp.in_proj_bias[0 : fp.embed_dim] = bQ
|
||||||
@ -463,6 +472,7 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||||||
assert static_v.size(2) == head_dim
|
assert static_v.size(2) == head_dim
|
||||||
v = static_v
|
v = static_v
|
||||||
|
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
src_len = k.size(1)
|
src_len = k.size(1)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
@ -471,17 +481,35 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||||||
|
|
||||||
if self.add_zero_attn:
|
if self.add_zero_attn:
|
||||||
src_len += 1
|
src_len += 1
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
|
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if k.is_quantized:
|
if k.is_quantized:
|
||||||
k_zeros = torch.quantize_per_tensor(
|
k_zeros = torch.quantize_per_tensor(
|
||||||
k_zeros, k.q_scale(), k.q_zero_point(), k.dtype
|
k_zeros,
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
|
k.q_scale(),
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
|
k.q_zero_point(),
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
|
k.dtype,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
k = torch.cat([k, k_zeros], dim=1)
|
k = torch.cat([k, k_zeros], dim=1)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
|
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if v.is_quantized:
|
if v.is_quantized:
|
||||||
v_zeros = torch.quantize_per_tensor(
|
v_zeros = torch.quantize_per_tensor(
|
||||||
v_zeros, v.q_scale(), v.q_zero_point(), v.dtype
|
v_zeros,
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
|
v.q_scale(),
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
|
v.q_zero_point(),
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
|
v.dtype,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
v = torch.cat([v, v_zeros], dim=1)
|
v = torch.cat([v, v_zeros], dim=1)
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
|
@ -376,6 +376,7 @@ class _LSTMLayer(torch.nn.Module):
|
|||||||
bidirectional,
|
bidirectional,
|
||||||
split_gates=split_gates,
|
split_gates=split_gates,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
layer.qconfig = getattr(other, "qconfig", qconfig)
|
layer.qconfig = getattr(other, "qconfig", qconfig)
|
||||||
wi = getattr(other, f"weight_ih_l{layer_idx}")
|
wi = getattr(other, f"weight_ih_l{layer_idx}")
|
||||||
wh = getattr(other, f"weight_hh_l{layer_idx}")
|
wh = getattr(other, f"weight_hh_l{layer_idx}")
|
||||||
@ -454,6 +455,7 @@ class LSTM(torch.nn.Module):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
not isinstance(dropout, numbers.Number)
|
not isinstance(dropout, numbers.Number)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
or not 0 <= dropout <= 1
|
or not 0 <= dropout <= 1
|
||||||
or isinstance(dropout, bool)
|
or isinstance(dropout, bool)
|
||||||
):
|
):
|
||||||
@ -462,6 +464,7 @@ class LSTM(torch.nn.Module):
|
|||||||
"representing the probability of an element being "
|
"representing the probability of an element being "
|
||||||
"zeroed"
|
"zeroed"
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
if dropout > 0:
|
if dropout > 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"dropout option for quantizable LSTM is ignored. "
|
"dropout option for quantizable LSTM is ignored. "
|
||||||
@ -573,6 +576,7 @@ class LSTM(torch.nn.Module):
|
|||||||
other.bidirectional,
|
other.bidirectional,
|
||||||
split_gates=split_gates,
|
split_gates=split_gates,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
observed.qconfig = getattr(other, "qconfig", qconfig)
|
observed.qconfig = getattr(other, "qconfig", qconfig)
|
||||||
for idx in range(other.num_layers):
|
for idx in range(other.num_layers):
|
||||||
observed.layers[idx] = _LSTMLayer.from_float(
|
observed.layers[idx] = _LSTMLayer.from_float(
|
||||||
|
@ -73,6 +73,7 @@ class Conv1d(nnq.Conv1d):
|
|||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
kernel_size = _single(kernel_size)
|
kernel_size = _single(kernel_size)
|
||||||
stride = _single(stride)
|
stride = _single(stride)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
padding = padding if isinstance(padding, str) else _single(padding)
|
padding = padding if isinstance(padding, str) else _single(padding)
|
||||||
dilation = _single(dilation)
|
dilation = _single(dilation)
|
||||||
|
|
||||||
|
@ -119,7 +119,9 @@ class Linear(nnq.Linear):
|
|||||||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||||
if type(mod) == nni.LinearReLU:
|
if type(mod) == nni.LinearReLU:
|
||||||
mod = mod[0]
|
mod = mod[0]
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||||
|
# pyrefly: ignore # not-callable
|
||||||
weight_observer = mod.qconfig.weight()
|
weight_observer = mod.qconfig.weight()
|
||||||
else:
|
else:
|
||||||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||||
@ -143,6 +145,7 @@ class Linear(nnq.Linear):
|
|||||||
"Unsupported dtype specified for dynamic quantized Linear!"
|
"Unsupported dtype specified for dynamic quantized Linear!"
|
||||||
)
|
)
|
||||||
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
|
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
qlinear.set_weight_bias(qweight, mod.bias)
|
qlinear.set_weight_bias(qweight, mod.bias)
|
||||||
return qlinear
|
return qlinear
|
||||||
|
|
||||||
|
@ -521,6 +521,7 @@ class LSTM(RNNBase):
|
|||||||
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
_FLOAT_MODULE = nn.LSTM
|
_FLOAT_MODULE = nn.LSTM
|
||||||
|
|
||||||
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
||||||
@ -806,6 +807,7 @@ class GRU(RNNBase):
|
|||||||
>>> output, hn = rnn(input, h0)
|
>>> output, hn = rnn(input, h0)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
_FLOAT_MODULE = nn.GRU
|
_FLOAT_MODULE = nn.GRU
|
||||||
|
|
||||||
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
||||||
|
@ -67,7 +67,9 @@ class Hardswish(torch.nn.Hardswish):
|
|||||||
def __init__(self, scale, zero_point, device=None, dtype=None):
|
def __init__(self, scale, zero_point, device=None, dtype=None):
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
@ -138,7 +140,9 @@ class LeakyReLU(torch.nn.LeakyReLU):
|
|||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__(negative_slope, inplace)
|
super().__init__(negative_slope, inplace)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
@ -226,6 +230,7 @@ class Softmax(torch.nn.Softmax):
|
|||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
|
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
|
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
|
||||||
|
|
||||||
def _get_name(self):
|
def _get_name(self):
|
||||||
|
@ -12,7 +12,9 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
|
|||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
|
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -408,6 +408,7 @@ class Conv1d(_ConvNd):
|
|||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
kernel_size = _single(kernel_size)
|
kernel_size = _single(kernel_size)
|
||||||
stride = _single(stride)
|
stride = _single(stride)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
padding = padding if isinstance(padding, str) else _single(padding)
|
padding = padding if isinstance(padding, str) else _single(padding)
|
||||||
dilation = _single(dilation)
|
dilation = _single(dilation)
|
||||||
|
|
||||||
|
@ -310,6 +310,7 @@ class Linear(WeightedQuantizedModule):
|
|||||||
# the type mismatch in assignment. Also, mypy has an issue with
|
# the type mismatch in assignment. Also, mypy has an issue with
|
||||||
# iterables not being implemented, so we are ignoring those too.
|
# iterables not being implemented, so we are ignoring those too.
|
||||||
if not isinstance(cls._FLOAT_MODULE, Iterable):
|
if not isinstance(cls._FLOAT_MODULE, Iterable):
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
cls._FLOAT_MODULE = [cls._FLOAT_MODULE]
|
cls._FLOAT_MODULE = [cls._FLOAT_MODULE]
|
||||||
supported_modules = ", ".join(
|
supported_modules = ", ".join(
|
||||||
[float_mod.__name__ for float_mod in cls._FLOAT_MODULE]
|
[float_mod.__name__ for float_mod in cls._FLOAT_MODULE]
|
||||||
|
@ -37,11 +37,14 @@ class LayerNorm(torch.nn.LayerNorm):
|
|||||||
normalized_shape,
|
normalized_shape,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
elementwise_affine=elementwise_affine,
|
elementwise_affine=elementwise_affine,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
@ -113,7 +116,9 @@ class GroupNorm(torch.nn.GroupNorm):
|
|||||||
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
|
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
@ -175,7 +180,9 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
|
|||||||
)
|
)
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
@ -242,7 +249,9 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
|
|||||||
)
|
)
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
@ -309,7 +318,9 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
|
|||||||
)
|
)
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
@ -95,6 +95,7 @@ class Conv1d(_ConvNd, nn.Conv1d):
|
|||||||
and the backend should be able to fuse the ops with `*` into a quantized conv1d
|
and the backend should be able to fuse the ops with `*` into a quantized conv1d
|
||||||
"""
|
"""
|
||||||
weight_quant_dequant = self.get_weight()
|
weight_quant_dequant = self.get_weight()
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
result = F.conv1d(
|
result = F.conv1d(
|
||||||
x,
|
x,
|
||||||
weight_quant_dequant,
|
weight_quant_dequant,
|
||||||
@ -140,6 +141,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
|
|||||||
dilation,
|
dilation,
|
||||||
groups,
|
groups,
|
||||||
bias,
|
bias,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
padding_mode,
|
padding_mode,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
@ -158,6 +160,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
|
|||||||
and the backend should be able to fuse the ops with `*` into a quantized conv2d
|
and the backend should be able to fuse the ops with `*` into a quantized conv2d
|
||||||
"""
|
"""
|
||||||
weight_quant_dequant = self.get_weight()
|
weight_quant_dequant = self.get_weight()
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
result = F.conv2d(
|
result = F.conv2d(
|
||||||
x,
|
x,
|
||||||
weight_quant_dequant,
|
weight_quant_dequant,
|
||||||
@ -203,6 +206,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
|
|||||||
dilation,
|
dilation,
|
||||||
groups,
|
groups,
|
||||||
bias,
|
bias,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
padding_mode,
|
padding_mode,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
@ -221,6 +225,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
|
|||||||
and the backend should be able to fuse the ops with `*` into a quantized conv3d
|
and the backend should be able to fuse the ops with `*` into a quantized conv3d
|
||||||
"""
|
"""
|
||||||
weight_quant_dequant = self.get_weight()
|
weight_quant_dequant = self.get_weight()
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
result = F.conv3d(
|
result = F.conv3d(
|
||||||
x,
|
x,
|
||||||
weight_quant_dequant,
|
weight_quant_dequant,
|
||||||
@ -378,6 +383,7 @@ class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
|
|||||||
groups,
|
groups,
|
||||||
bias,
|
bias,
|
||||||
dilation,
|
dilation,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
padding_mode,
|
padding_mode,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
@ -459,6 +465,7 @@ class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
|
|||||||
groups,
|
groups,
|
||||||
bias,
|
bias,
|
||||||
dilation,
|
dilation,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
padding_mode,
|
padding_mode,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
|
@ -663,7 +663,11 @@ class LSTM(RNNBase):
|
|||||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||||
if isinstance(orig_input, PackedSequence):
|
if isinstance(orig_input, PackedSequence):
|
||||||
output_packed = PackedSequence(
|
output_packed = PackedSequence(
|
||||||
output, batch_sizes, sorted_indices, unsorted_indices
|
output,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
batch_sizes,
|
||||||
|
sorted_indices,
|
||||||
|
unsorted_indices,
|
||||||
)
|
)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
else:
|
else:
|
||||||
@ -823,7 +827,11 @@ class GRU(RNNBase):
|
|||||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||||
if isinstance(orig_input, PackedSequence):
|
if isinstance(orig_input, PackedSequence):
|
||||||
output_packed = PackedSequence(
|
output_packed = PackedSequence(
|
||||||
output, batch_sizes, sorted_indices, unsorted_indices
|
output,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
batch_sizes,
|
||||||
|
sorted_indices,
|
||||||
|
unsorted_indices,
|
||||||
)
|
)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
else:
|
else:
|
||||||
|
@ -42,6 +42,7 @@ class Embedding(nn.Embedding, ReferenceQuantizedModule):
|
|||||||
scale_grad_by_freq,
|
scale_grad_by_freq,
|
||||||
sparse,
|
sparse,
|
||||||
_weight,
|
_weight,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
|
@ -18,6 +18,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||||||
"scale": 1.0,
|
"scale": 1.0,
|
||||||
"zero_point": 0,
|
"zero_point": 0,
|
||||||
}
|
}
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
|
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
|
||||||
self.weight_dtype = weight_qparams["dtype"]
|
self.weight_dtype = weight_qparams["dtype"]
|
||||||
assert self.weight_qscheme in [
|
assert self.weight_qscheme in [
|
||||||
@ -80,13 +81,16 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"weight_axis", torch.tensor(0, dtype=torch.int, device=device)
|
"weight_axis", torch.tensor(0, dtype=torch.int, device=device)
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
|
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
|
||||||
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
|
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
|
||||||
# for capturing `.item` operations
|
# for capturing `.item` operations
|
||||||
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
|
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.weight_quant_min: typing.Optional[int] = weight_qparams.get(
|
self.weight_quant_min: typing.Optional[int] = weight_qparams.get(
|
||||||
"quant_min", None
|
"quant_min", None
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.weight_quant_max: typing.Optional[int] = weight_qparams.get(
|
self.weight_quant_max: typing.Optional[int] = weight_qparams.get(
|
||||||
"quant_max", None
|
"quant_max", None
|
||||||
)
|
)
|
||||||
@ -105,6 +109,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||||||
return _quantize_and_dequantize_weight_decomposed(
|
return _quantize_and_dequantize_weight_decomposed(
|
||||||
self.weight, # type: ignore[arg-type]
|
self.weight, # type: ignore[arg-type]
|
||||||
self.weight_qscheme,
|
self.weight_qscheme,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.weight_dtype,
|
self.weight_dtype,
|
||||||
self.weight_scale,
|
self.weight_scale,
|
||||||
self.weight_zero_point,
|
self.weight_zero_point,
|
||||||
@ -116,6 +121,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||||||
return _quantize_and_dequantize_weight(
|
return _quantize_and_dequantize_weight(
|
||||||
self.weight, # type: ignore[arg-type]
|
self.weight, # type: ignore[arg-type]
|
||||||
self.weight_qscheme,
|
self.weight_qscheme,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.weight_dtype,
|
self.weight_dtype,
|
||||||
self.weight_scale,
|
self.weight_scale,
|
||||||
self.weight_zero_point,
|
self.weight_zero_point,
|
||||||
@ -131,6 +137,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||||||
return _quantize_weight_decomposed(
|
return _quantize_weight_decomposed(
|
||||||
self.weight, # type: ignore[arg-type]
|
self.weight, # type: ignore[arg-type]
|
||||||
self.weight_qscheme,
|
self.weight_qscheme,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.weight_dtype,
|
self.weight_dtype,
|
||||||
self.weight_scale,
|
self.weight_scale,
|
||||||
self.weight_zero_point,
|
self.weight_zero_point,
|
||||||
@ -142,6 +149,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||||||
return _quantize_weight(
|
return _quantize_weight(
|
||||||
self.weight, # type: ignore[arg-type]
|
self.weight, # type: ignore[arg-type]
|
||||||
self.weight_qscheme,
|
self.weight_qscheme,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.weight_dtype,
|
self.weight_dtype,
|
||||||
self.weight_scale,
|
self.weight_scale,
|
||||||
self.weight_zero_point,
|
self.weight_zero_point,
|
||||||
|
@ -151,7 +151,9 @@ class Linear(torch.nn.Module):
|
|||||||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||||
if type(mod) == nni.LinearReLU:
|
if type(mod) == nni.LinearReLU:
|
||||||
mod = mod[0]
|
mod = mod[0]
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||||
|
# pyrefly: ignore # not-callable
|
||||||
weight_observer = mod.qconfig.weight()
|
weight_observer = mod.qconfig.weight()
|
||||||
else:
|
else:
|
||||||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||||
@ -185,5 +187,6 @@ class Linear(torch.nn.Module):
|
|||||||
col_block_size,
|
col_block_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
|
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
|
||||||
return qlinear
|
return qlinear
|
||||||
|
@ -84,6 +84,7 @@ class _NSGraphMatchableSubgraphsIterator:
|
|||||||
if is_match:
|
if is_match:
|
||||||
# navigate to the base node
|
# navigate to the base node
|
||||||
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
|
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.seen_nodes.add(cur_start_node)
|
self.seen_nodes.add(cur_start_node)
|
||||||
# for now, assume that there are no other nodes
|
# for now, assume that there are no other nodes
|
||||||
# which need to be added to the stack
|
# which need to be added to the stack
|
||||||
@ -94,8 +95,10 @@ class _NSGraphMatchableSubgraphsIterator:
|
|||||||
cur_base_op_node = cur_start_node
|
cur_base_op_node = cur_start_node
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.seen_nodes.add(cur_start_node)
|
self.seen_nodes.add(cur_start_node)
|
||||||
# add args of previous nodes to stack
|
# add args of previous nodes to stack
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
for arg in cur_start_node.all_input_nodes:
|
for arg in cur_start_node.all_input_nodes:
|
||||||
self._recursively_add_node_arg_to_stack(arg)
|
self._recursively_add_node_arg_to_stack(arg)
|
||||||
|
|
||||||
@ -103,6 +106,7 @@ class _NSGraphMatchableSubgraphsIterator:
|
|||||||
# note: this check is done on the start_node, i.e.
|
# note: this check is done on the start_node, i.e.
|
||||||
# if we are matching linear-relu in reverse, this would do the matchable
|
# if we are matching linear-relu in reverse, this would do the matchable
|
||||||
# check on the linear
|
# check on the linear
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if not self._is_matchable(cur_base_op_node):
|
if not self._is_matchable(cur_base_op_node):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -116,8 +120,10 @@ class _NSGraphMatchableSubgraphsIterator:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
return NSSubgraph(
|
return NSSubgraph(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
start_node=cur_start_node,
|
start_node=cur_start_node,
|
||||||
end_node=cur_end_node,
|
end_node=cur_end_node,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
base_op_node=cur_base_op_node,
|
base_op_node=cur_base_op_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -415,6 +415,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
|
|||||||
target2,
|
target2,
|
||||||
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
|
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
|
||||||
new_connections.append((source, target1))
|
new_connections.append((source, target1))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
new_connections.append((source, target2))
|
new_connections.append((source, target2))
|
||||||
|
|
||||||
for source_to_target in (
|
for source_to_target in (
|
||||||
@ -423,6 +424,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
|
|||||||
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
|
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
|
||||||
):
|
):
|
||||||
for source, target in source_to_target.items(): # type:ignore[assignment]
|
for source, target in source_to_target.items(): # type:ignore[assignment]
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
new_connections.append((source, target))
|
new_connections.append((source, target))
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -95,6 +95,7 @@ class OutputProp:
|
|||||||
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
|
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
|
||||||
node.traced_result = result
|
node.traced_result = result
|
||||||
|
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
env[node.name] = result
|
env[node.name] = result
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -393,8 +394,10 @@ def create_submodule_from_subgraph(
|
|||||||
cur_name_idx += 1
|
cur_name_idx += 1
|
||||||
setattr(gm, mod_name, new_arg)
|
setattr(gm, mod_name, new_arg)
|
||||||
new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator]
|
new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator]
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
cur_args_copy.append(new_arg_placeholder)
|
cur_args_copy.append(new_arg_placeholder)
|
||||||
elif isinstance(arg, (float, int, torch.dtype)):
|
elif isinstance(arg, (float, int, torch.dtype)):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
cur_args_copy.append(arg)
|
cur_args_copy.append(arg)
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"arg of type {type(arg)} not handled yet")
|
raise AssertionError(f"arg of type {type(arg)} not handled yet")
|
||||||
@ -801,6 +804,7 @@ def create_add_loggers_graph(
|
|||||||
model,
|
model,
|
||||||
cur_subgraph_idx,
|
cur_subgraph_idx,
|
||||||
match_name,
|
match_name,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
maybe_subgraph,
|
maybe_subgraph,
|
||||||
[qconfig_mapping],
|
[qconfig_mapping],
|
||||||
[node_name_to_qconfig],
|
[node_name_to_qconfig],
|
||||||
@ -857,6 +861,7 @@ def create_add_loggers_graph(
|
|||||||
cur_node_orig = first_node
|
cur_node_orig = first_node
|
||||||
cur_node_copy = None
|
cur_node_copy = None
|
||||||
first_node_copy = None
|
first_node_copy = None
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
while cur_node_orig in subgraph_to_use:
|
while cur_node_orig in subgraph_to_use:
|
||||||
# TODO(future PR): make this support all possible args/kwargs
|
# TODO(future PR): make this support all possible args/kwargs
|
||||||
if cur_node_orig is first_node:
|
if cur_node_orig is first_node:
|
||||||
|
@ -404,6 +404,7 @@ def maybe_add_missing_fqns(results: NSResultsType) -> None:
|
|||||||
for model_name, model_results in model_name_to_results.items():
|
for model_name, model_results in model_name_to_results.items():
|
||||||
if model_name == model_name_with_fqns:
|
if model_name == model_name_with_fqns:
|
||||||
continue
|
continue
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for i in range(len(model_results)):
|
for i in range(len(model_results)):
|
||||||
fqn = ref_model_results[i]["fqn"]
|
fqn = ref_model_results[i]["fqn"]
|
||||||
model_results[i]["fqn"] = fqn
|
model_results[i]["fqn"] = fqn
|
||||||
@ -467,6 +468,7 @@ def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso
|
|||||||
Return:
|
Return:
|
||||||
float or tuple of floats
|
float or tuple of floats
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())
|
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,12 +23,17 @@ class SparseDLRM(DLRM_Net):
|
|||||||
super().__init__(**args)
|
super().__init__(**args)
|
||||||
|
|
||||||
def forward(self, dense_x, lS_o, lS_i):
|
def forward(self, dense_x, lS_o, lS_i):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
x = self.apply_mlp(dense_x, self.bot_l) # dense features
|
x = self.apply_mlp(dense_x, self.bot_l) # dense features
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag
|
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
z = self.interact_features(x, ly)
|
z = self.interact_features(x, ly)
|
||||||
|
|
||||||
z = z.to_sparse_coo()
|
z = z.to_sparse_coo()
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias)
|
z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
for layer in self.top_l[1:]:
|
for layer in self.top_l[1:]:
|
||||||
z = layer(z)
|
z = layer(z)
|
||||||
|
|
||||||
|
@ -72,6 +72,7 @@ class FPGMPruner(BaseStructuredSparsifier):
|
|||||||
dist_matrix = self.dist_fn(t_flatten)
|
dist_matrix = self.dist_fn(t_flatten)
|
||||||
|
|
||||||
# more similar with other filter indicates large in the sum of row
|
# more similar with other filter indicates large in the sum of row
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
distance = torch.sum(torch.abs(dist_matrix), 1)
|
distance = torch.sum(torch.abs(dist_matrix), 1)
|
||||||
|
|
||||||
return distance
|
return distance
|
||||||
|
@ -260,6 +260,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
|
|||||||
module.register_parameter(
|
module.register_parameter(
|
||||||
"_bias", nn.Parameter(module.bias.detach())
|
"_bias", nn.Parameter(module.bias.detach())
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
module.bias = None
|
module.bias = None
|
||||||
module.prune_bias = prune_bias
|
module.prune_bias = prune_bias
|
||||||
|
|
||||||
|
@ -97,6 +97,7 @@ def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]:
|
|||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
|
module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
|
||||||
elif getattr(module, "_bias", None) is not None:
|
elif getattr(module, "_bias", None) is not None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
|
module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
|
||||||
|
|
||||||
# get pruned biases to propagate to subsequent layer
|
# get pruned biases to propagate to subsequent layer
|
||||||
|
@ -170,6 +170,7 @@ class BaseSparsifier(abc.ABC):
|
|||||||
self.make_config_from_model(model)
|
self.make_config_from_model(model)
|
||||||
|
|
||||||
# TODO: Remove the configuration by reference ('module')
|
# TODO: Remove the configuration by reference ('module')
|
||||||
|
# pyrefly: ignore # not-iterable
|
||||||
for module_config in self.config:
|
for module_config in self.config:
|
||||||
assert isinstance(module_config, dict), (
|
assert isinstance(module_config, dict), (
|
||||||
"config elements should be dicts not modules i.e.:"
|
"config elements should be dicts not modules i.e.:"
|
||||||
|
@ -51,6 +51,7 @@ def swap_module(
|
|||||||
new_mod.register_forward_hook(hook_fn)
|
new_mod.register_forward_hook(hook_fn)
|
||||||
|
|
||||||
# respect device affinity when swapping modules
|
# respect device affinity when swapping modules
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
|
devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
|
||||||
assert len(devices) <= 1, (
|
assert len(devices) <= 1, (
|
||||||
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
|
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
|
||||||
|
@ -235,6 +235,7 @@ class WeightNormSparsifier(BaseSparsifier):
|
|||||||
ww = self.norm_fn(getattr(module, tensor_name))
|
ww = self.norm_fn(getattr(module, tensor_name))
|
||||||
tensor_mask = self._make_tensor_mask(
|
tensor_mask = self._make_tensor_mask(
|
||||||
data=ww,
|
data=ww,
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
input_shape=ww.shape,
|
input_shape=ww.shape,
|
||||||
sparsity_level=sparsity_level,
|
sparsity_level=sparsity_level,
|
||||||
sparse_block_shape=sparse_block_shape,
|
sparse_block_shape=sparse_block_shape,
|
||||||
|
@ -24,6 +24,8 @@ from .pt2e.export_utils import (
|
|||||||
_move_exported_model_to_eval as move_exported_model_to_eval,
|
_move_exported_model_to_eval as move_exported_model_to_eval,
|
||||||
_move_exported_model_to_train as move_exported_model_to_train,
|
_move_exported_model_to_train as move_exported_model_to_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # deprecated
|
||||||
from .qconfig import * # noqa: F403
|
from .qconfig import * # noqa: F403
|
||||||
from .qconfig_mapping import * # noqa: F403
|
from .qconfig_mapping import * # noqa: F403
|
||||||
from .quant_type import * # noqa: F403
|
from .quant_type import * # noqa: F403
|
||||||
|
@ -127,6 +127,7 @@ class AdaptiveRoundingOptimizer:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def feed_forward(self, x, weight, module):
|
def feed_forward(self, x, weight, module):
|
||||||
if isinstance(module, torch.nn.Conv1d):
|
if isinstance(module, torch.nn.Conv1d):
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
out = torch.nn.functional.conv1d(
|
out = torch.nn.functional.conv1d(
|
||||||
x,
|
x,
|
||||||
weight,
|
weight,
|
||||||
|
@ -185,7 +185,9 @@ class FakeQuantize(FakeQuantizeBase):
|
|||||||
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
|
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
|
||||||
"dtype", dtype
|
"dtype", dtype
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
|
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
|
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
|
||||||
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
|
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
|
||||||
observer_kwargs["is_dynamic"] = is_dynamic
|
observer_kwargs["is_dynamic"] = is_dynamic
|
||||||
|
@ -1149,6 +1149,7 @@ quantized_decomposed_lib.define(
|
|||||||
|
|
||||||
class FakeQuantPerChannel(torch.autograd.Function):
|
class FakeQuantPerChannel(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
|
def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
|
||||||
if scales.dtype != torch.float32:
|
if scales.dtype != torch.float32:
|
||||||
scales = scales.to(torch.float32)
|
scales = scales.to(torch.float32)
|
||||||
@ -1171,6 +1172,7 @@ class FakeQuantPerChannel(torch.autograd.Function):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, gy):
|
def backward(ctx, gy):
|
||||||
(mask,) = ctx.saved_tensors
|
(mask,) = ctx.saved_tensors
|
||||||
return gy * mask, None, None, None, None, None
|
return gy * mask, None, None, None, None, None
|
||||||
|
@ -246,6 +246,7 @@ def calculate_equalization_scale(
|
|||||||
|
|
||||||
|
|
||||||
class EqualizationQConfig(
|
class EqualizationQConfig(
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
namedtuple("EqualizationQConfig", ["input_activation", "weight"])
|
namedtuple("EqualizationQConfig", ["input_activation", "weight"])
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -460,6 +461,7 @@ def maybe_get_next_equalization_scale(
|
|||||||
In this case, the node given is linear1 and we want to locate the InputEqObs.
|
In this case, the node given is linear1 and we want to locate the InputEqObs.
|
||||||
"""
|
"""
|
||||||
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
|
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
|
||||||
|
# pyrefly: ignore # invalid-argument
|
||||||
if next_inp_eq_obs:
|
if next_inp_eq_obs:
|
||||||
if (
|
if (
|
||||||
next_inp_eq_obs.equalization_scale.nelement() == 1
|
next_inp_eq_obs.equalization_scale.nelement() == 1
|
||||||
@ -821,13 +823,18 @@ def convert_eq_obs(
|
|||||||
# Scale the weight nodes
|
# Scale the weight nodes
|
||||||
if node.op == "call_module":
|
if node.op == "call_module":
|
||||||
scale_weight_node(
|
scale_weight_node(
|
||||||
node, modules, equalization_scale, maybe_next_equalization_scale
|
node,
|
||||||
|
modules,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
equalization_scale,
|
||||||
|
maybe_next_equalization_scale,
|
||||||
)
|
)
|
||||||
elif node.op == "call_function":
|
elif node.op == "call_function":
|
||||||
scale_weight_functional(
|
scale_weight_functional(
|
||||||
node,
|
node,
|
||||||
model,
|
model,
|
||||||
modules,
|
modules,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
equalization_scale,
|
equalization_scale,
|
||||||
maybe_next_equalization_scale,
|
maybe_next_equalization_scale,
|
||||||
)
|
)
|
||||||
|
@ -223,6 +223,7 @@ class ModelReportVisualizer:
|
|||||||
feature_val = feature_val.item()
|
feature_val = feature_val.item()
|
||||||
|
|
||||||
# we add to our list of values
|
# we add to our list of values
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
tensor_table_row.append(feature_val)
|
tensor_table_row.append(feature_val)
|
||||||
|
|
||||||
tensor_table.append(tensor_table_row)
|
tensor_table.append(tensor_table_row)
|
||||||
@ -283,6 +284,7 @@ class ModelReportVisualizer:
|
|||||||
feature_val = feature_val.item()
|
feature_val = feature_val.item()
|
||||||
|
|
||||||
# add value to channel specific row
|
# add value to channel specific row
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
new_channel_row.append(feature_val)
|
new_channel_row.append(feature_val)
|
||||||
|
|
||||||
# add to table and increment row index counter
|
# add to table and increment row index counter
|
||||||
|
@ -166,6 +166,7 @@ def _create_obs_or_fq_from_qspec(
|
|||||||
}
|
}
|
||||||
edge_or_nodes = quantization_spec.derived_from
|
edge_or_nodes = quantization_spec.derived_from
|
||||||
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
|
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
kwargs["obs_or_fqs"] = obs_or_fqs
|
kwargs["obs_or_fqs"] = obs_or_fqs
|
||||||
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
|
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
|
||||||
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
|
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
|
||||||
@ -2085,8 +2086,11 @@ def prepare(
|
|||||||
|
|
||||||
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
|
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_update_qconfig_for_fusion(model, qconfig_mapping)
|
_update_qconfig_for_fusion(model, qconfig_mapping)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_update_qconfig_for_fusion(model, _equalization_config)
|
_update_qconfig_for_fusion(model, _equalization_config)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
|
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
|
||||||
# TODO: support regex as well
|
# TODO: support regex as well
|
||||||
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
|
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
|
||||||
@ -2094,6 +2098,7 @@ def prepare(
|
|||||||
if is_qat:
|
if is_qat:
|
||||||
module_to_qat_module = get_module_to_qat_module(backend_config)
|
module_to_qat_module = get_module_to_qat_module(backend_config)
|
||||||
_qat_swap_modules(model, module_to_qat_module)
|
_qat_swap_modules(model, module_to_qat_module)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_update_qconfig_for_qat(qconfig_mapping, backend_config)
|
_update_qconfig_for_qat(qconfig_mapping, backend_config)
|
||||||
|
|
||||||
# mapping from fully qualified module name to module instance
|
# mapping from fully qualified module name to module instance
|
||||||
@ -2107,10 +2112,20 @@ def prepare(
|
|||||||
|
|
||||||
# fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
|
# fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
|
||||||
equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
|
equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
|
||||||
model, named_modules, model.graph, _equalization_config, node_name_to_scope
|
model,
|
||||||
|
named_modules,
|
||||||
|
model.graph,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
_equalization_config,
|
||||||
|
node_name_to_scope,
|
||||||
)
|
)
|
||||||
node_name_to_qconfig = _generate_node_name_to_qconfig(
|
node_name_to_qconfig = _generate_node_name_to_qconfig(
|
||||||
model, named_modules, model.graph, qconfig_mapping, node_name_to_scope
|
model,
|
||||||
|
named_modules,
|
||||||
|
model.graph,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
qconfig_mapping,
|
||||||
|
node_name_to_scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
# match the patterns that will get quantized
|
# match the patterns that will get quantized
|
||||||
@ -2170,6 +2185,7 @@ def prepare(
|
|||||||
node_name_to_scope,
|
node_name_to_scope,
|
||||||
prepare_custom_config,
|
prepare_custom_config,
|
||||||
equalization_node_name_to_qconfig,
|
equalization_node_name_to_qconfig,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
qconfig_mapping,
|
qconfig_mapping,
|
||||||
is_qat,
|
is_qat,
|
||||||
observed_node_names,
|
observed_node_names,
|
||||||
|
@ -720,6 +720,7 @@ def _maybe_get_custom_module_lstm_from_node_arg(
|
|||||||
a = a.args[0][0] # type: ignore[assignment,index]
|
a = a.args[0][0] # type: ignore[assignment,index]
|
||||||
else:
|
else:
|
||||||
a = a.args[0] # type: ignore[assignment]
|
a = a.args[0] # type: ignore[assignment]
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return a
|
return a
|
||||||
|
|
||||||
all_match_patterns = [
|
all_match_patterns = [
|
||||||
|
@ -280,9 +280,12 @@ class UniformQuantizationObserverBase(ObserverBase):
|
|||||||
)
|
)
|
||||||
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
|
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
|
||||||
if self.has_customized_qrange:
|
if self.has_customized_qrange:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
validate_qmin_qmax(quant_min, quant_max)
|
validate_qmin_qmax(quant_min, quant_max)
|
||||||
self.quant_min, self.quant_max = calculate_qmin_qmax(
|
self.quant_min, self.quant_max = calculate_qmin_qmax(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
quant_min,
|
quant_min,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
quant_max,
|
quant_max,
|
||||||
self.has_customized_qrange,
|
self.has_customized_qrange,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
|
@ -72,6 +72,7 @@ def _find_q_dq_node_for_user(
|
|||||||
dq_node = n
|
dq_node = n
|
||||||
break
|
break
|
||||||
if dq_node is None:
|
if dq_node is None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for n in user.kwargs:
|
for n in user.kwargs:
|
||||||
if (
|
if (
|
||||||
isinstance(n, torch.fx.Node)
|
isinstance(n, torch.fx.Node)
|
||||||
|
@ -83,6 +83,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
class QConfig(namedtuple("QConfig", ["activation", "weight"])):
|
class QConfig(namedtuple("QConfig", ["activation", "weight"])):
|
||||||
"""
|
"""
|
||||||
Describes how to quantize a layer or a part of the network by providing
|
Describes how to quantize a layer or a part of the network by providing
|
||||||
@ -120,6 +121,7 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
|
|||||||
"`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead",
|
"`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead",
|
||||||
category=FutureWarning,
|
category=FutureWarning,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
|
class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
|
||||||
"""
|
"""
|
||||||
Describes how to dynamically quantize a layer or a part of the network by providing
|
Describes how to dynamically quantize a layer or a part of the network by providing
|
||||||
|
@ -417,6 +417,7 @@ class X86InductorQuantizer(Quantizer):
|
|||||||
# As we use `_need_skip_config` to skip all invalid configurations,
|
# As we use `_need_skip_config` to skip all invalid configurations,
|
||||||
# we can safely assume that the all existing non-None configurations
|
# we can safely assume that the all existing non-None configurations
|
||||||
# have the same quantization mode.
|
# have the same quantization mode.
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for qconfig in (
|
for qconfig in (
|
||||||
list(self.module_name_qconfig.values())
|
list(self.module_name_qconfig.values())
|
||||||
+ list(self.operator_type_qconfig.values())
|
+ list(self.operator_type_qconfig.values())
|
||||||
@ -808,6 +809,7 @@ class X86InductorQuantizer(Quantizer):
|
|||||||
)
|
)
|
||||||
binary_node.meta[QUANT_ANNOTATION_KEY] = (
|
binary_node.meta[QUANT_ANNOTATION_KEY] = (
|
||||||
_X86InductorQuantizationAnnotation(
|
_X86InductorQuantizationAnnotation(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
input_qspec_map=binary_node_input_qspec_map,
|
input_qspec_map=binary_node_input_qspec_map,
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
)
|
)
|
||||||
@ -878,6 +880,7 @@ class X86InductorQuantizer(Quantizer):
|
|||||||
)
|
)
|
||||||
binary_node.meta[QUANT_ANNOTATION_KEY] = (
|
binary_node.meta[QUANT_ANNOTATION_KEY] = (
|
||||||
_X86InductorQuantizationAnnotation(
|
_X86InductorQuantizationAnnotation(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
input_qspec_map=binary_node_input_qspec_map,
|
input_qspec_map=binary_node_input_qspec_map,
|
||||||
# TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
|
# TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
|
||||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||||
@ -1085,6 +1088,7 @@ class X86InductorQuantizer(Quantizer):
|
|||||||
quantization_config
|
quantization_config
|
||||||
)
|
)
|
||||||
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
input_qspec_map=binary_node_input_qspec_map,
|
input_qspec_map=binary_node_input_qspec_map,
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
)
|
)
|
||||||
@ -1139,6 +1143,7 @@ class X86InductorQuantizer(Quantizer):
|
|||||||
quantization_config
|
quantization_config
|
||||||
)
|
)
|
||||||
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
input_qspec_map=binary_node_input_qspec_map,
|
input_qspec_map=binary_node_input_qspec_map,
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
_is_output_of_quantized_pattern=True,
|
_is_output_of_quantized_pattern=True,
|
||||||
@ -1499,6 +1504,7 @@ class X86InductorQuantizer(Quantizer):
|
|||||||
has_unary = unary_op is not None
|
has_unary = unary_op is not None
|
||||||
seq_partition = [torch.nn.Linear, binary_op]
|
seq_partition = [torch.nn.Linear, binary_op]
|
||||||
if has_unary:
|
if has_unary:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
seq_partition.append(unary_op)
|
seq_partition.append(unary_op)
|
||||||
fused_partitions = find_sequential_partitions(gm, seq_partition)
|
fused_partitions = find_sequential_partitions(gm, seq_partition)
|
||||||
for fused_partition in fused_partitions:
|
for fused_partition in fused_partitions:
|
||||||
|
@ -376,9 +376,11 @@ def _do_annotate_conv_relu(
|
|||||||
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
||||||
partition.append(bias)
|
partition.append(bias)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if _is_annotated(partition):
|
if _is_annotated(partition):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if filter_fn and any(not filter_fn(n) for n in partition):
|
if filter_fn and any(not filter_fn(n) for n in partition):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -389,6 +391,7 @@ def _do_annotate_conv_relu(
|
|||||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_mark_nodes_as_annotated(partition)
|
_mark_nodes_as_annotated(partition)
|
||||||
annotated_partitions.append(partition)
|
annotated_partitions.append(partition)
|
||||||
return annotated_partitions
|
return annotated_partitions
|
||||||
|
@ -39,6 +39,7 @@ class Bernoulli(ExponentialFamily):
|
|||||||
validate_args (bool, optional): whether to validate arguments, None by default
|
validate_args (bool, optional): whether to validate arguments, None by default
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||||
support = constraints.boolean
|
support = constraints.boolean
|
||||||
has_enumerate_support = True
|
has_enumerate_support = True
|
||||||
@ -56,10 +57,12 @@ class Bernoulli(ExponentialFamily):
|
|||||||
)
|
)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
is_scalar = isinstance(probs, _Number)
|
is_scalar = isinstance(probs, _Number)
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
(self.probs,) = broadcast_all(probs)
|
(self.probs,) = broadcast_all(probs)
|
||||||
else:
|
else:
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
is_scalar = isinstance(logits, _Number)
|
is_scalar = isinstance(logits, _Number)
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
(self.logits,) = broadcast_all(logits)
|
(self.logits,) = broadcast_all(logits)
|
||||||
self._param = self.probs if probs is not None else self.logits
|
self._param = self.probs if probs is not None else self.logits
|
||||||
if is_scalar:
|
if is_scalar:
|
||||||
@ -137,5 +140,6 @@ class Bernoulli(ExponentialFamily):
|
|||||||
def _natural_params(self) -> tuple[Tensor]:
|
def _natural_params(self) -> tuple[Tensor]:
|
||||||
return (torch.logit(self.probs),)
|
return (torch.logit(self.probs),)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _log_normalizer(self, x):
|
def _log_normalizer(self, x):
|
||||||
return torch.log1p(torch.exp(x))
|
return torch.log1p(torch.exp(x))
|
||||||
|
@ -31,6 +31,7 @@ class Beta(ExponentialFamily):
|
|||||||
(often referred to as beta)
|
(often referred to as beta)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"concentration1": constraints.positive,
|
"concentration1": constraints.positive,
|
||||||
"concentration0": constraints.positive,
|
"concentration0": constraints.positive,
|
||||||
@ -113,5 +114,6 @@ class Beta(ExponentialFamily):
|
|||||||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||||
return (self.concentration1, self.concentration0)
|
return (self.concentration1, self.concentration0)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _log_normalizer(self, x, y):
|
def _log_normalizer(self, x, y):
|
||||||
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
|
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
|
||||||
|
@ -45,6 +45,7 @@ class Binomial(Distribution):
|
|||||||
logits (Tensor): Event log-odds
|
logits (Tensor): Event log-odds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"total_count": constraints.nonnegative_integer,
|
"total_count": constraints.nonnegative_integer,
|
||||||
"probs": constraints.unit_interval,
|
"probs": constraints.unit_interval,
|
||||||
@ -66,6 +67,7 @@ class Binomial(Distribution):
|
|||||||
if probs is not None:
|
if probs is not None:
|
||||||
(
|
(
|
||||||
self.total_count,
|
self.total_count,
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.probs,
|
self.probs,
|
||||||
) = broadcast_all(total_count, probs)
|
) = broadcast_all(total_count, probs)
|
||||||
self.total_count = self.total_count.type_as(self.probs)
|
self.total_count = self.total_count.type_as(self.probs)
|
||||||
@ -73,6 +75,7 @@ class Binomial(Distribution):
|
|||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
(
|
(
|
||||||
self.total_count,
|
self.total_count,
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.logits,
|
self.logits,
|
||||||
) = broadcast_all(total_count, logits)
|
) = broadcast_all(total_count, logits)
|
||||||
self.total_count = self.total_count.type_as(self.logits)
|
self.total_count = self.total_count.type_as(self.logits)
|
||||||
@ -99,6 +102,7 @@ class Binomial(Distribution):
|
|||||||
return self._param.new(*args, **kwargs)
|
return self._param.new(*args, **kwargs)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def support(self):
|
def support(self):
|
||||||
return constraints.integer_interval(0, self.total_count)
|
return constraints.integer_interval(0, self.total_count)
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ class Categorical(Distribution):
|
|||||||
logits (Tensor): event log probabilities (unnormalized)
|
logits (Tensor): event log probabilities (unnormalized)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||||
has_enumerate_support = True
|
has_enumerate_support = True
|
||||||
|
|
||||||
@ -66,12 +67,14 @@ class Categorical(Distribution):
|
|||||||
if probs is not None:
|
if probs is not None:
|
||||||
if probs.dim() < 1:
|
if probs.dim() < 1:
|
||||||
raise ValueError("`probs` parameter must be at least one-dimensional.")
|
raise ValueError("`probs` parameter must be at least one-dimensional.")
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.probs = probs / probs.sum(-1, keepdim=True)
|
self.probs = probs / probs.sum(-1, keepdim=True)
|
||||||
else:
|
else:
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
if logits.dim() < 1:
|
if logits.dim() < 1:
|
||||||
raise ValueError("`logits` parameter must be at least one-dimensional.")
|
raise ValueError("`logits` parameter must be at least one-dimensional.")
|
||||||
# Normalize
|
# Normalize
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
|
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
|
||||||
self._param = self.probs if probs is not None else self.logits
|
self._param = self.probs if probs is not None else self.logits
|
||||||
self._num_events = self._param.size()[-1]
|
self._num_events = self._param.size()[-1]
|
||||||
@ -99,6 +102,7 @@ class Categorical(Distribution):
|
|||||||
return self._param.new(*args, **kwargs)
|
return self._param.new(*args, **kwargs)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def support(self):
|
def support(self):
|
||||||
return constraints.integer_interval(0, self._num_events - 1)
|
return constraints.integer_interval(0, self._num_events - 1)
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ class Cauchy(Distribution):
|
|||||||
scale (float or Tensor): half width at half maximum.
|
scale (float or Tensor): half width at half maximum.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
@ -47,6 +47,7 @@ class ContinuousBernoulli(ExponentialFamily):
|
|||||||
https://arxiv.org/abs/1907.06845
|
https://arxiv.org/abs/1907.06845
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||||
support = constraints.unit_interval
|
support = constraints.unit_interval
|
||||||
_mean_carrier_measure = 0
|
_mean_carrier_measure = 0
|
||||||
@ -65,16 +66,19 @@ class ContinuousBernoulli(ExponentialFamily):
|
|||||||
)
|
)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
is_scalar = isinstance(probs, _Number)
|
is_scalar = isinstance(probs, _Number)
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
(self.probs,) = broadcast_all(probs)
|
(self.probs,) = broadcast_all(probs)
|
||||||
# validate 'probs' here if necessary as it is later clamped for numerical stability
|
# validate 'probs' here if necessary as it is later clamped for numerical stability
|
||||||
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
|
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
|
||||||
if validate_args is not None:
|
if validate_args is not None:
|
||||||
if not self.arg_constraints["probs"].check(self.probs).all():
|
if not self.arg_constraints["probs"].check(self.probs).all():
|
||||||
raise ValueError("The parameter probs has invalid values")
|
raise ValueError("The parameter probs has invalid values")
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.probs = clamp_probs(self.probs)
|
self.probs = clamp_probs(self.probs)
|
||||||
else:
|
else:
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
is_scalar = isinstance(logits, _Number)
|
is_scalar = isinstance(logits, _Number)
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
(self.logits,) = broadcast_all(logits)
|
(self.logits,) = broadcast_all(logits)
|
||||||
self._param = self.probs if probs is not None else self.logits
|
self._param = self.probs if probs is not None else self.logits
|
||||||
if is_scalar:
|
if is_scalar:
|
||||||
@ -230,6 +234,7 @@ class ContinuousBernoulli(ExponentialFamily):
|
|||||||
def _natural_params(self) -> tuple[Tensor]:
|
def _natural_params(self) -> tuple[Tensor]:
|
||||||
return (self.logits,)
|
return (self.logits,)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _log_normalizer(self, x):
|
def _log_normalizer(self, x):
|
||||||
"""computes the log normalizing constant as a function of the natural parameter"""
|
"""computes the log normalizing constant as a function of the natural parameter"""
|
||||||
out_unst_reg = torch.max(
|
out_unst_reg = torch.max(
|
||||||
|
@ -22,6 +22,7 @@ def _Dirichlet_backward(x, concentration, grad_output):
|
|||||||
|
|
||||||
class _Dirichlet(Function):
|
class _Dirichlet(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, concentration):
|
def forward(ctx, concentration):
|
||||||
x = torch._sample_dirichlet(concentration)
|
x = torch._sample_dirichlet(concentration)
|
||||||
ctx.save_for_backward(x, concentration)
|
ctx.save_for_backward(x, concentration)
|
||||||
@ -29,6 +30,7 @@ class _Dirichlet(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x, concentration = ctx.saved_tensors
|
x, concentration = ctx.saved_tensors
|
||||||
return _Dirichlet_backward(x, concentration, grad_output)
|
return _Dirichlet_backward(x, concentration, grad_output)
|
||||||
@ -50,6 +52,7 @@ class Dirichlet(ExponentialFamily):
|
|||||||
(often referred to as alpha)
|
(often referred to as alpha)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"concentration": constraints.independent(constraints.positive, 1)
|
"concentration": constraints.independent(constraints.positive, 1)
|
||||||
}
|
}
|
||||||
@ -130,5 +133,6 @@ class Dirichlet(ExponentialFamily):
|
|||||||
def _natural_params(self) -> tuple[Tensor]:
|
def _natural_params(self) -> tuple[Tensor]:
|
||||||
return (self.concentration,)
|
return (self.concentration,)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _log_normalizer(self, x):
|
def _log_normalizer(self, x):
|
||||||
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
|
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
|
||||||
|
@ -27,6 +27,7 @@ class Exponential(ExponentialFamily):
|
|||||||
rate (float or Tensor): rate = 1 / scale of the distribution
|
rate (float or Tensor): rate = 1 / scale of the distribution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"rate": constraints.positive}
|
arg_constraints = {"rate": constraints.positive}
|
||||||
support = constraints.nonnegative
|
support = constraints.nonnegative
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
@ -89,5 +90,6 @@ class Exponential(ExponentialFamily):
|
|||||||
def _natural_params(self) -> tuple[Tensor]:
|
def _natural_params(self) -> tuple[Tensor]:
|
||||||
return (-self.rate,)
|
return (-self.rate,)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _log_normalizer(self, x):
|
def _log_normalizer(self, x):
|
||||||
return -torch.log(-x)
|
return -torch.log(-x)
|
||||||
|
@ -29,6 +29,7 @@ class FisherSnedecor(Distribution):
|
|||||||
df2 (float or Tensor): degrees of freedom parameter 2
|
df2 (float or Tensor): degrees of freedom parameter 2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
|
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
|
||||||
support = constraints.positive
|
support = constraints.positive
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
@ -34,6 +34,7 @@ class Gamma(ExponentialFamily):
|
|||||||
(often referred to as beta), rate = 1 / scale
|
(often referred to as beta), rate = 1 / scale
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"concentration": constraints.positive,
|
"concentration": constraints.positive,
|
||||||
"rate": constraints.positive,
|
"rate": constraints.positive,
|
||||||
@ -109,6 +110,7 @@ class Gamma(ExponentialFamily):
|
|||||||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||||
return (self.concentration - 1, -self.rate)
|
return (self.concentration - 1, -self.rate)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _log_normalizer(self, x, y):
|
def _log_normalizer(self, x, y):
|
||||||
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
|
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ class GeneralizedPareto(Distribution):
|
|||||||
concentration (float or Tensor): Concentration parameter of the distribution
|
concentration (float or Tensor): Concentration parameter of the distribution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"loc": constraints.real,
|
"loc": constraints.real,
|
||||||
"scale": constraints.positive,
|
"scale": constraints.positive,
|
||||||
@ -130,6 +131,7 @@ class GeneralizedPareto(Distribution):
|
|||||||
concentration = self.concentration
|
concentration = self.concentration
|
||||||
valid = concentration < 0.5
|
valid = concentration < 0.5
|
||||||
safe_conc = torch.where(valid, concentration, 0.25)
|
safe_conc = torch.where(valid, concentration, 0.25)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
|
result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
|
||||||
return torch.where(valid, result, nan)
|
return torch.where(valid, result, nan)
|
||||||
|
|
||||||
@ -142,6 +144,7 @@ class GeneralizedPareto(Distribution):
|
|||||||
return self.loc
|
return self.loc
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def support(self):
|
def support(self):
|
||||||
lower = self.loc
|
lower = self.loc
|
||||||
upper = torch.where(
|
upper = torch.where(
|
||||||
|
@ -44,6 +44,7 @@ class Geometric(Distribution):
|
|||||||
logits (Number, Tensor): the log-odds of sampling `1`.
|
logits (Number, Tensor): the log-odds of sampling `1`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||||
support = constraints.nonnegative_integer
|
support = constraints.nonnegative_integer
|
||||||
|
|
||||||
@ -58,9 +59,11 @@ class Geometric(Distribution):
|
|||||||
"Either `probs` or `logits` must be specified, but not both."
|
"Either `probs` or `logits` must be specified, but not both."
|
||||||
)
|
)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
(self.probs,) = broadcast_all(probs)
|
(self.probs,) = broadcast_all(probs)
|
||||||
else:
|
else:
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
(self.logits,) = broadcast_all(logits)
|
(self.logits,) = broadcast_all(logits)
|
||||||
probs_or_logits = probs if probs is not None else logits
|
probs_or_logits = probs if probs is not None else logits
|
||||||
if isinstance(probs_or_logits, _Number):
|
if isinstance(probs_or_logits, _Number):
|
||||||
|
@ -32,6 +32,7 @@ class Gumbel(TransformedDistribution):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -32,8 +32,10 @@ class HalfCauchy(TransformedDistribution):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"scale": constraints.positive}
|
arg_constraints = {"scale": constraints.positive}
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
support = constraints.nonnegative
|
support = constraints.nonnegative
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
base_dist: Cauchy
|
base_dist: Cauchy
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -32,8 +32,10 @@ class HalfNormal(TransformedDistribution):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"scale": constraints.positive}
|
arg_constraints = {"scale": constraints.positive}
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
support = constraints.nonnegative
|
support = constraints.nonnegative
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
base_dist: Normal
|
base_dist: Normal
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -91,6 +91,7 @@ class Independent(Distribution, Generic[D]):
|
|||||||
return self.base_dist.has_enumerate_support
|
return self.base_dist.has_enumerate_support
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def support(self):
|
def support(self):
|
||||||
result = self.base_dist.support
|
result = self.base_dist.support
|
||||||
if self.reinterpreted_batch_ndims:
|
if self.reinterpreted_batch_ndims:
|
||||||
|
@ -38,8 +38,10 @@ class InverseGamma(TransformedDistribution):
|
|||||||
"concentration": constraints.positive,
|
"concentration": constraints.positive,
|
||||||
"rate": constraints.positive,
|
"rate": constraints.positive,
|
||||||
}
|
}
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
support = constraints.positive
|
support = constraints.positive
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
base_dist: Gamma
|
base_dist: Gamma
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -44,6 +44,7 @@ class Kumaraswamy(TransformedDistribution):
|
|||||||
"concentration1": constraints.positive,
|
"concentration1": constraints.positive,
|
||||||
"concentration0": constraints.positive,
|
"concentration0": constraints.positive,
|
||||||
}
|
}
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
support = constraints.unit_interval
|
support = constraints.unit_interval
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
|
||||||
@ -66,6 +67,7 @@ class Kumaraswamy(TransformedDistribution):
|
|||||||
AffineTransform(loc=1.0, scale=-1.0),
|
AffineTransform(loc=1.0, scale=-1.0),
|
||||||
PowerTransform(exponent=self.concentration1.reciprocal()),
|
PowerTransform(exponent=self.concentration1.reciprocal()),
|
||||||
]
|
]
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||||
|
|
||||||
def expand(self, batch_shape, _instance=None):
|
def expand(self, batch_shape, _instance=None):
|
||||||
|
@ -28,6 +28,7 @@ class Laplace(Distribution):
|
|||||||
scale (float or Tensor): scale of the distribution
|
scale (float or Tensor): scale of the distribution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
@ -60,6 +60,7 @@ class LKJCholesky(Distribution):
|
|||||||
Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
|
Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"concentration": constraints.positive}
|
arg_constraints = {"concentration": constraints.positive}
|
||||||
support = constraints.corr_cholesky
|
support = constraints.corr_cholesky
|
||||||
|
|
||||||
|
@ -32,8 +32,10 @@ class LogNormal(TransformedDistribution):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
support = constraints.positive
|
support = constraints.positive
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
base_dist: Normal
|
base_dist: Normal
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -36,8 +36,10 @@ class LogisticNormal(TransformedDistribution):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
support = constraints.simplex
|
support = constraints.simplex
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
base_dist: Independent[Normal]
|
base_dist: Independent[Normal]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -86,6 +86,7 @@ class LowRankMultivariateNormal(Distribution):
|
|||||||
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
|
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"loc": constraints.real_vector,
|
"loc": constraints.real_vector,
|
||||||
"cov_factor": constraints.independent(constraints.real, 2),
|
"cov_factor": constraints.independent(constraints.real, 2),
|
||||||
|
@ -124,6 +124,7 @@ class MixtureSameFamily(Distribution):
|
|||||||
return new
|
return new
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def support(self):
|
def support(self):
|
||||||
return MixtureSameFamilyConstraint(self._component_distribution.support)
|
return MixtureSameFamilyConstraint(self._component_distribution.support)
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ class Multinomial(Distribution):
|
|||||||
logits (Tensor): event log probabilities (unnormalized)
|
logits (Tensor): event log probabilities (unnormalized)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||||
total_count: int
|
total_count: int
|
||||||
|
|
||||||
@ -92,6 +93,7 @@ class Multinomial(Distribution):
|
|||||||
return self._categorical._new(*args, **kwargs)
|
return self._categorical._new(*args, **kwargs)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=True, event_dim=1)
|
@constraints.dependent_property(is_discrete=True, event_dim=1)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def support(self):
|
def support(self):
|
||||||
return constraints.multinomial(self.total_count)
|
return constraints.multinomial(self.total_count)
|
||||||
|
|
||||||
|
@ -123,6 +123,7 @@ class MultivariateNormal(Distribution):
|
|||||||
the corresponding lower triangular matrices using a Cholesky decomposition.
|
the corresponding lower triangular matrices using a Cholesky decomposition.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"loc": constraints.real_vector,
|
"loc": constraints.real_vector,
|
||||||
"covariance_matrix": constraints.positive_definite,
|
"covariance_matrix": constraints.positive_definite,
|
||||||
@ -156,6 +157,7 @@ class MultivariateNormal(Distribution):
|
|||||||
"with optional leading batch dimensions"
|
"with optional leading batch dimensions"
|
||||||
)
|
)
|
||||||
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
|
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
|
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
|
||||||
elif covariance_matrix is not None:
|
elif covariance_matrix is not None:
|
||||||
if covariance_matrix.dim() < 2:
|
if covariance_matrix.dim() < 2:
|
||||||
@ -166,6 +168,7 @@ class MultivariateNormal(Distribution):
|
|||||||
batch_shape = torch.broadcast_shapes(
|
batch_shape = torch.broadcast_shapes(
|
||||||
covariance_matrix.shape[:-2], loc.shape[:-1]
|
covariance_matrix.shape[:-2], loc.shape[:-1]
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
|
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
|
||||||
else:
|
else:
|
||||||
assert precision_matrix is not None # helps mypy
|
assert precision_matrix is not None # helps mypy
|
||||||
@ -177,6 +180,7 @@ class MultivariateNormal(Distribution):
|
|||||||
batch_shape = torch.broadcast_shapes(
|
batch_shape = torch.broadcast_shapes(
|
||||||
precision_matrix.shape[:-2], loc.shape[:-1]
|
precision_matrix.shape[:-2], loc.shape[:-1]
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
|
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
|
||||||
self.loc = loc.expand(batch_shape + (-1,))
|
self.loc = loc.expand(batch_shape + (-1,))
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ class NegativeBinomial(Distribution):
|
|||||||
logits (Tensor): Event log-odds for probabilities of success
|
logits (Tensor): Event log-odds for probabilities of success
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"total_count": constraints.greater_than_eq(0),
|
"total_count": constraints.greater_than_eq(0),
|
||||||
"probs": constraints.half_open_interval(0.0, 1.0),
|
"probs": constraints.half_open_interval(0.0, 1.0),
|
||||||
@ -54,6 +55,7 @@ class NegativeBinomial(Distribution):
|
|||||||
if probs is not None:
|
if probs is not None:
|
||||||
(
|
(
|
||||||
self.total_count,
|
self.total_count,
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.probs,
|
self.probs,
|
||||||
) = broadcast_all(total_count, probs)
|
) = broadcast_all(total_count, probs)
|
||||||
self.total_count = self.total_count.type_as(self.probs)
|
self.total_count = self.total_count.type_as(self.probs)
|
||||||
@ -61,6 +63,7 @@ class NegativeBinomial(Distribution):
|
|||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
(
|
(
|
||||||
self.total_count,
|
self.total_count,
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.logits,
|
self.logits,
|
||||||
) = broadcast_all(total_count, logits)
|
) = broadcast_all(total_count, logits)
|
||||||
self.total_count = self.total_count.type_as(self.logits)
|
self.total_count = self.total_count.type_as(self.logits)
|
||||||
|
@ -31,6 +31,7 @@ class Normal(ExponentialFamily):
|
|||||||
(often referred to as sigma)
|
(often referred to as sigma)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
@ -88,6 +89,7 @@ class Normal(ExponentialFamily):
|
|||||||
if self._validate_args:
|
if self._validate_args:
|
||||||
self._validate_sample(value)
|
self._validate_sample(value)
|
||||||
# compute the variance
|
# compute the variance
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
var = self.scale**2
|
var = self.scale**2
|
||||||
log_scale = (
|
log_scale = (
|
||||||
math.log(self.scale)
|
math.log(self.scale)
|
||||||
@ -117,5 +119,6 @@ class Normal(ExponentialFamily):
|
|||||||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||||
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
|
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _log_normalizer(self, x, y):
|
def _log_normalizer(self, x, y):
|
||||||
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
|
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
|
||||||
|
@ -42,6 +42,7 @@ class OneHotCategorical(Distribution):
|
|||||||
logits (Tensor): event log probabilities (unnormalized)
|
logits (Tensor): event log probabilities (unnormalized)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||||
support = constraints.one_hot
|
support = constraints.one_hot
|
||||||
has_enumerate_support = True
|
has_enumerate_support = True
|
||||||
|
@ -39,6 +39,7 @@ class Pareto(TransformedDistribution):
|
|||||||
self.scale, self.alpha = broadcast_all(scale, alpha)
|
self.scale, self.alpha = broadcast_all(scale, alpha)
|
||||||
base_dist = Exponential(self.alpha, validate_args=validate_args)
|
base_dist = Exponential(self.alpha, validate_args=validate_args)
|
||||||
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
|
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||||
|
|
||||||
def expand(
|
def expand(
|
||||||
|
@ -32,6 +32,7 @@ class Poisson(ExponentialFamily):
|
|||||||
rate (Number, Tensor): the rate parameter
|
rate (Number, Tensor): the rate parameter
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"rate": constraints.nonnegative}
|
arg_constraints = {"rate": constraints.nonnegative}
|
||||||
support = constraints.nonnegative_integer
|
support = constraints.nonnegative_integer
|
||||||
|
|
||||||
@ -82,5 +83,6 @@ class Poisson(ExponentialFamily):
|
|||||||
def _natural_params(self) -> tuple[Tensor]:
|
def _natural_params(self) -> tuple[Tensor]:
|
||||||
return (torch.log(self.rate),)
|
return (torch.log(self.rate),)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _log_normalizer(self, x):
|
def _log_normalizer(self, x):
|
||||||
return torch.exp(x)
|
return torch.exp(x)
|
||||||
|
@ -40,6 +40,7 @@ class LogitRelaxedBernoulli(Distribution):
|
|||||||
(Jang et al., 2017)
|
(Jang et al., 2017)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
|
|
||||||
@ -57,10 +58,12 @@ class LogitRelaxedBernoulli(Distribution):
|
|||||||
)
|
)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
is_scalar = isinstance(probs, _Number)
|
is_scalar = isinstance(probs, _Number)
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
(self.probs,) = broadcast_all(probs)
|
(self.probs,) = broadcast_all(probs)
|
||||||
else:
|
else:
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
is_scalar = isinstance(logits, _Number)
|
is_scalar = isinstance(logits, _Number)
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
(self.logits,) = broadcast_all(logits)
|
(self.logits,) = broadcast_all(logits)
|
||||||
self._param = self.probs if probs is not None else self.logits
|
self._param = self.probs if probs is not None else self.logits
|
||||||
if is_scalar:
|
if is_scalar:
|
||||||
@ -138,8 +141,10 @@ class RelaxedBernoulli(TransformedDistribution):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
support = constraints.unit_interval
|
support = constraints.unit_interval
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
base_dist: LogitRelaxedBernoulli
|
base_dist: LogitRelaxedBernoulli
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -38,6 +38,7 @@ class ExpRelaxedCategorical(Distribution):
|
|||||||
(Jang et al., 2017)
|
(Jang et al., 2017)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||||
support = (
|
support = (
|
||||||
constraints.real_vector
|
constraints.real_vector
|
||||||
@ -127,8 +128,10 @@ class RelaxedOneHotCategorical(TransformedDistribution):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
support = constraints.simplex
|
support = constraints.simplex
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
base_dist: ExpRelaxedCategorical
|
base_dist: ExpRelaxedCategorical
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -31,6 +31,7 @@ class StudentT(Distribution):
|
|||||||
scale (float or Tensor): scale of the distribution
|
scale (float or Tensor): scale of the distribution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"df": constraints.positive,
|
"df": constraints.positive,
|
||||||
"loc": constraints.real,
|
"loc": constraints.real,
|
||||||
|
@ -123,6 +123,7 @@ class TransformedDistribution(Distribution):
|
|||||||
return new
|
return new
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def support(self):
|
def support(self):
|
||||||
if not self.transforms:
|
if not self.transforms:
|
||||||
return self.base_dist.support
|
return self.base_dist.support
|
||||||
|
@ -226,11 +226,13 @@ class _InverseTransform(Transform):
|
|||||||
self._inv: Transform = transform # type: ignore[assignment]
|
self._inv: Transform = transform # type: ignore[assignment]
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def domain(self):
|
def domain(self):
|
||||||
assert self._inv is not None
|
assert self._inv is not None
|
||||||
return self._inv.codomain
|
return self._inv.codomain
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
assert self._inv is not None
|
assert self._inv is not None
|
||||||
return self._inv.domain
|
return self._inv.domain
|
||||||
@ -300,6 +302,7 @@ class ComposeTransform(Transform):
|
|||||||
return self.parts == other.parts
|
return self.parts == other.parts
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def domain(self):
|
def domain(self):
|
||||||
if not self.parts:
|
if not self.parts:
|
||||||
return constraints.real
|
return constraints.real
|
||||||
@ -315,6 +318,7 @@ class ComposeTransform(Transform):
|
|||||||
return domain
|
return domain
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
if not self.parts:
|
if not self.parts:
|
||||||
return constraints.real
|
return constraints.real
|
||||||
@ -434,12 +438,14 @@ class IndependentTransform(Transform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def domain(self):
|
def domain(self):
|
||||||
return constraints.independent(
|
return constraints.independent(
|
||||||
self.base_transform.domain, self.reinterpreted_batch_ndims
|
self.base_transform.domain, self.reinterpreted_batch_ndims
|
||||||
)
|
)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
return constraints.independent(
|
return constraints.independent(
|
||||||
self.base_transform.codomain, self.reinterpreted_batch_ndims
|
self.base_transform.codomain, self.reinterpreted_batch_ndims
|
||||||
@ -507,10 +513,12 @@ class ReshapeTransform(Transform):
|
|||||||
super().__init__(cache_size=cache_size)
|
super().__init__(cache_size=cache_size)
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def domain(self):
|
def domain(self):
|
||||||
return constraints.independent(constraints.real, len(self.in_shape))
|
return constraints.independent(constraints.real, len(self.in_shape))
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
return constraints.independent(constraints.real, len(self.out_shape))
|
return constraints.independent(constraints.real, len(self.out_shape))
|
||||||
|
|
||||||
@ -764,12 +772,14 @@ class AffineTransform(Transform):
|
|||||||
return self._event_dim
|
return self._event_dim
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def domain(self):
|
def domain(self):
|
||||||
if self.event_dim == 0:
|
if self.event_dim == 0:
|
||||||
return constraints.real
|
return constraints.real
|
||||||
return constraints.independent(constraints.real, self.event_dim)
|
return constraints.independent(constraints.real, self.event_dim)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
if self.event_dim == 0:
|
if self.event_dim == 0:
|
||||||
return constraints.real
|
return constraints.real
|
||||||
@ -867,6 +877,7 @@ class CorrCholeskyTransform(Transform):
|
|||||||
# apply stick-breaking on the squared values
|
# apply stick-breaking on the squared values
|
||||||
# Note that y = sign(r) * sqrt(z * z1m_cumprod)
|
# Note that y = sign(r) * sqrt(z * z1m_cumprod)
|
||||||
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
|
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
z = r**2
|
z = r**2
|
||||||
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
|
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
|
||||||
# Diagonal elements must be 1.
|
# Diagonal elements must be 1.
|
||||||
@ -1155,12 +1166,14 @@ class CatTransform(Transform):
|
|||||||
return all(t.bijective for t in self.transforms)
|
return all(t.bijective for t in self.transforms)
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def domain(self):
|
def domain(self):
|
||||||
return constraints.cat(
|
return constraints.cat(
|
||||||
[t.domain for t in self.transforms], self.dim, self.lengths
|
[t.domain for t in self.transforms], self.dim, self.lengths
|
||||||
)
|
)
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
return constraints.cat(
|
return constraints.cat(
|
||||||
[t.codomain for t in self.transforms], self.dim, self.lengths
|
[t.codomain for t in self.transforms], self.dim, self.lengths
|
||||||
@ -1233,10 +1246,12 @@ class StackTransform(Transform):
|
|||||||
return all(t.bijective for t in self.transforms)
|
return all(t.bijective for t in self.transforms)
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def domain(self):
|
def domain(self):
|
||||||
return constraints.stack([t.domain for t in self.transforms], self.dim)
|
return constraints.stack([t.domain for t in self.transforms], self.dim)
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
return constraints.stack([t.codomain for t in self.transforms], self.dim)
|
return constraints.stack([t.codomain for t in self.transforms], self.dim)
|
||||||
|
|
||||||
|
@ -79,6 +79,7 @@ class Uniform(Distribution):
|
|||||||
return new
|
return new
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def support(self):
|
def support(self):
|
||||||
return constraints.interval(self.low, self.high)
|
return constraints.interval(self.low, self.high)
|
||||||
|
|
||||||
|
@ -92,6 +92,7 @@ def _log_modified_bessel_fn(x, order=0):
|
|||||||
@torch.jit.script_if_tracing
|
@torch.jit.script_if_tracing
|
||||||
def _rejection_sample(loc, concentration, proposal_r, x):
|
def _rejection_sample(loc, concentration, proposal_r, x):
|
||||||
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
|
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
while not done.all():
|
while not done.all():
|
||||||
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
|
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
|
||||||
u1, u2, u3 = u.unbind()
|
u1, u2, u3 = u.unbind()
|
||||||
@ -100,6 +101,7 @@ def _rejection_sample(loc, concentration, proposal_r, x):
|
|||||||
c = concentration * (proposal_r - f)
|
c = concentration * (proposal_r - f)
|
||||||
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
|
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
|
||||||
if accept.any():
|
if accept.any():
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
|
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
|
||||||
done = done | accept
|
done = done | accept
|
||||||
return (x + math.pi + loc) % (2 * math.pi) - math.pi
|
return (x + math.pi + loc) % (2 * math.pi) - math.pi
|
||||||
@ -123,6 +125,7 @@ class VonMises(Distribution):
|
|||||||
:param torch.Tensor concentration: concentration parameter
|
:param torch.Tensor concentration: concentration parameter
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
has_rsample = False
|
has_rsample = False
|
||||||
@ -160,8 +163,10 @@ class VonMises(Distribution):
|
|||||||
@lazy_property
|
@lazy_property
|
||||||
def _proposal_r(self) -> Tensor:
|
def _proposal_r(self) -> Tensor:
|
||||||
kappa = self._concentration
|
kappa = self._concentration
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
tau = 1 + (1 + 4 * kappa**2).sqrt()
|
tau = 1 + (1 + 4 * kappa**2).sqrt()
|
||||||
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
|
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
_proposal_r = (1 + rho**2) / (2 * rho)
|
_proposal_r = (1 + rho**2) / (2 * rho)
|
||||||
# second order Taylor expansion around 0 for small kappa
|
# second order Taylor expansion around 0 for small kappa
|
||||||
_proposal_r_taylor = 1 / kappa + kappa
|
_proposal_r_taylor = 1 / kappa + kappa
|
||||||
|
@ -35,6 +35,7 @@ class Weibull(TransformedDistribution):
|
|||||||
"scale": constraints.positive,
|
"scale": constraints.positive,
|
||||||
"concentration": constraints.positive,
|
"concentration": constraints.positive,
|
||||||
}
|
}
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
support = constraints.positive
|
support = constraints.positive
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -52,6 +53,7 @@ class Weibull(TransformedDistribution):
|
|||||||
PowerTransform(exponent=self.concentration_reciprocal),
|
PowerTransform(exponent=self.concentration_reciprocal),
|
||||||
AffineTransform(loc=0, scale=self.scale),
|
AffineTransform(loc=0, scale=self.scale),
|
||||||
]
|
]
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||||
|
|
||||||
def expand(self, batch_shape, _instance=None):
|
def expand(self, batch_shape, _instance=None):
|
||||||
|
@ -116,10 +116,13 @@ class Wishart(ExponentialFamily):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if scale_tril is not None:
|
if scale_tril is not None:
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.scale_tril = param.expand(batch_shape + (-1, -1))
|
self.scale_tril = param.expand(batch_shape + (-1, -1))
|
||||||
elif covariance_matrix is not None:
|
elif covariance_matrix is not None:
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
|
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
|
||||||
elif precision_matrix is not None:
|
elif precision_matrix is not None:
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.precision_matrix = param.expand(batch_shape + (-1, -1))
|
self.precision_matrix = param.expand(batch_shape + (-1, -1))
|
||||||
|
|
||||||
if self.df.lt(event_shape[-1]).any():
|
if self.df.lt(event_shape[-1]).any():
|
||||||
@ -335,6 +338,7 @@ class Wishart(ExponentialFamily):
|
|||||||
p = self._event_shape[-1] # has singleton shape
|
p = self._event_shape[-1] # has singleton shape
|
||||||
return -self.precision_matrix / 2, (nu - p - 1) / 2
|
return -self.precision_matrix / 2, (nu - p - 1) / 2
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _log_normalizer(self, x, y):
|
def _log_normalizer(self, x, y):
|
||||||
p = self._event_shape[-1]
|
p = self._event_shape[-1]
|
||||||
return (y + (p + 1) / 2) * (
|
return (y + (p + 1) / 2) * (
|
||||||
|
@ -24,11 +24,15 @@ class TensorProperties:
|
|||||||
|
|
||||||
if not self.is_fake:
|
if not self.is_fake:
|
||||||
# only get the storage pointer for real tensors
|
# only get the storage pointer for real tensors
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.storage_ptr = tensor.untyped_storage().data_ptr()
|
self.storage_ptr = tensor.untyped_storage().data_ptr()
|
||||||
if self.is_contiguous:
|
if self.is_contiguous:
|
||||||
# only get storage size and start/end pointers for contiguous tensors
|
# only get storage size and start/end pointers for contiguous tensors
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.storage_size = tensor.untyped_storage().nbytes()
|
self.storage_size = tensor.untyped_storage().nbytes()
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.start = tensor.data_ptr()
|
self.start = tensor.data_ptr()
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.end = _end_ptr(tensor)
|
self.end = _end_ptr(tensor)
|
||||||
|
|
||||||
# info to recover tensor
|
# info to recover tensor
|
||||||
|
@ -65,6 +65,7 @@ class GraphPickler(pickle.Pickler):
|
|||||||
self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
|
self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def reducer_override(
|
def reducer_override(
|
||||||
self, obj: object
|
self, obj: object
|
||||||
) -> tuple[Callable[..., Any], tuple[Any, ...]]:
|
) -> tuple[Callable[..., Any], tuple[Any, ...]]:
|
||||||
@ -201,6 +202,7 @@ class _SymNodePickleData:
|
|||||||
]:
|
]:
|
||||||
args = (cls(obj.node), pickler._unpickle_state)
|
args = (cls(obj.node), pickler._unpickle_state)
|
||||||
if isinstance(obj, torch.SymInt):
|
if isinstance(obj, torch.SymInt):
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return _SymNodePickleData.unpickle_sym_int, args
|
return _SymNodePickleData.unpickle_sym_int, args
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
|
raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
|
||||||
@ -277,6 +279,7 @@ class _TensorPickleData:
|
|||||||
return FakeTensor(
|
return FakeTensor(
|
||||||
unpickle_state.fake_mode,
|
unpickle_state.fake_mode,
|
||||||
make_meta_t(),
|
make_meta_t(),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -603,6 +603,7 @@ class Tracer(TracerBase):
|
|||||||
in inspect.signature(self.create_proxy).parameters
|
in inspect.signature(self.create_proxy).parameters
|
||||||
):
|
):
|
||||||
kwargs["proxy_factory_fn"] = (
|
kwargs["proxy_factory_fn"] = (
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
None
|
None
|
||||||
if not self.param_shapes_constant
|
if not self.param_shapes_constant
|
||||||
else lambda node: ParameterProxy(
|
else lambda node: ParameterProxy(
|
||||||
|
@ -657,7 +657,10 @@ class Partitioner:
|
|||||||
# Mark bfs level
|
# Mark bfs level
|
||||||
get_bfs_level_partition(self.partitions)
|
get_bfs_level_partition(self.partitions)
|
||||||
find_combination, partitions = find_partition_to_combine_based_on_size(
|
find_combination, partitions = find_partition_to_combine_based_on_size(
|
||||||
sorted_partitions, available_mem_bytes, partitions
|
sorted_partitions,
|
||||||
|
available_mem_bytes,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
partitions,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -702,6 +705,7 @@ class Partitioner:
|
|||||||
non_embedding_partitions.append(partition)
|
non_embedding_partitions.append(partition)
|
||||||
if new_partition:
|
if new_partition:
|
||||||
partition = self.create_partition()
|
partition = self.create_partition()
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
partition.left_mem_bytes = available_mem_bytes
|
partition.left_mem_bytes = available_mem_bytes
|
||||||
return partition
|
return partition
|
||||||
return None
|
return None
|
||||||
@ -997,6 +1001,7 @@ class Partitioner:
|
|||||||
node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
|
node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
|
||||||
)
|
)
|
||||||
if cost < min_cost:
|
if cost < min_cost:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
node_pair = [node, n1]
|
node_pair = [node, n1]
|
||||||
min_cost = cost
|
min_cost = cost
|
||||||
return cost, node_pair # type: ignore[possibly-undefined]
|
return cost, node_pair # type: ignore[possibly-undefined]
|
||||||
|
@ -30,6 +30,7 @@ def split_result_tensors(
|
|||||||
else:
|
else:
|
||||||
splits = [x.shape[0] for x in inputs]
|
splits = [x.shape[0] for x in inputs]
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch.split(result, splits)
|
return torch.split(result, splits)
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,7 +171,14 @@ class MetaTracer(torch.fx.Tracer):
|
|||||||
proxy_factory_fn=None,
|
proxy_factory_fn=None,
|
||||||
):
|
):
|
||||||
rv = super().create_proxy(
|
rv = super().create_proxy(
|
||||||
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
|
kind,
|
||||||
|
target,
|
||||||
|
args,
|
||||||
|
kwargs,
|
||||||
|
name,
|
||||||
|
type_expr,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
proxy_factory_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
if kind == "placeholder" and target in self.meta_args:
|
if kind == "placeholder" and target in self.meta_args:
|
||||||
@ -193,6 +200,7 @@ class MetaTracer(torch.fx.Tracer):
|
|||||||
|
|
||||||
if kind == "call_function":
|
if kind == "call_function":
|
||||||
meta_target = manual_meta_overrides.get(target, target)
|
meta_target = manual_meta_overrides.get(target, target)
|
||||||
|
# pyrefly: ignore # not-callable
|
||||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||||
elif kind == "call_method":
|
elif kind == "call_method":
|
||||||
meta_target = getattr(args_metas[0], target) # type: ignore[index]
|
meta_target = getattr(args_metas[0], target) # type: ignore[index]
|
||||||
|
@ -528,9 +528,11 @@ def view_inference_rule(n: Node, symbols, constraints, counter):
|
|||||||
if t == -1:
|
if t == -1:
|
||||||
var, counter = gen_dvar(counter)
|
var, counter = gen_dvar(counter)
|
||||||
t2_type.append(var)
|
t2_type.append(var)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
num_constraints.append(BinConstraintD(var, Dyn, op_neq))
|
num_constraints.append(BinConstraintD(var, Dyn, op_neq))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
num_constraints.append(BinConstraintD(t, Dyn, op_neq))
|
num_constraints.append(BinConstraintD(t, Dyn, op_neq))
|
||||||
t2_type.append(t) # type: ignore[arg-type]
|
t2_type.append(t) # type: ignore[arg-type]
|
||||||
|
|
||||||
@ -1475,6 +1477,7 @@ class ConstraintGenerator:
|
|||||||
|
|
||||||
all_constraints = []
|
all_constraints = []
|
||||||
|
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
for n in graph.nodes:
|
for n in graph.nodes:
|
||||||
(constraints, counter) = self.generate_constraints_node(n, counter)
|
(constraints, counter) = self.generate_constraints_node(n, counter)
|
||||||
all_constraints += constraints
|
all_constraints += constraints
|
||||||
|
@ -193,6 +193,7 @@ def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
|
|||||||
assert isinstance(node.target, str)
|
assert isinstance(node.target, str)
|
||||||
cur_module = modules[node.target]
|
cur_module = modules[node.target]
|
||||||
if type(cur_module) in mkldnn_map:
|
if type(cur_module) in mkldnn_map:
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
|
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
|
||||||
assert isinstance(new_module, nn.Module)
|
assert isinstance(new_module, nn.Module)
|
||||||
old_modules[new_module] = copy.deepcopy(cur_module)
|
old_modules[new_module] = copy.deepcopy(cur_module)
|
||||||
@ -263,7 +264,10 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
|
|||||||
)
|
)
|
||||||
|
|
||||||
reset_modules(
|
reset_modules(
|
||||||
submodule.graph.nodes, dict(submodule.named_modules()), old_modules
|
submodule.graph.nodes,
|
||||||
|
dict(submodule.named_modules()),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
old_modules,
|
||||||
)
|
)
|
||||||
no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
|
no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
|
||||||
return mkl_time < no_mkl_time
|
return mkl_time < no_mkl_time
|
||||||
|
@ -124,6 +124,7 @@ pytree.register_pytree_node(
|
|||||||
torch.Size,
|
torch.Size,
|
||||||
lambda xs: (list(xs), None),
|
lambda xs: (list(xs), None),
|
||||||
lambda xs, _: tuple(xs),
|
lambda xs, _: tuple(xs),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
flatten_with_keys_fn=lambda xs: (
|
flatten_with_keys_fn=lambda xs: (
|
||||||
[(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
|
[(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
|
||||||
None,
|
None,
|
||||||
@ -306,6 +307,7 @@ def set_proxy_slot( # type: ignore[no-redef]
|
|||||||
|
|
||||||
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
|
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
|
||||||
assert isinstance(obj, (Tensor, SymNode)), type(obj)
|
assert isinstance(obj, (Tensor, SymNode)), type(obj)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
|
return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
|
||||||
|
|
||||||
|
|
||||||
@ -402,6 +404,7 @@ def get_proxy_slot(
|
|||||||
assert isinstance(obj, py_sym_types), type(obj)
|
assert isinstance(obj, py_sym_types), type(obj)
|
||||||
tracker = tracer.symnode_tracker
|
tracker = tracer.symnode_tracker
|
||||||
|
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
if obj not in tracker:
|
if obj not in tracker:
|
||||||
# Last ditch
|
# Last ditch
|
||||||
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
|
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
|
||||||
@ -413,6 +416,7 @@ def get_proxy_slot(
|
|||||||
)
|
)
|
||||||
return default
|
return default
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
value = tracker[obj]
|
value = tracker[obj]
|
||||||
res = transform(value)
|
res = transform(value)
|
||||||
return res
|
return res
|
||||||
@ -788,6 +792,7 @@ def fetch_object_proxy(
|
|||||||
def fetch_object_proxy(
|
def fetch_object_proxy(
|
||||||
tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType]
|
tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType]
|
||||||
) -> object:
|
) -> object:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return get_proxy_slot(t, tracer, t)
|
return get_proxy_slot(t, tracer, t)
|
||||||
|
|
||||||
|
|
||||||
@ -836,6 +841,7 @@ def _fetch_proxies_and_all_constant_flag(
|
|||||||
"""
|
"""
|
||||||
f_flat_args_kwargs = [
|
f_flat_args_kwargs = [
|
||||||
(
|
(
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
fetch_object_proxy(tracer, x)
|
fetch_object_proxy(tracer, x)
|
||||||
if isinstance(x, (Tensor, _AnyScriptObject))
|
if isinstance(x, (Tensor, _AnyScriptObject))
|
||||||
else x
|
else x
|
||||||
@ -1410,6 +1416,7 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
|
|||||||
kwargs: Optional[dict[str, object]] = None,
|
kwargs: Optional[dict[str, object]] = None,
|
||||||
) -> object:
|
) -> object:
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.tracer.torch_fn_metadata = func
|
self.tracer.torch_fn_metadata = func
|
||||||
self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1
|
self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
@ -1459,6 +1466,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
|||||||
# For autocast, the python APIs run so we don't have to run them again
|
# For autocast, the python APIs run so we don't have to run them again
|
||||||
# here.
|
# here.
|
||||||
if func is torch._C._set_grad_enabled:
|
if func is torch._C._set_grad_enabled:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
@ -1672,6 +1680,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
self.decomposition_table = decomposition_table or {}
|
self.decomposition_table = decomposition_table or {}
|
||||||
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
|
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self,
|
self,
|
||||||
target: str, # type: ignore[override]
|
target: str, # type: ignore[override]
|
||||||
@ -1684,6 +1693,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
# TODO handle case where the first character of target is '*'
|
# TODO handle case where the first character of target is '*'
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def get_attr(
|
def get_attr(
|
||||||
self,
|
self,
|
||||||
target: str, # type: ignore[override]
|
target: str, # type: ignore[override]
|
||||||
@ -1697,6 +1707,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
|
|
||||||
# call_function, call_method, call_module get traced automatically by the outer mode.
|
# call_function, call_method, call_module get traced automatically by the outer mode.
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def output(
|
def output(
|
||||||
self,
|
self,
|
||||||
target: str, # type: ignore[override]
|
target: str, # type: ignore[override]
|
||||||
@ -1782,14 +1793,17 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
self.enable_attr_proxy = False
|
self.enable_attr_proxy = False
|
||||||
self.submodule_paths = {}
|
self.submodule_paths = {}
|
||||||
for name, m in self.scope_root.named_modules(remove_duplicate=False):
|
for name, m in self.scope_root.named_modules(remove_duplicate=False):
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
if m in self.submodule_paths:
|
if m in self.submodule_paths:
|
||||||
log.info(
|
log.info(
|
||||||
"Shared module found between %s and %s, AttrProxy is enabled.",
|
"Shared module found between %s and %s, AttrProxy is enabled.",
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
self.submodule_paths[m],
|
self.submodule_paths[m],
|
||||||
name,
|
name,
|
||||||
)
|
)
|
||||||
self.enable_attr_proxy = True
|
self.enable_attr_proxy = True
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
self.submodule_paths[m] = name
|
self.submodule_paths[m] = name
|
||||||
|
|
||||||
self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary()
|
self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary()
|
||||||
@ -1815,6 +1829,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
# Class is modified to be a subclass of torch.nn.Module
|
# Class is modified to be a subclass of torch.nn.Module
|
||||||
# Warning: We blow away our own attributes here to mimic the base class
|
# Warning: We blow away our own attributes here to mimic the base class
|
||||||
# - so don't expect `self.x` to do anything useful.
|
# - so don't expect `self.x` to do anything useful.
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
self.__class__ = type(
|
self.__class__ = type(
|
||||||
base.__class__.__name__,
|
base.__class__.__name__,
|
||||||
(self.__class__, base.__class__),
|
(self.__class__, base.__class__),
|
||||||
@ -1837,6 +1852,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
if not isinstance(attr_val, Module):
|
if not isinstance(attr_val, Module):
|
||||||
return attr_val
|
return attr_val
|
||||||
|
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name)
|
return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name)
|
||||||
|
|
||||||
def get_base(self) -> Module:
|
def get_base(self) -> Module:
|
||||||
@ -1849,10 +1865,12 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
res = torch.nn.Sequential(
|
res = torch.nn.Sequential(
|
||||||
OrderedDict(list(self._modules.items())[idx])
|
OrderedDict(list(self._modules.items())[idx])
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
|
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
|
||||||
elif isinstance(self, torch.nn.ModuleList):
|
elif isinstance(self, torch.nn.ModuleList):
|
||||||
# Copied from nn/modules/container.py
|
# Copied from nn/modules/container.py
|
||||||
res = torch.nn.ModuleList(list(self._modules.values())[idx])
|
res = torch.nn.ModuleList(list(self._modules.values())[idx])
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
|
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
|
||||||
|
|
||||||
return super().__getitem__(idx) # type: ignore[misc]
|
return super().__getitem__(idx) # type: ignore[misc]
|
||||||
|
@ -839,6 +839,7 @@ def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr:
|
|||||||
factor = functools.reduce(math.gcd, map(integer_coefficient, atoms))
|
factor = functools.reduce(math.gcd, map(integer_coefficient, atoms))
|
||||||
if factor == 1:
|
if factor == 1:
|
||||||
return expr
|
return expr
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
atoms = [div_by_factor(x, factor) for x in atoms]
|
atoms = [div_by_factor(x, factor) for x in atoms]
|
||||||
return _sympy_from_args(
|
return _sympy_from_args(
|
||||||
sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative
|
sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative
|
||||||
@ -1234,6 +1235,7 @@ def _free_unbacked_symbols_with_path(
|
|||||||
else _symint_wrap(coeff)
|
else _symint_wrap(coeff)
|
||||||
)
|
)
|
||||||
# TODO: DivideByKey needs to test divisibility at runtime!
|
# TODO: DivideByKey needs to test divisibility at runtime!
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
r[unbacked] = path + (DivideByKey(divisor),)
|
r[unbacked] = path + (DivideByKey(divisor),)
|
||||||
if real is not None:
|
if real is not None:
|
||||||
assert isinstance(real, int)
|
assert isinstance(real, int)
|
||||||
@ -1256,6 +1258,7 @@ def _free_unbacked_symbols_with_path(
|
|||||||
and s.rhs == 1
|
and s.rhs == 1
|
||||||
and s.lhs in pending
|
and s.lhs in pending
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
r[s.lhs] = path + (ConvertIntKey(),)
|
r[s.lhs] = path + (ConvertIntKey(),)
|
||||||
if real is not None:
|
if real is not None:
|
||||||
assert type(real) is bool
|
assert type(real) is bool
|
||||||
@ -2172,6 +2175,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext):
|
|||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
if self.inner_contexts is None:
|
if self.inner_contexts is None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.inner_contexts = {}
|
self.inner_contexts = {}
|
||||||
|
|
||||||
|
|
||||||
@ -2260,9 +2264,12 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
|
|||||||
# only re-create the objects if any of the args changed to avoid expensive
|
# only re-create the objects if any of the args changed to avoid expensive
|
||||||
# checks when re-creating objects.
|
# checks when re-creating objects.
|
||||||
new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type]
|
new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type]
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)):
|
if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return _fast_expand(expr.func(*new_args))
|
return _fast_expand(expr.func(*new_args))
|
||||||
|
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if expr.is_Pow:
|
if expr.is_Pow:
|
||||||
base: sympy.Expr
|
base: sympy.Expr
|
||||||
exp: sympy.Expr
|
exp: sympy.Expr
|
||||||
@ -2272,9 +2279,11 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
|
|||||||
return sympy.expand_multinomial(expr, deep=False)
|
return sympy.expand_multinomial(expr, deep=False)
|
||||||
elif exp < 0:
|
elif exp < 0:
|
||||||
return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
|
return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
elif expr.is_Mul:
|
elif expr.is_Mul:
|
||||||
num: list[sympy.Expr] = []
|
num: list[sympy.Expr] = []
|
||||||
den: list[sympy.Expr] = []
|
den: list[sympy.Expr] = []
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
for arg in expr.args:
|
for arg in expr.args:
|
||||||
if arg.is_Pow and arg.args[1] == -1:
|
if arg.is_Pow and arg.args[1] == -1:
|
||||||
den.append(S.One / arg) # type: ignore[operator, arg-type]
|
den.append(S.One / arg) # type: ignore[operator, arg-type]
|
||||||
@ -2396,6 +2405,7 @@ def _maybe_evaluate_static_worker(
|
|||||||
|
|
||||||
# TODO: remove this try catch (esp for unbacked_only)
|
# TODO: remove this try catch (esp for unbacked_only)
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
new_expr = expr.xreplace(new_shape_env)
|
new_expr = expr.xreplace(new_shape_env)
|
||||||
except RecursionError:
|
except RecursionError:
|
||||||
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
|
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
|
||||||
@ -2933,13 +2943,19 @@ class DimConstraints:
|
|||||||
# is_integer tests though haha
|
# is_integer tests though haha
|
||||||
return (base - mod_reduced) / divisor
|
return (base - mod_reduced) / divisor
|
||||||
|
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if expr.has(Mod):
|
if expr.has(Mod):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
expr = expr.replace(Mod, mod_handler)
|
expr = expr.replace(Mod, mod_handler)
|
||||||
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
|
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
|
||||||
# arguments should be OK.
|
# arguments should be OK.
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if expr.has(PythonMod):
|
if expr.has(PythonMod):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
expr = expr.replace(PythonMod, mod_handler)
|
expr = expr.replace(PythonMod, mod_handler)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if expr.has(FloorDiv):
|
if expr.has(FloorDiv):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
expr = expr.replace(FloorDiv, floor_div_handler)
|
expr = expr.replace(FloorDiv, floor_div_handler)
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
@ -5057,6 +5073,7 @@ class ShapeEnv:
|
|||||||
|
|
||||||
if duck:
|
if duck:
|
||||||
# Make sure to reuse this symbol for subsequent duck shaping
|
# Make sure to reuse this symbol for subsequent duck shaping
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
self.val_to_var[val] = sympy_expr
|
self.val_to_var[val] = sympy_expr
|
||||||
|
|
||||||
if isinstance(val, int):
|
if isinstance(val, int):
|
||||||
@ -5288,15 +5305,19 @@ class ShapeEnv:
|
|||||||
|
|
||||||
# Expand optional inputs, or verify invariants are upheld
|
# Expand optional inputs, or verify invariants are upheld
|
||||||
if input_contexts is None:
|
if input_contexts is None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
input_contexts = [
|
input_contexts = [
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_create_no_constraints_context(t) if isinstance(t, Tensorlike) else None
|
_create_no_constraints_context(t) if isinstance(t, Tensorlike) else None
|
||||||
for t in placeholders
|
for t in placeholders
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
assert len(input_contexts) == len(placeholders)
|
assert len(input_contexts) == len(placeholders)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
|
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
|
||||||
if isinstance(t, Tensorlike):
|
if isinstance(t, Tensorlike):
|
||||||
if context is None:
|
if context is None:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
input_contexts[i] = _create_no_constraints_context(t)
|
input_contexts[i] = _create_no_constraints_context(t)
|
||||||
else:
|
else:
|
||||||
assert isinstance(t, (SymInt, int, SymFloat, float))
|
assert isinstance(t, (SymInt, int, SymFloat, float))
|
||||||
@ -5582,6 +5603,7 @@ class ShapeEnv:
|
|||||||
s = sympy.Float(val)
|
s = sympy.Float(val)
|
||||||
input_guards.append((source, s))
|
input_guards.append((source, s))
|
||||||
|
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
for t, source, context in zip(placeholders, sources, input_contexts):
|
for t, source, context in zip(placeholders, sources, input_contexts):
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
from torch._dynamo.source import LocalSource
|
from torch._dynamo.source import LocalSource
|
||||||
@ -5641,11 +5663,13 @@ class ShapeEnv:
|
|||||||
)
|
)
|
||||||
track_symint(property_source, ss, constraint_size[i])
|
track_symint(property_source, ss, constraint_size[i])
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
for i, ss in enumerate(curr_t.size()):
|
for i, ss in enumerate(curr_t.size()):
|
||||||
property_source = TensorPropertySource(
|
property_source = TensorPropertySource(
|
||||||
src, TensorProperty.SIZE, i
|
src, TensorProperty.SIZE, i
|
||||||
)
|
)
|
||||||
track_symint(property_source, ss, constraint_size[i])
|
track_symint(property_source, ss, constraint_size[i])
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
for i, ss in enumerate(curr_t.stride()):
|
for i, ss in enumerate(curr_t.stride()):
|
||||||
property_source = TensorPropertySource(
|
property_source = TensorPropertySource(
|
||||||
src, TensorProperty.STRIDE, i
|
src, TensorProperty.STRIDE, i
|
||||||
@ -5653,6 +5677,7 @@ class ShapeEnv:
|
|||||||
track_symint(property_source, ss, constraint_stride[i])
|
track_symint(property_source, ss, constraint_stride[i])
|
||||||
track_symint(
|
track_symint(
|
||||||
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
|
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
curr_t.storage_offset(),
|
curr_t.storage_offset(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -5698,6 +5723,7 @@ class ShapeEnv:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if is_dim(source):
|
if is_dim(source):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
self.dim_constraints.add_equality(source, expr)
|
self.dim_constraints.add_equality(source, expr)
|
||||||
|
|
||||||
for exprs, printer, lang in zip(all_exprs, printers, langs):
|
for exprs, printer, lang in zip(all_exprs, printers, langs):
|
||||||
@ -5851,6 +5877,7 @@ class ShapeEnv:
|
|||||||
continue
|
continue
|
||||||
expr = self.simplify(ra.expr)
|
expr = self.simplify(ra.expr)
|
||||||
|
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
self.dim_constraints.add(expr)
|
self.dim_constraints.add(expr)
|
||||||
|
|
||||||
# 3. Every symbol must be within its value range (this handles 0/1
|
# 3. Every symbol must be within its value range (this handles 0/1
|
||||||
@ -5867,6 +5894,7 @@ class ShapeEnv:
|
|||||||
verbose_expr = ""
|
verbose_expr = ""
|
||||||
if r.lower not in (-sympy.oo, -int_oo):
|
if r.lower not in (-sympy.oo, -int_oo):
|
||||||
if any(is_dim(source) for source in sources):
|
if any(is_dim(source) for source in sources):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
|
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
|
||||||
# Only print lower bound in simplified mode if it is not the
|
# Only print lower bound in simplified mode if it is not the
|
||||||
# default
|
# default
|
||||||
@ -5875,6 +5903,7 @@ class ShapeEnv:
|
|||||||
verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}"
|
verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}"
|
||||||
if r.upper not in (sympy.oo, int_oo):
|
if r.upper not in (sympy.oo, int_oo):
|
||||||
if any(is_dim(source) for source in sources):
|
if any(is_dim(source) for source in sources):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
self.dim_constraints.add(sympy.Le(symbol, r.upper))
|
self.dim_constraints.add(sympy.Le(symbol, r.upper))
|
||||||
# nontrivial upper bound is always interesting
|
# nontrivial upper bound is always interesting
|
||||||
bounds.append(sympy.Le(symbol, r.upper, evaluate=False))
|
bounds.append(sympy.Le(symbol, r.upper, evaluate=False))
|
||||||
@ -5943,6 +5972,7 @@ class ShapeEnv:
|
|||||||
else:
|
else:
|
||||||
str_msg = f" - {msg_cb()}"
|
str_msg = f" - {msg_cb()}"
|
||||||
error_msgs.append(str_msg)
|
error_msgs.append(str_msg)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
debug_names.add(debug_name)
|
debug_names.add(debug_name)
|
||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
debug_names_str = ", ".join(sorted(debug_names))
|
debug_names_str = ", ".join(sorted(debug_names))
|
||||||
@ -6076,6 +6106,7 @@ class ShapeEnv:
|
|||||||
Get a list of guards, but pruned so it only provides guards that
|
Get a list of guards, but pruned so it only provides guards that
|
||||||
reference symints from the passed in input
|
reference symints from the passed in input
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
symints = {
|
symints = {
|
||||||
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
|
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
|
||||||
}
|
}
|
||||||
@ -6121,6 +6152,7 @@ class ShapeEnv:
|
|||||||
else:
|
else:
|
||||||
bindings[-s] = -arg
|
bindings[-s] = -arg
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for t, arg in zip(placeholders, args):
|
for t, arg in zip(placeholders, args):
|
||||||
if t is None:
|
if t is None:
|
||||||
continue
|
continue
|
||||||
@ -6338,6 +6370,7 @@ class ShapeEnv:
|
|||||||
Apply symbol replacements to any symbols in the given expression.
|
Apply symbol replacements to any symbols in the given expression.
|
||||||
"""
|
"""
|
||||||
replacements = {}
|
replacements = {}
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
for s in expr.free_symbols:
|
for s in expr.free_symbols:
|
||||||
r = self._find(s)
|
r = self._find(s)
|
||||||
|
|
||||||
@ -6347,6 +6380,7 @@ class ShapeEnv:
|
|||||||
if not r.is_Symbol or r != s:
|
if not r.is_Symbol or r != s:
|
||||||
replacements[s] = r
|
replacements[s] = r
|
||||||
if replacements:
|
if replacements:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return safe_expand(expr.xreplace(replacements))
|
return safe_expand(expr.xreplace(replacements))
|
||||||
else:
|
else:
|
||||||
return expr
|
return expr
|
||||||
@ -7121,6 +7155,7 @@ class ShapeEnv:
|
|||||||
instructions = list(dis.Bytecode(frame.f_code))
|
instructions = list(dis.Bytecode(frame.f_code))
|
||||||
co_lines, offset = inspect.getsourcelines(frame.f_code)
|
co_lines, offset = inspect.getsourcelines(frame.f_code)
|
||||||
start, end, cur = None, None, None
|
start, end, cur = None, None, None
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for i, instr in enumerate(instructions):
|
for i, instr in enumerate(instructions):
|
||||||
if instr.starts_line is not None:
|
if instr.starts_line is not None:
|
||||||
cur = instr.starts_line
|
cur = instr.starts_line
|
||||||
@ -8000,6 +8035,7 @@ def _suggest_fixes_for_data_dependent_error_non_strict(
|
|||||||
if isinstance(leaf, torch.SymInt):
|
if isinstance(leaf, torch.SymInt):
|
||||||
src_map[str(leaf.node.expr)].append(name)
|
src_map[str(leaf.node.expr)].append(name)
|
||||||
elif isinstance(leaf, torch.Tensor):
|
elif isinstance(leaf, torch.Tensor):
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for i, dim in enumerate(leaf.shape):
|
for i, dim in enumerate(leaf.shape):
|
||||||
if isinstance(dim, torch.SymInt):
|
if isinstance(dim, torch.SymInt):
|
||||||
src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]")
|
src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]")
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user