|
|
import argparse |
|
|
import json |
|
|
from pathlib import Path |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
import torchio as tio |
|
|
from torchvision.utils import save_image |
|
|
from matplotlib.pyplot import get_cmap |
|
|
|
|
|
from models import MSTRegression |
|
|
|
|
|
|
|
|
|
|
|
def minmax_norm(x): |
|
|
"""Normalizes input to [0, 1] for each batch and channel""" |
|
|
return (x - x.min()) / (x.max() - x.min()) |
|
|
|
|
|
def tensor2image(tensor, batch=0): |
|
|
"""Transform tensor into shape of multiple 2D RGB/gray images. """ |
|
|
return (tensor if tensor.ndim<5 else torch.swapaxes(tensor[batch], 0, 1).reshape(-1, *tensor.shape[-2:])[:,None]) |
|
|
|
|
|
def tensor_cam2image(tensor, cam, batch=0, alpha=0.5, color_map=get_cmap('jet')): |
|
|
"""Transform a tensor and a (grad) cam into multiple 2D RGB images.""" |
|
|
img = tensor2image(tensor, batch) |
|
|
img = torch.cat([img for _ in range(3)], dim=1) if img.shape[1]!=3 else img |
|
|
cam_img = tensor2image(cam, batch) |
|
|
cam_img = cam_img[:,0].cpu().numpy() |
|
|
cam_img = torch.tensor(color_map(cam_img)) |
|
|
cam_img = torch.moveaxis(cam_img, -1, 1)[:, :3] |
|
|
overlay = (1-alpha)*img + alpha*cam_img |
|
|
return overlay |
|
|
|
|
|
|
|
|
|
|
|
def crop_breast_height(image, margin_top=10) -> tio.Crop: |
|
|
"""Crop height to 256 and try to cover breast based on intensity localization""" |
|
|
threshold = int(np.quantile(image.data.float(), 0.9)) |
|
|
foreground = image.data>threshold |
|
|
fg_rows = foreground[0].sum(axis=(0, 2)) |
|
|
top = min(max(512-int(torch.argwhere(fg_rows).max()) - margin_top, 0), 256) |
|
|
bottom = 256-top |
|
|
return tio.Crop((0,0, bottom, top, 0, 0)) |
|
|
|
|
|
|
|
|
def get_bilateral_transform(img:tio.ScalarImage, ref_img=None, target_spacing = (0.7, 0.7, 3), target_shape = (512, 512, 32)): |
|
|
|
|
|
ref_img = img if ref_img is None else ref_img |
|
|
|
|
|
|
|
|
ref_img = tio.ToCanonical()(ref_img) |
|
|
ref_img = tio.Resample(target_spacing)(ref_img) |
|
|
resample = tio.Resample(ref_img) |
|
|
|
|
|
|
|
|
ref_img = tio.CropOrPad(target_shape, padding_mode='minimum')(ref_img) |
|
|
crop_height = crop_breast_height(ref_img) |
|
|
|
|
|
|
|
|
trans = tio.Compose([ |
|
|
resample, |
|
|
tio.CropOrPad(target_shape, padding_mode='minimum'), |
|
|
crop_height, |
|
|
]) |
|
|
|
|
|
trans_inv = tio.Compose([ |
|
|
crop_height.inverse(), |
|
|
tio.CropOrPad(img.spatial_shape, padding_mode='minimum'), |
|
|
tio.Resample(img), |
|
|
]) |
|
|
return trans(img), trans_inv |
|
|
|
|
|
def get_unilateral_transform(img: tio.ScalarImage, target_shape=(224, 224, 32)): |
|
|
transform = tio.Compose([ |
|
|
tio.Flip((1,0)), |
|
|
tio.CropOrPad(target_shape), |
|
|
tio.ZNormalization(masking_method=lambda x:(x>x.min()) & (x<x.max())), |
|
|
]) |
|
|
inv_transform = tio.Compose([ |
|
|
tio.CropOrPad(img.spatial_shape), |
|
|
tio.Flip((1,0)), |
|
|
]) |
|
|
return transform(img), inv_transform |
|
|
|
|
|
|
|
|
def run_prediction(img: tio.ScalarImage, model: MSTRegression): |
|
|
img_bil, bil_trans_rev = get_bilateral_transform(img) |
|
|
split_side = { |
|
|
'right': tio.Crop((256, 0, 0, 0, 0, 0)), |
|
|
'left': tio.Crop((0, 256, 0, 0, 0, 0)), |
|
|
} |
|
|
|
|
|
weights, probs = {}, {} |
|
|
for side, crop in split_side.items(): |
|
|
img_side = crop(img_bil) |
|
|
img_side, uni_trans_inv = get_unilateral_transform(img_side) |
|
|
img_side = img_side.data.swapaxes(1,-1) |
|
|
img_side = img_side.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
device = next(model.parameters()).device |
|
|
logits, weight, weight_slice = model.forward_attention(img_side.to(device)) |
|
|
|
|
|
weight = F.interpolate(weight.unsqueeze(1), size=img_side.shape[2:], mode='trilinear', align_corners=False).cpu() |
|
|
|
|
|
pred_prob = F.softmax(logits, dim=-1).cpu() |
|
|
probs[side] = pred_prob.squeeze(0) |
|
|
|
|
|
weight = weight.squeeze(0).swapaxes(1,-1) |
|
|
weight = uni_trans_inv(weight) |
|
|
weights[side] = weight |
|
|
|
|
|
weight = torch.concat([weights['left'], weights['right']], dim=1) |
|
|
weight = tio.ScalarImage(tensor=weight, affine=img_bil.affine) |
|
|
weight = bil_trans_rev(weight) |
|
|
weight.set_data(minmax_norm(weight.data)) |
|
|
return probs, weight |
|
|
|
|
|
def load_model(repo_id= "ODELIA-AI/MST") -> MSTRegression: |
|
|
|
|
|
config_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="model_config.json") |
|
|
with open(config_path, "r", encoding="utf-8") as fp: |
|
|
config = json.load(fp) |
|
|
|
|
|
hparams = config.get("hparams", {}) |
|
|
model = MSTRegression(weights=False, **hparams) |
|
|
|
|
|
state_dict_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="state_dict.pt") |
|
|
state_dict = torch.load(state_dict_path, map_location="cpu") |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
return model |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--path_img', default='/home/homesOnMaster/gfranzes/Documents/datasets/ODELIA/UKA/data/UKA_2/Sub_1.nii.gz', type=str) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
path_out_dir = Path().cwd()/'results/test_attention' |
|
|
path_out_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
path_img = Path(args.path_img) |
|
|
img = tio.ScalarImage(path_img) |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = load_model() |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
probs, weight = run_prediction(img, model) |
|
|
|
|
|
img.save(path_out_dir/f"input.nii.gz") |
|
|
weight.save(path_out_dir/f"attention.nii.gz") |
|
|
weight = weight.data.swapaxes(1,-1).unsqueeze(0) |
|
|
img = img.data.swapaxes(1,-1).unsqueeze(0) |
|
|
save_image(tensor_cam2image(minmax_norm(img), minmax_norm(weight), alpha=0.5), |
|
|
path_out_dir/f"overlay.png", normalize=False) |
|
|
|
|
|
for side in ['left', 'right']: |
|
|
print(f"{side} breast predicted probabilities: {probs[side]}") |
|
|
|
|
|
|