r3gm commited on
Commit
44d8a6c
·
verified ·
1 Parent(s): 5ddf1f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -23
app.py CHANGED
@@ -49,6 +49,7 @@ from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
49
  import torch
50
  import re
51
  import time
 
52
  from PIL import ImageFile
53
  from utils import (
54
  download_things,
@@ -175,7 +176,11 @@ class GuiSD:
175
  self.last_load = datetime.now()
176
  self.inventory = []
177
 
178
- def update_storage_models(self, storage_floor_gb=30, required_inventory_for_purge=3):
 
 
 
 
179
  while get_used_storage_gb() > storage_floor_gb:
180
  if len(self.inventory) < required_inventory_for_purge:
181
  break
@@ -200,30 +205,48 @@ class GuiSD:
200
 
201
  def load_new_model(self, model_name, vae_model, task, controlnet_model, progress=gr.Progress(track_tqdm=True)):
202
 
203
- # download link model > model_name
204
- if model_name.startswith("http"):
205
- yield f"Downloading model: {model_name}"
206
- model_name = download_things(DIRECTORY_MODELS, model_name, HF_TOKEN, CIVITAI_API_KEY)
207
- if not model_name:
208
- raise ValueError("Error retrieving model information from URL")
209
 
210
- if IS_ZERO_GPU:
211
- self.update_storage_models()
 
 
 
212
 
213
- vae_model = vae_model if vae_model != "None" else None
214
- model_type = get_model_type(model_name)
215
- dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
216
-
217
- if not os.path.exists(model_name):
218
- logger.debug(f"model_name={model_name}, vae_model={vae_model}, task={task}, controlnet_model={controlnet_model}")
219
- _ = download_diffuser_repo(
220
- repo_name=model_name,
221
- model_type=model_type,
222
- revision="main",
223
- token=True,
224
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- self.update_inventory(model_name)
 
 
 
 
227
 
228
  for i in range(68):
229
  if not self.status_loading:
@@ -797,7 +820,7 @@ with gr.Blocks(theme=args.theme, css=CSS, fill_width=True, fill_height=False) as
797
  prompt_gui = gr.Textbox(lines=5, placeholder="Enter prompt", label="Prompt")
798
 
799
  with gr.Accordion("Negative prompt", open=False, visible=True):
800
- neg_prompt_gui = gr.Textbox(lines=3, placeholder="Enter Neg prompt", label="Negative prompt", value="bad anatomy, ((many hands, bad hands, missing fingers)), anatomical nonsense, ugly, deformed, bad proportions, bad shadow, extra limbs, missing limbs, floating limbs, disconnected limbs, malformed hands, poorly drawn, mutation, mutated hands and fingers, extra legs, interlocked fingers, extra arms, disfigured face, long neck, asymmetrical eyes, lowres, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry, duplicate, bad composition, text, worst quality, normal quality, low quality, very displeasing, monochrome, grayscale, black and white, desaturated, low contrast, muted tones, washed out, unfinished, incomplete, draft, logo, backlighting")
801
 
802
  with gr.Row(equal_height=False):
803
  set_params_gui = gr.Button(value="↙️", variant="secondary", size="sm")
 
49
  import torch
50
  import re
51
  import time
52
+ import threading
53
  from PIL import ImageFile
54
  from utils import (
55
  download_things,
 
176
  self.last_load = datetime.now()
177
  self.inventory = []
178
 
179
+ # Avoid duplicate downloads
180
+ self.active_downloads = set()
181
+ self.download_lock = threading.Lock()
182
+
183
+ def update_storage_models(self, storage_floor_gb=30, required_inventory_for_purge=4):
184
  while get_used_storage_gb() > storage_floor_gb:
185
  if len(self.inventory) < required_inventory_for_purge:
186
  break
 
205
 
206
  def load_new_model(self, model_name, vae_model, task, controlnet_model, progress=gr.Progress(track_tqdm=True)):
207
 
208
+ lock_key = model_name
 
 
 
 
 
209
 
210
+ while True:
211
+ with self.download_lock:
212
+ if lock_key not in self.active_downloads:
213
+ self.active_downloads.add(lock_key)
214
+ break
215
 
216
+ yield f"Waiting for existing download to finish: {model_name}..."
217
+ time.sleep(1)
218
+
219
+ try:
220
+ # download link model > model_name
221
+ if model_name.startswith("http"):
222
+ yield f"Downloading model: {model_name}"
223
+ model_name = download_things(DIRECTORY_MODELS, model_name, HF_TOKEN, CIVITAI_API_KEY)
224
+ if not model_name:
225
+ raise ValueError("Error retrieving model information from URL")
226
+
227
+ if IS_ZERO_GPU:
228
+ self.update_storage_models()
229
+
230
+ vae_model = vae_model if vae_model != "None" else None
231
+ model_type = get_model_type(model_name)
232
+ dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
233
+
234
+ if not os.path.exists(model_name):
235
+ logger.debug(f"model_name={model_name}, vae_model={vae_model}, task={task}, controlnet_model={controlnet_model}")
236
+ _ = download_diffuser_repo(
237
+ repo_name=model_name,
238
+ model_type=model_type,
239
+ revision="main",
240
+ token=True,
241
+ )
242
+
243
+ self.update_inventory(model_name)
244
 
245
+ except Exception as e:
246
+ raise e
247
+ finally:
248
+ with self.download_lock:
249
+ self.active_downloads.discard(lock_key)
250
 
251
  for i in range(68):
252
  if not self.status_loading:
 
820
  prompt_gui = gr.Textbox(lines=5, placeholder="Enter prompt", label="Prompt")
821
 
822
  with gr.Accordion("Negative prompt", open=False, visible=True):
823
+ neg_prompt_gui = gr.Textbox(lines=3, placeholder="Enter Neg prompt", label="Negative prompt", value="bad anatomy, ((many hands, bad hands, missing fingers)), anatomical nonsense, ugly, deformed, bad proportions, bad shadow, extra limbs, missing limbs, floating limbs, disconnected limbs, malformed hands, poorly drawn, mutation, mutated hands and fingers, extra legs, interlocked fingers, extra arms, disfigured face, long neck, asymmetrical eyes, lowres, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry, duplicate, bad composition, text, worst quality, normal quality, low quality, very displeasing, desaturated, low contrast, muted tones, washed out, unfinished, incomplete, draft, logo, backlighting")
824
 
825
  with gr.Row(equal_height=False):
826
  set_params_gui = gr.Button(value="↙️", variant="secondary", size="sm")