mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Subconfig is a class attribute (#41308)
* delete * fix this test * fix copies * oke, more tests to fix * fix last tests on DPT * deleted accidentally
This commit is contained in:
committed by
GitHub
parent
8137dbdbbd
commit
be3fa93b29
@ -880,7 +880,6 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
isinstance(getattr(self, key, None), PreTrainedConfig)
|
||||
and key in class_config_dict
|
||||
and isinstance(class_config_dict[key], dict)
|
||||
or key in self.sub_configs
|
||||
):
|
||||
# For nested configs we need to clean the diff recursively
|
||||
diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
|
||||
|
@ -1219,13 +1219,13 @@ def _get_dtype(
|
||||
dtype = getattr(torch, dtype)
|
||||
config.dtype = dtype
|
||||
for sub_config_key in config.sub_configs:
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.dtype = dtype
|
||||
if (sub_config := getattr(config, sub_config_key)) is not None:
|
||||
sub_config.dtype = dtype
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
config.dtype = dtype
|
||||
for sub_config_key in config.sub_configs:
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.dtype = dtype
|
||||
if (sub_config := getattr(config, sub_config_key)) is not None:
|
||||
sub_config.dtype = dtype
|
||||
elif isinstance(dtype, dict):
|
||||
for key, curr_dtype in dtype.items():
|
||||
if hasattr(config, key):
|
||||
@ -1250,8 +1250,8 @@ def _get_dtype(
|
||||
default_dtype = torch.get_default_dtype()
|
||||
config.dtype = default_dtype
|
||||
for key in config.sub_configs:
|
||||
value = getattr(config, key)
|
||||
value.dtype = default_dtype
|
||||
if (sub_config := getattr(config, key)) is not None:
|
||||
sub_config.dtype = default_dtype
|
||||
|
||||
return config, dtype, dtype_orig
|
||||
|
||||
@ -2700,34 +2700,34 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
|
||||
# We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
|
||||
for subconfig_key in self.config.sub_configs:
|
||||
subconfig = getattr(self.config, subconfig_key)
|
||||
sub_implementation = (
|
||||
requested_implementation
|
||||
if not isinstance(attn_implementation, dict)
|
||||
else attn_implementation.get(subconfig_key, subconfig._attn_implementation)
|
||||
)
|
||||
# This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered
|
||||
if (
|
||||
not hasattr(subconfig, "_attn_was_changed")
|
||||
# If it's already the same, then no need to enter here and raise warnings
|
||||
and sub_implementation != subconfig._attn_implementation
|
||||
):
|
||||
if sub_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
raise ValueError(
|
||||
f'Specified `attn_implementation="{sub_implementation}"` is not supported for {subconfig_key}. '
|
||||
'The only possible arguments are "eager" (manual attention implementation)'
|
||||
f"or one of the following: {list(ALL_ATTENTION_FUNCTIONS.valid_keys())}"
|
||||
)
|
||||
subconfig._attn_implementation_internal = sub_implementation
|
||||
logger.warning(
|
||||
f"We set the attention implementation for the sub-config `{subconfig_key}` to `{sub_implementation}` "
|
||||
"without finding the associated sub-model. For this reason we could not check if the model supports it. "
|
||||
"You may encounter undefined behavior."
|
||||
if (subconfig := getattr(self.config, subconfig_key)) is not None:
|
||||
sub_implementation = (
|
||||
requested_implementation
|
||||
if not isinstance(attn_implementation, dict)
|
||||
else attn_implementation.get(subconfig_key, subconfig._attn_implementation)
|
||||
)
|
||||
# Unset the attribute in this case, to avoid issues in the future
|
||||
else:
|
||||
if hasattr(subconfig, "_attn_was_changed"):
|
||||
del subconfig._attn_was_changed
|
||||
# This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered
|
||||
if (
|
||||
not hasattr(subconfig, "_attn_was_changed")
|
||||
# If it's already the same, then no need to enter here and raise warnings
|
||||
and sub_implementation != subconfig._attn_implementation
|
||||
):
|
||||
if sub_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
raise ValueError(
|
||||
f'Specified `attn_implementation="{sub_implementation}"` is not supported for {subconfig_key}. '
|
||||
'The only possible arguments are "eager" (manual attention implementation)'
|
||||
f"or one of the following: {list(ALL_ATTENTION_FUNCTIONS.valid_keys())}"
|
||||
)
|
||||
subconfig._attn_implementation_internal = sub_implementation
|
||||
logger.warning(
|
||||
f"We set the attention implementation for the sub-config `{subconfig_key}` to `{sub_implementation}` "
|
||||
"without finding the associated sub-model. For this reason we could not check if the model supports it. "
|
||||
"You may encounter undefined behavior."
|
||||
)
|
||||
# Unset the attribute in this case, to avoid issues in the future
|
||||
else:
|
||||
if hasattr(subconfig, "_attn_was_changed"):
|
||||
del subconfig._attn_was_changed
|
||||
|
||||
def enable_input_require_grads(self):
|
||||
"""
|
||||
|
@ -23,7 +23,7 @@ from ...configuration_utils import PreTrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -135,6 +135,7 @@ class ConditionalDetrConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "conditional_detr"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
@ -245,22 +246,6 @@ class ConditionalDetrConfig(PreTrainedConfig):
|
||||
self.focal_alpha = focal_alpha
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
class ConditionalDetrOnnxConfig(OnnxConfig):
|
||||
torch_onnx_minimum_version = version.parse("1.11")
|
||||
|
@ -21,7 +21,7 @@
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -194,6 +194,7 @@ class DFineConfig(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "d_fine"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
layer_types = ["basic", "bottleneck"]
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
@ -396,22 +397,6 @@ class DFineConfig(PreTrainedConfig):
|
||||
)
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_backbone_configs(cls, backbone_config: PreTrainedConfig, **kwargs):
|
||||
"""Instantiate a [`DFineConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
||||
|
@ -25,7 +25,7 @@ from ...configuration_utils import PreTrainedConfig
|
||||
from ...image_transforms import corners_to_center_format
|
||||
from ...utils import is_torchdynamo_compiling, logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
from ..rt_detr.modeling_rt_detr import (
|
||||
RTDetrConvNormLayer,
|
||||
RTDetrDecoder,
|
||||
@ -213,6 +213,7 @@ class DFineConfig(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "d_fine"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
layer_types = ["basic", "bottleneck"]
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
@ -415,22 +416,6 @@ class DFineConfig(PreTrainedConfig):
|
||||
)
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_backbone_configs(cls, backbone_config: PreTrainedConfig, **kwargs):
|
||||
"""Instantiate a [`DFineConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
||||
|
@ -17,7 +17,7 @@
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -136,6 +136,7 @@ class DabDetrConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "dab-detr"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"num_attention_heads": "encoder_attention_heads",
|
||||
@ -256,13 +257,5 @@ class DabDetrConfig(PreTrainedConfig):
|
||||
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DabDetrConfig"]
|
||||
|
@ -17,7 +17,7 @@
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -144,6 +144,7 @@ class DeformableDetrConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "deformable_detr"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
"num_attention_heads": "encoder_attention_heads",
|
||||
@ -270,21 +271,5 @@ class DeformableDetrConfig(PreTrainedConfig):
|
||||
self.disable_custom_kernels = disable_custom_kernels
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DeformableDetrConfig"]
|
||||
|
@ -14,12 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""DepthAnything model configuration"""
|
||||
|
||||
import copy
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -89,6 +87,7 @@ class DepthAnythingConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "depth_anything"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -151,26 +150,5 @@ class DepthAnythingConfig(PreTrainedConfig):
|
||||
self.depth_estimation_type = depth_estimation_type
|
||||
self.max_depth = max_depth if max_depth else 1
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PreTrainedConfig.to_dict`]. Returns:
|
||||
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
|
||||
if output["backbone_config"] is not None:
|
||||
output["backbone_config"] = self.backbone_config.to_dict()
|
||||
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
|
||||
|
||||
__all__ = ["DepthAnythingConfig"]
|
||||
|
@ -23,7 +23,7 @@ from ...configuration_utils import PreTrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -133,6 +133,7 @@ class DetrConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "detr"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
@ -244,22 +245,6 @@ class DetrConfig(PreTrainedConfig):
|
||||
self.eos_coefficient = eos_coefficient
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_backbone_config(cls, backbone_config: PreTrainedConfig, **kwargs):
|
||||
"""Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
|
||||
|
@ -14,12 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""DPT model configuration"""
|
||||
|
||||
import copy
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||
from ..bit import BitConfig
|
||||
|
||||
|
||||
@ -140,6 +138,7 @@ class DPTConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "dpt"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -275,26 +274,5 @@ class DPTConfig(PreTrainedConfig):
|
||||
self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
|
||||
self.pooler_act = pooler_act
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PreTrainedConfig.to_dict`]. Returns:
|
||||
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
|
||||
if output["backbone_config"] is not None:
|
||||
output["backbone_config"] = self.backbone_config.to_dict()
|
||||
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DPTConfig"]
|
||||
|
@ -23,7 +23,159 @@ from ...utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# TODO Update this
|
||||
|
||||
@dataclass
|
||||
class StructureModuleConfig:
|
||||
"""
|
||||
Args:
|
||||
sequence_dim:
|
||||
Single representation channel dimension
|
||||
pairwise_dim:
|
||||
Pair representation channel dimension
|
||||
ipa_dim:
|
||||
IPA hidden channel dimension
|
||||
resnet_dim:
|
||||
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
|
||||
num_heads_ipa:
|
||||
Number of IPA heads
|
||||
num_qk_points:
|
||||
Number of query/key points to generate during IPA
|
||||
num_v_points:
|
||||
Number of value points to generate during IPA
|
||||
dropout_rate:
|
||||
Dropout rate used throughout the layer
|
||||
num_blocks:
|
||||
Number of structure module blocks
|
||||
num_transition_layers:
|
||||
Number of layers in the single representation transition (Alg. 23 lines 8-9)
|
||||
num_resnet_blocks:
|
||||
Number of blocks in the angle resnet
|
||||
num_angles:
|
||||
Number of angles to generate in the angle resnet
|
||||
trans_scale_factor:
|
||||
Scale of single representation transition hidden dimension
|
||||
epsilon:
|
||||
Small number used in angle resnet normalization
|
||||
inf:
|
||||
Large number used for attention masking
|
||||
"""
|
||||
|
||||
sequence_dim: int = 384
|
||||
pairwise_dim: int = 128
|
||||
ipa_dim: int = 16
|
||||
resnet_dim: int = 128
|
||||
num_heads_ipa: int = 12
|
||||
num_qk_points: int = 4
|
||||
num_v_points: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
num_blocks: int = 8
|
||||
num_transition_layers: int = 1
|
||||
num_resnet_blocks: int = 2
|
||||
num_angles: int = 7
|
||||
trans_scale_factor: int = 10
|
||||
epsilon: float = 1e-8
|
||||
inf: float = 1e5
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrunkConfig:
|
||||
num_blocks: int = 48
|
||||
sequence_state_dim: int = 1024
|
||||
pairwise_state_dim: int = 128
|
||||
sequence_head_width: int = 32
|
||||
pairwise_head_width: int = 32
|
||||
position_bins: int = 32
|
||||
dropout: float = 0
|
||||
layer_drop: float = 0
|
||||
cpu_grad_checkpoint: bool = False
|
||||
max_recycles: int = 4
|
||||
chunk_size: Optional[int] = 128
|
||||
structure_module: "StructureModuleConfig" = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.structure_module is None:
|
||||
self.structure_module = StructureModuleConfig()
|
||||
elif isinstance(self.structure_module, dict):
|
||||
self.structure_module = StructureModuleConfig(**self.structure_module)
|
||||
|
||||
if self.max_recycles <= 0:
|
||||
raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
|
||||
if self.sequence_state_dim % self.sequence_state_dim != 0:
|
||||
raise ValueError(
|
||||
"`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
|
||||
f" {self.sequence_state_dim} and {self.sequence_state_dim}."
|
||||
)
|
||||
if self.pairwise_state_dim % self.pairwise_state_dim != 0:
|
||||
raise ValueError(
|
||||
"`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
|
||||
f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
|
||||
)
|
||||
|
||||
sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
|
||||
pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
|
||||
|
||||
if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
|
||||
raise ValueError(
|
||||
"`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
|
||||
f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
|
||||
)
|
||||
if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
|
||||
raise ValueError(
|
||||
"`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
|
||||
f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
|
||||
)
|
||||
if self.pairwise_state_dim % 2 != 0:
|
||||
raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")
|
||||
|
||||
if self.dropout >= 0.4:
|
||||
raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = asdict(self)
|
||||
output["structure_module"] = self.structure_module.to_dict()
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class EsmFoldConfig:
|
||||
esm_type: Optional[str] = None
|
||||
fp16_esm: bool = True
|
||||
use_esm_attn_map: bool = False
|
||||
esm_ablate_pairwise: bool = False
|
||||
esm_ablate_sequence: bool = False
|
||||
esm_input_dropout: float = 0
|
||||
|
||||
embed_aa: bool = True
|
||||
bypass_lm: bool = False
|
||||
|
||||
lddt_head_hid_dim: int = 128
|
||||
trunk: "TrunkConfig" = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.trunk is None:
|
||||
self.trunk = TrunkConfig()
|
||||
elif isinstance(self.trunk, dict):
|
||||
self.trunk = TrunkConfig(**self.trunk)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = asdict(self)
|
||||
output["trunk"] = self.trunk.to_dict()
|
||||
return output
|
||||
|
||||
|
||||
class EsmConfig(PreTrainedConfig):
|
||||
@ -94,6 +246,7 @@ class EsmConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "esm"
|
||||
sub_configs = {"esmfold_config": EsmFoldConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -153,6 +306,7 @@ class EsmConfig(PreTrainedConfig):
|
||||
if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
|
||||
raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
|
||||
|
||||
# TODO: update ESM to inherit from PreTrainedConfig
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PreTrainedConfig.to_dict`].
|
||||
@ -166,160 +320,6 @@ class EsmConfig(PreTrainedConfig):
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class EsmFoldConfig:
|
||||
esm_type: Optional[str] = None
|
||||
fp16_esm: bool = True
|
||||
use_esm_attn_map: bool = False
|
||||
esm_ablate_pairwise: bool = False
|
||||
esm_ablate_sequence: bool = False
|
||||
esm_input_dropout: float = 0
|
||||
|
||||
embed_aa: bool = True
|
||||
bypass_lm: bool = False
|
||||
|
||||
lddt_head_hid_dim: int = 128
|
||||
trunk: "TrunkConfig" = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.trunk is None:
|
||||
self.trunk = TrunkConfig()
|
||||
elif isinstance(self.trunk, dict):
|
||||
self.trunk = TrunkConfig(**self.trunk)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PreTrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = asdict(self)
|
||||
output["trunk"] = self.trunk.to_dict()
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrunkConfig:
|
||||
num_blocks: int = 48
|
||||
sequence_state_dim: int = 1024
|
||||
pairwise_state_dim: int = 128
|
||||
sequence_head_width: int = 32
|
||||
pairwise_head_width: int = 32
|
||||
position_bins: int = 32
|
||||
dropout: float = 0
|
||||
layer_drop: float = 0
|
||||
cpu_grad_checkpoint: bool = False
|
||||
max_recycles: int = 4
|
||||
chunk_size: Optional[int] = 128
|
||||
structure_module: "StructureModuleConfig" = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.structure_module is None:
|
||||
self.structure_module = StructureModuleConfig()
|
||||
elif isinstance(self.structure_module, dict):
|
||||
self.structure_module = StructureModuleConfig(**self.structure_module)
|
||||
|
||||
if self.max_recycles <= 0:
|
||||
raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
|
||||
if self.sequence_state_dim % self.sequence_state_dim != 0:
|
||||
raise ValueError(
|
||||
"`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
|
||||
f" {self.sequence_state_dim} and {self.sequence_state_dim}."
|
||||
)
|
||||
if self.pairwise_state_dim % self.pairwise_state_dim != 0:
|
||||
raise ValueError(
|
||||
"`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
|
||||
f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
|
||||
)
|
||||
|
||||
sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
|
||||
pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
|
||||
|
||||
if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
|
||||
raise ValueError(
|
||||
"`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
|
||||
f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
|
||||
)
|
||||
if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
|
||||
raise ValueError(
|
||||
"`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
|
||||
f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
|
||||
)
|
||||
if self.pairwise_state_dim % 2 != 0:
|
||||
raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")
|
||||
|
||||
if self.dropout >= 0.4:
|
||||
raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PreTrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = asdict(self)
|
||||
output["structure_module"] = self.structure_module.to_dict()
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructureModuleConfig:
|
||||
"""
|
||||
Args:
|
||||
sequence_dim:
|
||||
Single representation channel dimension
|
||||
pairwise_dim:
|
||||
Pair representation channel dimension
|
||||
ipa_dim:
|
||||
IPA hidden channel dimension
|
||||
resnet_dim:
|
||||
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
|
||||
num_heads_ipa:
|
||||
Number of IPA heads
|
||||
num_qk_points:
|
||||
Number of query/key points to generate during IPA
|
||||
num_v_points:
|
||||
Number of value points to generate during IPA
|
||||
dropout_rate:
|
||||
Dropout rate used throughout the layer
|
||||
num_blocks:
|
||||
Number of structure module blocks
|
||||
num_transition_layers:
|
||||
Number of layers in the single representation transition (Alg. 23 lines 8-9)
|
||||
num_resnet_blocks:
|
||||
Number of blocks in the angle resnet
|
||||
num_angles:
|
||||
Number of angles to generate in the angle resnet
|
||||
trans_scale_factor:
|
||||
Scale of single representation transition hidden dimension
|
||||
epsilon:
|
||||
Small number used in angle resnet normalization
|
||||
inf:
|
||||
Large number used for attention masking
|
||||
"""
|
||||
|
||||
sequence_dim: int = 384
|
||||
pairwise_dim: int = 128
|
||||
ipa_dim: int = 16
|
||||
resnet_dim: int = 128
|
||||
num_heads_ipa: int = 12
|
||||
num_qk_points: int = 4
|
||||
num_v_points: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
num_blocks: int = 8
|
||||
num_transition_layers: int = 1
|
||||
num_resnet_blocks: int = 2
|
||||
num_angles: int = 7
|
||||
trans_scale_factor: int = 10
|
||||
epsilon: float = 1e-8
|
||||
inf: float = 1e5
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def get_default_vocab_list():
|
||||
return (
|
||||
"<cls>",
|
||||
|
@ -17,7 +17,7 @@
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -146,6 +146,7 @@ class GroundingDinoConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "grounding-dino"
|
||||
sub_configs = {"backbone_config": AutoConfig, "text_config": AutoConfig}
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
"num_attention_heads": "encoder_attention_heads",
|
||||
@ -286,24 +287,5 @@ class GroundingDinoConfig(PreTrainedConfig):
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
sub_configs = {}
|
||||
backbone_config = getattr(self, "backbone_config", None)
|
||||
text_config = getattr(self, "text_config", None)
|
||||
if isinstance(backbone_config, PreTrainedConfig):
|
||||
sub_configs["backbone_config"] = type(backbone_config)
|
||||
if isinstance(text_config, PreTrainedConfig):
|
||||
sub_configs["text_config"] = type(self.text_config)
|
||||
return sub_configs
|
||||
|
||||
|
||||
__all__ = ["GroundingDinoConfig"]
|
||||
|
@ -19,7 +19,7 @@ from typing import Optional
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -128,6 +128,7 @@ class Mask2FormerConfig(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "mask2former"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
backbones_supported = ["swin"]
|
||||
attribute_map = {"hidden_size": "hidden_dim"}
|
||||
|
||||
@ -236,14 +237,6 @@ class Mask2FormerConfig(PreTrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_backbone_config(cls, backbone_config: PreTrainedConfig, **kwargs):
|
||||
"""Instantiate a [`Mask2FormerConfig`] (or a derived class) from a pre-trained backbone model configuration.
|
||||
|
@ -19,7 +19,7 @@ from typing import Optional
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
from ..detr import DetrConfig
|
||||
from ..swin import SwinConfig
|
||||
|
||||
@ -103,6 +103,7 @@ class MaskFormerConfig(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "maskformer"
|
||||
sub_configs = {"backbone_config": AutoConfig, "decoder_config": AutoConfig}
|
||||
attribute_map = {"hidden_size": "mask_feature_size"}
|
||||
backbones_supported = ["resnet", "swin"]
|
||||
decoders_supported = ["detr"]
|
||||
@ -200,15 +201,6 @@ class MaskFormerConfig(PreTrainedConfig):
|
||||
self.backbone_kwargs = backbone_kwargs
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
sub_configs = {}
|
||||
if self.backbone_config is not None and self.backbone_config != {}:
|
||||
sub_configs["backbone_config"] = type(self.backbone_config)
|
||||
if self.decoder_config is not None and self.decoder_config != {}:
|
||||
sub_configs["decoder_config"] = type(self.decoder_config)
|
||||
return sub_configs
|
||||
|
||||
@classmethod
|
||||
def from_backbone_and_decoder_configs(
|
||||
cls, backbone_config: PreTrainedConfig, decoder_config: PreTrainedConfig, **kwargs
|
||||
|
@ -22,7 +22,7 @@
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -146,6 +146,7 @@ class MMGroundingDinoConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "mm-grounding-dino"
|
||||
sub_configs = {"backbone_config": AutoConfig, "text_config": AutoConfig}
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
"num_attention_heads": "encoder_attention_heads",
|
||||
@ -280,24 +281,5 @@ class MMGroundingDinoConfig(PreTrainedConfig):
|
||||
self.init_std = init_std
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
sub_configs = {}
|
||||
backbone_config = getattr(self, "backbone_config", None)
|
||||
text_config = getattr(self, "text_config", None)
|
||||
if isinstance(backbone_config, PreTrainedConfig):
|
||||
sub_configs["backbone_config"] = type(backbone_config)
|
||||
if isinstance(text_config, PreTrainedConfig):
|
||||
sub_configs["text_config"] = type(self.text_config)
|
||||
return sub_configs
|
||||
|
||||
|
||||
__all__ = ["MMGroundingDinoConfig"]
|
||||
|
@ -17,7 +17,7 @@
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -145,6 +145,7 @@ class OmDetTurboConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "omdet-turbo"
|
||||
sub_configs = {"backbone_config": AutoConfig, "text_config": AutoConfig}
|
||||
attribute_map = {
|
||||
"encoder_hidden_dim": "d_model",
|
||||
"num_attention_heads": "encoder_attention_heads",
|
||||
@ -289,16 +290,5 @@ class OmDetTurboConfig(PreTrainedConfig):
|
||||
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
sub_configs = {}
|
||||
backbone_config = getattr(self, "backbone_config", None)
|
||||
text_config = getattr(self, "text_config", None)
|
||||
if isinstance(backbone_config, PreTrainedConfig):
|
||||
sub_configs["backbone_config"] = type(backbone_config)
|
||||
if isinstance(text_config, PreTrainedConfig):
|
||||
sub_configs["text_config"] = type(text_config)
|
||||
return sub_configs
|
||||
|
||||
|
||||
__all__ = ["OmDetTurboConfig"]
|
||||
|
@ -19,7 +19,7 @@ from typing import Optional
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -146,6 +146,7 @@ class OneFormerConfig(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "oneformer"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
attribute_map = {"hidden_size": "hidden_dim"}
|
||||
|
||||
def __init__(
|
||||
@ -273,13 +274,5 @@ class OneFormerConfig(PreTrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["OneFormerConfig"]
|
||||
|
@ -152,13 +152,5 @@ class PegasusConfig(PreTrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
|
||||
__all__ = ["PegasusConfig"]
|
||||
|
@ -165,13 +165,5 @@ class PegasusXConfig(PreTrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
|
||||
__all__ = ["PegasusXConfig"]
|
||||
|
@ -17,12 +17,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -92,6 +90,7 @@ class PromptDepthAnythingConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "prompt_depth_anything"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -154,26 +153,5 @@ class PromptDepthAnythingConfig(PreTrainedConfig):
|
||||
self.depth_estimation_type = depth_estimation_type
|
||||
self.max_depth = max_depth if max_depth else 1
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PreTrainedConfig.to_dict`]. Returns:
|
||||
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
|
||||
if output["backbone_config"] is not None:
|
||||
output["backbone_config"] = self.backbone_config.to_dict()
|
||||
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
|
||||
|
||||
__all__ = ["PromptDepthAnythingConfig"]
|
||||
|
@ -17,7 +17,7 @@
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
from .configuration_rt_detr_resnet import RTDetrResNetConfig
|
||||
|
||||
|
||||
@ -175,6 +175,7 @@ class RTDetrConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "rt_detr"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
layer_types = ["basic", "bottleneck"]
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
@ -335,22 +336,6 @@ class RTDetrConfig(PreTrainedConfig):
|
||||
self.eos_coefficient = eos_coefficient
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_backbone_configs(cls, backbone_config: PreTrainedConfig, **kwargs):
|
||||
"""Instantiate a [`RTDetrConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
||||
|
@ -22,7 +22,7 @@
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -185,6 +185,7 @@ class RTDetrV2Config(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "rt_detr_v2"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
layer_types = ["basic", "bottleneck"]
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
@ -358,14 +359,6 @@ class RTDetrV2Config(PreTrainedConfig):
|
||||
self.decoder_offset_scale = decoder_offset_scale
|
||||
self.decoder_method = decoder_method
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_backbone_configs(cls, backbone_config: PreTrainedConfig, **kwargs):
|
||||
"""Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
||||
|
@ -25,7 +25,7 @@ from ...utils import is_torchdynamo_compiling, logging
|
||||
from ...utils.backbone_utils import (
|
||||
verify_backbone_config_arguments,
|
||||
)
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
from ..rt_detr.modeling_rt_detr import (
|
||||
RTDetrDecoder,
|
||||
RTDetrDecoderLayer,
|
||||
@ -196,6 +196,7 @@ class RTDetrV2Config(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "rt_detr_v2"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
layer_types = ["basic", "bottleneck"]
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
@ -369,14 +370,6 @@ class RTDetrV2Config(PreTrainedConfig):
|
||||
self.decoder_offset_scale = decoder_offset_scale
|
||||
self.decoder_method = decoder_method
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_backbone_configs(cls, backbone_config: PreTrainedConfig, **kwargs):
|
||||
"""Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
||||
|
@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -68,6 +68,7 @@ class SuperGlueConfig(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "superglue"
|
||||
sub_configs = {"keypoint_detector_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -114,9 +115,5 @@ class SuperGlueConfig(PreTrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return {"keypoint_detector_config": type(self.keypoint_detector_config)}
|
||||
|
||||
|
||||
__all__ = ["SuperGlueConfig"]
|
||||
|
@ -23,7 +23,7 @@ from ...configuration_utils import PreTrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -133,6 +133,7 @@ class TableTransformerConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "table-transformer"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
@ -245,22 +246,6 @@ class TableTransformerConfig(PreTrainedConfig):
|
||||
self.eos_coefficient = eos_coefficient
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.configuration_detr.DetrOnnxConfig
|
||||
class TableTransformerOnnxConfig(OnnxConfig):
|
||||
|
@ -14,12 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""TVP model configuration"""
|
||||
|
||||
import copy
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -99,6 +97,7 @@ class TvpConfig(PreTrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "tvp"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -172,14 +171,6 @@ class TvpConfig(PreTrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_backbone_config(cls, backbone_config: PreTrainedConfig, **kwargs):
|
||||
"""Instantiate a [`TvpConfig`] (or a derived class) from a pre-trained backbone model configuration.
|
||||
@ -192,18 +183,5 @@ class TvpConfig(PreTrainedConfig):
|
||||
"""
|
||||
return cls(backbone_config=backbone_config, **kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PreTrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
if output["backbone_config"] is not None:
|
||||
output["backbone_config"] = self.backbone_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
|
||||
|
||||
__all__ = ["TvpConfig"]
|
||||
|
@ -17,7 +17,7 @@
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -83,6 +83,7 @@ class UperNetConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "upernet"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -136,13 +137,5 @@ class UperNetConfig(PreTrainedConfig):
|
||||
self.auxiliary_concat_input = auxiliary_concat_input
|
||||
self.loss_ignore_index = loss_ignore_index
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["UperNetConfig"]
|
||||
|
@ -14,13 +14,12 @@
|
||||
# limitations under the License.
|
||||
"""VitMatte model configuration"""
|
||||
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -78,6 +77,7 @@ class VitMatteConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "vitmatte"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -122,23 +122,5 @@ class VitMatteConfig(PreTrainedConfig):
|
||||
self.convstream_hidden_sizes = convstream_hidden_sizes
|
||||
self.fusion_hidden_sizes = fusion_hidden_sizes
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PreTrainedConfig.to_dict`]. Returns:
|
||||
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["backbone_config"] = self.backbone_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
|
||||
|
||||
__all__ = ["VitMatteConfig"]
|
||||
|
@ -19,7 +19,7 @@ from typing import Optional
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -74,6 +74,7 @@ class VitPoseConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "vitpose"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -122,13 +123,5 @@ class VitPoseConfig(PreTrainedConfig):
|
||||
self.scale_factor = scale_factor
|
||||
self.use_simple_decoder = use_simple_decoder
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["VitPoseConfig"]
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -133,6 +133,7 @@ class ZoeDepthConfig(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "zoedepth"
|
||||
sub_configs = {"backbone_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -233,13 +234,5 @@ class ZoeDepthConfig(PreTrainedConfig):
|
||||
self.patch_transformer_intermediate_size = patch_transformer_intermediate_size
|
||||
self.patch_transformer_num_attention_heads = patch_transformer_num_attention_heads
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return (
|
||||
{"backbone_config": type(self.backbone_config)}
|
||||
if getattr(self, "backbone_config", None) is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ZOEDEPTH_PRETRAINED_CONFIG_ARCHIVE_MAP", "ZoeDepthConfig"]
|
||||
|
@ -131,28 +131,29 @@ class ConfigTester:
|
||||
# Iterate over all sub_configs if there are any and load them with their own classes
|
||||
sub_configs = general_config_loaded.sub_configs
|
||||
for sub_config_key, sub_class in sub_configs.items():
|
||||
if sub_class.__name__ == "AutoConfig":
|
||||
sub_class = sub_class.for_model(**general_config_dict[sub_config_key]).__class__
|
||||
sub_config_loaded = sub_class.from_pretrained(tmpdirname)
|
||||
else:
|
||||
sub_config_loaded = sub_class.from_pretrained(tmpdirname)
|
||||
if general_config_dict[sub_config_key] is not None:
|
||||
if sub_class.__name__ == "AutoConfig":
|
||||
sub_class = sub_class.for_model(**general_config_dict[sub_config_key]).__class__
|
||||
sub_config_loaded = sub_class.from_pretrained(tmpdirname)
|
||||
else:
|
||||
sub_config_loaded = sub_class.from_pretrained(tmpdirname)
|
||||
|
||||
# Pop `transformers_version`, it never exists when a config is part of a general composite config
|
||||
# Verify that loading with subconfig class results in same dict as if we loaded with general composite config class
|
||||
sub_config_loaded_dict = sub_config_loaded.to_dict()
|
||||
sub_config_loaded_dict.pop("transformers_version", None)
|
||||
general_config_dict[sub_config_key].pop("transformers_version", None)
|
||||
self.parent.assertEqual(sub_config_loaded_dict, general_config_dict[sub_config_key])
|
||||
# Pop `transformers_version`, it never exists when a config is part of a general composite config
|
||||
# Verify that loading with subconfig class results in same dict as if we loaded with general composite config class
|
||||
sub_config_loaded_dict = sub_config_loaded.to_dict()
|
||||
sub_config_loaded_dict.pop("transformers_version", None)
|
||||
general_config_dict[sub_config_key].pop("transformers_version", None)
|
||||
self.parent.assertEqual(sub_config_loaded_dict, general_config_dict[sub_config_key])
|
||||
|
||||
# Verify that the loaded config type is same as in the general config
|
||||
type_from_general_config = type(getattr(general_config_loaded, sub_config_key))
|
||||
self.parent.assertTrue(isinstance(sub_config_loaded, type_from_general_config))
|
||||
# Verify that the loaded config type is same as in the general config
|
||||
type_from_general_config = type(getattr(general_config_loaded, sub_config_key))
|
||||
self.parent.assertTrue(isinstance(sub_config_loaded, type_from_general_config))
|
||||
|
||||
# Now save only the sub-config and load it back to make sure the whole load-save-load pipeline works
|
||||
with tempfile.TemporaryDirectory() as tmpdirname2:
|
||||
sub_config_loaded.save_pretrained(tmpdirname2)
|
||||
sub_config_loaded_2 = sub_class.from_pretrained(tmpdirname2)
|
||||
self.parent.assertEqual(sub_config_loaded.to_dict(), sub_config_loaded_2.to_dict())
|
||||
# Now save only the sub-config and load it back to make sure the whole load-save-load pipeline works
|
||||
with tempfile.TemporaryDirectory() as tmpdirname2:
|
||||
sub_config_loaded.save_pretrained(tmpdirname2)
|
||||
sub_config_loaded_2 = sub_class.from_pretrained(tmpdirname2)
|
||||
self.parent.assertEqual(sub_config_loaded.to_dict(), sub_config_loaded_2.to_dict())
|
||||
|
||||
def create_and_test_config_from_pretrained_custom_kwargs(self):
|
||||
"""
|
||||
|
@ -1257,7 +1257,8 @@ class ModelTesterMixin:
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
for k in config.sub_configs:
|
||||
getattr(config, k).output_attentions = True
|
||||
if getattr(config, k) is not None:
|
||||
getattr(config, k).output_attentions = True
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
@ -1736,20 +1737,23 @@ class ModelTesterMixin:
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
for k in config.sub_configs:
|
||||
getattr(config, k).output_hidden_states = True
|
||||
if getattr(config, k) is not None:
|
||||
getattr(config, k).output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for k in config.sub_configs:
|
||||
getattr(config, k).output_hidden_states = True
|
||||
if getattr(config, k) is not None:
|
||||
getattr(config, k).output_hidden_states = True
|
||||
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
for k in config.sub_configs:
|
||||
getattr(config, k).output_attentions = self.has_attentions
|
||||
if getattr(config, k) is not None:
|
||||
getattr(config, k).output_attentions = self.has_attentions
|
||||
|
||||
# force eager attention to support output attentions
|
||||
if self.has_attentions:
|
||||
@ -3188,13 +3192,15 @@ class ModelTesterMixin:
|
||||
# we just need to test if passing 'attn_implementation' as a dict fails or not
|
||||
attn_implementation_per_subconfig = {"": "eager"}
|
||||
for key in config.sub_configs:
|
||||
attn_implementation_per_subconfig[key] = "eager"
|
||||
if getattr(config, key) is not None:
|
||||
attn_implementation_per_subconfig[key] = "eager"
|
||||
|
||||
config._attn_implementation = attn_implementation_per_subconfig
|
||||
model = model_class(config)
|
||||
for key in config.sub_configs:
|
||||
sub_config = getattr(model.config, key)
|
||||
self.assertTrue(sub_config._attn_implementation == "eager")
|
||||
if getattr(config, key) is not None:
|
||||
sub_config = getattr(model.config, key)
|
||||
self.assertTrue(sub_config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
@ -3934,8 +3940,9 @@ class ModelTesterMixin:
|
||||
# Update config values
|
||||
update_config_headdim(config, requested_dim)
|
||||
for key in config.sub_configs:
|
||||
sub_config = getattr(config, key)
|
||||
update_config_headdim(sub_config, requested_dim)
|
||||
if getattr(config, key) is not None:
|
||||
sub_config = getattr(config, key)
|
||||
update_config_headdim(sub_config, requested_dim)
|
||||
|
||||
return config
|
||||
|
||||
@ -4119,7 +4126,10 @@ class ModelTesterMixin:
|
||||
for subconfig_key in subconfig_keys:
|
||||
# Get the subconfig from the model config
|
||||
subconfig_from_model_config = getattr(model.config, subconfig_key)
|
||||
if subconfig_from_model_config.__class__ == subconfig_from_model_internal.__class__:
|
||||
if (
|
||||
subconfig_from_model_config is not None
|
||||
and subconfig_from_model_config.__class__ == subconfig_from_model_internal.__class__
|
||||
):
|
||||
# Since some composite models have different submodels parameterized by 2 of the same config
|
||||
# class instances, we need to check against a list of matching classes, and check that at least
|
||||
# 1 is the exact object (instead of checking immediately for similar object)
|
||||
@ -4150,7 +4160,8 @@ class ModelTesterMixin:
|
||||
# sanity check to make sure everything is correctly eager
|
||||
self.assertTrue(model.config._attn_implementation == "eager")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
if getattr(config, subconfig_key) is not None:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
|
||||
if not all(
|
||||
submodule._can_set_attn_implementation()
|
||||
@ -4170,7 +4181,8 @@ class ModelTesterMixin:
|
||||
# Check everything was correctly changed
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa")
|
||||
if getattr(config, subconfig_key) is not None:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa")
|
||||
|
||||
# Check we cannot set it to random values, and it raises an error
|
||||
with self.assertRaisesRegex(ValueError, 'Specified `attn_implementation="foo"` is not supported'):
|
||||
@ -4179,7 +4191,8 @@ class ModelTesterMixin:
|
||||
# Should still be sdpa everywhere
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa")
|
||||
if getattr(config, subconfig_key) is not None:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa")
|
||||
|
||||
def test_can_set_attention_dynamically_composite_model(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -4198,7 +4211,8 @@ class ModelTesterMixin:
|
||||
# sanity check to make sure everything is correctly eager
|
||||
self.assertTrue(model.config._attn_implementation == "eager")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
if getattr(config, subconfig_key) is not None:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
|
||||
if not all(
|
||||
submodule._can_set_attn_implementation()
|
||||
@ -4213,7 +4227,8 @@ class ModelTesterMixin:
|
||||
# Check only top-most was correctly changed
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
if getattr(config, subconfig_key) is not None:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
|
||||
@require_torch
|
||||
def test_bc_torch_dtype(self):
|
||||
|
Reference in New Issue
Block a user