isa / app.py
ondrejbiza's picture
V1 works locally.
9d5d768
raw
history blame
6.24 kB
import functools
import os
from absl import flags
from clu import checkpoint
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
import tensorflow as tf
from huggingface_hub import snapshot_download
from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
from invariant_slot_attention.lib import input_pipeline
from invariant_slot_attention.lib import preprocessing
from invariant_slot_attention.lib import utils
def load_model(config, checkpoint_dir):
rng = jax.random.PRNGKey(42)
rng, data_rng = jax.random.split(rng)
# Initialize model
model = utils.build_model_from_config(config.model)
def init_model(rng):
rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
init_conditioning = None
init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
initial_vars = model.init(
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
video=init_inputs, conditioning=init_conditioning,
padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
# Split into state variables (e.g. for batchnorm stats) and model params.
# Note that `pop()` on a FrozenDict performs a deep copy.
state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error
# Filter out intermediates (we don't want to store these in the TrainState).
state_vars = utils.filter_key_from_frozen_dict(
state_vars, key="intermediates")
return state_vars, initial_params
state_vars, initial_params = init_model(rng)
opt_state = None
state = utils.TrainState(
step=42, opt_state=opt_state, params=initial_params, rng=rng,
variables=state_vars)
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
state = ckpt.restore(state)
return model, state, rng
def load_image(name):
img = Image.open(f"images/{name}.png")
img = img.crop((64, 29, 64 + 192, 29 + 192))
img = img.resize((128, 128))
img_ = np.array(img)
img = np.array(img)[:, :, :3] / 255.
img = jnp.array(img, dtype=jnp.float32)
return img, img_
download_path = snapshot_download(repo_id="ondrejbiza/isa")
checkpoint_dir = os.path.join(download_path, "clevr_isa_ts", "checkpoints")
model, state, rng = load_model(get_config(), checkpoint_dir)
rng, init_rng = jax.random.split(rng, num=2)
from flax import linen as nn
from typing import Callable
class DecoderWrapper(nn.Module):
decoder: Callable[[], nn.Module]
@nn.compact
def __call__(self, slots, train=False):
return self.decoder()(slots, train)
decoder_model = DecoderWrapper(decoder=model.decoder)
slots = np.zeros((11, 64), dtype=np.float32)
pos = np.zeros((11, 2), dtype=np.float32)
scale = np.zeros((11, 2), dtype=np.float32)
probs = np.zeros((11, 128, 128), dtype=np.float32)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr_choose_image = gr.Dropdown(
[f"img{i}" for i in range(1, 9)], label="CLEVR Image", info="Start by a picking an image from the CLEVR dataset."
)
gr_image_1 = gr.Image(type="numpy")
gr_image_2 = gr.Image(type="numpy")
with gr.Column():
gr_slot_slider = gr.Slider(1, 11, value=1, step=1, label="Slot")
gr_y_slider = gr.Slider(-1, 1, value=0, step=0.01, label="x")
gr_x_slider = gr.Slider(-1, 1, value=0, step=0.01, label="y")
gr_sy_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="width")
gr_sx_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="height")
gr_button = gr.Button("Render")
def update_image_and_segmentation(name, idx):
idx = idx - 1
img_input, img = load_image(name)
out = model.apply(
{"params": state.params, **state.variables},
video=img_input[None, None],
rngs={"state_init": init_rng},
train=False)
probs[:] = nn.softmax(out["outputs"]["segmentation_logits"][0, 0, :, :, :, 0], axis=0)
slots_ = out["states"]
slots[:] = slots_[0, 0, :, :-4]
pos[:] = slots_[0, 0, :, -4: -2]
scale[:] = slots_[0, 0, :, -2:]
return img, (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1])
gr_choose_image.change(
fn=update_image_and_segmentation,
inputs=[gr_choose_image, gr_slot_slider],
outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
)
def update_sliders(idx):
idx = idx - 1 # 1-indexing to 0-indexing
return (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1])
gr_slot_slider.change(
fn=update_sliders,
inputs=gr_slot_slider,
outputs=[gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
)
def update_pos_x(idx, val):
pos[idx - 1, 0] = val
def update_pos_y(idx, val):
pos[idx - 1, 1] = val
def update_scale_x(idx, val):
scale[idx - 1, 0] = val
def update_scale_y(idx, val):
scale[idx - 1, 1] = val
gr_x_slider.change(
fn=update_pos_x,
inputs=[gr_slot_slider, gr_x_slider]
)
gr_y_slider.change(
fn=update_pos_y,
inputs=[gr_slot_slider, gr_y_slider]
)
gr_sx_slider.change(
fn=update_scale_x,
inputs=[gr_slot_slider, gr_sx_slider]
)
gr_sy_slider.change(
fn=update_scale_y,
inputs=[gr_slot_slider, gr_sy_slider]
)
def render(idx):
idx = idx - 1
slots_ = np.concatenate([slots, pos, scale], axis=-1)
slots_ = jnp.array(slots_)
out = decoder_model.apply(
{"params": state.params, **state.variables},
slots=slots_[None, None],
train=False
)
probs[:] = nn.softmax(out["segmentation_logits"][0, 0, :, :, :, 0], axis=0)
image = np.array(out["video"][0, 0])
image = np.clip(image, 0, 1)
return (image * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8)
gr_button.click(
fn=render,
inputs=gr_slot_slider,
outputs=[gr_image_1, gr_image_2]
)
demo.launch()