mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 17:48:57 +08:00
Compare commits
4 Commits
v4.35.0
...
docs-ctrl-
Author | SHA1 | Date | |
---|---|---|---|
99b11c1697 | |||
b36cffc24a | |||
e3eb1a41c9 | |||
208251a3f0 |
@ -20,7 +20,9 @@ from ...utils import logging
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://huggingface.co/ctrl/resolve/main/config.json"}
|
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
"Salesforce/ctrl": "https://huggingface.co/Salesforce/ctrl/resolve/main/config.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class CTRLConfig(PretrainedConfig):
|
class CTRLConfig(PretrainedConfig):
|
||||||
@ -28,7 +30,7 @@ class CTRLConfig(PretrainedConfig):
|
|||||||
This is the configuration class to store the configuration of a [`CTRLModel`] or a [`TFCTRLModel`]. It is used to
|
This is the configuration class to store the configuration of a [`CTRLModel`] or a [`TFCTRLModel`]. It is used to
|
||||||
instantiate a CTRL model according to the specified arguments, defining the model architecture. Instantiating a
|
instantiate a CTRL model according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
configuration with the defaults will yield a similar configuration to that of the
|
configuration with the defaults will yield a similar configuration to that of the
|
||||||
[ctrl](https://huggingface.co/ctrl) architecture from SalesForce.
|
[Salesforce/ctrl](https://huggingface.co/Salesforce/ctrl) architecture from SalesForce.
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
@ -34,7 +34,7 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "CTRLConfig"
|
_CONFIG_FOR_DOC = "CTRLConfig"
|
||||||
|
|
||||||
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"ctrl"
|
"Salesforce/ctrl"
|
||||||
# See all CTRL models at https://huggingface.co/models?filter=ctrl
|
# See all CTRL models at https://huggingface.co/models?filter=ctrl
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -374,8 +374,8 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
>>> from transformers import AutoTokenizer, CTRLModel
|
>>> from transformers import AutoTokenizer, CTRLModel
|
||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
|
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
||||||
>>> model = CTRLModel.from_pretrained("ctrl")
|
>>> model = CTRLModel.from_pretrained("Salesforce/ctrl")
|
||||||
|
|
||||||
>>> # CTRL was trained with control codes as the first token
|
>>> # CTRL was trained with control codes as the first token
|
||||||
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
||||||
@ -564,8 +564,8 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
>>> from transformers import AutoTokenizer, CTRLLMHeadModel
|
>>> from transformers import AutoTokenizer, CTRLLMHeadModel
|
||||||
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
|
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
||||||
>>> model = CTRLLMHeadModel.from_pretrained("ctrl")
|
>>> model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
|
||||||
|
|
||||||
>>> # CTRL was trained with control codes as the first token
|
>>> # CTRL was trained with control codes as the first token
|
||||||
>>> inputs = tokenizer("Wikipedia The llama is", return_tensors="pt")
|
>>> inputs = tokenizer("Wikipedia The llama is", return_tensors="pt")
|
||||||
@ -692,8 +692,8 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
>>> from transformers import AutoTokenizer, CTRLForSequenceClassification
|
>>> from transformers import AutoTokenizer, CTRLForSequenceClassification
|
||||||
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
|
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
||||||
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl")
|
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl")
|
||||||
|
|
||||||
>>> # CTRL was trained with control codes as the first token
|
>>> # CTRL was trained with control codes as the first token
|
||||||
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
||||||
@ -713,7 +713,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
|||||||
>>> torch.manual_seed(42) # doctest: +IGNORE_RESULT
|
>>> torch.manual_seed(42) # doctest: +IGNORE_RESULT
|
||||||
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
||||||
>>> num_labels = len(model.config.id2label)
|
>>> num_labels = len(model.config.id2label)
|
||||||
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", num_labels=num_labels)
|
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
|
||||||
|
|
||||||
>>> labels = torch.tensor(1)
|
>>> labels = torch.tensor(1)
|
||||||
>>> loss = model(**inputs, labels=labels).loss
|
>>> loss = model(**inputs, labels=labels).loss
|
||||||
@ -727,8 +727,10 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
>>> from transformers import AutoTokenizer, CTRLForSequenceClassification
|
>>> from transformers import AutoTokenizer, CTRLForSequenceClassification
|
||||||
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
|
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
||||||
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", problem_type="multi_label_classification")
|
>>> model = CTRLForSequenceClassification.from_pretrained(
|
||||||
|
... "Salesforce/ctrl", problem_type="multi_label_classification"
|
||||||
|
... )
|
||||||
|
|
||||||
>>> # CTRL was trained with control codes as the first token
|
>>> # CTRL was trained with control codes as the first token
|
||||||
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
||||||
@ -745,7 +747,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
|||||||
```python
|
```python
|
||||||
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
||||||
>>> num_labels = len(model.config.id2label)
|
>>> num_labels = len(model.config.id2label)
|
||||||
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", num_labels=num_labels)
|
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
|
||||||
|
|
||||||
>>> num_labels = len(model.config.id2label)
|
>>> num_labels = len(model.config.id2label)
|
||||||
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
|
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
|
||||||
|
@ -39,11 +39,11 @@ from .configuration_ctrl import CTRLConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "ctrl"
|
_CHECKPOINT_FOR_DOC = "Salesforce/ctrl"
|
||||||
_CONFIG_FOR_DOC = "CTRLConfig"
|
_CONFIG_FOR_DOC = "CTRLConfig"
|
||||||
|
|
||||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"ctrl"
|
"Salesforce/ctrl"
|
||||||
# See all CTRL models at https://huggingface.co/models?filter=ctrl
|
# See all CTRL models at https://huggingface.co/models?filter=ctrl
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -264,7 +264,7 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_ctrl(self):
|
def test_lm_generate_ctrl(self):
|
||||||
model = CTRLLMHeadModel.from_pretrained("ctrl")
|
model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
[[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device
|
[[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device
|
||||||
|
@ -257,7 +257,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
class TFCTRLModelLanguageGenerationTest(unittest.TestCase):
|
class TFCTRLModelLanguageGenerationTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_ctrl(self):
|
def test_lm_generate_ctrl(self):
|
||||||
model = TFCTRLLMHeadModel.from_pretrained("ctrl")
|
model = TFCTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
|
||||||
input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32) # Legal the president is
|
input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32) # Legal the president is
|
||||||
expected_output_ids = [
|
expected_output_ids = [
|
||||||
11859,
|
11859,
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from transformers import CTRLConfig
|
||||||
from transformers.utils import direct_transformers_import
|
from transformers.utils import direct_transformers_import
|
||||||
|
|
||||||
|
|
||||||
@ -73,7 +74,9 @@ def get_checkpoint_from_config_class(config_class):
|
|||||||
def check_config_docstrings_have_checkpoints():
|
def check_config_docstrings_have_checkpoints():
|
||||||
configs_without_checkpoint = []
|
configs_without_checkpoint = []
|
||||||
|
|
||||||
for config_class in list(CONFIG_MAPPING.values()):
|
a = [CTRLConfig]
|
||||||
|
|
||||||
|
for config_class in a:
|
||||||
# Skip deprecated models
|
# Skip deprecated models
|
||||||
if "models.deprecated" in config_class.__module__:
|
if "models.deprecated" in config_class.__module__:
|
||||||
continue
|
continue
|
||||||
|
@ -1,2 +1,4 @@
|
|||||||
docs/source/en/generation_strategies.md
|
docs/source/en/generation_strategies.md
|
||||||
|
docs/source/en/model_doc/ctrl.md
|
||||||
docs/source/en/task_summary.md
|
docs/source/en/task_summary.md
|
||||||
|
src/transformers/models/ctrl/modeling_ctrl.py
|
||||||
|
Reference in New Issue
Block a user