chuuhtetnaing's picture
set the preview as the webp
6c1d97a
import gradio as gr
from loadimg import load_img
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import tempfile
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
image = process(im)
return image
def parse_color(color):
if color.startswith('#'):
hex_color = color.lstrip('#')
r = int(hex_color[0:2], 16)
g = int(hex_color[2:4], 16)
b = int(hex_color[4:6], 16)
elif color.startswith('rgba'):
rgba_values = color.replace('rgba(', '').replace(')', '')
parts = [x.strip() for x in rgba_values.split(',')]
r, g, b = int(float(parts[0])), int(float(parts[1])), int(float(parts[2]))
elif color.startswith('rgb'):
rgb_values = color.replace('rgb(', '').replace(')', '')
r, g, b = [int(float(x.strip())) for x in rgb_values.split(',')]
else:
r, g, b = 255, 255, 255
return (r, g, b, 255)
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cpu")
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image, mask
def process_file(f, bg_color):
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent_img, mask = process(im)
# With background color
rgba_color = parse_color(bg_color)
background = Image.new("RGBA", im.size, rgba_color)
with_bg = Image.alpha_composite(background, transparent_img)
with_bg_rgb = with_bg.convert("RGB")
bg_png_path = tempfile.mktemp(suffix=".png")
with_bg.save(bg_png_path, "PNG")
bg_jpeg_path = tempfile.mktemp(suffix=".jpeg")
with_bg_rgb.save(bg_jpeg_path, "JPEG")
# Transparent (no background)
trans_png_path = tempfile.mktemp(suffix=".png")
transparent_img.save(trans_png_path, "PNG")
return (with_bg_rgb, bg_png_path, bg_jpeg_path,
transparent_img, trans_png_path)
css = """
.gradio-container h1 {
margin-bottom: 24px;
}
.small-file, .small-file * {
min-height: 0 !important;
height: auto !important;
}
.small-file svg {
display: none !important;
}
"""
with gr.Blocks(css=css, title="Background Remover") as background_remover_app:
gr.Markdown("<h1 style='text-align: center;'>Background Remover</h1>")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Input")
input_image = gr.Image(label="Upload an image", type="filepath")
color_picker = gr.ColorPicker(label="Background Color", value="#ffffff")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column(scale=2):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Output With Background Color")
bg_preview = gr.Image(label="Preview")
bg_png = gr.File(label="Download PNG", elem_classes="small-file")
bg_jpeg = gr.File(label="Download JPEG", elem_classes="small-file")
with gr.Column(scale=1):
gr.Markdown("### Output With Transparent Background")
trans_preview = gr.Image(label="Preview")
trans_png = gr.File(label="Download PNG", elem_classes="small-file")
gr.Examples(
examples=[["butterfly.jpg", "#ffffff"]],
inputs=[input_image, color_picker]
)
submit_btn.click(
fn=process_file,
inputs=[input_image, color_picker],
outputs=[bg_preview, bg_png, bg_jpeg, trans_preview, trans_png]
)
if __name__ == "__main__":
background_remover_app.launch(share=True)