royeis commited on
Commit
3903517
·
1 Parent(s): c47cc24

extract generate_with_updates to global scope

Browse files
Files changed (1) hide show
  1. app.py +247 -219
app.py CHANGED
@@ -367,6 +367,224 @@ def initialize_app():
367
  # Load model automatically when script starts
368
  initialize_app()
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  # Create the Gradio interface
371
  def create_interface():
372
  # Create custom theme with light green progress bars
@@ -577,237 +795,47 @@ def create_interface():
577
  baseline_tokens_queue = queue.Queue()
578
  stop_generation = threading.Event()
579
 
580
- def baseline_progress_updater(prog_value):
581
- """Update the baseline progress via the queue"""
582
- baseline_progress_queue.put(prog_value)
583
-
584
- def baseline_tokens_updater(text, token_count):
585
- """Update the baseline generated text via the queue"""
586
- global baseline_think_tag_detected, baseline_progress_frozen, baseline_pre_think_content, baseline_post_think_content
587
-
588
- # Check if </think> tag appears in the text
589
- if not baseline_think_tag_detected and "</think>" in text:
590
- baseline_think_tag_detected = True
591
- baseline_progress_frozen = True
592
-
593
- # Split content at </think>
594
- parts = text.split("</think>", 1)
595
- baseline_pre_think_content = parts[0] + "</think>"
596
- baseline_post_think_content = parts[1] if len(parts) > 1 else ""
597
-
598
- # Signal content split with token count
599
- baseline_tokens_queue.put(("THINK_TAG_DETECTED", baseline_pre_think_content, baseline_post_think_content, token_count))
600
- elif baseline_think_tag_detected:
601
- # Update post-think content
602
- if "</think>" in text:
603
- parts = text.split("</think>", 1)
604
- baseline_post_think_content = parts[1] if len(parts) > 1 else ""
605
- baseline_tokens_queue.put(("POST_THINK_UPDATE", baseline_post_think_content))
606
- else:
607
- baseline_tokens_queue.put(("NORMAL_UPDATE", text))
608
- else:
609
- # Normal pre-think streaming with token count
610
- baseline_tokens_queue.put(("NORMAL_UPDATE", text, token_count))
611
-
612
  def stop_generation_fn():
613
  """Stop the generation process"""
614
  stop_generation.set()
615
  return "Generation stopped"
616
 
617
- def reset_ui():
618
- """Reset the UI elements for a new generation"""
619
- global baseline_think_tag_detected, baseline_progress_frozen
620
- global baseline_pre_think_content, baseline_post_think_content
621
-
622
- # Reset progress tracking for monotonic behavior
623
- reset_progress_tracking()
624
-
625
- baseline_think_tag_detected = False
626
- baseline_progress_frozen = False
627
- baseline_pre_think_content = ""
628
- baseline_post_think_content = ""
629
- stop_generation.clear()
630
-
631
- # Clear all queues
632
- while not baseline_progress_queue.empty():
633
- baseline_progress_queue.get()
634
- while not baseline_tokens_queue.empty():
635
- baseline_tokens_queue.get()
636
-
637
- return {
638
- generation_status: "**Starting generation...**",
639
- baseline_progress_bar: 0,
640
- baseline_thinking_output: "",
641
- baseline_answer_output: "",
642
- baseline_tokens_count: "",
643
- generate_btn: gr.Button("Generating...", variant="secondary", interactive=False),
644
- stop_btn: gr.Button("Stop", variant="stop", interactive=True)
645
- }
646
-
647
- @spaces.GPU(duration=240)
648
- def generate_with_updates(prompt):
649
- """Wrapper around generation function that handles real-time updates"""
650
- # Check if model is loaded
651
- if not model_loaded_successfully:
652
- yield {
653
- generation_status: f"**Cannot generate: {model_loading_error}**"
654
- }
655
- return
656
 
657
- # Use default values
658
- max_tokens = 2048
659
-
660
- # Reset UI first
661
- yield reset_ui()
662
-
663
- # Start generation in a separate thread to allow for UI updates
664
- baseline_result = ""
665
- baseline_token_count = 0
666
- generation_error = None
667
- generation_thread = None
668
-
669
- def run_generation():
670
- nonlocal baseline_result, baseline_token_count, generation_error
671
- try:
672
- # Baseline-only generation
673
- baseline_result, baseline_token_count = generate_baseline_only(
674
- prompt=prompt,
675
- max_new_tokens=max_tokens,
676
- baseline_progress_callback=baseline_progress_updater,
677
- baseline_tokens_callback=baseline_tokens_updater,
678
- stop_event=stop_generation
679
  )
