Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pylab as plt | |
| import torch.nn.functional as F | |
| from vae import HVAE | |
| from datasets import morphomnist, ukbb, mimic, get_attr_max_min | |
| from pgm.flow_pgm import MorphoMNISTPGM, FlowPGM, ChestPGM | |
| from app_utils import ( | |
| mnist_graph, | |
| brain_graph, | |
| chest_graph, | |
| vae_preprocess, | |
| normalize, | |
| preprocess_brain, | |
| get_fig_arr, | |
| postprocess, | |
| MidpointNormalize, | |
| ) | |
| DATA, MODELS = {}, {} | |
| for k in ["Morpho-MNIST", "Brain MRI", "Chest X-ray"]: | |
| DATA[k], MODELS[k] = {}, {} | |
| # mnist | |
| DIGITS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] | |
| # brain | |
| MRISEQ_CAT = ["T1", "T2-FLAIR"] # 0,1 | |
| SEX_CAT = ["female", "male"] # 0,1 | |
| HEIGHT, WIDTH = 270, 270 | |
| # chest | |
| SEX_CAT_CHEST = ["male", "female"] # 0,1 | |
| RACE_CAT = ["white", "asian", "black"] # 0,1,2 | |
| FIND_CAT = ["no disease", "pleural effusion"] | |
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| class Hparams: | |
| def update(self, dict): | |
| for k, v in dict.items(): | |
| setattr(self, k, v) | |
| def get_paths(dataset_id): | |
| if "MNIST" in dataset_id: | |
| data_path = "./data/morphomnist" | |
| pgm_path = "./checkpoints/t_i_d/sup_pgm/checkpoint.pt" | |
| vae_path = "./checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt" | |
| elif "Brain" in dataset_id: | |
| data_path = "./data/ukbb_subset" | |
| pgm_path = "./checkpoints/m_b_v_s/sup_pgm/checkpoint.pt" | |
| vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt" | |
| elif "Chest" in dataset_id: | |
| data_path = "./data/mimic_subset" | |
| pgm_path = "./checkpoints/a_r_s_f/sup_pgm_mimic/checkpoint.pt" | |
| vae_path = [ | |
| "./checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/checkpoint.pt", # base vae | |
| "./checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/6500_checkpoint.pt", # cf trained DSCM | |
| ] | |
| return data_path, vae_path, pgm_path | |
| def load_pgm(dataset_id, pgm_path): | |
| checkpoint = torch.load(pgm_path, map_location=DEVICE) | |
| args = Hparams() | |
| args.update(checkpoint["hparams"]) | |
| args.device = DEVICE | |
| if "MNIST" in dataset_id: | |
| pgm = MorphoMNISTPGM(args).to(args.device) | |
| elif "Brain" in dataset_id: | |
| pgm = FlowPGM(args).to(args.device) | |
| elif "Chest" in dataset_id: | |
| pgm = ChestPGM(args).to(args.device) | |
| pgm.load_state_dict(checkpoint["ema_model_state_dict"]) | |
| MODELS[dataset_id]["pgm"] = pgm | |
| MODELS[dataset_id]["pgm_args"] = args | |
| def load_vae(dataset_id, vae_path): | |
| if "Chest" in dataset_id: | |
| vae_path, dscm_path = vae_path[0], vae_path[1] | |
| checkpoint = torch.load(vae_path, map_location=DEVICE) | |
| args = Hparams() | |
| args.update(checkpoint["hparams"]) | |
| # backwards compatibility hack | |
| if not hasattr(args, "vae"): | |
| args.vae = "hierarchical" | |
| if not hasattr(args, "cond_prior"): | |
| args.cond_prior = False | |
| if hasattr(args, "free_bits"): | |
| args.kl_free_bits = args.free_bits | |
| args.device = DEVICE | |
| vae = HVAE(args).to(args.device) | |
| if "Chest" in dataset_id: | |
| dscm_ckpt = torch.load(dscm_path, map_location=DEVICE) | |
| vae.load_state_dict( | |
| { | |
| k[4:]: v | |
| for k, v in dscm_ckpt["ema_model_state_dict"].items() | |
| if "vae." in k | |
| } | |
| ) | |
| else: | |
| vae.load_state_dict(checkpoint["ema_model_state_dict"]) | |
| MODELS[dataset_id]["vae"] = vae | |
| MODELS[dataset_id]["vae_args"] = args | |
| def get_dataloader(dataset_id, data_path): | |
| MODELS[dataset_id]["pgm_args"].data_dir = data_path | |
| args = MODELS[dataset_id]["pgm_args"] | |
| if "MNIST" in dataset_id: | |
| datasets = morphomnist(args) | |
| elif "Brain" in dataset_id: | |
| datasets = ukbb(args) | |
| elif "Chest" in dataset_id: | |
| datasets = mimic(args) | |
| DATA[dataset_id]["test"] = torch.utils.data.DataLoader( | |
| datasets["test"], shuffle=False, batch_size=args.bs, num_workers=4 | |
| ) | |
| def load_dataset(dataset_id): | |
| data_path, _, pgm_path = get_paths(dataset_id) | |
| checkpoint = torch.load(pgm_path, map_location=DEVICE) | |
| args = Hparams() | |
| args.update(checkpoint["hparams"]) | |
| args.device = DEVICE | |
| MODELS[dataset_id]["pgm_args"] = args | |
| get_dataloader(dataset_id, data_path) | |
| def load_model(dataset_id): | |
| _, vae_path, pgm_path = get_paths(dataset_id) | |
| load_pgm(dataset_id, pgm_path) | |
| load_vae(dataset_id, vae_path) | |
| def counterfactual_inference(dataset_id, obs, do_pa): | |
| pa = {k: v.clone() for k, v in obs.items() if k != "x"} | |
| cf_pa = MODELS[dataset_id]["pgm"].counterfactual( | |
| obs=pa, intervention=do_pa, num_particles=1 | |
| ) | |
| args, vae = MODELS[dataset_id]["vae_args"], MODELS[dataset_id]["vae"] | |
| _pa = vae_preprocess(args, {k: v.clone() for k, v in pa.items()}) | |
| _cf_pa = vae_preprocess(args, {k: v.clone() for k, v in cf_pa.items()}) | |
| z_t = 0.1 if "mnist" in args.hps else 1.0 | |
| z = vae.abduct(x=obs["x"], parents=_pa, t=z_t) | |
| if vae.cond_prior: | |
| z = [z[j]["z"] for j in range(len(z))] | |
| px_loc, px_scale = vae.forward_latents(latents=z, parents=_pa) | |
| cf_loc, cf_scale = vae.forward_latents(latents=z, parents=_cf_pa) | |
| u = (obs["x"] - px_loc) / px_scale.clamp(min=1e-12) | |
| u_t = 0.1 if "mnist" in args.hps else 1.0 # cf sampling temp | |
| cf_scale = cf_scale * u_t | |
| cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1) | |
| return {"cf_x": cf_x, "rec_x": px_loc, "cf_pa": cf_pa} | |
| def get_obs_item(dataset_id, idx=None): | |
| if idx is None: | |
| n_test = len(DATA[dataset_id]["test"].dataset) | |
| idx = torch.randperm(n_test)[0] | |
| idx = int(idx) | |
| return idx, DATA[dataset_id]["test"].dataset.__getitem__(idx) | |
| def get_mnist_obs(idx=None): | |
| dataset_id = "Morpho-MNIST" | |
| if not DATA[dataset_id]: | |
| load_dataset(dataset_id) | |
| idx, obs = get_obs_item(dataset_id, idx) | |
| x = get_fig_arr(obs["x"].clone().squeeze().numpy()) | |
| t = (obs["thickness"].clone() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526 | |
| i = (obs["intensity"].clone() + 1) / 2 * (254.90317 - 66.601204) + 66.601204 | |
| y = DIGITS[obs["digit"].clone().argmax(-1)] | |
| return (idx, x, float(np.round(t, 2)), float(np.round(i, 2)), y) | |
| def get_brain_obs(idx=None): | |
| dataset_id = "Brain MRI" | |
| if not DATA[dataset_id]: | |
| load_dataset(dataset_id) | |
| idx, obs = get_obs_item(dataset_id, idx) | |
| x = get_fig_arr(obs["x"].clone().squeeze().numpy()) | |
| m = MRISEQ_CAT[int(obs["mri_seq"].clone().item())] | |
| s = SEX_CAT[int(obs["sex"].clone().item())] | |
| a = obs["age"].clone().item() | |
| b = obs["brain_volume"].clone().item() / 1000 # in ml | |
| v = obs["ventricle_volume"].clone().item() / 1000 # in ml | |
| return (idx, x, m, s, a, float(np.round(b, 2)), float(np.round(v, 2))) | |
| def get_chest_obs(idx=None): | |
| dataset_id = "Chest X-ray" | |
| if not DATA[dataset_id]: | |
| load_dataset(dataset_id) | |
| idx, obs = get_obs_item(dataset_id, idx) | |
| x = get_fig_arr(postprocess(obs["x"].clone())) | |
| s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())] | |
| f = FIND_CAT[int(obs["finding"].clone().squeeze().numpy())] | |
| r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)] | |
| a = (obs["age"].clone().squeeze().numpy() + 1) * 50 | |
| return (idx, x, r, s, f, float(np.round(a, 1))) | |
| def infer_mnist_cf(*args): | |
| dataset_id = "Morpho-MNIST" | |
| idx, _, t, i, y, do_t, do_i, do_y = args | |
| n_particles = 32 | |
| # preprocess | |
| obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) | |
| obs["x"] = (obs["x"] - 127.5) / 127.5 | |
| for k, v in obs.items(): | |
| obs[k] = v.view(1, 1) if len(v.shape) < 1 else v.unsqueeze(0) | |
| obs[k] = obs[k].to(MODELS[dataset_id]["vae_args"].device).float() | |
| if n_particles > 1: | |
| ndims = (1,) * 3 if k == "x" else (1,) | |
| obs[k] = obs[k].repeat(n_particles, *ndims) | |
| # intervention(s) | |
| do_pa = {} | |
| if do_t: | |
| do_pa["thickness"] = torch.tensor( | |
| normalize(t, x_max=6.255515, x_min=0.87598526) | |
| ).view(1, 1) | |
| if do_i: | |
| do_pa["intensity"] = torch.tensor( | |
| normalize(i, x_max=254.90317, x_min=66.601204) | |
| ).view(1, 1) | |
| if do_y: | |
| do_pa["digit"] = F.one_hot(torch.tensor(DIGITS.index(y)), num_classes=10).view( | |
| 1, 10 | |
| ) | |
| for k, v in do_pa.items(): | |
| do_pa[k] = ( | |
| v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) | |
| ) | |
| # infer counterfactual | |
| out = counterfactual_inference(dataset_id, obs, do_pa) | |
| # avg cf particles | |
| cf_x = out["cf_x"].mean(0) | |
| cf_x_std = out["cf_x"].std(0) | |
| rec_x = out["rec_x"].mean(0) | |
| cf_t = out["cf_pa"]["thickness"].mean(0) | |
| cf_i = out["cf_pa"]["intensity"].mean(0) | |
| cf_y = out["cf_pa"]["digit"].mean(0) | |
| # post process | |
| cf_x = postprocess(cf_x) | |
| cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() | |
| rec_x = postprocess(rec_x) | |
| cf_t = np.round((cf_t.item() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526, 2) | |
| cf_i = np.round((cf_i.item() + 1) / 2 * (254.90317 - 66.601204) + 66.601204, 2) | |
| cf_y = DIGITS[cf_y.argmax(-1)] | |
| # plots | |
| # plt.close('all') | |
| effect = cf_x - rec_x | |
| effect = get_fig_arr( | |
| effect, cmap="RdBu_r", norm=MidpointNormalize(vmin=-255, midpoint=0, vmax=255) | |
| ) | |
| cf_x = get_fig_arr(cf_x) | |
| cf_x_std = get_fig_arr(cf_x_std, cmap="jet") | |
| return (cf_x, cf_x_std, effect, cf_t, cf_i, cf_y) | |
| def infer_brain_cf(*args): | |
| dataset_id = "Brain MRI" | |
| idx, _, m, s, a, b, v = args[:7] | |
| do_m, do_s, do_a, do_b, do_v = args[7:] | |
| n_particles = 16 | |
| # preprocessing | |
| obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) | |
| obs = preprocess_brain(MODELS[dataset_id]["vae_args"], obs) | |
| for k, _v in obs.items(): | |
| if n_particles > 1: | |
| ndims = (1,) * 3 if k == "x" else (1,) | |
| obs[k] = _v.repeat(n_particles, *ndims) | |
| # interventions(s) | |
| do_pa = {} | |
| if do_m: | |
| do_pa["mri_seq"] = torch.tensor(MRISEQ_CAT.index(m)).view(1, 1) | |
| if do_s: | |
| do_pa["sex"] = torch.tensor(SEX_CAT.index(s)).view(1, 1) | |
| if do_a: | |
| do_pa["age"] = torch.tensor(a).view(1, 1) | |
| if do_b: | |
| do_pa["brain_volume"] = torch.tensor(b * 1000).view(1, 1) | |
| if do_v: | |
| do_pa["ventricle_volume"] = torch.tensor(v * 1000).view(1, 1) | |
| # normalize continuous attributes | |
| for k in ["age", "brain_volume", "ventricle_volume"]: | |
| if k in do_pa.keys(): | |
| k_max, k_min = get_attr_max_min(k) | |
| do_pa[k] = (do_pa[k] - k_min) / (k_max - k_min) # [0,1] | |
| do_pa[k] = 2 * do_pa[k] - 1 # [-1,1] | |
| for k, _v in do_pa.items(): | |
| do_pa[k] = ( | |
| _v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) | |
| ) | |
| # infer counterfactual | |
| out = counterfactual_inference(dataset_id, obs, do_pa) | |
| # avg cf particles | |
| cf_x = out["cf_x"].mean(0) | |
| cf_x_std = out["cf_x"].std(0) | |
| rec_x = out["rec_x"].mean(0) | |
| cf_m = out["cf_pa"]["mri_seq"].mean(0) | |
| cf_s = out["cf_pa"]["sex"].mean(0) | |
| # post process | |
| cf_x = postprocess(cf_x) | |
| cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() | |
| rec_x = postprocess(rec_x) | |
| cf_m = MRISEQ_CAT[int(cf_m.item())] | |
| cf_s = SEX_CAT[int(cf_s.item())] | |
| cf_ = {} | |
| for k in ["age", "brain_volume", "ventricle_volume"]: # unnormalize | |
| k_max, k_min = get_attr_max_min(k) | |
| cf_[k] = (out["cf_pa"][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min | |
| # plots | |
| # plt.close('all') | |
| effect = cf_x - rec_x | |
| effect = get_fig_arr( | |
| effect, | |
| cmap="RdBu_r", | |
| norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()), | |
| ) | |
| cf_x = get_fig_arr(cf_x) | |
| cf_x_std = get_fig_arr(cf_x_std, cmap="jet") | |
| return ( | |
| cf_x, | |
| cf_x_std, | |
| effect, | |
| cf_m, | |
| cf_s, | |
| np.round(cf_["age"], 1), | |
| np.round(cf_["brain_volume"] / 1000, 2), | |
| np.round(cf_["ventricle_volume"] / 1000, 2), | |
| ) | |
| def infer_chest_cf(*args): | |
| dataset_id = "Chest X-ray" | |
| idx, _, r, s, f, a = args[:6] | |
| do_r, do_s, do_f, do_a = args[6:] | |
| n_particles = 16 | |
| # preprocessing | |
| obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) | |
| for k, v in obs.items(): | |
| obs[k] = v.to(MODELS[dataset_id]["vae_args"].device).float() | |
| if n_particles > 1: | |
| ndims = (1,) * 3 if k == "x" else (1,) | |
| obs[k] = obs[k].repeat(n_particles, *ndims) | |
| # intervention(s) | |
| do_pa = {} | |
| with torch.no_grad(): | |
| if do_s: | |
| do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1) | |
| if do_f: | |
| do_pa["finding"] = torch.tensor(FIND_CAT.index(f)).view(1, 1) | |
| if do_r: | |
| do_pa["race"] = F.one_hot( | |
| torch.tensor(RACE_CAT.index(r)), num_classes=3 | |
| ).view(1, 3) | |
| if do_a: | |
| do_pa["age"] = torch.tensor(a / 100 * 2 - 1).view(1, 1) | |
| for k, v in do_pa.items(): | |
| do_pa[k] = ( | |
| v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) | |
| ) | |
| # infer counterfactual | |
| out = counterfactual_inference(dataset_id, obs, do_pa) | |
| # avg cf particles | |
| cf_x = out["cf_x"].mean(0) | |
| cf_x_std = out["cf_x"].std(0) | |
| rec_x = out["rec_x"].mean(0) | |
| cf_r = out["cf_pa"]["race"].mean(0) | |
| cf_s = out["cf_pa"]["sex"].mean(0) | |
| cf_f = out["cf_pa"]["finding"].mean(0) | |
| cf_a = out["cf_pa"]["age"].mean(0) | |
| # post process | |
| cf_x = postprocess(cf_x) | |
| cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() | |
| rec_x = postprocess(rec_x) | |
| cf_r = RACE_CAT[cf_r.argmax(-1)] | |
| cf_s = SEX_CAT_CHEST[int(cf_s.item())] | |
| cf_f = FIND_CAT[int(cf_f.item())] | |
| cf_a = (cf_a.item() + 1) * 50 | |
| # plots | |
| # plt.close('all') | |
| effect = cf_x - rec_x | |
| effect = get_fig_arr( | |
| effect, | |
| cmap="RdBu_r", | |
| norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()), | |
| ) | |
| cf_x = get_fig_arr(cf_x) | |
| cf_x_std = get_fig_arr(cf_x_std, cmap="jet") | |
| return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1)) | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| with gr.Tabs(): | |
| with gr.TabItem("Brain MRI") as brain_tab: | |
| brain_id = gr.Textbox(value=brain_tab.label, visible=False) | |
| with gr.Row().style(equal_height=True): | |
| idx_brain = gr.Number(value=0, visible=False) | |
| with gr.Column(scale=1, min_width=200): | |
| x_brain = gr.Image(label="Observation", interactive=False).style( | |
| height=HEIGHT | |
| ) | |
| with gr.Column(scale=1, min_width=200): | |
| cf_x_brain = gr.Image( | |
| label="Counterfactual", interactive=False | |
| ).style(height=HEIGHT) | |
| with gr.Column(scale=1, min_width=200): | |
| cf_x_std_brain = gr.Image( | |
| label="Counterfactual Uncertainty", interactive=False | |
| ).style(height=HEIGHT) | |
| with gr.Column(scale=1, min_width=200): | |
| effect_brain = gr.Image( | |
| label="Direct Causal Effect", interactive=False | |
| ).style(height=HEIGHT) | |
| with gr.Row(): | |
| with gr.Column(scale=2.55): | |
| gr.Markdown( | |
| "**Intervention**" | |
| # + 20 * " " | |
| # + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)" | |
| # + "  |   Hint: try 90% zoom" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(min_width=200): | |
| do_a = gr.Checkbox(label="do(age)", value=False) | |
| a = gr.Slider( | |
| label="\u00A0", | |
| value=50, | |
| minimum=44, | |
| maximum=73, | |
| step=1, | |
| interactive=False, | |
| ) | |
| with gr.Column(min_width=200): | |
| do_s = gr.Checkbox(label="do(sex)", value=False) | |
| s = gr.Radio( | |
| ["female", "male"], label="", interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(min_width=200): | |
| do_b = gr.Checkbox(label="do(brain volume)", value=False) | |
| b = gr.Slider( | |
| label="\u00A0", | |
| value=1000, | |
| minimum=850, | |
| maximum=1550, | |
| step=20, | |
| interactive=False, | |
| ) | |
| with gr.Column(min_width=200): | |
| do_v = gr.Checkbox( | |
| label="do(ventricle volume)", value=False | |
| ) | |
| v = gr.Slider( | |
| label="\u00A0", | |
| value=40, | |
| minimum=10, | |
| maximum=125, | |
| step=2, | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| new_brain = gr.Button("New Observation") | |
| reset_brain = gr.Button("Reset", variant="stop") | |
| submit_brain = gr.Button("Submit", variant="primary") | |
| with gr.Column(scale=1): | |
| # gr.Markdown("### ") | |
| causal_graph_brain = gr.Image( | |
| label="Causal Graph", interactive=False | |
| ).style(height=340) | |
| with gr.TabItem("Chest X-ray") as chest_tab: | |
| chest_id = gr.Textbox(value=chest_tab.label, visible=False) | |
| with gr.Row().style(equal_height=True): | |
| idx_chest = gr.Number(value=0, visible=False) | |
| with gr.Column(scale=1, min_width=200): | |
| x_chest = gr.Image(label="Observation", interactive=False).style( | |
| height=HEIGHT | |
| ) | |
| with gr.Column(scale=1, min_width=200): | |
| cf_x_chest = gr.Image( | |
| label="Counterfactual", interactive=False | |
| ).style(height=HEIGHT) | |
| with gr.Column(scale=1, min_width=200): | |
| cf_x_std_chest = gr.Image( | |
| label="Counterfactual Uncertainty", interactive=False | |
| ).style(height=HEIGHT) | |
| with gr.Column(scale=1, min_width=200): | |
| effect_chest = gr.Image( | |
| label="Direct Causal Effect", interactive=False | |
| ).style(height=HEIGHT) | |
| with gr.Row(): | |
| with gr.Column(scale=2.55): | |
| gr.Markdown( | |
| "**Intervention**" | |
| # + 20 * " " | |
| # + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)" | |
| # + "  |   Hint: try 90% zoom" | |
| ) | |
| with gr.Row().style(equal_height=True): | |
| with gr.Column(min_width=200): | |
| do_a_chest = gr.Checkbox(label="do(age)", value=False) | |
| a_chest = gr.Slider( | |
| label="\u00A0", minimum=18, maximum=98, step=1 | |
| ) | |
| with gr.Column(min_width=200): | |
| do_s_chest = gr.Checkbox(label="do(sex)", value=False) | |
| s_chest = gr.Radio( | |
| SEX_CAT_CHEST, label="", interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(min_width=200): | |
| do_r_chest = gr.Checkbox(label="do(race)", value=False) | |
| r_chest = gr.Radio(RACE_CAT, label="", interactive=False) | |
| with gr.Column(min_width=200): | |
| do_f_chest = gr.Checkbox(label="do(disease)", value=False) | |
| f_chest = gr.Radio(FIND_CAT, label="", interactive=False) | |
| with gr.Row(): | |
| new_chest = gr.Button("New Observation") | |
| reset_chest = gr.Button("Reset", variant="stop") | |
| submit_chest = gr.Button("Submit", variant="primary") | |
| with gr.Column(scale=1): | |
| # gr.Markdown("### ") | |
| causal_graph_chest = gr.Image( | |
| label="Causal Graph", interactive=False | |
| ).style(height=345) | |
| # morphomnist | |
| # do = [do_t, do_i, do_y] | |
| # obs = [idx, x, t, i, y] | |
| # cf_out = [cf_x, cf_x_std, effect] | |
| # brain | |
| do_brain = [do_s, do_a, do_b, do_v] # intervention checkboxes | |
| obs_brain = [idx_brain, x_brain, s, a, b, v] # observed image/attributes | |
| cf_out_brain = [cf_x_brain, cf_x_std_brain, effect_brain] # counterfactual outputs | |
| # chest | |
| do_chest = [do_r_chest, do_s_chest, do_f_chest, do_a_chest] | |
| obs_chest = [idx_chest, x_chest, r_chest, s_chest, f_chest, a_chest] | |
| cf_out_chest = [cf_x_chest, cf_x_std_chest, effect_chest] | |
| # on start: load new observations & causal graph | |
| demo.load(fn=get_brain_obs, inputs=None, outputs=obs_brain) | |
| demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest) | |
| demo.load(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain) | |
| demo.load(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest) | |
| # on tab select: load models | |
| brain_tab.select(fn=load_model, inputs=brain_id, outputs=None) | |
| chest_tab.select(fn=load_model, inputs=chest_id, outputs=None) | |
| # "new" button: load new observations | |
| new_chest.click(fn=get_chest_obs, inputs=None, outputs=obs_chest) | |
| new_brain.click(fn=get_brain_obs, inputs=None, outputs=obs_brain) | |
| # "new" button: reset causal graphs | |
| new_brain.click(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain) | |
| new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest) | |
| # "new" button: reset cf output panels | |
| for _k, _v in zip( | |
| [new_brain, new_chest], [cf_out_brain, cf_out_chest] | |
| ): | |
| _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v) | |
| # "reset" button: reload current observations | |
| reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain) | |
| reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest) | |
| # "reset" button: deselect intervention checkboxes | |
| reset_brain.click( | |
| fn=lambda: (gr.update(value=False),) * len(do_brain), | |
| inputs=None, | |
| outputs=do_brain, | |
| ) | |
| reset_chest.click( | |
| fn=lambda: (gr.update(value=False),) * len(do_chest), | |
| inputs=None, | |
| outputs=do_chest, | |
| ) | |
| # "reset" button: reset cf output panels | |
| for _k, _v in zip( | |
| [reset_brain, reset_chest], [cf_out_brain, cf_out_chest] | |
| ): | |
| _k.click(fn=lambda: plt.close("all"), inputs=None, outputs=None) | |
| _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v) | |
| # enable brain interventions when checkbox is selected & update graph | |
| for _k, _v in zip(do_brain, [s, a, b, v]): | |
| _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v) | |
| _k.change(brain_graph, inputs=do_brain, outputs=causal_graph_brain) | |
| # enable chest interventions when checkbox is selected & update graph | |
| for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]): | |
| _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v) | |
| _k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest) | |
| # "submit" button: infer countefactuals | |
| submit_brain.click( | |
| fn=infer_brain_cf, | |
| inputs=obs_brain + do_brain, | |
| outputs=cf_out_brain + [s, a, b, v], | |
| ) | |
| submit_chest.click( | |
| fn=infer_chest_cf, | |
| inputs=obs_chest + do_chest, | |
| outputs=cf_out_chest + [r_chest, s_chest, f_chest, a_chest], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() | |