import warnings
import numpy as np
from numba import njit
[docs]
@njit
def each_evidence(y_, f, fh, v, s, vh, N, D):
"""
compute the maximum evidence for each class
"""
epsilon = 1e-5
alpha = 1.0
beta = 1.0
lam = alpha / beta
tmp = vh @ (f @ np.ascontiguousarray(y_))
for i in range(100):
# print(f'{i}/100')
# should converge after at most 10 steps
# typically converge after two or three steps
gamma = (s / (s + lam)).sum()
# A = v @ np.diag(alpha + beta * s) @ v.transpose() # no need to compute A
# A_inv = v @ np.diag(1.0 / (alpha + beta * s)) @ v.transpose() # no need to compute A_inv
m = v @ (tmp * beta / (alpha + beta * s))
alpha_de = (m * m).sum()
alpha = gamma / (alpha_de + epsilon)
beta_de = ((y_ - fh @ m) ** 2).sum()
beta = (N - gamma) / (beta_de + epsilon)
new_lam = alpha / beta
print(np.abs(new_lam - lam))
if np.abs(new_lam - lam) / lam < 0.01:
break
lam = new_lam
evidence = (
D / 2.0 * np.log(alpha)
+ N / 2.0 * np.log(beta)
- 0.5 * np.sum(np.log(alpha + beta * s))
- beta / 2.0 * (beta_de + epsilon)
- alpha / 2.0 * (alpha_de + epsilon)
- N / 2.0 * np.log(2 * np.pi)
)
return evidence / N, alpha, beta, m
# use pseudo data to compile the function
# D = 20, N = 50
# f_tmp = np.random.randn(20, 50).astype(np.float64)
# each_evidence(np.random.randint(0, 2, 50).astype(np.float64), f_tmp, f_tmp.transpose(), np.eye(20, dtype=np.float64), np.ones(20, dtype=np.float64), np.eye(20, dtype=np.float64), 50, 20)
[docs]
@njit
def truncated_svd(x):
u, s, vh = np.linalg.svd(x.transpose() @ x)
s = np.sqrt(s)
u_times_sigma = x @ vh.transpose()
k = np.sum((s > 1e-10) * 1) # rank of f
s = s.reshape(-1, 1)
s = s[:k]
vh = vh[:k]
u = u_times_sigma[:, :k] / s.reshape(1, -1)
return u, s, vh
# truncated_svd(np.random.randn(20, 10).astype(np.float64))
[docs]
class LogME(object):
def __init__(self, regression=False):
"""
:param regression: whether regression
"""
# print('logme')
self.regression = regression
self.fitted = False
self.reset()
[docs]
def reset(self):
self.num_dim = 0
self.alphas = [] # alpha for each class / dimension
self.betas = [] # beta for each class / dimension
# self.ms.shape --> [C, D]
self.ms = [] # m for each class / dimension
def _fit_icml(self, f: np.ndarray, y: np.ndarray):
"""
LogME calculation proposed in the ICML 2021 paper
"LogME: Practical Assessment of Pre-trained Models for Transfer Learning"
at http://proceedings.mlr.press/v139/you21b.html
"""
fh = f
f = f.transpose()
D, N = f.shape
v, s, vh = np.linalg.svd(f @ fh, full_matrices=True)
evidences = []
self.num_dim = y.shape[1] if self.regression else int(y.max() + 1)
for i in range(self.num_dim):
print(f"k: {i}")
y_ = y[:, i] if self.regression else (y == i).astype(np.float64)
evidence, alpha, beta, m = each_evidence(y_, f, fh, v, s, vh, N, D)
evidences.append(evidence)
self.alphas.append(alpha)
self.betas.append(beta)
self.ms.append(m)
self.ms = np.stack(self.ms)
return np.mean(evidences)
def _fit_fixed_point(self, f: np.ndarray, y: np.ndarray):
"""
LogME calculation proposed in the arxiv 2021 paper
"Ranking and Tuning Pre-trained Models: A New Paradigm of Exploiting Model Hubs"
at https://arxiv.org/abs/2110.10545
"""
N, D = f.shape # k = min(N, D)
if N > D: # direct SVD may be expensive
u, s, vh = truncated_svd(f)
else:
u, s, vh = np.linalg.svd(f, full_matrices=False)
# u.shape = N x k
# s.shape = k
# vh.shape = k x D
s = s.reshape(-1, 1)
sigma = s**2
evidences = []
self.num_dim = y.shape[1] if self.regression else int(y.max() + 1)
for i in range(self.num_dim):
# print(f'k: {i}')
y_ = y[:, i] if self.regression else (y == i).astype(np.float64)
y_ = y_.reshape(-1, 1)
x = u.T @ y_ # x has shape [k, 1], but actually x should have shape [N, 1]
x2 = x**2
res_x2 = (
(y_**2).sum() - x2.sum()
) # if k < N, we compute sum of xi for 0 singular values directly
alpha, beta = 1.0, 1.0
for _ in range(50):
t = alpha / beta
gamma = (sigma / (sigma + t)).sum()
m2 = (sigma * x2 / ((t + sigma) ** 2)).sum()
res2 = (x2 / ((1 + sigma / t) ** 2)).sum() + res_x2
alpha = gamma / (m2 + 1e-5)
beta = (N - gamma) / (res2 + 1e-5)
t_ = alpha / beta
# print(abs(t_ - t) / t)
# evidence = D / 2.0 * np.log(alpha) \
# + N / 2.0 * np.log(beta) \
# - 0.5 * np.sum(np.log(alpha + beta * sigma)) \
# - beta / 2.0 * res2 \
# - alpha / 2.0 * m2 \
# - N / 2.0 * np.log(2 * np.pi)
# evidence /= N
if (
abs(t_ - t) / t <= 1e-3
): # abs(t_ - t) <= 1e-5 or abs(1 / t_ - 1 / t) <= 1e-5:
break
# if abs(t_ - t) / t > 1e-3:
# print('Not converged!')
evidence = (
D / 2.0 * np.log(alpha)
+ N / 2.0 * np.log(beta)
- 0.5 * np.sum(np.log(alpha + beta * sigma))
- beta / 2.0 * res2
- alpha / 2.0 * m2
- N / 2.0 * np.log(2 * np.pi)
)
evidence /= N
m = 1.0 / (t + sigma) * s * x
m = (vh.T @ m).reshape(-1)
evidences.append(evidence)
self.alphas.append(alpha)
self.betas.append(beta)
self.ms.append(m)
self.ms = np.stack(self.ms)
return np.mean(evidences)
_fit = _fit_fixed_point
# _fit = _fit_icml
[docs]
def fit(self, f: np.ndarray, y: np.ndarray, add_intercept=False):
"""
:param f: [N, F], feature matrix from pre-trained model
:param y: target labels.
For classification, y has shape [N] with element in [0, C_t).
For regression, y has shape [N, C] with C regression-labels
:return: LogME score (how well f can fit y directly)
"""
if self.fitted:
warnings.warn("re-fitting for new data. old parameters cleared.")
self.reset()
else:
self.fitted = True
if add_intercept:
f = np.hstack((np.ones((f.shape[0], 1)), f))
f = f.astype(np.float64)
if self.regression:
y = y.astype(np.float64)
if len(y.shape) == 1:
y = y.reshape(-1, 1)
return self._fit(f, y)
[docs]
def predict(self, f: np.ndarray):
"""
:param f: [N, F], feature matrix
:return: prediction, return shape [N, X]
"""
if not self.fitted:
raise RuntimeError("not fitted, please call fit first")
f = f.astype(np.float64)
logits = f @ self.ms.T
if self.regression:
return logits
return np.argmax(logits, axis=-1)