import argparse import json from pathlib import Path from huggingface_hub import hf_hub_download import torch import numpy as np import torch.nn.functional as F import torchio as tio from torchvision.utils import save_image from matplotlib.pyplot import get_cmap from models import MSTRegression def minmax_norm(x): """Normalizes input to [0, 1] for each batch and channel""" return (x - x.min()) / (x.max() - x.min()) def tensor2image(tensor, batch=0): """Transform tensor into shape of multiple 2D RGB/gray images. """ return (tensor if tensor.ndim<5 else torch.swapaxes(tensor[batch], 0, 1).reshape(-1, *tensor.shape[-2:])[:,None]) def tensor_cam2image(tensor, cam, batch=0, alpha=0.5, color_map=get_cmap('jet')): """Transform a tensor and a (grad) cam into multiple 2D RGB images.""" img = tensor2image(tensor, batch) # -> [B, C, H, W] img = torch.cat([img for _ in range(3)], dim=1) if img.shape[1]!=3 else img # Ensure RGB [B, 3, H, W] cam_img = tensor2image(cam, batch) # -> [B, 1, H, W] cam_img = cam_img[:,0].cpu().numpy() # -> [B, H, W] cam_img = torch.tensor(color_map(cam_img)) # -> [B, H, W, 4], color_map expects input to be [0.0, 1.0] cam_img = torch.moveaxis(cam_img, -1, 1)[:, :3] # -> [B, 3, H, W] overlay = (1-alpha)*img + alpha*cam_img return overlay def crop_breast_height(image, margin_top=10) -> tio.Crop: """Crop height to 256 and try to cover breast based on intensity localization""" threshold = int(np.quantile(image.data.float(), 0.9)) foreground = image.data>threshold fg_rows = foreground[0].sum(axis=(0, 2)) top = min(max(512-int(torch.argwhere(fg_rows).max()) - margin_top, 0), 256) bottom = 256-top return tio.Crop((0,0, bottom, top, 0, 0)) def get_bilateral_transform(img:tio.ScalarImage, ref_img=None, target_spacing = (0.7, 0.7, 3), target_shape = (512, 512, 32)): # -------- Settings -------------- ref_img = img if ref_img is None else ref_img # Spacing ref_img = tio.ToCanonical()(ref_img) ref_img = tio.Resample(target_spacing)(ref_img) resample = tio.Resample(ref_img) # Crop ref_img = tio.CropOrPad(target_shape, padding_mode='minimum')(ref_img) crop_height = crop_breast_height(ref_img) # Process input image trans = tio.Compose([ resample, tio.CropOrPad(target_shape, padding_mode='minimum'), crop_height, ]) trans_inv = tio.Compose([ crop_height.inverse(), tio.CropOrPad(img.spatial_shape, padding_mode='minimum'), tio.Resample(img), ]) return trans(img), trans_inv def get_unilateral_transform(img: tio.ScalarImage, target_shape=(224, 224, 32)): transform = tio.Compose([ tio.Flip((1,0)), tio.CropOrPad(target_shape), tio.ZNormalization(masking_method=lambda x:(x>x.min()) & (x [1, C, H, W, D] with torch.no_grad(): device = next(model.parameters()).device logits, weight, weight_slice = model.forward_attention(img_side.to(device)) weight = F.interpolate(weight.unsqueeze(1), size=img_side.shape[2:], mode='trilinear', align_corners=False).cpu() # pred_prob = model.logits2probabilities(logits).cpu() pred_prob = F.softmax(logits, dim=-1).cpu() probs[side] = pred_prob.squeeze(0) weight = weight.squeeze(0).swapaxes(1,-1) # ->[C, W, H, D] weight = uni_trans_inv(weight) weights[side] = weight weight = torch.concat([weights['left'], weights['right']], dim=1) # C, W, H, D weight = tio.ScalarImage(tensor=weight, affine=img_bil.affine) weight = bil_trans_rev(weight) weight.set_data(minmax_norm(weight.data)) return probs, weight def load_model(repo_id= "ODELIA-AI/MST") -> MSTRegression: # Download config and state dict config_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="model_config.json") with open(config_path, "r", encoding="utf-8") as fp: config = json.load(fp) hparams = config.get("hparams", {}) model = MSTRegression(weights=False, **hparams) state_dict_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="state_dict.pt") state_dict = torch.load(state_dict_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) return model if __name__ == "__main__": #------------ Get Arguments ---------------- parser = argparse.ArgumentParser() parser.add_argument('--path_img', default='/home/homesOnMaster/gfranzes/Documents/datasets/ODELIA/UKA/data/UKA_2/Sub_1.nii.gz', type=str) args = parser.parse_args() #------------ Settings/Defaults ---------------- path_out_dir = Path().cwd()/'results/test_attention' path_out_dir.mkdir(parents=True, exist_ok=True) # ------------ Load Data ---------------- path_img = Path(args.path_img) img = tio.ScalarImage(path_img) # ------------ Initialize Model ------------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_model() model.to(device) model.eval() # ------------ Predict ---------------- probs, weight = run_prediction(img, model) img.save(path_out_dir/f"input.nii.gz") weight.save(path_out_dir/f"attention.nii.gz") weight = weight.data.swapaxes(1,-1).unsqueeze(0) # C, D, H, W img = img.data.swapaxes(1,-1).unsqueeze(0) # C, D, H, W save_image(tensor_cam2image(minmax_norm(img), minmax_norm(weight), alpha=0.5), path_out_dir/f"overlay.png", normalize=False) for side in ['left', 'right']: print(f"{side} breast predicted probabilities: {probs[side]}")