import gradio as gr from loadimg import load_img from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms from PIL import Image import tempfile torch.set_float32_matmul_precision(["high", "highest"][0]) birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to("cpu") transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) def fn(image): im = load_img(image, output_type="pil") im = im.convert("RGB") origin = im.copy() image = process(im) return image def parse_color(color): if color.startswith('#'): hex_color = color.lstrip('#') r = int(hex_color[0:2], 16) g = int(hex_color[2:4], 16) b = int(hex_color[4:6], 16) elif color.startswith('rgba'): rgba_values = color.replace('rgba(', '').replace(')', '') parts = [x.strip() for x in rgba_values.split(',')] r, g, b = int(float(parts[0])), int(float(parts[1])), int(float(parts[2])) elif color.startswith('rgb'): rgb_values = color.replace('rgb(', '').replace(')', '') r, g, b = [int(float(x.strip())) for x in rgb_values.split(',')] else: r, g, b = 255, 255, 255 return (r, g, b, 255) def process(image): image_size = image.size input_images = transform_image(image).unsqueeze(0).to("cpu") with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return image, mask def process_file(f, bg_color): im = load_img(f, output_type="pil") im = im.convert("RGB") transparent_img, mask = process(im) # With background color rgba_color = parse_color(bg_color) background = Image.new("RGBA", im.size, rgba_color) with_bg = Image.alpha_composite(background, transparent_img) with_bg_rgb = with_bg.convert("RGB") bg_png_path = tempfile.mktemp(suffix=".png") with_bg.save(bg_png_path, "PNG") bg_jpeg_path = tempfile.mktemp(suffix=".jpeg") with_bg_rgb.save(bg_jpeg_path, "JPEG") # Transparent (no background) trans_png_path = tempfile.mktemp(suffix=".png") transparent_img.save(trans_png_path, "PNG") return (with_bg_rgb, bg_png_path, bg_jpeg_path, transparent_img, trans_png_path) css = """ .gradio-container h1 { margin-bottom: 24px; } .small-file, .small-file * { min-height: 0 !important; height: auto !important; } .small-file svg { display: none !important; } """ with gr.Blocks(css=css, title="Background Remover") as background_remover_app: gr.Markdown("

Background Remover

") with gr.Row(): with gr.Column(scale=2): gr.Markdown("### Input") input_image = gr.Image(label="Upload an image", type="filepath") color_picker = gr.ColorPicker(label="Background Color", value="#ffffff") submit_btn = gr.Button("Submit", variant="primary") with gr.Column(scale=2): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Output With Background Color") bg_preview = gr.Image(label="Preview") bg_png = gr.File(label="Download PNG", elem_classes="small-file") bg_jpeg = gr.File(label="Download JPEG", elem_classes="small-file") with gr.Column(scale=1): gr.Markdown("### Output With Transparent Background") trans_preview = gr.Image(label="Preview") trans_png = gr.File(label="Download PNG", elem_classes="small-file") gr.Examples( examples=[["butterfly.jpg", "#ffffff"]], inputs=[input_image, color_picker] ) submit_btn.click( fn=process_file, inputs=[input_image, color_picker], outputs=[bg_preview, bg_png, bg_jpeg, trans_preview, trans_png] ) if __name__ == "__main__": background_remover_app.launch(share=True)