Source code for hfselect.task_ranking

from typing import Optional
from collections.abc import Sequence

import numpy as np
import pandas as pd

from .esmconfig import ESMConfig


[docs] class InvalidTaskRankingError(Exception): """ An Exception raised when the task ranking contains an error """ default_message = "The task ranking is invalid." def __init__(self, message: Optional[str] = None): super().__init__(message or self.default_message)
[docs] class TaskRanking(Sequence): """ A task ranking contains the esm configs of ranked ESMS, their scores and their ranks """ def __init__( self, esm_configs: list[ESMConfig], scores: list[float], ranks: Optional[list[int]] = None, ): """ Creates a new task ranking Args: esm_configs: The ESMConfigs of the ESMs that are ranked scores: The ESM-LogME score of the ESMs / their datasets ranks: The ranks of the ranking (advised to not change) """ self.esm_configs = esm_configs self.scores = scores self.ranks = ranks self._remove_faulty_scores() self._check_validity() self._sort() def __getitem__(self, index): if isinstance(index, int): return TaskRanking( [self.esm_configs[index]], [self.scores[index]], [self.ranks[index]] ) elif isinstance(index, slice): return TaskRanking( self.esm_configs[index], self.scores[index], self.ranks[index] ) raise Exception(f"TaskRanking could not be indexed with a type {type(index)}.") def __len__(self): return len(self.esm_configs) def __repr__(self): return "\n".join(self._format_for_output()) def __str__(self): return "\n".join(self[:10]._format_for_output(score_rounding=6)) + ( "\n..." if len(self) > 10 else "" ) def _format_for_output(self, score_rounding: Optional[int] = None): output_lines = [] for rank, esm_config, score in zip(self.ranks, self.esm_configs, self.scores): output_lines.append( f"{str(rank) + '.':<5}{esm_config.task_id:<50}" f"Score: {round(score, score_rounding) if score_rounding else score}" ) return output_lines def _sort(self): sorting_order = np.argsort(self.scores)[::-1] self.esm_configs = [self.esm_configs[idx] for idx in sorting_order] self.scores = [self.scores[idx] for idx in sorting_order] self.ranks = ( list(range(1, len(self) + 1)) if self.ranks is None else [self.ranks[idx] for idx in sorting_order] ) def _remove_faulty_scores(self): faulty_indices = [ idx for idx, score in enumerate(self.scores) if np.isnan(score) ] self.esm_configs = self._remove_indices(self.esm_configs, faulty_indices) self.scores = self._remove_indices(self.scores, faulty_indices) self.ranks = self._remove_indices(self.ranks, faulty_indices) @staticmethod def _remove_indices(list_to_clean: Optional[list], indices: list[int]): if list_to_clean is None: return if len(indices) == 0: return list_to_clean return [elem for idx, elem in enumerate(list_to_clean) if idx not in indices] def _check_validity(self): if len(self.esm_configs) != len(self.scores): raise InvalidTaskRankingError( f"Task ranking contains {len(self.esm_configs)} ESM configs but {len(self.scores)} scores." ) if self.ranks is not None and len(self.esm_configs) != len(self.ranks): raise InvalidTaskRankingError( f"Task ranking contains {len(self.esm_configs)} ESM configs but {len(self.ranks)} provided ranks." ) if len(self.esm_configs) == 0: raise InvalidTaskRankingError("Task ranking is empty.")
[docs] def to_pandas(self) -> pd.DataFrame: """ Creates a Pandas DataFrame of the ranking Returns: The resulting dataframe """ return pd.DataFrame.from_dict( { "Rank": self.ranks, "Task ID": [esm_config.task_id for esm_config in self.esm_configs], "Task Subset": [ esm_config.task_subset for esm_config in self.esm_configs ], "Text Column": [ esm_config.text_column for esm_config in self.esm_configs ], "Label Column": [ esm_config.label_column for esm_config in self.esm_configs ], "Task Split": [ esm_config.task_split for esm_config in self.esm_configs ], "Num Examples": [ esm_config.num_examples for esm_config in self.esm_configs ], "ESM Architecture": [ esm_config.esm_architecture for esm_config in self.esm_configs ], "Score": self.scores, } ).set_index("Rank")