Spaces:
Runtime error
Runtime error
| import spaces | |
| import tempfile | |
| import os | |
| from pathlib import Path | |
| import SimpleITK as sitk | |
| import numpy as np | |
| import nibabel as nib | |
| from totalsegmentator.python_api import totalsegmentator | |
| import gradio as gr | |
| from segmap import seg_map | |
| import logging | |
| # Logging configuration | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| sample_files = ["ct1.nii.gz", "ct2.nii.gz", "ct3.nii.gz"] | |
| def map_labels(seg_array): | |
| labels = [] | |
| count = 0 | |
| logger.debug("unique segs:") | |
| logger.debug(str(len(np.unique(seg_array)))) | |
| for seg_class in np.unique(seg_array): | |
| if seg_class == 0: | |
| continue | |
| labels.append((seg_array == seg_class, seg_map[seg_class])) | |
| count += 1 | |
| return labels | |
| def sitk_to_numpy(img_sitk, norm=False): | |
| img_sitk = sitk.DICOMOrient(img_sitk, "LPS") | |
| img_np = sitk.GetArrayFromImage(img_sitk) | |
| if norm: | |
| min_val, max_val = np.min(img_np), np.max(img_np) | |
| img_np = ((img_np - min_val) / (max_val - min_val)).clip(0, 1) * 255 | |
| img_np = img_np.astype(np.uint8) | |
| return img_np | |
| def load_image(path, norm=False): | |
| img_sitk = sitk.ReadImage(path) | |
| return sitk_to_numpy(img_sitk, norm) | |
| def show_img_seg(img_np, seg_np=None, slice_idx=50): | |
| if img_np is None or (isinstance(img_np, list) and len(img_np) == 0): | |
| return None | |
| if isinstance(img_np, list): | |
| img_np = img_np[-1] | |
| slice_pos = int(slice_idx * (img_np.shape[0] / 100)) | |
| img_slice = img_np[slice_pos, :, :] | |
| if seg_np is None or (isinstance(seg_np, list) and len(seg_np) == 0): | |
| seg_np = [] | |
| else: | |
| if isinstance(seg_np, list): | |
| seg_np = seg_np[-1] | |
| seg_np = map_labels(seg_np[slice_pos, :, :]) | |
| return img_slice, seg_np | |
| def load_img_to_state(path, img_state, seg_state): | |
| img_state.clear() | |
| seg_state.clear() | |
| if path: | |
| img_np = load_image(path, norm=True) | |
| img_state.append(img_np) | |
| return None, img_state, seg_state | |
| else: | |
| return None, img_state, seg_state | |
| def save_seg(seg, path): | |
| if Path(path).name in sample_files: | |
| path = os.path.join("output_examples", f"{Path(Path(path).stem).stem}_seg.nii.gz") | |
| else: | |
| sitk.WriteImage(seg, path) | |
| return path | |
| def run_inference(path): | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| input_nib = nib.load(path) | |
| output_nib = totalsegmentator(input_nib, fast=True) | |
| output_path = os.path.join(temp_dir, "totalseg_output.nii.gz") | |
| nib.save(output_nib, output_path) | |
| seg_sitk = sitk.ReadImage(output_path) | |
| return seg_sitk | |
| def inference_wrapper(input_file, img_state, seg_state, slice_slider=50): | |
| file_name = Path(input_file).name | |
| if file_name in sample_files: | |
| seg_sitk = sitk.ReadImage(os.path.join("output_examples", f"{Path(Path(file_name).stem).stem}_seg.nii.gz")) | |
| else: | |
| seg_sitk = run_inference(input_file.name) | |
| seg_path = save_seg(seg_sitk, input_file.name) | |
| seg_state.append(sitk_to_numpy(seg_sitk)) | |
| if not img_state: | |
| img_sitk = sitk.ReadImage(input_file.name) | |
| img_state.append(sitk_to_numpy(img_sitk)) | |
| return show_img_seg(img_state[-1], seg_state[-1], slice_slider), seg_state, seg_path | |
| with gr.Blocks(title="TotalSegmentator") as interface: | |
| gr.Markdown("# TotalSegmentator: Segmentation of 117 Classes in CT and MR Images") | |
| gr.Markdown(""" | |
| - **GitHub:** https://github.com/wasserth/TotalSegmentator | |
| - **Please Note:** This tool is intended for research purposes only and can segment 117 classes in CT/MRI images | |
| - Supports both CT and MR imaging modalities | |
| - Credit: adapted from `DiGuaQiu/MRSegmentator-Gradio` | |
| """) | |
| img_state = gr.State([]) | |
| seg_state = gr.State([]) | |
| with gr.Accordion(label='Upload CT Scan (nifti file) then click on Generate Segmentation to run TotalSegmentator', open=True): | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File( | |
| type="filepath", label="Upload a CT or MR Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"] | |
| ) | |
| gr.Examples(["input_examples/" + example for example in sample_files], file_input) | |
| with gr.Row(): | |
| infer_button = gr.Button("Generate Segmentations", variant="primary") | |
| clear_button = gr.ClearButton() | |
| with gr.Column(): | |
| slice_slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice") | |
| img_viewer = gr.AnnotatedImage(label="Image Viewer") | |
| download_seg = gr.File(label="Download Segmentation", interactive=False) | |
| file_input.change( | |
| load_img_to_state, | |
| inputs=[file_input, img_state, seg_state], | |
| outputs=[img_viewer, img_state, seg_state], | |
| ) | |
| slice_slider.change(show_img_seg, inputs=[img_state, seg_state, slice_slider], outputs=[img_viewer]) | |
| infer_button.click( | |
| inference_wrapper, | |
| inputs=[file_input, img_state, seg_state, slice_slider], | |
| outputs=[img_viewer, seg_state, download_seg], | |
| ) | |
| clear_button.add([file_input, img_viewer, img_state, seg_state, download_seg]) | |
| if __name__ == "__main__": | |
| interface.queue() | |
| interface.launch(debug=True) |