680
- except Exception as e:
681
- generation_error = str(e)
682
-
683
- # Start the generation thread
684
- generation_thread = threading.Thread(target=run_generation)
685
- generation_thread.start()
686
-
687
- # Monitor queues for updates while generation is running
688
- baseline_current_text = ""
689
- baseline_thinking_tokens = 0
690
- baseline_last_progress = 0
691
-
692
- while generation_thread.is_alive() or not baseline_tokens_queue.empty() or not baseline_progress_queue.empty():
693
- updates = {}
694
-
695
- # Check baseline tokens queue
696
- try:
697
- while not baseline_tokens_queue.empty():
698
- token_update = baseline_tokens_queue.get_nowait()
699
-
700
- if isinstance(token_update, tuple):
701
- update_type = token_update[0]
702
-
703
- if update_type == "THINK_TAG_DETECTED":
704
- # </think> tag detected - split content
705
- pre_content = token_update[1]
706
- post_content = token_update[2]
707
- thinking_token_count = token_update[3]
708
-
709
- updates[baseline_thinking_output] = pre_content
710
- updates[baseline_answer_output] = post_content
711
- updates[baseline_progress_bar] = 100.0 # Freeze at 100%
712
-
713
- # Use actual token count (before </think>)
714
- baseline_thinking_tokens = thinking_token_count
715
- updates[baseline_tokens_count] = f"{baseline_thinking_tokens}"
716
-
717
- elif update_type == "POST_THINK_UPDATE":
718
- # Update only the final answer
719
- post_content = token_update[1]
720
- updates[baseline_answer_output] = post_content
721
- # Don't update token count - frozen at thinking tokens
722
-
723
- elif update_type == "NORMAL_UPDATE":
724
- # Normal text update
725
- baseline_current_text = token_update[1]
726
- if not baseline_think_tag_detected:
727
- updates[baseline_thinking_output] = baseline_current_text
728
- # Update thinking token count with actual token count if available
729
- if len(token_update) > 2:
730
- baseline_thinking_tokens = token_update[2]
731
- else:
732
- # Fallback to word count for backward compatibility
733
- baseline_thinking_tokens = len(baseline_current_text.split())
734
- updates[baseline_tokens_count] = f"{baseline_thinking_tokens}"
735
- else:
736
- # This shouldn't happen, but handle it gracefully
737
- updates[baseline_answer_output] = baseline_current_text
738
- else:
739
- # Backward compatibility - treat as normal text
740
- baseline_current_text = token_update
741
- updates[baseline_thinking_output] = baseline_current_text
742
- if not baseline_think_tag_detected:
743
- baseline_thinking_tokens = len(baseline_current_text.split())
744
- updates[baseline_tokens_count] = f"{baseline_thinking_tokens}"
745
-
746
- except queue.Empty:
747
- pass
748
-
749
- # Check baseline progress queue
750
- try:
751
- while not baseline_progress_queue.empty():
752
- baseline_last_progress = baseline_progress_queue.get_nowait()
753
- updates[baseline_progress_bar] = baseline_last_progress
754
- except queue.Empty:
755
- pass
756
-
757
- # If there are any updates, yield them
758
- if updates:
759
- yield updates
760
 
