Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| from typing import Optional | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import requests | |
| from fastapi import FastAPI, File, Form, UploadFile, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| app = FastAPI(title="Watermark Remover API") | |
| # Serve static + templates | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| LAMA_URL = os.getenv("LAMA_URL", "http://localhost:5000") # optional | |
| def read_image_to_cv2(file: UploadFile) -> np.ndarray: | |
| data = file.file.read() | |
| img_arr = np.frombuffer(data, np.uint8) | |
| img = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise HTTPException(400, detail="Invalid image file") | |
| return img | |
| def pil_bytes_from_cv2(img: np.ndarray, fmt: str = "PNG") -> io.BytesIO: | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(img_rgb) | |
| buf = io.BytesIO() | |
| pil_img.save(buf, format=fmt) | |
| buf.seek(0) | |
| return buf | |
| def auto_text_mask(img: np.ndarray) -> np.ndarray: | |
| """Simple heuristic mask for text/logo-like overlays.""" | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| gray = cv2.equalizeHist(gray) | |
| mser = cv2.MSER_create(_delta=5, _min_area=60, _max_area=10000) | |
| regions, _ = mser.detectRegions(gray) | |
| mask = np.zeros(gray.shape, dtype=np.uint8) | |
| for p in regions: | |
| hull = cv2.convexHull(p.reshape(-1, 1, 2)) | |
| cv2.drawContours(mask, [hull], -1, 255, thickness=-1) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1) | |
| mask = cv2.dilate(mask, kernel, iterations=1) | |
| return mask | |
| def inpaint_opencv(img: np.ndarray, mask: np.ndarray, method: str = "telea", radius: int = 3) -> np.ndarray: | |
| flag = cv2.INPAINT_TELEA if method.lower() == "telea" else cv2.INPAINT_NS | |
| if mask.ndim == 3: | |
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
| _, mask_bin = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY) | |
| result = cv2.inpaint(img, mask_bin, radius, flag) | |
| return result | |
| def call_lama_server(img: np.ndarray, mask: np.ndarray) -> np.ndarray: | |
| if mask.ndim == 3: | |
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
| _, mask_bin = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY) | |
| def to_png_bytes(arr: np.ndarray) -> bytes: | |
| a = arr | |
| if a.ndim == 2: | |
| a = cv2.cvtColor(a, cv2.COLOR_GRAY2BGR) | |
| ok, buf = cv2.imencode('.png', a) | |
| if not ok: | |
| raise HTTPException(500, detail="Encoding error") | |
| return buf.tobytes() | |
| files = { | |
| 'image': ('image.png', to_png_bytes(img), 'image/png'), | |
| 'mask': ('mask.png', to_png_bytes(mask_bin), 'image/png'), | |
| } | |
| data = {'method': 'lama'} | |
| try: | |
| resp = requests.post(f"{LAMA_URL}/inpaint", data=data, files=files, timeout=120) | |
| resp.raise_for_status() | |
| except Exception as e: | |
| raise HTTPException(502, detail=f"LaMa server error: {e}") | |
| nparr = np.frombuffer(resp.content, np.uint8) | |
| out = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if out is None: | |
| raise HTTPException(502, detail="Invalid response from LaMa server") | |
| return out | |
| def index(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| def health(): | |
| return JSONResponse({"ok": True}) | |
| def remove_watermark( | |
| image: UploadFile = File(...), | |
| mask: Optional[UploadFile] = File(None), | |
| engine: str = Form("opencv"), # "opencv" | "lama" | |
| method: str = Form("telea"), # opencv: telea | ns | |
| radius: int = Form(3), | |
| auto_mask: int = Form(1), # 1=true, 0=false | |
| ): | |
| img = read_image_to_cv2(image) | |
| if mask is not None: | |
| mask_img = read_image_to_cv2(mask) | |
| else: | |
| mask_img = auto_text_mask(img) if auto_mask else np.zeros(img.shape[:2], dtype=np.uint8) | |
| if engine == "lama": | |
| out = call_lama_server(img, mask_img) | |
| else: | |
| out = inpaint_opencv(img, mask_img, method=method, radius=radius) | |
| buf = pil_bytes_from_cv2(out, fmt="PNG") | |
| return StreamingResponse(buf, media_type="image/png") | |