Sophia Tang
Initial commit with LFS
7efee70
# evaluation metrics
import torch
import math
from functools import partial
from typing import Optional
import numpy as np
import ot as pot
import torch
def wasserstein(
x0: torch.Tensor,
x1: torch.Tensor,
method: Optional[str] = None,
reg: float = 0.05,
power: int = 2,
**kwargs,
) -> float:
assert power == 1 or power == 2
# ot_fn should take (a, b, M) as arguments where a, b are marginals and
# M is a cost matrix
if method == "exact" or method is None:
ot_fn = pot.emd2
elif method == "sinkhorn":
ot_fn = partial(pot.sinkhorn2, reg=reg)
else:
raise ValueError(f"Unknown method: {method}")
a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0])
if x0.dim() > 2:
x0 = x0.reshape(x0.shape[0], -1)
if x1.dim() > 2:
x1 = x1.reshape(x1.shape[0], -1)
M = torch.cdist(x0, x1)
if power == 2:
M = M**2
ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7)
if power == 2:
ret = math.sqrt(ret)
return ret
min_var_est = 1e-8
# Consider linear time MMD with a linear kernel:
# K(f(x), f(y)) = f(x)^Tf(y)
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
# = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
#
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
def linear_mmd2(f_of_X, f_of_Y):
loss = 0.0
delta = f_of_X - f_of_Y
loss = torch.mean((delta[:-1] * delta[1:]).sum(1))
return loss
# Consider linear time MMD with a polynomial kernel:
# K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0):
K_XX = alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c
K_XX_mean = torch.mean(K_XX.pow(d))
K_YY = alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c
K_YY_mean = torch.mean(K_YY.pow(d))
K_XY = alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c
K_XY_mean = torch.mean(K_XY.pow(d))
K_YX = alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c
K_YX_mean = torch.mean(K_YX.pow(d))
return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean
def _mix_rbf_kernel(X, Y, sigma_list):
assert X.size(0) == Y.size(0)
m = X.size(0)
Z = torch.cat((X, Y), 0)
ZZT = torch.mm(Z, Z.t())
diag_ZZT = torch.diag(ZZT).unsqueeze(1)
Z_norm_sqr = diag_ZZT.expand_as(ZZT)
exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()
K = 0.0
for sigma in sigma_list:
gamma = 1.0 / (2 * sigma**2)
K += torch.exp(-gamma * exponent)
return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)
def mix_rbf_mmd2(X, Y, sigma_list, biased=True):
K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
# return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True):
K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
# return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
################################################################################
# Helper functions to compute variances based on kernel matrices
################################################################################
def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
m = K_XX.size(0) # assume X, Y are same shape
# Get the various sums of kernels that we'll use
# Kts drop the diagonal, but we don't need to compute them explicitly
if const_diagonal is not False:
diag_X = diag_Y = const_diagonal
sum_diag_X = sum_diag_Y = m * const_diagonal
else:
diag_X = torch.diag(K_XX) # (m,)
diag_Y = torch.diag(K_YY) # (m,)
sum_diag_X = torch.sum(diag_X)
sum_diag_Y = torch.sum(diag_Y)
Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X
Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y
K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e
Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e
Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e
K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e
if biased:
mmd2 = (
(Kt_XX_sum + sum_diag_X) / (m * m)
+ (Kt_YY_sum + sum_diag_Y) / (m * m)
- 2.0 * K_XY_sum / (m * m)
)
else:
mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m)
return mmd2
def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
mmd2, var_est = _mmd2_and_variance(
K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased
)
loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est))
return loss, mmd2, var_est
def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
m = K_XX.size(0) # assume X, Y are same shape
# Get the various sums of kernels that we'll use
# Kts drop the diagonal, but we don't need to compute them explicitly
if const_diagonal is not False:
diag_X = diag_Y = const_diagonal
sum_diag_X = sum_diag_Y = m * const_diagonal
sum_diag2_X = sum_diag2_Y = m * const_diagonal**2
else:
diag_X = torch.diag(K_XX) # (m,)
diag_Y = torch.diag(K_YY) # (m,)
sum_diag_X = torch.sum(diag_X)
sum_diag_Y = torch.sum(diag_Y)
sum_diag2_X = diag_X.dot(diag_X)
sum_diag2_Y = diag_Y.dot(diag_Y)
Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X
Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y
K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e
K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e
Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e
Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e
K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e
Kt_XX_2_sum = (K_XX**2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2
Kt_YY_2_sum = (K_YY**2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2
K_XY_2_sum = (K_XY**2).sum() # \| K_{XY} \|_F^2
if biased:
mmd2 = (
(Kt_XX_sum + sum_diag_X) / (m * m)
+ (Kt_YY_sum + sum_diag_Y) / (m * m)
- 2.0 * K_XY_sum / (m * m)
)
else:
mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m)
var_est = (
2.0
/ (m**2 * (m - 1.0) ** 2)
* (
2 * Kt_XX_sums.dot(Kt_XX_sums)
- Kt_XX_2_sum
+ 2 * Kt_YY_sums.dot(Kt_YY_sums)
- Kt_YY_2_sum
)
- (4.0 * m - 6.0) / (m**3 * (m - 1.0) ** 3) * (Kt_XX_sum**2 + Kt_YY_sum**2)
+ 4.0
* (m - 2.0)
/ (m**3 * (m - 1.0) ** 2)
* (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0))
- 4.0 * (m - 3.0) / (m**3 * (m - 1.0) ** 2) * (K_XY_2_sum)
- (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2
+ 8.0
/ (m**3 * (m - 1.0))
* (
1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
- Kt_XX_sums.dot(K_XY_sums_1)
- Kt_YY_sums.dot(K_XY_sums_0)
)
)
return mmd2, var_est
from typing import Union
def compute_distances(pred, true):
"""Computes distances between vectors."""
mse = torch.nn.functional.mse_loss(pred, true).item()
me = math.sqrt(mse)
mae = torch.mean(torch.abs(pred - true)).item()
return mse, me, mae
def compute_distribution_distances(pred: torch.Tensor, true: Union[torch.Tensor, list]):
"""computes distances between distributions.
This handles jagged times as a list of tensors.
"""
NAMES = [
"1-Wasserstein",
"2-Wasserstein",
"RBF_MMD",
"Mean_MSE",
"Mean_L2",
"Mean_L1",
"Median_MSE",
"Median_L2",
"Median_L1",
"Eq-EMD2",
]
a = pred
b = true
pred_2d = pred[:, :2]
true_2d = true[:, :2]
w1 = wasserstein(pred_2d, true_2d, power=1)
w2 = wasserstein(pred_2d, true_2d, power=2)
mmd_rbf = mix_rbf_mmd2(a, b, sigma_list=[0.01, 0.1, 1, 10, 100]).item()
mean_dists = compute_distances(torch.mean(a, dim=0), torch.mean(b, dim=0))
median_dists = compute_distances(torch.median(a, dim=0)[0], torch.median(b, dim=0)[0])
dists = [w1, w2, mmd_rbf, *mean_dists, *median_dists]
return NAMES, dists
def compute_wasserstein_distances(pred: torch.Tensor, true: Union[torch.Tensor, list]):
"""computes distances between distributions.
This handles jagged times as a list of tensors.
"""
NAMES = [
"1-Wasserstein",
"2-Wasserstein",
]
pred_2d = pred[:, :2]
true_2d = true[:, :2]
w1 = wasserstein(pred_2d, true_2d, power=1)
w2 = wasserstein(pred_2d, true_2d, power=2)
dists = [w1, w2]
return NAMES, dists