Source code for hfselect.trainers

from typing import Optional, Union
from abc import ABC, abstractmethod
import os
from datetime import datetime
import time
import json

from tqdm.auto import tqdm
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import RandomSampler, DataLoader
from transformers import (
    get_linear_schedule_with_warmup,
    PreTrainedModel,
    PreTrainedTokenizer,
)

from .esm import ESM
from .esmconfig import ESMConfig
from .embedding_dataset import EmbeddingDataset, create_embedding_dataset
from .dataset import Dataset
from hfselect import logger


[docs] class Trainer(ABC): """ A abstract trainer class """ def __init__( self, model: Optional[nn.Module] = None, optimizer: Optional["torch.optim.Optimizer"] = None, learning_rate: float = 0.001, weight_decay: float = 0.01, device_name: str = "cpu", ): self.model = model self.learning_rate = learning_rate self.weight_decay = weight_decay self.optimizer = optimizer self.scheduler = None if device_name != "cpu" and torch.cuda.is_available(): self.device = ( torch.device(device_name) if torch.cuda.is_available() else torch.device("cpu") ) else: self.device = "cpu" # self.model.to(self.device) self.total_loss = 0 self.num_train_examples = 0 @abstractmethod def _train_step(self, *args, **kwargs): pass @abstractmethod def _create_model(self): pass def _create_optimizer(self, model: nn.Module) -> AdamW: return AdamW( model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) @staticmethod def _create_scheduler( optimizer: "torch.optim.Optimizer", num_train_steps: int ) -> "torch.optim.lr_scheduler.LRScheduler": return get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=num_train_steps )
[docs] def reset_loss(self): """ Resets the loss for optimization. Returns: """ self.total_loss = 0 self.num_train_examples = 0
@property def avg_loss(self): """ The average loss per training example Returns: The average loss per training example """ return self.total_loss / self.num_train_examples
[docs] class ESMTrainer(Trainer): """ A trainer class that fabricates ESMs """ def __init__( self, model: Optional[nn.Module] = None, optimizer: Optional["torch.optim.Optimizer"] = None, weight_decay: float = 0.01, learning_rate: float = 0.01, device_name: str = "cpu", ): """ Creates an ESMTrainer Args: model: The underlying model to be used in the ESM optimizer: The optimizer for training the ESM weight_decay: The weight decay for training the ESM learning_rate: The learning rate for training the ESM device_name: The device name of the device for computation (e.g. "cpu", "cuda") """ super(ESMTrainer, self).__init__( model=model, optimizer=optimizer, weight_decay=weight_decay, learning_rate=learning_rate, device_name=device_name, ) self.loss_fct = nn.MSELoss() def _create_model( self, architecture: Optional[Union[str, dict[str, Union[str, tuple[str]]]]] = None, embedding_dim: Optional[int] = None, ) -> ESM: # Creates a new ESM return ESM(architecture=architecture, embedding_dim=embedding_dim) def _train_step(self, embeddings_batch: tuple[torch.Tensor, torch.Tensor]) -> float: # One train step for one batch self.model.train() embeddings_batch = tuple(b.to(self.device) for b in embeddings_batch) b_standard_embeddings, b_transferred_embeddings = embeddings_batch self.model.zero_grad() outputs = self.model(b_standard_embeddings.float()) loss = self.loss_fct(outputs, b_transferred_embeddings.float()) self.total_loss += loss.item() self.num_train_examples += len(b_standard_embeddings) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() self.scheduler.step() return loss.item()
[docs] def train_with_embeddings( self, embedding_dataset: EmbeddingDataset, architecture: Optional[ Union[str, dict[str, Union[str, tuple[str]]]] ] = "linear", output_dir: Optional[str] = None, num_epochs: int = 10, batch_size: int = 32, reset_model: bool = True, verbose: int = 1, ) -> ESM: """ Trains an ESM using an EmbeddingDataset dataset. The ESM is fitted to the embedding pairs in the dataset. Args: embedding_dataset: The embeddings of the same dataset embedded by a base model and a fine-tuned model architecture: The desired architecture of the ESM output_dir: If a directory is specified, the ESM will be saved locally after training num_epochs: The number of epochs for training the ESM batch_size: The batch size for training the ESM reset_model: If set to False, the same model with be trained further with multiple calls of the function. verbose: 0 hides everything, 1 shows the complete training of the ESM, and 2 shows the ESM training epochs. Returns: The resulting ESM """ if self.model is None or reset_model: self.model = self._create_model( architecture=architecture, embedding_dim=embedding_dataset.embedding_dim ) self.model.to(self.device) if self.optimizer is None: self.optimizer = self._create_optimizer(model=self.model) sampler = RandomSampler(embedding_dataset) dataloader = DataLoader( embedding_dataset, sampler=sampler, batch_size=batch_size ) num_train_steps = len(dataloader) * num_epochs self.scheduler = self._create_scheduler( optimizer=self.optimizer, num_train_steps=num_train_steps ) epoch_train_durations = [] epoch_avg_losses = [] with tqdm( range(num_epochs), desc="Training ESM", unit="epoch", disable=verbose < 1 ) as epoch_pbar: for epoch_i in epoch_pbar: self.reset_loss() start_time = time.perf_counter() with tqdm( dataloader, desc=f"Training: Epoch {epoch_i} / {num_epochs}", unit="batch", disable=verbose < 2, ) as batch_pbar: for batch in batch_pbar: loss = self._train_step(batch) avg_train_loss = loss / batch_size epoch_pbar.set_postfix(avg_train_loss=avg_train_loss) batch_pbar.set_postfix(avg_train_loss=avg_train_loss) end_time = time.perf_counter() epoch_train_durations.append(end_time - start_time) epoch_avg_losses.append(self.avg_loss) self.model.config = ESMConfig( esm_num_epochs=num_epochs, esm_learning_rate=self.learning_rate, esm_weight_decay=self.weight_decay, esm_batch_size=batch_size, esm_architecture=architecture, esm_embedding_dim=embedding_dataset.embedding_dim, ) self.model.config.update(embedding_dataset.metadata) if output_dir: if os.path.isdir(output_dir): logger.warning(f"Overwriting ESM at path: {output_dir}") self.model.save_pretrained(output_dir) train_info_dict = { "training_completed_timestamp": datetime.now().strftime( "%m/%d/%Y, %H:%M:%S" ), "num_epochs": num_epochs, "num_train_examples": len(embedding_dataset), "epoch_train_durations": epoch_train_durations, "epoch_avg_losses": epoch_avg_losses, } with open(os.path.join(output_dir, "train_info.json"), "w") as f: json.dump(train_info_dict, f) return self.model
[docs] def train_with_models( self, dataset: Dataset, base_model: PreTrainedModel, tuned_model: PreTrainedModel, tokenizer: PreTrainedTokenizer, architecture: Optional[ Union[str, dict[str, Union[str, tuple[str]]]] ] = "linear", model_output_dir: Optional[str] = None, embeddings_output_filepath: Optional[str] = None, num_epochs: int = 10, train_batch_size: int = 32, embeddings_batch_size: int = 128, device_name: str = "cpu", ) -> ESM: """ Trains an ESM using a dataset, a base language model and a fine-tuned language model. Internally, an EmbeddingDataset is created. Following this, the train_with_embeddings is called and the ESM is fitted to the embedding pairs in the dataset. Args: dataset: The dataset used for fine-tuning the language model base_model: The base language model tuned_model: The fine-tuned language model tokenizer: The tokenizer for processing input texts architecture: The desired architecture of the ESM model_output_dir: If a directory is specified, the ESM will be saved locally after training embeddings_output_filepath: If a filepath is specified, the EmbeddingDataset will be saved locally num_epochs: The number of epochs for training the ESM train_batch_size: The batch size for training the ESM embeddings_batch_size: The batch size for creating the EmbeddingDataset device_name: The device name of the device for computation (e.g. "cpu", "cuda") Returns: The resulting ESM """ embedding_dataset = create_embedding_dataset( dataset=dataset, base_model=base_model, tuned_model=tuned_model, tokenizer=tokenizer, batch_size=embeddings_batch_size, device_name=device_name, ) if embeddings_output_filepath: embedding_dataset.save(embeddings_output_filepath) esm = self.train_with_embeddings( embedding_dataset=embedding_dataset, architecture=architecture, output_dir=model_output_dir, num_epochs=num_epochs, batch_size=train_batch_size, ) return esm