Source code for hfselect.esm_logme

from typing import Optional
from collections import defaultdict

from tqdm.auto import tqdm
import numpy as np
import torch
from torch.utils.data import SequentialSampler, DataLoader
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer

from .logme import LogME
from .model_utils import get_pooled_output
from .utils import fetch_esms, find_esm_repo_ids
from .dataset import Dataset
from .task_ranking import TaskRanking
from .esm import ESM
from hfselect import logger


[docs] class NoESMsFoundError(Exception): def __init__(self): super().__init__("No ESMs matching the search criteria could be found.")
[docs] def compute_scores( dataset: Dataset, base_model: PreTrainedModel, esms: list[ESM], tokenizer: PreTrainedTokenizer, batch_size: int = 128, device_name: str = "cpu", ) -> list[float]: """ Computes the ESM-LogME scores for all ESMs. Args: dataset: The target dataset base_model: The base LM used for computing embeddings esms: List of the ESMs representing the intermediate datasets tokenizer: The tokenizer used for tokenizing the target texts batch_size: Describes how many embeddings are computed and transformed in a batch device_name: The device name of the device for computation (e.g. "cpu", "cuda") Returns: scores: The ESM-LogME scores produced by the ESMs """ sampler = SequentialSampler(dataset) dataloader = DataLoader( dataset, sampler=sampler, batch_size=batch_size, collate_fn=lambda x: dataset.collate_fn(x, tokenizer=tokenizer), ) device = torch.device(device_name) base_model.to(device) for esm in esms: esm.to(device) regression = dataset.is_regression if regression: label_dtype = float else: label_dtype = int labels = np.zeros(0, label_dtype) esm_embeddings = [[] for _ in range(len(esms))] faulty_esm_indices = set() errors = defaultdict(list) with tqdm(dataloader, desc="Computing embeddings", unit="batch") as pbar: for batch in pbar: batch = tuple(t.to(device) for t in batch) b_input_ids, b_input_mask, b_labels = batch b_labels = b_labels.detach().cpu().numpy().flatten() with torch.no_grad(): batch_base_embeddings = get_pooled_output( base_model, b_input_ids, b_input_mask ) for i, esm in enumerate(esms): if i in faulty_esm_indices: continue try: batch_transformed_embeddings = ( esm(batch_base_embeddings).cpu().numpy() ) esm_embeddings[i].append(batch_transformed_embeddings) except Exception as e: faulty_esm_indices.add(i) errors[type(e).__name__].append(esm.config.get(["repo_id"])) labels = np.append(labels, b_labels, axis=0) if len(errors) > 0: logger.warning( f"Computing embeddings failed for {len(faulty_esm_indices)} of {len(esms)} ESMs." ) logger.debug(errors) scores = [] with tqdm(esm_embeddings, desc="Computing LogME", unit="Task") as pbar: for idx, features in enumerate(pbar): if idx in faulty_esm_indices: scores.append(np.nan) continue embeddings = np.vstack(features) scores.append( LogME(regression=regression).fit( embeddings, labels, add_intercept=False ) ) return scores
[docs] def compute_task_ranking( dataset: Dataset, model_name: str, esms: Optional[list[ESM]] = None, esm_repo_ids: Optional[list[str]] = None, batch_size: int = 128, device_name: str = "cpu", ) -> TaskRanking: """ Computes a task ranking by first computing scores and then ranking the intermediate datasets by their scores. Args: dataset: The target dataset model_name: The name of the base LM used for computing embeddings esms: List of the ESMs representing the intermediate datasets esm_repo_ids: List of the HF repo IDs of the ESMs representing the intermediate datasets batch_size: Describes how many embeddings are computed and transformed in a batch device_name: The device name of the device for computation (e.g. "cpu", "cuda") Returns: task_ranking: A task ranking of the intermediate tasks. Intermediate datasets with invalid ESMS are excluded. """ if esms is None: if esm_repo_ids is None: esm_repo_ids = find_esm_repo_ids(model_name=model_name) esms = fetch_esms(esm_repo_ids) if len(esms) == 0: logger.error( "No ESMs matching the search criteria could be found." "You can use get_esm_coverage to find out which base models have valid ESMs." ) raise NoESMsFoundError bert_model = AutoModel.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) scores = compute_scores( dataset=dataset, base_model=bert_model, tokenizer=tokenizer, esms=esms, batch_size=batch_size, device_name=device_name, ) return TaskRanking([esm.create_config() for esm in esms], scores)