[Pyrefly][Refactor] Replace dict() calls with literal dict syntax for improved readability (#157735)

There are 31 places that I spotted which construct literal dictionaries.

This PR refactors dictionary construction by replacing` dict(...) `calls with `literal {...}` syntax where applicable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157735
Approved by: https://github.com/ezyang, https://github.com/Skylion007
This commit is contained in:
Zeina Migeed
2025-07-08 18:10:29 +00:00
committed by PyTorch MergeBot
parent 0f31445139
commit 4f5be56612
21 changed files with 348 additions and 334 deletions

View File

@ -211,12 +211,12 @@ inline void {{kernel_name}}(
) -> str:
buffer_size = " * ".join(map(str, size_args))
return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render(
dict(
buffer_name=buffer_name,
buffer_dtype=buffer_dtype,
buffer_size=buffer_size,
is_msvc_compiler=cpp_builder.is_msvc_cl(),
)
{
"buffer_name": buffer_name,
"buffer_dtype": buffer_dtype,
"buffer_size": buffer_size,
"is_msvc_compiler": cpp_builder.is_msvc_cl(),
}
)
def is_woq_int4(self):

View File

@ -1179,27 +1179,27 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
instance_definition, instance_type = self._define_gemm_instance(op, evt_name)
options = dict(
alpha=self.alpha,
beta=self.beta,
X=X,
W=W,
Y=Y,
kernel_call_signature=kernel_call_signature,
Bias=Bias,
epilogue_template=epilogue_template,
argument_template=argument_template,
should_swap_xw=should_swap_xw,
template=self,
kernel=kernel,
instance_definition=instance_definition,
instance_type=instance_type,
input_reorder=self.input_reorder,
epilogue_args=evt_args,
test_call_statement=test_call_statement,
op_conf_name=op.configuration_name(),
epilogue_visitor_tree=evt_code,
)
options = {
"alpha": self.alpha,
"beta": self.beta,
"X": X,
"W": W,
"Y": Y,
"kernel_call_signature": kernel_call_signature,
"Bias": Bias,
"epilogue_template": epilogue_template,
"argument_template": argument_template,
"should_swap_xw": should_swap_xw,
"template": self,
"kernel": kernel,
"instance_definition": instance_definition,
"instance_type": instance_type,
"input_reorder": self.input_reorder,
"epilogue_args": evt_args,
"test_call_statement": test_call_statement,
"op_conf_name": op.configuration_name(),
"epilogue_visitor_tree": evt_code,
}
options.update(dict(zip(extra_names, extra_inputs)))
res = self._template_from_string(self._get_template()).render(**options)
if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias):
@ -1587,19 +1587,19 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
before the function call.
"""
options = dict(
alpha=alpha,
beta=beta,
X=X,
W=W,
Y=Y,
Bias=Bias,
template=self,
kernel=kernel,
M="M",
N="N",
epilogue_args=epilogue_args,
)
options = {
"alpha": alpha,
"beta": beta,
"X": X,
"W": W,
"Y": Y,
"Bias": Bias,
"template": self,
"kernel": kernel,
"M": "M",
"N": "N",
"epilogue_args": epilogue_args,
}
assert epilogue_template is not None
if should_swap_xw:
@ -1878,21 +1878,21 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
before the function call.
"""
options = dict(
instance_type=instance_type,
alpha=alpha,
beta=beta,
X=X,
W=W,
Y=Y,
Bias=Bias,
Meta=Meta,
template=self,
kernel=kernel,
M="M",
N="N",
epilogue_args=epilogue_args,
)
options = {
"instance_type": instance_type,
"alpha": alpha,
"beta": beta,
"X": X,
"W": W,
"Y": Y,
"Bias": Bias,
"Meta": Meta,
"template": self,
"kernel": kernel,
"M": "M",
"N": "N",
"epilogue_args": epilogue_args,
}
if epilogue_template is None:
arguments = self._template_from_string(argument_template).render(

View File

@ -86,12 +86,12 @@ def tma_options() -> dict[str, Any]:
def persistent_mm_options(mat1, mat2):
res = dict(
A_ROW_MAJOR=not mat1.layout.is_transposed(),
B_ROW_MAJOR=not mat2.layout.is_transposed(),
NUM_SMS=get_num_sms(),
TMA_SIZE=TMA_DESCRIPTOR_SIZE,
)
res = {
"A_ROW_MAJOR": not mat1.layout.is_transposed(),
"B_ROW_MAJOR": not mat2.layout.is_transposed(),
"NUM_SMS": get_num_sms(),
"TMA_SIZE": TMA_DESCRIPTOR_SIZE,
}
res.update(tma_options())
return res

View File

@ -147,12 +147,12 @@ def grouped_gemm_lowering(
choices: list[ChoiceCaller] = []
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
kwargs = dict(
has_bias=[bias is not None for bias in b],
trans_w=True,
epilogue_creator=None,
act_mapping=dict.fromkeys(range(num_gemm), x),
)
kwargs = {
"has_bias": [bias is not None for bias in b],
"trans_w": True,
"epilogue_creator": None,
"act_mapping": dict.fromkeys(range(num_gemm), x),
}
input_nodes = [x, *w]
input_nodes.extend([bias for bias in b if bias is not None])
@ -353,11 +353,13 @@ def register_onednn_fusion_ops():
buf, attr, scalars=scalars, algorithm=algorithm
)
kwargs = dict(
has_bias=b is not None,
trans_w=True,
epilogue_creator=None if attr == "none" else epilogue_creator,
)
kwargs = {
"has_bias": b is not None,
"trans_w": True,
"epilogue_creator": (
None if attr == "none" else epilogue_creator
),
}
if b is not None:
kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment]
CppGemmTemplate.add_choices(
@ -416,11 +418,12 @@ def register_onednn_fusion_ops():
def epilogue_creator(buf):
return create_epilogue_with_attr(buf, attr, other=y)
kwargs = dict(
has_bias=b is not None,
trans_w=True,
epilogue_creator=epilogue_creator,
)
kwargs = {
"has_bias": b is not None,
"trans_w": True,
"epilogue_creator": epilogue_creator,
}
kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
CppGemmTemplate.add_choices(
choices,

View File

@ -166,123 +166,123 @@ Example::
""",
)
args_and_kwargs = dict(
args_and_kwargs = {
# argument name sufficies separated by double underscore will
# be removed in the final documentation string.
sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
cumsum=(("dim__as_int",), ("dtype=None", "mask=None")),
cumprod=(("dim__as_int",), ("dtype=None", "mask=None")),
amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
norm=(
"sum": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
"prod": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
"cumsum": (("dim__as_int",), ("dtype=None", "mask=None")),
"cumprod": (("dim__as_int",), ("dtype=None", "mask=None")),
"amin": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
"amax": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
"argmin": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
"argmax": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
"mean": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
"median": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
"norm": (
(
"ord",
"dim",
),
("keepdim=False", "dtype=None", "mask=None"),
),
var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
softmin=(("dim__as_int",), ("dtype=None", "mask=None")),
normalize=(
"var": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
"std": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
"logsumexp": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
"softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
"log_softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
"softmin": (("dim__as_int",), ("dtype=None", "mask=None")),
"normalize": (
(
"ord__required",
"dim__as_int",
),
("eps=1e-12", "dtype=None", "mask=None"),
),
)
}
argument_declarations = dict(
dim="""\
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
dim__as_int="""\
dim (int): the dimension along which {operation name} is computed.""",
ord="""\
ord (int, float, optional): the order of vector norm. Default: 2.
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
ord__required="""\
ord (int, float): the order of vector norm. Default: 2.
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
unbiased="""\
unbiased (bool): when True, use Bessel's correction, otherwise, compute
the uncorrected sample variance.""",
eps="""\
eps (float, optional): small value to avoid division by zero. Default: {default}.""",
keepdim="""\
keepdim (bool, optional): whether the output tensor has
:attr:`dim` retained or not. Default: {default}.""",
dtype="""\
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the input tensor is
casted to :attr:`dtype` before the operation is
performed. Default: {default}.""",
mask="""\
mask (:class:`torch.Tensor`, optional): the boolean tensor
containing the binary mask of validity of input tensor
elements.
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
)
argument_declarations = {
"dim": """\
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
"dim__as_int": """\
dim (int): the dimension along which {operation name} is computed.""",
"ord": """\
ord (int, float, optional): the order of vector norm. Default: 2.
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
"ord__required": """\
ord (int, float): the order of vector norm. Default: 2.
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
"unbiased": """\
unbiased (bool): when True, use Bessel's correction, otherwise, compute
the uncorrected sample variance.""",
"eps": """\
eps (float, optional): small value to avoid division by zero. Default: {default}.""",
"keepdim": """\
keepdim (bool, optional): whether the output tensor has
:attr:`dim` retained or not. Default: {default}.""",
"dtype": """\
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the input tensor is
casted to :attr:`dtype` before the operation is
performed. Default: {default}.""",
"mask": """\
mask (:class:`torch.Tensor`, optional): the boolean tensor
containing the binary mask of validity of input tensor
elements.
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
}
definitions = dict(
softmax="""\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
defined as ``exp(x[i])/sum(exp(x))``.""",
log_softmax="""\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
defined as ``log(exp(x[i])/sum(exp(x)))``.""",
softmin="""\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
defined as ``exp(-x[i])/sum(exp(-x))``.""",
normalize="""\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
defined as ``x[i]/max(norm(x, p), eps)``.""",
cumsum="""\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
defined as ``sum(x[:i])``.""",
cumprod="""\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
defined as ``prod(x[:i])``.""",
)
definitions = {
"softmax": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
defined as ``exp(x[i])/sum(exp(x))``.""",
"log_softmax": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
defined as ``log(exp(x[i])/sum(exp(x)))``.""",
"softmin": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
defined as ``exp(-x[i])/sum(exp(-x))``.""",
"normalize": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
defined as ``x[i]/max(norm(x, p), eps)``.""",
"cumsum": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
defined as ``sum(x[:i])``.""",
"cumprod": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
defined as ``prod(x[:i])``.""",
}
reduction_names = dict(
sum="sum",
prod="product",
amax="maximum",
amin="minimum",
argmax="argmax",
argmin="argmin",
mean="mean",
median="median",
norm="norm",
var="variance",
std="standard_deviation",
logsumexp="logsumexp",
)
reduction_names = {
"sum": "sum",
"prod": "product",
"amax": "maximum",
"amin": "minimum",
"argmax": "argmax",
"argmin": "argmin",
"mean": "mean",
"median": "median",
"norm": "norm",
"var": "variance",
"std": "standard_deviation",
"logsumexp": "logsumexp",
}
normalization_names = dict(
softmax="softmax",
log_softmax="log_softmax",
softmin="softmin",
normalize="normalize",
cumsum="cumulative_sum",
cumprod="cumulative_prod",
)
normalization_names = {
"softmax": "softmax",
"log_softmax": "log_softmax",
"softmin": "softmin",
"normalize": "normalize",
"cumsum": "cumulative_sum",
"cumprod": "cumulative_prod",
}
operation_names = {}
operation_names.update(reduction_names)

View File

@ -47,15 +47,15 @@ class Adafactor(Optimizer):
raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}")
if not 0.0 <= weight_decay:
raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}")
defaults = dict(
lr=lr,
beta2_decay=beta2_decay,
eps=eps,
d=d,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
)
defaults = {
"lr": lr,
"beta2_decay": beta2_decay,
"eps": eps,
"d": d,
"weight_decay": weight_decay,
"foreach": foreach,
"maximize": maximize,
}
super().__init__(params, defaults)
def __setstate__(self, state):

View File

@ -50,16 +50,16 @@ class Adadelta(Optimizer):
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
maximize=maximize,
capturable=capturable,
foreach=foreach,
differentiable=differentiable,
)
defaults = {
"lr": lr,
"rho": rho,
"eps": eps,
"weight_decay": weight_decay,
"maximize": maximize,
"capturable": capturable,
"foreach": foreach,
"differentiable": differentiable,
}
super().__init__(params, defaults)
def __setstate__(self, state):

View File

@ -54,17 +54,17 @@ class Adagrad(Optimizer):
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
defaults = dict(
lr=lr,
lr_decay=lr_decay,
eps=eps,
weight_decay=weight_decay,
initial_accumulator_value=initial_accumulator_value,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
fused=fused,
)
defaults = {
"lr": lr,
"lr_decay": lr_decay,
"eps": eps,
"weight_decay": weight_decay,
"initial_accumulator_value": initial_accumulator_value,
"foreach": foreach,
"maximize": maximize,
"differentiable": differentiable,
"fused": fused,
}
super().__init__(params, defaults)
if fused:

View File

@ -85,19 +85,19 @@ class Adam(Optimizer):
if betas[1].numel() != 1:
raise ValueError("Tensor betas[1] must be 1-element")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
maximize=maximize,
foreach=foreach,
capturable=capturable,
differentiable=differentiable,
fused=fused,
decoupled_weight_decay=decoupled_weight_decay,
)
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"amsgrad": amsgrad,
"maximize": maximize,
"foreach": foreach,
"capturable": capturable,
"differentiable": differentiable,
"fused": fused,
"decoupled_weight_decay": decoupled_weight_decay,
}
super().__init__(params, defaults)
if fused:

View File

@ -53,16 +53,16 @@ class Adamax(Optimizer):
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"foreach": foreach,
"maximize": maximize,
"differentiable": differentiable,
"capturable": capturable,
}
super().__init__(params, defaults)
def __setstate__(self, state):

View File

@ -47,17 +47,17 @@ class ASGD(Optimizer):
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
lambd=lambd,
alpha=alpha,
t0=t0,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)
defaults = {
"lr": lr,
"lambd": lambd,
"alpha": alpha,
"t0": t0,
"weight_decay": weight_decay,
"foreach": foreach,
"maximize": maximize,
"differentiable": differentiable,
"capturable": capturable,
}
super().__init__(params, defaults)
def __setstate__(self, state):

View File

@ -231,15 +231,15 @@ class LBFGS(Optimizer):
raise ValueError(f"Invalid learning rate: {lr}")
if max_eval is None:
max_eval = max_iter * 5 // 4
defaults = dict(
lr=lr,
max_iter=max_iter,
max_eval=max_eval,
tolerance_grad=tolerance_grad,
tolerance_change=tolerance_change,
history_size=history_size,
line_search_fn=line_search_fn,
)
defaults = {
"lr": lr,
"max_iter": max_iter,
"max_eval": max_eval,
"tolerance_grad": tolerance_grad,
"tolerance_change": tolerance_change,
"history_size": history_size,
"line_search_fn": line_search_fn,
}
super().__init__(params, defaults)
if len(self.param_groups) != 1:

View File

@ -59,18 +59,18 @@ class NAdam(Optimizer): # noqa: D101
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= momentum_decay:
raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
momentum_decay=momentum_decay,
decoupled_weight_decay=decoupled_weight_decay,
maximize=maximize,
foreach=foreach,
capturable=capturable,
differentiable=differentiable,
)
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"momentum_decay": momentum_decay,
"decoupled_weight_decay": decoupled_weight_decay,
"maximize": maximize,
"foreach": foreach,
"capturable": capturable,
"differentiable": differentiable,
}
super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105

View File

@ -56,17 +56,17 @@ class RAdam(Optimizer): # noqa: D101
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
maximize=maximize,
foreach=foreach,
capturable=capturable,
decoupled_weight_decay=decoupled_weight_decay,
differentiable=differentiable,
)
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"maximize": maximize,
"foreach": foreach,
"capturable": capturable,
"decoupled_weight_decay": decoupled_weight_decay,
"differentiable": differentiable,
}
super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105

View File

@ -55,18 +55,18 @@ class RMSprop(Optimizer): # noqa: D101
if not 0.0 <= alpha:
raise ValueError(f"Invalid alpha value: {alpha}")
defaults = dict(
lr=lr,
momentum=momentum,
alpha=alpha,
eps=eps,
centered=centered,
weight_decay=weight_decay,
capturable=capturable,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
)
defaults = {
"lr": lr,
"momentum": momentum,
"alpha": alpha,
"eps": eps,
"centered": centered,
"weight_decay": weight_decay,
"capturable": capturable,
"foreach": foreach,
"maximize": maximize,
"differentiable": differentiable,
}
super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105

View File

@ -47,15 +47,15 @@ class Rprop(Optimizer): # noqa: D101
if not 0.0 < etas[0] < 1.0 < etas[1]:
raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
defaults = dict(
lr=lr,
etas=etas,
step_sizes=step_sizes,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)
defaults = {
"lr": lr,
"etas": etas,
"step_sizes": step_sizes,
"foreach": foreach,
"maximize": maximize,
"differentiable": differentiable,
"capturable": capturable,
}
super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105

View File

@ -49,17 +49,17 @@ class SGD(Optimizer): # noqa: D101
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
maximize=maximize,
foreach=foreach,
differentiable=differentiable,
fused=fused,
)
defaults = {
"lr": lr,
"momentum": momentum,
"dampening": dampening,
"weight_decay": weight_decay,
"nesterov": nesterov,
"maximize": maximize,
"foreach": foreach,
"differentiable": differentiable,
"fused": fused,
}
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)

View File

@ -31,7 +31,12 @@ class SparseAdam(Optimizer):
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize)
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"maximize": maximize,
}
super().__init__(params, defaults)
sparse_params = []

View File

@ -1109,15 +1109,15 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
return res
return None
sys_info = dict(
protocol_version=PROTOCOL_VERSION,
little_endian=sys.byteorder == "little",
type_sizes=dict(
short=SHORT_SIZE,
int=INT_SIZE,
long=LONG_SIZE,
),
)
sys_info = {
"protocol_version": PROTOCOL_VERSION,
"little_endian": sys.byteorder == "little",
"type_sizes": {
"short": SHORT_SIZE,
"int": INT_SIZE,
"long": LONG_SIZE,
},
}
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)

View File

@ -598,7 +598,10 @@ def as_sparse_gradcheck(gradcheck):
and obj.requires_grad
and obj.layout in sparse_layouts
):
d = dict(layout=obj.layout, shape=obj.shape)
d = {
"layout": obj.layout,
"shape": obj.shape,
}
if not masked:
# Materialize unspecified elements with zero values
batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()

View File

@ -213,13 +213,14 @@ def get_model_info(
path_prefix = prefix
elif prefix != path_prefix:
raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002
zip_files.append(dict(
filename=zi.filename,
compression=zi.compress_type,
compressed_size=zi.compress_size,
file_size=zi.file_size,
))
zip_files.append(
{
"filename": zi.filename,
"compression": zi.compress_type,
"compressed_size": zi.compress_size,
"file_size": zi.file_size,
}
)
assert path_prefix is not None
version = zf.read(path_prefix + "/version").decode("utf-8").strip()
@ -332,18 +333,20 @@ def get_model_info(
continue
extra_pickles[zi.filename] = contents
return {"model": dict(
title=title,
file_size=file_size,
version=version,
zip_files=zip_files,
interned_strings=list(interned_strings),
code_files=code_files,
model_data=model_data,
constants=constants,
extra_files_jsons=extra_files_jsons,
extra_pickles=extra_pickles,
)}
return {
"model": {
"title": title,
"file_size": file_size,
"version": version,
"zip_files": zip_files,
"interned_strings": list(interned_strings),
"code_files": code_files,
"model_data": model_data,
"constants": constants,
"extra_files_jsons": extra_files_jsons,
"extra_pickles": extra_pickles,
}
}
def get_inline_skeleton():