Spaces:
Sleeping
Sleeping
Commit
·
7a56e2a
1
Parent(s):
de952e4
Add remaining files from agllm-development state
Browse files- .cursorignore +5 -0
- .gitattributes +40 -0
- .gitignore +42 -0
- .vscode/settings.json +2 -0
- AgLLM.code-workspace +8 -0
- agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx +3 -0
- agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx +3 -0
- agllm-data/corrected/Corrected_weed-prepared_results-all.xlsx +3 -0
- agllm-data/india/species.csv +2 -0
- agllm_analysis.ipynb +0 -0
- agllm_development_without_secret.ipynb +0 -0
- api_request_parallel_processor.py +508 -0
- api_request_parallel_processor_anthropic.py +473 -0
- api_request_parallel_processor_universal.py +492 -0
- api_request_parallel_processor_universal_SEQUENTIAL.py +420 -0
- app.py +727 -0
- app_backup.py +490 -0
- app_backup_2.py +510 -0
- app_database_prep.py +234 -0
- outdated-files/agllm-data.zip +3 -0
- outdated-files/agllm_with_evaluation.ipynb +0 -0
- outdated-files/app-basic.py +19 -0
- outdated-files/app-old.py +334 -0
- outdated-files/dd.txt +1 -0
- outdated-files/rag-evaluation (outdated).ipynb +896 -0
- push_logs.txt +1 -0
- question-generation-retrieval-evaluation.ipynb +0 -0
- requirements-23feb2025.txt +241 -0
- requirements.txt +16 -0
- retriever-evaluation-tutorial.ipynb +0 -0
- temp_results.csv +0 -0
.cursorignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv)
|
| 2 |
+
*.ipynb
|
| 3 |
+
*.xlsx
|
| 4 |
+
*.csv
|
| 5 |
+
|
.gitattributes
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
db5/** filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
vector-databases-deployed/** filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# First exclude everything in agllm-data
|
| 2 |
+
agllm-data/*
|
| 3 |
+
|
| 4 |
+
# Then explicitly include the corrected and india directories and their contents
|
| 5 |
+
!agllm-data/corrected/
|
| 6 |
+
!agllm-data/corrected/**
|
| 7 |
+
!agllm-data/india/
|
| 8 |
+
!agllm-data/india/**
|
| 9 |
+
|
| 10 |
+
# Other exclusions
|
| 11 |
+
agllm/
|
| 12 |
+
db3/
|
| 13 |
+
# db5/
|
| 14 |
+
db4/
|
| 15 |
+
mlruns/
|
| 16 |
+
/agllm-data.zip
|
| 17 |
+
.env
|
| 18 |
+
vector-databases/
|
| 19 |
+
vector-databases-deployement/
|
| 20 |
+
writing/
|
| 21 |
+
# but include the analysis folder
|
| 22 |
+
!writing/*/analysis/
|
| 23 |
+
# Include specific file
|
| 24 |
+
!writing/65d4fadc59fceb1a54d1aae6/main.tex
|
| 25 |
+
*.pdf
|
| 26 |
+
|
| 27 |
+
# Chloropleth related files
|
| 28 |
+
chloropleth/
|
| 29 |
+
chloropleth_backup/
|
| 30 |
+
*.shp
|
| 31 |
+
*.png
|
| 32 |
+
|
| 33 |
+
# Vector database backups
|
| 34 |
+
vector-databases-deployed-backup/
|
| 35 |
+
*.sqlite3
|
| 36 |
+
|
| 37 |
+
# Binary files
|
| 38 |
+
*.pyc
|
| 39 |
+
__pycache__/
|
| 40 |
+
.DS_Store
|
| 41 |
+
|
| 42 |
+
vector-databases-deployed/
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
}
|
AgLLM.code-workspace
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"folders": [
|
| 3 |
+
{
|
| 4 |
+
"path": "."
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"settings": {}
|
| 8 |
+
}
|
agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6aef24ba96ca9be054b4f87344d93d3f8f59cbd280ff29d45991977482067717
|
| 3 |
+
size 1185132
|
agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc953489e31d114f689783aaa50d1e23ce3fcc58618a9b22562e0cad5750af57
|
| 3 |
+
size 980300
|
agllm-data/corrected/Corrected_weed-prepared_results-all.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e97a0eaa24c101b940d48f1dd1e11345246e2455521b2d79a45c3da108c63195
|
| 3 |
+
size 602025
|
agllm-data/india/species.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Common Name,species,ETL,Alternate Hosts,Management,Remarks
|
| 2 |
+
Brown Marmorated Stink Bug,Halyomorpha Halys,,horticulture and field crops,"Crush any eggs nymphs or adults you find on plants. Repeat daily for several days.|Botanical extracts: Spray plants with a mixture of 50 g of neem seed extract in 2 L of water. Boil for 15 minutes let cool and spray three times at two-week intervals.|Soapy water: Spray plants with soapy water to remove insects",
|
agllm_analysis.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
agllm_development_without_secret.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
api_request_parallel_processor.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API REQUEST PARALLEL PROCESSOR
|
| 3 |
+
|
| 4 |
+
Using the OpenAI API to process lots of text quickly takes some care.
|
| 5 |
+
If you trickle in a million API requests one by one, they'll take days to complete.
|
| 6 |
+
If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors.
|
| 7 |
+
To maximize throughput, parallel requests need to be throttled to stay under rate limits.
|
| 8 |
+
|
| 9 |
+
This script parallelizes requests to the OpenAI API while throttling to stay under rate limits.
|
| 10 |
+
|
| 11 |
+
Features:
|
| 12 |
+
- Streams requests from file, to avoid running out of memory for giant jobs
|
| 13 |
+
- Makes requests concurrently, to maximize throughput
|
| 14 |
+
- Throttles request and token usage, to stay under rate limits
|
| 15 |
+
- Retries failed requests up to {max_attempts} times, to avoid missing data
|
| 16 |
+
- Logs errors, to diagnose problems with requests
|
| 17 |
+
|
| 18 |
+
Example command to call script:
|
| 19 |
+
```
|
| 20 |
+
python examples/api_request_parallel_processor.py \
|
| 21 |
+
--requests_filepath examples/data/example_requests_to_parallel_process.jsonl \
|
| 22 |
+
--save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \
|
| 23 |
+
--request_url https://api.openai.com/v1/embeddings \
|
| 24 |
+
--max_requests_per_minute 1500 \
|
| 25 |
+
--max_tokens_per_minute 6250000 \
|
| 26 |
+
--token_encoding_name cl100k_base \
|
| 27 |
+
--max_attempts 5 \
|
| 28 |
+
--logging_level 20
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Inputs:
|
| 32 |
+
- requests_filepath : str
|
| 33 |
+
- path to the file containing the requests to be processed
|
| 34 |
+
- file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field
|
| 35 |
+
- e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}}
|
| 36 |
+
- as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically)
|
| 37 |
+
- an example file is provided at examples/data/example_requests_to_parallel_process.jsonl
|
| 38 |
+
- the code to generate the example file is appended to the bottom of this script
|
| 39 |
+
- save_filepath : str, optional
|
| 40 |
+
- path to the file where the results will be saved
|
| 41 |
+
- file will be a jsonl file, where each line is an array with the original request plus the API response
|
| 42 |
+
- e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}]
|
| 43 |
+
- if omitted, results will be saved to {requests_filename}_results.jsonl
|
| 44 |
+
- request_url : str, optional
|
| 45 |
+
- URL of the API endpoint to call
|
| 46 |
+
- if omitted, will default to "https://api.openai.com/v1/embeddings"
|
| 47 |
+
- api_key : str, optional
|
| 48 |
+
- API key to use
|
| 49 |
+
- if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")}
|
| 50 |
+
- max_requests_per_minute : float, optional
|
| 51 |
+
- target number of requests to make per minute (will make less if limited by tokens)
|
| 52 |
+
- leave headroom by setting this to 50% or 75% of your limit
|
| 53 |
+
- if requests are limiting you, try batching multiple embeddings or completions into one request
|
| 54 |
+
- if omitted, will default to 1,500
|
| 55 |
+
- max_tokens_per_minute : float, optional
|
| 56 |
+
- target number of tokens to use per minute (will use less if limited by requests)
|
| 57 |
+
- leave headroom by setting this to 50% or 75% of your limit
|
| 58 |
+
- if omitted, will default to 125,000
|
| 59 |
+
- token_encoding_name : str, optional
|
| 60 |
+
- name of the token encoding used, as defined in the `tiktoken` package
|
| 61 |
+
- if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`)
|
| 62 |
+
- max_attempts : int, optional
|
| 63 |
+
- number of times to retry a failed request before giving up
|
| 64 |
+
- if omitted, will default to 5
|
| 65 |
+
- logging_level : int, optional
|
| 66 |
+
- level of logging to use; higher numbers will log fewer messages
|
| 67 |
+
- 40 = ERROR; will log only when requests fail after all retries
|
| 68 |
+
- 30 = WARNING; will log when requests his rate limits or other errors
|
| 69 |
+
- 20 = INFO; will log when requests start and the status at finish
|
| 70 |
+
- 10 = DEBUG; will log various things as the loop runs to see when they occur
|
| 71 |
+
- if omitted, will default to 20 (INFO).
|
| 72 |
+
|
| 73 |
+
The script is structured as follows:
|
| 74 |
+
- Imports
|
| 75 |
+
- Define main()
|
| 76 |
+
- Initialize things
|
| 77 |
+
- In main loop:
|
| 78 |
+
- Get next request if one is not already waiting for capacity
|
| 79 |
+
- Update available token & request capacity
|
| 80 |
+
- If enough capacity available, call API
|
| 81 |
+
- The loop pauses if a rate limit error is hit
|
| 82 |
+
- The loop breaks when no tasks remain
|
| 83 |
+
- Define dataclasses
|
| 84 |
+
- StatusTracker (stores script metadata counters; only one instance is created)
|
| 85 |
+
- APIRequest (stores API inputs, outputs, metadata; one method to call API)
|
| 86 |
+
- Define functions
|
| 87 |
+
- api_endpoint_from_url (extracts API endpoint from request URL)
|
| 88 |
+
- append_to_jsonl (writes to results file)
|
| 89 |
+
- num_tokens_consumed_from_request (bigger function to infer token usage from request)
|
| 90 |
+
- task_id_generator_function (yields 0, 1, 2, ...)
|
| 91 |
+
- Run main()
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
# imports
|
| 95 |
+
import aiohttp # for making API calls concurrently
|
| 96 |
+
import argparse # for running script from command line
|
| 97 |
+
import asyncio # for running API calls concurrently
|
| 98 |
+
import json # for saving results to a jsonl file
|
| 99 |
+
import logging # for logging rate limit warnings and other messages
|
| 100 |
+
import os # for reading API key
|
| 101 |
+
import re # for matching endpoint from request URL
|
| 102 |
+
import tiktoken # for counting tokens
|
| 103 |
+
import time # for sleeping after rate limit is hit
|
| 104 |
+
from dataclasses import (
|
| 105 |
+
dataclass,
|
| 106 |
+
field,
|
| 107 |
+
) # for storing API inputs, outputs, and metadata
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
async def process_api_requests_from_file(
|
| 111 |
+
requests_filepath: str,
|
| 112 |
+
save_filepath: str,
|
| 113 |
+
request_url: str,
|
| 114 |
+
api_key: str,
|
| 115 |
+
max_requests_per_minute: float,
|
| 116 |
+
max_tokens_per_minute: float,
|
| 117 |
+
token_encoding_name: str,
|
| 118 |
+
max_attempts: int,
|
| 119 |
+
logging_level: int,
|
| 120 |
+
):
|
| 121 |
+
"""Processes API requests in parallel, throttling to stay under rate limits."""
|
| 122 |
+
# constants
|
| 123 |
+
seconds_to_pause_after_rate_limit_error = 15
|
| 124 |
+
seconds_to_sleep_each_loop = (
|
| 125 |
+
0.001 # 1 ms limits max throughput to 1,000 requests per second
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# initialize logging
|
| 129 |
+
logging.basicConfig(level=logging_level)
|
| 130 |
+
logging.debug(f"Logging initialized at level {logging_level}")
|
| 131 |
+
|
| 132 |
+
# infer API endpoint and construct request header
|
| 133 |
+
api_endpoint = api_endpoint_from_url(request_url)
|
| 134 |
+
request_header = {"Authorization": f"Bearer {api_key}"}
|
| 135 |
+
# use api-key header for Azure deployments
|
| 136 |
+
if '/deployments' in request_url:
|
| 137 |
+
request_header = {"api-key": f"{api_key}"}
|
| 138 |
+
|
| 139 |
+
# initialize trackers
|
| 140 |
+
queue_of_requests_to_retry = asyncio.Queue()
|
| 141 |
+
task_id_generator = (
|
| 142 |
+
task_id_generator_function()
|
| 143 |
+
) # generates integer IDs of 0, 1, 2, ...
|
| 144 |
+
status_tracker = (
|
| 145 |
+
StatusTracker()
|
| 146 |
+
) # single instance to track a collection of variables
|
| 147 |
+
next_request = None # variable to hold the next request to call
|
| 148 |
+
|
| 149 |
+
# initialize available capacity counts
|
| 150 |
+
available_request_capacity = max_requests_per_minute
|
| 151 |
+
available_token_capacity = max_tokens_per_minute
|
| 152 |
+
last_update_time = time.time()
|
| 153 |
+
|
| 154 |
+
# initialize flags
|
| 155 |
+
file_not_finished = True # after file is empty, we'll skip reading it
|
| 156 |
+
logging.debug(f"Initialization complete.")
|
| 157 |
+
|
| 158 |
+
# initialize file reading
|
| 159 |
+
with open(requests_filepath) as file:
|
| 160 |
+
# `requests` will provide requests one at a time
|
| 161 |
+
requests = file.__iter__()
|
| 162 |
+
logging.debug(f"File opened. Entering main loop")
|
| 163 |
+
async with aiohttp.ClientSession() as session: # Initialize ClientSession here
|
| 164 |
+
while True:
|
| 165 |
+
# get next request (if one is not already waiting for capacity)
|
| 166 |
+
if next_request is None:
|
| 167 |
+
if not queue_of_requests_to_retry.empty():
|
| 168 |
+
next_request = queue_of_requests_to_retry.get_nowait()
|
| 169 |
+
logging.debug(
|
| 170 |
+
f"Retrying request {next_request.task_id}: {next_request}"
|
| 171 |
+
)
|
| 172 |
+
elif file_not_finished:
|
| 173 |
+
try:
|
| 174 |
+
# get new request
|
| 175 |
+
request_json = json.loads(next(requests))
|
| 176 |
+
next_request = APIRequest(
|
| 177 |
+
task_id=next(task_id_generator),
|
| 178 |
+
request_json=request_json,
|
| 179 |
+
token_consumption=num_tokens_consumed_from_request(
|
| 180 |
+
request_json, api_endpoint, token_encoding_name
|
| 181 |
+
),
|
| 182 |
+
attempts_left=max_attempts,
|
| 183 |
+
metadata=request_json.pop("metadata", None),
|
| 184 |
+
)
|
| 185 |
+
status_tracker.num_tasks_started += 1
|
| 186 |
+
status_tracker.num_tasks_in_progress += 1
|
| 187 |
+
logging.debug(
|
| 188 |
+
f"Reading request {next_request.task_id}: {next_request}"
|
| 189 |
+
)
|
| 190 |
+
except StopIteration:
|
| 191 |
+
# if file runs out, set flag to stop reading it
|
| 192 |
+
logging.debug("Read file exhausted")
|
| 193 |
+
file_not_finished = False
|
| 194 |
+
|
| 195 |
+
# update available capacity
|
| 196 |
+
current_time = time.time()
|
| 197 |
+
seconds_since_update = current_time - last_update_time
|
| 198 |
+
available_request_capacity = min(
|
| 199 |
+
available_request_capacity
|
| 200 |
+
+ max_requests_per_minute * seconds_since_update / 60.0,
|
| 201 |
+
max_requests_per_minute,
|
| 202 |
+
)
|
| 203 |
+
available_token_capacity = min(
|
| 204 |
+
available_token_capacity
|
| 205 |
+
+ max_tokens_per_minute * seconds_since_update / 60.0,
|
| 206 |
+
max_tokens_per_minute,
|
| 207 |
+
)
|
| 208 |
+
last_update_time = current_time
|
| 209 |
+
|
| 210 |
+
# if enough capacity available, call API
|
| 211 |
+
if next_request:
|
| 212 |
+
next_request_tokens = next_request.token_consumption
|
| 213 |
+
if (
|
| 214 |
+
available_request_capacity >= 1
|
| 215 |
+
and available_token_capacity >= next_request_tokens
|
| 216 |
+
):
|
| 217 |
+
# update counters
|
| 218 |
+
available_request_capacity -= 1
|
| 219 |
+
available_token_capacity -= next_request_tokens
|
| 220 |
+
next_request.attempts_left -= 1
|
| 221 |
+
|
| 222 |
+
# call API
|
| 223 |
+
asyncio.create_task(
|
| 224 |
+
next_request.call_api(
|
| 225 |
+
session=session,
|
| 226 |
+
request_url=request_url,
|
| 227 |
+
request_header=request_header,
|
| 228 |
+
retry_queue=queue_of_requests_to_retry,
|
| 229 |
+
save_filepath=save_filepath,
|
| 230 |
+
status_tracker=status_tracker,
|
| 231 |
+
)
|
| 232 |
+
)
|
| 233 |
+
next_request = None # reset next_request to empty
|
| 234 |
+
|
| 235 |
+
# if all tasks are finished, break
|
| 236 |
+
if status_tracker.num_tasks_in_progress == 0:
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
# main loop sleeps briefly so concurrent tasks can run
|
| 240 |
+
await asyncio.sleep(seconds_to_sleep_each_loop)
|
| 241 |
+
|
| 242 |
+
# if a rate limit error was hit recently, pause to cool down
|
| 243 |
+
seconds_since_rate_limit_error = (
|
| 244 |
+
time.time() - status_tracker.time_of_last_rate_limit_error
|
| 245 |
+
)
|
| 246 |
+
if (
|
| 247 |
+
seconds_since_rate_limit_error
|
| 248 |
+
< seconds_to_pause_after_rate_limit_error
|
| 249 |
+
):
|
| 250 |
+
remaining_seconds_to_pause = (
|
| 251 |
+
seconds_to_pause_after_rate_limit_error
|
| 252 |
+
- seconds_since_rate_limit_error
|
| 253 |
+
)
|
| 254 |
+
await asyncio.sleep(remaining_seconds_to_pause)
|
| 255 |
+
# ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
|
| 256 |
+
logging.warn(
|
| 257 |
+
f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# after finishing, log final status
|
| 261 |
+
logging.info(
|
| 262 |
+
f"""Parallel processing complete. Results saved to {save_filepath}"""
|
| 263 |
+
)
|
| 264 |
+
if status_tracker.num_tasks_failed > 0:
|
| 265 |
+
logging.warning(
|
| 266 |
+
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
|
| 267 |
+
)
|
| 268 |
+
if status_tracker.num_rate_limit_errors > 0:
|
| 269 |
+
logging.warning(
|
| 270 |
+
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# dataclasses
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
@dataclass
|
| 278 |
+
class StatusTracker:
|
| 279 |
+
"""Stores metadata about the script's progress. Only one instance is created."""
|
| 280 |
+
|
| 281 |
+
num_tasks_started: int = 0
|
| 282 |
+
num_tasks_in_progress: int = 0 # script ends when this reaches 0
|
| 283 |
+
num_tasks_succeeded: int = 0
|
| 284 |
+
num_tasks_failed: int = 0
|
| 285 |
+
num_rate_limit_errors: int = 0
|
| 286 |
+
num_api_errors: int = 0 # excluding rate limit errors, counted above
|
| 287 |
+
num_other_errors: int = 0
|
| 288 |
+
time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@dataclass
|
| 292 |
+
class APIRequest:
|
| 293 |
+
"""Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call."""
|
| 294 |
+
|
| 295 |
+
task_id: int
|
| 296 |
+
request_json: dict
|
| 297 |
+
token_consumption: int
|
| 298 |
+
attempts_left: int
|
| 299 |
+
metadata: dict
|
| 300 |
+
result: list = field(default_factory=list)
|
| 301 |
+
|
| 302 |
+
async def call_api(
|
| 303 |
+
self,
|
| 304 |
+
session: aiohttp.ClientSession,
|
| 305 |
+
request_url: str,
|
| 306 |
+
request_header: dict,
|
| 307 |
+
retry_queue: asyncio.Queue,
|
| 308 |
+
save_filepath: str,
|
| 309 |
+
status_tracker: StatusTracker,
|
| 310 |
+
):
|
| 311 |
+
"""Calls the OpenAI API and saves results."""
|
| 312 |
+
logging.info(f"Starting request #{self.task_id}")
|
| 313 |
+
error = None
|
| 314 |
+
try:
|
| 315 |
+
async with session.post(
|
| 316 |
+
url=request_url, headers=request_header, json=self.request_json
|
| 317 |
+
) as response:
|
| 318 |
+
response = await response.json()
|
| 319 |
+
if "error" in response:
|
| 320 |
+
logging.warning(
|
| 321 |
+
f"Request {self.task_id} failed with error {response['error']}"
|
| 322 |
+
)
|
| 323 |
+
status_tracker.num_api_errors += 1
|
| 324 |
+
error = response
|
| 325 |
+
if "Rate limit" in response["error"].get("message", ""):
|
| 326 |
+
status_tracker.time_of_last_rate_limit_error = time.time()
|
| 327 |
+
status_tracker.num_rate_limit_errors += 1
|
| 328 |
+
status_tracker.num_api_errors -= (
|
| 329 |
+
1 # rate limit errors are counted separately
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
except (
|
| 333 |
+
Exception
|
| 334 |
+
) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
|
| 335 |
+
logging.warning(f"Request {self.task_id} failed with Exception {e}")
|
| 336 |
+
status_tracker.num_other_errors += 1
|
| 337 |
+
error = e
|
| 338 |
+
if error:
|
| 339 |
+
self.result.append(error)
|
| 340 |
+
if self.attempts_left:
|
| 341 |
+
retry_queue.put_nowait(self)
|
| 342 |
+
else:
|
| 343 |
+
logging.error(
|
| 344 |
+
f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
|
| 345 |
+
)
|
| 346 |
+
data = (
|
| 347 |
+
[self.request_json, [str(e) for e in self.result], self.metadata]
|
| 348 |
+
if self.metadata
|
| 349 |
+
else [self.request_json, [str(e) for e in self.result]]
|
| 350 |
+
)
|
| 351 |
+
append_to_jsonl(data, save_filepath)
|
| 352 |
+
status_tracker.num_tasks_in_progress -= 1
|
| 353 |
+
status_tracker.num_tasks_failed += 1
|
| 354 |
+
else:
|
| 355 |
+
data = (
|
| 356 |
+
[self.request_json, response, self.metadata]
|
| 357 |
+
if self.metadata
|
| 358 |
+
else [self.request_json, response]
|
| 359 |
+
)
|
| 360 |
+
append_to_jsonl(data, save_filepath)
|
| 361 |
+
status_tracker.num_tasks_in_progress -= 1
|
| 362 |
+
status_tracker.num_tasks_succeeded += 1
|
| 363 |
+
logging.debug(f"Request {self.task_id} saved to {save_filepath}")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# functions
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def api_endpoint_from_url(request_url):
|
| 370 |
+
"""Extract the API endpoint from the request URL."""
|
| 371 |
+
match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url)
|
| 372 |
+
if match is None:
|
| 373 |
+
# for Azure OpenAI deployment urls
|
| 374 |
+
match = re.search(r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url)
|
| 375 |
+
return match[1]
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def append_to_jsonl(data, filename: str) -> None:
|
| 379 |
+
"""Append a json payload to the end of a jsonl file."""
|
| 380 |
+
json_string = json.dumps(data)
|
| 381 |
+
with open(filename, "a") as f:
|
| 382 |
+
f.write(json_string + "\n")
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def num_tokens_consumed_from_request(
|
| 386 |
+
request_json: dict,
|
| 387 |
+
api_endpoint: str,
|
| 388 |
+
token_encoding_name: str,
|
| 389 |
+
):
|
| 390 |
+
"""Count the number of tokens in the request. Only supports completion and embedding requests."""
|
| 391 |
+
encoding = tiktoken.get_encoding(token_encoding_name)
|
| 392 |
+
# if completions request, tokens = prompt + n * max_tokens
|
| 393 |
+
if api_endpoint.endswith("completions"):
|
| 394 |
+
max_tokens = request_json.get("max_tokens", 15)
|
| 395 |
+
n = request_json.get("n", 1)
|
| 396 |
+
completion_tokens = n * max_tokens
|
| 397 |
+
|
| 398 |
+
# chat completions
|
| 399 |
+
if api_endpoint.startswith("chat/"):
|
| 400 |
+
num_tokens = 0
|
| 401 |
+
for message in request_json["messages"]:
|
| 402 |
+
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
| 403 |
+
for key, value in message.items():
|
| 404 |
+
num_tokens += len(encoding.encode(value))
|
| 405 |
+
if key == "name": # if there's a name, the role is omitted
|
| 406 |
+
num_tokens -= 1 # role is always required and always 1 token
|
| 407 |
+
num_tokens += 2 # every reply is primed with <im_start>assistant
|
| 408 |
+
return num_tokens + completion_tokens
|
| 409 |
+
# normal completions
|
| 410 |
+
else:
|
| 411 |
+
prompt = request_json["prompt"]
|
| 412 |
+
if isinstance(prompt, str): # single prompt
|
| 413 |
+
prompt_tokens = len(encoding.encode(prompt))
|
| 414 |
+
num_tokens = prompt_tokens + completion_tokens
|
| 415 |
+
return num_tokens
|
| 416 |
+
elif isinstance(prompt, list): # multiple prompts
|
| 417 |
+
prompt_tokens = sum([len(encoding.encode(p)) for p in prompt])
|
| 418 |
+
num_tokens = prompt_tokens + completion_tokens * len(prompt)
|
| 419 |
+
return num_tokens
|
| 420 |
+
else:
|
| 421 |
+
raise TypeError(
|
| 422 |
+
'Expecting either string or list of strings for "prompt" field in completion request'
|
| 423 |
+
)
|
| 424 |
+
# if embeddings request, tokens = input tokens
|
| 425 |
+
elif api_endpoint == "embeddings":
|
| 426 |
+
input = request_json["input"]
|
| 427 |
+
if isinstance(input, str): # single input
|
| 428 |
+
num_tokens = len(encoding.encode(input))
|
| 429 |
+
return num_tokens
|
| 430 |
+
elif isinstance(input, list): # multiple inputs
|
| 431 |
+
num_tokens = sum([len(encoding.encode(i)) for i in input])
|
| 432 |
+
return num_tokens
|
| 433 |
+
else:
|
| 434 |
+
raise TypeError(
|
| 435 |
+
'Expecting either string or list of strings for "inputs" field in embedding request'
|
| 436 |
+
)
|
| 437 |
+
# more logic needed to support other API calls (e.g., edits, inserts, DALL-E)
|
| 438 |
+
else:
|
| 439 |
+
raise NotImplementedError(
|
| 440 |
+
f'API endpoint "{api_endpoint}" not implemented in this script'
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def task_id_generator_function():
|
| 445 |
+
"""Generate integers 0, 1, 2, and so on."""
|
| 446 |
+
task_id = 0
|
| 447 |
+
while True:
|
| 448 |
+
yield task_id
|
| 449 |
+
task_id += 1
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# run script
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
if __name__ == "__main__":
|
| 456 |
+
# parse command line arguments
|
| 457 |
+
parser = argparse.ArgumentParser()
|
| 458 |
+
parser.add_argument("--requests_filepath")
|
| 459 |
+
parser.add_argument("--save_filepath", default=None)
|
| 460 |
+
parser.add_argument("--request_url", default="https://api.openai.com/v1/embeddings")
|
| 461 |
+
parser.add_argument("--api_key", default=os.getenv("OPENAI_API_KEY"))
|
| 462 |
+
parser.add_argument("--max_requests_per_minute", type=int, default=3_000 * 0.5)
|
| 463 |
+
parser.add_argument("--max_tokens_per_minute", type=int, default=250_000 * 0.5)
|
| 464 |
+
parser.add_argument("--token_encoding_name", default="cl100k_base")
|
| 465 |
+
parser.add_argument("--max_attempts", type=int, default=5)
|
| 466 |
+
parser.add_argument("--logging_level", default=logging.INFO)
|
| 467 |
+
args = parser.parse_args()
|
| 468 |
+
|
| 469 |
+
if args.save_filepath is None:
|
| 470 |
+
args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl")
|
| 471 |
+
|
| 472 |
+
# run script
|
| 473 |
+
asyncio.run(
|
| 474 |
+
process_api_requests_from_file(
|
| 475 |
+
requests_filepath=args.requests_filepath,
|
| 476 |
+
save_filepath=args.save_filepath,
|
| 477 |
+
request_url=args.request_url,
|
| 478 |
+
api_key=args.api_key,
|
| 479 |
+
max_requests_per_minute=float(args.max_requests_per_minute),
|
| 480 |
+
max_tokens_per_minute=float(args.max_tokens_per_minute),
|
| 481 |
+
token_encoding_name=args.token_encoding_name,
|
| 482 |
+
max_attempts=int(args.max_attempts),
|
| 483 |
+
logging_level=int(args.logging_level),
|
| 484 |
+
)
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
"""
|
| 489 |
+
APPENDIX
|
| 490 |
+
|
| 491 |
+
The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002.
|
| 492 |
+
|
| 493 |
+
It was generated with the following code:
|
| 494 |
+
|
| 495 |
+
```python
|
| 496 |
+
import json
|
| 497 |
+
|
| 498 |
+
filename = "data/example_requests_to_parallel_process.jsonl"
|
| 499 |
+
n_requests = 10_000
|
| 500 |
+
jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)]
|
| 501 |
+
with open(filename, "w") as f:
|
| 502 |
+
for job in jobs:
|
| 503 |
+
json_string = json.dumps(job)
|
| 504 |
+
f.write(json_string + "\n")
|
| 505 |
+
```
|
| 506 |
+
|
| 507 |
+
As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
|
| 508 |
+
"""
|
api_request_parallel_processor_anthropic.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
print("IMPORTED")
|
| 2 |
+
"""
|
| 3 |
+
API REQUEST PARALLEL PROCESSOR
|
| 4 |
+
|
| 5 |
+
Using the OpenAI API to process lots of text quickly takes some care.
|
| 6 |
+
If you trickle in a million API requests one by one, they'll take days to complete.
|
| 7 |
+
If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors.
|
| 8 |
+
To maximize throughput, parallel requests need to be throttled to stay under rate limits.
|
| 9 |
+
|
| 10 |
+
This script parallelizes requests to the OpenAI API while throttling to stay under rate limits.
|
| 11 |
+
|
| 12 |
+
Features:
|
| 13 |
+
- Streams requests from file, to avoid running out of memory for giant jobs
|
| 14 |
+
- Makes requests concurrently, to maximize throughput
|
| 15 |
+
- Throttles request and token usage, to stay under rate limits
|
| 16 |
+
- Retries failed requests up to {max_attempts} times, to avoid missing data
|
| 17 |
+
- Logs errors, to diagnose problems with requests
|
| 18 |
+
|
| 19 |
+
Example command to call script:
|
| 20 |
+
```
|
| 21 |
+
python examples/api_request_parallel_processor.py \
|
| 22 |
+
--requests_filepath examples/data/example_requests_to_parallel_process.jsonl \
|
| 23 |
+
--save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \
|
| 24 |
+
--request_url https://api.openai.com/v1/embeddings \
|
| 25 |
+
--max_requests_per_minute 1500 \
|
| 26 |
+
--max_tokens_per_minute 6250000 \
|
| 27 |
+
--token_encoding_name cl100k_base \
|
| 28 |
+
--max_attempts 5 \
|
| 29 |
+
--logging_level 20
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Inputs:
|
| 33 |
+
- requests_filepath : str
|
| 34 |
+
- path to the file containing the requests to be processed
|
| 35 |
+
- file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field
|
| 36 |
+
- e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}}
|
| 37 |
+
- as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically)
|
| 38 |
+
- an example file is provided at examples/data/example_requests_to_parallel_process.jsonl
|
| 39 |
+
- the code to generate the example file is appended to the bottom of this script
|
| 40 |
+
- save_filepath : str, optional
|
| 41 |
+
- path to the file where the results will be saved
|
| 42 |
+
- file will be a jsonl file, where each line is an array with the original request plus the API response
|
| 43 |
+
- e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}]
|
| 44 |
+
- if omitted, results will be saved to {requests_filename}_results.jsonl
|
| 45 |
+
- request_url : str, optional
|
| 46 |
+
- URL of the API endpoint to call
|
| 47 |
+
- if omitted, will default to "https://api.openai.com/v1/embeddings"
|
| 48 |
+
- api_key : str, optional
|
| 49 |
+
- API key to use
|
| 50 |
+
- if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")}
|
| 51 |
+
- max_requests_per_minute : float, optional
|
| 52 |
+
- target number of requests to make per minute (will make less if limited by tokens)
|
| 53 |
+
- leave headroom by setting this to 50% or 75% of your limit
|
| 54 |
+
- if requests are limiting you, try batching multiple embeddings or completions into one request
|
| 55 |
+
- if omitted, will default to 1,500
|
| 56 |
+
- max_tokens_per_minute : float, optional
|
| 57 |
+
- target number of tokens to use per minute (will use less if limited by requests)
|
| 58 |
+
- leave headroom by setting this to 50% or 75% of your limit
|
| 59 |
+
- if omitted, will default to 125,000
|
| 60 |
+
- token_encoding_name : str, optional
|
| 61 |
+
- name of the token encoding used, as defined in the `tiktoken` package
|
| 62 |
+
- if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`)
|
| 63 |
+
- max_attempts : int, optional
|
| 64 |
+
- number of times to retry a failed request before giving up
|
| 65 |
+
- if omitted, will default to 5
|
| 66 |
+
- logging_level : int, optional
|
| 67 |
+
- level of logging to use; higher numbers will log fewer messages
|
| 68 |
+
- 40 = ERROR; will log only when requests fail after all retries
|
| 69 |
+
- 30 = WARNING; will log when requests his rate limits or other errors
|
| 70 |
+
- 20 = INFO; will log when requests start and the status at finish
|
| 71 |
+
- 10 = DEBUG; will log various things as the loop runs to see when they occur
|
| 72 |
+
- if omitted, will default to 20 (INFO).
|
| 73 |
+
|
| 74 |
+
The script is structured as follows:
|
| 75 |
+
- Imports
|
| 76 |
+
- Define main()
|
| 77 |
+
- Initialize things
|
| 78 |
+
- In main loop:
|
| 79 |
+
- Get next request if one is not already waiting for capacity
|
| 80 |
+
- Update available token & request capacity
|
| 81 |
+
- If enough capacity available, call API
|
| 82 |
+
- The loop pauses if a rate limit error is hit
|
| 83 |
+
- The loop breaks when no tasks remain
|
| 84 |
+
- Define dataclasses
|
| 85 |
+
- StatusTracker (stores script metadata counters; only one instance is created)
|
| 86 |
+
- APIRequest (stores API inputs, outputs, metadata; one method to call API)
|
| 87 |
+
- Define functions
|
| 88 |
+
- api_endpoint_from_url (extracts API endpoint from request URL)
|
| 89 |
+
- append_to_jsonl (writes to results file)
|
| 90 |
+
- num_tokens_consumed_from_request (bigger function to infer token usage from request)
|
| 91 |
+
- task_id_generator_function (yields 0, 1, 2, ...)
|
| 92 |
+
- Run main()
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
# imports
|
| 96 |
+
import aiohttp # for making API calls concurrently
|
| 97 |
+
import argparse # for running script from command line
|
| 98 |
+
import asyncio # for running API calls concurrently
|
| 99 |
+
import json # for saving results to a jsonl file
|
| 100 |
+
import logging # for logging rate limit warnings and other messages
|
| 101 |
+
import os # for reading API key
|
| 102 |
+
import re # for matching endpoint from request URL
|
| 103 |
+
import tiktoken # for counting tokens
|
| 104 |
+
import time # for sleeping after rate limit is hit
|
| 105 |
+
from dataclasses import (
|
| 106 |
+
dataclass,
|
| 107 |
+
field,
|
| 108 |
+
) # for storing API inputs, outputs, and metadata
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
async def process_api_requests_from_file(
|
| 112 |
+
requests_filepath: str,
|
| 113 |
+
save_filepath: str,
|
| 114 |
+
request_url: str,
|
| 115 |
+
api_key: str,
|
| 116 |
+
max_requests_per_minute: float,
|
| 117 |
+
max_tokens_per_minute: float,
|
| 118 |
+
token_encoding_name: str,
|
| 119 |
+
max_attempts: int,
|
| 120 |
+
logging_level: int,
|
| 121 |
+
):
|
| 122 |
+
"""Processes API requests in parallel, throttling to stay under rate limits."""
|
| 123 |
+
# constants
|
| 124 |
+
seconds_to_pause_after_rate_limit_error = 15
|
| 125 |
+
seconds_to_sleep_each_loop = (
|
| 126 |
+
0.001 # 1 ms limits max throughput to 1,000 requests per second
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# initialize logging
|
| 130 |
+
logging.basicConfig(level=logging_level)
|
| 131 |
+
logging.debug(f"Logging initialized at level {logging_level}")
|
| 132 |
+
|
| 133 |
+
# infer API endpoint and construct request header
|
| 134 |
+
api_endpoint = api_endpoint_from_url(request_url)
|
| 135 |
+
# request_header = {"Authorization": f"Bearer {api_key}"}
|
| 136 |
+
request_header = {
|
| 137 |
+
"x-api-key": api_key,
|
| 138 |
+
"anthropic-version": "2023-06-01",
|
| 139 |
+
"content-type": "application/json",
|
| 140 |
+
}
|
| 141 |
+
# use api-key header for Azure deployments
|
| 142 |
+
if '/deployments' in request_url:
|
| 143 |
+
request_header = {"api-key": f"{api_key}"}
|
| 144 |
+
|
| 145 |
+
# initialize trackers
|
| 146 |
+
queue_of_requests_to_retry = asyncio.Queue()
|
| 147 |
+
task_id_generator = (
|
| 148 |
+
task_id_generator_function()
|
| 149 |
+
) # generates integer IDs of 0, 1, 2, ...
|
| 150 |
+
status_tracker = (
|
| 151 |
+
StatusTracker()
|
| 152 |
+
) # single instance to track a collection of variables
|
| 153 |
+
next_request = None # variable to hold the next request to call
|
| 154 |
+
|
| 155 |
+
# initialize available capacity counts
|
| 156 |
+
available_request_capacity = max_requests_per_minute
|
| 157 |
+
available_token_capacity = max_tokens_per_minute
|
| 158 |
+
last_update_time = time.time()
|
| 159 |
+
|
| 160 |
+
# initialize flags
|
| 161 |
+
file_not_finished = True # after file is empty, we'll skip reading it
|
| 162 |
+
logging.debug(f"Initialization complete.")
|
| 163 |
+
|
| 164 |
+
# initialize file reading
|
| 165 |
+
with open(requests_filepath) as file:
|
| 166 |
+
# `requests` will provide requests one at a time
|
| 167 |
+
requests = file.__iter__()
|
| 168 |
+
logging.debug(f"File opened. Entering main loop")
|
| 169 |
+
async with aiohttp.ClientSession() as session: # Initialize ClientSession here
|
| 170 |
+
while True:
|
| 171 |
+
# get next request (if one is not already waiting for capacity)
|
| 172 |
+
if next_request is None:
|
| 173 |
+
if not queue_of_requests_to_retry.empty():
|
| 174 |
+
next_request = queue_of_requests_to_retry.get_nowait()
|
| 175 |
+
logging.debug(
|
| 176 |
+
f"Retrying request {next_request.task_id}: {next_request}"
|
| 177 |
+
)
|
| 178 |
+
elif file_not_finished:
|
| 179 |
+
try:
|
| 180 |
+
# get new request
|
| 181 |
+
request_json = json.loads(next(requests))
|
| 182 |
+
next_request = APIRequest(
|
| 183 |
+
task_id=next(task_id_generator),
|
| 184 |
+
request_json=request_json,
|
| 185 |
+
token_consumption=num_tokens_consumed_from_request(
|
| 186 |
+
request_json, api_endpoint, token_encoding_name
|
| 187 |
+
),
|
| 188 |
+
attempts_left=max_attempts,
|
| 189 |
+
metadata=request_json.pop("metadata", None),
|
| 190 |
+
)
|
| 191 |
+
status_tracker.num_tasks_started += 1
|
| 192 |
+
status_tracker.num_tasks_in_progress += 1
|
| 193 |
+
logging.debug(
|
| 194 |
+
f"Reading request {next_request.task_id}: {next_request}"
|
| 195 |
+
)
|
| 196 |
+
except StopIteration:
|
| 197 |
+
# if file runs out, set flag to stop reading it
|
| 198 |
+
logging.debug("Read file exhausted")
|
| 199 |
+
file_not_finished = False
|
| 200 |
+
|
| 201 |
+
# update available capacity
|
| 202 |
+
current_time = time.time()
|
| 203 |
+
seconds_since_update = current_time - last_update_time
|
| 204 |
+
available_request_capacity = min(
|
| 205 |
+
available_request_capacity
|
| 206 |
+
+ max_requests_per_minute * seconds_since_update / 60.0,
|
| 207 |
+
max_requests_per_minute,
|
| 208 |
+
)
|
| 209 |
+
available_token_capacity = min(
|
| 210 |
+
available_token_capacity
|
| 211 |
+
+ max_tokens_per_minute * seconds_since_update / 60.0,
|
| 212 |
+
max_tokens_per_minute,
|
| 213 |
+
)
|
| 214 |
+
last_update_time = current_time
|
| 215 |
+
|
| 216 |
+
# if enough capacity available, call API
|
| 217 |
+
if next_request:
|
| 218 |
+
next_request_tokens = next_request.token_consumption
|
| 219 |
+
if (
|
| 220 |
+
available_request_capacity >= 1
|
| 221 |
+
and available_token_capacity >= next_request_tokens
|
| 222 |
+
):
|
| 223 |
+
# update counters
|
| 224 |
+
available_request_capacity -= 1
|
| 225 |
+
available_token_capacity -= next_request_tokens
|
| 226 |
+
next_request.attempts_left -= 1
|
| 227 |
+
|
| 228 |
+
# call API
|
| 229 |
+
asyncio.create_task(
|
| 230 |
+
next_request.call_api(
|
| 231 |
+
session=session,
|
| 232 |
+
request_url=request_url,
|
| 233 |
+
request_header=request_header,
|
| 234 |
+
retry_queue=queue_of_requests_to_retry,
|
| 235 |
+
save_filepath=save_filepath,
|
| 236 |
+
status_tracker=status_tracker,
|
| 237 |
+
)
|
| 238 |
+
)
|
| 239 |
+
next_request = None # reset next_request to empty
|
| 240 |
+
|
| 241 |
+
# if all tasks are finished, break
|
| 242 |
+
if status_tracker.num_tasks_in_progress == 0:
|
| 243 |
+
break
|
| 244 |
+
|
| 245 |
+
# main loop sleeps briefly so concurrent tasks can run
|
| 246 |
+
await asyncio.sleep(seconds_to_sleep_each_loop)
|
| 247 |
+
|
| 248 |
+
# if a rate limit error was hit recently, pause to cool down
|
| 249 |
+
seconds_since_rate_limit_error = (
|
| 250 |
+
time.time() - status_tracker.time_of_last_rate_limit_error
|
| 251 |
+
)
|
| 252 |
+
if (
|
| 253 |
+
seconds_since_rate_limit_error
|
| 254 |
+
< seconds_to_pause_after_rate_limit_error
|
| 255 |
+
):
|
| 256 |
+
remaining_seconds_to_pause = (
|
| 257 |
+
seconds_to_pause_after_rate_limit_error
|
| 258 |
+
- seconds_since_rate_limit_error
|
| 259 |
+
)
|
| 260 |
+
await asyncio.sleep(remaining_seconds_to_pause)
|
| 261 |
+
# ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
|
| 262 |
+
logging.warn(
|
| 263 |
+
f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# after finishing, log final status
|
| 267 |
+
logging.info(
|
| 268 |
+
f"""Parallel processing complete. Results saved to {save_filepath}"""
|
| 269 |
+
)
|
| 270 |
+
if status_tracker.num_tasks_failed > 0:
|
| 271 |
+
logging.warning(
|
| 272 |
+
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
|
| 273 |
+
)
|
| 274 |
+
if status_tracker.num_rate_limit_errors > 0:
|
| 275 |
+
logging.warning(
|
| 276 |
+
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# dataclasses
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@dataclass
|
| 284 |
+
class StatusTracker:
|
| 285 |
+
"""Stores metadata about the script's progress. Only one instance is created."""
|
| 286 |
+
|
| 287 |
+
num_tasks_started: int = 0
|
| 288 |
+
num_tasks_in_progress: int = 0 # script ends when this reaches 0
|
| 289 |
+
num_tasks_succeeded: int = 0
|
| 290 |
+
num_tasks_failed: int = 0
|
| 291 |
+
num_rate_limit_errors: int = 0
|
| 292 |
+
num_api_errors: int = 0 # excluding rate limit errors, counted above
|
| 293 |
+
num_other_errors: int = 0
|
| 294 |
+
time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@dataclass
|
| 298 |
+
class APIRequest:
|
| 299 |
+
"""Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call."""
|
| 300 |
+
|
| 301 |
+
task_id: int
|
| 302 |
+
request_json: dict
|
| 303 |
+
token_consumption: int
|
| 304 |
+
attempts_left: int
|
| 305 |
+
metadata: dict
|
| 306 |
+
result: list = field(default_factory=list)
|
| 307 |
+
|
| 308 |
+
async def call_api(
|
| 309 |
+
self,
|
| 310 |
+
session: aiohttp.ClientSession,
|
| 311 |
+
request_url: str,
|
| 312 |
+
request_header: dict,
|
| 313 |
+
retry_queue: asyncio.Queue,
|
| 314 |
+
save_filepath: str,
|
| 315 |
+
status_tracker: StatusTracker,
|
| 316 |
+
):
|
| 317 |
+
"""Calls the OpenAI API and saves results."""
|
| 318 |
+
logging.info(f"Starting request #{self.task_id}")
|
| 319 |
+
error = None
|
| 320 |
+
try:
|
| 321 |
+
async with session.post(
|
| 322 |
+
url=request_url, headers=request_header, json=self.request_json
|
| 323 |
+
) as response:
|
| 324 |
+
response = await response.json()
|
| 325 |
+
if "error" in response:
|
| 326 |
+
logging.warning(
|
| 327 |
+
f"Request {self.task_id} failed with error {response['error']}"
|
| 328 |
+
)
|
| 329 |
+
status_tracker.num_api_errors += 1
|
| 330 |
+
error = response
|
| 331 |
+
if "Rate limit" in response["error"].get("message", ""):
|
| 332 |
+
status_tracker.time_of_last_rate_limit_error = time.time()
|
| 333 |
+
status_tracker.num_rate_limit_errors += 1
|
| 334 |
+
status_tracker.num_api_errors -= (
|
| 335 |
+
1 # rate limit errors are counted separately
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
except (
|
| 339 |
+
Exception
|
| 340 |
+
) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
|
| 341 |
+
logging.warning(f"Request {self.task_id} failed with Exception {e}")
|
| 342 |
+
status_tracker.num_other_errors += 1
|
| 343 |
+
error = e
|
| 344 |
+
if error:
|
| 345 |
+
self.result.append(error)
|
| 346 |
+
if self.attempts_left:
|
| 347 |
+
retry_queue.put_nowait(self)
|
| 348 |
+
else:
|
| 349 |
+
logging.error(
|
| 350 |
+
f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
|
| 351 |
+
)
|
| 352 |
+
data = (
|
| 353 |
+
[self.request_json, [str(e) for e in self.result], self.metadata]
|
| 354 |
+
if self.metadata
|
| 355 |
+
else [self.request_json, [str(e) for e in self.result]]
|
| 356 |
+
)
|
| 357 |
+
append_to_jsonl(data, save_filepath)
|
| 358 |
+
status_tracker.num_tasks_in_progress -= 1
|
| 359 |
+
status_tracker.num_tasks_failed += 1
|
| 360 |
+
else:
|
| 361 |
+
data = (
|
| 362 |
+
[self.request_json, response, self.metadata]
|
| 363 |
+
if self.metadata
|
| 364 |
+
else [self.request_json, response]
|
| 365 |
+
)
|
| 366 |
+
append_to_jsonl(data, save_filepath)
|
| 367 |
+
status_tracker.num_tasks_in_progress -= 1
|
| 368 |
+
status_tracker.num_tasks_succeeded += 1
|
| 369 |
+
logging.debug(f"Request {self.task_id} saved to {save_filepath}")
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
# functions
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def api_endpoint_from_url(request_url):
|
| 376 |
+
print(request_url)
|
| 377 |
+
"""Extract the API endpoint from the request URL."""
|
| 378 |
+
match = re.search(r"^https://[^/]+/v1/(.+)$", request_url)
|
| 379 |
+
if match:
|
| 380 |
+
return match[1]
|
| 381 |
+
else:
|
| 382 |
+
raise ValueError(f"Invalid API URL: {request_url}")
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def append_to_jsonl(data, filename: str) -> None:
|
| 386 |
+
"""Append a json payload to the end of a jsonl file."""
|
| 387 |
+
json_string = json.dumps(data)
|
| 388 |
+
with open(filename, "a") as f:
|
| 389 |
+
f.write(json_string + "\n")
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def num_tokens_consumed_from_request(
|
| 393 |
+
request_json: dict,
|
| 394 |
+
api_endpoint: str,
|
| 395 |
+
token_encoding_name: str,
|
| 396 |
+
):
|
| 397 |
+
encoding = tiktoken.get_encoding(token_encoding_name)
|
| 398 |
+
print(api_endpoint)
|
| 399 |
+
if api_endpoint == "messages":
|
| 400 |
+
num_tokens = 0
|
| 401 |
+
for message in request_json["messages"]:
|
| 402 |
+
num_tokens += len(encoding.encode(message["content"]))
|
| 403 |
+
return num_tokens
|
| 404 |
+
else:
|
| 405 |
+
raise NotImplementedError(
|
| 406 |
+
f'API endpoint "{api_endpoint}" not implemented in this script'
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
def task_id_generator_function():
|
| 410 |
+
"""Generate integers 0, 1, 2, and so on."""
|
| 411 |
+
task_id = 0
|
| 412 |
+
while True:
|
| 413 |
+
yield task_id
|
| 414 |
+
task_id += 1
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# run script
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
if __name__ == "__main__":
|
| 421 |
+
# parse command line arguments
|
| 422 |
+
parser = argparse.ArgumentParser()
|
| 423 |
+
parser.add_argument("--requests_filepath")
|
| 424 |
+
parser.add_argument("--save_filepath", default=None)
|
| 425 |
+
parser.add_argument("--request_url", default="https://api.openai.com/v1/embeddings")
|
| 426 |
+
parser.add_argument("--api_key", default=os.getenv("ANTHROPIC_API_KEY"))
|
| 427 |
+
parser.add_argument("--max_requests_per_minute", type=int, default=3_000 * 0.5)
|
| 428 |
+
parser.add_argument("--max_tokens_per_minute", type=int, default=250_000 * 0.5)
|
| 429 |
+
parser.add_argument("--token_encoding_name", default="cl100k_base")
|
| 430 |
+
parser.add_argument("--max_attempts", type=int, default=5)
|
| 431 |
+
parser.add_argument("--logging_level", default=logging.INFO)
|
| 432 |
+
args = parser.parse_args()
|
| 433 |
+
|
| 434 |
+
if args.save_filepath is None:
|
| 435 |
+
args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl")
|
| 436 |
+
|
| 437 |
+
# run script
|
| 438 |
+
asyncio.run(
|
| 439 |
+
process_api_requests_from_file(
|
| 440 |
+
requests_filepath=args.requests_filepath,
|
| 441 |
+
save_filepath=args.save_filepath,
|
| 442 |
+
request_url=args.request_url,
|
| 443 |
+
api_key=args.api_key,
|
| 444 |
+
max_requests_per_minute=float(args.max_requests_per_minute),
|
| 445 |
+
max_tokens_per_minute=float(args.max_tokens_per_minute),
|
| 446 |
+
token_encoding_name=args.token_encoding_name,
|
| 447 |
+
max_attempts=int(args.max_attempts),
|
| 448 |
+
logging_level=int(args.logging_level),
|
| 449 |
+
)
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
"""
|
| 454 |
+
APPENDIX
|
| 455 |
+
|
| 456 |
+
The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002.
|
| 457 |
+
|
| 458 |
+
It was generated with the following code:
|
| 459 |
+
|
| 460 |
+
```python
|
| 461 |
+
import json
|
| 462 |
+
|
| 463 |
+
filename = "data/example_requests_to_parallel_process.jsonl"
|
| 464 |
+
n_requests = 10_000
|
| 465 |
+
jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)]
|
| 466 |
+
with open(filename, "w") as f:
|
| 467 |
+
for job in jobs:
|
| 468 |
+
json_string = json.dumps(job)
|
| 469 |
+
f.write(json_string + "\n")
|
| 470 |
+
```
|
| 471 |
+
|
| 472 |
+
As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
|
| 473 |
+
"""
|
api_request_parallel_processor_universal.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API REQUEST PARALLEL PROCESSOR
|
| 3 |
+
|
| 4 |
+
Using the OpenAI API to process lots of text quickly takes some care.
|
| 5 |
+
If you trickle in a million API requests one by one, they'll take days to complete.
|
| 6 |
+
If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors.
|
| 7 |
+
To maximize throughput, parallel requests need to be throttled to stay under rate limits.
|
| 8 |
+
|
| 9 |
+
This script parallelizes requests to the OpenAI API while throttling to stay under rate limits.
|
| 10 |
+
|
| 11 |
+
Features:
|
| 12 |
+
- Streams requests from file, to avoid running out of memory for giant jobs
|
| 13 |
+
- Makes requests concurrently, to maximize throughput
|
| 14 |
+
- Throttles request and token usage, to stay under rate limits
|
| 15 |
+
- Retries failed requests up to {max_attempts} times, to avoid missing data
|
| 16 |
+
- Logs errors, to diagnose problems with requests
|
| 17 |
+
|
| 18 |
+
Example command to call script:
|
| 19 |
+
```
|
| 20 |
+
python examples/api_request_parallel_processor.py \
|
| 21 |
+
--requests_filepath examples/data/example_requests_to_parallel_process.jsonl \
|
| 22 |
+
--save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \
|
| 23 |
+
--request_url https://api.openai.com/v1/embeddings \
|
| 24 |
+
--max_requests_per_minute 1500 \
|
| 25 |
+
--max_tokens_per_minute 6250000 \
|
| 26 |
+
--token_encoding_name cl100k_base \
|
| 27 |
+
--max_attempts 5 \
|
| 28 |
+
--logging_level 20
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Inputs:
|
| 32 |
+
- requests_filepath : str
|
| 33 |
+
- path to the file containing the requests to be processed
|
| 34 |
+
- file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field
|
| 35 |
+
- e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}}
|
| 36 |
+
- as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically)
|
| 37 |
+
- an example file is provided at examples/data/example_requests_to_parallel_process.jsonl
|
| 38 |
+
- the code to generate the example file is appended to the bottom of this script
|
| 39 |
+
- save_filepath : str, optional
|
| 40 |
+
- path to the file where the results will be saved
|
| 41 |
+
- file will be a jsonl file, where each line is an array with the original request plus the API response
|
| 42 |
+
- e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}]
|
| 43 |
+
- if omitted, results will be saved to {requests_filename}_results.jsonl
|
| 44 |
+
- request_url : str, optional
|
| 45 |
+
- URL of the API endpoint to call
|
| 46 |
+
- if omitted, will default to "https://api.openai.com/v1/embeddings"
|
| 47 |
+
- api_key : str, optional
|
| 48 |
+
- API key to use
|
| 49 |
+
- if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")}
|
| 50 |
+
- max_requests_per_minute : float, optional
|
| 51 |
+
- target number of requests to make per minute (will make less if limited by tokens)
|
| 52 |
+
- leave headroom by setting this to 50% or 75% of your limit
|
| 53 |
+
- if requests are limiting you, try batching multiple embeddings or completions into one request
|
| 54 |
+
- if omitted, will default to 1,500
|
| 55 |
+
- max_tokens_per_minute : float, optional
|
| 56 |
+
- target number of tokens to use per minute (will use less if limited by requests)
|
| 57 |
+
- leave headroom by setting this to 50% or 75% of your limit
|
| 58 |
+
- if omitted, will default to 125,000
|
| 59 |
+
- token_encoding_name : str, optional
|
| 60 |
+
- name of the token encoding used, as defined in the `tiktoken` package
|
| 61 |
+
- if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`)
|
| 62 |
+
- max_attempts : int, optional
|
| 63 |
+
- number of times to retry a failed request before giving up
|
| 64 |
+
- if omitted, will default to 5
|
| 65 |
+
- logging_level : int, optional
|
| 66 |
+
- level of logging to use; higher numbers will log fewer messages
|
| 67 |
+
- 40 = ERROR; will log only when requests fail after all retries
|
| 68 |
+
- 30 = WARNING; will log when requests his rate limits or other errors
|
| 69 |
+
- 20 = INFO; will log when requests start and the status at finish
|
| 70 |
+
- 10 = DEBUG; will log various things as the loop runs to see when they occur
|
| 71 |
+
- if omitted, will default to 20 (INFO).
|
| 72 |
+
|
| 73 |
+
The script is structured as follows:
|
| 74 |
+
- Imports
|
| 75 |
+
- Define main()
|
| 76 |
+
- Initialize things
|
| 77 |
+
- In main loop:
|
| 78 |
+
- Get next request if one is not already waiting for capacity
|
| 79 |
+
- Update available token & request capacity
|
| 80 |
+
- If enough capacity available, call API
|
| 81 |
+
- The loop pauses if a rate limit error is hit
|
| 82 |
+
- The loop breaks when no tasks remain
|
| 83 |
+
- Define dataclasses
|
| 84 |
+
- StatusTracker (stores script metadata counters; only one instance is created)
|
| 85 |
+
- APIRequest (stores API inputs, outputs, metadata; one method to call API)
|
| 86 |
+
- Define functions
|
| 87 |
+
- api_endpoint_from_url (extracts API endpoint from request URL)
|
| 88 |
+
- append_to_jsonl (writes to results file)
|
| 89 |
+
- num_tokens_consumed_from_request (bigger function to infer token usage from request)
|
| 90 |
+
- task_id_generator_function (yields 0, 1, 2, ...)
|
| 91 |
+
- Run main()
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
# imports
|
| 95 |
+
import aiohttp # for making API calls concurrently
|
| 96 |
+
import argparse # for running script from command line
|
| 97 |
+
import asyncio # for running API calls concurrently
|
| 98 |
+
import json # for saving results to a jsonl file
|
| 99 |
+
import logging # for logging rate limit warnings and other messages
|
| 100 |
+
import os # for reading API key
|
| 101 |
+
import re # for matching endpoint from request URL
|
| 102 |
+
import tiktoken # for counting tokens
|
| 103 |
+
import time # for sleeping after rate limit is hit
|
| 104 |
+
from dataclasses import (
|
| 105 |
+
dataclass,
|
| 106 |
+
field,
|
| 107 |
+
) # for storing API inputs, outputs, and metadata
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
async def process_api_requests_from_file(
|
| 111 |
+
vendor_name: str,
|
| 112 |
+
requests_filepath: str,
|
| 113 |
+
save_filepath: str,
|
| 114 |
+
request_url: str,
|
| 115 |
+
api_key: str,
|
| 116 |
+
max_requests_per_minute: float,
|
| 117 |
+
max_tokens_per_minute: float,
|
| 118 |
+
token_encoding_name: str,
|
| 119 |
+
max_attempts: int,
|
| 120 |
+
logging_level: int,
|
| 121 |
+
):
|
| 122 |
+
"""Processes API requests in parallel, throttling to stay under rate limits."""
|
| 123 |
+
# constants
|
| 124 |
+
seconds_to_pause_after_rate_limit_error = 15
|
| 125 |
+
seconds_to_sleep_each_loop = (
|
| 126 |
+
0.001 # 1 ms limits max throughput to 1,000 requests per second
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# initialize logging
|
| 130 |
+
logging.basicConfig(level=logging_level)
|
| 131 |
+
logging.debug(f"Logging initialized at level {logging_level}")
|
| 132 |
+
|
| 133 |
+
# infer API endpoint and construct request header
|
| 134 |
+
api_endpoint = api_endpoint_from_url(request_url, vendor_name)
|
| 135 |
+
request_header=None
|
| 136 |
+
if vendor_name=="openai":
|
| 137 |
+
request_header = {"Authorization": f"Bearer {api_key}"}
|
| 138 |
+
elif vendor_name=="anthropic":
|
| 139 |
+
request_header = {
|
| 140 |
+
"x-api-key": api_key,
|
| 141 |
+
"anthropic-version": "2023-06-01",
|
| 142 |
+
"content-type": "application/json",
|
| 143 |
+
}
|
| 144 |
+
elif vendor_name == "meta" or vendor_name == "google" :
|
| 145 |
+
request_header = {
|
| 146 |
+
"Content-Type": "application/json",
|
| 147 |
+
"Authorization": f"Bearer {api_key}",
|
| 148 |
+
}
|
| 149 |
+
else:
|
| 150 |
+
print("Error. Invalid Model Input. Exiting")
|
| 151 |
+
# exit()
|
| 152 |
+
|
| 153 |
+
# use api-key header for Azure deployments
|
| 154 |
+
# if '/deployments' in request_url:
|
| 155 |
+
# request_header = {"api-key": f"{api_key}"}
|
| 156 |
+
|
| 157 |
+
# initialize trackers
|
| 158 |
+
queue_of_requests_to_retry = asyncio.Queue()
|
| 159 |
+
task_id_generator = (
|
| 160 |
+
task_id_generator_function()
|
| 161 |
+
) # generates integer IDs of 0, 1, 2, ...
|
| 162 |
+
status_tracker = (
|
| 163 |
+
StatusTracker()
|
| 164 |
+
) # single instance to track a collection of variables
|
| 165 |
+
next_request = None # variable to hold the next request to call
|
| 166 |
+
|
| 167 |
+
# initialize available capacity counts
|
| 168 |
+
available_request_capacity = max_requests_per_minute
|
| 169 |
+
available_token_capacity = max_tokens_per_minute
|
| 170 |
+
last_update_time = time.time()
|
| 171 |
+
|
| 172 |
+
# initialize flags
|
| 173 |
+
file_not_finished = True # after file is empty, we'll skip reading it
|
| 174 |
+
logging.debug(f"Initialization complete.")
|
| 175 |
+
|
| 176 |
+
# initialize file reading
|
| 177 |
+
with open(requests_filepath) as file:
|
| 178 |
+
# `requests` will provide requests one at a time
|
| 179 |
+
requests = file.__iter__()
|
| 180 |
+
logging.debug(f"File opened. Entering main loop")
|
| 181 |
+
async with aiohttp.ClientSession() as session: # Initialize ClientSession here
|
| 182 |
+
while True:
|
| 183 |
+
# get next request (if one is not already waiting for capacity)
|
| 184 |
+
if next_request is None:
|
| 185 |
+
if not queue_of_requests_to_retry.empty():
|
| 186 |
+
next_request = queue_of_requests_to_retry.get_nowait()
|
| 187 |
+
logging.debug(
|
| 188 |
+
f"Retrying request {next_request.task_id}: {next_request}"
|
| 189 |
+
)
|
| 190 |
+
elif file_not_finished:
|
| 191 |
+
try:
|
| 192 |
+
# get new request
|
| 193 |
+
request_json = json.loads(next(requests))
|
| 194 |
+
next_request = APIRequest(
|
| 195 |
+
task_id=next(task_id_generator),
|
| 196 |
+
request_json=request_json,
|
| 197 |
+
# token_consumption=num_tokens_consumed_from_request(
|
| 198 |
+
# request_json, api_endpoint, token_encoding_name
|
| 199 |
+
# ), # just disabled the tokens consumption. Not worried about that.
|
| 200 |
+
token_consumption=0,
|
| 201 |
+
attempts_left=max_attempts,
|
| 202 |
+
metadata=request_json.pop("metadata", None),
|
| 203 |
+
)
|
| 204 |
+
status_tracker.num_tasks_started += 1
|
| 205 |
+
status_tracker.num_tasks_in_progress += 1
|
| 206 |
+
logging.debug(
|
| 207 |
+
f"Reading request {next_request.task_id}: {next_request}"
|
| 208 |
+
)
|
| 209 |
+
except StopIteration:
|
| 210 |
+
# if file runs out, set flag to stop reading it
|
| 211 |
+
logging.debug("Read file exhausted")
|
| 212 |
+
file_not_finished = False
|
| 213 |
+
|
| 214 |
+
# update available capacity
|
| 215 |
+
current_time = time.time()
|
| 216 |
+
seconds_since_update = current_time - last_update_time
|
| 217 |
+
available_request_capacity = min(
|
| 218 |
+
available_request_capacity
|
| 219 |
+
+ max_requests_per_minute * seconds_since_update / 60.0,
|
| 220 |
+
max_requests_per_minute,
|
| 221 |
+
)
|
| 222 |
+
available_token_capacity = min(
|
| 223 |
+
available_token_capacity
|
| 224 |
+
+ max_tokens_per_minute * seconds_since_update / 60.0,
|
| 225 |
+
max_tokens_per_minute,
|
| 226 |
+
)
|
| 227 |
+
last_update_time = current_time
|
| 228 |
+
|
| 229 |
+
# if enough capacity available, call API
|
| 230 |
+
if next_request:
|
| 231 |
+
next_request_tokens = next_request.token_consumption
|
| 232 |
+
if (
|
| 233 |
+
available_request_capacity >= 1
|
| 234 |
+
and available_token_capacity >= next_request_tokens
|
| 235 |
+
):
|
| 236 |
+
# update counters
|
| 237 |
+
available_request_capacity -= 1
|
| 238 |
+
available_token_capacity -= next_request_tokens
|
| 239 |
+
next_request.attempts_left -= 1
|
| 240 |
+
|
| 241 |
+
# call API
|
| 242 |
+
asyncio.create_task(
|
| 243 |
+
next_request.call_api(
|
| 244 |
+
session=session,
|
| 245 |
+
request_url=request_url,
|
| 246 |
+
request_header=request_header,
|
| 247 |
+
retry_queue=queue_of_requests_to_retry,
|
| 248 |
+
save_filepath=save_filepath,
|
| 249 |
+
status_tracker=status_tracker,
|
| 250 |
+
)
|
| 251 |
+
)
|
| 252 |
+
next_request = None # reset next_request to empty
|
| 253 |
+
|
| 254 |
+
# if all tasks are finished, break
|
| 255 |
+
if status_tracker.num_tasks_in_progress == 0:
|
| 256 |
+
break
|
| 257 |
+
|
| 258 |
+
# main loop sleeps briefly so concurrent tasks can run
|
| 259 |
+
await asyncio.sleep(seconds_to_sleep_each_loop)
|
| 260 |
+
|
| 261 |
+
# if a rate limit error was hit recently, pause to cool down
|
| 262 |
+
seconds_since_rate_limit_error = (
|
| 263 |
+
time.time() - status_tracker.time_of_last_rate_limit_error
|
| 264 |
+
)
|
| 265 |
+
if (
|
| 266 |
+
seconds_since_rate_limit_error
|
| 267 |
+
< seconds_to_pause_after_rate_limit_error
|
| 268 |
+
):
|
| 269 |
+
remaining_seconds_to_pause = (
|
| 270 |
+
seconds_to_pause_after_rate_limit_error
|
| 271 |
+
- seconds_since_rate_limit_error
|
| 272 |
+
)
|
| 273 |
+
await asyncio.sleep(remaining_seconds_to_pause)
|
| 274 |
+
# ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
|
| 275 |
+
logging.warn(
|
| 276 |
+
f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# after finishing, log final status
|
| 280 |
+
logging.info(
|
| 281 |
+
f"""Parallel processing complete. Results saved to {save_filepath}"""
|
| 282 |
+
)
|
| 283 |
+
if status_tracker.num_tasks_failed > 0:
|
| 284 |
+
logging.warning(
|
| 285 |
+
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
|
| 286 |
+
)
|
| 287 |
+
if status_tracker.num_rate_limit_errors > 0:
|
| 288 |
+
logging.warning(
|
| 289 |
+
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# dataclasses
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
@dataclass
|
| 297 |
+
class StatusTracker:
|
| 298 |
+
"""Stores metadata about the script's progress. Only one instance is created."""
|
| 299 |
+
|
| 300 |
+
num_tasks_started: int = 0
|
| 301 |
+
num_tasks_in_progress: int = 0 # script ends when this reaches 0
|
| 302 |
+
num_tasks_succeeded: int = 0
|
| 303 |
+
num_tasks_failed: int = 0
|
| 304 |
+
num_rate_limit_errors: int = 0
|
| 305 |
+
num_api_errors: int = 0 # excluding rate limit errors, counted above
|
| 306 |
+
num_other_errors: int = 0
|
| 307 |
+
time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@dataclass
|
| 311 |
+
class APIRequest:
|
| 312 |
+
"""Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call."""
|
| 313 |
+
|
| 314 |
+
task_id: int
|
| 315 |
+
request_json: dict
|
| 316 |
+
token_consumption: int
|
| 317 |
+
attempts_left: int
|
| 318 |
+
metadata: dict
|
| 319 |
+
result: list = field(default_factory=list)
|
| 320 |
+
|
| 321 |
+
async def call_api(
|
| 322 |
+
self,
|
| 323 |
+
session: aiohttp.ClientSession,
|
| 324 |
+
request_url: str,
|
| 325 |
+
request_header: dict,
|
| 326 |
+
retry_queue: asyncio.Queue,
|
| 327 |
+
save_filepath: str,
|
| 328 |
+
status_tracker: StatusTracker,
|
| 329 |
+
):
|
| 330 |
+
"""Calls the OpenAI API and saves results."""
|
| 331 |
+
logging.info(f"Starting request #{self.task_id}")
|
| 332 |
+
error = None
|
| 333 |
+
try:
|
| 334 |
+
async with session.post(
|
| 335 |
+
url=request_url, headers=request_header, json=self.request_json
|
| 336 |
+
) as response:
|
| 337 |
+
response = await response.json()
|
| 338 |
+
if "error" in response:
|
| 339 |
+
logging.warning(
|
| 340 |
+
f"Request {self.task_id} failed with error {response['error']}"
|
| 341 |
+
)
|
| 342 |
+
status_tracker.num_api_errors += 1
|
| 343 |
+
error = response
|
| 344 |
+
if "Rate limit" in response["error"].get("message", ""):
|
| 345 |
+
status_tracker.time_of_last_rate_limit_error = time.time()
|
| 346 |
+
status_tracker.num_rate_limit_errors += 1
|
| 347 |
+
status_tracker.num_api_errors -= (
|
| 348 |
+
1 # rate limit errors are counted separately
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
except (
|
| 352 |
+
Exception
|
| 353 |
+
) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
|
| 354 |
+
logging.warning(f"Request {self.task_id} failed with Exception {e}")
|
| 355 |
+
status_tracker.num_other_errors += 1
|
| 356 |
+
error = e
|
| 357 |
+
if error:
|
| 358 |
+
self.result.append(error)
|
| 359 |
+
if self.attempts_left:
|
| 360 |
+
retry_queue.put_nowait(self)
|
| 361 |
+
else:
|
| 362 |
+
logging.error(
|
| 363 |
+
f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
|
| 364 |
+
)
|
| 365 |
+
data = (
|
| 366 |
+
[self.request_json, [str(e) for e in self.result], self.metadata]
|
| 367 |
+
if self.metadata
|
| 368 |
+
else [self.request_json, [str(e) for e in self.result]]
|
| 369 |
+
)
|
| 370 |
+
append_to_jsonl(data, save_filepath)
|
| 371 |
+
status_tracker.num_tasks_in_progress -= 1
|
| 372 |
+
status_tracker.num_tasks_failed += 1
|
| 373 |
+
else:
|
| 374 |
+
data = (
|
| 375 |
+
[self.request_json, response, self.metadata]
|
| 376 |
+
if self.metadata
|
| 377 |
+
else [self.request_json, response]
|
| 378 |
+
)
|
| 379 |
+
append_to_jsonl(data, save_filepath)
|
| 380 |
+
status_tracker.num_tasks_in_progress -= 1
|
| 381 |
+
status_tracker.num_tasks_succeeded += 1
|
| 382 |
+
logging.debug(f"Request {self.task_id} saved to {save_filepath}")
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# functions
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def api_endpoint_from_url(request_url, vendor_name):
|
| 389 |
+
"""Extract the API endpoint from the request URL."""
|
| 390 |
+
match=None
|
| 391 |
+
if vendor_name=="openai":
|
| 392 |
+
match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url)
|
| 393 |
+
elif vendor_name=="anthropic":
|
| 394 |
+
match = re.search(r"^https://[^/]+/v1/(.+)$", request_url)
|
| 395 |
+
elif vendor_name == "meta" or vendor_name == "google":
|
| 396 |
+
match = re.search(r"^https://[^/]+/api/v1/(.+)$", request_url)
|
| 397 |
+
else:
|
| 398 |
+
print("Error. Invalid Model Input. Exiting")
|
| 399 |
+
# exit()
|
| 400 |
+
if match is None:
|
| 401 |
+
# for Azure OpenAI deployment urls
|
| 402 |
+
match = re.search(r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url)
|
| 403 |
+
return match[1]
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def append_to_jsonl(data, filename: str) -> None:
|
| 407 |
+
"""Append a json payload to the end of a jsonl file."""
|
| 408 |
+
json_string = json.dumps(data)
|
| 409 |
+
with open(filename, "a") as f:
|
| 410 |
+
f.write(json_string + "\n")
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def task_id_generator_function():
|
| 415 |
+
"""Generate integers 0, 1, 2, and so on."""
|
| 416 |
+
task_id = 0
|
| 417 |
+
while True:
|
| 418 |
+
yield task_id
|
| 419 |
+
task_id += 1
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
# run script
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
if __name__ == "__main__":
|
| 426 |
+
# parse command line arguments
|
| 427 |
+
parser = argparse.ArgumentParser()
|
| 428 |
+
parser.add_argument("--vendor_name", default=None)
|
| 429 |
+
parser.add_argument("--requests_filepath")
|
| 430 |
+
parser.add_argument("--save_filepath", default=None)
|
| 431 |
+
parser.add_argument("--request_url", default=None)
|
| 432 |
+
parser.add_argument("--api_key", default=None)
|
| 433 |
+
parser.add_argument("--max_requests_per_minute", type=int, default=3_000 * 0.5)
|
| 434 |
+
parser.add_argument("--max_tokens_per_minute", type=int, default=250_000 * 0.5)
|
| 435 |
+
parser.add_argument("--token_encoding_name", default="cl100k_base")
|
| 436 |
+
parser.add_argument("--max_attempts", type=int, default=5)
|
| 437 |
+
parser.add_argument("--logging_level", default=logging.INFO)
|
| 438 |
+
|
| 439 |
+
args = parser.parse_args()
|
| 440 |
+
if args.vendor_name=="openai":
|
| 441 |
+
args.api_key=os.getenv("OPENAI_API_KEY")
|
| 442 |
+
args.request_url="https://api.openai.com/v1/chat/completions"
|
| 443 |
+
elif args.vendor_name=="anthropic":
|
| 444 |
+
args.api_key=os.getenv("ANTHROPIC_API_KEY")
|
| 445 |
+
args.request_url="https://api.anthropic.com/v1/messages"
|
| 446 |
+
elif args.vendor_name == "meta" or args.vendor_name == "google" :
|
| 447 |
+
args.api_key = os.getenv("OPENROUTER_API_KEY")
|
| 448 |
+
args.request_url = "https://openrouter.ai/api/v1/chat/completions"
|
| 449 |
+
else:
|
| 450 |
+
print("Error. Invalid Model Input. Exiting")
|
| 451 |
+
# exit()
|
| 452 |
+
if args.save_filepath is None:
|
| 453 |
+
args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl")
|
| 454 |
+
|
| 455 |
+
# run script
|
| 456 |
+
asyncio.run(
|
| 457 |
+
process_api_requests_from_file(
|
| 458 |
+
vendor_name=args.vendor_name,
|
| 459 |
+
requests_filepath=args.requests_filepath,
|
| 460 |
+
save_filepath=args.save_filepath,
|
| 461 |
+
request_url=args.request_url,
|
| 462 |
+
api_key=args.api_key,
|
| 463 |
+
max_requests_per_minute=float(args.max_requests_per_minute),
|
| 464 |
+
max_tokens_per_minute=float(args.max_tokens_per_minute),
|
| 465 |
+
token_encoding_name=args.token_encoding_name,
|
| 466 |
+
max_attempts=int(args.max_attempts),
|
| 467 |
+
logging_level=int(args.logging_level),
|
| 468 |
+
)
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
"""
|
| 473 |
+
APPENDIX
|
| 474 |
+
|
| 475 |
+
The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002.
|
| 476 |
+
|
| 477 |
+
It was generated with the following code:
|
| 478 |
+
|
| 479 |
+
```python
|
| 480 |
+
import json
|
| 481 |
+
|
| 482 |
+
filename = "data/example_requests_to_parallel_process.jsonl"
|
| 483 |
+
n_requests = 10_000
|
| 484 |
+
jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)]
|
| 485 |
+
with open(filename, "w") as f:
|
| 486 |
+
for job in jobs:
|
| 487 |
+
json_string = json.dumps(job)
|
| 488 |
+
f.write(json_string + "\n")
|
| 489 |
+
```
|
| 490 |
+
|
| 491 |
+
As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
|
| 492 |
+
"""
|
api_request_parallel_processor_universal_SEQUENTIAL.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API REQUEST PARALLEL PROCESSOR
|
| 3 |
+
|
| 4 |
+
Using the OpenAI API to process lots of text quickly takes some care.
|
| 5 |
+
If you trickle in a million API requests one by one, they'll take days to complete.
|
| 6 |
+
If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors.
|
| 7 |
+
To maximize throughput, parallel requests need to be throttled to stay under rate limits.
|
| 8 |
+
|
| 9 |
+
This script parallelizes requests to the OpenAI API while throttling to stay under rate limits.
|
| 10 |
+
|
| 11 |
+
Features:
|
| 12 |
+
- Streams requests from file, to avoid running out of memory for giant jobs
|
| 13 |
+
- Makes requests concurrently, to maximize throughput
|
| 14 |
+
- Throttles request and token usage, to stay under rate limits
|
| 15 |
+
- Retries failed requests up to {max_attempts} times, to avoid missing data
|
| 16 |
+
- Logs errors, to diagnose problems with requests
|
| 17 |
+
|
| 18 |
+
Example command to call script:
|
| 19 |
+
```
|
| 20 |
+
python examples/api_request_parallel_processor.py \
|
| 21 |
+
--requests_filepath examples/data/example_requests_to_parallel_process.jsonl \
|
| 22 |
+
--save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \
|
| 23 |
+
--request_url https://api.openai.com/v1/embeddings \
|
| 24 |
+
--max_requests_per_minute 1500 \
|
| 25 |
+
--max_tokens_per_minute 6250000 \
|
| 26 |
+
--token_encoding_name cl100k_base \
|
| 27 |
+
--max_attempts 5 \
|
| 28 |
+
--logging_level 20
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Inputs:
|
| 32 |
+
- requests_filepath : str
|
| 33 |
+
- path to the file containing the requests to be processed
|
| 34 |
+
- file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field
|
| 35 |
+
- e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}}
|
| 36 |
+
- as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically)
|
| 37 |
+
- an example file is provided at examples/data/example_requests_to_parallel_process.jsonl
|
| 38 |
+
- the code to generate the example file is appended to the bottom of this script
|
| 39 |
+
- save_filepath : str, optional
|
| 40 |
+
- path to the file where the results will be saved
|
| 41 |
+
- file will be a jsonl file, where each line is an array with the original request plus the API response
|
| 42 |
+
- e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}]
|
| 43 |
+
- if omitted, results will be saved to {requests_filename}_results.jsonl
|
| 44 |
+
- request_url : str, optional
|
| 45 |
+
- URL of the API endpoint to call
|
| 46 |
+
- if omitted, will default to "https://api.openai.com/v1/embeddings"
|
| 47 |
+
- api_key : str, optional
|
| 48 |
+
- API key to use
|
| 49 |
+
- if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")}
|
| 50 |
+
- max_requests_per_minute : float, optional
|
| 51 |
+
- target number of requests to make per minute (will make less if limited by tokens)
|
| 52 |
+
- leave headroom by setting this to 50% or 75% of your limit
|
| 53 |
+
- if requests are limiting you, try batching multiple embeddings or completions into one request
|
| 54 |
+
- if omitted, will default to 1,500
|
| 55 |
+
- max_tokens_per_minute : float, optional
|
| 56 |
+
- target number of tokens to use per minute (will use less if limited by requests)
|
| 57 |
+
- leave headroom by setting this to 50% or 75% of your limit
|
| 58 |
+
- if omitted, will default to 125,000
|
| 59 |
+
- token_encoding_name : str, optional
|
| 60 |
+
- name of the token encoding used, as defined in the `tiktoken` package
|
| 61 |
+
- if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`)
|
| 62 |
+
- max_attempts : int, optional
|
| 63 |
+
- number of times to retry a failed request before giving up
|
| 64 |
+
- if omitted, will default to 5
|
| 65 |
+
- logging_level : int, optional
|
| 66 |
+
- level of logging to use; higher numbers will log fewer messages
|
| 67 |
+
- 40 = ERROR; will log only when requests fail after all retries
|
| 68 |
+
- 30 = WARNING; will log when requests his rate limits or other errors
|
| 69 |
+
- 20 = INFO; will log when requests start and the status at finish
|
| 70 |
+
- 10 = DEBUG; will log various things as the loop runs to see when they occur
|
| 71 |
+
- if omitted, will default to 20 (INFO).
|
| 72 |
+
|
| 73 |
+
The script is structured as follows:
|
| 74 |
+
- Imports
|
| 75 |
+
- Define main()
|
| 76 |
+
- Initialize things
|
| 77 |
+
- In main loop:
|
| 78 |
+
- Get next request if one is not already waiting for capacity
|
| 79 |
+
- Update available token & request capacity
|
| 80 |
+
- If enough capacity available, call API
|
| 81 |
+
- The loop pauses if a rate limit error is hit
|
| 82 |
+
- The loop breaks when no tasks remain
|
| 83 |
+
- Define dataclasses
|
| 84 |
+
- StatusTracker (stores script metadata counters; only one instance is created)
|
| 85 |
+
- APIRequest (stores API inputs, outputs, metadata; one method to call API)
|
| 86 |
+
- Define functions
|
| 87 |
+
- api_endpoint_from_url (extracts API endpoint from request URL)
|
| 88 |
+
- append_to_jsonl (writes to results file)
|
| 89 |
+
- num_tokens_consumed_from_request (bigger function to infer token usage from request)
|
| 90 |
+
- task_id_generator_function (yields 0, 1, 2, ...)
|
| 91 |
+
- Run main()
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
# imports
|
| 95 |
+
import aiohttp # for making API calls concurrently
|
| 96 |
+
import argparse # for running script from command line
|
| 97 |
+
import asyncio # for running API calls concurrently
|
| 98 |
+
import json # for saving results to a jsonl file
|
| 99 |
+
import logging # for logging rate limit warnings and other messages
|
| 100 |
+
import os # for reading API key
|
| 101 |
+
import re # for matching endpoint from request URL
|
| 102 |
+
import tiktoken # for counting tokens
|
| 103 |
+
import time # for sleeping after rate limit is hit
|
| 104 |
+
from dataclasses import (
|
| 105 |
+
dataclass,
|
| 106 |
+
field,
|
| 107 |
+
) # for storing API inputs, outputs, and metadata
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def process_api_requests_from_file(
|
| 111 |
+
vendor_name: str,
|
| 112 |
+
requests_filepath: str,
|
| 113 |
+
save_filepath: str,
|
| 114 |
+
request_url: str,
|
| 115 |
+
api_key: str,
|
| 116 |
+
max_requests_per_minute: float,
|
| 117 |
+
max_tokens_per_minute: float,
|
| 118 |
+
token_encoding_name: str,
|
| 119 |
+
max_attempts: int,
|
| 120 |
+
logging_level: int,
|
| 121 |
+
):
|
| 122 |
+
"""Processes API requests sequentially."""
|
| 123 |
+
# initialize logging
|
| 124 |
+
logging.basicConfig(level=logging_level)
|
| 125 |
+
logging.debug(f"Logging initialized at level {logging_level}")
|
| 126 |
+
|
| 127 |
+
# infer API endpoint and construct request header
|
| 128 |
+
api_endpoint = api_endpoint_from_url(request_url, vendor_name)
|
| 129 |
+
request_header = None
|
| 130 |
+
if vendor_name == "openai":
|
| 131 |
+
request_header = {"Authorization": f"Bearer {api_key}"}
|
| 132 |
+
elif vendor_name == "anthropic":
|
| 133 |
+
request_header = {
|
| 134 |
+
"x-api-key": api_key,
|
| 135 |
+
"anthropic-version": "2023-06-01",
|
| 136 |
+
"content-type": "application/json",
|
| 137 |
+
}
|
| 138 |
+
elif vendor_name == "meta" or vendor_name == "google":
|
| 139 |
+
request_header = {
|
| 140 |
+
"Content-Type": "application/json",
|
| 141 |
+
"Authorization": f"Bearer {api_key}",
|
| 142 |
+
}
|
| 143 |
+
else:
|
| 144 |
+
print("Error. Invalid Model Input. Exiting")
|
| 145 |
+
|
| 146 |
+
# initialize trackers
|
| 147 |
+
task_id_generator = task_id_generator_function()
|
| 148 |
+
status_tracker = StatusTracker()
|
| 149 |
+
|
| 150 |
+
# process requests sequentially
|
| 151 |
+
with open(requests_filepath) as file, requests.Session() as session:
|
| 152 |
+
for line in file:
|
| 153 |
+
request_json = json.loads(line)
|
| 154 |
+
request = APIRequest(
|
| 155 |
+
task_id=next(task_id_generator),
|
| 156 |
+
request_json=request_json,
|
| 157 |
+
token_consumption=0,
|
| 158 |
+
attempts_left=max_attempts,
|
| 159 |
+
metadata=request_json.pop("metadata", None),
|
| 160 |
+
)
|
| 161 |
+
status_tracker.num_tasks_started += 1
|
| 162 |
+
logging.debug(f"Processing request {request.task_id}: {request}")
|
| 163 |
+
|
| 164 |
+
while request.attempts_left > 0:
|
| 165 |
+
error = None
|
| 166 |
+
try:
|
| 167 |
+
response = session.post(
|
| 168 |
+
url=request_url,
|
| 169 |
+
headers=request_header,
|
| 170 |
+
json=request.request_json,
|
| 171 |
+
).json()
|
| 172 |
+
if "error" in response:
|
| 173 |
+
logging.warning(
|
| 174 |
+
f"Request {request.task_id} failed with error {response['error']}"
|
| 175 |
+
)
|
| 176 |
+
status_tracker.num_api_errors += 1
|
| 177 |
+
error = response
|
| 178 |
+
if "Rate limit" in response["error"].get("message", ""):
|
| 179 |
+
status_tracker.num_rate_limit_errors += 1
|
| 180 |
+
status_tracker.num_api_errors -= 1
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logging.warning(f"Request {request.task_id} failed with Exception {e}")
|
| 183 |
+
status_tracker.num_other_errors += 1
|
| 184 |
+
error = e
|
| 185 |
+
|
| 186 |
+
if error:
|
| 187 |
+
request.result.append(error)
|
| 188 |
+
request.attempts_left -= 1
|
| 189 |
+
if request.attempts_left == 0:
|
| 190 |
+
logging.error(
|
| 191 |
+
f"Request {request.request_json} failed after all attempts. Saving errors: {request.result}"
|
| 192 |
+
)
|
| 193 |
+
data = (
|
| 194 |
+
[request.request_json, [str(e) for e in request.result], request.metadata]
|
| 195 |
+
if request.metadata
|
| 196 |
+
else [request.request_json, [str(e) for e in request.result]]
|
| 197 |
+
)
|
| 198 |
+
append_to_jsonl(data, save_filepath)
|
| 199 |
+
status_tracker.num_tasks_failed += 1
|
| 200 |
+
else:
|
| 201 |
+
data = (
|
| 202 |
+
[request.request_json, response, request.metadata]
|
| 203 |
+
if request.metadata
|
| 204 |
+
else [request.request_json, response]
|
| 205 |
+
)
|
| 206 |
+
append_to_jsonl(data, save_filepath)
|
| 207 |
+
status_tracker.num_tasks_succeeded += 1
|
| 208 |
+
logging.debug(f"Request {request.task_id} saved to {save_filepath}")
|
| 209 |
+
break
|
| 210 |
+
|
| 211 |
+
# after finishing, log final status
|
| 212 |
+
logging.info(f"Sequential processing complete. Results saved to {save_filepath}")
|
| 213 |
+
if status_tracker.num_tasks_failed > 0:
|
| 214 |
+
logging.warning(
|
| 215 |
+
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
|
| 216 |
+
)
|
| 217 |
+
if status_tracker.num_rate_limit_errors > 0:
|
| 218 |
+
logging.warning(
|
| 219 |
+
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
| 220 |
+
)
|
| 221 |
+
# dataclasses
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@dataclass
|
| 225 |
+
class StatusTracker:
|
| 226 |
+
"""Stores metadata about the script's progress. Only one instance is created."""
|
| 227 |
+
|
| 228 |
+
num_tasks_started: int = 0
|
| 229 |
+
num_tasks_in_progress: int = 0 # script ends when this reaches 0
|
| 230 |
+
num_tasks_succeeded: int = 0
|
| 231 |
+
num_tasks_failed: int = 0
|
| 232 |
+
num_rate_limit_errors: int = 0
|
| 233 |
+
num_api_errors: int = 0 # excluding rate limit errors, counted above
|
| 234 |
+
num_other_errors: int = 0
|
| 235 |
+
time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@dataclass
|
| 239 |
+
class APIRequest:
|
| 240 |
+
"""Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call."""
|
| 241 |
+
|
| 242 |
+
task_id: int
|
| 243 |
+
request_json: dict
|
| 244 |
+
token_consumption: int
|
| 245 |
+
attempts_left: int
|
| 246 |
+
metadata: dict
|
| 247 |
+
result: list = field(default_factory=list)
|
| 248 |
+
|
| 249 |
+
async def call_api(
|
| 250 |
+
self,
|
| 251 |
+
session: aiohttp.ClientSession,
|
| 252 |
+
request_url: str,
|
| 253 |
+
request_header: dict,
|
| 254 |
+
retry_queue: asyncio.Queue,
|
| 255 |
+
save_filepath: str,
|
| 256 |
+
status_tracker: StatusTracker,
|
| 257 |
+
):
|
| 258 |
+
"""Calls the OpenAI API and saves results."""
|
| 259 |
+
logging.info(f"Starting request #{self.task_id}")
|
| 260 |
+
error = None
|
| 261 |
+
try:
|
| 262 |
+
async with session.post(
|
| 263 |
+
url=request_url, headers=request_header, json=self.request_json
|
| 264 |
+
) as response:
|
| 265 |
+
response = await response.json()
|
| 266 |
+
if "error" in response:
|
| 267 |
+
logging.warning(
|
| 268 |
+
f"Request {self.task_id} failed with error {response['error']}"
|
| 269 |
+
)
|
| 270 |
+
status_tracker.num_api_errors += 1
|
| 271 |
+
error = response
|
| 272 |
+
if "Rate limit" in response["error"].get("message", ""):
|
| 273 |
+
status_tracker.time_of_last_rate_limit_error = time.time()
|
| 274 |
+
status_tracker.num_rate_limit_errors += 1
|
| 275 |
+
status_tracker.num_api_errors -= (
|
| 276 |
+
1 # rate limit errors are counted separately
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
except (
|
| 280 |
+
Exception
|
| 281 |
+
) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
|
| 282 |
+
logging.warning(f"Request {self.task_id} failed with Exception {e}")
|
| 283 |
+
status_tracker.num_other_errors += 1
|
| 284 |
+
error = e
|
| 285 |
+
if error:
|
| 286 |
+
self.result.append(error)
|
| 287 |
+
if self.attempts_left:
|
| 288 |
+
retry_queue.put_nowait(self)
|
| 289 |
+
else:
|
| 290 |
+
logging.error(
|
| 291 |
+
f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
|
| 292 |
+
)
|
| 293 |
+
data = (
|
| 294 |
+
[self.request_json, [str(e) for e in self.result], self.metadata]
|
| 295 |
+
if self.metadata
|
| 296 |
+
else [self.request_json, [str(e) for e in self.result]]
|
| 297 |
+
)
|
| 298 |
+
append_to_jsonl(data, save_filepath)
|
| 299 |
+
status_tracker.num_tasks_in_progress -= 1
|
| 300 |
+
status_tracker.num_tasks_failed += 1
|
| 301 |
+
else:
|
| 302 |
+
data = (
|
| 303 |
+
[self.request_json, response, self.metadata]
|
| 304 |
+
if self.metadata
|
| 305 |
+
else [self.request_json, response]
|
| 306 |
+
)
|
| 307 |
+
append_to_jsonl(data, save_filepath)
|
| 308 |
+
status_tracker.num_tasks_in_progress -= 1
|
| 309 |
+
status_tracker.num_tasks_succeeded += 1
|
| 310 |
+
logging.debug(f"Request {self.task_id} saved to {save_filepath}")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# functions
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def api_endpoint_from_url(request_url, vendor_name):
|
| 317 |
+
"""Extract the API endpoint from the request URL."""
|
| 318 |
+
match=None
|
| 319 |
+
if vendor_name=="openai":
|
| 320 |
+
match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url)
|
| 321 |
+
elif vendor_name=="anthropic":
|
| 322 |
+
match = re.search(r"^https://[^/]+/v1/(.+)$", request_url)
|
| 323 |
+
elif vendor_name == "meta" or vendor_name == "google":
|
| 324 |
+
match = re.search(r"^https://[^/]+/api/v1/(.+)$", request_url)
|
| 325 |
+
else:
|
| 326 |
+
print("Error. Invalid Model Input. Exiting")
|
| 327 |
+
# exit()
|
| 328 |
+
if match is None:
|
| 329 |
+
# for Azure OpenAI deployment urls
|
| 330 |
+
match = re.search(r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url)
|
| 331 |
+
return match[1]
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def append_to_jsonl(data, filename: str) -> None:
|
| 335 |
+
"""Append a json payload to the end of a jsonl file."""
|
| 336 |
+
json_string = json.dumps(data)
|
| 337 |
+
with open(filename, "a") as f:
|
| 338 |
+
f.write(json_string + "\n")
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def task_id_generator_function():
|
| 343 |
+
"""Generate integers 0, 1, 2, and so on."""
|
| 344 |
+
task_id = 0
|
| 345 |
+
while True:
|
| 346 |
+
yield task_id
|
| 347 |
+
task_id += 1
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# run script
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
# parse command line arguments
|
| 355 |
+
parser = argparse.ArgumentParser()
|
| 356 |
+
parser.add_argument("--vendor_name", default=None)
|
| 357 |
+
parser.add_argument("--requests_filepath")
|
| 358 |
+
parser.add_argument("--save_filepath", default=None)
|
| 359 |
+
parser.add_argument("--request_url", default=None)
|
| 360 |
+
parser.add_argument("--api_key", default=None)
|
| 361 |
+
parser.add_argument("--max_requests_per_minute", type=int, default=3_000 * 0.5)
|
| 362 |
+
parser.add_argument("--max_tokens_per_minute", type=int, default=250_000 * 0.5)
|
| 363 |
+
parser.add_argument("--token_encoding_name", default="cl100k_base")
|
| 364 |
+
parser.add_argument("--max_attempts", type=int, default=5)
|
| 365 |
+
parser.add_argument("--logging_level", default=logging.INFO)
|
| 366 |
+
|
| 367 |
+
args = parser.parse_args()
|
| 368 |
+
if args.vendor_name=="openai":
|
| 369 |
+
args.api_key=os.getenv("OPENAI_API_KEY")
|
| 370 |
+
args.request_url="https://api.openai.com/v1/chat/completions"
|
| 371 |
+
elif args.vendor_name=="anthropic":
|
| 372 |
+
args.api_key=os.getenv("ANTHROPIC_API_KEY")
|
| 373 |
+
args.request_url="https://api.anthropic.com/v1/messages"
|
| 374 |
+
elif args.vendor_name == "meta" or args.vendor_name == "google" :
|
| 375 |
+
args.api_key = os.getenv("OPENROUTER_API_KEY")
|
| 376 |
+
args.request_url = "https://openrouter.ai/api/v1/chat/completions"
|
| 377 |
+
else:
|
| 378 |
+
print("Error. Invalid Model Input. Exiting")
|
| 379 |
+
# exit()
|
| 380 |
+
if args.save_filepath is None:
|
| 381 |
+
args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl")
|
| 382 |
+
|
| 383 |
+
# run script
|
| 384 |
+
asyncio.run(
|
| 385 |
+
process_api_requests_from_file(
|
| 386 |
+
vendor_name=args.vendor_name,
|
| 387 |
+
requests_filepath=args.requests_filepath,
|
| 388 |
+
save_filepath=args.save_filepath,
|
| 389 |
+
request_url=args.request_url,
|
| 390 |
+
api_key=args.api_key,
|
| 391 |
+
max_requests_per_minute=float(args.max_requests_per_minute),
|
| 392 |
+
max_tokens_per_minute=float(args.max_tokens_per_minute),
|
| 393 |
+
token_encoding_name=args.token_encoding_name,
|
| 394 |
+
max_attempts=int(args.max_attempts),
|
| 395 |
+
logging_level=int(args.logging_level),
|
| 396 |
+
)
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
"""
|
| 401 |
+
APPENDIX
|
| 402 |
+
|
| 403 |
+
The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002.
|
| 404 |
+
|
| 405 |
+
It was generated with the following code:
|
| 406 |
+
|
| 407 |
+
```python
|
| 408 |
+
import json
|
| 409 |
+
|
| 410 |
+
filename = "data/example_requests_to_parallel_process.jsonl"
|
| 411 |
+
n_requests = 10_000
|
| 412 |
+
jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)]
|
| 413 |
+
with open(filename, "w") as f:
|
| 414 |
+
for job in jobs:
|
| 415 |
+
json_string = json.dumps(job)
|
| 416 |
+
f.write(json_string + "\n")
|
| 417 |
+
```
|
| 418 |
+
|
| 419 |
+
As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
|
| 420 |
+
"""
|
app.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# https://stackoverflow.com/questions/76175046/how-to-add-prompt-to-langchain-conversationalretrievalchain-chat-over-docs-with
|
| 3 |
+
# again from:
|
| 4 |
+
# https://python.langchain.com/docs/integrations/providers/vectara/vectara_chat
|
| 5 |
+
from langchain.document_loaders import PyPDFDirectoryLoader
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import langchain
|
| 8 |
+
from queue import Queue
|
| 9 |
+
from typing import Any
|
| 10 |
+
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
| 11 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 12 |
+
from langchain.schema import LLMResult
|
| 13 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 14 |
+
from langchain.vectorstores import FAISS
|
| 15 |
+
from langchain.prompts.prompt import PromptTemplate
|
| 16 |
+
from anyio.from_thread import start_blocking_portal #For model callback streaming
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
| 20 |
+
import os
|
| 21 |
+
from dotenv import load_dotenv
|
| 22 |
+
|
| 23 |
+
import streamlit as st
|
| 24 |
+
import json
|
| 25 |
+
from langchain.document_loaders import PyPDFLoader
|
| 26 |
+
from langchain.text_splitter import CharacterTextSplitter
|
| 27 |
+
from langchain.embeddings import OpenAIEmbeddings
|
| 28 |
+
from langchain.chains.question_answering import load_qa_chain
|
| 29 |
+
from langchain.chat_models import ChatOpenAI
|
| 30 |
+
# from langchain.chat_models import ChatAnthropic
|
| 31 |
+
from langchain_anthropic import ChatAnthropic
|
| 32 |
+
from langchain.vectorstores import Chroma
|
| 33 |
+
import chromadb
|
| 34 |
+
|
| 35 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 36 |
+
from langchain.llms import OpenAI
|
| 37 |
+
from langchain.chains import RetrievalQA
|
| 38 |
+
from langchain.document_loaders import TextLoader
|
| 39 |
+
from langchain.document_loaders import DirectoryLoader
|
| 40 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
| 41 |
+
from langchain.schema import Document
|
| 42 |
+
|
| 43 |
+
from langchain.memory import ConversationBufferMemory
|
| 44 |
+
|
| 45 |
+
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
| 46 |
+
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
| 47 |
+
import gradio as gr
|
| 48 |
+
from langchain.memory import ConversationBufferMemory
|
| 49 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 50 |
+
from langchain.chains import ConversationChain
|
| 51 |
+
from langchain.prompts import PromptTemplate
|
| 52 |
+
from langchain.chains import LLMChain
|
| 53 |
+
print("Started")
|
| 54 |
+
|
| 55 |
+
# Function to map UI region selection to DB metadata region
|
| 56 |
+
# def get_db_region(ui_region: str) -> str:
|
| 57 |
+
# """Maps UI region selection (e.g., 'Iowa') to the region string stored in metadata (e.g., 'United States')."""
|
| 58 |
+
# if ui_region == "Iowa":
|
| 59 |
+
# return "United States"
|
| 60 |
+
# # Add more mappings if needed (e.g., Africa)
|
| 61 |
+
# return ui_region # Default to using the UI string if no specific mapping
|
| 62 |
+
|
| 63 |
+
def get_species_list_from_db(db_name):
|
| 64 |
+
embedding = OpenAIEmbeddings()
|
| 65 |
+
vectordb_temp = Chroma(persist_directory=db_name,
|
| 66 |
+
embedding_function=embedding)
|
| 67 |
+
species_list=[]
|
| 68 |
+
for meta in vectordb_temp.get()["metadatas"] :
|
| 69 |
+
try:
|
| 70 |
+
matched_first_species = meta['matched_specie_0']
|
| 71 |
+
except KeyError:
|
| 72 |
+
continue
|
| 73 |
+
# Since each document is considered as a single chunk, the chunk_index is 0 for all
|
| 74 |
+
species_list.append( matched_first_species)
|
| 75 |
+
|
| 76 |
+
return species_list
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# default_persist_directory = './db5' # For deployement
|
| 81 |
+
default_persist_directory_insects='./vector-databases-deployed/db5-agllm-data-isu-field-insects-all-species'
|
| 82 |
+
default_persist_directory_weeds='./vector-databases-deployed/db5-agllm-data-isu-field-weeds-all-species'
|
| 83 |
+
|
| 84 |
+
species_list_insects=get_species_list_from_db(default_persist_directory_insects)
|
| 85 |
+
species_list_weeds=get_species_list_from_db(default_persist_directory_weeds)
|
| 86 |
+
# default_persist_directory = 'vector-databases/db5-pre-completion' # For Development
|
| 87 |
+
csv_filepath1 = "./agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx"
|
| 88 |
+
csv_filepath2 = "./agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx"
|
| 89 |
+
|
| 90 |
+
# India data
|
| 91 |
+
cv_india="./agllm-data/india/species.csv"
|
| 92 |
+
model_name=4
|
| 93 |
+
max_tokens=400
|
| 94 |
+
system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later.
|
| 95 |
+
langchain.debug=False # TODO: DOUBLE CHECK
|
| 96 |
+
from langchain import globals
|
| 97 |
+
globals.set_debug(False)
|
| 98 |
+
|
| 99 |
+
retriever_k_value=3
|
| 100 |
+
embedding = OpenAIEmbeddings()
|
| 101 |
+
print("Started....")
|
| 102 |
+
class ChatOpenRouter(ChatOpenAI):
|
| 103 |
+
openai_api_base: str
|
| 104 |
+
openai_api_key: str
|
| 105 |
+
model_name: str
|
| 106 |
+
|
| 107 |
+
def __init__(self,
|
| 108 |
+
model_name: str,
|
| 109 |
+
openai_api_key: [str] = None,
|
| 110 |
+
openai_api_base: str = "https://openrouter.ai/api/v1",
|
| 111 |
+
**kwargs):
|
| 112 |
+
openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY')
|
| 113 |
+
super().__init__(openai_api_base=openai_api_base,
|
| 114 |
+
openai_api_key=openai_api_key,
|
| 115 |
+
model_name=model_name, **kwargs)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
######### todo: skipping the first step
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# print(# Single example
|
| 123 |
+
# vectordb.as_retriever(k=2, search_kwargs={"filter": {"matched_specie_0": "Hypagyrtis unipunctata"}, 'k':1}).get_relevant_documents(
|
| 124 |
+
# "Checking if retriever is correctly initalized?"
|
| 125 |
+
# ))
|
| 126 |
+
|
| 127 |
+
columns = ['species', 'common name', 'order', 'family',
|
| 128 |
+
'genus', 'Updated role in ecosystem', 'Proof',
|
| 129 |
+
'ipm strategies', 'size of insect', 'geographical spread',
|
| 130 |
+
'life cycle specifics', 'pest for plant species', 'species status',
|
| 131 |
+
'distribution area', 'appearance', 'identification']
|
| 132 |
+
|
| 133 |
+
df1 = pd.read_excel(csv_filepath1, usecols=columns)
|
| 134 |
+
df2 = pd.read_excel(csv_filepath2, usecols=columns)
|
| 135 |
+
df_india = pd.read_csv(cv_india)
|
| 136 |
+
all_insects_data = pd.concat([df1, df2], ignore_index=True)
|
| 137 |
+
|
| 138 |
+
def get_prompt_with_vetted_info_from_specie_name(search_for_specie, mode):
|
| 139 |
+
|
| 140 |
+
def read_and_format_filtered_csv_better(dataframe_given, insect_specie):
|
| 141 |
+
filtered_data = dataframe_given[dataframe_given['species'] == insect_specie]
|
| 142 |
+
formatted_data = ""
|
| 143 |
+
# Format the filtered data
|
| 144 |
+
for index, row in filtered_data.iterrows():
|
| 145 |
+
row_data = [f"{col}: {row[col]}" for col in filtered_data.columns]
|
| 146 |
+
formatted_row = "\n".join(row_data)
|
| 147 |
+
formatted_data += f"{formatted_row}\n"
|
| 148 |
+
|
| 149 |
+
return formatted_data
|
| 150 |
+
|
| 151 |
+
# Use the path to your CSV file here
|
| 152 |
+
|
| 153 |
+
vetted_info=read_and_format_filtered_csv_better(all_insects_data, search_for_specie)
|
| 154 |
+
india_info=read_and_format_filtered_csv_better(df_india, search_for_specie)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if mode=="Farmer":
|
| 158 |
+
language_constraint="The language should be acustomed to the Farmers. Given question is likely to be asked by a farmer in the field will ask which will help to make decisions which are immediate and practical."
|
| 159 |
+
elif mode=="Researcher":
|
| 160 |
+
language_constraint="The language should be acustomed to a researcher. Given question is likely to be asked by a scientist which are comprehensive and aimed at exploring new knowledge or refining existing methodologies"
|
| 161 |
+
else:
|
| 162 |
+
print("No valid mode provided. Exiting")
|
| 163 |
+
exit()
|
| 164 |
+
# general_system_template = """
|
| 165 |
+
# In every question you are provided information about the insect/weed. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect/weed species and a question by the user. answer the question according to these two types of informations.
|
| 166 |
+
# ----
|
| 167 |
+
# Vetted info is as follows:
|
| 168 |
+
# {vetted_info}
|
| 169 |
+
# ----
|
| 170 |
+
# The context retrieved for documents about this particular question is as follows:
|
| 171 |
+
# {context}
|
| 172 |
+
# ----
|
| 173 |
+
# Additional Instruction:
|
| 174 |
+
# 1. Reference Constraint
|
| 175 |
+
# At the end of each answer provide the source/reference for the given data in following format:
|
| 176 |
+
# \n\n[enter two new lines before writing below] References:
|
| 177 |
+
# Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'.
|
| 178 |
+
# Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used.
|
| 179 |
+
# 2. Information Constraint:
|
| 180 |
+
# Only answer the question from information provided otherwise say you dont know. You have to answer in 50 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about.
|
| 181 |
+
# 3. Language constraint:
|
| 182 |
+
# {language_constraint}
|
| 183 |
+
|
| 184 |
+
# ----
|
| 185 |
+
# """.format(vetted_info=vetted_info, language_constraint=language_constraint,context="{context}", )
|
| 186 |
+
|
| 187 |
+
general_system_template = f"""
|
| 188 |
+
You are an AI assistant specialized in providing information about insects/weeds. Answer the user's question based on the available information or your general knowledge.
|
| 189 |
+
|
| 190 |
+
The context retrieved for this question is as follows:
|
| 191 |
+
{{context}}
|
| 192 |
+
|
| 193 |
+
Instructions:
|
| 194 |
+
1. Evaluate the relevance of the provided context to the question.
|
| 195 |
+
2. If the context contains relevant information, use it to answer the question and explicitly mention "Based on provided information" in your source.
|
| 196 |
+
3. If the context does not contain relevant information, use your general knowledge to answer the question and state "Based on general knowledge" as the source.
|
| 197 |
+
4. Format your response as follows:
|
| 198 |
+
Answer: Provide a concise answer in less than 50 words.
|
| 199 |
+
Source: State either "Based on provided information" or "Based on general knowledge".
|
| 200 |
+
|
| 201 |
+
5. Language constraint:
|
| 202 |
+
{language_constraint}
|
| 203 |
+
6. Other region (India) information:
|
| 204 |
+
{india_info}
|
| 205 |
+
7. So you have two kinds of information (default from Iowa and other region (India) information). First need to ask the user what region they are interested in. and only provide information from that region.
|
| 206 |
+
8. When answering question, say what if the information is from what regiion. So, if a region is selected by user and specified in the question, then only answer based on that region and say so.
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
Question: {{question}}
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
general_user_template = "Question:```{question}```"
|
| 213 |
+
messages_formatted = [
|
| 214 |
+
SystemMessagePromptTemplate.from_template(general_system_template),
|
| 215 |
+
HumanMessagePromptTemplate.from_template(general_user_template)
|
| 216 |
+
]
|
| 217 |
+
qa_prompt = ChatPromptTemplate.from_messages( messages_formatted)
|
| 218 |
+
# print(qa_prompt)
|
| 219 |
+
return qa_prompt
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# qa_prompt=get_prompt_with_vetted_info_from_specie_name("Papaipema nebris", "Researcher")
|
| 226 |
+
# print("First prompt is intialized as: " , qa_prompt, "\n\n")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer', return_messages=True) # https://github.com/langchain-ai/langchain/issues/9394#issuecomment-1683538834
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
if model_name==4:
|
| 233 |
+
llm_openai = ChatOpenAI(model_name="gpt-4o-2024-08-06" , temperature=0, max_tokens=max_tokens) # TODO: NEW MODEL VERSION AVAILABLE
|
| 234 |
+
else:
|
| 235 |
+
llm_openai = ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
|
| 236 |
+
|
| 237 |
+
specie_selector="Papaipema nebris"
|
| 238 |
+
filter = {
|
| 239 |
+
"$or": [
|
| 240 |
+
{"matched_specie_0": specie_selector},
|
| 241 |
+
{"matched_specie_1": specie_selector},
|
| 242 |
+
{"matched_specie_2": specie_selector},
|
| 243 |
+
]
|
| 244 |
+
}
|
| 245 |
+
# retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
|
| 246 |
+
|
| 247 |
+
# qa_chain = ConversationalRetrievalChain.from_llm(
|
| 248 |
+
# llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,\
|
| 249 |
+
# combine_docs_chain_kwargs={'prompt': qa_prompt}
|
| 250 |
+
|
| 251 |
+
# )
|
| 252 |
+
#
|
| 253 |
+
|
| 254 |
+
def initialize_qa_chain(specie_selector, application_mode, model_name, region, database_persistent_directory=default_persist_directory_insects, domain_name="Insects"):
|
| 255 |
+
# Add helper function for India info (kept for potential future use, but removed from RAG prompt)
|
| 256 |
+
def read_and_format_filtered_csv_better(dataframe_given, insect_specie):
|
| 257 |
+
filtered_data = dataframe_given[dataframe_given['species'] == insect_specie]
|
| 258 |
+
formatted_data = ""
|
| 259 |
+
# Format the filtered data
|
| 260 |
+
for index, row in filtered_data.iterrows():
|
| 261 |
+
row_data = [f"{col}: {row[col]}" for col in filtered_data.columns]
|
| 262 |
+
formatted_row = "\n".join(row_data)
|
| 263 |
+
formatted_data += f"{formatted_row}\n"
|
| 264 |
+
|
| 265 |
+
return formatted_data
|
| 266 |
+
|
| 267 |
+
# Get India info (potentially useful if not using RAG or for specific logic)
|
| 268 |
+
india_info = read_and_format_filtered_csv_better(df_india, specie_selector)
|
| 269 |
+
# db_region = get_db_region(region) # Map UI region to DB region - REMOVED
|
| 270 |
+
|
| 271 |
+
if model_name=="GPT-4":
|
| 272 |
+
chosen_llm=ChatOpenAI(model_name="gpt-4o-2024-08-06" , temperature=0, max_tokens=max_tokens)
|
| 273 |
+
elif model_name=="GPT-3.5":
|
| 274 |
+
chosen_llm=ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
|
| 275 |
+
elif model_name=="Llama-3 70B":
|
| 276 |
+
chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-70b-instruct", temperature=0,max_tokens=max_tokens )
|
| 277 |
+
elif model_name=="Llama-3 8B":
|
| 278 |
+
chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-8b-instruct", temperature=0, max_tokens=max_tokens)
|
| 279 |
+
elif model_name=="Gemini-1.5 Pro":
|
| 280 |
+
chosen_llm = ChatOpenRouter(model_name="google/gemini-pro-1.5", temperature=0, max_tokens=max_tokens)
|
| 281 |
+
elif model_name=="Claude 3 Opus":
|
| 282 |
+
chosen_llm = ChatAnthropic(model_name='claude-3-opus-20240229', temperature=0, max_tokens=max_tokens)
|
| 283 |
+
elif model_name=="Claude 3.5 Sonnet":
|
| 284 |
+
chosen_llm = ChatAnthropic(model_name='claude-3-5-sonnet-20240620', temperature=0, max_tokens=max_tokens)
|
| 285 |
+
else:
|
| 286 |
+
print("No appropriate llm was selected")
|
| 287 |
+
exit()
|
| 288 |
+
|
| 289 |
+
if application_mode == "Farmer":
|
| 290 |
+
language_constraint = "The language should be customized for Farmers. The given question is likely to be asked by a farmer in the field and will help to make decisions which are immediate and practical."
|
| 291 |
+
elif application_mode == "Researcher":
|
| 292 |
+
language_constraint = "The language should be customized for a researcher. The given question is likely to be asked by a scientist and should be comprehensive, aimed at exploring new knowledge or refining existing methodologies."
|
| 293 |
+
else:
|
| 294 |
+
print("No valid mode provided. Exiting")
|
| 295 |
+
exit()
|
| 296 |
+
|
| 297 |
+
# RAG is always ON now
|
| 298 |
+
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
|
| 299 |
+
|
| 300 |
+
# Construct the species filter part
|
| 301 |
+
species_filter = {
|
| 302 |
+
"$or": [
|
| 303 |
+
{"matched_specie_" + str(i): specie_selector} for i in range(11) # Generate dynamically up to 10
|
| 304 |
+
]
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
embedding = OpenAIEmbeddings()
|
| 308 |
+
vectordb = Chroma(persist_directory=database_persistent_directory,
|
| 309 |
+
embedding_function=embedding)
|
| 310 |
+
|
| 311 |
+
# --- Find all available regions for this species --- #
|
| 312 |
+
availability_message = f"Checking region availability for {specie_selector}..."
|
| 313 |
+
available_regions = set()
|
| 314 |
+
|
| 315 |
+
try:
|
| 316 |
+
# Query ChromaDB just for metadata based on species
|
| 317 |
+
species_docs = vectordb.get(where=species_filter, include=['metadatas'])
|
| 318 |
+
if species_docs and species_docs.get('metadatas'):
|
| 319 |
+
for meta in species_docs['metadatas']:
|
| 320 |
+
if 'region' in meta:
|
| 321 |
+
available_regions.add(meta['region'])
|
| 322 |
+
|
| 323 |
+
if available_regions:
|
| 324 |
+
available_regions_list = sorted(list(available_regions))
|
| 325 |
+
availability_message = f"Information for **{specie_selector}** is available in region(s): **{', '.join(available_regions_list)}**."
|
| 326 |
+
else:
|
| 327 |
+
available_regions_list = []
|
| 328 |
+
availability_message = f"No regional information found for **{specie_selector}** in the {domain_name} database."
|
| 329 |
+
except Exception as e:
|
| 330 |
+
print(f"Error checking region availability: {e}")
|
| 331 |
+
available_regions_list = []
|
| 332 |
+
availability_message = f"Could not determine region availability for {specie_selector}."
|
| 333 |
+
|
| 334 |
+
# --- Prepare context sections by region --- #
|
| 335 |
+
# Dictionary to hold context documents for each region
|
| 336 |
+
# region_contexts = {} # Unused variable, removing
|
| 337 |
+
|
| 338 |
+
# First check if selected region has information
|
| 339 |
+
selected_region_has_info = region in available_regions
|
| 340 |
+
|
| 341 |
+
# Create list of other available regions (excluding selected region)
|
| 342 |
+
other_regions = [r for r in available_regions_list if r != region]
|
| 343 |
+
|
| 344 |
+
# --- Create multi-region retrieval chain --- #
|
| 345 |
+
class MultiRegionRetriever:
|
| 346 |
+
def __init__(self, vectordb, species_filter, selected_region, other_regions, k=3):
|
| 347 |
+
self.vectordb = vectordb
|
| 348 |
+
self.species_filter = species_filter
|
| 349 |
+
self.selected_region = selected_region
|
| 350 |
+
self.other_regions = other_regions
|
| 351 |
+
self.k = k
|
| 352 |
+
|
| 353 |
+
def get_relevant_documents(self, query):
|
| 354 |
+
all_docs = []
|
| 355 |
+
region_docs = {}
|
| 356 |
+
|
| 357 |
+
# First get documents for selected region
|
| 358 |
+
# Fix illogical condition: self.selected_region == self.selected_region is always True
|
| 359 |
+
# Replace with a check if selected_region exists
|
| 360 |
+
if self.selected_region:
|
| 361 |
+
selected_filter = {"$and": [self.species_filter, {"region": self.selected_region}]}
|
| 362 |
+
selected_retriever = self.vectordb.as_retriever(search_kwargs={'k': self.k, 'filter': selected_filter})
|
| 363 |
+
try:
|
| 364 |
+
selected_docs = selected_retriever.get_relevant_documents(query)
|
| 365 |
+
if selected_docs:
|
| 366 |
+
all_docs.extend(selected_docs)
|
| 367 |
+
region_docs[self.selected_region] = selected_docs
|
| 368 |
+
except Exception as e:
|
| 369 |
+
print(f"Error retrieving docs for selected region {self.selected_region}: {e}")
|
| 370 |
+
|
| 371 |
+
# Then get documents for each other region
|
| 372 |
+
for other_region in self.other_regions:
|
| 373 |
+
if other_region != self.selected_region: # Skip if same as selected region
|
| 374 |
+
other_filter = {"$and": [self.species_filter, {"region": other_region}]}
|
| 375 |
+
other_retriever = self.vectordb.as_retriever(search_kwargs={'k': self.k, 'filter': other_filter})
|
| 376 |
+
try:
|
| 377 |
+
other_docs = other_retriever.get_relevant_documents(query)
|
| 378 |
+
if other_docs:
|
| 379 |
+
all_docs.extend(other_docs)
|
| 380 |
+
region_docs[other_region] = other_docs
|
| 381 |
+
except Exception as e:
|
| 382 |
+
print(f"Error retrieving docs for region {other_region}: {e}")
|
| 383 |
+
|
| 384 |
+
# Store the region-specific documents for formatting in the prompt
|
| 385 |
+
self.last_region_docs = region_docs
|
| 386 |
+
return all_docs
|
| 387 |
+
|
| 388 |
+
# Initialize the multi-region retriever
|
| 389 |
+
multi_region_retriever = MultiRegionRetriever(
|
| 390 |
+
vectordb=vectordb,
|
| 391 |
+
species_filter=species_filter,
|
| 392 |
+
selected_region=region,
|
| 393 |
+
other_regions=available_regions_list,
|
| 394 |
+
k=retriever_k_value
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Custom prompt handler that formats context by region
|
| 398 |
+
# Remove unused imports
|
| 399 |
+
# from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 400 |
+
# from langchain.chains import create_retrieval_chain
|
| 401 |
+
|
| 402 |
+
# Updated prompt template for multi-part response with region-specific contexts
|
| 403 |
+
general_system_template = f"""
|
| 404 |
+
You are an AI assistant specialized in providing information about {domain_name.lower()} ({specie_selector}). The user is primarily interested in the '{region}' region.
|
| 405 |
+
|
| 406 |
+
The following context has been retrieved from a database organized by region:
|
| 407 |
+
|
| 408 |
+
{{context}}
|
| 409 |
+
|
| 410 |
+
Instructions:
|
| 411 |
+
1. Analyze the user's question in relation to {specie_selector}.
|
| 412 |
+
2. Structure your answer in the following multi-part format:
|
| 413 |
+
|
| 414 |
+
**Part 1: Selected Region Information ({region})**
|
| 415 |
+
If relevant information exists in the context for the selected region that answers the user's query:
|
| 416 |
+
Based on your selected region ({region}), for {specie_selector}, [summary of information for selected region] [1].
|
| 417 |
+
|
| 418 |
+
If no relevant information exists for the selected region:
|
| 419 |
+
"Based on the provided documents, there is no specific information for {specie_selector} in your selected region ({region}) regarding your question."
|
| 420 |
+
|
| 421 |
+
**Part 2: Other Regions Information** (Only include if information from other regions is available AND relevant to the query)
|
| 422 |
+
If you found relevant information from other regions that answers the user's query, include:
|
| 423 |
+
|
| 424 |
+
Additionally, information was found for other regions:
|
| 425 |
+
- In [Other Region Name]: [summary of information that directly answers the user's query] [next reference number].
|
| 426 |
+
- In [Another Region Name]: [summary of information that directly answers the user's query] [next reference number].
|
| 427 |
+
|
| 428 |
+
Only include regions where the information directly addresses the user's question.
|
| 429 |
+
Use consecutive reference numbers starting from where Part 1 left off.
|
| 430 |
+
If no other regions have relevant information, omit this part entirely.
|
| 431 |
+
|
| 432 |
+
**Part 3: General Knowledge** (Only include if context information is insufficient or incomplete)
|
| 433 |
+
If the available context does not fully address the query, add:
|
| 434 |
+
|
| 435 |
+
Based on my general knowledge as {model_name}: [Your general knowledge insights that directly address the query] [next reference number].
|
| 436 |
+
|
| 437 |
+
If the context information is sufficient, omit this part entirely.
|
| 438 |
+
|
| 439 |
+
3. After providing all parts of your answer, include a References section ONLY for information you actually used:
|
| 440 |
+
|
| 441 |
+
References:
|
| 442 |
+
[1] Based on Expert Curated information about {specie_selector} in {region}
|
| 443 |
+
[2] Based on Expert Curated information about {specie_selector} in [Other Region Name]
|
| 444 |
+
[3] Based on Expert Curated information about {specie_selector} in [Another Region Name]
|
| 445 |
+
[x] {model_name}'s inherent knowledge
|
| 446 |
+
|
| 447 |
+
IMPORTANT:
|
| 448 |
+
- Only include reference numbers that correspond to information you actually used in your answer.
|
| 449 |
+
- Reference numbers should be sequential (1, 2, 3...) based on the order they appear in your answer.
|
| 450 |
+
- If you don't use information from a particular region, don't include a reference for it.
|
| 451 |
+
- If you don't use general knowledge, don't include a reference for it.
|
| 452 |
+
- Every claim with a reference marker [x] must have a corresponding entry in the References section.
|
| 453 |
+
|
| 454 |
+
4. Apply this language constraint: {language_constraint}
|
| 455 |
+
5. Keep your summaries concise and directly related to the user's question.
|
| 456 |
+
|
| 457 |
+
User Question about {specie_selector} ({domain_name}): {{question}}
|
| 458 |
+
"""
|
| 459 |
+
|
| 460 |
+
class RegionFormattingLLMChain:
|
| 461 |
+
def __init__(self, llm, prompt, retriever):
|
| 462 |
+
self.llm = llm
|
| 463 |
+
self.prompt = prompt
|
| 464 |
+
self.retriever = retriever
|
| 465 |
+
|
| 466 |
+
def __call__(self, inputs):
|
| 467 |
+
# Get documents using the multi-region retriever
|
| 468 |
+
docs = self.retriever.get_relevant_documents(inputs["question"])
|
| 469 |
+
|
| 470 |
+
# Get the region-specific document organization
|
| 471 |
+
region_docs = getattr(self.retriever, "last_region_docs", {})
|
| 472 |
+
|
| 473 |
+
# Format context with clear region sections
|
| 474 |
+
formatted_context = ""
|
| 475 |
+
|
| 476 |
+
# First add context for selected region if available
|
| 477 |
+
if region in region_docs:
|
| 478 |
+
formatted_context += f"--- CONTEXT FROM SELECTED REGION: {region} ---\n"
|
| 479 |
+
for i, doc in enumerate(region_docs[region]):
|
| 480 |
+
formatted_context += f"Document {i+1} from {region}:\n{doc.page_content}\n\n"
|
| 481 |
+
|
| 482 |
+
# Then add context for each other region
|
| 483 |
+
for other_region in [r for r in region_docs.keys() if r != region]:
|
| 484 |
+
formatted_context += f"--- CONTEXT FROM OTHER REGION: {other_region} ---\n"
|
| 485 |
+
for i, doc in enumerate(region_docs[other_region]):
|
| 486 |
+
formatted_context += f"Document {i+1} from {other_region}:\n{doc.page_content}\n\n"
|
| 487 |
+
|
| 488 |
+
# Replace the context placeholder with our formatted context
|
| 489 |
+
formatted_prompt = self.prompt.format(
|
| 490 |
+
context=formatted_context,
|
| 491 |
+
question=inputs["question"]
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# Call the LLM with our formatted prompt
|
| 495 |
+
result = self.llm.invoke(formatted_prompt)
|
| 496 |
+
|
| 497 |
+
# Return the result in the expected format
|
| 498 |
+
return {"answer": result.content, "source_documents": docs}
|
| 499 |
+
|
| 500 |
+
# Create the custom chain
|
| 501 |
+
qa_chain = RegionFormattingLLMChain(
|
| 502 |
+
llm=chosen_llm,
|
| 503 |
+
prompt=general_system_template,
|
| 504 |
+
retriever=multi_region_retriever
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
return qa_chain, availability_message
|
| 508 |
+
# result = qa_chain.invoke({"question": "where are stalk borer eggs laid?"})
|
| 509 |
+
# print("Got the first LLM task working: ", result)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
#Application Interface:
|
| 513 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 514 |
+
with gr.Row():
|
| 515 |
+
with gr.Column(scale=1):
|
| 516 |
+
gr.Markdown(
|
| 517 |
+
"""
|
| 518 |
+

|
| 519 |
+
"""
|
| 520 |
+
)
|
| 521 |
+
with gr.Column(scale=1):
|
| 522 |
+
gr.Markdown(
|
| 523 |
+
"""
|
| 524 |
+

|
| 525 |
+
"""
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
# Configure UI layout
|
| 529 |
+
chatbot = gr.Chatbot(height=600, label="AgLLM")
|
| 530 |
+
with gr.Row():
|
| 531 |
+
with gr.Column(scale=1):
|
| 532 |
+
with gr.Row():
|
| 533 |
+
domain_name = gr.Dropdown(
|
| 534 |
+
list(["Insects", "Weeds"]),
|
| 535 |
+
value="Insects",
|
| 536 |
+
label="Domain",
|
| 537 |
+
info="Select Domain",
|
| 538 |
+
interactive=True,
|
| 539 |
+
scale=1,
|
| 540 |
+
visible=True
|
| 541 |
+
)
|
| 542 |
+
region_selector = gr.Dropdown(
|
| 543 |
+
list(["United States", "India", "Africa"]), # Updated regions
|
| 544 |
+
value="United States", # Updated default
|
| 545 |
+
label="Region",
|
| 546 |
+
info="Select the Region",
|
| 547 |
+
interactive=True,
|
| 548 |
+
scale=1,
|
| 549 |
+
visible=True
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
# Model selection
|
| 553 |
+
specie_selector = gr.Dropdown(
|
| 554 |
+
list(set(species_list_insects)),
|
| 555 |
+
value=species_list_insects[0],
|
| 556 |
+
label="Species",
|
| 557 |
+
info="Select the Species",
|
| 558 |
+
interactive=True,
|
| 559 |
+
scale=1,
|
| 560 |
+
visible=True
|
| 561 |
+
)
|
| 562 |
+
with gr.Row():
|
| 563 |
+
model_name = gr.Dropdown(
|
| 564 |
+
list(["GPT-4", "GPT-3.5", "Llama-3 70B", "Llama-3 8B", "Gemini-1.5 Pro", "Claude 3 Opus", "Claude 3.5 Sonnet"]),
|
| 565 |
+
value="Llama-3 70B",
|
| 566 |
+
label="LLM",
|
| 567 |
+
info="Select the LLM",
|
| 568 |
+
interactive=True,
|
| 569 |
+
scale=1,
|
| 570 |
+
visible=True
|
| 571 |
+
)
|
| 572 |
+
application_mode = gr.Dropdown(
|
| 573 |
+
list(["Farmer", "Researcher"]),
|
| 574 |
+
value="Researcher",
|
| 575 |
+
label="Mode",
|
| 576 |
+
info="Select the Mode",
|
| 577 |
+
interactive=True,
|
| 578 |
+
scale=1,
|
| 579 |
+
visible=True
|
| 580 |
+
)
|
| 581 |
+
region_availability_display = gr.Markdown(value="Select species/domain to see region availability.") # Added display area
|
| 582 |
+
|
| 583 |
+
with gr.Column(scale=2):
|
| 584 |
+
# User input prompt text field
|
| 585 |
+
user_prompt_message = gr.Textbox(placeholder="Please add user prompt here", label="User prompt")
|
| 586 |
+
with gr.Row():
|
| 587 |
+
# clear = gr.Button("Clear Conversation", scale=2)
|
| 588 |
+
submitBtn = gr.Button("Submit", scale=8)
|
| 589 |
+
|
| 590 |
+
state = gr.State([])
|
| 591 |
+
qa_chain_state = gr.State(value=None)
|
| 592 |
+
|
| 593 |
+
# Handle user message
|
| 594 |
+
def user(user_prompt_message, history):
|
| 595 |
+
# print("HISTORY IS: ", history) # TODO: REMOVE IT LATER
|
| 596 |
+
if user_prompt_message != "":
|
| 597 |
+
return history + [[user_prompt_message, None]]
|
| 598 |
+
else:
|
| 599 |
+
return history + [["Invalid prompts - user prompt cannot be empty", None]]
|
| 600 |
+
|
| 601 |
+
# Chatbot logic for configuration, sending the prompts, rendering the streamed back generations, etc.
|
| 602 |
+
def bot(model_name, application_mode, user_prompt_message, history, messages_history, qa_chain, domain_name, region): # Removed use_rag
|
| 603 |
+
if qa_chain == None:
|
| 604 |
+
# Initial QA chain setup if not already done (uses default species for the selected domain)
|
| 605 |
+
initial_species = species_list_insects[0] if domain_name == "Insects" else species_list_weeds[0]
|
| 606 |
+
# Need to handle the tuple returned by init_qa_chain now
|
| 607 |
+
# Use the currently selected region for initialization if qa_chain is None
|
| 608 |
+
qa_chain, _ = init_qa_chain(initial_species, application_mode, model_name, domain_name, region) # Pass region
|
| 609 |
+
|
| 610 |
+
history[-1][1] = "" # Placeholder for the answer
|
| 611 |
+
|
| 612 |
+
# RAG is always ON now
|
| 613 |
+
result = qa_chain({"question": user_prompt_message, "chat_history": messages_history})
|
| 614 |
+
answer = result["answer"]
|
| 615 |
+
# source_documents = result.get("source_documents", []) # Keep source_documents if needed for debugging or future refinement
|
| 616 |
+
# formatted_response = format_response_with_source(answer, source_documents, domain_name, region) # REMOVED: Rely on LLM prompt now
|
| 617 |
+
|
| 618 |
+
history[-1][1] = answer # Assign raw LLM answer directly
|
| 619 |
+
return [history, messages_history]
|
| 620 |
+
|
| 621 |
+
# Helper function to format the response with source information
|
| 622 |
+
# def format_response_with_source(answer, source_documents, domain_name, region): # Pass region # FUNCTION NO LONGER USED
|
| 623 |
+
# try:
|
| 624 |
+
# answer_start = answer.find("Answer:")
|
| 625 |
+
# source_start = answer.find("Source:")
|
| 626 |
+
# ... (rest of the function commented out or removed) ...
|
| 627 |
+
# except Exception as e:
|
| 628 |
+
# print(f"Error parsing output or formatting source: {e}")
|
| 629 |
+
# formatted_response = answer # Return raw answer on error
|
| 630 |
+
#
|
| 631 |
+
# return formatted_response
|
| 632 |
+
|
| 633 |
+
# Initialize the chat history with default system message
|
| 634 |
+
def init_history(messages_history):
|
| 635 |
+
messages_history = []
|
| 636 |
+
messages_history += [system_message]
|
| 637 |
+
return messages_history
|
| 638 |
+
|
| 639 |
+
# Clean up the user input text field
|
| 640 |
+
def input_cleanup():
|
| 641 |
+
return ""
|
| 642 |
+
|
| 643 |
+
def init_qa_chain(specie_selector, application_mode, model_name, domain_name, region): # Removed use_rag
|
| 644 |
+
print(f"--- init_qa_chain wrapper called with domain: '{domain_name}' ---") # DIAGNOSTIC PRINT
|
| 645 |
+
qa_chain_instance = None
|
| 646 |
+
availability_msg = "Error initializing QA chain."
|
| 647 |
+
try:
|
| 648 |
+
if domain_name=="Insects":
|
| 649 |
+
qa_chain_instance, availability_msg = initialize_qa_chain(specie_selector, application_mode, model_name, region, default_persist_directory_insects, domain_name) # Removed use_rag
|
| 650 |
+
elif domain_name=="Weeds":
|
| 651 |
+
qa_chain_instance, availability_msg = initialize_qa_chain(specie_selector, application_mode, model_name, region, default_persist_directory_weeds, domain_name) # Removed use_rag
|
| 652 |
+
else:
|
| 653 |
+
print("No Appropriate Chain Selected")
|
| 654 |
+
availability_msg = "Invalid domain selected."
|
| 655 |
+
# Return None for chain and the message
|
| 656 |
+
except Exception as e:
|
| 657 |
+
print(f"Error in init_qa_chain wrapper: {e}")
|
| 658 |
+
availability_msg = f"Error initializing: {e}"
|
| 659 |
+
|
| 660 |
+
return qa_chain_instance, availability_msg # Return both chain and message
|
| 661 |
+
|
| 662 |
+
# Update QA chain AND availability message when relevant inputs change
|
| 663 |
+
inputs_for_qa_chain = [specie_selector, application_mode, model_name, domain_name, region_selector] # CORRECT ORDER
|
| 664 |
+
outputs_for_qa_chain = [qa_chain_state, region_availability_display]
|
| 665 |
+
|
| 666 |
+
# specie_selector.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
|
| 667 |
+
# model_name.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
|
| 668 |
+
# region_selector.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
|
| 669 |
+
# application_mode.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
|
| 670 |
+
# domain_name.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
|
| 671 |
+
specie_selector.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
|
| 672 |
+
model_name.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
|
| 673 |
+
region_selector.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
|
| 674 |
+
application_mode.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
|
| 675 |
+
domain_name.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
|
| 676 |
+
|
| 677 |
+
#####
|
| 678 |
+
def update_species_list(domain):
|
| 679 |
+
if domain == "Insects":
|
| 680 |
+
return gr.Dropdown( species_list_insects, value=species_list_insects[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
|
| 681 |
+
elif domain == "Weeds":
|
| 682 |
+
return gr.Dropdown( species_list_weeds, value=species_list_weeds[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
|
| 683 |
+
|
| 684 |
+
domain_name.change(
|
| 685 |
+
update_species_list,
|
| 686 |
+
inputs=[domain_name],
|
| 687 |
+
outputs=[specie_selector]
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
# When the user clicks Enter and the user message is submitted
|
| 691 |
+
user_prompt_message.submit(
|
| 692 |
+
user,
|
| 693 |
+
[user_prompt_message, chatbot],
|
| 694 |
+
[chatbot],
|
| 695 |
+
queue=False
|
| 696 |
+
).then(
|
| 697 |
+
bot,
|
| 698 |
+
[model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name, region_selector], # Removed use_rag
|
| 699 |
+
[chatbot, state]
|
| 700 |
+
).then(input_cleanup,
|
| 701 |
+
[],
|
| 702 |
+
[user_prompt_message],
|
| 703 |
+
queue=False
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
# When the user clicks the submit button
|
| 707 |
+
submitBtn.click(
|
| 708 |
+
user,
|
| 709 |
+
[user_prompt_message, chatbot],
|
| 710 |
+
[chatbot],
|
| 711 |
+
queue=False
|
| 712 |
+
).then(
|
| 713 |
+
bot,
|
| 714 |
+
[model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name, region_selector], # Removed use_rag
|
| 715 |
+
[chatbot, state]
|
| 716 |
+
).then(
|
| 717 |
+
input_cleanup,
|
| 718 |
+
[],
|
| 719 |
+
[user_prompt_message],
|
| 720 |
+
queue=False
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
# When the user clicks the clear button
|
| 724 |
+
# clear.click(lambda: None, None, chatbot, queue=False).success(init_history, [state], [state])
|
| 725 |
+
if __name__ == "__main__":
|
| 726 |
+
# demo.launch()
|
| 727 |
+
demo.queue().launch(allowed_paths=["/"], share=False, show_error=True)
|
app_backup.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# https://stackoverflow.com/questions/76175046/how-to-add-prompt-to-langchain-conversationalretrievalchain-chat-over-docs-with
|
| 3 |
+
# again from:
|
| 4 |
+
# https://python.langchain.com/docs/integrations/providers/vectara/vectara_chat
|
| 5 |
+
from langchain.document_loaders import PyPDFDirectoryLoader
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import langchain
|
| 8 |
+
from queue import Queue
|
| 9 |
+
from typing import Any
|
| 10 |
+
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
| 11 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 12 |
+
from langchain.schema import LLMResult
|
| 13 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 14 |
+
from langchain.vectorstores import FAISS
|
| 15 |
+
from langchain.prompts.prompt import PromptTemplate
|
| 16 |
+
from anyio.from_thread import start_blocking_portal #For model callback streaming
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
| 20 |
+
import os
|
| 21 |
+
from dotenv import load_dotenv
|
| 22 |
+
|
| 23 |
+
import streamlit as st
|
| 24 |
+
import json
|
| 25 |
+
from langchain.document_loaders import PyPDFLoader
|
| 26 |
+
from langchain.text_splitter import CharacterTextSplitter
|
| 27 |
+
from langchain.embeddings import OpenAIEmbeddings
|
| 28 |
+
from langchain.chains.question_answering import load_qa_chain
|
| 29 |
+
from langchain.chat_models import ChatOpenAI
|
| 30 |
+
# from langchain.chat_models import ChatAnthropic
|
| 31 |
+
from langchain_anthropic import ChatAnthropic
|
| 32 |
+
from langchain.vectorstores import Chroma
|
| 33 |
+
import chromadb
|
| 34 |
+
|
| 35 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 36 |
+
from langchain.llms import OpenAI
|
| 37 |
+
from langchain.chains import RetrievalQA
|
| 38 |
+
from langchain.document_loaders import TextLoader
|
| 39 |
+
from langchain.document_loaders import DirectoryLoader
|
| 40 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
| 41 |
+
from langchain.schema import Document
|
| 42 |
+
|
| 43 |
+
from langchain.memory import ConversationBufferMemory
|
| 44 |
+
|
| 45 |
+
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
| 46 |
+
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
| 47 |
+
import gradio as gr
|
| 48 |
+
from langchain.memory import ConversationBufferMemory
|
| 49 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 50 |
+
print("Started")
|
| 51 |
+
|
| 52 |
+
def get_species_list_from_db(db_name):
|
| 53 |
+
embedding = OpenAIEmbeddings()
|
| 54 |
+
vectordb_temp = Chroma(persist_directory=db_name,
|
| 55 |
+
embedding_function=embedding)
|
| 56 |
+
species_list=[]
|
| 57 |
+
for meta in vectordb_temp.get()["metadatas"] :
|
| 58 |
+
try:
|
| 59 |
+
matched_first_species = meta['matched_specie_0']
|
| 60 |
+
except KeyError:
|
| 61 |
+
continue
|
| 62 |
+
# Since each document is considered as a single chunk, the chunk_index is 0 for all
|
| 63 |
+
species_list.append( matched_first_species)
|
| 64 |
+
|
| 65 |
+
return species_list
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# default_persist_directory = './db5' # For deployement
|
| 70 |
+
default_persist_directory_insects='./vector-databases-deployed/db5-agllm-data-isu-field-insects-all-species'
|
| 71 |
+
default_persist_directory_weeds='./vector-databases-deployed/db5-agllm-data-isu-field-weeds-all-species'
|
| 72 |
+
|
| 73 |
+
species_list_insects=get_species_list_from_db(default_persist_directory_insects)
|
| 74 |
+
species_list_weeds=get_species_list_from_db(default_persist_directory_weeds)
|
| 75 |
+
# default_persist_directory = 'vector-databases/db5-pre-completion' # For Development
|
| 76 |
+
csv_filepath1 = "./agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx"
|
| 77 |
+
csv_filepath2 = "./agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx"
|
| 78 |
+
model_name=4
|
| 79 |
+
max_tokens=400
|
| 80 |
+
system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later.
|
| 81 |
+
langchain.debug=False # TODO: DOUBLE CHECK
|
| 82 |
+
from langchain import globals
|
| 83 |
+
globals.set_debug(False)
|
| 84 |
+
|
| 85 |
+
retriever_k_value=3
|
| 86 |
+
embedding = OpenAIEmbeddings()
|
| 87 |
+
print("Started....")
|
| 88 |
+
class ChatOpenRouter(ChatOpenAI):
|
| 89 |
+
openai_api_base: str
|
| 90 |
+
openai_api_key: str
|
| 91 |
+
model_name: str
|
| 92 |
+
|
| 93 |
+
def __init__(self,
|
| 94 |
+
model_name: str,
|
| 95 |
+
openai_api_key: [str] = None,
|
| 96 |
+
openai_api_base: str = "https://openrouter.ai/api/v1",
|
| 97 |
+
**kwargs):
|
| 98 |
+
openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY')
|
| 99 |
+
super().__init__(openai_api_base=openai_api_base,
|
| 100 |
+
openai_api_key=openai_api_key,
|
| 101 |
+
model_name=model_name, **kwargs)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
######### todo: skipping the first step
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# print(# Single example
|
| 109 |
+
# vectordb.as_retriever(k=2, search_kwargs={"filter": {"matched_specie_0": "Hypagyrtis unipunctata"}, 'k':1}).get_relevant_documents(
|
| 110 |
+
# "Checking if retriever is correctly initalized?"
|
| 111 |
+
# ))
|
| 112 |
+
|
| 113 |
+
columns = ['species', 'common name', 'order', 'family',
|
| 114 |
+
'genus', 'Updated role in ecosystem', 'Proof',
|
| 115 |
+
'ipm strategies', 'size of insect', 'geographical spread',
|
| 116 |
+
'life cycle specifics', 'pest for plant species', 'species status',
|
| 117 |
+
'distribution area', 'appearance', 'identification']
|
| 118 |
+
|
| 119 |
+
df1 = pd.read_excel(csv_filepath1, usecols=columns)
|
| 120 |
+
df2 = pd.read_excel(csv_filepath2, usecols=columns)
|
| 121 |
+
|
| 122 |
+
all_insects_data = pd.concat([df1, df2], ignore_index=True)
|
| 123 |
+
|
| 124 |
+
def get_prompt_with_vetted_info_from_specie_name(search_for_specie, mode):
|
| 125 |
+
|
| 126 |
+
def read_and_format_filtered_csv_better(insect_specie):
|
| 127 |
+
filtered_data = all_insects_data[all_insects_data['species'] == insect_specie]
|
| 128 |
+
formatted_data = ""
|
| 129 |
+
# Format the filtered data
|
| 130 |
+
for index, row in filtered_data.iterrows():
|
| 131 |
+
row_data = [f"{col}: {row[col]}" for col in filtered_data.columns]
|
| 132 |
+
formatted_row = "\n".join(row_data)
|
| 133 |
+
formatted_data += f"{formatted_row}\n"
|
| 134 |
+
|
| 135 |
+
return formatted_data
|
| 136 |
+
|
| 137 |
+
# Use the path to your CSV file here
|
| 138 |
+
|
| 139 |
+
vetted_info=read_and_format_filtered_csv_better(search_for_specie)
|
| 140 |
+
|
| 141 |
+
if mode=="Farmer":
|
| 142 |
+
language_constraint="The language should be acustomed to the Farmers. Given question is likely to be asked by a farmer in the field will ask which will help to make decisions which are immediate and practical."
|
| 143 |
+
elif mode=="Researcher":
|
| 144 |
+
language_constraint="The language should be acustomed to a researcher. Given question is likely to be asked by a scientist which are comprehensive and aimed at exploring new knowledge or refining existing methodologies"
|
| 145 |
+
else:
|
| 146 |
+
print("No valid mode provided. Exiting")
|
| 147 |
+
exit()
|
| 148 |
+
# general_system_template = """
|
| 149 |
+
# In every question you are provided information about the insect/weed. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect/weed species and a question by the user. answer the question according to these two types of informations.
|
| 150 |
+
# ----
|
| 151 |
+
# Vetted info is as follows:
|
| 152 |
+
# {vetted_info}
|
| 153 |
+
# ----
|
| 154 |
+
# The context retrieved for documents about this particular question is as follows:
|
| 155 |
+
# {context}
|
| 156 |
+
# ----
|
| 157 |
+
# Additional Instruction:
|
| 158 |
+
# 1. Reference Constraint
|
| 159 |
+
# At the end of each answer provide the source/reference for the given data in following format:
|
| 160 |
+
# \n\n[enter two new lines before writing below] References:
|
| 161 |
+
# Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'.
|
| 162 |
+
# Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used.
|
| 163 |
+
# 2. Information Constraint:
|
| 164 |
+
# Only answer the question from information provided otherwise say you dont know. You have to answer in 50 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about.
|
| 165 |
+
# 3. Language constraint:
|
| 166 |
+
# {language_constraint}
|
| 167 |
+
|
| 168 |
+
# ----
|
| 169 |
+
# """.format(vetted_info=vetted_info, language_constraint=language_constraint,context="{context}", )
|
| 170 |
+
|
| 171 |
+
general_system_template = f"""
|
| 172 |
+
You are an AI assistant specialized in providing information about insects/weeds. Answer the user's question based on the available information or your general knowledge.
|
| 173 |
+
|
| 174 |
+
The context retrieved for this question is as follows:
|
| 175 |
+
{{context}}
|
| 176 |
+
|
| 177 |
+
Instructions:
|
| 178 |
+
1. Evaluate the relevance of the provided context to the question.
|
| 179 |
+
2. If the context contains relevant information, use it to answer the question.
|
| 180 |
+
3. If the context does not contain relevant information, use your general knowledge to answer the question.
|
| 181 |
+
4. Format your response as follows:
|
| 182 |
+
Answer: Provide a concise answer in less than 50 words.
|
| 183 |
+
Reference: If you used the provided context, cite the specific information used. If you used your general knowledge, state "Based on general knowledge".
|
| 184 |
+
|
| 185 |
+
5. Language constraint:
|
| 186 |
+
{language_constraint}
|
| 187 |
+
|
| 188 |
+
Question: {{question}}
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
general_user_template = "Question:```{question}```"
|
| 192 |
+
messages_formatted = [
|
| 193 |
+
SystemMessagePromptTemplate.from_template(general_system_template),
|
| 194 |
+
HumanMessagePromptTemplate.from_template(general_user_template)
|
| 195 |
+
]
|
| 196 |
+
qa_prompt = ChatPromptTemplate.from_messages( messages_formatted )
|
| 197 |
+
# print(qa_prompt)
|
| 198 |
+
return qa_prompt
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
qa_prompt=get_prompt_with_vetted_info_from_specie_name("Papaipema nebris", "Researcher")
|
| 205 |
+
# print("First prompt is intialized as: " , qa_prompt, "\n\n")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer', return_messages=True) # https://github.com/langchain-ai/langchain/issues/9394#issuecomment-1683538834
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if model_name==4:
|
| 212 |
+
llm_openai = ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens) # TODO: NEW MODEL VERSION AVAILABLE
|
| 213 |
+
else:
|
| 214 |
+
llm_openai = ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
|
| 215 |
+
|
| 216 |
+
specie_selector="Papaipema nebris"
|
| 217 |
+
filter = {
|
| 218 |
+
"$or": [
|
| 219 |
+
{"matched_specie_0": specie_selector},
|
| 220 |
+
{"matched_specie_1": specie_selector},
|
| 221 |
+
{"matched_specie_2": specie_selector},
|
| 222 |
+
]
|
| 223 |
+
}
|
| 224 |
+
# retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
|
| 225 |
+
|
| 226 |
+
# qa_chain = ConversationalRetrievalChain.from_llm(
|
| 227 |
+
# llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,\
|
| 228 |
+
# combine_docs_chain_kwargs={'prompt': qa_prompt}
|
| 229 |
+
|
| 230 |
+
# )
|
| 231 |
+
#
|
| 232 |
+
|
| 233 |
+
def initialize_qa_chain(specie_selector, application_mode, model_name="GPT-4", database_persistent_directory=default_persist_directory_insects):
|
| 234 |
+
if model_name=="GPT-4":
|
| 235 |
+
chosen_llm=ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens)
|
| 236 |
+
elif model_name=="GPT-3.5":
|
| 237 |
+
chosen_llm=ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
|
| 238 |
+
elif model_name=="Llama-3 70B":
|
| 239 |
+
chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-70b-instruct", temperature=0,max_tokens=max_tokens )
|
| 240 |
+
elif model_name=="Llama-3 8B":
|
| 241 |
+
chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-8b-instruct", temperature=0, max_tokens=max_tokens)
|
| 242 |
+
elif model_name=="Gemini-1.5 Pro":
|
| 243 |
+
chosen_llm = ChatOpenRouter(model_name="google/gemini-pro-1.5", temperature=0, max_tokens=max_tokens)
|
| 244 |
+
elif model_name=="Claude 3 Opus":
|
| 245 |
+
chosen_llm = ChatAnthropic(model_name='claude-3-opus-20240229', temperature=0, max_tokens=max_tokens)
|
| 246 |
+
|
| 247 |
+
else:
|
| 248 |
+
print("No appropriate llm was selected")
|
| 249 |
+
exit()
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
filter = {
|
| 254 |
+
"$or": [
|
| 255 |
+
{"matched_specie_0": specie_selector},
|
| 256 |
+
{"matched_specie_1": specie_selector},
|
| 257 |
+
{"matched_specie_2": specie_selector},
|
| 258 |
+
{"matched_specie_3": specie_selector},
|
| 259 |
+
{"matched_specie_4": specie_selector},
|
| 260 |
+
{"matched_specie_5": specie_selector},
|
| 261 |
+
{"matched_specie_6": specie_selector},
|
| 262 |
+
{"matched_specie_7": specie_selector},
|
| 263 |
+
{"matched_specie_8": specie_selector},
|
| 264 |
+
{"matched_specie_9": specie_selector},
|
| 265 |
+
{"matched_specie_10": specie_selector}
|
| 266 |
+
]
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
embedding = OpenAIEmbeddings()
|
| 270 |
+
vectordb = Chroma(persist_directory=database_persistent_directory,
|
| 271 |
+
embedding_function=embedding)
|
| 272 |
+
|
| 273 |
+
print("got updated retriever without metadata filtering")
|
| 274 |
+
retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
|
| 275 |
+
|
| 276 |
+
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
|
| 277 |
+
qa_prompt=get_prompt_with_vetted_info_from_specie_name(specie_selector, application_mode)
|
| 278 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 279 |
+
chosen_llm, retriever, memory=memory, verbose=False, return_source_documents=True,
|
| 280 |
+
combine_docs_chain_kwargs={'prompt': qa_prompt}
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
return qa_chain
|
| 284 |
+
# result = qa_chain.invoke({"question": "where are stalk borer eggs laid?"})
|
| 285 |
+
# print("Got the first LLM task working: ", result)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
#Application Interface:
|
| 289 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 290 |
+
with gr.Row():
|
| 291 |
+
with gr.Column(scale=1):
|
| 292 |
+
gr.Markdown(
|
| 293 |
+
"""
|
| 294 |
+

|
| 295 |
+
"""
|
| 296 |
+
)
|
| 297 |
+
with gr.Column(scale=1):
|
| 298 |
+
gr.Markdown(
|
| 299 |
+
"""
|
| 300 |
+

|
| 301 |
+
"""
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Configure UI layout
|
| 305 |
+
chatbot = gr.Chatbot(height=600, label="AgLLM")
|
| 306 |
+
with gr.Row():
|
| 307 |
+
with gr.Column(scale=1):
|
| 308 |
+
with gr.Row():
|
| 309 |
+
domain_name = gr.Dropdown(
|
| 310 |
+
list(["Insects", "Weeds"]),
|
| 311 |
+
value="Insects",
|
| 312 |
+
label="Domain",
|
| 313 |
+
info="Select Domain",
|
| 314 |
+
interactive=True,
|
| 315 |
+
scale=1,
|
| 316 |
+
visible=True
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# Model selection
|
| 320 |
+
specie_selector = gr.Dropdown(
|
| 321 |
+
species_list_insects,
|
| 322 |
+
value=species_list_insects[0],
|
| 323 |
+
label="Species",
|
| 324 |
+
info="Select the Species",
|
| 325 |
+
interactive=True,
|
| 326 |
+
scale=1,
|
| 327 |
+
visible=True
|
| 328 |
+
)
|
| 329 |
+
with gr.Row():
|
| 330 |
+
model_name = gr.Dropdown(
|
| 331 |
+
list(["GPT-4", "GPT-3.5", "Llama-3 70B", "Llama-3 8B", "Gemini-1.5 Pro", "Claude 3 Opus"]),
|
| 332 |
+
value="Llama-3 70B",
|
| 333 |
+
label="LLM",
|
| 334 |
+
info="Select the LLM",
|
| 335 |
+
interactive=True,
|
| 336 |
+
scale=1,
|
| 337 |
+
visible=True
|
| 338 |
+
)
|
| 339 |
+
application_mode = gr.Dropdown(
|
| 340 |
+
list(["Farmer", "Researcher"]),
|
| 341 |
+
value="Researcher",
|
| 342 |
+
label="Mode",
|
| 343 |
+
info="Select the Mode",
|
| 344 |
+
interactive=True,
|
| 345 |
+
scale=1,
|
| 346 |
+
visible=True
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
with gr.Column(scale=2):
|
| 351 |
+
# User input prompt text field
|
| 352 |
+
user_prompt_message = gr.Textbox(placeholder="Please add user prompt here", label="User prompt")
|
| 353 |
+
with gr.Row():
|
| 354 |
+
# clear = gr.Button("Clear Conversation", scale=2)
|
| 355 |
+
submitBtn = gr.Button("Submit", scale=8)
|
| 356 |
+
|
| 357 |
+
state = gr.State([])
|
| 358 |
+
qa_chain_state = gr.State(value=None)
|
| 359 |
+
|
| 360 |
+
# Handle user message
|
| 361 |
+
def user(user_prompt_message, history):
|
| 362 |
+
# print("HISTORY IS: ", history) # TODO: REMOVE IT LATER
|
| 363 |
+
if user_prompt_message != "":
|
| 364 |
+
return history + [[user_prompt_message, None]]
|
| 365 |
+
else:
|
| 366 |
+
return history + [["Invalid prompts - user prompt cannot be empty", None]]
|
| 367 |
+
|
| 368 |
+
# Chatbot logic for configuration, sending the prompts, rendering the streamed back generations, etc.
|
| 369 |
+
def bot(model_name, application_mode, user_prompt_message, history, messages_history, qa_chain, domain_name):
|
| 370 |
+
if qa_chain == None:
|
| 371 |
+
qa_chain=init_qa_chain(species_list_insects[0], application_mode, model_name, domain_name)
|
| 372 |
+
|
| 373 |
+
dialog = []
|
| 374 |
+
bot_message = ""
|
| 375 |
+
history[-1][1] = "" # Placeholder for the answer
|
| 376 |
+
|
| 377 |
+
dialog = [
|
| 378 |
+
{"role": "user", "content": user_prompt_message},
|
| 379 |
+
]
|
| 380 |
+
messages_history += dialog
|
| 381 |
+
|
| 382 |
+
# Queue for streamed character rendering
|
| 383 |
+
q = Queue()
|
| 384 |
+
|
| 385 |
+
# Async task for streamed chain results wired to callbacks we previously defined, so we don't block the UI
|
| 386 |
+
|
| 387 |
+
def task(user_prompt_message):
|
| 388 |
+
result = qa_chain.invoke({"question": user_prompt_message})
|
| 389 |
+
answer = result["answer"]
|
| 390 |
+
|
| 391 |
+
try:
|
| 392 |
+
answer_start = answer.find("Answer:")
|
| 393 |
+
reference_start = answer.find("Reference:")
|
| 394 |
+
|
| 395 |
+
if answer_start != -1 and reference_start != -1:
|
| 396 |
+
model_answer = answer[answer_start + len("Answer:"):reference_start].strip()
|
| 397 |
+
reference = answer[reference_start + len("Reference:"):].strip()
|
| 398 |
+
formatted_response = f"Answer:\n{model_answer}\n\nReferences:\n{reference}"
|
| 399 |
+
else:
|
| 400 |
+
formatted_response = answer
|
| 401 |
+
except:
|
| 402 |
+
print(f"Error parsing so displaying the raw output")
|
| 403 |
+
formatted_response = answer
|
| 404 |
+
|
| 405 |
+
return formatted_response
|
| 406 |
+
|
| 407 |
+
history[-1][1] = task(user_prompt_message)
|
| 408 |
+
return [history, messages_history]
|
| 409 |
+
|
| 410 |
+
# Initialize the chat history with default system message
|
| 411 |
+
def init_history(messages_history):
|
| 412 |
+
messages_history = []
|
| 413 |
+
messages_history += [system_message]
|
| 414 |
+
return messages_history
|
| 415 |
+
|
| 416 |
+
# Clean up the user input text field
|
| 417 |
+
def input_cleanup():
|
| 418 |
+
return ""
|
| 419 |
+
|
| 420 |
+
def init_qa_chain(specie_selector, application_mode, model_name, domain_name):
|
| 421 |
+
if domain_name=="Insects":
|
| 422 |
+
qa_chain = initialize_qa_chain(specie_selector, application_mode, model_name, default_persist_directory_insects)
|
| 423 |
+
elif domain_name=="Weeds":
|
| 424 |
+
qa_chain = initialize_qa_chain(specie_selector, application_mode, model_name, default_persist_directory_weeds)
|
| 425 |
+
else:
|
| 426 |
+
print("No Appropriate Chain Selected")
|
| 427 |
+
return qa_chain
|
| 428 |
+
|
| 429 |
+
specie_selector.change(
|
| 430 |
+
init_qa_chain,
|
| 431 |
+
inputs=[specie_selector, application_mode,model_name, domain_name ],
|
| 432 |
+
outputs=[qa_chain_state]
|
| 433 |
+
)
|
| 434 |
+
model_name.change(
|
| 435 |
+
init_qa_chain,
|
| 436 |
+
inputs=[specie_selector, application_mode,model_name, domain_name ],
|
| 437 |
+
outputs=[qa_chain_state]
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
#####
|
| 441 |
+
def update_species_list(domain):
|
| 442 |
+
if domain == "Insects":
|
| 443 |
+
return gr.Dropdown( species_list_insects, value=species_list_insects[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
|
| 444 |
+
elif domain == "Weeds":
|
| 445 |
+
return gr.Dropdown( species_list_weeds, value=species_list_weeds[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
|
| 446 |
+
|
| 447 |
+
domain_name.change(
|
| 448 |
+
update_species_list,
|
| 449 |
+
inputs=[domain_name],
|
| 450 |
+
outputs=[specie_selector]
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# When the user clicks Enter and the user message is submitted
|
| 454 |
+
user_prompt_message.submit(
|
| 455 |
+
user,
|
| 456 |
+
[user_prompt_message, chatbot],
|
| 457 |
+
[chatbot],
|
| 458 |
+
queue=False
|
| 459 |
+
).then(
|
| 460 |
+
bot,
|
| 461 |
+
[model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name],
|
| 462 |
+
[chatbot, state]
|
| 463 |
+
).then(input_cleanup,
|
| 464 |
+
[],
|
| 465 |
+
[user_prompt_message],
|
| 466 |
+
queue=False
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# When the user clicks the submit button
|
| 470 |
+
submitBtn.click(
|
| 471 |
+
user,
|
| 472 |
+
[user_prompt_message, chatbot],
|
| 473 |
+
[chatbot],
|
| 474 |
+
queue=False
|
| 475 |
+
).then(
|
| 476 |
+
bot,
|
| 477 |
+
[model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name],
|
| 478 |
+
[chatbot, state]
|
| 479 |
+
).then(
|
| 480 |
+
input_cleanup,
|
| 481 |
+
[],
|
| 482 |
+
[user_prompt_message],
|
| 483 |
+
queue=False
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# When the user clicks the clear button
|
| 487 |
+
# clear.click(lambda: None, None, chatbot, queue=False).success(init_history, [state], [state])
|
| 488 |
+
if __name__ == "__main__":
|
| 489 |
+
# demo.launch()
|
| 490 |
+
demo.queue().launch(allowed_paths=["/"], share=False, show_error=True)
|
app_backup_2.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# https://stackoverflow.com/questions/76175046/how-to-add-prompt-to-langchain-conversationalretrievalchain-chat-over-docs-with
|
| 3 |
+
# again from:
|
| 4 |
+
# https://python.langchain.com/docs/integrations/providers/vectara/vectara_chat
|
| 5 |
+
from langchain.document_loaders import PyPDFDirectoryLoader
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import langchain
|
| 8 |
+
from queue import Queue
|
| 9 |
+
from typing import Any
|
| 10 |
+
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
| 11 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 12 |
+
from langchain.schema import LLMResult
|
| 13 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 14 |
+
from langchain.vectorstores import FAISS
|
| 15 |
+
from langchain.prompts.prompt import PromptTemplate
|
| 16 |
+
from anyio.from_thread import start_blocking_portal #For model callback streaming
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
| 20 |
+
import os
|
| 21 |
+
from dotenv import load_dotenv
|
| 22 |
+
|
| 23 |
+
import streamlit as st
|
| 24 |
+
import json
|
| 25 |
+
from langchain.document_loaders import PyPDFLoader
|
| 26 |
+
from langchain.text_splitter import CharacterTextSplitter
|
| 27 |
+
from langchain.embeddings import OpenAIEmbeddings
|
| 28 |
+
from langchain.chains.question_answering import load_qa_chain
|
| 29 |
+
from langchain.chat_models import ChatOpenAI
|
| 30 |
+
# from langchain.chat_models import ChatAnthropic
|
| 31 |
+
from langchain_anthropic import ChatAnthropic
|
| 32 |
+
from langchain.vectorstores import Chroma
|
| 33 |
+
import chromadb
|
| 34 |
+
|
| 35 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 36 |
+
from langchain.llms import OpenAI
|
| 37 |
+
from langchain.chains import RetrievalQA
|
| 38 |
+
from langchain.document_loaders import TextLoader
|
| 39 |
+
from langchain.document_loaders import DirectoryLoader
|
| 40 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
| 41 |
+
from langchain.schema import Document
|
| 42 |
+
|
| 43 |
+
from langchain.memory import ConversationBufferMemory
|
| 44 |
+
|
| 45 |
+
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
| 46 |
+
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
| 47 |
+
import gradio as gr
|
| 48 |
+
from langchain.memory import ConversationBufferMemory
|
| 49 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 50 |
+
print("Started")
|
| 51 |
+
|
| 52 |
+
def get_species_list_from_db(db_name):
|
| 53 |
+
embedding = OpenAIEmbeddings()
|
| 54 |
+
vectordb_temp = Chroma(persist_directory=db_name,
|
| 55 |
+
embedding_function=embedding)
|
| 56 |
+
species_list=[]
|
| 57 |
+
for meta in vectordb_temp.get()["metadatas"] :
|
| 58 |
+
try:
|
| 59 |
+
matched_first_species = meta['matched_specie_0']
|
| 60 |
+
except KeyError:
|
| 61 |
+
continue
|
| 62 |
+
# Since each document is considered as a single chunk, the chunk_index is 0 for all
|
| 63 |
+
species_list.append( matched_first_species)
|
| 64 |
+
|
| 65 |
+
return species_list
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# default_persist_directory = './db5' # For deployement
|
| 70 |
+
default_persist_directory_insects='./vector-databases-deployed/db5-agllm-data-isu-field-insects-all-species'
|
| 71 |
+
default_persist_directory_weeds='./vector-databases-deployed/db5-agllm-data-isu-field-weeds-all-species'
|
| 72 |
+
|
| 73 |
+
species_list_insects=get_species_list_from_db(default_persist_directory_insects)
|
| 74 |
+
species_list_weeds=get_species_list_from_db(default_persist_directory_weeds)
|
| 75 |
+
# default_persist_directory = 'vector-databases/db5-pre-completion' # For Development
|
| 76 |
+
csv_filepath1 = "./agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx"
|
| 77 |
+
csv_filepath2 = "./agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx"
|
| 78 |
+
model_name=4
|
| 79 |
+
max_tokens=400
|
| 80 |
+
system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later.
|
| 81 |
+
langchain.debug=False # TODO: DOUBLE CHECK
|
| 82 |
+
from langchain import globals
|
| 83 |
+
globals.set_debug(False)
|
| 84 |
+
|
| 85 |
+
retriever_k_value=3
|
| 86 |
+
embedding = OpenAIEmbeddings()
|
| 87 |
+
print("Started....")
|
| 88 |
+
class ChatOpenRouter(ChatOpenAI):
|
| 89 |
+
openai_api_base: str
|
| 90 |
+
openai_api_key: str
|
| 91 |
+
model_name: str
|
| 92 |
+
|
| 93 |
+
def __init__(self,
|
| 94 |
+
model_name: str,
|
| 95 |
+
openai_api_key: [str] = None,
|
| 96 |
+
openai_api_base: str = "https://openrouter.ai/api/v1",
|
| 97 |
+
**kwargs):
|
| 98 |
+
openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY')
|
| 99 |
+
super().__init__(openai_api_base=openai_api_base,
|
| 100 |
+
openai_api_key=openai_api_key,
|
| 101 |
+
model_name=model_name, **kwargs)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
######### todo: skipping the first step
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# print(# Single example
|
| 109 |
+
# vectordb.as_retriever(k=2, search_kwargs={"filter": {"matched_specie_0": "Hypagyrtis unipunctata"}, 'k':1}).get_relevant_documents(
|
| 110 |
+
# "Checking if retriever is correctly initalized?"
|
| 111 |
+
# ))
|
| 112 |
+
|
| 113 |
+
columns = ['species', 'common name', 'order', 'family',
|
| 114 |
+
'genus', 'Updated role in ecosystem', 'Proof',
|
| 115 |
+
'ipm strategies', 'size of insect', 'geographical spread',
|
| 116 |
+
'life cycle specifics', 'pest for plant species', 'species status',
|
| 117 |
+
'distribution area', 'appearance', 'identification']
|
| 118 |
+
|
| 119 |
+
df1 = pd.read_excel(csv_filepath1, usecols=columns)
|
| 120 |
+
df2 = pd.read_excel(csv_filepath2, usecols=columns)
|
| 121 |
+
|
| 122 |
+
all_insects_data = pd.concat([df1, df2], ignore_index=True)
|
| 123 |
+
|
| 124 |
+
def get_prompt_with_vetted_info_from_specie_name(search_for_specie, mode):
|
| 125 |
+
|
| 126 |
+
def read_and_format_filtered_csv_better(insect_specie):
|
| 127 |
+
filtered_data = all_insects_data[all_insects_data['species'] == insect_specie]
|
| 128 |
+
formatted_data = ""
|
| 129 |
+
# Format the filtered data
|
| 130 |
+
for index, row in filtered_data.iterrows():
|
| 131 |
+
row_data = [f"{col}: {row[col]}" for col in filtered_data.columns]
|
| 132 |
+
formatted_row = "\n".join(row_data)
|
| 133 |
+
formatted_data += f"{formatted_row}\n"
|
| 134 |
+
|
| 135 |
+
return formatted_data
|
| 136 |
+
|
| 137 |
+
# Use the path to your CSV file here
|
| 138 |
+
|
| 139 |
+
vetted_info=read_and_format_filtered_csv_better(search_for_specie)
|
| 140 |
+
|
| 141 |
+
if mode=="Farmer":
|
| 142 |
+
language_constraint="The language should be acustomed to the Farmers. Given question is likely to be asked by a farmer in the field will ask which will help to make decisions which are immediate and practical."
|
| 143 |
+
elif mode=="Researcher":
|
| 144 |
+
language_constraint="The language should be acustomed to a researcher. Given question is likely to be asked by a scientist which are comprehensive and aimed at exploring new knowledge or refining existing methodologies"
|
| 145 |
+
else:
|
| 146 |
+
print("No valid mode provided. Exiting")
|
| 147 |
+
exit()
|
| 148 |
+
# general_system_template = """
|
| 149 |
+
# In every question you are provided information about the insect/weed. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect/weed species and a question by the user. answer the question according to these two types of informations.
|
| 150 |
+
# ----
|
| 151 |
+
# Vetted info is as follows:
|
| 152 |
+
# {vetted_info}
|
| 153 |
+
# ----
|
| 154 |
+
# The context retrieved for documents about this particular question is as follows:
|
| 155 |
+
# {context}
|
| 156 |
+
# ----
|
| 157 |
+
# Additional Instruction:
|
| 158 |
+
# 1. Reference Constraint
|
| 159 |
+
# At the end of each answer provide the source/reference for the given data in following format:
|
| 160 |
+
# \n\n[enter two new lines before writing below] References:
|
| 161 |
+
# Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'.
|
| 162 |
+
# Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used.
|
| 163 |
+
# 2. Information Constraint:
|
| 164 |
+
# Only answer the question from information provided otherwise say you dont know. You have to answer in 50 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about.
|
| 165 |
+
# 3. Language constraint:
|
| 166 |
+
# {language_constraint}
|
| 167 |
+
|
| 168 |
+
# ----
|
| 169 |
+
# """.format(vetted_info=vetted_info, language_constraint=language_constraint,context="{context}", )
|
| 170 |
+
|
| 171 |
+
general_system_template = f"""
|
| 172 |
+
You are an AI assistant specialized in providing information about insects/weeds. Answer the user's question based on the available information or your general knowledge.
|
| 173 |
+
|
| 174 |
+
The context retrieved for this question is as follows:
|
| 175 |
+
{{context}}
|
| 176 |
+
|
| 177 |
+
Instructions:
|
| 178 |
+
1. Evaluate the relevance of the provided context to the question.
|
| 179 |
+
2. If the context contains relevant information, use it to answer the question and explicitly mention "Based on provided information" in your source.
|
| 180 |
+
3. If the context does not contain relevant information, use your general knowledge to answer the question and state "Based on general knowledge" as the source.
|
| 181 |
+
4. Format your response as follows:
|
| 182 |
+
Answer: Provide a concise answer in less than 50 words.
|
| 183 |
+
Source: State either "Based on provided information" or "Based on general knowledge".
|
| 184 |
+
|
| 185 |
+
5. Language constraint:
|
| 186 |
+
{language_constraint}
|
| 187 |
+
|
| 188 |
+
Question: {{question}}
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
general_user_template = "Question:```{question}```"
|
| 192 |
+
messages_formatted = [
|
| 193 |
+
SystemMessagePromptTemplate.from_template(general_system_template),
|
| 194 |
+
HumanMessagePromptTemplate.from_template(general_user_template)
|
| 195 |
+
]
|
| 196 |
+
qa_prompt = ChatPromptTemplate.from_messages( messages_formatted )
|
| 197 |
+
# print(qa_prompt)
|
| 198 |
+
return qa_prompt
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
qa_prompt=get_prompt_with_vetted_info_from_specie_name("Papaipema nebris", "Researcher")
|
| 205 |
+
# print("First prompt is intialized as: " , qa_prompt, "\n\n")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer', return_messages=True) # https://github.com/langchain-ai/langchain/issues/9394#issuecomment-1683538834
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if model_name==4:
|
| 212 |
+
llm_openai = ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens) # TODO: NEW MODEL VERSION AVAILABLE
|
| 213 |
+
else:
|
| 214 |
+
llm_openai = ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
|
| 215 |
+
|
| 216 |
+
specie_selector="Papaipema nebris"
|
| 217 |
+
filter = {
|
| 218 |
+
"$or": [
|
| 219 |
+
{"matched_specie_0": specie_selector},
|
| 220 |
+
{"matched_specie_1": specie_selector},
|
| 221 |
+
{"matched_specie_2": specie_selector},
|
| 222 |
+
]
|
| 223 |
+
}
|
| 224 |
+
# retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
|
| 225 |
+
|
| 226 |
+
# qa_chain = ConversationalRetrievalChain.from_llm(
|
| 227 |
+
# llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,\
|
| 228 |
+
# combine_docs_chain_kwargs={'prompt': qa_prompt}
|
| 229 |
+
|
| 230 |
+
# )
|
| 231 |
+
#
|
| 232 |
+
|
| 233 |
+
def initialize_qa_chain(specie_selector, application_mode, model_name="GPT-4", database_persistent_directory=default_persist_directory_insects):
|
| 234 |
+
if model_name=="GPT-4":
|
| 235 |
+
chosen_llm=ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens)
|
| 236 |
+
elif model_name=="GPT-3.5":
|
| 237 |
+
chosen_llm=ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
|
| 238 |
+
elif model_name=="Llama-3 70B":
|
| 239 |
+
chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-70b-instruct", temperature=0,max_tokens=max_tokens )
|
| 240 |
+
elif model_name=="Llama-3 8B":
|
| 241 |
+
chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-8b-instruct", temperature=0, max_tokens=max_tokens)
|
| 242 |
+
elif model_name=="Gemini-1.5 Pro":
|
| 243 |
+
chosen_llm = ChatOpenRouter(model_name="google/gemini-pro-1.5", temperature=0, max_tokens=max_tokens)
|
| 244 |
+
elif model_name=="Claude 3 Opus":
|
| 245 |
+
chosen_llm = ChatAnthropic(model_name='claude-3-opus-20240229', temperature=0, max_tokens=max_tokens)
|
| 246 |
+
|
| 247 |
+
else:
|
| 248 |
+
print("No appropriate llm was selected")
|
| 249 |
+
exit()
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
filter = {
|
| 254 |
+
"$or": [
|
| 255 |
+
{"matched_specie_0": specie_selector},
|
| 256 |
+
{"matched_specie_1": specie_selector},
|
| 257 |
+
{"matched_specie_2": specie_selector},
|
| 258 |
+
{"matched_specie_3": specie_selector},
|
| 259 |
+
{"matched_specie_4": specie_selector},
|
| 260 |
+
{"matched_specie_5": specie_selector},
|
| 261 |
+
{"matched_specie_6": specie_selector},
|
| 262 |
+
{"matched_specie_7": specie_selector},
|
| 263 |
+
{"matched_specie_8": specie_selector},
|
| 264 |
+
{"matched_specie_9": specie_selector},
|
| 265 |
+
{"matched_specie_10": specie_selector}
|
| 266 |
+
]
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
embedding = OpenAIEmbeddings()
|
| 270 |
+
vectordb = Chroma(persist_directory=database_persistent_directory,
|
| 271 |
+
embedding_function=embedding)
|
| 272 |
+
|
| 273 |
+
# print("got updated retriever without metadata filtering")
|
| 274 |
+
retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
|
| 275 |
+
|
| 276 |
+
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
|
| 277 |
+
qa_prompt=get_prompt_with_vetted_info_from_specie_name(specie_selector, application_mode)
|
| 278 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 279 |
+
chosen_llm, retriever, memory=memory, verbose=False, return_source_documents=True,
|
| 280 |
+
combine_docs_chain_kwargs={'prompt': qa_prompt}
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
return qa_chain
|
| 284 |
+
# result = qa_chain.invoke({"question": "where are stalk borer eggs laid?"})
|
| 285 |
+
# print("Got the first LLM task working: ", result)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
#Application Interface:
|
| 289 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 290 |
+
with gr.Row():
|
| 291 |
+
with gr.Column(scale=1):
|
| 292 |
+
gr.Markdown(
|
| 293 |
+
"""
|
| 294 |
+

|
| 295 |
+
"""
|
| 296 |
+
)
|
| 297 |
+
with gr.Column(scale=1):
|
| 298 |
+
gr.Markdown(
|
| 299 |
+
"""
|
| 300 |
+

|
| 301 |
+
"""
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Configure UI layout
|
| 305 |
+
chatbot = gr.Chatbot(height=600, label="AgLLM")
|
| 306 |
+
with gr.Row():
|
| 307 |
+
with gr.Column(scale=1):
|
| 308 |
+
with gr.Row():
|
| 309 |
+
domain_name = gr.Dropdown(
|
| 310 |
+
list(["Insects", "Weeds"]),
|
| 311 |
+
value="Insects",
|
| 312 |
+
label="Domain",
|
| 313 |
+
info="Select Domain",
|
| 314 |
+
interactive=True,
|
| 315 |
+
scale=1,
|
| 316 |
+
visible=True
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# Model selection
|
| 320 |
+
specie_selector = gr.Dropdown(
|
| 321 |
+
species_list_insects,
|
| 322 |
+
value=species_list_insects[0],
|
| 323 |
+
label="Species",
|
| 324 |
+
info="Select the Species",
|
| 325 |
+
interactive=True,
|
| 326 |
+
scale=1,
|
| 327 |
+
visible=True
|
| 328 |
+
)
|
| 329 |
+
with gr.Row():
|
| 330 |
+
model_name = gr.Dropdown(
|
| 331 |
+
list(["GPT-4", "GPT-3.5", "Llama-3 70B", "Llama-3 8B", "Gemini-1.5 Pro", "Claude 3 Opus"]),
|
| 332 |
+
value="Llama-3 70B",
|
| 333 |
+
label="LLM",
|
| 334 |
+
info="Select the LLM",
|
| 335 |
+
interactive=True,
|
| 336 |
+
scale=1,
|
| 337 |
+
visible=True
|
| 338 |
+
)
|
| 339 |
+
application_mode = gr.Dropdown(
|
| 340 |
+
list(["Farmer", "Researcher"]),
|
| 341 |
+
value="Researcher",
|
| 342 |
+
label="Mode",
|
| 343 |
+
info="Select the Mode",
|
| 344 |
+
interactive=True,
|
| 345 |
+
scale=1,
|
| 346 |
+
visible=True
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
with gr.Column(scale=2):
|
| 351 |
+
# User input prompt text field
|
| 352 |
+
user_prompt_message = gr.Textbox(placeholder="Please add user prompt here", label="User prompt")
|
| 353 |
+
with gr.Row():
|
| 354 |
+
# clear = gr.Button("Clear Conversation", scale=2)
|
| 355 |
+
submitBtn = gr.Button("Submit", scale=8)
|
| 356 |
+
|
| 357 |
+
state = gr.State([])
|
| 358 |
+
qa_chain_state = gr.State(value=None)
|
| 359 |
+
|
| 360 |
+
# Handle user message
|
| 361 |
+
def user(user_prompt_message, history):
|
| 362 |
+
# print("HISTORY IS: ", history) # TODO: REMOVE IT LATER
|
| 363 |
+
if user_prompt_message != "":
|
| 364 |
+
return history + [[user_prompt_message, None]]
|
| 365 |
+
else:
|
| 366 |
+
return history + [["Invalid prompts - user prompt cannot be empty", None]]
|
| 367 |
+
|
| 368 |
+
# Chatbot logic for configuration, sending the prompts, rendering the streamed back generations, etc.
|
| 369 |
+
def bot(model_name, application_mode, user_prompt_message, history, messages_history, qa_chain, domain_name):
|
| 370 |
+
if qa_chain == None:
|
| 371 |
+
qa_chain=init_qa_chain(species_list_insects[0], application_mode, model_name, domain_name)
|
| 372 |
+
|
| 373 |
+
dialog = []
|
| 374 |
+
bot_message = ""
|
| 375 |
+
history[-1][1] = "" # Placeholder for the answer
|
| 376 |
+
|
| 377 |
+
dialog = [
|
| 378 |
+
{"role": "user", "content": user_prompt_message},
|
| 379 |
+
]
|
| 380 |
+
messages_history += dialog
|
| 381 |
+
|
| 382 |
+
# Queue for streamed character rendering
|
| 383 |
+
q = Queue()
|
| 384 |
+
|
| 385 |
+
# Async task for streamed chain results wired to callbacks we previously defined, so we don't block the UI
|
| 386 |
+
|
| 387 |
+
def task(user_prompt_message):
|
| 388 |
+
result = qa_chain.invoke({"question": user_prompt_message})
|
| 389 |
+
answer = result["answer"]
|
| 390 |
+
source_documents = result.get("source_documents", [])
|
| 391 |
+
print("SOURCE DOCUMENTS: ", source_documents)
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
answer_start = answer.find("Answer:")
|
| 395 |
+
source_start = answer.find("Source:")
|
| 396 |
+
|
| 397 |
+
if answer_start != -1 and source_start != -1:
|
| 398 |
+
model_answer = answer[answer_start + len("Answer:"):source_start].strip()
|
| 399 |
+
source = answer[source_start + len("Source:"):].strip()
|
| 400 |
+
|
| 401 |
+
if "Based on provided information" in source:
|
| 402 |
+
# Extract the most relevant source document
|
| 403 |
+
if source_documents:
|
| 404 |
+
doc = source_documents[0]
|
| 405 |
+
species_name = doc.metadata.get('matched_specie_0', 'Unspecified species')
|
| 406 |
+
|
| 407 |
+
if domain_name == "Insects":
|
| 408 |
+
formatted_source = f'Iowa State University Extension and Outreach. "Field Crop Insects." Iowa State University Extension Store, June 26, 2023. https://store.extension.iastate.edu/product/13725. Information about {species_name}.'
|
| 409 |
+
elif domain_name == "Weeds":
|
| 410 |
+
formatted_source = f'Iowa State University Extension and Outreach. "Weed Identification Field Guide 2nd Edition." Iowa State University Extension Store, August 2015. https://store.extension.iastate.edu/product/13358. Information about {species_name}.'
|
| 411 |
+
else:
|
| 412 |
+
formatted_source = f"Based on provided information about {species_name}, but domain is unspecified."
|
| 413 |
+
else:
|
| 414 |
+
formatted_source = "Based on provided information, but source details are unavailable."
|
| 415 |
+
else:
|
| 416 |
+
formatted_source = source
|
| 417 |
+
|
| 418 |
+
formatted_response = f"Answer:\n{model_answer}\n\nSource:\n{formatted_source}"
|
| 419 |
+
else:
|
| 420 |
+
formatted_response = answer
|
| 421 |
+
except Exception as e:
|
| 422 |
+
print(f"Error parsing output: {e}")
|
| 423 |
+
formatted_response = answer
|
| 424 |
+
|
| 425 |
+
return formatted_response
|
| 426 |
+
|
| 427 |
+
history[-1][1] = task(user_prompt_message)
|
| 428 |
+
return [history, messages_history]
|
| 429 |
+
|
| 430 |
+
# Initialize the chat history with default system message
|
| 431 |
+
def init_history(messages_history):
|
| 432 |
+
messages_history = []
|
| 433 |
+
messages_history += [system_message]
|
| 434 |
+
return messages_history
|
| 435 |
+
|
| 436 |
+
# Clean up the user input text field
|
| 437 |
+
def input_cleanup():
|
| 438 |
+
return ""
|
| 439 |
+
|
| 440 |
+
def init_qa_chain(specie_selector, application_mode, model_name, domain_name):
|
| 441 |
+
if domain_name=="Insects":
|
| 442 |
+
qa_chain = initialize_qa_chain(specie_selector, application_mode, model_name, default_persist_directory_insects)
|
| 443 |
+
elif domain_name=="Weeds":
|
| 444 |
+
qa_chain = initialize_qa_chain(specie_selector, application_mode, model_name, default_persist_directory_weeds)
|
| 445 |
+
else:
|
| 446 |
+
print("No Appropriate Chain Selected")
|
| 447 |
+
return qa_chain
|
| 448 |
+
|
| 449 |
+
specie_selector.change(
|
| 450 |
+
init_qa_chain,
|
| 451 |
+
inputs=[specie_selector, application_mode,model_name, domain_name ],
|
| 452 |
+
outputs=[qa_chain_state]
|
| 453 |
+
)
|
| 454 |
+
model_name.change(
|
| 455 |
+
init_qa_chain,
|
| 456 |
+
inputs=[specie_selector, application_mode,model_name, domain_name ],
|
| 457 |
+
outputs=[qa_chain_state]
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
#####
|
| 461 |
+
def update_species_list(domain):
|
| 462 |
+
if domain == "Insects":
|
| 463 |
+
return gr.Dropdown( species_list_insects, value=species_list_insects[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
|
| 464 |
+
elif domain == "Weeds":
|
| 465 |
+
return gr.Dropdown( species_list_weeds, value=species_list_weeds[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
|
| 466 |
+
|
| 467 |
+
domain_name.change(
|
| 468 |
+
update_species_list,
|
| 469 |
+
inputs=[domain_name],
|
| 470 |
+
outputs=[specie_selector]
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# When the user clicks Enter and the user message is submitted
|
| 474 |
+
user_prompt_message.submit(
|
| 475 |
+
user,
|
| 476 |
+
[user_prompt_message, chatbot],
|
| 477 |
+
[chatbot],
|
| 478 |
+
queue=False
|
| 479 |
+
).then(
|
| 480 |
+
bot,
|
| 481 |
+
[model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name],
|
| 482 |
+
[chatbot, state]
|
| 483 |
+
).then(input_cleanup,
|
| 484 |
+
[],
|
| 485 |
+
[user_prompt_message],
|
| 486 |
+
queue=False
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# When the user clicks the submit button
|
| 490 |
+
submitBtn.click(
|
| 491 |
+
user,
|
| 492 |
+
[user_prompt_message, chatbot],
|
| 493 |
+
[chatbot],
|
| 494 |
+
queue=False
|
| 495 |
+
).then(
|
| 496 |
+
bot,
|
| 497 |
+
[model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name],
|
| 498 |
+
[chatbot, state]
|
| 499 |
+
).then(
|
| 500 |
+
input_cleanup,
|
| 501 |
+
[],
|
| 502 |
+
[user_prompt_message],
|
| 503 |
+
queue=False
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# When the user clicks the clear button
|
| 507 |
+
# clear.click(lambda: None, None, chatbot, queue=False).success(init_history, [state], [state])
|
| 508 |
+
if __name__ == "__main__":
|
| 509 |
+
# demo.launch()
|
| 510 |
+
demo.queue().launch(allowed_paths=["/"], share=False, show_error=True)
|
app_database_prep.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
|
| 4 |
+
from langchain.document_loaders import PyPDFDirectoryLoader
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import langchain
|
| 7 |
+
from queue import Queue
|
| 8 |
+
from typing import Any, List
|
| 9 |
+
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
| 10 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 11 |
+
from langchain.schema import LLMResult
|
| 12 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 13 |
+
from langchain.vectorstores import FAISS
|
| 14 |
+
from langchain.prompts.prompt import PromptTemplate
|
| 15 |
+
from anyio.from_thread import start_blocking_portal #For model callback streaming
|
| 16 |
+
|
| 17 |
+
langchain.debug=True # TODO: DOUBLE CHECK
|
| 18 |
+
|
| 19 |
+
system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later.
|
| 20 |
+
import os
|
| 21 |
+
from dotenv import load_dotenv
|
| 22 |
+
|
| 23 |
+
import streamlit as st
|
| 24 |
+
|
| 25 |
+
from langchain.document_loaders import PyPDFLoader
|
| 26 |
+
from langchain.text_splitter import CharacterTextSplitter
|
| 27 |
+
from langchain.embeddings import OpenAIEmbeddings
|
| 28 |
+
from langchain.chains.question_answering import load_qa_chain
|
| 29 |
+
from langchain.chat_models import ChatOpenAI
|
| 30 |
+
from langchain.vectorstores import Chroma
|
| 31 |
+
import chromadb
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
## added information in metadata:
|
| 35 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 36 |
+
from langchain.llms import OpenAI
|
| 37 |
+
from langchain.chains import RetrievalQA
|
| 38 |
+
from langchain.document_loaders import TextLoader
|
| 39 |
+
from langchain.document_loaders import DirectoryLoader
|
| 40 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
| 41 |
+
from langchain.schema import Document
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Function to process a sheet from the Excel file
|
| 45 |
+
def process_excel_sheet(
|
| 46 |
+
excel_path: str,
|
| 47 |
+
sheet_name: str,
|
| 48 |
+
region: str,
|
| 49 |
+
splitter: RecursiveCharacterTextSplitter
|
| 50 |
+
) -> List[Document]:
|
| 51 |
+
"""Loads data from an Excel sheet, creates Documents, splits them, and adds metadata."""
|
| 52 |
+
print(f"--- Processing Excel Sheet: {sheet_name} (Region: {region}) ---")
|
| 53 |
+
try:
|
| 54 |
+
df = pd.read_excel(excel_path, sheet_name=sheet_name)
|
| 55 |
+
print(f"Excel Data Head ({sheet_name}):\\n", df.head())
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Error loading sheet '{sheet_name}' from {excel_path}: {e}")
|
| 58 |
+
return []
|
| 59 |
+
|
| 60 |
+
initial_documents = []
|
| 61 |
+
for index, row in df.iterrows():
|
| 62 |
+
ipm_info = str(row['IPM Info']) if pd.notna(row['IPM Info']) else ""
|
| 63 |
+
# Check if essential columns exist and are not empty (removed accuracy check)
|
| 64 |
+
if pd.isna(row['Common Name']) or pd.isna(row['Species']):
|
| 65 |
+
print(f"Skipping row {index+2} in sheet '{sheet_name}' due to missing essential data (Common Name or Species).")
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
doc = Document(
|
| 69 |
+
page_content=ipm_info,
|
| 70 |
+
metadata={
|
| 71 |
+
"source": f"{excel_path}#sheet={sheet_name}#row={index+2}",
|
| 72 |
+
"common_name": row['Common Name'],
|
| 73 |
+
"species": row['Species'],
|
| 74 |
+
"matched_specie_0": row['Species'],
|
| 75 |
+
"region": region
|
| 76 |
+
}
|
| 77 |
+
)
|
| 78 |
+
initial_documents.append(doc)
|
| 79 |
+
|
| 80 |
+
if initial_documents:
|
| 81 |
+
print(f"First Document from {sheet_name} (before splitting):\\n", initial_documents[0])
|
| 82 |
+
else:
|
| 83 |
+
print(f"No documents created from sheet: {sheet_name}")
|
| 84 |
+
return [] # Return empty list if no documents were created
|
| 85 |
+
|
| 86 |
+
split_documents = []
|
| 87 |
+
for doc in initial_documents:
|
| 88 |
+
splits = splitter.split_documents([doc])
|
| 89 |
+
for i, split_doc in enumerate(splits, start=1):
|
| 90 |
+
metadata = split_doc.metadata.copy()
|
| 91 |
+
metadata["source"] = f"{metadata['source']}#chunk{i}"
|
| 92 |
+
split_doc.metadata = metadata
|
| 93 |
+
split_documents.append(split_doc)
|
| 94 |
+
|
| 95 |
+
if split_documents:
|
| 96 |
+
print(f"First Document chunk from {sheet_name}:\\n", split_documents[0])
|
| 97 |
+
|
| 98 |
+
print(f"Finished processing sheet: {sheet_name}. Found {len(split_documents)} chunks.")
|
| 99 |
+
print("---------------------------------------------------")
|
| 100 |
+
return split_documents
|
| 101 |
+
|
| 102 |
+
# --- Main Script Logic ---
|
| 103 |
+
|
| 104 |
+
# loader = DirectoryLoader('./agllm-data/', glob="./*.pdf", loader_cls=PyMuPDFLoader)
|
| 105 |
+
# loader = DirectoryLoader('/u/marshad/data/agllm-data/', glob='**/*.pdf', loader_cls=PyMuPDFLoader)
|
| 106 |
+
data_domain_identifier="agllm-data-isu-field-insects-all-species"
|
| 107 |
+
persist_directory = f'vector-databases-deployed/db5-{data_domain_identifier}' # was full
|
| 108 |
+
loader = DirectoryLoader(f'agllm-data/{data_domain_identifier}', glob='**/*.pdf', loader_cls=PyMuPDFLoader)#,# was full, loader_kwargs={'chunk_size':512})
|
| 109 |
+
chunk_size_input=512
|
| 110 |
+
metadata_raw = pd.read_csv(f"./agllm-data/{data_domain_identifier}/matched_species_results_v2.csv")
|
| 111 |
+
documents = loader.load()
|
| 112 |
+
|
| 113 |
+
## Load Excel File Path (Define once)
|
| 114 |
+
excel_file_path = "agllm-data/PestID Species.xlsx"
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
## Process PDF documents and add metadata
|
| 118 |
+
print("--- Processing PDF Documents ---")
|
| 119 |
+
pdf_documents_for_splitting = [] # Prepare list to hold docs with added metadata
|
| 120 |
+
for doc in documents:
|
| 121 |
+
# Add region for PDF docs
|
| 122 |
+
doc.metadata["region"] = "United States"
|
| 123 |
+
|
| 124 |
+
# Add species metadata (existing logic)
|
| 125 |
+
file_name_associated_with_this_doc = doc.metadata["source"].split('/')[-1]
|
| 126 |
+
matching_species_for_this_file_name = metadata_raw[metadata_raw["File Name"].str.lower() == file_name_associated_with_this_doc.lower()]["Species"]
|
| 127 |
+
# Ensure matching_species_for_this_file_name is iterable and not empty
|
| 128 |
+
if not matching_species_for_this_file_name.empty:
|
| 129 |
+
for specie_index in range(len(matching_species_for_this_file_name)):
|
| 130 |
+
# Check if specie_index is within bounds (although range should handle this)
|
| 131 |
+
if specie_index < len(matching_species_for_this_file_name):
|
| 132 |
+
specie_name = matching_species_for_this_file_name.iloc[specie_index]
|
| 133 |
+
doc.metadata["matched_specie_" + str(specie_index)] = specie_name
|
| 134 |
+
else:
|
| 135 |
+
# This case should ideally not happen with range(len(...))
|
| 136 |
+
print(f"Warning: Specie index {specie_index} out of bounds for file {file_name_associated_with_this_doc}")
|
| 137 |
+
else:
|
| 138 |
+
print(f"Warning: No matching species found in CSV for PDF: {file_name_associated_with_this_doc}")
|
| 139 |
+
|
| 140 |
+
pdf_documents_for_splitting.append(doc) # Add modified doc to new list
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# Initialize Text Splitter
|
| 144 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size_input, chunk_overlap=10)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# Split PDF documents
|
| 148 |
+
pdf_splitted_documents = []
|
| 149 |
+
for doc in pdf_documents_for_splitting: # Use the list with added metadata
|
| 150 |
+
splits = text_splitter.split_documents([doc])
|
| 151 |
+
for i, split_doc in enumerate(splits, start=1):
|
| 152 |
+
metadata = split_doc.metadata.copy()
|
| 153 |
+
# Update source for PDF chunks (existing logic)
|
| 154 |
+
source_base = metadata.get('source', 'unknown_source')
|
| 155 |
+
page_num = metadata.get('page', 'unknown_page')
|
| 156 |
+
metadata["source"] = f"{source_base}#page{page_num}#chunk{i}"
|
| 157 |
+
# Remove the raw page number if desired, as it's now in the source string
|
| 158 |
+
# metadata.pop('page', None)
|
| 159 |
+
split_doc.metadata = metadata
|
| 160 |
+
pdf_splitted_documents.append(split_doc)
|
| 161 |
+
|
| 162 |
+
print("First PDF Document chunk:\\n", pdf_splitted_documents[0] if pdf_splitted_documents else "No PDF documents processed")
|
| 163 |
+
print(f"Count after PDF processing: {len(pdf_splitted_documents)}")
|
| 164 |
+
print("---------------------------------------------------")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# Process Excel Sheets using the function
|
| 168 |
+
india_splitted_documents = process_excel_sheet(
|
| 169 |
+
excel_path=excel_file_path,
|
| 170 |
+
sheet_name="India",
|
| 171 |
+
region="India",
|
| 172 |
+
splitter=text_splitter
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
africa_splitted_documents = process_excel_sheet(
|
| 176 |
+
excel_path=excel_file_path,
|
| 177 |
+
sheet_name="Africa",
|
| 178 |
+
region="Africa",
|
| 179 |
+
splitter=text_splitter
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# Combine lists from all sources
|
| 184 |
+
splitted_documents = pdf_splitted_documents + india_splitted_documents + africa_splitted_documents
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# print(splitted_documents[0]) # Original print statement - commented out as we print chunks above
|
| 188 |
+
print("=== Combined Processing Done ===") # Adjusted print statement
|
| 189 |
+
print(f"Total documents after combining PDF, India, and Africa sources: {len(splitted_documents)}")
|
| 190 |
+
print("=============================")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ONLY FOR THE FIRST TIME
|
| 195 |
+
|
| 196 |
+
# Check if the persist directory exists and delete it to ensure a fresh start
|
| 197 |
+
if os.path.exists(persist_directory):
|
| 198 |
+
print(f"Deleting existing vector database directory: {persist_directory}")
|
| 199 |
+
shutil.rmtree(persist_directory)
|
| 200 |
+
print(f"Directory deleted.")
|
| 201 |
+
else:
|
| 202 |
+
print(f"Vector database directory not found, creating a new one: {persist_directory}")
|
| 203 |
+
|
| 204 |
+
embedding = OpenAIEmbeddings()
|
| 205 |
+
vectordb = Chroma.from_documents(documents=splitted_documents,
|
| 206 |
+
embedding=embedding,
|
| 207 |
+
persist_directory=persist_directory)
|
| 208 |
+
|
| 209 |
+
# persiste the db to disk
|
| 210 |
+
vectordb.persist()
|
| 211 |
+
vectordb = None
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# Now we can load the persisted database from disk, and use it as normal.
|
| 215 |
+
|
| 216 |
+
vectordb = Chroma(persist_directory=persist_directory,
|
| 217 |
+
embedding_function=embedding)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
print(vectordb.get())
|
| 221 |
+
|
| 222 |
+
#just a test script:
|
| 223 |
+
|
| 224 |
+
specie_selector="Aphis spiraecola"
|
| 225 |
+
filter = {
|
| 226 |
+
"$or": [
|
| 227 |
+
{"matched_specie_0": specie_selector},
|
| 228 |
+
{"matched_specie_1": specie_selector},
|
| 229 |
+
{"matched_specie_2": specie_selector},
|
| 230 |
+
]
|
| 231 |
+
}
|
| 232 |
+
answer = vectordb.as_retriever(search_kwargs={'k':10, 'filter': filter}).get_relevant_documents(
|
| 233 |
+
"anything else.?")
|
| 234 |
+
print(answer)
|
outdated-files/agllm-data.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e644edb3d6ae6153f2411ff768ba5047314a9fdd92f97b3d4fa25b8e28e98a5
|
| 3 |
+
size 4658799
|
outdated-files/agllm_with_evaluation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
outdated-files/app-basic.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import gradio as gr
|
| 2 |
+
|
| 3 |
+
# def greet(name):
|
| 4 |
+
# return "Hello " + name + "!"
|
| 5 |
+
|
| 6 |
+
# demo = gr.Interface(fn=greet, inputs="textbox", outputs="textbox")
|
| 7 |
+
|
| 8 |
+
# demo.launch() # Share your demo with just 1 extra parameter
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
def greet(name):
|
| 14 |
+
return "Hello " + name + "!"
|
| 15 |
+
|
| 16 |
+
demo = gr.Interface(fn=greet, inputs="textbox", outputs="textbox")
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
demo.launch()
|
outdated-files/app-old.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# https://stackoverflow.com/questions/76175046/how-to-add-prompt-to-langchain-conversationalretrievalchain-chat-over-docs-with
|
| 3 |
+
# again from:
|
| 4 |
+
# https://python.langchain.com/docs/integrations/providers/vectara/vectara_chat
|
| 5 |
+
from langchain.document_loaders import PyPDFDirectoryLoader
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import langchain
|
| 8 |
+
from queue import Queue
|
| 9 |
+
from typing import Any
|
| 10 |
+
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
| 11 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 12 |
+
from langchain.schema import LLMResult
|
| 13 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 14 |
+
from langchain.vectorstores import FAISS
|
| 15 |
+
from langchain.prompts.prompt import PromptTemplate
|
| 16 |
+
from anyio.from_thread import start_blocking_portal #For model callback streaming
|
| 17 |
+
|
| 18 |
+
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
| 19 |
+
import os
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
|
| 22 |
+
import streamlit as st
|
| 23 |
+
|
| 24 |
+
from langchain.document_loaders import PyPDFLoader
|
| 25 |
+
from langchain.text_splitter import CharacterTextSplitter
|
| 26 |
+
from langchain.embeddings import OpenAIEmbeddings
|
| 27 |
+
from langchain.chains.question_answering import load_qa_chain
|
| 28 |
+
from langchain.chat_models import ChatOpenAI
|
| 29 |
+
from langchain.vectorstores import Chroma
|
| 30 |
+
import chromadb
|
| 31 |
+
|
| 32 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 33 |
+
from langchain.llms import OpenAI
|
| 34 |
+
from langchain.chains import RetrievalQA
|
| 35 |
+
from langchain.document_loaders import TextLoader
|
| 36 |
+
from langchain.document_loaders import DirectoryLoader
|
| 37 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
| 38 |
+
from langchain.schema import Document
|
| 39 |
+
|
| 40 |
+
from langchain.memory import ConversationBufferMemory
|
| 41 |
+
|
| 42 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 43 |
+
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
| 44 |
+
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
| 45 |
+
import gradio as gr
|
| 46 |
+
from langchain.memory import ConversationBufferMemory
|
| 47 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 48 |
+
|
| 49 |
+
persist_directory = '/projects/bcjp/marshad/agllm/db5'
|
| 50 |
+
csv_filepath1 = "./agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx"
|
| 51 |
+
csv_filepath2 = "./agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx"
|
| 52 |
+
model_name=4
|
| 53 |
+
max_tokens=400
|
| 54 |
+
system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later.
|
| 55 |
+
langchain.debug=True # TODO: DOUBLE CHECK
|
| 56 |
+
retriever_k_value=2
|
| 57 |
+
embedding = OpenAIEmbeddings()
|
| 58 |
+
|
| 59 |
+
######### todo: skipping the first step
|
| 60 |
+
|
| 61 |
+
embedding = OpenAIEmbeddings()
|
| 62 |
+
vectordb = Chroma(persist_directory=persist_directory,
|
| 63 |
+
embedding_function=embedding)
|
| 64 |
+
|
| 65 |
+
retriever = vectordb.as_retriever()
|
| 66 |
+
|
| 67 |
+
print(# Single example
|
| 68 |
+
vectordb.as_retriever(k=2, search_kwargs={"filter": {"matched_specie_0": "Hypagyrtis unipunctata"}, 'k':1}).get_relevant_documents(
|
| 69 |
+
"Checking if retriever is correctly initalized?"
|
| 70 |
+
))
|
| 71 |
+
|
| 72 |
+
columns = ['species', 'common name', 'order', 'family',
|
| 73 |
+
'genus', 'Updated role in ecosystem', 'Proof',
|
| 74 |
+
'ipm strategies', 'size of insect', 'geographical spread',
|
| 75 |
+
'life cycle specifics', 'pest for plant species', 'species status',
|
| 76 |
+
'distribution area', 'appearance', 'identification']
|
| 77 |
+
|
| 78 |
+
df1 = pd.read_excel(csv_filepath1, usecols=columns)
|
| 79 |
+
df2 = pd.read_excel(csv_filepath2, usecols=columns)
|
| 80 |
+
|
| 81 |
+
all_insects_data = pd.concat([df1, df2], ignore_index=True)
|
| 82 |
+
|
| 83 |
+
def get_prompt_with_vetted_info_from_specie_name(search_for_specie, mode):
|
| 84 |
+
|
| 85 |
+
def read_and_format_filtered_csv_better(insect_specie):
|
| 86 |
+
filtered_data = all_insects_data[all_insects_data['species'] == insect_specie]
|
| 87 |
+
formatted_data = ""
|
| 88 |
+
# Format the filtered data
|
| 89 |
+
for index, row in filtered_data.iterrows():
|
| 90 |
+
row_data = [f"{col}: {row[col]}" for col in filtered_data.columns]
|
| 91 |
+
formatted_row = "\n".join(row_data)
|
| 92 |
+
formatted_data += f"{formatted_row}\n"
|
| 93 |
+
|
| 94 |
+
return formatted_data
|
| 95 |
+
|
| 96 |
+
# Use the path to your CSV file here
|
| 97 |
+
|
| 98 |
+
vetted_info=read_and_format_filtered_csv_better(search_for_specie)
|
| 99 |
+
if mode=="user":
|
| 100 |
+
language_constraint="The language should be acustomed to the end user. This question is likely asked by a farmer. So, answer things in their language. Bur for referencing information, you can use the original content. This is only for the main answer to be provided by you."
|
| 101 |
+
elif mode=="researcher":
|
| 102 |
+
language_constraint="The language should be acustomed to a researcher. This question is likely asked by an academic researcher. So you can use all the technical terms freely. And for referencing information, you can use the original content. This is only for the main answer to be provided by you."
|
| 103 |
+
else:
|
| 104 |
+
print("No valid model provided. Exiting")
|
| 105 |
+
exit()
|
| 106 |
+
general_system_template = """
|
| 107 |
+
In every question you are provided information about the insect. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect specie and a question by the user. answer the question according to these two types of informations.
|
| 108 |
+
----
|
| 109 |
+
Vetted info is as follows:
|
| 110 |
+
{vetted_info}
|
| 111 |
+
----
|
| 112 |
+
The context retrieved for documents about this particular question is a as follows:
|
| 113 |
+
{context}
|
| 114 |
+
----
|
| 115 |
+
Additional Instruction:
|
| 116 |
+
1. Reference Constraint
|
| 117 |
+
At the end of each answer provide the source/reference for the given data in following format:
|
| 118 |
+
\n\n[enter two new lines before writing below] References:
|
| 119 |
+
Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'.
|
| 120 |
+
Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used.
|
| 121 |
+
2. Information Constraint:
|
| 122 |
+
Only answer the question from information provided otherwise say you dont know. You have to answer in 150 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about.
|
| 123 |
+
3. Language constraint:
|
| 124 |
+
{language_constraint}
|
| 125 |
+
|
| 126 |
+
----
|
| 127 |
+
""".format(vetted_info=vetted_info, language_constraint=language_constraint,context="{context}", )
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
general_user_template = "Question:```{question}```"
|
| 131 |
+
messages_formatted = [
|
| 132 |
+
SystemMessagePromptTemplate.from_template(general_system_template),
|
| 133 |
+
# HumanMessagePromptTemplate.from_template(general_system_template),
|
| 134 |
+
HumanMessagePromptTemplate.from_template(general_user_template)
|
| 135 |
+
]
|
| 136 |
+
qa_prompt = ChatPromptTemplate.from_messages( messages_formatted )
|
| 137 |
+
print(qa_prompt)
|
| 138 |
+
return qa_prompt
|
| 139 |
+
qa_prompt=get_prompt_with_vetted_info_from_specie_name("Papaipema nebris", "researcher")
|
| 140 |
+
print("First prompt is intialized as: " , qa_prompt, "\n\n")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer', return_messages=True) # https://github.com/langchain-ai/langchain/issues/9394#issuecomment-1683538834
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
if model_name==4:
|
| 147 |
+
llm_openai = ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens) # TODO: NEW MODEL VERSION AVAILABLE
|
| 148 |
+
else:
|
| 149 |
+
llm_openai = ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
|
| 150 |
+
|
| 151 |
+
specie_selector="Papaipema nebris"
|
| 152 |
+
filter = {
|
| 153 |
+
"$or": [
|
| 154 |
+
{"matched_specie_0": specie_selector},
|
| 155 |
+
{"matched_specie_1": specie_selector},
|
| 156 |
+
{"matched_specie_2": specie_selector},
|
| 157 |
+
]
|
| 158 |
+
}
|
| 159 |
+
retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
|
| 160 |
+
|
| 161 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 162 |
+
llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,\
|
| 163 |
+
combine_docs_chain_kwargs={'prompt': qa_prompt}
|
| 164 |
+
|
| 165 |
+
)
|
| 166 |
+
#
|
| 167 |
+
|
| 168 |
+
def initialize_qa_chain(specie_selector, application_mode):
|
| 169 |
+
|
| 170 |
+
filter = {
|
| 171 |
+
"$or": [
|
| 172 |
+
{"matched_specie_0": specie_selector},
|
| 173 |
+
{"matched_specie_1": specie_selector},
|
| 174 |
+
{"matched_specie_2": specie_selector},
|
| 175 |
+
]
|
| 176 |
+
}
|
| 177 |
+
retriever = vectordb.as_retriever(search_kwargs={'k':2, 'filter': filter})
|
| 178 |
+
|
| 179 |
+
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
|
| 180 |
+
qa_prompt=get_prompt_with_vetted_info_from_specie_name(specie_selector, application_mode)
|
| 181 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 182 |
+
llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,
|
| 183 |
+
combine_docs_chain_kwargs={'prompt': qa_prompt}
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
return qa_chain
|
| 187 |
+
result = qa_chain.invoke({"question": "where are stalk borer eggs laid?"})
|
| 188 |
+
print("Got the first LLM task working: ", result)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
#Application Interface:
|
| 192 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 193 |
+
with gr.Row():
|
| 194 |
+
with gr.Column(scale=1):
|
| 195 |
+
gr.Markdown(
|
| 196 |
+
"""
|
| 197 |
+

|
| 198 |
+
"""
|
| 199 |
+
)
|
| 200 |
+
with gr.Column(scale=1):
|
| 201 |
+
gr.Markdown(
|
| 202 |
+
"""
|
| 203 |
+

|
| 204 |
+
"""
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Configure UI layout
|
| 208 |
+
chatbot = gr.Chatbot(height=600, label="AgLLM22")
|
| 209 |
+
with gr.Row():
|
| 210 |
+
with gr.Column(scale=1):
|
| 211 |
+
with gr.Row():
|
| 212 |
+
# Model selection
|
| 213 |
+
specie_selector = gr.Dropdown(
|
| 214 |
+
list(["Papaipema nebris", "Nomophila nearctica"]),
|
| 215 |
+
value="Papaipema nebris",
|
| 216 |
+
label="Species",
|
| 217 |
+
info="Select the Species",
|
| 218 |
+
interactive=True,
|
| 219 |
+
scale=2,
|
| 220 |
+
visible=True
|
| 221 |
+
)
|
| 222 |
+
with gr.Row():
|
| 223 |
+
application_mode = gr.Dropdown(
|
| 224 |
+
list(["user", "researcher"]),
|
| 225 |
+
value="researcher",
|
| 226 |
+
label="Mode",
|
| 227 |
+
info="Select the Mode",
|
| 228 |
+
interactive=True,
|
| 229 |
+
scale=2,
|
| 230 |
+
visible=True
|
| 231 |
+
)
|
| 232 |
+
with gr.Row():
|
| 233 |
+
pass
|
| 234 |
+
with gr.Column(scale=2):
|
| 235 |
+
# User input prompt text field
|
| 236 |
+
user_prompt_message = gr.Textbox(placeholder="Please add user prompt here", label="User prompt")
|
| 237 |
+
with gr.Row():
|
| 238 |
+
# clear = gr.Button("Clear Conversation", scale=2)
|
| 239 |
+
submitBtn = gr.Button("Submit", scale=8)
|
| 240 |
+
|
| 241 |
+
state = gr.State([])
|
| 242 |
+
qa_chain_state = gr.State(value=None)
|
| 243 |
+
|
| 244 |
+
# Handle user message
|
| 245 |
+
def user(user_prompt_message, history):
|
| 246 |
+
print("HISTORY IS: ", history) # TODO: REMOVE IT LATER
|
| 247 |
+
if user_prompt_message != "":
|
| 248 |
+
return history + [[user_prompt_message, None]]
|
| 249 |
+
else:
|
| 250 |
+
return history + [["Invalid prompts - user prompt cannot be empty", None]]
|
| 251 |
+
|
| 252 |
+
# Chatbot logic for configuration, sending the prompts, rendering the streamed back generations, etc.
|
| 253 |
+
def bot(application_mode, user_prompt_message, history, messages_history, qa_chain):
|
| 254 |
+
if qa_chain == None:
|
| 255 |
+
qa_chain=init_qa_chain("Papaipema nebris", application_mode)
|
| 256 |
+
|
| 257 |
+
dialog = []
|
| 258 |
+
bot_message = ""
|
| 259 |
+
history[-1][1] = "" # Placeholder for the answer
|
| 260 |
+
|
| 261 |
+
dialog = [
|
| 262 |
+
{"role": "user", "content": user_prompt_message},
|
| 263 |
+
]
|
| 264 |
+
messages_history += dialog
|
| 265 |
+
|
| 266 |
+
# Queue for streamed character rendering
|
| 267 |
+
q = Queue()
|
| 268 |
+
|
| 269 |
+
# Async task for streamed chain results wired to callbacks we previously defined, so we don't block the UI
|
| 270 |
+
def task(user_prompt_message):
|
| 271 |
+
ret = qa_chain.invoke({"question": user_prompt_message})["answer"]
|
| 272 |
+
return ret
|
| 273 |
+
|
| 274 |
+
history[-1][1] = task(user_prompt_message)
|
| 275 |
+
return [history, messages_history]
|
| 276 |
+
|
| 277 |
+
# Initialize the chat history with default system message
|
| 278 |
+
def init_history(messages_history):
|
| 279 |
+
messages_history = []
|
| 280 |
+
messages_history += [system_message]
|
| 281 |
+
return messages_history
|
| 282 |
+
|
| 283 |
+
# Clean up the user input text field
|
| 284 |
+
def input_cleanup():
|
| 285 |
+
return ""
|
| 286 |
+
|
| 287 |
+
def init_qa_chain(specie_selector, application_mode):
|
| 288 |
+
qa_chain = initialize_qa_chain(specie_selector, application_mode)
|
| 289 |
+
return qa_chain
|
| 290 |
+
|
| 291 |
+
specie_selector.change(
|
| 292 |
+
init_qa_chain,
|
| 293 |
+
inputs=[specie_selector, application_mode],
|
| 294 |
+
outputs=[qa_chain_state]
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# When the user clicks Enter and the user message is submitted
|
| 298 |
+
user_prompt_message.submit(
|
| 299 |
+
user,
|
| 300 |
+
[user_prompt_message, chatbot],
|
| 301 |
+
[chatbot],
|
| 302 |
+
queue=False
|
| 303 |
+
).then(
|
| 304 |
+
bot,
|
| 305 |
+
[application_mode, user_prompt_message, chatbot, state, qa_chain_state],
|
| 306 |
+
[chatbot, state]
|
| 307 |
+
).then(input_cleanup,
|
| 308 |
+
[],
|
| 309 |
+
[user_prompt_message],
|
| 310 |
+
queue=False
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# When the user clicks the submit button
|
| 314 |
+
submitBtn.click(
|
| 315 |
+
user,
|
| 316 |
+
[user_prompt_message, chatbot],
|
| 317 |
+
[chatbot],
|
| 318 |
+
queue=False
|
| 319 |
+
).then(
|
| 320 |
+
bot,
|
| 321 |
+
[application_mode, user_prompt_message, chatbot, state, qa_chain_state],
|
| 322 |
+
[chatbot, state]
|
| 323 |
+
).then(
|
| 324 |
+
input_cleanup,
|
| 325 |
+
[],
|
| 326 |
+
[user_prompt_message],
|
| 327 |
+
queue=False
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# When the user clicks the clear button
|
| 331 |
+
# clear.click(lambda: None, None, chatbot, queue=False).success(init_history, [state], [state])
|
| 332 |
+
if __name__ == "__main__":
|
| 333 |
+
# demo.launch()
|
| 334 |
+
demo.queue().launch(allowed_paths=["/"], server_name="0.0.0.0", share=True, debug=True)
|
outdated-files/dd.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
input_variables=['context', 'question'] messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['context'], template=" \n In every question you are provided information about the insect. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect specie and a question by the user. answer the question according to these two types of informations. \n ---- \n Vetted info is as follows:\n species: Papaipema nebris\ncommon name: stalk borer\norder: Lepidoptera\nfamily: Noctuidae\ngenus: Papaipema\nUpdated role in ecosystem: Pest\nProof: The stalk borer, Papaipema nebris, is correctly classified as a pest because it damages a variety of crops by boring into their stems. This can significantly reduce crop yields and impact agricultural production. Its larvae feed inside the stalks causing the plants to wilt and often die, which characterizes its detrimental role in the ecosystem predominantly affecting agricultural systems.\nipm strategies: IPM strategies for stalk borer include crop rotation, planting trap crops, timely planting to avoid peak egg-laying periods, mechanical control through cultivation, and chemical controls with insecticides if necessary.\nsize of insect: 25-35 mm\ngeographical spread: United States, Canada\nlife cycle specifics: Stalk borers undergo complete metamorphosis, including egg, larva, pupa, and adult stages. Eggs are laid on grasses in fall, hatch in spring, and the larvae bore into stems of a variety of host plants. Pupation takes place within the host plant or in the soil.\npest for plant species: Corn, sorghum, and other grasses\nspecies status: It is native to North America.\ndistribution area: Most widespread in the corn belt of the United States.\nappearance: The adult stalk borer is a medium-sized moth with a wingspan of about 25-35 mm with distinct, dark, and light gray or brown patterns on its wings.\nidentification: Eggs are white and difficult to find, larvae are purplish or pinkish with a distinct white stripe and darker line above this, and adults are gray or brown moths with a characteristic wing pattern.\n\n ----\n The context retrieved for documents about this particular question is a as follows:\n {context}\n ----\n Additional Instruction:\n 1. Reference Constraint\n At the end of each answer provide the source/reference for the given data in following format: \n \n\n[enter two new lines before writing below] References:\n Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'. \n Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used. \n 2. Information Constraint: \n Only answer the question from information provided otherwise say you dont know. You have to answer in 150 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about. \n 3. Language constraint:\n The language should be acustomed to a researcher. This question is likely asked by an academic researcher. So you can use all the technical terms freely. And for referencing information, you can use the original content. This is only for the main answer to be provided by you.\n\n ----\n ")), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='Question:```{question}```'))]
|
outdated-files/rag-evaluation (outdated).ipynb
ADDED
|
@@ -0,0 +1,896 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"application/vnd.databricks.v1+cell": {
|
| 7 |
+
"cellMetadata": {},
|
| 8 |
+
"inputWidgets": {},
|
| 9 |
+
"nuid": "42084110-295b-493a-9b3e-5d8d29ff78b3",
|
| 10 |
+
"showTitle": false,
|
| 11 |
+
"title": ""
|
| 12 |
+
}
|
| 13 |
+
},
|
| 14 |
+
"source": [
|
| 15 |
+
"# LLM RAG Evaluation with MLflow Example Notebook\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"In this notebook, we will demonstrate how to evaluate various a RAG system with MLflow."
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "raw",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"source": [
|
| 24 |
+
"<a href=\"https://raw.githubusercontent.com/mlflow/mlflow/master/docs/source/llms/llm-evaluate/notebooks/rag-evaluation.ipynb\" class=\"notebook-download-btn\"><i class=\"fas fa-download\"></i>Download this Notebook</a>"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "markdown",
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"source": [
|
| 31 |
+
"We need to set our OpenAI API key.\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"In order to set your private key safely, please be sure to either export your key through a command-line terminal for your current instance, or, for a permanent addition to all user-based sessions, configure your favored environment management configuration file (i.e., .bashrc, .zshrc) to have the following entry:\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"`OPENAI_API_KEY=<your openai API key>`\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"If using Azure OpenAI, you will instead need to set\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"`OPENAI_API_TYPE=\"azure\"`\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"`OPENAI_API_VERSION=<YYYY-MM-DD>`\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"`OPENAI_API_KEY=<https://<>.<>.<>.com>`\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"`OPENAI_API_DEPLOYMENT_NAME=<deployment name>`\n"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 4,
|
| 51 |
+
"metadata": {
|
| 52 |
+
"application/vnd.databricks.v1+cell": {
|
| 53 |
+
"cellMetadata": {
|
| 54 |
+
"byteLimit": 2048000,
|
| 55 |
+
"rowLimit": 10000
|
| 56 |
+
},
|
| 57 |
+
"inputWidgets": {},
|
| 58 |
+
"nuid": "fb946228-62fb-4d68-9732-75935c9cb401",
|
| 59 |
+
"showTitle": false,
|
| 60 |
+
"title": ""
|
| 61 |
+
}
|
| 62 |
+
},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"import pandas as pd\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"import mlflow\n",
|
| 68 |
+
"import os\n",
|
| 69 |
+
"os.environ[\"OPENAI_API_KEY\"] =\"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\"\n"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "markdown",
|
| 74 |
+
"metadata": {
|
| 75 |
+
"application/vnd.databricks.v1+cell": {
|
| 76 |
+
"cellMetadata": {},
|
| 77 |
+
"inputWidgets": {},
|
| 78 |
+
"nuid": "273d1345-95d7-435a-a7b6-a5f3dbb3f073",
|
| 79 |
+
"showTitle": false,
|
| 80 |
+
"title": ""
|
| 81 |
+
}
|
| 82 |
+
},
|
| 83 |
+
"source": [
|
| 84 |
+
"## Create a RAG system\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"Use Langchain and Chroma to create a RAG system that answers questions based on the MLflow documentation."
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"cell_type": "code",
|
| 91 |
+
"execution_count": 13,
|
| 92 |
+
"metadata": {
|
| 93 |
+
"application/vnd.databricks.v1+cell": {
|
| 94 |
+
"cellMetadata": {
|
| 95 |
+
"byteLimit": 2048000,
|
| 96 |
+
"rowLimit": 10000
|
| 97 |
+
},
|
| 98 |
+
"inputWidgets": {},
|
| 99 |
+
"nuid": "2c28d0ad-f469-46ab-a2b4-c5e8db50a729",
|
| 100 |
+
"showTitle": false,
|
| 101 |
+
"title": ""
|
| 102 |
+
}
|
| 103 |
+
},
|
| 104 |
+
"outputs": [],
|
| 105 |
+
"source": [
|
| 106 |
+
"from langchain.chains import RetrievalQA\n",
|
| 107 |
+
"from langchain.document_loaders import WebBaseLoader\n",
|
| 108 |
+
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
| 109 |
+
"from langchain.llms import OpenAI\n",
|
| 110 |
+
"from langchain.text_splitter import CharacterTextSplitter\n",
|
| 111 |
+
"from langchain.vectorstores import Chroma\n",
|
| 112 |
+
"from langchain.chat_models import ChatOpenAI\n"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": 14,
|
| 118 |
+
"metadata": {
|
| 119 |
+
"application/vnd.databricks.v1+cell": {
|
| 120 |
+
"cellMetadata": {
|
| 121 |
+
"byteLimit": 2048000,
|
| 122 |
+
"rowLimit": 10000
|
| 123 |
+
},
|
| 124 |
+
"inputWidgets": {},
|
| 125 |
+
"nuid": "83a7e77e-6717-472a-86dc-02e2c356ddef",
|
| 126 |
+
"showTitle": false,
|
| 127 |
+
"title": ""
|
| 128 |
+
}
|
| 129 |
+
},
|
| 130 |
+
"outputs": [],
|
| 131 |
+
"source": [
|
| 132 |
+
"loader = WebBaseLoader(\"https://mlflow.org/docs/latest/index.html\")\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"documents = loader.load()\n",
|
| 135 |
+
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
| 136 |
+
"texts = text_splitter.split_documents(documents)\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"embeddings = OpenAIEmbeddings()\n",
|
| 139 |
+
"docsearch = Chroma.from_documents(texts, embeddings)\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"qa = RetrievalQA.from_chain_type(\n",
|
| 142 |
+
" llm=ChatOpenAI(model_name=\"gpt-3.5-turbo-0125\" , temperature=0),\n",
|
| 143 |
+
" chain_type=\"stuff\",\n",
|
| 144 |
+
" retriever=docsearch.as_retriever(),\n",
|
| 145 |
+
" return_source_documents=True,\n",
|
| 146 |
+
")"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "markdown",
|
| 151 |
+
"metadata": {
|
| 152 |
+
"application/vnd.databricks.v1+cell": {
|
| 153 |
+
"cellMetadata": {},
|
| 154 |
+
"inputWidgets": {},
|
| 155 |
+
"nuid": "fd70bcf6-7c44-44d3-9435-567b82611e1c",
|
| 156 |
+
"showTitle": false,
|
| 157 |
+
"title": ""
|
| 158 |
+
}
|
| 159 |
+
},
|
| 160 |
+
"source": [
|
| 161 |
+
"## Evaluate the RAG system using `mlflow.evaluate()`"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "markdown",
|
| 166 |
+
"metadata": {
|
| 167 |
+
"application/vnd.databricks.v1+cell": {
|
| 168 |
+
"cellMetadata": {},
|
| 169 |
+
"inputWidgets": {},
|
| 170 |
+
"nuid": "de1bc359-2e40-459c-bea4-bed35a117988",
|
| 171 |
+
"showTitle": false,
|
| 172 |
+
"title": ""
|
| 173 |
+
}
|
| 174 |
+
},
|
| 175 |
+
"source": [
|
| 176 |
+
"Create a simple function that runs each input through the RAG chain"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "code",
|
| 181 |
+
"execution_count": 15,
|
| 182 |
+
"metadata": {
|
| 183 |
+
"application/vnd.databricks.v1+cell": {
|
| 184 |
+
"cellMetadata": {
|
| 185 |
+
"byteLimit": 2048000,
|
| 186 |
+
"rowLimit": 10000
|
| 187 |
+
},
|
| 188 |
+
"inputWidgets": {},
|
| 189 |
+
"nuid": "667ec809-2bb5-4170-9937-6804386b41ec",
|
| 190 |
+
"showTitle": false,
|
| 191 |
+
"title": ""
|
| 192 |
+
}
|
| 193 |
+
},
|
| 194 |
+
"outputs": [],
|
| 195 |
+
"source": [
|
| 196 |
+
"def model(input_df):\n",
|
| 197 |
+
" answer = []\n",
|
| 198 |
+
" for index, row in input_df.iterrows():\n",
|
| 199 |
+
" answer.append(qa(row[\"questions\"]))\n",
|
| 200 |
+
"\n",
|
| 201 |
+
" return answer"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "code",
|
| 206 |
+
"execution_count": null,
|
| 207 |
+
"metadata": {},
|
| 208 |
+
"outputs": [],
|
| 209 |
+
"source": []
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"cell_type": "markdown",
|
| 213 |
+
"metadata": {
|
| 214 |
+
"application/vnd.databricks.v1+cell": {
|
| 215 |
+
"cellMetadata": {},
|
| 216 |
+
"inputWidgets": {},
|
| 217 |
+
"nuid": "d1064306-b7f3-4b3e-825c-4353d808f21d",
|
| 218 |
+
"showTitle": false,
|
| 219 |
+
"title": ""
|
| 220 |
+
}
|
| 221 |
+
},
|
| 222 |
+
"source": [
|
| 223 |
+
"Create an eval dataset"
|
| 224 |
+
]
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"cell_type": "code",
|
| 228 |
+
"execution_count": 16,
|
| 229 |
+
"metadata": {
|
| 230 |
+
"application/vnd.databricks.v1+cell": {
|
| 231 |
+
"cellMetadata": {
|
| 232 |
+
"byteLimit": 2048000,
|
| 233 |
+
"rowLimit": 10000
|
| 234 |
+
},
|
| 235 |
+
"inputWidgets": {},
|
| 236 |
+
"nuid": "a5481491-e4a9-42ea-8a3f-f527faffd04d",
|
| 237 |
+
"showTitle": false,
|
| 238 |
+
"title": ""
|
| 239 |
+
}
|
| 240 |
+
},
|
| 241 |
+
"outputs": [],
|
| 242 |
+
"source": [
|
| 243 |
+
"eval_df = pd.DataFrame(\n",
|
| 244 |
+
" {\n",
|
| 245 |
+
" \"questions\": [\n",
|
| 246 |
+
" \"What is MLflow?\",\n",
|
| 247 |
+
" \"How to run mlflow.evaluate()?\",\n",
|
| 248 |
+
" \"How to log_table()?\",\n",
|
| 249 |
+
" \"How to load_table()?\",\n",
|
| 250 |
+
" ],\n",
|
| 251 |
+
" }\n",
|
| 252 |
+
")"
|
| 253 |
+
]
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"cell_type": "markdown",
|
| 257 |
+
"metadata": {
|
| 258 |
+
"application/vnd.databricks.v1+cell": {
|
| 259 |
+
"cellMetadata": {},
|
| 260 |
+
"inputWidgets": {},
|
| 261 |
+
"nuid": "9c3c8023-8feb-427a-b36d-34cd1853a5dc",
|
| 262 |
+
"showTitle": false,
|
| 263 |
+
"title": ""
|
| 264 |
+
}
|
| 265 |
+
},
|
| 266 |
+
"source": [
|
| 267 |
+
"Create a faithfulness metric"
|
| 268 |
+
]
|
| 269 |
+
},
|
| 270 |
+
{
|
| 271 |
+
"cell_type": "code",
|
| 272 |
+
"execution_count": 17,
|
| 273 |
+
"metadata": {
|
| 274 |
+
"application/vnd.databricks.v1+cell": {
|
| 275 |
+
"cellMetadata": {
|
| 276 |
+
"byteLimit": 2048000,
|
| 277 |
+
"rowLimit": 10000
|
| 278 |
+
},
|
| 279 |
+
"inputWidgets": {},
|
| 280 |
+
"nuid": "3882b940-9c25-41ce-a301-72d8c0c90aaa",
|
| 281 |
+
"showTitle": false,
|
| 282 |
+
"title": ""
|
| 283 |
+
}
|
| 284 |
+
},
|
| 285 |
+
"outputs": [
|
| 286 |
+
{
|
| 287 |
+
"name": "stdout",
|
| 288 |
+
"output_type": "stream",
|
| 289 |
+
"text": [
|
| 290 |
+
"EvaluationMetric(name=faithfulness, greater_is_better=True, long_name=faithfulness, version=v1, metric_details=\n",
|
| 291 |
+
"Task:\n",
|
| 292 |
+
"You must return the following fields in your response in two lines, one below the other:\n",
|
| 293 |
+
"score: Your numerical score for the model's faithfulness based on the rubric\n",
|
| 294 |
+
"justification: Your reasoning about the model's faithfulness score\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"You are an impartial judge. You will be given an input that was sent to a machine\n",
|
| 297 |
+
"learning model, and you will be given an output that the model produced. You\n",
|
| 298 |
+
"may also be given additional information that was used by the model to generate the output.\n",
|
| 299 |
+
"\n",
|
| 300 |
+
"Your task is to determine a numerical score called faithfulness based on the input and output.\n",
|
| 301 |
+
"A definition of faithfulness and a grading rubric are provided below.\n",
|
| 302 |
+
"You must use the grading rubric to determine your score. You must also justify your score.\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"Examples could be included below for reference. Make sure to use them as references and to\n",
|
| 305 |
+
"understand them before completing the task.\n",
|
| 306 |
+
"\n",
|
| 307 |
+
"Input:\n",
|
| 308 |
+
"{input}\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"Output:\n",
|
| 311 |
+
"{output}\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"{grading_context_columns}\n",
|
| 314 |
+
"\n",
|
| 315 |
+
"Metric definition:\n",
|
| 316 |
+
"Faithfulness is only evaluated with the provided output and provided context, please ignore the provided input entirely when scoring faithfulness. Faithfulness assesses how much of the provided output is factually consistent with the provided context. A higher score indicates that a higher proportion of claims present in the output can be derived from the provided context. Faithfulness does not consider how much extra information from the context is not present in the output.\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"Grading rubric:\n",
|
| 319 |
+
"Faithfulness: Below are the details for different scores:\n",
|
| 320 |
+
"- Score 1: None of the claims in the output can be inferred from the provided context.\n",
|
| 321 |
+
"- Score 2: Some of the claims in the output can be inferred from the provided context, but the majority of the output is missing from, inconsistent with, or contradictory to the provided context.\n",
|
| 322 |
+
"- Score 3: Half or more of the claims in the output can be inferred from the provided context.\n",
|
| 323 |
+
"- Score 4: Most of the claims in the output can be inferred from the provided context, with very little information that is not directly supported by the provided context.\n",
|
| 324 |
+
"- Score 5: All of the claims in the output are directly supported by the provided context, demonstrating high faithfulness to the provided context.\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"Examples:\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"Example Input:\n",
|
| 329 |
+
"How do I disable MLflow autologging?\n",
|
| 330 |
+
"\n",
|
| 331 |
+
"Example Output:\n",
|
| 332 |
+
"mlflow.autolog(disable=True) will disable autologging for all functions. In Databricks, autologging is enabled by default. \n",
|
| 333 |
+
"\n",
|
| 334 |
+
"Additional information used by the model:\n",
|
| 335 |
+
"key: context\n",
|
| 336 |
+
"value:\n",
|
| 337 |
+
"mlflow.autolog(log_input_examples: bool = False, log_model_signatures: bool = True, log_models: bool = True, log_datasets: bool = True, disable: bool = False, exclusive: bool = False, disable_for_unsupported_versions: bool = False, silent: bool = False, extra_tags: Optional[Dict[str, str]] = None) → None[source] Enables (or disables) and configures autologging for all supported integrations. The parameters are passed to any autologging integrations that support them. See the tracking docs for a list of supported autologging integrations. Note that framework-specific configurations set at any point will take precedence over any configurations set by this function.\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"Example score: 2\n",
|
| 340 |
+
"Example justification: The output provides a working solution, using the mlflow.autolog() function that is provided in the context.\n",
|
| 341 |
+
" \n",
|
| 342 |
+
"\n",
|
| 343 |
+
"Example Input:\n",
|
| 344 |
+
"How do I disable MLflow autologging?\n",
|
| 345 |
+
"\n",
|
| 346 |
+
"Example Output:\n",
|
| 347 |
+
"mlflow.autolog(disable=True) will disable autologging for all functions.\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"Additional information used by the model:\n",
|
| 350 |
+
"key: context\n",
|
| 351 |
+
"value:\n",
|
| 352 |
+
"mlflow.autolog(log_input_examples: bool = False, log_model_signatures: bool = True, log_models: bool = True, log_datasets: bool = True, disable: bool = False, exclusive: bool = False, disable_for_unsupported_versions: bool = False, silent: bool = False, extra_tags: Optional[Dict[str, str]] = None) → None[source] Enables (or disables) and configures autologging for all supported integrations. The parameters are passed to any autologging integrations that support them. See the tracking docs for a list of supported autologging integrations. Note that framework-specific configurations set at any point will take precedence over any configurations set by this function.\n",
|
| 353 |
+
"\n",
|
| 354 |
+
"Example score: 5\n",
|
| 355 |
+
"Example justification: The output provides a solution that is using the mlflow.autolog() function that is provided in the context.\n",
|
| 356 |
+
" \n",
|
| 357 |
+
"\n",
|
| 358 |
+
"You must return the following fields in your response in two lines, one below the other:\n",
|
| 359 |
+
"score: Your numerical score for the model's faithfulness based on the rubric\n",
|
| 360 |
+
"justification: Your reasoning about the model's faithfulness score\n",
|
| 361 |
+
"\n",
|
| 362 |
+
"Do not add additional new lines. Do not add any other fields.\n",
|
| 363 |
+
" )\n"
|
| 364 |
+
]
|
| 365 |
+
}
|
| 366 |
+
],
|
| 367 |
+
"source": [
|
| 368 |
+
"from mlflow.metrics.genai import EvaluationExample, faithfulness\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"# Create a good and bad example for faithfulness in the context of this problem\n",
|
| 371 |
+
"faithfulness_examples = [\n",
|
| 372 |
+
" EvaluationExample(\n",
|
| 373 |
+
" input=\"How do I disable MLflow autologging?\",\n",
|
| 374 |
+
" output=\"mlflow.autolog(disable=True) will disable autologging for all functions. In Databricks, autologging is enabled by default. \",\n",
|
| 375 |
+
" score=2,\n",
|
| 376 |
+
" justification=\"The output provides a working solution, using the mlflow.autolog() function that is provided in the context.\",\n",
|
| 377 |
+
" grading_context={\n",
|
| 378 |
+
" \"context\": \"mlflow.autolog(log_input_examples: bool = False, log_model_signatures: bool = True, log_models: bool = True, log_datasets: bool = True, disable: bool = False, exclusive: bool = False, disable_for_unsupported_versions: bool = False, silent: bool = False, extra_tags: Optional[Dict[str, str]] = None) → None[source] Enables (or disables) and configures autologging for all supported integrations. The parameters are passed to any autologging integrations that support them. See the tracking docs for a list of supported autologging integrations. Note that framework-specific configurations set at any point will take precedence over any configurations set by this function.\"\n",
|
| 379 |
+
" },\n",
|
| 380 |
+
" ),\n",
|
| 381 |
+
" EvaluationExample(\n",
|
| 382 |
+
" input=\"How do I disable MLflow autologging?\",\n",
|
| 383 |
+
" output=\"mlflow.autolog(disable=True) will disable autologging for all functions.\",\n",
|
| 384 |
+
" score=5,\n",
|
| 385 |
+
" justification=\"The output provides a solution that is using the mlflow.autolog() function that is provided in the context.\",\n",
|
| 386 |
+
" grading_context={\n",
|
| 387 |
+
" \"context\": \"mlflow.autolog(log_input_examples: bool = False, log_model_signatures: bool = True, log_models: bool = True, log_datasets: bool = True, disable: bool = False, exclusive: bool = False, disable_for_unsupported_versions: bool = False, silent: bool = False, extra_tags: Optional[Dict[str, str]] = None) → None[source] Enables (or disables) and configures autologging for all supported integrations. The parameters are passed to any autologging integrations that support them. See the tracking docs for a list of supported autologging integrations. Note that framework-specific configurations set at any point will take precedence over any configurations set by this function.\"\n",
|
| 388 |
+
" },\n",
|
| 389 |
+
" ),\n",
|
| 390 |
+
"]\n",
|
| 391 |
+
"\n",
|
| 392 |
+
"faithfulness_metric = faithfulness(model=\"openai:/gpt-4\", examples=faithfulness_examples)\n",
|
| 393 |
+
"print(faithfulness_metric)"
|
| 394 |
+
]
|
| 395 |
+
},
|
| 396 |
+
{
|
| 397 |
+
"cell_type": "markdown",
|
| 398 |
+
"metadata": {},
|
| 399 |
+
"source": [
|
| 400 |
+
"Create a relevance metric. You can see the full grading prompt by printing the metric or by accessing the `metric_details` attribute of the metric."
|
| 401 |
+
]
|
| 402 |
+
},
|
| 403 |
+
{
|
| 404 |
+
"cell_type": "code",
|
| 405 |
+
"execution_count": 18,
|
| 406 |
+
"metadata": {},
|
| 407 |
+
"outputs": [
|
| 408 |
+
{
|
| 409 |
+
"name": "stdout",
|
| 410 |
+
"output_type": "stream",
|
| 411 |
+
"text": [
|
| 412 |
+
"EvaluationMetric(name=relevance, greater_is_better=True, long_name=relevance, version=v1, metric_details=\n",
|
| 413 |
+
"Task:\n",
|
| 414 |
+
"You must return the following fields in your response in two lines, one below the other:\n",
|
| 415 |
+
"score: Your numerical score for the model's relevance based on the rubric\n",
|
| 416 |
+
"justification: Your reasoning about the model's relevance score\n",
|
| 417 |
+
"\n",
|
| 418 |
+
"You are an impartial judge. You will be given an input that was sent to a machine\n",
|
| 419 |
+
"learning model, and you will be given an output that the model produced. You\n",
|
| 420 |
+
"may also be given additional information that was used by the model to generate the output.\n",
|
| 421 |
+
"\n",
|
| 422 |
+
"Your task is to determine a numerical score called relevance based on the input and output.\n",
|
| 423 |
+
"A definition of relevance and a grading rubric are provided below.\n",
|
| 424 |
+
"You must use the grading rubric to determine your score. You must also justify your score.\n",
|
| 425 |
+
"\n",
|
| 426 |
+
"Examples could be included below for reference. Make sure to use them as references and to\n",
|
| 427 |
+
"understand them before completing the task.\n",
|
| 428 |
+
"\n",
|
| 429 |
+
"Input:\n",
|
| 430 |
+
"{input}\n",
|
| 431 |
+
"\n",
|
| 432 |
+
"Output:\n",
|
| 433 |
+
"{output}\n",
|
| 434 |
+
"\n",
|
| 435 |
+
"{grading_context_columns}\n",
|
| 436 |
+
"\n",
|
| 437 |
+
"Metric definition:\n",
|
| 438 |
+
"Relevance encompasses the appropriateness, significance, and applicability of the output with respect to both the input and context. Scores should reflect the extent to which the output directly addresses the question provided in the input, given the provided context.\n",
|
| 439 |
+
"\n",
|
| 440 |
+
"Grading rubric:\n",
|
| 441 |
+
"Relevance: Below are the details for different scores:- Score 1: The output doesn't mention anything about the question or is completely irrelevant to the provided context.\n",
|
| 442 |
+
"- Score 2: The output provides some relevance to the question and is somehow related to the provided context.\n",
|
| 443 |
+
"- Score 3: The output mostly answers the question and is largely consistent with the provided context.\n",
|
| 444 |
+
"- Score 4: The output answers the question and is consistent with the provided context.\n",
|
| 445 |
+
"- Score 5: The output answers the question comprehensively using the provided context.\n",
|
| 446 |
+
"\n",
|
| 447 |
+
"Examples:\n",
|
| 448 |
+
"\n",
|
| 449 |
+
"Example Input:\n",
|
| 450 |
+
"How is MLflow related to Databricks?\n",
|
| 451 |
+
"\n",
|
| 452 |
+
"Example Output:\n",
|
| 453 |
+
"Databricks is a data engineering and analytics platform designed to help organizations process and analyze large amounts of data. Databricks is a company specializing in big data and machine learning solutions.\n",
|
| 454 |
+
"\n",
|
| 455 |
+
"Additional information used by the model:\n",
|
| 456 |
+
"key: context\n",
|
| 457 |
+
"value:\n",
|
| 458 |
+
"MLflow is an open-source platform for managing the end-to-end machine learning (ML) lifecycle. It was developed by Databricks, a company that specializes in big data and machine learning solutions. MLflow is designed to address the challenges that data scientists and machine learning engineers face when developing, training, and deploying machine learning models.\n",
|
| 459 |
+
"\n",
|
| 460 |
+
"Example score: 2\n",
|
| 461 |
+
"Example justification: The output provides relevant information about Databricks, mentioning it as a company specializing in big data and machine learning solutions. However, it doesn't directly address how MLflow is related to Databricks, which is the specific question asked in the input. Therefore, the output is only somewhat related to the provided context.\n",
|
| 462 |
+
" \n",
|
| 463 |
+
"\n",
|
| 464 |
+
"Example Input:\n",
|
| 465 |
+
"How is MLflow related to Databricks?\n",
|
| 466 |
+
"\n",
|
| 467 |
+
"Example Output:\n",
|
| 468 |
+
"MLflow is a product created by Databricks to enhance the efficiency of machine learning processes.\n",
|
| 469 |
+
"\n",
|
| 470 |
+
"Additional information used by the model:\n",
|
| 471 |
+
"key: context\n",
|
| 472 |
+
"value:\n",
|
| 473 |
+
"MLflow is an open-source platform for managing the end-to-end machine learning (ML) lifecycle. It was developed by Databricks, a company that specializes in big data and machine learning solutions. MLflow is designed to address the challenges that data scientists and machine learning engineers face when developing, training, and deploying machine learning models.\n",
|
| 474 |
+
"\n",
|
| 475 |
+
"Example score: 4\n",
|
| 476 |
+
"Example justification: The output provides a relevant and accurate statement about the relationship between MLflow and Databricks. While it doesn't provide extensive detail, it still offers a substantial and meaningful response. To achieve a score of 5, the response could be further improved by providing additional context or details about how MLflow specifically functions within the Databricks ecosystem.\n",
|
| 477 |
+
" \n",
|
| 478 |
+
"\n",
|
| 479 |
+
"You must return the following fields in your response in two lines, one below the other:\n",
|
| 480 |
+
"score: Your numerical score for the model's relevance based on the rubric\n",
|
| 481 |
+
"justification: Your reasoning about the model's relevance score\n",
|
| 482 |
+
"\n",
|
| 483 |
+
"Do not add additional new lines. Do not add any other fields.\n",
|
| 484 |
+
" )\n"
|
| 485 |
+
]
|
| 486 |
+
}
|
| 487 |
+
],
|
| 488 |
+
"source": [
|
| 489 |
+
"from mlflow.metrics.genai import EvaluationExample, relevance\n",
|
| 490 |
+
"\n",
|
| 491 |
+
"relevance_metric = relevance(model=\"openai:/gpt-4\")\n",
|
| 492 |
+
"print(relevance_metric)"
|
| 493 |
+
]
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"cell_type": "code",
|
| 497 |
+
"execution_count": 24,
|
| 498 |
+
"metadata": {},
|
| 499 |
+
"outputs": [],
|
| 500 |
+
"source": [
|
| 501 |
+
"eval_df_final=eval_df.copy(deep=True)"
|
| 502 |
+
]
|
| 503 |
+
},
|
| 504 |
+
{
|
| 505 |
+
"cell_type": "code",
|
| 506 |
+
"execution_count": 25,
|
| 507 |
+
"metadata": {},
|
| 508 |
+
"outputs": [],
|
| 509 |
+
"source": [
|
| 510 |
+
"aa=model(eval_df)"
|
| 511 |
+
]
|
| 512 |
+
},
|
| 513 |
+
{
|
| 514 |
+
"cell_type": "code",
|
| 515 |
+
"execution_count": 26,
|
| 516 |
+
"metadata": {},
|
| 517 |
+
"outputs": [
|
| 518 |
+
{
|
| 519 |
+
"data": {
|
| 520 |
+
"text/plain": [
|
| 521 |
+
"[{'query': 'What is MLflow?',\n",
|
| 522 |
+
" 'result': 'MLflow is an open-source platform designed to help machine learning practitioners and teams manage the complexities of the machine learning process. It focuses on the full lifecycle of machine learning projects, making sure that each phase is manageable, traceable, and reproducible.',\n",
|
| 523 |
+
" 'source_documents': [Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 524 |
+
" Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 525 |
+
" Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 526 |
+
" Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation\\n\\n2.12.1\\n\\n\\n MLflow\\n\\nWhat is MLflow?\\nGetting Started with MLflow\\nNew Features\\nLLMs\\nModel Evaluation\\nDeep Learning\\nTraditional ML\\nDeployment\\nMLflow Tracking\\nSystem Metrics\\nMLflow Projects\\nMLflow Models\\nMLflow Model Registry\\nMLflow Recipes\\nMLflow Plugins\\nMLflow Authentication\\nCommand-Line Interface\\nSearch Runs\\nSearch Experiments\\nPython API\\nR API\\nJava API\\nREST API\\nOfficial MLflow Docker Image\\nCommunity Model Flavors\\nTutorials and Examples\\n\\n\\nContribute\\n\\n\\nDocumentation \\nMLflow: A Tool for Managing the Machine Learning Lifecycle', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
|
| 527 |
+
" {'query': 'How to run mlflow.evaluate()?',\n",
|
| 528 |
+
" 'result': 'To run `mlflow.evaluate()`, you need to follow these steps:\\n\\n1. Ensure you have MLflow installed in your environment.\\n2. Import the necessary libraries, including `mlflow`.\\n3. Load your model and data.\\n4. Use the `mlflow.evaluate()` function, passing in your model, data, and any other necessary parameters.\\n5. Review the evaluation results and metrics provided by the function.\\n\\nIf you need more specific details or code examples, please refer to the MLflow documentation or tutorials for a step-by-step guide on running `mlflow.evaluate()`.',\n",
|
| 529 |
+
" 'source_documents': [Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 530 |
+
" Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 531 |
+
" Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 532 |
+
" Document(page_content='Learn how to evaluate LLMs and LLM-powered solutions with MLflow Evaluate.\\n \\n\\n Using Custom PyFunc with LLMs\\n \\n\\n Explore the nuances of packaging and deploying advanced LLMs in MLflow using custom PyFuncs. This guide delves deep\\n into managing intricate model behaviors, ensuring seamless and efficient LLM deployments.\\n \\n\\n Evaluation for RAG\\n \\n\\n Learn how to evaluate Retrieval Augmented Generation applications by leveraging LLMs to generate a evaluation dataset and evaluate it using the built-in metrics in the MLflow Evaluate API.\\n \\n\\n LLM Tracking with MLflow', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
|
| 533 |
+
" {'query': 'How to log_table()?',\n",
|
| 534 |
+
" 'result': \"I don't have information on a specific function called `log_table()` in the context provided. It's possible that it might be a custom function or a feature not explicitly mentioned in the provided context. If you can provide more details or context, I may be able to assist you further.\",\n",
|
| 535 |
+
" 'source_documents': [Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 536 |
+
" Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 537 |
+
" Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 538 |
+
" Document(page_content=\"Dive into the intricacies of MLflow's LLM Tracking system. From capturing prompts to monitoring generated outputs,\\n discover how MLflow provides a holistic solution for managing LLM interactions.\", metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
|
| 539 |
+
" {'query': 'How to load_table()?',\n",
|
| 540 |
+
" 'result': \"I'm not sure what `load_table()` refers to in this context. If you can provide more information or context, I might be able to help you better.\",\n",
|
| 541 |
+
" 'source_documents': [Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 542 |
+
" Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 543 |
+
" Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 544 |
+
" Document(page_content='This guide showcases the seamless end-to-end process of training a linear regression model, packaging it in a reproducible format,\\n and deploying to a Kubernetes cluster using MLflow. Explore how MLflow simplifies model deployment to production environments.\\n \\n\\n\\nNext \\n\\n\\n © MLflow Project, a Series of LF Projects, LLC. All rights reserved.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]}]"
|
| 545 |
+
]
|
| 546 |
+
},
|
| 547 |
+
"execution_count": 26,
|
| 548 |
+
"metadata": {},
|
| 549 |
+
"output_type": "execute_result"
|
| 550 |
+
}
|
| 551 |
+
],
|
| 552 |
+
"source": [
|
| 553 |
+
"aa"
|
| 554 |
+
]
|
| 555 |
+
},
|
| 556 |
+
{
|
| 557 |
+
"cell_type": "code",
|
| 558 |
+
"execution_count": 20,
|
| 559 |
+
"metadata": {},
|
| 560 |
+
"outputs": [
|
| 561 |
+
{
|
| 562 |
+
"data": {
|
| 563 |
+
"text/plain": [
|
| 564 |
+
"[{'query': 'What is MLflow?',\n",
|
| 565 |
+
" 'result': 'MLflow is an open-source platform designed to help machine learning practitioners and teams manage the complexities of the machine learning process. It focuses on the full lifecycle of machine learning projects, ensuring that each phase is manageable, traceable, and reproducible. It offers features such as tracking, projects, models, model registry, and more to assist in solving real-world MLOps problems.',\n",
|
| 566 |
+
" 'source_documents': [Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 567 |
+
" Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 568 |
+
" Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 569 |
+
" Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation\\n\\n2.12.1\\n\\n\\n MLflow\\n\\nWhat is MLflow?\\nGetting Started with MLflow\\nNew Features\\nLLMs\\nModel Evaluation\\nDeep Learning\\nTraditional ML\\nDeployment\\nMLflow Tracking\\nSystem Metrics\\nMLflow Projects\\nMLflow Models\\nMLflow Model Registry\\nMLflow Recipes\\nMLflow Plugins\\nMLflow Authentication\\nCommand-Line Interface\\nSearch Runs\\nSearch Experiments\\nPython API\\nR API\\nJava API\\nREST API\\nOfficial MLflow Docker Image\\nCommunity Model Flavors\\nTutorials and Examples\\n\\n\\nContribute\\n\\n\\nDocumentation \\nMLflow: A Tool for Managing the Machine Learning Lifecycle', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
|
| 570 |
+
" {'query': 'How to run mlflow.evaluate()?',\n",
|
| 571 |
+
" 'result': 'To run `mlflow.evaluate()`, you need to follow these steps:\\n\\n1. Import the necessary libraries:\\n```python\\nimport mlflow\\n```\\n\\n2. Load your model and data:\\n```python\\nmodel = load_model() # Load your model\\ndata = load_data() # Load your evaluation data\\n```\\n\\n3. Use `mlflow.evaluate()` to evaluate your model:\\n```python\\nmlflow.evaluate(model, data)\\n```\\n\\n4. Review the evaluation metrics and visual insights in the MLflow UI.\\n\\nIf you need more specific details or have a particular use case in mind, please provide additional context for a more tailored explanation.',\n",
|
| 572 |
+
" 'source_documents': [Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 573 |
+
" Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 574 |
+
" Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 575 |
+
" Document(page_content='Learn how to evaluate LLMs and LLM-powered solutions with MLflow Evaluate.\\n \\n\\n Using Custom PyFunc with LLMs\\n \\n\\n Explore the nuances of packaging and deploying advanced LLMs in MLflow using custom PyFuncs. This guide delves deep\\n into managing intricate model behaviors, ensuring seamless and efficient LLM deployments.\\n \\n\\n Evaluation for RAG\\n \\n\\n Learn how to evaluate Retrieval Augmented Generation applications by leveraging LLMs to generate a evaluation dataset and evaluate it using the built-in metrics in the MLflow Evaluate API.\\n \\n\\n LLM Tracking with MLflow', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
|
| 576 |
+
" {'query': 'How to log_table()?',\n",
|
| 577 |
+
" 'result': \"I don't have information on a specific function called `log_table()` in the context provided. It's possible that it might be a custom function or a feature not covered in the provided context. If you can provide more details or context, I may be able to assist you further.\",\n",
|
| 578 |
+
" 'source_documents': [Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 579 |
+
" Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 580 |
+
" Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 581 |
+
" Document(page_content=\"Dive into the intricacies of MLflow's LLM Tracking system. From capturing prompts to monitoring generated outputs,\\n discover how MLflow provides a holistic solution for managing LLM interactions.\", metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
|
| 582 |
+
" {'query': 'How to load_table()?',\n",
|
| 583 |
+
" 'result': \"I don't have information on a function called `load_table()` in the context provided. It seems to be outside the scope of the MLflow Tracking, Autologging, and Deployment Quickstart guides. If you can provide more context or clarify where `load_table()` is from, I may be able to assist you further.\",\n",
|
| 584 |
+
" 'source_documents': [Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 585 |
+
" Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 586 |
+
" Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
|
| 587 |
+
" Document(page_content='This guide showcases the seamless end-to-end process of training a linear regression model, packaging it in a reproducible format,\\n and deploying to a Kubernetes cluster using MLflow. Explore how MLflow simplifies model deployment to production environments.\\n \\n\\n\\nNext \\n\\n\\n © MLflow Project, a Series of LF Projects, LLC. All rights reserved.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]}]"
|
| 588 |
+
]
|
| 589 |
+
},
|
| 590 |
+
"execution_count": 20,
|
| 591 |
+
"metadata": {},
|
| 592 |
+
"output_type": "execute_result"
|
| 593 |
+
}
|
| 594 |
+
],
|
| 595 |
+
"source": [
|
| 596 |
+
"eval_df_final[\"results\"]=model(eval_df)"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"cell_type": "code",
|
| 601 |
+
"execution_count": 29,
|
| 602 |
+
"metadata": {
|
| 603 |
+
"application/vnd.databricks.v1+cell": {
|
| 604 |
+
"cellMetadata": {
|
| 605 |
+
"byteLimit": 2048000,
|
| 606 |
+
"rowLimit": 10000
|
| 607 |
+
},
|
| 608 |
+
"inputWidgets": {},
|
| 609 |
+
"nuid": "ea40ce52-6ac7-4c20-9669-d24f80a6cebe",
|
| 610 |
+
"showTitle": false,
|
| 611 |
+
"title": ""
|
| 612 |
+
}
|
| 613 |
+
},
|
| 614 |
+
"outputs": [
|
| 615 |
+
{
|
| 616 |
+
"name": "stderr",
|
| 617 |
+
"output_type": "stream",
|
| 618 |
+
"text": [
|
| 619 |
+
"/u/marshad/.conda/envs/agllm-env1/lib/python3.9/site-packages/mlflow/data/digest_utils.py:26: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
|
| 620 |
+
" string_columns = trimmed_df.columns[(df.applymap(type) == str).all(0)]\n",
|
| 621 |
+
"/u/marshad/.conda/envs/agllm-env1/lib/python3.9/site-packages/mlflow/models/evaluation/base.py:414: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
|
| 622 |
+
" data = data.applymap(_hash_array_like_element_as_bytes)\n",
|
| 623 |
+
"2024/04/21 17:23:37 INFO mlflow.models.evaluation.base: Evaluating the model with the default evaluator.\n",
|
| 624 |
+
"2024/04/21 17:23:37 INFO mlflow.models.evaluation.default_evaluator: Computing model predictions.\n",
|
| 625 |
+
"2024/04/21 17:23:44 INFO mlflow.models.evaluation.default_evaluator: Testing metrics on first row...\n",
|
| 626 |
+
"2024/04/21 17:23:44 WARNING mlflow.metrics.metric_definitions: Failed to load 'toxicity' metric (error: ModuleNotFoundError(\"No module named 'evaluate'\")), skipping metric logging.\n",
|
| 627 |
+
"2024/04/21 17:23:44 WARNING mlflow.metrics.metric_definitions: Failed to load flesch kincaid metric, skipping metric logging.\n",
|
| 628 |
+
"2024/04/21 17:23:44 WARNING mlflow.metrics.metric_definitions: Failed to load automated readability index metric, skipping metric logging.\n",
|
| 629 |
+
"100%|██████████| 1/1 [00:03<00:00, 3.53s/it]\n",
|
| 630 |
+
"100%|██████████| 1/1 [00:03<00:00, 3.72s/it]\n",
|
| 631 |
+
"2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: token_count\n",
|
| 632 |
+
"2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: toxicity\n",
|
| 633 |
+
"2024/04/21 17:23:51 WARNING mlflow.metrics.metric_definitions: Failed to load 'toxicity' metric (error: ModuleNotFoundError(\"No module named 'evaluate'\")), skipping metric logging.\n",
|
| 634 |
+
"2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: flesch_kincaid_grade_level\n",
|
| 635 |
+
"2024/04/21 17:23:51 WARNING mlflow.metrics.metric_definitions: Failed to load flesch kincaid metric, skipping metric logging.\n",
|
| 636 |
+
"2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: ari_grade_level\n",
|
| 637 |
+
"2024/04/21 17:23:51 WARNING mlflow.metrics.metric_definitions: Failed to load automated readability index metric, skipping metric logging.\n",
|
| 638 |
+
"2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: exact_match\n",
|
| 639 |
+
"2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating metrics: faithfulness\n",
|
| 640 |
+
"100%|██████████| 4/4 [00:03<00:00, 1.02it/s]\n",
|
| 641 |
+
"2024/04/21 17:23:55 INFO mlflow.models.evaluation.default_evaluator: Evaluating metrics: relevance\n",
|
| 642 |
+
"100%|██████████| 4/4 [00:04<00:00, 1.20s/it]\n"
|
| 643 |
+
]
|
| 644 |
+
},
|
| 645 |
+
{
|
| 646 |
+
"name": "stdout",
|
| 647 |
+
"output_type": "stream",
|
| 648 |
+
"text": [
|
| 649 |
+
"{'latency/mean': 1.7213420271873474, 'latency/variance': 0.04483020464773446, 'latency/p90': 1.942362928390503, 'faithfulness/v1/mean': 4.75, 'faithfulness/v1/variance': 0.1875, 'faithfulness/v1/p90': 5.0, 'relevance/v1/mean': 3.75, 'relevance/v1/variance': 1.6875, 'relevance/v1/p90': 5.0}\n"
|
| 650 |
+
]
|
| 651 |
+
}
|
| 652 |
+
],
|
| 653 |
+
"source": [
|
| 654 |
+
"results = mlflow.evaluate(\n",
|
| 655 |
+
" model,\n",
|
| 656 |
+
" eval_df,\n",
|
| 657 |
+
" model_type=\"question-answering\",\n",
|
| 658 |
+
" evaluators=\"default\",\n",
|
| 659 |
+
" predictions=\"result\",\n",
|
| 660 |
+
" extra_metrics=[faithfulness_metric, relevance_metric, mlflow.metrics.latency()],\n",
|
| 661 |
+
" evaluator_config={\n",
|
| 662 |
+
" \"col_mapping\": {\n",
|
| 663 |
+
" \"inputs\": \"questions\",\n",
|
| 664 |
+
" \"context\": \"source_documents\",\n",
|
| 665 |
+
" }\n",
|
| 666 |
+
" },\n",
|
| 667 |
+
")\n",
|
| 668 |
+
"print(results.metrics)"
|
| 669 |
+
]
|
| 670 |
+
},
|
| 671 |
+
{
|
| 672 |
+
"cell_type": "code",
|
| 673 |
+
"execution_count": null,
|
| 674 |
+
"metadata": {},
|
| 675 |
+
"outputs": [],
|
| 676 |
+
"source": []
|
| 677 |
+
},
|
| 678 |
+
{
|
| 679 |
+
"cell_type": "code",
|
| 680 |
+
"execution_count": 13,
|
| 681 |
+
"metadata": {
|
| 682 |
+
"application/vnd.databricks.v1+cell": {
|
| 683 |
+
"cellMetadata": {},
|
| 684 |
+
"inputWidgets": {},
|
| 685 |
+
"nuid": "989a0861-5153-44e6-a19d-efcae7fe6cb5",
|
| 686 |
+
"showTitle": false,
|
| 687 |
+
"title": ""
|
| 688 |
+
}
|
| 689 |
+
},
|
| 690 |
+
"outputs": [
|
| 691 |
+
{
|
| 692 |
+
"data": {
|
| 693 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 694 |
+
"model_id": "747f65b309b94257b396eebffe814fa6",
|
| 695 |
+
"version_major": 2,
|
| 696 |
+
"version_minor": 0
|
| 697 |
+
},
|
| 698 |
+
"text/plain": [
|
| 699 |
+
"Downloading artifacts: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 700 |
+
]
|
| 701 |
+
},
|
| 702 |
+
"metadata": {},
|
| 703 |
+
"output_type": "display_data"
|
| 704 |
+
},
|
| 705 |
+
{
|
| 706 |
+
"data": {
|
| 707 |
+
"text/html": [
|
| 708 |
+
"<div>\n",
|
| 709 |
+
"<style scoped>\n",
|
| 710 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 711 |
+
" vertical-align: middle;\n",
|
| 712 |
+
" }\n",
|
| 713 |
+
"\n",
|
| 714 |
+
" .dataframe tbody tr th {\n",
|
| 715 |
+
" vertical-align: top;\n",
|
| 716 |
+
" }\n",
|
| 717 |
+
"\n",
|
| 718 |
+
" .dataframe thead th {\n",
|
| 719 |
+
" text-align: right;\n",
|
| 720 |
+
" }\n",
|
| 721 |
+
"</style>\n",
|
| 722 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 723 |
+
" <thead>\n",
|
| 724 |
+
" <tr style=\"text-align: right;\">\n",
|
| 725 |
+
" <th></th>\n",
|
| 726 |
+
" <th>questions</th>\n",
|
| 727 |
+
" <th>outputs</th>\n",
|
| 728 |
+
" <th>source_documents</th>\n",
|
| 729 |
+
" <th>latency</th>\n",
|
| 730 |
+
" <th>token_count</th>\n",
|
| 731 |
+
" <th>toxicity/v1/score</th>\n",
|
| 732 |
+
" <th>flesch_kincaid_grade_level/v1/score</th>\n",
|
| 733 |
+
" <th>ari_grade_level/v1/score</th>\n",
|
| 734 |
+
" <th>faithfulness/v1/score</th>\n",
|
| 735 |
+
" <th>faithfulness/v1/justification</th>\n",
|
| 736 |
+
" <th>relevance/v1/score</th>\n",
|
| 737 |
+
" <th>relevance/v1/justification</th>\n",
|
| 738 |
+
" </tr>\n",
|
| 739 |
+
" </thead>\n",
|
| 740 |
+
" <tbody>\n",
|
| 741 |
+
" <tr>\n",
|
| 742 |
+
" <th>0</th>\n",
|
| 743 |
+
" <td>What is MLflow?</td>\n",
|
| 744 |
+
" <td>MLflow is an open-source platform, purpose-bu...</td>\n",
|
| 745 |
+
" <td>[{'lc_attributes': {}, 'lc_namespace': ['langc...</td>\n",
|
| 746 |
+
" <td>1.989822</td>\n",
|
| 747 |
+
" <td>53</td>\n",
|
| 748 |
+
" <td>0.000137</td>\n",
|
| 749 |
+
" <td>12.5</td>\n",
|
| 750 |
+
" <td>18.4</td>\n",
|
| 751 |
+
" <td>5</td>\n",
|
| 752 |
+
" <td>The output provided by the model is a direct e...</td>\n",
|
| 753 |
+
" <td>5</td>\n",
|
| 754 |
+
" <td>The output provides a comprehensive answer to ...</td>\n",
|
| 755 |
+
" </tr>\n",
|
| 756 |
+
" <tr>\n",
|
| 757 |
+
" <th>1</th>\n",
|
| 758 |
+
" <td>How to run mlflow.evaluate()?</td>\n",
|
| 759 |
+
" <td>The mlflow.evaluate() API allows you to valid...</td>\n",
|
| 760 |
+
" <td>[{'lc_attributes': {}, 'lc_namespace': ['langc...</td>\n",
|
| 761 |
+
" <td>1.945368</td>\n",
|
| 762 |
+
" <td>55</td>\n",
|
| 763 |
+
" <td>0.000200</td>\n",
|
| 764 |
+
" <td>9.1</td>\n",
|
| 765 |
+
" <td>12.6</td>\n",
|
| 766 |
+
" <td>5</td>\n",
|
| 767 |
+
" <td>The output provided by the model is completely...</td>\n",
|
| 768 |
+
" <td>4</td>\n",
|
| 769 |
+
" <td>The output provides a relevant and accurate ex...</td>\n",
|
| 770 |
+
" </tr>\n",
|
| 771 |
+
" <tr>\n",
|
| 772 |
+
" <th>2</th>\n",
|
| 773 |
+
" <td>How to log_table()?</td>\n",
|
| 774 |
+
" <td>You can log a table with MLflow using the log...</td>\n",
|
| 775 |
+
" <td>[{'lc_attributes': {}, 'lc_namespace': ['langc...</td>\n",
|
| 776 |
+
" <td>1.521511</td>\n",
|
| 777 |
+
" <td>32</td>\n",
|
| 778 |
+
" <td>0.000289</td>\n",
|
| 779 |
+
" <td>5.0</td>\n",
|
| 780 |
+
" <td>6.8</td>\n",
|
| 781 |
+
" <td>1</td>\n",
|
| 782 |
+
" <td>The output claims that you can log a table wit...</td>\n",
|
| 783 |
+
" <td>5</td>\n",
|
| 784 |
+
" <td>The output provides a comprehensive answer to ...</td>\n",
|
| 785 |
+
" </tr>\n",
|
| 786 |
+
" <tr>\n",
|
| 787 |
+
" <th>3</th>\n",
|
| 788 |
+
" <td>How to load_table()?</td>\n",
|
| 789 |
+
" <td>You can't load_table() with MLflow. MLflow is...</td>\n",
|
| 790 |
+
" <td>[{'lc_attributes': {}, 'lc_namespace': ['langc...</td>\n",
|
| 791 |
+
" <td>1.105279</td>\n",
|
| 792 |
+
" <td>27</td>\n",
|
| 793 |
+
" <td>0.000279</td>\n",
|
| 794 |
+
" <td>5.8</td>\n",
|
| 795 |
+
" <td>8.8</td>\n",
|
| 796 |
+
" <td>5</td>\n",
|
| 797 |
+
" <td>The output claim that \"You can't load_table() ...</td>\n",
|
| 798 |
+
" <td>4</td>\n",
|
| 799 |
+
" <td>The output provides a relevant and accurate re...</td>\n",
|
| 800 |
+
" </tr>\n",
|
| 801 |
+
" </tbody>\n",
|
| 802 |
+
"</table>\n",
|
| 803 |
+
"</div>"
|
| 804 |
+
],
|
| 805 |
+
"text/plain": [
|
| 806 |
+
" questions \\\n",
|
| 807 |
+
"0 What is MLflow? \n",
|
| 808 |
+
"1 How to run mlflow.evaluate()? \n",
|
| 809 |
+
"2 How to log_table()? \n",
|
| 810 |
+
"3 How to load_table()? \n",
|
| 811 |
+
"\n",
|
| 812 |
+
" outputs \\\n",
|
| 813 |
+
"0 MLflow is an open-source platform, purpose-bu... \n",
|
| 814 |
+
"1 The mlflow.evaluate() API allows you to valid... \n",
|
| 815 |
+
"2 You can log a table with MLflow using the log... \n",
|
| 816 |
+
"3 You can't load_table() with MLflow. MLflow is... \n",
|
| 817 |
+
"\n",
|
| 818 |
+
" source_documents latency token_count \\\n",
|
| 819 |
+
"0 [{'lc_attributes': {}, 'lc_namespace': ['langc... 1.989822 53 \n",
|
| 820 |
+
"1 [{'lc_attributes': {}, 'lc_namespace': ['langc... 1.945368 55 \n",
|
| 821 |
+
"2 [{'lc_attributes': {}, 'lc_namespace': ['langc... 1.521511 32 \n",
|
| 822 |
+
"3 [{'lc_attributes': {}, 'lc_namespace': ['langc... 1.105279 27 \n",
|
| 823 |
+
"\n",
|
| 824 |
+
" toxicity/v1/score flesch_kincaid_grade_level/v1/score \\\n",
|
| 825 |
+
"0 0.000137 12.5 \n",
|
| 826 |
+
"1 0.000200 9.1 \n",
|
| 827 |
+
"2 0.000289 5.0 \n",
|
| 828 |
+
"3 0.000279 5.8 \n",
|
| 829 |
+
"\n",
|
| 830 |
+
" ari_grade_level/v1/score faithfulness/v1/score \\\n",
|
| 831 |
+
"0 18.4 5 \n",
|
| 832 |
+
"1 12.6 5 \n",
|
| 833 |
+
"2 6.8 1 \n",
|
| 834 |
+
"3 8.8 5 \n",
|
| 835 |
+
"\n",
|
| 836 |
+
" faithfulness/v1/justification relevance/v1/score \\\n",
|
| 837 |
+
"0 The output provided by the model is a direct e... 5 \n",
|
| 838 |
+
"1 The output provided by the model is completely... 4 \n",
|
| 839 |
+
"2 The output claims that you can log a table wit... 5 \n",
|
| 840 |
+
"3 The output claim that \"You can't load_table() ... 4 \n",
|
| 841 |
+
"\n",
|
| 842 |
+
" relevance/v1/justification \n",
|
| 843 |
+
"0 The output provides a comprehensive answer to ... \n",
|
| 844 |
+
"1 The output provides a relevant and accurate ex... \n",
|
| 845 |
+
"2 The output provides a comprehensive answer to ... \n",
|
| 846 |
+
"3 The output provides a relevant and accurate re... "
|
| 847 |
+
]
|
| 848 |
+
},
|
| 849 |
+
"execution_count": 13,
|
| 850 |
+
"metadata": {},
|
| 851 |
+
"output_type": "execute_result"
|
| 852 |
+
}
|
| 853 |
+
],
|
| 854 |
+
"source": [
|
| 855 |
+
"results.tables[\"eval_results_table\"]"
|
| 856 |
+
]
|
| 857 |
+
},
|
| 858 |
+
{
|
| 859 |
+
"cell_type": "code",
|
| 860 |
+
"execution_count": null,
|
| 861 |
+
"metadata": {},
|
| 862 |
+
"outputs": [],
|
| 863 |
+
"source": []
|
| 864 |
+
}
|
| 865 |
+
],
|
| 866 |
+
"metadata": {
|
| 867 |
+
"application/vnd.databricks.v1+notebook": {
|
| 868 |
+
"dashboards": [],
|
| 869 |
+
"language": "python",
|
| 870 |
+
"notebookMetadata": {
|
| 871 |
+
"pythonIndentUnit": 2
|
| 872 |
+
},
|
| 873 |
+
"notebookName": "LLM Evaluation Examples -- RAG",
|
| 874 |
+
"widgets": {}
|
| 875 |
+
},
|
| 876 |
+
"kernelspec": {
|
| 877 |
+
"display_name": "mlflow-dev-env",
|
| 878 |
+
"language": "python",
|
| 879 |
+
"name": "python3"
|
| 880 |
+
},
|
| 881 |
+
"language_info": {
|
| 882 |
+
"codemirror_mode": {
|
| 883 |
+
"name": "ipython",
|
| 884 |
+
"version": 3
|
| 885 |
+
},
|
| 886 |
+
"file_extension": ".py",
|
| 887 |
+
"mimetype": "text/x-python",
|
| 888 |
+
"name": "python",
|
| 889 |
+
"nbconvert_exporter": "python",
|
| 890 |
+
"pygments_lexer": "ipython3",
|
| 891 |
+
"version": "3.9.19"
|
| 892 |
+
}
|
| 893 |
+
},
|
| 894 |
+
"nbformat": 4,
|
| 895 |
+
"nbformat_minor": 0
|
| 896 |
+
}
|
push_logs.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Uploading LFS objects: 0% (0/58), 0 B | 0 B/s, done.
|
question-generation-retrieval-evaluation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements-23feb2025.txt
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Package Version
|
| 2 |
+
---------------------------------------- ---------
|
| 3 |
+
aiofiles 23.2.1
|
| 4 |
+
aiohappyeyeballs 2.4.0
|
| 5 |
+
aiohttp 3.10.5
|
| 6 |
+
aiosignal 1.3.1
|
| 7 |
+
alembic 1.13.2
|
| 8 |
+
altair 5.4.0
|
| 9 |
+
aniso8601 9.0.1
|
| 10 |
+
annotated-types 0.7.0
|
| 11 |
+
anthropic 0.34.1
|
| 12 |
+
anyio 4.4.0
|
| 13 |
+
appdirs 1.4.4
|
| 14 |
+
appnope 0.1.4
|
| 15 |
+
asgiref 3.8.1
|
| 16 |
+
asttokens 2.4.1
|
| 17 |
+
async-timeout 4.0.3
|
| 18 |
+
attrs 24.2.0
|
| 19 |
+
backcall 0.2.0
|
| 20 |
+
backoff 2.2.1
|
| 21 |
+
bcrypt 4.2.0
|
| 22 |
+
blinker 1.8.2
|
| 23 |
+
Bottleneck 1.3.7
|
| 24 |
+
build 1.2.1
|
| 25 |
+
cachetools 5.5.0
|
| 26 |
+
cattrs 24.1.2
|
| 27 |
+
certifi 2024.7.4
|
| 28 |
+
charset-normalizer 3.3.2
|
| 29 |
+
chroma-hnswlib 0.7.6
|
| 30 |
+
chromadb 0.5.5
|
| 31 |
+
click 8.1.7
|
| 32 |
+
click-plugins 1.1.1
|
| 33 |
+
cligj 0.7.2
|
| 34 |
+
cloudpickle 3.0.0
|
| 35 |
+
coloredlogs 15.0.1
|
| 36 |
+
comm 0.2.2
|
| 37 |
+
contourpy 1.2.0
|
| 38 |
+
cycler 0.12.1
|
| 39 |
+
databricks-sdk 0.31.1
|
| 40 |
+
dataclasses-json 0.6.7
|
| 41 |
+
debugpy 1.6.7
|
| 42 |
+
decorator 5.1.1
|
| 43 |
+
defusedxml 0.7.1
|
| 44 |
+
Deprecated 1.2.14
|
| 45 |
+
distro 1.9.0
|
| 46 |
+
docker 7.1.0
|
| 47 |
+
entrypoints 0.4
|
| 48 |
+
et-xmlfile 1.1.0
|
| 49 |
+
exceptiongroup 1.2.2
|
| 50 |
+
executing 2.0.1
|
| 51 |
+
fastapi 0.112.2
|
| 52 |
+
ffmpy 0.4.0
|
| 53 |
+
filelock 3.15.4
|
| 54 |
+
fiona 1.9.5
|
| 55 |
+
Flask 3.0.3
|
| 56 |
+
flatbuffers 24.3.25
|
| 57 |
+
fonttools 4.25.0
|
| 58 |
+
frozenlist 1.4.1
|
| 59 |
+
fsspec 2024.6.1
|
| 60 |
+
GDAL 3.6.2
|
| 61 |
+
geojson-rewind 1.1.0
|
| 62 |
+
geomet 1.1.0
|
| 63 |
+
geopandas 0.9.0
|
| 64 |
+
gitdb 4.0.11
|
| 65 |
+
GitPython 3.1.43
|
| 66 |
+
google-auth 2.34.0
|
| 67 |
+
googleapis-common-protos 1.63.2
|
| 68 |
+
gradio 4.42.0
|
| 69 |
+
gradio_client 1.3.0
|
| 70 |
+
graphene 3.3
|
| 71 |
+
graphql-core 3.2.3
|
| 72 |
+
graphql-relay 3.2.0
|
| 73 |
+
grpcio 1.66.0
|
| 74 |
+
gunicorn 23.0.0
|
| 75 |
+
h11 0.14.0
|
| 76 |
+
httpcore 1.0.5
|
| 77 |
+
httptools 0.6.1
|
| 78 |
+
httpx 0.27.0
|
| 79 |
+
huggingface-hub 0.24.6
|
| 80 |
+
humanfriendly 10.0
|
| 81 |
+
idna 3.8
|
| 82 |
+
importlib_metadata 8.0.0
|
| 83 |
+
importlib_resources 6.4.4
|
| 84 |
+
ipykernel 6.29.5
|
| 85 |
+
ipython 8.12.0
|
| 86 |
+
itsdangerous 2.2.0
|
| 87 |
+
jedi 0.19.1
|
| 88 |
+
Jinja2 3.1.4
|
| 89 |
+
jiter 0.5.0
|
| 90 |
+
joblib 1.4.2
|
| 91 |
+
jsonpatch 1.33
|
| 92 |
+
jsonpointer 3.0.0
|
| 93 |
+
jsonschema 4.23.0
|
| 94 |
+
jsonschema-specifications 2023.12.1
|
| 95 |
+
jupyter-client 7.3.4
|
| 96 |
+
jupyter_core 5.7.2
|
| 97 |
+
kiwisolver 1.4.4
|
| 98 |
+
kubernetes 30.1.0
|
| 99 |
+
langchain 0.2.14
|
| 100 |
+
langchain-anthropic 0.1.23
|
| 101 |
+
langchain-community 0.2.12
|
| 102 |
+
langchain-core 0.2.34
|
| 103 |
+
langchain-openai 0.1.22
|
| 104 |
+
langchain-text-splitters 0.2.2
|
| 105 |
+
langsmith 0.1.104
|
| 106 |
+
lightning-utilities 0.11.7
|
| 107 |
+
Mako 1.3.5
|
| 108 |
+
Markdown 3.7
|
| 109 |
+
markdown-it-py 3.0.0
|
| 110 |
+
MarkupSafe 2.1.5
|
| 111 |
+
marshmallow 3.22.0
|
| 112 |
+
matplotlib 3.9.2
|
| 113 |
+
matplotlib-inline 0.1.7
|
| 114 |
+
mdurl 0.1.2
|
| 115 |
+
missingno 0.5.2
|
| 116 |
+
mlflow 2.16.0
|
| 117 |
+
mlflow-skinny 2.16.0
|
| 118 |
+
mmh3 4.1.0
|
| 119 |
+
monotonic 1.6
|
| 120 |
+
mpmath 1.3.0
|
| 121 |
+
multidict 6.0.5
|
| 122 |
+
munkres 1.1.4
|
| 123 |
+
mypy-extensions 1.0.0
|
| 124 |
+
narwhals 1.5.5
|
| 125 |
+
nest_asyncio 1.6.0
|
| 126 |
+
networkx 3.2.1
|
| 127 |
+
nltk 3.9.1
|
| 128 |
+
numexpr 2.10.1
|
| 129 |
+
numpy 1.26.4
|
| 130 |
+
oauthlib 3.2.2
|
| 131 |
+
onnxruntime 1.19.0
|
| 132 |
+
openai 1.42.0
|
| 133 |
+
openpyxl 3.0.9
|
| 134 |
+
opentelemetry-api 1.26.0
|
| 135 |
+
opentelemetry-exporter-otlp-proto-common 1.26.0
|
| 136 |
+
opentelemetry-exporter-otlp-proto-grpc 1.26.0
|
| 137 |
+
opentelemetry-instrumentation 0.47b0
|
| 138 |
+
opentelemetry-instrumentation-asgi 0.47b0
|
| 139 |
+
opentelemetry-instrumentation-fastapi 0.47b0
|
| 140 |
+
opentelemetry-proto 1.26.0
|
| 141 |
+
opentelemetry-sdk 1.26.0
|
| 142 |
+
opentelemetry-semantic-conventions 0.47b0
|
| 143 |
+
opentelemetry-util-http 0.47b0
|
| 144 |
+
orjson 3.10.7
|
| 145 |
+
overrides 7.7.0
|
| 146 |
+
packaging 24.1
|
| 147 |
+
pandas 2.0.3
|
| 148 |
+
parso 0.8.4
|
| 149 |
+
patsy 0.5.6
|
| 150 |
+
pexpect 4.9.0
|
| 151 |
+
pickleshare 0.7.5
|
| 152 |
+
pillow 10.4.0
|
| 153 |
+
pip 24.2
|
| 154 |
+
platformdirs 4.2.2
|
| 155 |
+
plotly 5.24.1
|
| 156 |
+
posthog 3.5.2
|
| 157 |
+
prompt_toolkit 3.0.47
|
| 158 |
+
protobuf 4.25.4
|
| 159 |
+
psutil 5.9.0
|
| 160 |
+
ptyprocess 0.7.0
|
| 161 |
+
pure_eval 0.2.3
|
| 162 |
+
pyarrow 17.0.0
|
| 163 |
+
pyasn1 0.6.0
|
| 164 |
+
pyasn1_modules 0.4.0
|
| 165 |
+
pydantic 2.8.2
|
| 166 |
+
pydantic_core 2.20.1
|
| 167 |
+
pydeck 0.9.1
|
| 168 |
+
pydub 0.25.1
|
| 169 |
+
pygbif 0.6.4
|
| 170 |
+
Pygments 2.18.0
|
| 171 |
+
PyMuPDF 1.24.9
|
| 172 |
+
PyMuPDFb 1.24.9
|
| 173 |
+
pyogrio 0.10.0
|
| 174 |
+
pyparsing 3.1.2
|
| 175 |
+
PyPDF2 3.0.1
|
| 176 |
+
PyPika 0.48.9
|
| 177 |
+
pyproj 3.6.1
|
| 178 |
+
pyproject_hooks 1.1.0
|
| 179 |
+
python-dateutil 2.9.0
|
| 180 |
+
python-dotenv 1.0.1
|
| 181 |
+
python-multipart 0.0.9
|
| 182 |
+
pytz 2024.1
|
| 183 |
+
PyYAML 6.0.2
|
| 184 |
+
pyzmq 25.1.2
|
| 185 |
+
referencing 0.35.1
|
| 186 |
+
regex 2024.7.24
|
| 187 |
+
requests 2.32.3
|
| 188 |
+
requests-cache 1.2.1
|
| 189 |
+
requests-oauthlib 2.0.0
|
| 190 |
+
rich 13.7.1
|
| 191 |
+
rpds-py 0.20.0
|
| 192 |
+
rsa 4.9
|
| 193 |
+
Rtree 1.0.1
|
| 194 |
+
ruff 0.6.2
|
| 195 |
+
scikit-learn 1.5.1
|
| 196 |
+
scipy 1.13.1
|
| 197 |
+
seaborn 0.13.2
|
| 198 |
+
semantic-version 2.10.0
|
| 199 |
+
setuptools 72.1.0
|
| 200 |
+
shapely 2.0.5
|
| 201 |
+
shellingham 1.5.4
|
| 202 |
+
six 1.16.0
|
| 203 |
+
smmap 5.0.1
|
| 204 |
+
sniffio 1.3.1
|
| 205 |
+
SQLAlchemy 2.0.32
|
| 206 |
+
sqlparse 0.5.1
|
| 207 |
+
stack-data 0.6.2
|
| 208 |
+
starlette 0.38.2
|
| 209 |
+
statsmodels 0.14.2
|
| 210 |
+
streamlit 1.37.1
|
| 211 |
+
sympy 1.13.2
|
| 212 |
+
tenacity 8.5.0
|
| 213 |
+
threadpoolctl 3.5.0
|
| 214 |
+
tiktoken 0.7.0
|
| 215 |
+
tokenizers 0.20.0
|
| 216 |
+
toml 0.10.2
|
| 217 |
+
tomli 2.0.1
|
| 218 |
+
tomlkit 0.12.0
|
| 219 |
+
torch 2.4.0
|
| 220 |
+
torchmetrics 1.4.1
|
| 221 |
+
tornado 6.1
|
| 222 |
+
tqdm 4.66.5
|
| 223 |
+
traitlets 5.14.3
|
| 224 |
+
typer 0.12.5
|
| 225 |
+
typing_extensions 4.12.2
|
| 226 |
+
typing-inspect 0.9.0
|
| 227 |
+
tzdata 2024.1
|
| 228 |
+
url-normalize 1.4.3
|
| 229 |
+
urllib3 2.2.2
|
| 230 |
+
uvicorn 0.30.6
|
| 231 |
+
uvloop 0.20.0
|
| 232 |
+
watchfiles 0.23.0
|
| 233 |
+
wcwidth 0.2.13
|
| 234 |
+
websocket-client 1.8.0
|
| 235 |
+
websockets 12.0
|
| 236 |
+
Werkzeug 3.0.4
|
| 237 |
+
wget 3.2
|
| 238 |
+
wheel 0.43.0
|
| 239 |
+
wrapt 1.16.0
|
| 240 |
+
yarl 1.9.4
|
| 241 |
+
zipp 3.20.0
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
chromadb
|
| 3 |
+
gradio
|
| 4 |
+
openai
|
| 5 |
+
langchain
|
| 6 |
+
langchain-anthropic
|
| 7 |
+
langchain_community
|
| 8 |
+
pandas
|
| 9 |
+
python-dotenv
|
| 10 |
+
streamlit
|
| 11 |
+
pypdf2
|
| 12 |
+
tiktoken
|
| 13 |
+
streamlit
|
| 14 |
+
langchain-openai
|
| 15 |
+
pymupdf
|
| 16 |
+
openpyxl
|
retriever-evaluation-tutorial.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
temp_results.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|