|
|
|
|
|
|
|
|
import torch |
|
|
import math |
|
|
from functools import partial |
|
|
from typing import Optional |
|
|
|
|
|
import numpy as np |
|
|
import ot as pot |
|
|
import torch |
|
|
|
|
|
def wasserstein( |
|
|
x0: torch.Tensor, |
|
|
x1: torch.Tensor, |
|
|
method: Optional[str] = None, |
|
|
reg: float = 0.05, |
|
|
power: int = 2, |
|
|
**kwargs, |
|
|
) -> float: |
|
|
assert power == 1 or power == 2 |
|
|
|
|
|
|
|
|
if method == "exact" or method is None: |
|
|
ot_fn = pot.emd2 |
|
|
elif method == "sinkhorn": |
|
|
ot_fn = partial(pot.sinkhorn2, reg=reg) |
|
|
else: |
|
|
raise ValueError(f"Unknown method: {method}") |
|
|
|
|
|
a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0]) |
|
|
if x0.dim() > 2: |
|
|
x0 = x0.reshape(x0.shape[0], -1) |
|
|
if x1.dim() > 2: |
|
|
x1 = x1.reshape(x1.shape[0], -1) |
|
|
M = torch.cdist(x0, x1) |
|
|
if power == 2: |
|
|
M = M**2 |
|
|
ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7) |
|
|
if power == 2: |
|
|
ret = math.sqrt(ret) |
|
|
return ret |
|
|
|
|
|
min_var_est = 1e-8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_mmd2(f_of_X, f_of_Y): |
|
|
loss = 0.0 |
|
|
delta = f_of_X - f_of_Y |
|
|
loss = torch.mean((delta[:-1] * delta[1:]).sum(1)) |
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0): |
|
|
K_XX = alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c |
|
|
K_XX_mean = torch.mean(K_XX.pow(d)) |
|
|
|
|
|
K_YY = alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c |
|
|
K_YY_mean = torch.mean(K_YY.pow(d)) |
|
|
|
|
|
K_XY = alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c |
|
|
K_XY_mean = torch.mean(K_XY.pow(d)) |
|
|
|
|
|
K_YX = alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c |
|
|
K_YX_mean = torch.mean(K_YX.pow(d)) |
|
|
|
|
|
return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean |
|
|
|
|
|
|
|
|
def _mix_rbf_kernel(X, Y, sigma_list): |
|
|
assert X.size(0) == Y.size(0) |
|
|
m = X.size(0) |
|
|
|
|
|
Z = torch.cat((X, Y), 0) |
|
|
ZZT = torch.mm(Z, Z.t()) |
|
|
diag_ZZT = torch.diag(ZZT).unsqueeze(1) |
|
|
Z_norm_sqr = diag_ZZT.expand_as(ZZT) |
|
|
exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t() |
|
|
|
|
|
K = 0.0 |
|
|
for sigma in sigma_list: |
|
|
gamma = 1.0 / (2 * sigma**2) |
|
|
K += torch.exp(-gamma * exponent) |
|
|
|
|
|
return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list) |
|
|
|
|
|
|
|
|
def mix_rbf_mmd2(X, Y, sigma_list, biased=True): |
|
|
K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) |
|
|
|
|
|
return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) |
|
|
|
|
|
|
|
|
def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True): |
|
|
K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) |
|
|
|
|
|
return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): |
|
|
m = K_XX.size(0) |
|
|
|
|
|
|
|
|
|
|
|
if const_diagonal is not False: |
|
|
diag_X = diag_Y = const_diagonal |
|
|
sum_diag_X = sum_diag_Y = m * const_diagonal |
|
|
else: |
|
|
diag_X = torch.diag(K_XX) |
|
|
diag_Y = torch.diag(K_YY) |
|
|
sum_diag_X = torch.sum(diag_X) |
|
|
sum_diag_Y = torch.sum(diag_Y) |
|
|
|
|
|
Kt_XX_sums = K_XX.sum(dim=1) - diag_X |
|
|
Kt_YY_sums = K_YY.sum(dim=1) - diag_Y |
|
|
K_XY_sums_0 = K_XY.sum(dim=0) |
|
|
|
|
|
Kt_XX_sum = Kt_XX_sums.sum() |
|
|
Kt_YY_sum = Kt_YY_sums.sum() |
|
|
K_XY_sum = K_XY_sums_0.sum() |
|
|
|
|
|
if biased: |
|
|
mmd2 = ( |
|
|
(Kt_XX_sum + sum_diag_X) / (m * m) |
|
|
+ (Kt_YY_sum + sum_diag_Y) / (m * m) |
|
|
- 2.0 * K_XY_sum / (m * m) |
|
|
) |
|
|
else: |
|
|
mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m) |
|
|
|
|
|
return mmd2 |
|
|
|
|
|
|
|
|
def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): |
|
|
mmd2, var_est = _mmd2_and_variance( |
|
|
K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased |
|
|
) |
|
|
loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est)) |
|
|
return loss, mmd2, var_est |
|
|
|
|
|
|
|
|
def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): |
|
|
m = K_XX.size(0) |
|
|
|
|
|
|
|
|
|
|
|
if const_diagonal is not False: |
|
|
diag_X = diag_Y = const_diagonal |
|
|
sum_diag_X = sum_diag_Y = m * const_diagonal |
|
|
sum_diag2_X = sum_diag2_Y = m * const_diagonal**2 |
|
|
else: |
|
|
diag_X = torch.diag(K_XX) |
|
|
diag_Y = torch.diag(K_YY) |
|
|
sum_diag_X = torch.sum(diag_X) |
|
|
sum_diag_Y = torch.sum(diag_Y) |
|
|
sum_diag2_X = diag_X.dot(diag_X) |
|
|
sum_diag2_Y = diag_Y.dot(diag_Y) |
|
|
|
|
|
Kt_XX_sums = K_XX.sum(dim=1) - diag_X |
|
|
Kt_YY_sums = K_YY.sum(dim=1) - diag_Y |
|
|
K_XY_sums_0 = K_XY.sum(dim=0) |
|
|
K_XY_sums_1 = K_XY.sum(dim=1) |
|
|
|
|
|
Kt_XX_sum = Kt_XX_sums.sum() |
|
|
Kt_YY_sum = Kt_YY_sums.sum() |
|
|
K_XY_sum = K_XY_sums_0.sum() |
|
|
|
|
|
Kt_XX_2_sum = (K_XX**2).sum() - sum_diag2_X |
|
|
Kt_YY_2_sum = (K_YY**2).sum() - sum_diag2_Y |
|
|
K_XY_2_sum = (K_XY**2).sum() |
|
|
|
|
|
if biased: |
|
|
mmd2 = ( |
|
|
(Kt_XX_sum + sum_diag_X) / (m * m) |
|
|
+ (Kt_YY_sum + sum_diag_Y) / (m * m) |
|
|
- 2.0 * K_XY_sum / (m * m) |
|
|
) |
|
|
else: |
|
|
mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m) |
|
|
|
|
|
var_est = ( |
|
|
2.0 |
|
|
/ (m**2 * (m - 1.0) ** 2) |
|
|
* ( |
|
|
2 * Kt_XX_sums.dot(Kt_XX_sums) |
|
|
- Kt_XX_2_sum |
|
|
+ 2 * Kt_YY_sums.dot(Kt_YY_sums) |
|
|
- Kt_YY_2_sum |
|
|
) |
|
|
- (4.0 * m - 6.0) / (m**3 * (m - 1.0) ** 3) * (Kt_XX_sum**2 + Kt_YY_sum**2) |
|
|
+ 4.0 |
|
|
* (m - 2.0) |
|
|
/ (m**3 * (m - 1.0) ** 2) |
|
|
* (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0)) |
|
|
- 4.0 * (m - 3.0) / (m**3 * (m - 1.0) ** 2) * (K_XY_2_sum) |
|
|
- (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2 |
|
|
+ 8.0 |
|
|
/ (m**3 * (m - 1.0)) |
|
|
* ( |
|
|
1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum |
|
|
- Kt_XX_sums.dot(K_XY_sums_1) |
|
|
- Kt_YY_sums.dot(K_XY_sums_0) |
|
|
) |
|
|
) |
|
|
return mmd2, var_est |
|
|
|
|
|
from typing import Union |
|
|
def compute_distances(pred, true): |
|
|
"""Computes distances between vectors.""" |
|
|
mse = torch.nn.functional.mse_loss(pred, true).item() |
|
|
me = math.sqrt(mse) |
|
|
mae = torch.mean(torch.abs(pred - true)).item() |
|
|
return mse, me, mae |
|
|
|
|
|
def compute_distribution_distances(pred: torch.Tensor, true: Union[torch.Tensor, list]): |
|
|
"""computes distances between distributions. |
|
|
|
|
|
This handles jagged times as a list of tensors. |
|
|
""" |
|
|
NAMES = [ |
|
|
"1-Wasserstein", |
|
|
"2-Wasserstein", |
|
|
"RBF_MMD", |
|
|
"Mean_MSE", |
|
|
"Mean_L2", |
|
|
"Mean_L1", |
|
|
"Median_MSE", |
|
|
"Median_L2", |
|
|
"Median_L1", |
|
|
"Eq-EMD2", |
|
|
] |
|
|
a = pred |
|
|
b = true |
|
|
pred_2d = pred[:, :2] |
|
|
true_2d = true[:, :2] |
|
|
w1 = wasserstein(pred_2d, true_2d, power=1) |
|
|
w2 = wasserstein(pred_2d, true_2d, power=2) |
|
|
|
|
|
mmd_rbf = mix_rbf_mmd2(a, b, sigma_list=[0.01, 0.1, 1, 10, 100]).item() |
|
|
mean_dists = compute_distances(torch.mean(a, dim=0), torch.mean(b, dim=0)) |
|
|
median_dists = compute_distances(torch.median(a, dim=0)[0], torch.median(b, dim=0)[0]) |
|
|
dists = [w1, w2, mmd_rbf, *mean_dists, *median_dists] |
|
|
return NAMES, dists |
|
|
|
|
|
|
|
|
def compute_wasserstein_distances(pred: torch.Tensor, true: Union[torch.Tensor, list]): |
|
|
"""computes distances between distributions. |
|
|
|
|
|
This handles jagged times as a list of tensors. |
|
|
""" |
|
|
NAMES = [ |
|
|
"1-Wasserstein", |
|
|
"2-Wasserstein", |
|
|
] |
|
|
pred_2d = pred[:, :2] |
|
|
true_2d = true[:, :2] |
|
|
w1 = wasserstein(pred_2d, true_2d, power=1) |
|
|
w2 = wasserstein(pred_2d, true_2d, power=2) |
|
|
|
|
|
dists = [w1, w2] |
|
|
return NAMES, dists |