Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import pretrainedmodels | |
| from torchvision.models import densenet121 | |
| from layers import Flatten | |
| import torch | |
| import torchvision.transforms as transforms | |
| from pathlib import Path | |
| from constant import IMAGENET_MEAN, IMAGENET_STD | |
| import os | |
| import sys | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| yolov9 = os.path.join(script_dir, '..', 'chestXray14') | |
| sys.path.append(yolov9) | |
| class ChexNet(nn.Module): | |
| tfm = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) | |
| ]) | |
| def __init__(self, trained=False, model_name='20180525-222635'): | |
| super().__init__() | |
| # chexnet.parameters() is freezed except head | |
| if trained: | |
| self.load_model(model_name) | |
| else: | |
| self.load_pretrained() | |
| def load_model(self, model_name): | |
| self.backbone = densenet121(False).features | |
| self.head = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| Flatten(), | |
| nn.Linear(1024, 14) | |
| ) | |
| path = Path('chestX-ray-14') | |
| state_dict = torch.load('chexnet.h5') | |
| self.load_state_dict(state_dict) | |
| def load_pretrained(self, torch=False): | |
| if torch: | |
| self.backbone = densenet121(True).features | |
| else: | |
| self.backbone = pretrainedmodels.__dict__['densenet121']().features | |
| self.head = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| Flatten(), | |
| nn.Linear(1024, 14) | |
| ) | |
| def forward(self, x): | |
| return self.head(self.backbone(x)) | |
| def predict(self, image): | |
| """ | |
| input: PIL image (w, h, c) | |
| output: prob np.array | |
| """ | |
| image_tensor = self.tfm(image).unsqueeze(0) # Add batch dimension | |
| image_tensor = image_tensor.to(next(self.parameters()).device) # Move to the same device as the model | |
| with torch.no_grad(): | |
| py = torch.sigmoid(self(image_tensor)) | |
| prob = py.cpu().numpy()[0] | |
| return prob | |