Image Classification
English
breast
cancer
odelia
File size: 6,387 Bytes
255fb0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb68040
255fb0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import argparse
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
import torch 
import numpy as np
import torch.nn.functional as F
import torchio as tio
from torchvision.utils import save_image
from matplotlib.pyplot import get_cmap

from models import MSTRegression



def minmax_norm(x):
    """Normalizes input to [0, 1] for each batch and channel"""
    return (x - x.min()) / (x.max() - x.min())

def tensor2image(tensor, batch=0):
    """Transform tensor into shape of multiple 2D RGB/gray images. """
    return (tensor if tensor.ndim<5 else torch.swapaxes(tensor[batch], 0, 1).reshape(-1, *tensor.shape[-2:])[:,None])

def tensor_cam2image(tensor, cam, batch=0, alpha=0.5, color_map=get_cmap('jet')):
    """Transform a tensor and a (grad) cam into multiple 2D RGB images."""
    img = tensor2image(tensor, batch) #  -> [B, C, H, W]
    img = torch.cat([img for _ in range(3)], dim=1) if img.shape[1]!=3 else img # Ensure RGB  [B, 3, H, W] 
    cam_img = tensor2image(cam, batch) #  -> [B, 1, H, W]
    cam_img = cam_img[:,0].cpu().numpy() # -> [B, H, W]
    cam_img = torch.tensor(color_map(cam_img)) # -> [B, H, W, 4], color_map expects input to be [0.0, 1.0]
    cam_img = torch.moveaxis(cam_img, -1, 1)[:, :3] # -> [B, 3, H, W]
    overlay = (1-alpha)*img + alpha*cam_img
    return overlay



def crop_breast_height(image, margin_top=10) -> tio.Crop:
    """Crop height to 256 and try to cover breast based on intensity localization"""
    threshold = int(np.quantile(image.data.float(), 0.9))
    foreground = image.data>threshold
    fg_rows = foreground[0].sum(axis=(0, 2))
    top = min(max(512-int(torch.argwhere(fg_rows).max()) - margin_top, 0), 256)
    bottom = 256-top
    return  tio.Crop((0,0, bottom, top, 0, 0))


def get_bilateral_transform(img:tio.ScalarImage, ref_img=None, target_spacing = (0.7, 0.7, 3), target_shape = (512, 512, 32)):
    # -------- Settings --------------
    ref_img = img if ref_img is None else ref_img
    
    # Spacing 
    ref_img = tio.ToCanonical()(ref_img)
    ref_img = tio.Resample(target_spacing)(ref_img)
    resample = tio.Resample(ref_img)

    # Crop 
    ref_img = tio.CropOrPad(target_shape, padding_mode='minimum')(ref_img)
    crop_height = crop_breast_height(ref_img)     

    # Process input image
    trans = tio.Compose([
        resample,
        tio.CropOrPad(target_shape, padding_mode='minimum'),
        crop_height,
    ])

    trans_inv = tio.Compose([
        crop_height.inverse(),
        tio.CropOrPad(img.spatial_shape, padding_mode='minimum'),
        tio.Resample(img),
    ])
    return trans(img), trans_inv

def get_unilateral_transform(img: tio.ScalarImage, target_shape=(224, 224, 32)):
    transform = tio.Compose([
        tio.Flip((1,0)), 
        tio.CropOrPad(target_shape),
        tio.ZNormalization(masking_method=lambda x:(x>x.min()) & (x<x.max())), 
    ])
    inv_transform = tio.Compose([
        tio.CropOrPad(img.spatial_shape),
        tio.Flip((1,0)), 
    ])
    return transform(img), inv_transform


def run_prediction(img: tio.ScalarImage, model: MSTRegression):
    img_bil, bil_trans_rev = get_bilateral_transform(img)
    split_side = {
        'right': tio.Crop((256, 0, 0, 0, 0, 0)),
        'left': tio.Crop((0, 256, 0, 0, 0, 0)),
    }

    weights, probs = {}, {}
    for side, crop in split_side.items():
        img_side = crop(img_bil)
        img_side, uni_trans_inv = get_unilateral_transform(img_side)
        img_side = img_side.data.swapaxes(1,-1)
        img_side = img_side.unsqueeze(0)  # Add batch dim -> [1, C, H, W, D]

        with torch.no_grad():
            device = next(model.parameters()).device
            logits, weight, weight_slice = model.forward_attention(img_side.to(device))

        weight = F.interpolate(weight.unsqueeze(1), size=img_side.shape[2:], mode='trilinear', align_corners=False).cpu()
        # pred_prob = model.logits2probabilities(logits).cpu()
        pred_prob = F.softmax(logits, dim=-1).cpu()
        probs[side] = pred_prob.squeeze(0)

        weight = weight.squeeze(0).swapaxes(1,-1)  # ->[C, W, H, D]
        weight = uni_trans_inv(weight)
        weights[side] = weight

    weight = torch.concat([weights['left'], weights['right']], dim=1) #  C, W, H, D
    weight = tio.ScalarImage(tensor=weight, affine=img_bil.affine)  
    weight = bil_trans_rev(weight)
    weight.set_data(minmax_norm(weight.data))  
    return probs, weight

def load_model(repo_id= "ODELIA-AI/MST") -> MSTRegression:
    # Download config and state dict
    config_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="model_config.json")
    with open(config_path, "r", encoding="utf-8") as fp:
        config = json.load(fp)

    hparams = config.get("hparams", {})
    model = MSTRegression(weights=False, **hparams)

    state_dict_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="state_dict.pt")
    state_dict = torch.load(state_dict_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=True)
    return model


if __name__ == "__main__":
    #------------ Get Arguments ----------------
    parser = argparse.ArgumentParser()
    parser.add_argument('--path_img', default='/home/homesOnMaster/gfranzes/Documents/datasets/ODELIA/UKA/data/UKA_2/Sub_1.nii.gz', type=str)
    args = parser.parse_args()


    #------------ Settings/Defaults ----------------
    path_out_dir = Path().cwd()/'results/test_attention'
    path_out_dir.mkdir(parents=True, exist_ok=True)


    # ------------ Load Data ----------------
    path_img = Path(args.path_img)
    img = tio.ScalarImage(path_img)


    # ------------ Initialize Model ------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model()
    model.to(device)
    model.eval()


    # ------------ Predict ----------------
    probs, weight = run_prediction(img, model)

    img.save(path_out_dir/f"input.nii.gz")
    weight.save(path_out_dir/f"attention.nii.gz")
    weight = weight.data.swapaxes(1,-1).unsqueeze(0)  # C, D, H, W
    img = img.data.swapaxes(1,-1).unsqueeze(0)  # C, D, H, W
    save_image(tensor_cam2image(minmax_norm(img), minmax_norm(weight), alpha=0.5), 
            path_out_dir/f"overlay.png", normalize=False)
    
    for side in ['left', 'right']:
        print(f"{side} breast predicted probabilities: {probs[side]}")