Source code for hfselect.model_utils

import torch
from transformers import (
    PreTrainedModel,
    AutoTokenizer,
    BertModel,
    BertConfig,
    BertForSequenceClassification,
    BertTokenizer,
    RobertaModel,
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaTokenizer,
    DistilBertModel,
    DistilBertConfig,
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
)

SEQUENCE_CLASSIFICATION_MODEL_ARGS = {
    "output_attentions": False,
    "output_hidden_states": False,
}

MODELS = {
    "bert": {
        "base_model": BertModel,
        "sequence_classification_model": BertForSequenceClassification,
        "tokenizer": BertTokenizer,
        "config": BertConfig,
        "pretrained_name": "bert-base-uncased",
        "base_model_attribute_name": "bert",
        "embedding_size": 768,
        "pooling_method": "bert_pooling",
    },
    "bert_multilang": {
        "base_model": BertModel,
        "sequence_classification_model": BertForSequenceClassification,
        "tokenizer": BertTokenizer,
        "config": BertConfig,
        "pretrained_name": "bert-base-multilingual-uncased",
        "base_model_attribute_name": "bert",
        "embedding_size": 768,
        "pooling_method": "bert_pooling",
    },
    "bert-tiny": {
        "base_model": BertModel,
        "sequence_classification_model": BertForSequenceClassification,
        "tokenizer": BertTokenizer,
        "config": BertConfig,
        "pretrained_name": "prajjwal1/bert-tiny",
        "base_model_attribute_name": "bert",
        "embedding_size": 128,
        "pooling_method": "bert_pooling",
    },
    "roberta-base": {
        "base_model": RobertaModel,
        "sequence_classification_model": RobertaForSequenceClassification,
        "tokenizer": RobertaTokenizer,
        "config": RobertaConfig,
        "pretrained_name": "roberta-base",
        "base_model_attribute_name": "roberta",
        "embedding_size": 768,
        "pooling_method": "roberta_pooling",
    },
    "distilbert": {
        "base_model": DistilBertModel,
        "sequence_classification_model": DistilBertForSequenceClassification,
        "tokenizer": DistilBertTokenizer,
        "config": DistilBertConfig,
        "pretrained_name": "distilbert-base-uncased",
        "base_model_attribute_name": "distilbert",
        "embedding_size": 768,
        "pooling_method": "distilbert_pooling",
    },
    "sentence_mpnet": {
        "base_model": BertModel,
        "sequence_classification_model": BertForSequenceClassification,
        "tokenizer": AutoTokenizer,
        "config": BertConfig,
        "pretrained_name": "sentence-transformers/all-mpnet-base-v2",
        "base_model_attribute_name": "bert",
        "embedding_size": 768,
        "pooling_method": "sbert_pooling",
    },
    "sentence_minilm": {
        "base_model": BertModel,
        "sequence_classification_model": BertForSequenceClassification,
        "tokenizer": AutoTokenizer,
        "config": BertConfig,
        "pretrained_name": "sentence-transformers/all-MiniLM-L6-v2",
        "base_model_attribute_name": "bert",
        "embedding_size": 384,
        "pooling_method": "sbert_pooling",
    },
    "sentence_minilm_multilang": {
        "base_model": BertModel,
        "sequence_classification_model": BertForSequenceClassification,
        "tokenizer": AutoTokenizer,
        "config": BertConfig,
        "pretrained_name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
        "base_model_attribute_name": "bert",
        "embedding_size": 384,
        "pooling_method": "sbert_pooling",
    },
}


[docs] def get_pooled_output( base_model: PreTrainedModel, input_ids: torch.Tensor, attention_mask: torch.Tensor ): """ Embeds texts using a language model Args: base_model: The language model input_ids: The input IDs of the texts (after tokenization) attention_mask: The attention masks of the texts (after tokenization) Returns: The embeddings of the texts """ if isinstance(base_model, BertModel): if "sentence-transformers" in base_model.name_or_path: token_embeddings = base_model(input_ids, attention_mask=attention_mask)[0] input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) return base_model(input_ids, attention_mask=attention_mask)[1] elif isinstance(base_model, DistilBertModel): return base_model(input_ids, attention_mask=attention_mask)[0][:, 0] elif isinstance(base_model, RobertaModel): return base_model(input_ids, attention_mask=attention_mask)[0][:, 0, :] else: raise NotImplementedError( "Method for getting CLS embeddings is not implemented for this model." )