761
- # Sleep briefly to prevent excessive CPU usage
762
- time.sleep(0.05)
763
-
764
- # Final update
765
- final_updates = {
766
- generation_status: "**Generation complete!**" if not generation_error else f"**Error: {generation_error}**",
767
- baseline_progress_bar: 100,
768
- generate_btn: gr.Button("Generate", variant="primary", interactive=True),
769
- stop_btn: gr.Button("Stop", variant="stop", interactive=True)
770
- }
771
-
772
- if not generation_error:
773
- # Handle baseline final display
774
- if baseline_think_tag_detected:
775
- # Split result for final display
776
- if "</think>" in baseline_result:
777
- parts = baseline_result.split("</think>", 1)
778
- final_updates[baseline_thinking_output] = parts[0] + "</think>"
779
- final_updates[baseline_answer_output] = parts[1] if len(parts) > 1 else ""
780
- # Use actual token count from generation
781
- if baseline_thinking_tokens > 0:
782
- final_updates[baseline_tokens_count] = f"{baseline_thinking_tokens}"
783
- else:
784
- # Fallback: use actual token count for thinking part
785
- thinking_text = parts[0] + "</think>"
786
- thinking_token_count = len(tokenizer.encode(thinking_text, add_special_tokens=False))
787
- final_updates[baseline_tokens_count] = f"{thinking_token_count}"
788
- else:
789
- final_updates[baseline_thinking_output] = baseline_result
790
- # Use actual token count
791
- if baseline_thinking_tokens > 0:
792
- final_updates[baseline_tokens_count] = f"{baseline_thinking_tokens}"
793
- else:
794
- total_token_count = len(tokenizer.encode(baseline_result, add_special_tokens=False))
795
- final_updates[baseline_tokens_count] = f"{total_token_count}"
796
- else:
797
- final_updates[baseline_thinking_output] = baseline_result
798
- # Use actual token count
799
- if baseline_thinking_tokens > 0:
800
- final_updates[baseline_tokens_count] = f"{baseline_thinking_tokens}"
801
- else:
802
- total_token_count = len(tokenizer.encode(baseline_result, add_special_tokens=False))
803
- final_updates[baseline_tokens_count] = f"{total_token_count}"
804
-
805
- yield final_updates
806
 
807
  # Connect the buttons to the handlers
808
  if model_loaded_successfully:
