RemoveWatermark / server.py
Waikul's picture
Upload 7 files
ff90953 verified
import io
import os
from typing import Optional
import cv2
import numpy as np
from PIL import Image
import requests
from fastapi import FastAPI, File, Form, UploadFile, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
app = FastAPI(title="Watermark Remover API")
# Serve static + templates
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
LAMA_URL = os.getenv("LAMA_URL", "http://localhost:5000") # optional
def read_image_to_cv2(file: UploadFile) -> np.ndarray:
data = file.file.read()
img_arr = np.frombuffer(data, np.uint8)
img = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(400, detail="Invalid image file")
return img
def pil_bytes_from_cv2(img: np.ndarray, fmt: str = "PNG") -> io.BytesIO:
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(img_rgb)
buf = io.BytesIO()
pil_img.save(buf, format=fmt)
buf.seek(0)
return buf
def auto_text_mask(img: np.ndarray) -> np.ndarray:
"""Simple heuristic mask for text/logo-like overlays."""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = cv2.equalizeHist(gray)
mser = cv2.MSER_create(_delta=5, _min_area=60, _max_area=10000)
regions, _ = mser.detectRegions(gray)
mask = np.zeros(gray.shape, dtype=np.uint8)
for p in regions:
hull = cv2.convexHull(p.reshape(-1, 1, 2))
cv2.drawContours(mask, [hull], -1, 255, thickness=-1)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
mask = cv2.dilate(mask, kernel, iterations=1)
return mask
def inpaint_opencv(img: np.ndarray, mask: np.ndarray, method: str = "telea", radius: int = 3) -> np.ndarray:
flag = cv2.INPAINT_TELEA if method.lower() == "telea" else cv2.INPAINT_NS
if mask.ndim == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
_, mask_bin = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
result = cv2.inpaint(img, mask_bin, radius, flag)
return result
def call_lama_server(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
if mask.ndim == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
_, mask_bin = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
def to_png_bytes(arr: np.ndarray) -> bytes:
a = arr
if a.ndim == 2:
a = cv2.cvtColor(a, cv2.COLOR_GRAY2BGR)
ok, buf = cv2.imencode('.png', a)
if not ok:
raise HTTPException(500, detail="Encoding error")
return buf.tobytes()
files = {
'image': ('image.png', to_png_bytes(img), 'image/png'),
'mask': ('mask.png', to_png_bytes(mask_bin), 'image/png'),
}
data = {'method': 'lama'}
try:
resp = requests.post(f"{LAMA_URL}/inpaint", data=data, files=files, timeout=120)
resp.raise_for_status()
except Exception as e:
raise HTTPException(502, detail=f"LaMa server error: {e}")
nparr = np.frombuffer(resp.content, np.uint8)
out = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if out is None:
raise HTTPException(502, detail="Invalid response from LaMa server")
return out
@app.get("/", response_class=HTMLResponse)
def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/health")
def health():
return JSONResponse({"ok": True})
@app.post("/api/remove-watermark")
def remove_watermark(
image: UploadFile = File(...),
mask: Optional[UploadFile] = File(None),
engine: str = Form("opencv"), # "opencv" | "lama"
method: str = Form("telea"), # opencv: telea | ns
radius: int = Form(3),
auto_mask: int = Form(1), # 1=true, 0=false
):
img = read_image_to_cv2(image)
if mask is not None:
mask_img = read_image_to_cv2(mask)
else:
mask_img = auto_text_mask(img) if auto_mask else np.zeros(img.shape[:2], dtype=np.uint8)
if engine == "lama":
out = call_lama_server(img, mask_img)
else:
out = inpaint_opencv(img, mask_img, method=method, radius=radius)
buf = pil_bytes_from_cv2(out, fmt="PNG")
return StreamingResponse(buf, media_type="image/png")