Source code for hfselect.utils

from typing import Optional
from collections import defaultdict, Counter

from tqdm.auto import tqdm
from huggingface_hub import HfApi, model_info, ModelInfo

from .esm import ESM, ESMNotInitializedError
from .esmconfig import ESMConfig, InvalidESMConfigError
from hfselect import logger


[docs] def find_esm_repo_ids( model_name: Optional[str] = None, filters: Optional[list[str]] = None ) -> list[str]: """ Finds all ESM repo IDs for the specified language model name and filters Args: model_name: The name of the base language model filters: Filters for selecting ESMs (see hf_api.list_models) Returns: A list of ESM repo IDs """ esm_infos = find_esm_model_infos(model_name, filters=filters) return [esm_info.id for esm_info in esm_infos]
[docs] def find_esm_model_infos( model_name: Optional[str] = None, filters: Optional[list[str]] = None ) -> list[ModelInfo]: """ Finds HF ModelInfos for all ESMs specified by the language model name and filters Args: model_name: The name of the base language model filters: Filters for selecting ESMs (see hf_api.list_models) Returns: A list of ESM ModelInfos """ hf_api = HfApi() if filters is None: filters = [] elif isinstance(filters, str): filters = [filters] filters.append("embedding_space_map") if model_name: # Make sure that the possibly redirected repo name is used, # e.g. google-bert/bert-base-uncased instead of bert-base-uncased model_name = model_info(model_name).id filters.append(f"base_model:{model_name}") return list(hf_api.list_models(filter=filters))
[docs] def fetch_esms( repo_ids: list[str], ) -> list[ESM]: """ Fetches ESMs by their repo IDs. Invalid ESMs are excluded from the results. This can be seen in the logs. Args: repo_ids: The HF repo IDs of the ESMs Returns: A list of ESMs """ esms = [] errors = defaultdict(list) with tqdm(repo_ids, desc="Fetching ESMs", unit="ESM") as pbar: for repo_id in pbar: try: esm = ESM.from_pretrained(repo_id) if esm.is_legacy_model: esm.convert_legacy_to_new() if not esm.is_initialized: raise ESMNotInitializedError if not esm.create_config().is_valid: raise InvalidESMConfigError esms.append(esm) except Exception as e: errors[type(e).__name__].append(repo_id) if len(errors) > 0: if len(errors) > 0: logger.warning( f"Fetching ESMs failed for {sum(map(len, errors.values()))} of {len(repo_ids)} repo IDs." ) logger.debug(errors) return esms
[docs] def fetch_esm_configs( repo_ids: list[str], ) -> list[ESMConfig]: """ Fetches ESMConfigs by their repo IDs. Invalid ESMConfigs are excluded from the results. This can be seen in the logs. Args: repo_ids: The HF repo IDs of the ESMs Returns: A list of ESMConfigs """ esm_configs = [] errors = defaultdict(list) with tqdm(repo_ids, desc="Fetching ESM Configs", unit="ESM Config") as pbar: for repo_id in pbar: try: esm_config = ESMConfig.from_pretrained(repo_id) if not esm_config.is_valid: raise InvalidESMConfigError esm_configs.append(esm_config) except Exception as e: errors[type(e).__name__].append(repo_id) if len(errors) > 0: logger.warning( f"Fetching ESM configs failed for {sum(map(len, errors.values()))} of {len(repo_ids)} repo IDs." ) logger.debug(errors) return esm_configs
[docs] def get_esm_coverage(filters: Optional[list[str]] = None) -> dict[str, int]: """ Finds out how many ESMs are available for each base model Args: filters: Filters for selecting ESMs (see hf_api.list_models) Returns: A dictionary with base model names as keys and the number of available ESMs for them as items """ model_infos = find_esm_model_infos(filters=filters) base_model_names = [_get_base_model_from_model_info(info) for info in model_infos] return dict(sorted(Counter(base_model_names).items(), key=lambda x: -x[1]))
def _get_base_model_from_model_info(info: ModelInfo) -> Optional[str]: base_model_tags = [ tag for tag in info.tags if tag.startswith("base_model:finetune:") ] if len(base_model_tags) == 0: return None return base_model_tags[0].split("base_model:finetune:")[-1]