809
  generate_btn.click(
810
- generate_with_updates,
811
  inputs=[prompt],
812
  outputs=[
813
  generation_status,
 
367
  # Load model automatically when script starts
368
  initialize_app()
369
 
370
+ # Global function for resetting UI state
371
+ def reset_ui():
372
+ """Reset the UI elements for a new generation"""
373
+ global baseline_think_tag_detected, baseline_progress_frozen
374
+ global baseline_pre_think_content, baseline_post_think_content
375
+
376
+ # Reset progress tracking for monotonic behavior
377
+ reset_progress_tracking()
378
+
379
+ baseline_think_tag_detected = False
380
+ baseline_progress_frozen = False
381
+ baseline_pre_think_content = ""
382
+ baseline_post_think_content = ""
383
+
384
+ return {
385
+ "status": "**Starting generation...**",
386
+ "progress": 0,
387
+ "thinking": "",
388
+ "answer": "",
389
+ "tokens": "",
390
+ "generate_btn_text": "Generating...",
391
+ "generate_btn_interactive": False,
392
+ "stop_btn_interactive": True
393
+ }
394
+
395
+ @spaces.GPU(duration=240)
396
+ def generate_with_updates(prompt, baseline_progress_queue, baseline_tokens_queue, stop_generation):
397
+ """Wrapper around generation function that handles real-time updates"""
398
+ # Check if model is loaded
399
+ if not model_loaded_successfully:
400
+ yield {
401
+ "status": f"**Cannot generate: {model_loading_error}**"
402
+ }
403
+ return
404
+
405
+ # Use default values
406
+ max_tokens = 2048
407
+
408
+ # Reset UI first
409
+ yield reset_ui()
410
+
411
+ # Start generation in a separate thread to allow for UI updates
412
+ baseline_result = ""
413
+ baseline_token_count = 0
414
+ generation_error = None
415
+ generation_thread = None
416
+
417
+ def baseline_progress_updater(prog_value):
418
+ """Update the baseline progress via the queue"""
419
+ baseline_progress_queue.put(prog_value)
420
+
421
+ def baseline_tokens_updater(text, token_count):
422
+ """Update the baseline generated text via the queue"""
423
+ global baseline_think_tag_detected, baseline_progress_frozen, baseline_pre_think_content, baseline_post_think_content
424
+
425
+ # Check if </think> tag appears in the text
426
+ if not baseline_think_tag_detected and "</think>" in text:
427
+ baseline_think_tag_detected = True
428
+ baseline_progress_frozen = True
429
+
430
+ # Split content at </think>
431
+ parts = text.split("</think>", 1)
432
+ baseline_pre_think_content = parts[0] + "</think>"
433
+ baseline_post_think_content = parts[1] if len(parts) > 1 else ""
434
+
435
+ # Signal content split with token count
436
+ baseline_tokens_queue.put(("THINK_TAG_DETECTED", baseline_pre_think_content, baseline_post_think_content, token_count))
437
+ elif baseline_think_tag_detected:
438
+ # Update post-think content
439
+ if "</think>" in text:
440
+ parts = text.split("</think>", 1)
441
+ baseline_post_think_content = parts[1] if len(parts) > 1 else ""
442
+ baseline_tokens_queue.put(("POST_THINK_UPDATE", baseline_post_think_content))
443
+ else:
444
+ baseline_tokens_queue.put(("NORMAL_UPDATE", text))
445
+ else:
446
+ # Normal pre-think streaming with token count
447
+ baseline_tokens_queue.put(("NORMAL_UPDATE", text, token_count))
448
+
449
+ def run_generation():
450
+ nonlocal baseline_result, baseline_token_count, generation_error
451
+ try:
452
+ # Baseline-only generation
453
+ baseline_result, baseline_token_count = generate_baseline_only(
454
+ prompt=prompt,
455
+ max_new_tokens=max_tokens,
456
+ baseline_progress_callback=baseline_progress_updater,
457
+ baseline_tokens_callback=baseline_tokens_updater,
458
+ stop_event=stop_generation
459
+ )
460
+ except Exception as e:
461
+ generation_error = str(e)
462
+
463
+ # Start the generation thread
464
+ generation_thread = threading.Thread(target=run_generation)
465
+ generation_thread.start()
466
+
467
+ # Monitor queues for updates while generation is running
468
+ baseline_current_text = ""
469
+ baseline_thinking_tokens = 0
470
+ baseline_last_progress = 0
471
+
472
+ while generation_thread.is_alive() or not baseline_tokens_queue.empty() or not baseline_progress_queue.empty():
473
+ updates = {}
474
+
475
+ # Check baseline tokens queue
476
+ try:
477
+ while not baseline_tokens_queue.empty():
478
+ token_update = baseline_tokens_queue.get_nowait()
479
+
480
+ if isinstance(token_update, tuple):
481
+ update_type = token_update[0]
482
+
483
+ if update_type == "THINK_TAG_DETECTED":
484
+ # </think> tag detected - split content
485
+ pre_content = token_update[1]
486
+ post_content = token_update[2]
487
+ thinking_token_count = token_update[3]
488
+
489
+ updates["thinking"] = pre_content
490
+ updates["answer"] = post_content
491
+ updates["progress"] = 100.0 # Freeze at 100%
492
+
493
+ # Use actual token count (before </think>)
494
+ baseline_thinking_tokens = thinking_token_count
495
+ updates["tokens"] = f"{baseline_thinking_tokens}"
496
+
497
+ elif update_type == "POST_THINK_UPDATE":
498
+ # Update only the final answer
499
+ post_content = token_update[1]
500
+ updates["answer"] = post_content
501
+ # Don't update token count - frozen at thinking tokens
502
+
503
+ elif update_type == "NORMAL_UPDATE":
504
+ # Normal text update
505
+ baseline_current_text = token_update[1]
506
+ if not baseline_think_tag_detected:
507
+ updates["thinking"] = baseline_current_text
508
+ # Update thinking token count with actual token count if available
509
+ if len(token_update) > 2:
510
+ baseline_thinking_tokens = token_update[2]
511
+ else:
512
+ # Fallback to word count for backward compatibility
513
+ baseline_thinking_tokens = len(baseline_current_text.split())
514
+ updates["tokens"] = f"{baseline_thinking_tokens}"
515
+ else:
516
+ # This shouldn't happen, but handle it gracefully
517
+ updates["answer"] = baseline_current_text
518
+ else:
519
+ # Backward compatibility - treat as normal text
520
+ baseline_current_text = token_update
521
+ updates["thinking"] = baseline_current_text
522
+ if not baseline_think_tag_detected:
523
+ baseline_thinking_tokens = len(baseline_current_text.split())
524
+ updates["tokens"] = f"{baseline_thinking_tokens}"
525
+
526
+ except queue.Empty:
527
+ pass
528
+
529
+ # Check baseline progress queue
530
+ try:
531
+ while not baseline_progress_queue.empty():
532
+ baseline_last_progress = baseline_progress_queue.get_nowait()
533
+ updates["progress"] = baseline_last_progress
534
+ except queue.Empty:
535
+ pass
536
+
537
+ # If there are any updates, yield them
538
+ if updates:
539
+ yield updates
540
+
541
+ # Sleep briefly to prevent excessive CPU usage
542
+ time.sleep(0.05)
543
+
544
+ # Final update
545
+ final_updates = {
546
+ "status": "**Generation complete!**" if not generation_error else f"**Error: {generation_error}**",
547
+ "progress": 100,
548
+ "generate_btn_text": "Generate",
549
+ "generate_btn_interactive": True,
550
+ "stop_btn_interactive": True
551
+ }
552
+
553
+ if not generation_error:
554
+ # Handle baseline final display
555
+ if baseline_think_tag_detected:
556
+ # Split result for final display
557
+ if "</think>" in baseline_result:
558
+ parts = baseline_result.split("</think>", 1)
559
+ final_updates["thinking"] = parts[0] + "</think>"
560
+ final_updates["answer"] = parts[1] if len(parts) > 1 else ""
561
+ # Use actual token count from generation
562
+ if baseline_thinking_tokens > 0:
563
+ final_updates["tokens"] = f"{baseline_thinking_tokens}"
564
+ else:
565
+ # Fallback: use actual token count for thinking part
566
+ thinking_text = parts[0] + "</think>"
567
+ thinking_token_count = len(tokenizer.encode(thinking_text, add_special_tokens=False))
568
+ final_updates["tokens"] = f"{thinking_token_count}"
569
+ else:
570
+ final_updates["thinking"] = baseline_result
571
+ # Use actual token count
572
+ if baseline_thinking_tokens > 0:
573
+ final_updates["tokens"] = f"{baseline_thinking_tokens}"
574
+ else:
575
+ total_token_count = len(tokenizer.encode(baseline_result, add_special_tokens=False))
576
+ final_updates["tokens"] = f"{total_token_count}"
577
+ else:
578
+ final_updates["thinking"] = baseline_result
579
+ # Use actual token count
580
+ if baseline_thinking_tokens > 0:
581
+ final_updates["tokens"] = f"{baseline_thinking_tokens}"
582
+ else:
583
+ total_token_count = len(tokenizer.encode(baseline_result, add_special_tokens=False))
584
+ final_updates["tokens"] = f"{total_token_count}"
585
+
586
+ yield final_updates
587
+
588
  # Create the Gradio interface
589
  def create_interface():
590
  # Create custom theme with light green progress bars
 
795
  baseline_tokens_queue = queue.Queue()
796
  stop_generation = threading.Event()
797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
  def stop_generation_fn():
799
  """Stop the generation process"""
800
  stop_generation.set()
801
  return "Generation stopped"
802
 
803
+ def generate_wrapper(prompt):
804
+ """Wrapper to adapt the global generate_with_updates function for Gradio"""
805
+ # Process updates from the global function and map to UI components
806
+ for update_dict in generate_with_updates(prompt, baseline_progress_queue, baseline_tokens_queue, stop_generation):
807
+ gradio_updates = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
808
 
809
+ # Map the string keys to actual Gradio components
810
+ if "status" in update_dict:
811
+ gradio_updates[generation_status] = update_dict["status"]
812
+ if "progress" in update_dict:
813
+ gradio_updates[baseline_progress_bar] = update_dict["progress"]
814
+ if "thinking" in update_dict:
815
+ gradio_updates[baseline_thinking_output] = update_dict["thinking"]
816
+ if "answer" in update_dict:
817
+ gradio_updates[baseline_answer_output] = update_dict["answer"]
818
+ if "tokens" in update_dict:
819
+ gradio_updates[baseline_tokens_count] = update_dict["tokens"]
820
+ if "generate_btn_text" in update_dict:
821
+ gradio_updates[generate_btn] = gr.Button(
822
+ update_dict["generate_btn_text"],
823
+ variant="secondary" if "Generating" in update_dict["generate_btn_text"] else "primary",
824
+ interactive=update_dict.get("generate_btn_interactive", True)
825
+ )
826
+ if "stop_btn_interactive" in update_dict:
827
+ gradio_updates[stop_btn] = gr.Button(
828
+ "Stop",
829
+ variant="stop",
830
+ interactive=update_dict["stop_btn_interactive"]
831
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832
 
833
+ yield gradio_updates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
834
 
835
  # Connect the buttons to the handlers
836
  if model_loaded_successfully:
837
  generate_btn.click(
838
+ generate_wrapper,
839
  inputs=[prompt],
840
  outputs=[
841
  generation_status,