| | import gradio as gr |
| | from loadimg import load_img |
| | from transformers import AutoModelForImageSegmentation |
| | import torch |
| | from torchvision import transforms |
| | from PIL import Image |
| | import tempfile |
| |
|
| | torch.set_float32_matmul_precision(["high", "highest"][0]) |
| |
|
| | birefnet = AutoModelForImageSegmentation.from_pretrained( |
| | "ZhengPeng7/BiRefNet", trust_remote_code=True |
| | ) |
| | birefnet.to("cpu") |
| | transform_image = transforms.Compose( |
| | [ |
| | transforms.Resize((1024, 1024)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| | ] |
| | ) |
| |
|
| | def fn(image): |
| | im = load_img(image, output_type="pil") |
| | im = im.convert("RGB") |
| | origin = im.copy() |
| | image = process(im) |
| | return image |
| |
|
| | def parse_color(color): |
| | if color.startswith('#'): |
| | hex_color = color.lstrip('#') |
| | r = int(hex_color[0:2], 16) |
| | g = int(hex_color[2:4], 16) |
| | b = int(hex_color[4:6], 16) |
| | elif color.startswith('rgba'): |
| | rgba_values = color.replace('rgba(', '').replace(')', '') |
| | parts = [x.strip() for x in rgba_values.split(',')] |
| | r, g, b = int(float(parts[0])), int(float(parts[1])), int(float(parts[2])) |
| | elif color.startswith('rgb'): |
| | rgb_values = color.replace('rgb(', '').replace(')', '') |
| | r, g, b = [int(float(x.strip())) for x in rgb_values.split(',')] |
| | else: |
| | r, g, b = 255, 255, 255 |
| | return (r, g, b, 255) |
| |
|
| | def process(image): |
| | image_size = image.size |
| | input_images = transform_image(image).unsqueeze(0).to("cpu") |
| | with torch.no_grad(): |
| | preds = birefnet(input_images)[-1].sigmoid().cpu() |
| | pred = preds[0].squeeze() |
| | pred_pil = transforms.ToPILImage()(pred) |
| | mask = pred_pil.resize(image_size) |
| | image.putalpha(mask) |
| | return image, mask |
| |
|
| | def process_file(f, bg_color): |
| | im = load_img(f, output_type="pil") |
| | im = im.convert("RGB") |
| |
|
| | transparent_img, mask = process(im) |
| |
|
| | |
| | rgba_color = parse_color(bg_color) |
| | background = Image.new("RGBA", im.size, rgba_color) |
| | with_bg = Image.alpha_composite(background, transparent_img) |
| | with_bg_rgb = with_bg.convert("RGB") |
| |
|
| | bg_png_path = tempfile.mktemp(suffix=".png") |
| | with_bg.save(bg_png_path, "PNG") |
| |
|
| | bg_jpeg_path = tempfile.mktemp(suffix=".jpeg") |
| | with_bg_rgb.save(bg_jpeg_path, "JPEG") |
| |
|
| | |
| | trans_png_path = tempfile.mktemp(suffix=".png") |
| | transparent_img.save(trans_png_path, "PNG") |
| |
|
| | return (with_bg_rgb, bg_png_path, bg_jpeg_path, |
| | transparent_img, trans_png_path) |
| |
|
| | css = """ |
| | .gradio-container h1 { |
| | margin-bottom: 24px; |
| | } |
| | .small-file, .small-file * { |
| | min-height: 0 !important; |
| | height: auto !important; |
| | } |
| | .small-file svg { |
| | display: none !important; |
| | } |
| | """ |
| |
|
| | with gr.Blocks(css=css, title="Background Remover") as background_remover_app: |
| | gr.Markdown("<h1 style='text-align: center;'>Background Remover</h1>") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | gr.Markdown("### Input") |
| | input_image = gr.Image(label="Upload an image", type="filepath") |
| | color_picker = gr.ColorPicker(label="Background Color", value="#ffffff") |
| | submit_btn = gr.Button("Submit", variant="primary") |
| |
|
| | with gr.Column(scale=2): |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.Markdown("### Output With Background Color") |
| | bg_preview = gr.Image(label="Preview") |
| | bg_png = gr.File(label="Download PNG", elem_classes="small-file") |
| | bg_jpeg = gr.File(label="Download JPEG", elem_classes="small-file") |
| |
|
| | with gr.Column(scale=1): |
| | gr.Markdown("### Output With Transparent Background") |
| | trans_preview = gr.Image(label="Preview") |
| | trans_png = gr.File(label="Download PNG", elem_classes="small-file") |
| |
|
| | gr.Examples( |
| | examples=[["butterfly.jpg", "#ffffff"]], |
| | inputs=[input_image, color_picker] |
| | ) |
| |
|
| | submit_btn.click( |
| | fn=process_file, |
| | inputs=[input_image, color_picker], |
| | outputs=[bg_preview, bg_png, bg_jpeg, trans_preview, trans_png] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | background_remover_app.launch(share=True) |
| |
|