mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
4 Commits
v4.50.2
...
docs-ctrl-
Author | SHA1 | Date | |
---|---|---|---|
99b11c1697 | |||
b36cffc24a | |||
e3eb1a41c9 | |||
208251a3f0 |
@ -20,7 +20,9 @@ from ...utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://huggingface.co/ctrl/resolve/main/config.json"}
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"Salesforce/ctrl": "https://huggingface.co/Salesforce/ctrl/resolve/main/config.json"
|
||||
}
|
||||
|
||||
|
||||
class CTRLConfig(PretrainedConfig):
|
||||
@ -28,7 +30,7 @@ class CTRLConfig(PretrainedConfig):
|
||||
This is the configuration class to store the configuration of a [`CTRLModel`] or a [`TFCTRLModel`]. It is used to
|
||||
instantiate a CTRL model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the
|
||||
[ctrl](https://huggingface.co/ctrl) architecture from SalesForce.
|
||||
[Salesforce/ctrl](https://huggingface.co/Salesforce/ctrl) architecture from SalesForce.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
@ -34,7 +34,7 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "CTRLConfig"
|
||||
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"ctrl"
|
||||
"Salesforce/ctrl"
|
||||
# See all CTRL models at https://huggingface.co/models?filter=ctrl
|
||||
]
|
||||
|
||||
@ -374,8 +374,8 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
>>> from transformers import AutoTokenizer, CTRLModel
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
|
||||
>>> model = CTRLModel.from_pretrained("ctrl")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
||||
>>> model = CTRLModel.from_pretrained("Salesforce/ctrl")
|
||||
|
||||
>>> # CTRL was trained with control codes as the first token
|
||||
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
||||
@ -564,8 +564,8 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, CTRLLMHeadModel
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
|
||||
>>> model = CTRLLMHeadModel.from_pretrained("ctrl")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
||||
>>> model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
|
||||
|
||||
>>> # CTRL was trained with control codes as the first token
|
||||
>>> inputs = tokenizer("Wikipedia The llama is", return_tensors="pt")
|
||||
@ -692,8 +692,8 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, CTRLForSequenceClassification
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
|
||||
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
||||
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl")
|
||||
|
||||
>>> # CTRL was trained with control codes as the first token
|
||||
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
||||
@ -713,7 +713,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
||||
>>> torch.manual_seed(42) # doctest: +IGNORE_RESULT
|
||||
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
||||
>>> num_labels = len(model.config.id2label)
|
||||
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", num_labels=num_labels)
|
||||
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
|
||||
|
||||
>>> labels = torch.tensor(1)
|
||||
>>> loss = model(**inputs, labels=labels).loss
|
||||
@ -727,8 +727,10 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, CTRLForSequenceClassification
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
|
||||
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", problem_type="multi_label_classification")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
||||
>>> model = CTRLForSequenceClassification.from_pretrained(
|
||||
... "Salesforce/ctrl", problem_type="multi_label_classification"
|
||||
... )
|
||||
|
||||
>>> # CTRL was trained with control codes as the first token
|
||||
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
||||
@ -745,7 +747,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
||||
```python
|
||||
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
||||
>>> num_labels = len(model.config.id2label)
|
||||
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", num_labels=num_labels)
|
||||
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
|
||||
|
||||
>>> num_labels = len(model.config.id2label)
|
||||
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
|
||||
|
@ -39,11 +39,11 @@ from .configuration_ctrl import CTRLConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "ctrl"
|
||||
_CHECKPOINT_FOR_DOC = "Salesforce/ctrl"
|
||||
_CONFIG_FOR_DOC = "CTRLConfig"
|
||||
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"ctrl"
|
||||
"Salesforce/ctrl"
|
||||
# See all CTRL models at https://huggingface.co/models?filter=ctrl
|
||||
]
|
||||
|
||||
|
@ -264,7 +264,7 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_lm_generate_ctrl(self):
|
||||
model = CTRLLMHeadModel.from_pretrained("ctrl")
|
||||
model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor(
|
||||
[[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device
|
||||
|
@ -257,7 +257,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
class TFCTRLModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_ctrl(self):
|
||||
model = TFCTRLLMHeadModel.from_pretrained("ctrl")
|
||||
model = TFCTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
|
||||
input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32) # Legal the president is
|
||||
expected_output_ids = [
|
||||
11859,
|
||||
|
@ -16,6 +16,7 @@
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from transformers import CTRLConfig
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
|
||||
@ -73,7 +74,9 @@ def get_checkpoint_from_config_class(config_class):
|
||||
def check_config_docstrings_have_checkpoints():
|
||||
configs_without_checkpoint = []
|
||||
|
||||
for config_class in list(CONFIG_MAPPING.values()):
|
||||
a = [CTRLConfig]
|
||||
|
||||
for config_class in a:
|
||||
# Skip deprecated models
|
||||
if "models.deprecated" in config_class.__module__:
|
||||
continue
|
||||
|
@ -1,2 +1,4 @@
|
||||
docs/source/en/generation_strategies.md
|
||||
docs/source/en/model_doc/ctrl.md
|
||||
docs/source/en/task_summary.md
|
||||
src/transformers/models/ctrl/modeling_ctrl.py
|
||||
|
Reference in New Issue
Block a user