Compare commits

...

4 Commits

Author SHA1 Message Date
99b11c1697 Add modeling file 2023-09-18 13:14:21 +02:00
b36cffc24a Slow doc tests 2023-09-18 11:50:18 +02:00
e3eb1a41c9 Fixup 2023-09-15 12:34:01 -04:00
208251a3f0 moved ctrl to Salesforce/ctrl
redirects should theoretically work, but still updating those repo references for clarity
2023-09-15 14:02:43 +02:00
7 changed files with 27 additions and 18 deletions

View File

@ -20,7 +20,9 @@ from ...utils import logging
logger = logging.get_logger(__name__)
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://huggingface.co/ctrl/resolve/main/config.json"}
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"Salesforce/ctrl": "https://huggingface.co/Salesforce/ctrl/resolve/main/config.json"
}
class CTRLConfig(PretrainedConfig):
@ -28,7 +30,7 @@ class CTRLConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`CTRLModel`] or a [`TFCTRLModel`]. It is used to
instantiate a CTRL model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the
[ctrl](https://huggingface.co/ctrl) architecture from SalesForce.
[Salesforce/ctrl](https://huggingface.co/Salesforce/ctrl) architecture from SalesForce.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

View File

@ -34,7 +34,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "CTRLConfig"
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
"ctrl"
"Salesforce/ctrl"
# See all CTRL models at https://huggingface.co/models?filter=ctrl
]
@ -374,8 +374,8 @@ class CTRLModel(CTRLPreTrainedModel):
>>> from transformers import AutoTokenizer, CTRLModel
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
>>> model = CTRLModel.from_pretrained("ctrl")
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
>>> model = CTRLModel.from_pretrained("Salesforce/ctrl")
>>> # CTRL was trained with control codes as the first token
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
@ -564,8 +564,8 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
>>> import torch
>>> from transformers import AutoTokenizer, CTRLLMHeadModel
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
>>> model = CTRLLMHeadModel.from_pretrained("ctrl")
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
>>> model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
>>> # CTRL was trained with control codes as the first token
>>> inputs = tokenizer("Wikipedia The llama is", return_tensors="pt")
@ -692,8 +692,8 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
>>> import torch
>>> from transformers import AutoTokenizer, CTRLForSequenceClassification
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl")
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl")
>>> # CTRL was trained with control codes as the first token
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
@ -713,7 +713,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
>>> torch.manual_seed(42) # doctest: +IGNORE_RESULT
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", num_labels=num_labels)
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
>>> labels = torch.tensor(1)
>>> loss = model(**inputs, labels=labels).loss
@ -727,8 +727,10 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
>>> import torch
>>> from transformers import AutoTokenizer, CTRLForSequenceClassification
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", problem_type="multi_label_classification")
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
>>> model = CTRLForSequenceClassification.from_pretrained(
... "Salesforce/ctrl", problem_type="multi_label_classification"
... )
>>> # CTRL was trained with control codes as the first token
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
@ -745,7 +747,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
```python
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", num_labels=num_labels)
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
>>> num_labels = len(model.config.id2label)
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(

View File

@ -39,11 +39,11 @@ from .configuration_ctrl import CTRLConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "ctrl"
_CHECKPOINT_FOR_DOC = "Salesforce/ctrl"
_CONFIG_FOR_DOC = "CTRLConfig"
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
"ctrl"
"Salesforce/ctrl"
# See all CTRL models at https://huggingface.co/models?filter=ctrl
]

View File

@ -264,7 +264,7 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_ctrl(self):
model = CTRLLMHeadModel.from_pretrained("ctrl")
model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
model.to(torch_device)
input_ids = torch.tensor(
[[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device

View File

@ -257,7 +257,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
class TFCTRLModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_ctrl(self):
model = TFCTRLLMHeadModel.from_pretrained("ctrl")
model = TFCTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32) # Legal the president is
expected_output_ids = [
11859,

View File

@ -16,6 +16,7 @@
import inspect
import re
from transformers import CTRLConfig
from transformers.utils import direct_transformers_import
@ -73,7 +74,9 @@ def get_checkpoint_from_config_class(config_class):
def check_config_docstrings_have_checkpoints():
configs_without_checkpoint = []
for config_class in list(CONFIG_MAPPING.values()):
a = [CTRLConfig]
for config_class in a:
# Skip deprecated models
if "models.deprecated" in config_class.__module__:
continue

View File

@ -1,2 +1,4 @@
docs/source/en/generation_strategies.md
docs/source/en/model_doc/ctrl.md
docs/source/en/task_summary.md
src/transformers/models/ctrl/modeling_ctrl.py