| import torch |
| import os |
|
|
| from diffusers import ( |
| DDPMScheduler, |
| StableDiffusionXLImg2ImgPipeline, |
| LTXPipeline, |
| AutoencoderKL, |
| ) |
|
|
| from hidiffusion import apply_hidiffusion |
|
|
| from mediapipe.tasks import python |
| from mediapipe.tasks.python import vision |
|
|
| from image_gen_aux import UpscaleWithModel |
|
|
| BASE_MODEL = "stabilityai/sdxl-turbo" |
| VIDEO_MODEL = "Lightricks/LTX-Video" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| class ModelHandler: |
| def __init__(self): |
| self.base_pipe = None |
| self.video_pipe = None |
| self.compiled_model = None |
| self.segmenter = None |
| self.upscaler = None |
| self.upscaler4SD = None |
| self.load_models() |
|
|
| def load_base(self): |
| vae = AutoencoderKL.from_pretrained( |
| "madebyollin/sdxl-vae-fp16-fix", |
| torch_dtype=torch.float16, |
| ) |
|
|
| base_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
| BASE_MODEL, |
| vae=vae, |
| torch_dtype=torch.float16, |
| variant="fp16", |
| use_safetensors=True, |
| ) |
| base_pipe = base_pipe.to(device, silence_dtype_warnings=True) |
| base_pipe.scheduler = DDPMScheduler.from_pretrained( |
| BASE_MODEL, |
| subfolder="scheduler", |
| ) |
| apply_hidiffusion(base_pipe) |
| |
| return base_pipe |
| |
| def load_video_pipe(self): |
| pipe = LTXPipeline.from_pretrained(VIDEO_MODEL, torch_dtype=torch.bfloat16) |
| pipe.to(device) |
| return pipe |
| |
| def load_segmenter(self): |
| segment_model = "checkpoints/selfie_multiclass_256x256.tflite" |
| base_options = python.BaseOptions(model_asset_path=segment_model) |
| options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) |
| segmenter = vision.ImageSegmenter.create_from_options(options) |
| return segmenter |
| |
| def load_upscaler(self): |
| model_name = os.environ.get("UPSCALE_MODEL", "Phips/4xNomosWebPhoto_RealPLKSR") |
| upscaler = UpscaleWithModel.from_pretrained(model_name).to(device) |
| return upscaler |
| |
| def load_upscaler4SD(self): |
| model_name = os.environ.get("UPSCALE_FOR_SD_MODEL", "Phips/1xDeJPG_realplksr_otf") |
| upscaler = UpscaleWithModel.from_pretrained(model_name).to(device) |
| return upscaler |
|
|
| def load_models(self): |
| base_pipe = self.load_base() |
| segmenter = self.load_segmenter() |
| upscaler = self.load_upscaler() |
| upscaler4SD = self.load_upscaler4SD() |
|
|
| self.base_pipe = base_pipe |
| self.segmenter = segmenter |
| self.upscaler = upscaler |
| self.upscaler4SD = upscaler4SD |
|
|
| MODELS = ModelHandler() |