from typing import Optional, Union, List, Iterable
import os
from tqdm.auto import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import SequentialSampler, DataLoader
from transformers import PreTrainedModel, PreTrainedTokenizer
from .dataset import Dataset
from .model_utils import get_pooled_output
from hfselect import logger
[docs]
class InvalidEmbeddingDatasetError(Exception):
"""
This error should be raised when an embedding dataset is invalid.
"""
def __init__(self, message: str):
super().__init__(message)
[docs]
class EmbeddingDataset(TorchDataset):
"""
And EmbeddingDataset contains two sets of embeddings:
A dataset embedded using a base model and the same dataset embedded by a fine-tuned model.
It can be used to train an ESM on the transformation of the embedding space caused by fine-tuning the model.
"""
def __init__(
self,
x: Union[np.array, List[np.array]],
y: Union[np.array, List[np.array]],
metadata: Optional[dict] = None,
):
"""
Creates an embedding dataset from two sets of embeddings
Args:
x: The embeddings before fine-tuning
y: The embeddings after fine-tuning
metadata: The metadata will be forwarded to the ESMConfig when an ESM is trained using the embeddings
"""
if isinstance(x, list):
x = np.vstack(x)
if isinstance(y, list):
y = np.vstack(y)
if len(x) != len(y):
raise InvalidEmbeddingDatasetError(
f"Number of base and transformed embeddings does not match: {len(x)} != {len(y)}."
)
if x.shape[1] != y.shape[1]:
raise InvalidEmbeddingDatasetError(
f"Dimension of base and transformed embeddings does not match: {x.shape[1]} != {y.shape[1]}."
)
self.x = x
self.y = y
self.metadata = metadata or {}
self.embedding_dim = x.shape[1]
self.num_rows = len(self.x)
[docs]
@classmethod
def from_disk(cls, filepath: str):
"""
Loads an EmbeddingDataset from a local file
Args:
filepath: Filepath of the saved EmbeddingDataset
Returns:
The loaded EmbeddingDataset
"""
embeddings = np.load(filepath, allow_pickle=True)
x = embeddings["x"]
y = embeddings["y"]
if "metadata" in embeddings:
metadata = embeddings["metadata"].item()
else:
metadata = None
return EmbeddingDataset(x, y, metadata=metadata)
[docs]
def save(self, filepath: str) -> None:
"""
Saves an EmbeddingDataset to a local file
Args:
filepath: Filepath to save the embedding
Returns:
"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
np.savez(filepath, x=self.x, y=self.y, metadata=np.array(self.metadata))
def __getitem__(self, idx: Union[int, Iterable[int]]):
if isinstance(idx, int):
return self.x[idx], self.y[idx]
# return EmbeddingDataset(self.x[idx][None, :], self.y[idx][None, :])
return EmbeddingDataset(self.x[idx], self.y[idx])
def __len__(self) -> int:
return self.num_rows
[docs]
def create_embedding_dataset(
dataset: Dataset,
base_model: PreTrainedModel,
tuned_model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
device_name: str = "cpu",
output_path: Optional[str] = None,
batch_size: int = 128,
) -> EmbeddingDataset:
"""
Creates an EmbeddingDataset by embedding the same dataset with a base model and fine-tuned model
Args:
dataset: The dataset to be embedded
base_model: The base model
tuned_model: The fine-tuned model
tokenizer: The tokenizer to be used
device_name: The device name of the device for computation (e.g. "cpu", "cuda")
output_path: If an output path is passed here, the EmbeddingDataset will be saved
batch_size: The batch size for embedding the dataset
Returns:
The resulting EmbeddingDataset
"""
device = torch.device(device_name)
base_model.to(device)
tuned_model.to(device)
base_model.eval()
tuned_model.eval()
sampler = SequentialSampler(dataset)
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
collate_fn=lambda x: dataset.collate_fn(x, tokenizer=tokenizer),
)
base_embeddings = []
trained_embeddings = []
with tqdm(dataloader, desc="Computing embedding dataset", unit="batch") as pbar:
for batch in pbar:
batch = tuple(t.to(device) for t in batch)
b_input_ids, b_input_mask, _ = batch
with torch.no_grad():
base_embeddings_batch = (
get_pooled_output(base_model, b_input_ids, b_input_mask)
.cpu()
.numpy()
)
trained_embeddings_batch = (
get_pooled_output(tuned_model, b_input_ids, b_input_mask)
.cpu()
.numpy()
)
base_embeddings.append(base_embeddings_batch)
trained_embeddings.append(trained_embeddings_batch)
metadata = {
**{"base_model_name": base_model.config.name_or_path},
**dataset.metadata,
}
embedding_dataset = EmbeddingDataset(
base_embeddings, trained_embeddings, metadata=metadata
)
if output_path:
if os.path.isfile(output_path):
logger.warning(f"Overwriting embeddings dataset at path: {output_path}")
embedding_dataset.save(output_path)
return embedding_dataset