arbabarshad commited on
Commit
7a56e2a
·
1 Parent(s): de952e4

Add remaining files from agllm-development state

Browse files
.cursorignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv)
2
+ *.ipynb
3
+ *.xlsx
4
+ *.csv
5
+
.gitattributes ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.xlsx filter=lfs diff=lfs merge=lfs -text
2
+ db5/** filter=lfs diff=lfs merge=lfs -text
3
+ *.7z filter=lfs diff=lfs merge=lfs -text
4
+ *.arrow filter=lfs diff=lfs merge=lfs -text
5
+ *.bin filter=lfs diff=lfs merge=lfs -text
6
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
7
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
8
+ *.ftz filter=lfs diff=lfs merge=lfs -text
9
+ *.gz filter=lfs diff=lfs merge=lfs -text
10
+ *.h5 filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
13
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
14
+ *.model filter=lfs diff=lfs merge=lfs -text
15
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
16
+ *.npy filter=lfs diff=lfs merge=lfs -text
17
+ *.npz filter=lfs diff=lfs merge=lfs -text
18
+ *.onnx filter=lfs diff=lfs merge=lfs -text
19
+ *.ot filter=lfs diff=lfs merge=lfs -text
20
+ *.parquet filter=lfs diff=lfs merge=lfs -text
21
+ *.pb filter=lfs diff=lfs merge=lfs -text
22
+ *.pickle filter=lfs diff=lfs merge=lfs -text
23
+ *.pkl filter=lfs diff=lfs merge=lfs -text
24
+ *.pt filter=lfs diff=lfs merge=lfs -text
25
+ *.pth filter=lfs diff=lfs merge=lfs -text
26
+ *.rar filter=lfs diff=lfs merge=lfs -text
27
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
28
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
30
+ *.tar filter=lfs diff=lfs merge=lfs -text
31
+ *.tflite filter=lfs diff=lfs merge=lfs -text
32
+ *.tgz filter=lfs diff=lfs merge=lfs -text
33
+ *.wasm filter=lfs diff=lfs merge=lfs -text
34
+ *.xz filter=lfs diff=lfs merge=lfs -text
35
+ *.zip filter=lfs diff=lfs merge=lfs -text
36
+ *.zst filter=lfs diff=lfs merge=lfs -text
37
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
38
+ vector-databases-deployed/** filter=lfs diff=lfs merge=lfs -text
39
+ *.pdf filter=lfs diff=lfs merge=lfs -text
40
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # First exclude everything in agllm-data
2
+ agllm-data/*
3
+
4
+ # Then explicitly include the corrected and india directories and their contents
5
+ !agllm-data/corrected/
6
+ !agllm-data/corrected/**
7
+ !agllm-data/india/
8
+ !agllm-data/india/**
9
+
10
+ # Other exclusions
11
+ agllm/
12
+ db3/
13
+ # db5/
14
+ db4/
15
+ mlruns/
16
+ /agllm-data.zip
17
+ .env
18
+ vector-databases/
19
+ vector-databases-deployement/
20
+ writing/
21
+ # but include the analysis folder
22
+ !writing/*/analysis/
23
+ # Include specific file
24
+ !writing/65d4fadc59fceb1a54d1aae6/main.tex
25
+ *.pdf
26
+
27
+ # Chloropleth related files
28
+ chloropleth/
29
+ chloropleth_backup/
30
+ *.shp
31
+ *.png
32
+
33
+ # Vector database backups
34
+ vector-databases-deployed-backup/
35
+ *.sqlite3
36
+
37
+ # Binary files
38
+ *.pyc
39
+ __pycache__/
40
+ .DS_Store
41
+
42
+ vector-databases-deployed/
.vscode/settings.json ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ {
2
+ }
AgLLM.code-workspace ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "folders": [
3
+ {
4
+ "path": "."
5
+ }
6
+ ],
7
+ "settings": {}
8
+ }
agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6aef24ba96ca9be054b4f87344d93d3f8f59cbd280ff29d45991977482067717
3
+ size 1185132
agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc953489e31d114f689783aaa50d1e23ce3fcc58618a9b22562e0cad5750af57
3
+ size 980300
agllm-data/corrected/Corrected_weed-prepared_results-all.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e97a0eaa24c101b940d48f1dd1e11345246e2455521b2d79a45c3da108c63195
3
+ size 602025
agllm-data/india/species.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Common Name,species,ETL,Alternate Hosts,Management,Remarks
2
+ Brown Marmorated Stink Bug,Halyomorpha Halys,,horticulture and field crops,"Crush any eggs nymphs or adults you find on plants. Repeat daily for several days.|Botanical extracts: Spray plants with a mixture of 50 g of neem seed extract in 2 L of water. Boil for 15 minutes let cool and spray three times at two-week intervals.|Soapy water: Spray plants with soapy water to remove insects",
agllm_analysis.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
agllm_development_without_secret.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
api_request_parallel_processor.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API REQUEST PARALLEL PROCESSOR
3
+
4
+ Using the OpenAI API to process lots of text quickly takes some care.
5
+ If you trickle in a million API requests one by one, they'll take days to complete.
6
+ If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors.
7
+ To maximize throughput, parallel requests need to be throttled to stay under rate limits.
8
+
9
+ This script parallelizes requests to the OpenAI API while throttling to stay under rate limits.
10
+
11
+ Features:
12
+ - Streams requests from file, to avoid running out of memory for giant jobs
13
+ - Makes requests concurrently, to maximize throughput
14
+ - Throttles request and token usage, to stay under rate limits
15
+ - Retries failed requests up to {max_attempts} times, to avoid missing data
16
+ - Logs errors, to diagnose problems with requests
17
+
18
+ Example command to call script:
19
+ ```
20
+ python examples/api_request_parallel_processor.py \
21
+ --requests_filepath examples/data/example_requests_to_parallel_process.jsonl \
22
+ --save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \
23
+ --request_url https://api.openai.com/v1/embeddings \
24
+ --max_requests_per_minute 1500 \
25
+ --max_tokens_per_minute 6250000 \
26
+ --token_encoding_name cl100k_base \
27
+ --max_attempts 5 \
28
+ --logging_level 20
29
+ ```
30
+
31
+ Inputs:
32
+ - requests_filepath : str
33
+ - path to the file containing the requests to be processed
34
+ - file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field
35
+ - e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}}
36
+ - as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically)
37
+ - an example file is provided at examples/data/example_requests_to_parallel_process.jsonl
38
+ - the code to generate the example file is appended to the bottom of this script
39
+ - save_filepath : str, optional
40
+ - path to the file where the results will be saved
41
+ - file will be a jsonl file, where each line is an array with the original request plus the API response
42
+ - e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}]
43
+ - if omitted, results will be saved to {requests_filename}_results.jsonl
44
+ - request_url : str, optional
45
+ - URL of the API endpoint to call
46
+ - if omitted, will default to "https://api.openai.com/v1/embeddings"
47
+ - api_key : str, optional
48
+ - API key to use
49
+ - if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")}
50
+ - max_requests_per_minute : float, optional
51
+ - target number of requests to make per minute (will make less if limited by tokens)
52
+ - leave headroom by setting this to 50% or 75% of your limit
53
+ - if requests are limiting you, try batching multiple embeddings or completions into one request
54
+ - if omitted, will default to 1,500
55
+ - max_tokens_per_minute : float, optional
56
+ - target number of tokens to use per minute (will use less if limited by requests)
57
+ - leave headroom by setting this to 50% or 75% of your limit
58
+ - if omitted, will default to 125,000
59
+ - token_encoding_name : str, optional
60
+ - name of the token encoding used, as defined in the `tiktoken` package
61
+ - if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`)
62
+ - max_attempts : int, optional
63
+ - number of times to retry a failed request before giving up
64
+ - if omitted, will default to 5
65
+ - logging_level : int, optional
66
+ - level of logging to use; higher numbers will log fewer messages
67
+ - 40 = ERROR; will log only when requests fail after all retries
68
+ - 30 = WARNING; will log when requests his rate limits or other errors
69
+ - 20 = INFO; will log when requests start and the status at finish
70
+ - 10 = DEBUG; will log various things as the loop runs to see when they occur
71
+ - if omitted, will default to 20 (INFO).
72
+
73
+ The script is structured as follows:
74
+ - Imports
75
+ - Define main()
76
+ - Initialize things
77
+ - In main loop:
78
+ - Get next request if one is not already waiting for capacity
79
+ - Update available token & request capacity
80
+ - If enough capacity available, call API
81
+ - The loop pauses if a rate limit error is hit
82
+ - The loop breaks when no tasks remain
83
+ - Define dataclasses
84
+ - StatusTracker (stores script metadata counters; only one instance is created)
85
+ - APIRequest (stores API inputs, outputs, metadata; one method to call API)
86
+ - Define functions
87
+ - api_endpoint_from_url (extracts API endpoint from request URL)
88
+ - append_to_jsonl (writes to results file)
89
+ - num_tokens_consumed_from_request (bigger function to infer token usage from request)
90
+ - task_id_generator_function (yields 0, 1, 2, ...)
91
+ - Run main()
92
+ """
93
+
94
+ # imports
95
+ import aiohttp # for making API calls concurrently
96
+ import argparse # for running script from command line
97
+ import asyncio # for running API calls concurrently
98
+ import json # for saving results to a jsonl file
99
+ import logging # for logging rate limit warnings and other messages
100
+ import os # for reading API key
101
+ import re # for matching endpoint from request URL
102
+ import tiktoken # for counting tokens
103
+ import time # for sleeping after rate limit is hit
104
+ from dataclasses import (
105
+ dataclass,
106
+ field,
107
+ ) # for storing API inputs, outputs, and metadata
108
+
109
+
110
+ async def process_api_requests_from_file(
111
+ requests_filepath: str,
112
+ save_filepath: str,
113
+ request_url: str,
114
+ api_key: str,
115
+ max_requests_per_minute: float,
116
+ max_tokens_per_minute: float,
117
+ token_encoding_name: str,
118
+ max_attempts: int,
119
+ logging_level: int,
120
+ ):
121
+ """Processes API requests in parallel, throttling to stay under rate limits."""
122
+ # constants
123
+ seconds_to_pause_after_rate_limit_error = 15
124
+ seconds_to_sleep_each_loop = (
125
+ 0.001 # 1 ms limits max throughput to 1,000 requests per second
126
+ )
127
+
128
+ # initialize logging
129
+ logging.basicConfig(level=logging_level)
130
+ logging.debug(f"Logging initialized at level {logging_level}")
131
+
132
+ # infer API endpoint and construct request header
133
+ api_endpoint = api_endpoint_from_url(request_url)
134
+ request_header = {"Authorization": f"Bearer {api_key}"}
135
+ # use api-key header for Azure deployments
136
+ if '/deployments' in request_url:
137
+ request_header = {"api-key": f"{api_key}"}
138
+
139
+ # initialize trackers
140
+ queue_of_requests_to_retry = asyncio.Queue()
141
+ task_id_generator = (
142
+ task_id_generator_function()
143
+ ) # generates integer IDs of 0, 1, 2, ...
144
+ status_tracker = (
145
+ StatusTracker()
146
+ ) # single instance to track a collection of variables
147
+ next_request = None # variable to hold the next request to call
148
+
149
+ # initialize available capacity counts
150
+ available_request_capacity = max_requests_per_minute
151
+ available_token_capacity = max_tokens_per_minute
152
+ last_update_time = time.time()
153
+
154
+ # initialize flags
155
+ file_not_finished = True # after file is empty, we'll skip reading it
156
+ logging.debug(f"Initialization complete.")
157
+
158
+ # initialize file reading
159
+ with open(requests_filepath) as file:
160
+ # `requests` will provide requests one at a time
161
+ requests = file.__iter__()
162
+ logging.debug(f"File opened. Entering main loop")
163
+ async with aiohttp.ClientSession() as session: # Initialize ClientSession here
164
+ while True:
165
+ # get next request (if one is not already waiting for capacity)
166
+ if next_request is None:
167
+ if not queue_of_requests_to_retry.empty():
168
+ next_request = queue_of_requests_to_retry.get_nowait()
169
+ logging.debug(
170
+ f"Retrying request {next_request.task_id}: {next_request}"
171
+ )
172
+ elif file_not_finished:
173
+ try:
174
+ # get new request
175
+ request_json = json.loads(next(requests))
176
+ next_request = APIRequest(
177
+ task_id=next(task_id_generator),
178
+ request_json=request_json,
179
+ token_consumption=num_tokens_consumed_from_request(
180
+ request_json, api_endpoint, token_encoding_name
181
+ ),
182
+ attempts_left=max_attempts,
183
+ metadata=request_json.pop("metadata", None),
184
+ )
185
+ status_tracker.num_tasks_started += 1
186
+ status_tracker.num_tasks_in_progress += 1
187
+ logging.debug(
188
+ f"Reading request {next_request.task_id}: {next_request}"
189
+ )
190
+ except StopIteration:
191
+ # if file runs out, set flag to stop reading it
192
+ logging.debug("Read file exhausted")
193
+ file_not_finished = False
194
+
195
+ # update available capacity
196
+ current_time = time.time()
197
+ seconds_since_update = current_time - last_update_time
198
+ available_request_capacity = min(
199
+ available_request_capacity
200
+ + max_requests_per_minute * seconds_since_update / 60.0,
201
+ max_requests_per_minute,
202
+ )
203
+ available_token_capacity = min(
204
+ available_token_capacity
205
+ + max_tokens_per_minute * seconds_since_update / 60.0,
206
+ max_tokens_per_minute,
207
+ )
208
+ last_update_time = current_time
209
+
210
+ # if enough capacity available, call API
211
+ if next_request:
212
+ next_request_tokens = next_request.token_consumption
213
+ if (
214
+ available_request_capacity >= 1
215
+ and available_token_capacity >= next_request_tokens
216
+ ):
217
+ # update counters
218
+ available_request_capacity -= 1
219
+ available_token_capacity -= next_request_tokens
220
+ next_request.attempts_left -= 1
221
+
222
+ # call API
223
+ asyncio.create_task(
224
+ next_request.call_api(
225
+ session=session,
226
+ request_url=request_url,
227
+ request_header=request_header,
228
+ retry_queue=queue_of_requests_to_retry,
229
+ save_filepath=save_filepath,
230
+ status_tracker=status_tracker,
231
+ )
232
+ )
233
+ next_request = None # reset next_request to empty
234
+
235
+ # if all tasks are finished, break
236
+ if status_tracker.num_tasks_in_progress == 0:
237
+ break
238
+
239
+ # main loop sleeps briefly so concurrent tasks can run
240
+ await asyncio.sleep(seconds_to_sleep_each_loop)
241
+
242
+ # if a rate limit error was hit recently, pause to cool down
243
+ seconds_since_rate_limit_error = (
244
+ time.time() - status_tracker.time_of_last_rate_limit_error
245
+ )
246
+ if (
247
+ seconds_since_rate_limit_error
248
+ < seconds_to_pause_after_rate_limit_error
249
+ ):
250
+ remaining_seconds_to_pause = (
251
+ seconds_to_pause_after_rate_limit_error
252
+ - seconds_since_rate_limit_error
253
+ )
254
+ await asyncio.sleep(remaining_seconds_to_pause)
255
+ # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
256
+ logging.warn(
257
+ f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
258
+ )
259
+
260
+ # after finishing, log final status
261
+ logging.info(
262
+ f"""Parallel processing complete. Results saved to {save_filepath}"""
263
+ )
264
+ if status_tracker.num_tasks_failed > 0:
265
+ logging.warning(
266
+ f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
267
+ )
268
+ if status_tracker.num_rate_limit_errors > 0:
269
+ logging.warning(
270
+ f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
271
+ )
272
+
273
+
274
+ # dataclasses
275
+
276
+
277
+ @dataclass
278
+ class StatusTracker:
279
+ """Stores metadata about the script's progress. Only one instance is created."""
280
+
281
+ num_tasks_started: int = 0
282
+ num_tasks_in_progress: int = 0 # script ends when this reaches 0
283
+ num_tasks_succeeded: int = 0
284
+ num_tasks_failed: int = 0
285
+ num_rate_limit_errors: int = 0
286
+ num_api_errors: int = 0 # excluding rate limit errors, counted above
287
+ num_other_errors: int = 0
288
+ time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits
289
+
290
+
291
+ @dataclass
292
+ class APIRequest:
293
+ """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call."""
294
+
295
+ task_id: int
296
+ request_json: dict
297
+ token_consumption: int
298
+ attempts_left: int
299
+ metadata: dict
300
+ result: list = field(default_factory=list)
301
+
302
+ async def call_api(
303
+ self,
304
+ session: aiohttp.ClientSession,
305
+ request_url: str,
306
+ request_header: dict,
307
+ retry_queue: asyncio.Queue,
308
+ save_filepath: str,
309
+ status_tracker: StatusTracker,
310
+ ):
311
+ """Calls the OpenAI API and saves results."""
312
+ logging.info(f"Starting request #{self.task_id}")
313
+ error = None
314
+ try:
315
+ async with session.post(
316
+ url=request_url, headers=request_header, json=self.request_json
317
+ ) as response:
318
+ response = await response.json()
319
+ if "error" in response:
320
+ logging.warning(
321
+ f"Request {self.task_id} failed with error {response['error']}"
322
+ )
323
+ status_tracker.num_api_errors += 1
324
+ error = response
325
+ if "Rate limit" in response["error"].get("message", ""):
326
+ status_tracker.time_of_last_rate_limit_error = time.time()
327
+ status_tracker.num_rate_limit_errors += 1
328
+ status_tracker.num_api_errors -= (
329
+ 1 # rate limit errors are counted separately
330
+ )
331
+
332
+ except (
333
+ Exception
334
+ ) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
335
+ logging.warning(f"Request {self.task_id} failed with Exception {e}")
336
+ status_tracker.num_other_errors += 1
337
+ error = e
338
+ if error:
339
+ self.result.append(error)
340
+ if self.attempts_left:
341
+ retry_queue.put_nowait(self)
342
+ else:
343
+ logging.error(
344
+ f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
345
+ )
346
+ data = (
347
+ [self.request_json, [str(e) for e in self.result], self.metadata]
348
+ if self.metadata
349
+ else [self.request_json, [str(e) for e in self.result]]
350
+ )
351
+ append_to_jsonl(data, save_filepath)
352
+ status_tracker.num_tasks_in_progress -= 1
353
+ status_tracker.num_tasks_failed += 1
354
+ else:
355
+ data = (
356
+ [self.request_json, response, self.metadata]
357
+ if self.metadata
358
+ else [self.request_json, response]
359
+ )
360
+ append_to_jsonl(data, save_filepath)
361
+ status_tracker.num_tasks_in_progress -= 1
362
+ status_tracker.num_tasks_succeeded += 1
363
+ logging.debug(f"Request {self.task_id} saved to {save_filepath}")
364
+
365
+
366
+ # functions
367
+
368
+
369
+ def api_endpoint_from_url(request_url):
370
+ """Extract the API endpoint from the request URL."""
371
+ match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url)
372
+ if match is None:
373
+ # for Azure OpenAI deployment urls
374
+ match = re.search(r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url)
375
+ return match[1]
376
+
377
+
378
+ def append_to_jsonl(data, filename: str) -> None:
379
+ """Append a json payload to the end of a jsonl file."""
380
+ json_string = json.dumps(data)
381
+ with open(filename, "a") as f:
382
+ f.write(json_string + "\n")
383
+
384
+
385
+ def num_tokens_consumed_from_request(
386
+ request_json: dict,
387
+ api_endpoint: str,
388
+ token_encoding_name: str,
389
+ ):
390
+ """Count the number of tokens in the request. Only supports completion and embedding requests."""
391
+ encoding = tiktoken.get_encoding(token_encoding_name)
392
+ # if completions request, tokens = prompt + n * max_tokens
393
+ if api_endpoint.endswith("completions"):
394
+ max_tokens = request_json.get("max_tokens", 15)
395
+ n = request_json.get("n", 1)
396
+ completion_tokens = n * max_tokens
397
+
398
+ # chat completions
399
+ if api_endpoint.startswith("chat/"):
400
+ num_tokens = 0
401
+ for message in request_json["messages"]:
402
+ num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
403
+ for key, value in message.items():
404
+ num_tokens += len(encoding.encode(value))
405
+ if key == "name": # if there's a name, the role is omitted
406
+ num_tokens -= 1 # role is always required and always 1 token
407
+ num_tokens += 2 # every reply is primed with <im_start>assistant
408
+ return num_tokens + completion_tokens
409
+ # normal completions
410
+ else:
411
+ prompt = request_json["prompt"]
412
+ if isinstance(prompt, str): # single prompt
413
+ prompt_tokens = len(encoding.encode(prompt))
414
+ num_tokens = prompt_tokens + completion_tokens
415
+ return num_tokens
416
+ elif isinstance(prompt, list): # multiple prompts
417
+ prompt_tokens = sum([len(encoding.encode(p)) for p in prompt])
418
+ num_tokens = prompt_tokens + completion_tokens * len(prompt)
419
+ return num_tokens
420
+ else:
421
+ raise TypeError(
422
+ 'Expecting either string or list of strings for "prompt" field in completion request'
423
+ )
424
+ # if embeddings request, tokens = input tokens
425
+ elif api_endpoint == "embeddings":
426
+ input = request_json["input"]
427
+ if isinstance(input, str): # single input
428
+ num_tokens = len(encoding.encode(input))
429
+ return num_tokens
430
+ elif isinstance(input, list): # multiple inputs
431
+ num_tokens = sum([len(encoding.encode(i)) for i in input])
432
+ return num_tokens
433
+ else:
434
+ raise TypeError(
435
+ 'Expecting either string or list of strings for "inputs" field in embedding request'
436
+ )
437
+ # more logic needed to support other API calls (e.g., edits, inserts, DALL-E)
438
+ else:
439
+ raise NotImplementedError(
440
+ f'API endpoint "{api_endpoint}" not implemented in this script'
441
+ )
442
+
443
+
444
+ def task_id_generator_function():
445
+ """Generate integers 0, 1, 2, and so on."""
446
+ task_id = 0
447
+ while True:
448
+ yield task_id
449
+ task_id += 1
450
+
451
+
452
+ # run script
453
+
454
+
455
+ if __name__ == "__main__":
456
+ # parse command line arguments
457
+ parser = argparse.ArgumentParser()
458
+ parser.add_argument("--requests_filepath")
459
+ parser.add_argument("--save_filepath", default=None)
460
+ parser.add_argument("--request_url", default="https://api.openai.com/v1/embeddings")
461
+ parser.add_argument("--api_key", default=os.getenv("OPENAI_API_KEY"))
462
+ parser.add_argument("--max_requests_per_minute", type=int, default=3_000 * 0.5)
463
+ parser.add_argument("--max_tokens_per_minute", type=int, default=250_000 * 0.5)
464
+ parser.add_argument("--token_encoding_name", default="cl100k_base")
465
+ parser.add_argument("--max_attempts", type=int, default=5)
466
+ parser.add_argument("--logging_level", default=logging.INFO)
467
+ args = parser.parse_args()
468
+
469
+ if args.save_filepath is None:
470
+ args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl")
471
+
472
+ # run script
473
+ asyncio.run(
474
+ process_api_requests_from_file(
475
+ requests_filepath=args.requests_filepath,
476
+ save_filepath=args.save_filepath,
477
+ request_url=args.request_url,
478
+ api_key=args.api_key,
479
+ max_requests_per_minute=float(args.max_requests_per_minute),
480
+ max_tokens_per_minute=float(args.max_tokens_per_minute),
481
+ token_encoding_name=args.token_encoding_name,
482
+ max_attempts=int(args.max_attempts),
483
+ logging_level=int(args.logging_level),
484
+ )
485
+ )
486
+
487
+
488
+ """
489
+ APPENDIX
490
+
491
+ The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002.
492
+
493
+ It was generated with the following code:
494
+
495
+ ```python
496
+ import json
497
+
498
+ filename = "data/example_requests_to_parallel_process.jsonl"
499
+ n_requests = 10_000
500
+ jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)]
501
+ with open(filename, "w") as f:
502
+ for job in jobs:
503
+ json_string = json.dumps(job)
504
+ f.write(json_string + "\n")
505
+ ```
506
+
507
+ As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
508
+ """
api_request_parallel_processor_anthropic.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("IMPORTED")
2
+ """
3
+ API REQUEST PARALLEL PROCESSOR
4
+
5
+ Using the OpenAI API to process lots of text quickly takes some care.
6
+ If you trickle in a million API requests one by one, they'll take days to complete.
7
+ If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors.
8
+ To maximize throughput, parallel requests need to be throttled to stay under rate limits.
9
+
10
+ This script parallelizes requests to the OpenAI API while throttling to stay under rate limits.
11
+
12
+ Features:
13
+ - Streams requests from file, to avoid running out of memory for giant jobs
14
+ - Makes requests concurrently, to maximize throughput
15
+ - Throttles request and token usage, to stay under rate limits
16
+ - Retries failed requests up to {max_attempts} times, to avoid missing data
17
+ - Logs errors, to diagnose problems with requests
18
+
19
+ Example command to call script:
20
+ ```
21
+ python examples/api_request_parallel_processor.py \
22
+ --requests_filepath examples/data/example_requests_to_parallel_process.jsonl \
23
+ --save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \
24
+ --request_url https://api.openai.com/v1/embeddings \
25
+ --max_requests_per_minute 1500 \
26
+ --max_tokens_per_minute 6250000 \
27
+ --token_encoding_name cl100k_base \
28
+ --max_attempts 5 \
29
+ --logging_level 20
30
+ ```
31
+
32
+ Inputs:
33
+ - requests_filepath : str
34
+ - path to the file containing the requests to be processed
35
+ - file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field
36
+ - e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}}
37
+ - as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically)
38
+ - an example file is provided at examples/data/example_requests_to_parallel_process.jsonl
39
+ - the code to generate the example file is appended to the bottom of this script
40
+ - save_filepath : str, optional
41
+ - path to the file where the results will be saved
42
+ - file will be a jsonl file, where each line is an array with the original request plus the API response
43
+ - e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}]
44
+ - if omitted, results will be saved to {requests_filename}_results.jsonl
45
+ - request_url : str, optional
46
+ - URL of the API endpoint to call
47
+ - if omitted, will default to "https://api.openai.com/v1/embeddings"
48
+ - api_key : str, optional
49
+ - API key to use
50
+ - if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")}
51
+ - max_requests_per_minute : float, optional
52
+ - target number of requests to make per minute (will make less if limited by tokens)
53
+ - leave headroom by setting this to 50% or 75% of your limit
54
+ - if requests are limiting you, try batching multiple embeddings or completions into one request
55
+ - if omitted, will default to 1,500
56
+ - max_tokens_per_minute : float, optional
57
+ - target number of tokens to use per minute (will use less if limited by requests)
58
+ - leave headroom by setting this to 50% or 75% of your limit
59
+ - if omitted, will default to 125,000
60
+ - token_encoding_name : str, optional
61
+ - name of the token encoding used, as defined in the `tiktoken` package
62
+ - if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`)
63
+ - max_attempts : int, optional
64
+ - number of times to retry a failed request before giving up
65
+ - if omitted, will default to 5
66
+ - logging_level : int, optional
67
+ - level of logging to use; higher numbers will log fewer messages
68
+ - 40 = ERROR; will log only when requests fail after all retries
69
+ - 30 = WARNING; will log when requests his rate limits or other errors
70
+ - 20 = INFO; will log when requests start and the status at finish
71
+ - 10 = DEBUG; will log various things as the loop runs to see when they occur
72
+ - if omitted, will default to 20 (INFO).
73
+
74
+ The script is structured as follows:
75
+ - Imports
76
+ - Define main()
77
+ - Initialize things
78
+ - In main loop:
79
+ - Get next request if one is not already waiting for capacity
80
+ - Update available token & request capacity
81
+ - If enough capacity available, call API
82
+ - The loop pauses if a rate limit error is hit
83
+ - The loop breaks when no tasks remain
84
+ - Define dataclasses
85
+ - StatusTracker (stores script metadata counters; only one instance is created)
86
+ - APIRequest (stores API inputs, outputs, metadata; one method to call API)
87
+ - Define functions
88
+ - api_endpoint_from_url (extracts API endpoint from request URL)
89
+ - append_to_jsonl (writes to results file)
90
+ - num_tokens_consumed_from_request (bigger function to infer token usage from request)
91
+ - task_id_generator_function (yields 0, 1, 2, ...)
92
+ - Run main()
93
+ """
94
+
95
+ # imports
96
+ import aiohttp # for making API calls concurrently
97
+ import argparse # for running script from command line
98
+ import asyncio # for running API calls concurrently
99
+ import json # for saving results to a jsonl file
100
+ import logging # for logging rate limit warnings and other messages
101
+ import os # for reading API key
102
+ import re # for matching endpoint from request URL
103
+ import tiktoken # for counting tokens
104
+ import time # for sleeping after rate limit is hit
105
+ from dataclasses import (
106
+ dataclass,
107
+ field,
108
+ ) # for storing API inputs, outputs, and metadata
109
+
110
+
111
+ async def process_api_requests_from_file(
112
+ requests_filepath: str,
113
+ save_filepath: str,
114
+ request_url: str,
115
+ api_key: str,
116
+ max_requests_per_minute: float,
117
+ max_tokens_per_minute: float,
118
+ token_encoding_name: str,
119
+ max_attempts: int,
120
+ logging_level: int,
121
+ ):
122
+ """Processes API requests in parallel, throttling to stay under rate limits."""
123
+ # constants
124
+ seconds_to_pause_after_rate_limit_error = 15
125
+ seconds_to_sleep_each_loop = (
126
+ 0.001 # 1 ms limits max throughput to 1,000 requests per second
127
+ )
128
+
129
+ # initialize logging
130
+ logging.basicConfig(level=logging_level)
131
+ logging.debug(f"Logging initialized at level {logging_level}")
132
+
133
+ # infer API endpoint and construct request header
134
+ api_endpoint = api_endpoint_from_url(request_url)
135
+ # request_header = {"Authorization": f"Bearer {api_key}"}
136
+ request_header = {
137
+ "x-api-key": api_key,
138
+ "anthropic-version": "2023-06-01",
139
+ "content-type": "application/json",
140
+ }
141
+ # use api-key header for Azure deployments
142
+ if '/deployments' in request_url:
143
+ request_header = {"api-key": f"{api_key}"}
144
+
145
+ # initialize trackers
146
+ queue_of_requests_to_retry = asyncio.Queue()
147
+ task_id_generator = (
148
+ task_id_generator_function()
149
+ ) # generates integer IDs of 0, 1, 2, ...
150
+ status_tracker = (
151
+ StatusTracker()
152
+ ) # single instance to track a collection of variables
153
+ next_request = None # variable to hold the next request to call
154
+
155
+ # initialize available capacity counts
156
+ available_request_capacity = max_requests_per_minute
157
+ available_token_capacity = max_tokens_per_minute
158
+ last_update_time = time.time()
159
+
160
+ # initialize flags
161
+ file_not_finished = True # after file is empty, we'll skip reading it
162
+ logging.debug(f"Initialization complete.")
163
+
164
+ # initialize file reading
165
+ with open(requests_filepath) as file:
166
+ # `requests` will provide requests one at a time
167
+ requests = file.__iter__()
168
+ logging.debug(f"File opened. Entering main loop")
169
+ async with aiohttp.ClientSession() as session: # Initialize ClientSession here
170
+ while True:
171
+ # get next request (if one is not already waiting for capacity)
172
+ if next_request is None:
173
+ if not queue_of_requests_to_retry.empty():
174
+ next_request = queue_of_requests_to_retry.get_nowait()
175
+ logging.debug(
176
+ f"Retrying request {next_request.task_id}: {next_request}"
177
+ )
178
+ elif file_not_finished:
179
+ try:
180
+ # get new request
181
+ request_json = json.loads(next(requests))
182
+ next_request = APIRequest(
183
+ task_id=next(task_id_generator),
184
+ request_json=request_json,
185
+ token_consumption=num_tokens_consumed_from_request(
186
+ request_json, api_endpoint, token_encoding_name
187
+ ),
188
+ attempts_left=max_attempts,
189
+ metadata=request_json.pop("metadata", None),
190
+ )
191
+ status_tracker.num_tasks_started += 1
192
+ status_tracker.num_tasks_in_progress += 1
193
+ logging.debug(
194
+ f"Reading request {next_request.task_id}: {next_request}"
195
+ )
196
+ except StopIteration:
197
+ # if file runs out, set flag to stop reading it
198
+ logging.debug("Read file exhausted")
199
+ file_not_finished = False
200
+
201
+ # update available capacity
202
+ current_time = time.time()
203
+ seconds_since_update = current_time - last_update_time
204
+ available_request_capacity = min(
205
+ available_request_capacity
206
+ + max_requests_per_minute * seconds_since_update / 60.0,
207
+ max_requests_per_minute,
208
+ )
209
+ available_token_capacity = min(
210
+ available_token_capacity
211
+ + max_tokens_per_minute * seconds_since_update / 60.0,
212
+ max_tokens_per_minute,
213
+ )
214
+ last_update_time = current_time
215
+
216
+ # if enough capacity available, call API
217
+ if next_request:
218
+ next_request_tokens = next_request.token_consumption
219
+ if (
220
+ available_request_capacity >= 1
221
+ and available_token_capacity >= next_request_tokens
222
+ ):
223
+ # update counters
224
+ available_request_capacity -= 1
225
+ available_token_capacity -= next_request_tokens
226
+ next_request.attempts_left -= 1
227
+
228
+ # call API
229
+ asyncio.create_task(
230
+ next_request.call_api(
231
+ session=session,
232
+ request_url=request_url,
233
+ request_header=request_header,
234
+ retry_queue=queue_of_requests_to_retry,
235
+ save_filepath=save_filepath,
236
+ status_tracker=status_tracker,
237
+ )
238
+ )
239
+ next_request = None # reset next_request to empty
240
+
241
+ # if all tasks are finished, break
242
+ if status_tracker.num_tasks_in_progress == 0:
243
+ break
244
+
245
+ # main loop sleeps briefly so concurrent tasks can run
246
+ await asyncio.sleep(seconds_to_sleep_each_loop)
247
+
248
+ # if a rate limit error was hit recently, pause to cool down
249
+ seconds_since_rate_limit_error = (
250
+ time.time() - status_tracker.time_of_last_rate_limit_error
251
+ )
252
+ if (
253
+ seconds_since_rate_limit_error
254
+ < seconds_to_pause_after_rate_limit_error
255
+ ):
256
+ remaining_seconds_to_pause = (
257
+ seconds_to_pause_after_rate_limit_error
258
+ - seconds_since_rate_limit_error
259
+ )
260
+ await asyncio.sleep(remaining_seconds_to_pause)
261
+ # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
262
+ logging.warn(
263
+ f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
264
+ )
265
+
266
+ # after finishing, log final status
267
+ logging.info(
268
+ f"""Parallel processing complete. Results saved to {save_filepath}"""
269
+ )
270
+ if status_tracker.num_tasks_failed > 0:
271
+ logging.warning(
272
+ f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
273
+ )
274
+ if status_tracker.num_rate_limit_errors > 0:
275
+ logging.warning(
276
+ f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
277
+ )
278
+
279
+
280
+ # dataclasses
281
+
282
+
283
+ @dataclass
284
+ class StatusTracker:
285
+ """Stores metadata about the script's progress. Only one instance is created."""
286
+
287
+ num_tasks_started: int = 0
288
+ num_tasks_in_progress: int = 0 # script ends when this reaches 0
289
+ num_tasks_succeeded: int = 0
290
+ num_tasks_failed: int = 0
291
+ num_rate_limit_errors: int = 0
292
+ num_api_errors: int = 0 # excluding rate limit errors, counted above
293
+ num_other_errors: int = 0
294
+ time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits
295
+
296
+
297
+ @dataclass
298
+ class APIRequest:
299
+ """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call."""
300
+
301
+ task_id: int
302
+ request_json: dict
303
+ token_consumption: int
304
+ attempts_left: int
305
+ metadata: dict
306
+ result: list = field(default_factory=list)
307
+
308
+ async def call_api(
309
+ self,
310
+ session: aiohttp.ClientSession,
311
+ request_url: str,
312
+ request_header: dict,
313
+ retry_queue: asyncio.Queue,
314
+ save_filepath: str,
315
+ status_tracker: StatusTracker,
316
+ ):
317
+ """Calls the OpenAI API and saves results."""
318
+ logging.info(f"Starting request #{self.task_id}")
319
+ error = None
320
+ try:
321
+ async with session.post(
322
+ url=request_url, headers=request_header, json=self.request_json
323
+ ) as response:
324
+ response = await response.json()
325
+ if "error" in response:
326
+ logging.warning(
327
+ f"Request {self.task_id} failed with error {response['error']}"
328
+ )
329
+ status_tracker.num_api_errors += 1
330
+ error = response
331
+ if "Rate limit" in response["error"].get("message", ""):
332
+ status_tracker.time_of_last_rate_limit_error = time.time()
333
+ status_tracker.num_rate_limit_errors += 1
334
+ status_tracker.num_api_errors -= (
335
+ 1 # rate limit errors are counted separately
336
+ )
337
+
338
+ except (
339
+ Exception
340
+ ) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
341
+ logging.warning(f"Request {self.task_id} failed with Exception {e}")
342
+ status_tracker.num_other_errors += 1
343
+ error = e
344
+ if error:
345
+ self.result.append(error)
346
+ if self.attempts_left:
347
+ retry_queue.put_nowait(self)
348
+ else:
349
+ logging.error(
350
+ f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
351
+ )
352
+ data = (
353
+ [self.request_json, [str(e) for e in self.result], self.metadata]
354
+ if self.metadata
355
+ else [self.request_json, [str(e) for e in self.result]]
356
+ )
357
+ append_to_jsonl(data, save_filepath)
358
+ status_tracker.num_tasks_in_progress -= 1
359
+ status_tracker.num_tasks_failed += 1
360
+ else:
361
+ data = (
362
+ [self.request_json, response, self.metadata]
363
+ if self.metadata
364
+ else [self.request_json, response]
365
+ )
366
+ append_to_jsonl(data, save_filepath)
367
+ status_tracker.num_tasks_in_progress -= 1
368
+ status_tracker.num_tasks_succeeded += 1
369
+ logging.debug(f"Request {self.task_id} saved to {save_filepath}")
370
+
371
+
372
+ # functions
373
+
374
+
375
+ def api_endpoint_from_url(request_url):
376
+ print(request_url)
377
+ """Extract the API endpoint from the request URL."""
378
+ match = re.search(r"^https://[^/]+/v1/(.+)$", request_url)
379
+ if match:
380
+ return match[1]
381
+ else:
382
+ raise ValueError(f"Invalid API URL: {request_url}")
383
+
384
+
385
+ def append_to_jsonl(data, filename: str) -> None:
386
+ """Append a json payload to the end of a jsonl file."""
387
+ json_string = json.dumps(data)
388
+ with open(filename, "a") as f:
389
+ f.write(json_string + "\n")
390
+
391
+
392
+ def num_tokens_consumed_from_request(
393
+ request_json: dict,
394
+ api_endpoint: str,
395
+ token_encoding_name: str,
396
+ ):
397
+ encoding = tiktoken.get_encoding(token_encoding_name)
398
+ print(api_endpoint)
399
+ if api_endpoint == "messages":
400
+ num_tokens = 0
401
+ for message in request_json["messages"]:
402
+ num_tokens += len(encoding.encode(message["content"]))
403
+ return num_tokens
404
+ else:
405
+ raise NotImplementedError(
406
+ f'API endpoint "{api_endpoint}" not implemented in this script'
407
+ )
408
+
409
+ def task_id_generator_function():
410
+ """Generate integers 0, 1, 2, and so on."""
411
+ task_id = 0
412
+ while True:
413
+ yield task_id
414
+ task_id += 1
415
+
416
+
417
+ # run script
418
+
419
+
420
+ if __name__ == "__main__":
421
+ # parse command line arguments
422
+ parser = argparse.ArgumentParser()
423
+ parser.add_argument("--requests_filepath")
424
+ parser.add_argument("--save_filepath", default=None)
425
+ parser.add_argument("--request_url", default="https://api.openai.com/v1/embeddings")
426
+ parser.add_argument("--api_key", default=os.getenv("ANTHROPIC_API_KEY"))
427
+ parser.add_argument("--max_requests_per_minute", type=int, default=3_000 * 0.5)
428
+ parser.add_argument("--max_tokens_per_minute", type=int, default=250_000 * 0.5)
429
+ parser.add_argument("--token_encoding_name", default="cl100k_base")
430
+ parser.add_argument("--max_attempts", type=int, default=5)
431
+ parser.add_argument("--logging_level", default=logging.INFO)
432
+ args = parser.parse_args()
433
+
434
+ if args.save_filepath is None:
435
+ args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl")
436
+
437
+ # run script
438
+ asyncio.run(
439
+ process_api_requests_from_file(
440
+ requests_filepath=args.requests_filepath,
441
+ save_filepath=args.save_filepath,
442
+ request_url=args.request_url,
443
+ api_key=args.api_key,
444
+ max_requests_per_minute=float(args.max_requests_per_minute),
445
+ max_tokens_per_minute=float(args.max_tokens_per_minute),
446
+ token_encoding_name=args.token_encoding_name,
447
+ max_attempts=int(args.max_attempts),
448
+ logging_level=int(args.logging_level),
449
+ )
450
+ )
451
+
452
+
453
+ """
454
+ APPENDIX
455
+
456
+ The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002.
457
+
458
+ It was generated with the following code:
459
+
460
+ ```python
461
+ import json
462
+
463
+ filename = "data/example_requests_to_parallel_process.jsonl"
464
+ n_requests = 10_000
465
+ jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)]
466
+ with open(filename, "w") as f:
467
+ for job in jobs:
468
+ json_string = json.dumps(job)
469
+ f.write(json_string + "\n")
470
+ ```
471
+
472
+ As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
473
+ """
api_request_parallel_processor_universal.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API REQUEST PARALLEL PROCESSOR
3
+
4
+ Using the OpenAI API to process lots of text quickly takes some care.
5
+ If you trickle in a million API requests one by one, they'll take days to complete.
6
+ If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors.
7
+ To maximize throughput, parallel requests need to be throttled to stay under rate limits.
8
+
9
+ This script parallelizes requests to the OpenAI API while throttling to stay under rate limits.
10
+
11
+ Features:
12
+ - Streams requests from file, to avoid running out of memory for giant jobs
13
+ - Makes requests concurrently, to maximize throughput
14
+ - Throttles request and token usage, to stay under rate limits
15
+ - Retries failed requests up to {max_attempts} times, to avoid missing data
16
+ - Logs errors, to diagnose problems with requests
17
+
18
+ Example command to call script:
19
+ ```
20
+ python examples/api_request_parallel_processor.py \
21
+ --requests_filepath examples/data/example_requests_to_parallel_process.jsonl \
22
+ --save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \
23
+ --request_url https://api.openai.com/v1/embeddings \
24
+ --max_requests_per_minute 1500 \
25
+ --max_tokens_per_minute 6250000 \
26
+ --token_encoding_name cl100k_base \
27
+ --max_attempts 5 \
28
+ --logging_level 20
29
+ ```
30
+
31
+ Inputs:
32
+ - requests_filepath : str
33
+ - path to the file containing the requests to be processed
34
+ - file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field
35
+ - e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}}
36
+ - as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically)
37
+ - an example file is provided at examples/data/example_requests_to_parallel_process.jsonl
38
+ - the code to generate the example file is appended to the bottom of this script
39
+ - save_filepath : str, optional
40
+ - path to the file where the results will be saved
41
+ - file will be a jsonl file, where each line is an array with the original request plus the API response
42
+ - e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}]
43
+ - if omitted, results will be saved to {requests_filename}_results.jsonl
44
+ - request_url : str, optional
45
+ - URL of the API endpoint to call
46
+ - if omitted, will default to "https://api.openai.com/v1/embeddings"
47
+ - api_key : str, optional
48
+ - API key to use
49
+ - if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")}
50
+ - max_requests_per_minute : float, optional
51
+ - target number of requests to make per minute (will make less if limited by tokens)
52
+ - leave headroom by setting this to 50% or 75% of your limit
53
+ - if requests are limiting you, try batching multiple embeddings or completions into one request
54
+ - if omitted, will default to 1,500
55
+ - max_tokens_per_minute : float, optional
56
+ - target number of tokens to use per minute (will use less if limited by requests)
57
+ - leave headroom by setting this to 50% or 75% of your limit
58
+ - if omitted, will default to 125,000
59
+ - token_encoding_name : str, optional
60
+ - name of the token encoding used, as defined in the `tiktoken` package
61
+ - if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`)
62
+ - max_attempts : int, optional
63
+ - number of times to retry a failed request before giving up
64
+ - if omitted, will default to 5
65
+ - logging_level : int, optional
66
+ - level of logging to use; higher numbers will log fewer messages
67
+ - 40 = ERROR; will log only when requests fail after all retries
68
+ - 30 = WARNING; will log when requests his rate limits or other errors
69
+ - 20 = INFO; will log when requests start and the status at finish
70
+ - 10 = DEBUG; will log various things as the loop runs to see when they occur
71
+ - if omitted, will default to 20 (INFO).
72
+
73
+ The script is structured as follows:
74
+ - Imports
75
+ - Define main()
76
+ - Initialize things
77
+ - In main loop:
78
+ - Get next request if one is not already waiting for capacity
79
+ - Update available token & request capacity
80
+ - If enough capacity available, call API
81
+ - The loop pauses if a rate limit error is hit
82
+ - The loop breaks when no tasks remain
83
+ - Define dataclasses
84
+ - StatusTracker (stores script metadata counters; only one instance is created)
85
+ - APIRequest (stores API inputs, outputs, metadata; one method to call API)
86
+ - Define functions
87
+ - api_endpoint_from_url (extracts API endpoint from request URL)
88
+ - append_to_jsonl (writes to results file)
89
+ - num_tokens_consumed_from_request (bigger function to infer token usage from request)
90
+ - task_id_generator_function (yields 0, 1, 2, ...)
91
+ - Run main()
92
+ """
93
+
94
+ # imports
95
+ import aiohttp # for making API calls concurrently
96
+ import argparse # for running script from command line
97
+ import asyncio # for running API calls concurrently
98
+ import json # for saving results to a jsonl file
99
+ import logging # for logging rate limit warnings and other messages
100
+ import os # for reading API key
101
+ import re # for matching endpoint from request URL
102
+ import tiktoken # for counting tokens
103
+ import time # for sleeping after rate limit is hit
104
+ from dataclasses import (
105
+ dataclass,
106
+ field,
107
+ ) # for storing API inputs, outputs, and metadata
108
+
109
+
110
+ async def process_api_requests_from_file(
111
+ vendor_name: str,
112
+ requests_filepath: str,
113
+ save_filepath: str,
114
+ request_url: str,
115
+ api_key: str,
116
+ max_requests_per_minute: float,
117
+ max_tokens_per_minute: float,
118
+ token_encoding_name: str,
119
+ max_attempts: int,
120
+ logging_level: int,
121
+ ):
122
+ """Processes API requests in parallel, throttling to stay under rate limits."""
123
+ # constants
124
+ seconds_to_pause_after_rate_limit_error = 15
125
+ seconds_to_sleep_each_loop = (
126
+ 0.001 # 1 ms limits max throughput to 1,000 requests per second
127
+ )
128
+
129
+ # initialize logging
130
+ logging.basicConfig(level=logging_level)
131
+ logging.debug(f"Logging initialized at level {logging_level}")
132
+
133
+ # infer API endpoint and construct request header
134
+ api_endpoint = api_endpoint_from_url(request_url, vendor_name)
135
+ request_header=None
136
+ if vendor_name=="openai":
137
+ request_header = {"Authorization": f"Bearer {api_key}"}
138
+ elif vendor_name=="anthropic":
139
+ request_header = {
140
+ "x-api-key": api_key,
141
+ "anthropic-version": "2023-06-01",
142
+ "content-type": "application/json",
143
+ }
144
+ elif vendor_name == "meta" or vendor_name == "google" :
145
+ request_header = {
146
+ "Content-Type": "application/json",
147
+ "Authorization": f"Bearer {api_key}",
148
+ }
149
+ else:
150
+ print("Error. Invalid Model Input. Exiting")
151
+ # exit()
152
+
153
+ # use api-key header for Azure deployments
154
+ # if '/deployments' in request_url:
155
+ # request_header = {"api-key": f"{api_key}"}
156
+
157
+ # initialize trackers
158
+ queue_of_requests_to_retry = asyncio.Queue()
159
+ task_id_generator = (
160
+ task_id_generator_function()
161
+ ) # generates integer IDs of 0, 1, 2, ...
162
+ status_tracker = (
163
+ StatusTracker()
164
+ ) # single instance to track a collection of variables
165
+ next_request = None # variable to hold the next request to call
166
+
167
+ # initialize available capacity counts
168
+ available_request_capacity = max_requests_per_minute
169
+ available_token_capacity = max_tokens_per_minute
170
+ last_update_time = time.time()
171
+
172
+ # initialize flags
173
+ file_not_finished = True # after file is empty, we'll skip reading it
174
+ logging.debug(f"Initialization complete.")
175
+
176
+ # initialize file reading
177
+ with open(requests_filepath) as file:
178
+ # `requests` will provide requests one at a time
179
+ requests = file.__iter__()
180
+ logging.debug(f"File opened. Entering main loop")
181
+ async with aiohttp.ClientSession() as session: # Initialize ClientSession here
182
+ while True:
183
+ # get next request (if one is not already waiting for capacity)
184
+ if next_request is None:
185
+ if not queue_of_requests_to_retry.empty():
186
+ next_request = queue_of_requests_to_retry.get_nowait()
187
+ logging.debug(
188
+ f"Retrying request {next_request.task_id}: {next_request}"
189
+ )
190
+ elif file_not_finished:
191
+ try:
192
+ # get new request
193
+ request_json = json.loads(next(requests))
194
+ next_request = APIRequest(
195
+ task_id=next(task_id_generator),
196
+ request_json=request_json,
197
+ # token_consumption=num_tokens_consumed_from_request(
198
+ # request_json, api_endpoint, token_encoding_name
199
+ # ), # just disabled the tokens consumption. Not worried about that.
200
+ token_consumption=0,
201
+ attempts_left=max_attempts,
202
+ metadata=request_json.pop("metadata", None),
203
+ )
204
+ status_tracker.num_tasks_started += 1
205
+ status_tracker.num_tasks_in_progress += 1
206
+ logging.debug(
207
+ f"Reading request {next_request.task_id}: {next_request}"
208
+ )
209
+ except StopIteration:
210
+ # if file runs out, set flag to stop reading it
211
+ logging.debug("Read file exhausted")
212
+ file_not_finished = False
213
+
214
+ # update available capacity
215
+ current_time = time.time()
216
+ seconds_since_update = current_time - last_update_time
217
+ available_request_capacity = min(
218
+ available_request_capacity
219
+ + max_requests_per_minute * seconds_since_update / 60.0,
220
+ max_requests_per_minute,
221
+ )
222
+ available_token_capacity = min(
223
+ available_token_capacity
224
+ + max_tokens_per_minute * seconds_since_update / 60.0,
225
+ max_tokens_per_minute,
226
+ )
227
+ last_update_time = current_time
228
+
229
+ # if enough capacity available, call API
230
+ if next_request:
231
+ next_request_tokens = next_request.token_consumption
232
+ if (
233
+ available_request_capacity >= 1
234
+ and available_token_capacity >= next_request_tokens
235
+ ):
236
+ # update counters
237
+ available_request_capacity -= 1
238
+ available_token_capacity -= next_request_tokens
239
+ next_request.attempts_left -= 1
240
+
241
+ # call API
242
+ asyncio.create_task(
243
+ next_request.call_api(
244
+ session=session,
245
+ request_url=request_url,
246
+ request_header=request_header,
247
+ retry_queue=queue_of_requests_to_retry,
248
+ save_filepath=save_filepath,
249
+ status_tracker=status_tracker,
250
+ )
251
+ )
252
+ next_request = None # reset next_request to empty
253
+
254
+ # if all tasks are finished, break
255
+ if status_tracker.num_tasks_in_progress == 0:
256
+ break
257
+
258
+ # main loop sleeps briefly so concurrent tasks can run
259
+ await asyncio.sleep(seconds_to_sleep_each_loop)
260
+
261
+ # if a rate limit error was hit recently, pause to cool down
262
+ seconds_since_rate_limit_error = (
263
+ time.time() - status_tracker.time_of_last_rate_limit_error
264
+ )
265
+ if (
266
+ seconds_since_rate_limit_error
267
+ < seconds_to_pause_after_rate_limit_error
268
+ ):
269
+ remaining_seconds_to_pause = (
270
+ seconds_to_pause_after_rate_limit_error
271
+ - seconds_since_rate_limit_error
272
+ )
273
+ await asyncio.sleep(remaining_seconds_to_pause)
274
+ # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
275
+ logging.warn(
276
+ f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
277
+ )
278
+
279
+ # after finishing, log final status
280
+ logging.info(
281
+ f"""Parallel processing complete. Results saved to {save_filepath}"""
282
+ )
283
+ if status_tracker.num_tasks_failed > 0:
284
+ logging.warning(
285
+ f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
286
+ )
287
+ if status_tracker.num_rate_limit_errors > 0:
288
+ logging.warning(
289
+ f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
290
+ )
291
+
292
+
293
+ # dataclasses
294
+
295
+
296
+ @dataclass
297
+ class StatusTracker:
298
+ """Stores metadata about the script's progress. Only one instance is created."""
299
+
300
+ num_tasks_started: int = 0
301
+ num_tasks_in_progress: int = 0 # script ends when this reaches 0
302
+ num_tasks_succeeded: int = 0
303
+ num_tasks_failed: int = 0
304
+ num_rate_limit_errors: int = 0
305
+ num_api_errors: int = 0 # excluding rate limit errors, counted above
306
+ num_other_errors: int = 0
307
+ time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits
308
+
309
+
310
+ @dataclass
311
+ class APIRequest:
312
+ """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call."""
313
+
314
+ task_id: int
315
+ request_json: dict
316
+ token_consumption: int
317
+ attempts_left: int
318
+ metadata: dict
319
+ result: list = field(default_factory=list)
320
+
321
+ async def call_api(
322
+ self,
323
+ session: aiohttp.ClientSession,
324
+ request_url: str,
325
+ request_header: dict,
326
+ retry_queue: asyncio.Queue,
327
+ save_filepath: str,
328
+ status_tracker: StatusTracker,
329
+ ):
330
+ """Calls the OpenAI API and saves results."""
331
+ logging.info(f"Starting request #{self.task_id}")
332
+ error = None
333
+ try:
334
+ async with session.post(
335
+ url=request_url, headers=request_header, json=self.request_json
336
+ ) as response:
337
+ response = await response.json()
338
+ if "error" in response:
339
+ logging.warning(
340
+ f"Request {self.task_id} failed with error {response['error']}"
341
+ )
342
+ status_tracker.num_api_errors += 1
343
+ error = response
344
+ if "Rate limit" in response["error"].get("message", ""):
345
+ status_tracker.time_of_last_rate_limit_error = time.time()
346
+ status_tracker.num_rate_limit_errors += 1
347
+ status_tracker.num_api_errors -= (
348
+ 1 # rate limit errors are counted separately
349
+ )
350
+
351
+ except (
352
+ Exception
353
+ ) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
354
+ logging.warning(f"Request {self.task_id} failed with Exception {e}")
355
+ status_tracker.num_other_errors += 1
356
+ error = e
357
+ if error:
358
+ self.result.append(error)
359
+ if self.attempts_left:
360
+ retry_queue.put_nowait(self)
361
+ else:
362
+ logging.error(
363
+ f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
364
+ )
365
+ data = (
366
+ [self.request_json, [str(e) for e in self.result], self.metadata]
367
+ if self.metadata
368
+ else [self.request_json, [str(e) for e in self.result]]
369
+ )
370
+ append_to_jsonl(data, save_filepath)
371
+ status_tracker.num_tasks_in_progress -= 1
372
+ status_tracker.num_tasks_failed += 1
373
+ else:
374
+ data = (
375
+ [self.request_json, response, self.metadata]
376
+ if self.metadata
377
+ else [self.request_json, response]
378
+ )
379
+ append_to_jsonl(data, save_filepath)
380
+ status_tracker.num_tasks_in_progress -= 1
381
+ status_tracker.num_tasks_succeeded += 1
382
+ logging.debug(f"Request {self.task_id} saved to {save_filepath}")
383
+
384
+
385
+ # functions
386
+
387
+
388
+ def api_endpoint_from_url(request_url, vendor_name):
389
+ """Extract the API endpoint from the request URL."""
390
+ match=None
391
+ if vendor_name=="openai":
392
+ match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url)
393
+ elif vendor_name=="anthropic":
394
+ match = re.search(r"^https://[^/]+/v1/(.+)$", request_url)
395
+ elif vendor_name == "meta" or vendor_name == "google":
396
+ match = re.search(r"^https://[^/]+/api/v1/(.+)$", request_url)
397
+ else:
398
+ print("Error. Invalid Model Input. Exiting")
399
+ # exit()
400
+ if match is None:
401
+ # for Azure OpenAI deployment urls
402
+ match = re.search(r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url)
403
+ return match[1]
404
+
405
+
406
+ def append_to_jsonl(data, filename: str) -> None:
407
+ """Append a json payload to the end of a jsonl file."""
408
+ json_string = json.dumps(data)
409
+ with open(filename, "a") as f:
410
+ f.write(json_string + "\n")
411
+
412
+
413
+
414
+ def task_id_generator_function():
415
+ """Generate integers 0, 1, 2, and so on."""
416
+ task_id = 0
417
+ while True:
418
+ yield task_id
419
+ task_id += 1
420
+
421
+
422
+ # run script
423
+
424
+
425
+ if __name__ == "__main__":
426
+ # parse command line arguments
427
+ parser = argparse.ArgumentParser()
428
+ parser.add_argument("--vendor_name", default=None)
429
+ parser.add_argument("--requests_filepath")
430
+ parser.add_argument("--save_filepath", default=None)
431
+ parser.add_argument("--request_url", default=None)
432
+ parser.add_argument("--api_key", default=None)
433
+ parser.add_argument("--max_requests_per_minute", type=int, default=3_000 * 0.5)
434
+ parser.add_argument("--max_tokens_per_minute", type=int, default=250_000 * 0.5)
435
+ parser.add_argument("--token_encoding_name", default="cl100k_base")
436
+ parser.add_argument("--max_attempts", type=int, default=5)
437
+ parser.add_argument("--logging_level", default=logging.INFO)
438
+
439
+ args = parser.parse_args()
440
+ if args.vendor_name=="openai":
441
+ args.api_key=os.getenv("OPENAI_API_KEY")
442
+ args.request_url="https://api.openai.com/v1/chat/completions"
443
+ elif args.vendor_name=="anthropic":
444
+ args.api_key=os.getenv("ANTHROPIC_API_KEY")
445
+ args.request_url="https://api.anthropic.com/v1/messages"
446
+ elif args.vendor_name == "meta" or args.vendor_name == "google" :
447
+ args.api_key = os.getenv("OPENROUTER_API_KEY")
448
+ args.request_url = "https://openrouter.ai/api/v1/chat/completions"
449
+ else:
450
+ print("Error. Invalid Model Input. Exiting")
451
+ # exit()
452
+ if args.save_filepath is None:
453
+ args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl")
454
+
455
+ # run script
456
+ asyncio.run(
457
+ process_api_requests_from_file(
458
+ vendor_name=args.vendor_name,
459
+ requests_filepath=args.requests_filepath,
460
+ save_filepath=args.save_filepath,
461
+ request_url=args.request_url,
462
+ api_key=args.api_key,
463
+ max_requests_per_minute=float(args.max_requests_per_minute),
464
+ max_tokens_per_minute=float(args.max_tokens_per_minute),
465
+ token_encoding_name=args.token_encoding_name,
466
+ max_attempts=int(args.max_attempts),
467
+ logging_level=int(args.logging_level),
468
+ )
469
+ )
470
+
471
+
472
+ """
473
+ APPENDIX
474
+
475
+ The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002.
476
+
477
+ It was generated with the following code:
478
+
479
+ ```python
480
+ import json
481
+
482
+ filename = "data/example_requests_to_parallel_process.jsonl"
483
+ n_requests = 10_000
484
+ jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)]
485
+ with open(filename, "w") as f:
486
+ for job in jobs:
487
+ json_string = json.dumps(job)
488
+ f.write(json_string + "\n")
489
+ ```
490
+
491
+ As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
492
+ """
api_request_parallel_processor_universal_SEQUENTIAL.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API REQUEST PARALLEL PROCESSOR
3
+
4
+ Using the OpenAI API to process lots of text quickly takes some care.
5
+ If you trickle in a million API requests one by one, they'll take days to complete.
6
+ If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors.
7
+ To maximize throughput, parallel requests need to be throttled to stay under rate limits.
8
+
9
+ This script parallelizes requests to the OpenAI API while throttling to stay under rate limits.
10
+
11
+ Features:
12
+ - Streams requests from file, to avoid running out of memory for giant jobs
13
+ - Makes requests concurrently, to maximize throughput
14
+ - Throttles request and token usage, to stay under rate limits
15
+ - Retries failed requests up to {max_attempts} times, to avoid missing data
16
+ - Logs errors, to diagnose problems with requests
17
+
18
+ Example command to call script:
19
+ ```
20
+ python examples/api_request_parallel_processor.py \
21
+ --requests_filepath examples/data/example_requests_to_parallel_process.jsonl \
22
+ --save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \
23
+ --request_url https://api.openai.com/v1/embeddings \
24
+ --max_requests_per_minute 1500 \
25
+ --max_tokens_per_minute 6250000 \
26
+ --token_encoding_name cl100k_base \
27
+ --max_attempts 5 \
28
+ --logging_level 20
29
+ ```
30
+
31
+ Inputs:
32
+ - requests_filepath : str
33
+ - path to the file containing the requests to be processed
34
+ - file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field
35
+ - e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}}
36
+ - as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically)
37
+ - an example file is provided at examples/data/example_requests_to_parallel_process.jsonl
38
+ - the code to generate the example file is appended to the bottom of this script
39
+ - save_filepath : str, optional
40
+ - path to the file where the results will be saved
41
+ - file will be a jsonl file, where each line is an array with the original request plus the API response
42
+ - e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}]
43
+ - if omitted, results will be saved to {requests_filename}_results.jsonl
44
+ - request_url : str, optional
45
+ - URL of the API endpoint to call
46
+ - if omitted, will default to "https://api.openai.com/v1/embeddings"
47
+ - api_key : str, optional
48
+ - API key to use
49
+ - if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")}
50
+ - max_requests_per_minute : float, optional
51
+ - target number of requests to make per minute (will make less if limited by tokens)
52
+ - leave headroom by setting this to 50% or 75% of your limit
53
+ - if requests are limiting you, try batching multiple embeddings or completions into one request
54
+ - if omitted, will default to 1,500
55
+ - max_tokens_per_minute : float, optional
56
+ - target number of tokens to use per minute (will use less if limited by requests)
57
+ - leave headroom by setting this to 50% or 75% of your limit
58
+ - if omitted, will default to 125,000
59
+ - token_encoding_name : str, optional
60
+ - name of the token encoding used, as defined in the `tiktoken` package
61
+ - if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`)
62
+ - max_attempts : int, optional
63
+ - number of times to retry a failed request before giving up
64
+ - if omitted, will default to 5
65
+ - logging_level : int, optional
66
+ - level of logging to use; higher numbers will log fewer messages
67
+ - 40 = ERROR; will log only when requests fail after all retries
68
+ - 30 = WARNING; will log when requests his rate limits or other errors
69
+ - 20 = INFO; will log when requests start and the status at finish
70
+ - 10 = DEBUG; will log various things as the loop runs to see when they occur
71
+ - if omitted, will default to 20 (INFO).
72
+
73
+ The script is structured as follows:
74
+ - Imports
75
+ - Define main()
76
+ - Initialize things
77
+ - In main loop:
78
+ - Get next request if one is not already waiting for capacity
79
+ - Update available token & request capacity
80
+ - If enough capacity available, call API
81
+ - The loop pauses if a rate limit error is hit
82
+ - The loop breaks when no tasks remain
83
+ - Define dataclasses
84
+ - StatusTracker (stores script metadata counters; only one instance is created)
85
+ - APIRequest (stores API inputs, outputs, metadata; one method to call API)
86
+ - Define functions
87
+ - api_endpoint_from_url (extracts API endpoint from request URL)
88
+ - append_to_jsonl (writes to results file)
89
+ - num_tokens_consumed_from_request (bigger function to infer token usage from request)
90
+ - task_id_generator_function (yields 0, 1, 2, ...)
91
+ - Run main()
92
+ """
93
+
94
+ # imports
95
+ import aiohttp # for making API calls concurrently
96
+ import argparse # for running script from command line
97
+ import asyncio # for running API calls concurrently
98
+ import json # for saving results to a jsonl file
99
+ import logging # for logging rate limit warnings and other messages
100
+ import os # for reading API key
101
+ import re # for matching endpoint from request URL
102
+ import tiktoken # for counting tokens
103
+ import time # for sleeping after rate limit is hit
104
+ from dataclasses import (
105
+ dataclass,
106
+ field,
107
+ ) # for storing API inputs, outputs, and metadata
108
+
109
+
110
+ def process_api_requests_from_file(
111
+ vendor_name: str,
112
+ requests_filepath: str,
113
+ save_filepath: str,
114
+ request_url: str,
115
+ api_key: str,
116
+ max_requests_per_minute: float,
117
+ max_tokens_per_minute: float,
118
+ token_encoding_name: str,
119
+ max_attempts: int,
120
+ logging_level: int,
121
+ ):
122
+ """Processes API requests sequentially."""
123
+ # initialize logging
124
+ logging.basicConfig(level=logging_level)
125
+ logging.debug(f"Logging initialized at level {logging_level}")
126
+
127
+ # infer API endpoint and construct request header
128
+ api_endpoint = api_endpoint_from_url(request_url, vendor_name)
129
+ request_header = None
130
+ if vendor_name == "openai":
131
+ request_header = {"Authorization": f"Bearer {api_key}"}
132
+ elif vendor_name == "anthropic":
133
+ request_header = {
134
+ "x-api-key": api_key,
135
+ "anthropic-version": "2023-06-01",
136
+ "content-type": "application/json",
137
+ }
138
+ elif vendor_name == "meta" or vendor_name == "google":
139
+ request_header = {
140
+ "Content-Type": "application/json",
141
+ "Authorization": f"Bearer {api_key}",
142
+ }
143
+ else:
144
+ print("Error. Invalid Model Input. Exiting")
145
+
146
+ # initialize trackers
147
+ task_id_generator = task_id_generator_function()
148
+ status_tracker = StatusTracker()
149
+
150
+ # process requests sequentially
151
+ with open(requests_filepath) as file, requests.Session() as session:
152
+ for line in file:
153
+ request_json = json.loads(line)
154
+ request = APIRequest(
155
+ task_id=next(task_id_generator),
156
+ request_json=request_json,
157
+ token_consumption=0,
158
+ attempts_left=max_attempts,
159
+ metadata=request_json.pop("metadata", None),
160
+ )
161
+ status_tracker.num_tasks_started += 1
162
+ logging.debug(f"Processing request {request.task_id}: {request}")
163
+
164
+ while request.attempts_left > 0:
165
+ error = None
166
+ try:
167
+ response = session.post(
168
+ url=request_url,
169
+ headers=request_header,
170
+ json=request.request_json,
171
+ ).json()
172
+ if "error" in response:
173
+ logging.warning(
174
+ f"Request {request.task_id} failed with error {response['error']}"
175
+ )
176
+ status_tracker.num_api_errors += 1
177
+ error = response
178
+ if "Rate limit" in response["error"].get("message", ""):
179
+ status_tracker.num_rate_limit_errors += 1
180
+ status_tracker.num_api_errors -= 1
181
+ except Exception as e:
182
+ logging.warning(f"Request {request.task_id} failed with Exception {e}")
183
+ status_tracker.num_other_errors += 1
184
+ error = e
185
+
186
+ if error:
187
+ request.result.append(error)
188
+ request.attempts_left -= 1
189
+ if request.attempts_left == 0:
190
+ logging.error(
191
+ f"Request {request.request_json} failed after all attempts. Saving errors: {request.result}"
192
+ )
193
+ data = (
194
+ [request.request_json, [str(e) for e in request.result], request.metadata]
195
+ if request.metadata
196
+ else [request.request_json, [str(e) for e in request.result]]
197
+ )
198
+ append_to_jsonl(data, save_filepath)
199
+ status_tracker.num_tasks_failed += 1
200
+ else:
201
+ data = (
202
+ [request.request_json, response, request.metadata]
203
+ if request.metadata
204
+ else [request.request_json, response]
205
+ )
206
+ append_to_jsonl(data, save_filepath)
207
+ status_tracker.num_tasks_succeeded += 1
208
+ logging.debug(f"Request {request.task_id} saved to {save_filepath}")
209
+ break
210
+
211
+ # after finishing, log final status
212
+ logging.info(f"Sequential processing complete. Results saved to {save_filepath}")
213
+ if status_tracker.num_tasks_failed > 0:
214
+ logging.warning(
215
+ f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
216
+ )
217
+ if status_tracker.num_rate_limit_errors > 0:
218
+ logging.warning(
219
+ f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
220
+ )
221
+ # dataclasses
222
+
223
+
224
+ @dataclass
225
+ class StatusTracker:
226
+ """Stores metadata about the script's progress. Only one instance is created."""
227
+
228
+ num_tasks_started: int = 0
229
+ num_tasks_in_progress: int = 0 # script ends when this reaches 0
230
+ num_tasks_succeeded: int = 0
231
+ num_tasks_failed: int = 0
232
+ num_rate_limit_errors: int = 0
233
+ num_api_errors: int = 0 # excluding rate limit errors, counted above
234
+ num_other_errors: int = 0
235
+ time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits
236
+
237
+
238
+ @dataclass
239
+ class APIRequest:
240
+ """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call."""
241
+
242
+ task_id: int
243
+ request_json: dict
244
+ token_consumption: int
245
+ attempts_left: int
246
+ metadata: dict
247
+ result: list = field(default_factory=list)
248
+
249
+ async def call_api(
250
+ self,
251
+ session: aiohttp.ClientSession,
252
+ request_url: str,
253
+ request_header: dict,
254
+ retry_queue: asyncio.Queue,
255
+ save_filepath: str,
256
+ status_tracker: StatusTracker,
257
+ ):
258
+ """Calls the OpenAI API and saves results."""
259
+ logging.info(f"Starting request #{self.task_id}")
260
+ error = None
261
+ try:
262
+ async with session.post(
263
+ url=request_url, headers=request_header, json=self.request_json
264
+ ) as response:
265
+ response = await response.json()
266
+ if "error" in response:
267
+ logging.warning(
268
+ f"Request {self.task_id} failed with error {response['error']}"
269
+ )
270
+ status_tracker.num_api_errors += 1
271
+ error = response
272
+ if "Rate limit" in response["error"].get("message", ""):
273
+ status_tracker.time_of_last_rate_limit_error = time.time()
274
+ status_tracker.num_rate_limit_errors += 1
275
+ status_tracker.num_api_errors -= (
276
+ 1 # rate limit errors are counted separately
277
+ )
278
+
279
+ except (
280
+ Exception
281
+ ) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
282
+ logging.warning(f"Request {self.task_id} failed with Exception {e}")
283
+ status_tracker.num_other_errors += 1
284
+ error = e
285
+ if error:
286
+ self.result.append(error)
287
+ if self.attempts_left:
288
+ retry_queue.put_nowait(self)
289
+ else:
290
+ logging.error(
291
+ f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
292
+ )
293
+ data = (
294
+ [self.request_json, [str(e) for e in self.result], self.metadata]
295
+ if self.metadata
296
+ else [self.request_json, [str(e) for e in self.result]]
297
+ )
298
+ append_to_jsonl(data, save_filepath)
299
+ status_tracker.num_tasks_in_progress -= 1
300
+ status_tracker.num_tasks_failed += 1
301
+ else:
302
+ data = (
303
+ [self.request_json, response, self.metadata]
304
+ if self.metadata
305
+ else [self.request_json, response]
306
+ )
307
+ append_to_jsonl(data, save_filepath)
308
+ status_tracker.num_tasks_in_progress -= 1
309
+ status_tracker.num_tasks_succeeded += 1
310
+ logging.debug(f"Request {self.task_id} saved to {save_filepath}")
311
+
312
+
313
+ # functions
314
+
315
+
316
+ def api_endpoint_from_url(request_url, vendor_name):
317
+ """Extract the API endpoint from the request URL."""
318
+ match=None
319
+ if vendor_name=="openai":
320
+ match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url)
321
+ elif vendor_name=="anthropic":
322
+ match = re.search(r"^https://[^/]+/v1/(.+)$", request_url)
323
+ elif vendor_name == "meta" or vendor_name == "google":
324
+ match = re.search(r"^https://[^/]+/api/v1/(.+)$", request_url)
325
+ else:
326
+ print("Error. Invalid Model Input. Exiting")
327
+ # exit()
328
+ if match is None:
329
+ # for Azure OpenAI deployment urls
330
+ match = re.search(r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url)
331
+ return match[1]
332
+
333
+
334
+ def append_to_jsonl(data, filename: str) -> None:
335
+ """Append a json payload to the end of a jsonl file."""
336
+ json_string = json.dumps(data)
337
+ with open(filename, "a") as f:
338
+ f.write(json_string + "\n")
339
+
340
+
341
+
342
+ def task_id_generator_function():
343
+ """Generate integers 0, 1, 2, and so on."""
344
+ task_id = 0
345
+ while True:
346
+ yield task_id
347
+ task_id += 1
348
+
349
+
350
+ # run script
351
+
352
+
353
+ if __name__ == "__main__":
354
+ # parse command line arguments
355
+ parser = argparse.ArgumentParser()
356
+ parser.add_argument("--vendor_name", default=None)
357
+ parser.add_argument("--requests_filepath")
358
+ parser.add_argument("--save_filepath", default=None)
359
+ parser.add_argument("--request_url", default=None)
360
+ parser.add_argument("--api_key", default=None)
361
+ parser.add_argument("--max_requests_per_minute", type=int, default=3_000 * 0.5)
362
+ parser.add_argument("--max_tokens_per_minute", type=int, default=250_000 * 0.5)
363
+ parser.add_argument("--token_encoding_name", default="cl100k_base")
364
+ parser.add_argument("--max_attempts", type=int, default=5)
365
+ parser.add_argument("--logging_level", default=logging.INFO)
366
+
367
+ args = parser.parse_args()
368
+ if args.vendor_name=="openai":
369
+ args.api_key=os.getenv("OPENAI_API_KEY")
370
+ args.request_url="https://api.openai.com/v1/chat/completions"
371
+ elif args.vendor_name=="anthropic":
372
+ args.api_key=os.getenv("ANTHROPIC_API_KEY")
373
+ args.request_url="https://api.anthropic.com/v1/messages"
374
+ elif args.vendor_name == "meta" or args.vendor_name == "google" :
375
+ args.api_key = os.getenv("OPENROUTER_API_KEY")
376
+ args.request_url = "https://openrouter.ai/api/v1/chat/completions"
377
+ else:
378
+ print("Error. Invalid Model Input. Exiting")
379
+ # exit()
380
+ if args.save_filepath is None:
381
+ args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl")
382
+
383
+ # run script
384
+ asyncio.run(
385
+ process_api_requests_from_file(
386
+ vendor_name=args.vendor_name,
387
+ requests_filepath=args.requests_filepath,
388
+ save_filepath=args.save_filepath,
389
+ request_url=args.request_url,
390
+ api_key=args.api_key,
391
+ max_requests_per_minute=float(args.max_requests_per_minute),
392
+ max_tokens_per_minute=float(args.max_tokens_per_minute),
393
+ token_encoding_name=args.token_encoding_name,
394
+ max_attempts=int(args.max_attempts),
395
+ logging_level=int(args.logging_level),
396
+ )
397
+ )
398
+
399
+
400
+ """
401
+ APPENDIX
402
+
403
+ The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002.
404
+
405
+ It was generated with the following code:
406
+
407
+ ```python
408
+ import json
409
+
410
+ filename = "data/example_requests_to_parallel_process.jsonl"
411
+ n_requests = 10_000
412
+ jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)]
413
+ with open(filename, "w") as f:
414
+ for job in jobs:
415
+ json_string = json.dumps(job)
416
+ f.write(json_string + "\n")
417
+ ```
418
+
419
+ As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
420
+ """
app.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # https://stackoverflow.com/questions/76175046/how-to-add-prompt-to-langchain-conversationalretrievalchain-chat-over-docs-with
3
+ # again from:
4
+ # https://python.langchain.com/docs/integrations/providers/vectara/vectara_chat
5
+ from langchain.document_loaders import PyPDFDirectoryLoader
6
+ import pandas as pd
7
+ import langchain
8
+ from queue import Queue
9
+ from typing import Any
10
+ from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
11
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
+ from langchain.schema import LLMResult
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
+ from langchain.vectorstores import FAISS
15
+ from langchain.prompts.prompt import PromptTemplate
16
+ from anyio.from_thread import start_blocking_portal #For model callback streaming
17
+
18
+
19
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
20
+ import os
21
+ from dotenv import load_dotenv
22
+
23
+ import streamlit as st
24
+ import json
25
+ from langchain.document_loaders import PyPDFLoader
26
+ from langchain.text_splitter import CharacterTextSplitter
27
+ from langchain.embeddings import OpenAIEmbeddings
28
+ from langchain.chains.question_answering import load_qa_chain
29
+ from langchain.chat_models import ChatOpenAI
30
+ # from langchain.chat_models import ChatAnthropic
31
+ from langchain_anthropic import ChatAnthropic
32
+ from langchain.vectorstores import Chroma
33
+ import chromadb
34
+
35
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
36
+ from langchain.llms import OpenAI
37
+ from langchain.chains import RetrievalQA
38
+ from langchain.document_loaders import TextLoader
39
+ from langchain.document_loaders import DirectoryLoader
40
+ from langchain_community.document_loaders import PyMuPDFLoader
41
+ from langchain.schema import Document
42
+
43
+ from langchain.memory import ConversationBufferMemory
44
+
45
+ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
46
+ from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
47
+ import gradio as gr
48
+ from langchain.memory import ConversationBufferMemory
49
+ from langchain.chains import ConversationalRetrievalChain
50
+ from langchain.chains import ConversationChain
51
+ from langchain.prompts import PromptTemplate
52
+ from langchain.chains import LLMChain
53
+ print("Started")
54
+
55
+ # Function to map UI region selection to DB metadata region
56
+ # def get_db_region(ui_region: str) -> str:
57
+ # """Maps UI region selection (e.g., 'Iowa') to the region string stored in metadata (e.g., 'United States')."""
58
+ # if ui_region == "Iowa":
59
+ # return "United States"
60
+ # # Add more mappings if needed (e.g., Africa)
61
+ # return ui_region # Default to using the UI string if no specific mapping
62
+
63
+ def get_species_list_from_db(db_name):
64
+ embedding = OpenAIEmbeddings()
65
+ vectordb_temp = Chroma(persist_directory=db_name,
66
+ embedding_function=embedding)
67
+ species_list=[]
68
+ for meta in vectordb_temp.get()["metadatas"] :
69
+ try:
70
+ matched_first_species = meta['matched_specie_0']
71
+ except KeyError:
72
+ continue
73
+ # Since each document is considered as a single chunk, the chunk_index is 0 for all
74
+ species_list.append( matched_first_species)
75
+
76
+ return species_list
77
+
78
+
79
+
80
+ # default_persist_directory = './db5' # For deployement
81
+ default_persist_directory_insects='./vector-databases-deployed/db5-agllm-data-isu-field-insects-all-species'
82
+ default_persist_directory_weeds='./vector-databases-deployed/db5-agllm-data-isu-field-weeds-all-species'
83
+
84
+ species_list_insects=get_species_list_from_db(default_persist_directory_insects)
85
+ species_list_weeds=get_species_list_from_db(default_persist_directory_weeds)
86
+ # default_persist_directory = 'vector-databases/db5-pre-completion' # For Development
87
+ csv_filepath1 = "./agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx"
88
+ csv_filepath2 = "./agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx"
89
+
90
+ # India data
91
+ cv_india="./agllm-data/india/species.csv"
92
+ model_name=4
93
+ max_tokens=400
94
+ system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later.
95
+ langchain.debug=False # TODO: DOUBLE CHECK
96
+ from langchain import globals
97
+ globals.set_debug(False)
98
+
99
+ retriever_k_value=3
100
+ embedding = OpenAIEmbeddings()
101
+ print("Started....")
102
+ class ChatOpenRouter(ChatOpenAI):
103
+ openai_api_base: str
104
+ openai_api_key: str
105
+ model_name: str
106
+
107
+ def __init__(self,
108
+ model_name: str,
109
+ openai_api_key: [str] = None,
110
+ openai_api_base: str = "https://openrouter.ai/api/v1",
111
+ **kwargs):
112
+ openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY')
113
+ super().__init__(openai_api_base=openai_api_base,
114
+ openai_api_key=openai_api_key,
115
+ model_name=model_name, **kwargs)
116
+
117
+
118
+ ######### todo: skipping the first step
119
+
120
+
121
+
122
+ # print(# Single example
123
+ # vectordb.as_retriever(k=2, search_kwargs={"filter": {"matched_specie_0": "Hypagyrtis unipunctata"}, 'k':1}).get_relevant_documents(
124
+ # "Checking if retriever is correctly initalized?"
125
+ # ))
126
+
127
+ columns = ['species', 'common name', 'order', 'family',
128
+ 'genus', 'Updated role in ecosystem', 'Proof',
129
+ 'ipm strategies', 'size of insect', 'geographical spread',
130
+ 'life cycle specifics', 'pest for plant species', 'species status',
131
+ 'distribution area', 'appearance', 'identification']
132
+
133
+ df1 = pd.read_excel(csv_filepath1, usecols=columns)
134
+ df2 = pd.read_excel(csv_filepath2, usecols=columns)
135
+ df_india = pd.read_csv(cv_india)
136
+ all_insects_data = pd.concat([df1, df2], ignore_index=True)
137
+
138
+ def get_prompt_with_vetted_info_from_specie_name(search_for_specie, mode):
139
+
140
+ def read_and_format_filtered_csv_better(dataframe_given, insect_specie):
141
+ filtered_data = dataframe_given[dataframe_given['species'] == insect_specie]
142
+ formatted_data = ""
143
+ # Format the filtered data
144
+ for index, row in filtered_data.iterrows():
145
+ row_data = [f"{col}: {row[col]}" for col in filtered_data.columns]
146
+ formatted_row = "\n".join(row_data)
147
+ formatted_data += f"{formatted_row}\n"
148
+
149
+ return formatted_data
150
+
151
+ # Use the path to your CSV file here
152
+
153
+ vetted_info=read_and_format_filtered_csv_better(all_insects_data, search_for_specie)
154
+ india_info=read_and_format_filtered_csv_better(df_india, search_for_specie)
155
+
156
+
157
+ if mode=="Farmer":
158
+ language_constraint="The language should be acustomed to the Farmers. Given question is likely to be asked by a farmer in the field will ask which will help to make decisions which are immediate and practical."
159
+ elif mode=="Researcher":
160
+ language_constraint="The language should be acustomed to a researcher. Given question is likely to be asked by a scientist which are comprehensive and aimed at exploring new knowledge or refining existing methodologies"
161
+ else:
162
+ print("No valid mode provided. Exiting")
163
+ exit()
164
+ # general_system_template = """
165
+ # In every question you are provided information about the insect/weed. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect/weed species and a question by the user. answer the question according to these two types of informations.
166
+ # ----
167
+ # Vetted info is as follows:
168
+ # {vetted_info}
169
+ # ----
170
+ # The context retrieved for documents about this particular question is as follows:
171
+ # {context}
172
+ # ----
173
+ # Additional Instruction:
174
+ # 1. Reference Constraint
175
+ # At the end of each answer provide the source/reference for the given data in following format:
176
+ # \n\n[enter two new lines before writing below] References:
177
+ # Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'.
178
+ # Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used.
179
+ # 2. Information Constraint:
180
+ # Only answer the question from information provided otherwise say you dont know. You have to answer in 50 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about.
181
+ # 3. Language constraint:
182
+ # {language_constraint}
183
+
184
+ # ----
185
+ # """.format(vetted_info=vetted_info, language_constraint=language_constraint,context="{context}", )
186
+
187
+ general_system_template = f"""
188
+ You are an AI assistant specialized in providing information about insects/weeds. Answer the user's question based on the available information or your general knowledge.
189
+
190
+ The context retrieved for this question is as follows:
191
+ {{context}}
192
+
193
+ Instructions:
194
+ 1. Evaluate the relevance of the provided context to the question.
195
+ 2. If the context contains relevant information, use it to answer the question and explicitly mention "Based on provided information" in your source.
196
+ 3. If the context does not contain relevant information, use your general knowledge to answer the question and state "Based on general knowledge" as the source.
197
+ 4. Format your response as follows:
198
+ Answer: Provide a concise answer in less than 50 words.
199
+ Source: State either "Based on provided information" or "Based on general knowledge".
200
+
201
+ 5. Language constraint:
202
+ {language_constraint}
203
+ 6. Other region (India) information:
204
+ {india_info}
205
+ 7. So you have two kinds of information (default from Iowa and other region (India) information). First need to ask the user what region they are interested in. and only provide information from that region.
206
+ 8. When answering question, say what if the information is from what regiion. So, if a region is selected by user and specified in the question, then only answer based on that region and say so.
207
+
208
+
209
+ Question: {{question}}
210
+ """
211
+
212
+ general_user_template = "Question:```{question}```"
213
+ messages_formatted = [
214
+ SystemMessagePromptTemplate.from_template(general_system_template),
215
+ HumanMessagePromptTemplate.from_template(general_user_template)
216
+ ]
217
+ qa_prompt = ChatPromptTemplate.from_messages( messages_formatted)
218
+ # print(qa_prompt)
219
+ return qa_prompt
220
+
221
+
222
+
223
+
224
+
225
+ # qa_prompt=get_prompt_with_vetted_info_from_specie_name("Papaipema nebris", "Researcher")
226
+ # print("First prompt is intialized as: " , qa_prompt, "\n\n")
227
+
228
+
229
+ memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer', return_messages=True) # https://github.com/langchain-ai/langchain/issues/9394#issuecomment-1683538834
230
+
231
+
232
+ if model_name==4:
233
+ llm_openai = ChatOpenAI(model_name="gpt-4o-2024-08-06" , temperature=0, max_tokens=max_tokens) # TODO: NEW MODEL VERSION AVAILABLE
234
+ else:
235
+ llm_openai = ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
236
+
237
+ specie_selector="Papaipema nebris"
238
+ filter = {
239
+ "$or": [
240
+ {"matched_specie_0": specie_selector},
241
+ {"matched_specie_1": specie_selector},
242
+ {"matched_specie_2": specie_selector},
243
+ ]
244
+ }
245
+ # retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
246
+
247
+ # qa_chain = ConversationalRetrievalChain.from_llm(
248
+ # llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,\
249
+ # combine_docs_chain_kwargs={'prompt': qa_prompt}
250
+
251
+ # )
252
+ #
253
+
254
+ def initialize_qa_chain(specie_selector, application_mode, model_name, region, database_persistent_directory=default_persist_directory_insects, domain_name="Insects"):
255
+ # Add helper function for India info (kept for potential future use, but removed from RAG prompt)
256
+ def read_and_format_filtered_csv_better(dataframe_given, insect_specie):
257
+ filtered_data = dataframe_given[dataframe_given['species'] == insect_specie]
258
+ formatted_data = ""
259
+ # Format the filtered data
260
+ for index, row in filtered_data.iterrows():
261
+ row_data = [f"{col}: {row[col]}" for col in filtered_data.columns]
262
+ formatted_row = "\n".join(row_data)
263
+ formatted_data += f"{formatted_row}\n"
264
+
265
+ return formatted_data
266
+
267
+ # Get India info (potentially useful if not using RAG or for specific logic)
268
+ india_info = read_and_format_filtered_csv_better(df_india, specie_selector)
269
+ # db_region = get_db_region(region) # Map UI region to DB region - REMOVED
270
+
271
+ if model_name=="GPT-4":
272
+ chosen_llm=ChatOpenAI(model_name="gpt-4o-2024-08-06" , temperature=0, max_tokens=max_tokens)
273
+ elif model_name=="GPT-3.5":
274
+ chosen_llm=ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
275
+ elif model_name=="Llama-3 70B":
276
+ chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-70b-instruct", temperature=0,max_tokens=max_tokens )
277
+ elif model_name=="Llama-3 8B":
278
+ chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-8b-instruct", temperature=0, max_tokens=max_tokens)
279
+ elif model_name=="Gemini-1.5 Pro":
280
+ chosen_llm = ChatOpenRouter(model_name="google/gemini-pro-1.5", temperature=0, max_tokens=max_tokens)
281
+ elif model_name=="Claude 3 Opus":
282
+ chosen_llm = ChatAnthropic(model_name='claude-3-opus-20240229', temperature=0, max_tokens=max_tokens)
283
+ elif model_name=="Claude 3.5 Sonnet":
284
+ chosen_llm = ChatAnthropic(model_name='claude-3-5-sonnet-20240620', temperature=0, max_tokens=max_tokens)
285
+ else:
286
+ print("No appropriate llm was selected")
287
+ exit()
288
+
289
+ if application_mode == "Farmer":
290
+ language_constraint = "The language should be customized for Farmers. The given question is likely to be asked by a farmer in the field and will help to make decisions which are immediate and practical."
291
+ elif application_mode == "Researcher":
292
+ language_constraint = "The language should be customized for a researcher. The given question is likely to be asked by a scientist and should be comprehensive, aimed at exploring new knowledge or refining existing methodologies."
293
+ else:
294
+ print("No valid mode provided. Exiting")
295
+ exit()
296
+
297
+ # RAG is always ON now
298
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
299
+
300
+ # Construct the species filter part
301
+ species_filter = {
302
+ "$or": [
303
+ {"matched_specie_" + str(i): specie_selector} for i in range(11) # Generate dynamically up to 10
304
+ ]
305
+ }
306
+
307
+ embedding = OpenAIEmbeddings()
308
+ vectordb = Chroma(persist_directory=database_persistent_directory,
309
+ embedding_function=embedding)
310
+
311
+ # --- Find all available regions for this species --- #
312
+ availability_message = f"Checking region availability for {specie_selector}..."
313
+ available_regions = set()
314
+
315
+ try:
316
+ # Query ChromaDB just for metadata based on species
317
+ species_docs = vectordb.get(where=species_filter, include=['metadatas'])
318
+ if species_docs and species_docs.get('metadatas'):
319
+ for meta in species_docs['metadatas']:
320
+ if 'region' in meta:
321
+ available_regions.add(meta['region'])
322
+
323
+ if available_regions:
324
+ available_regions_list = sorted(list(available_regions))
325
+ availability_message = f"Information for **{specie_selector}** is available in region(s): **{', '.join(available_regions_list)}**."
326
+ else:
327
+ available_regions_list = []
328
+ availability_message = f"No regional information found for **{specie_selector}** in the {domain_name} database."
329
+ except Exception as e:
330
+ print(f"Error checking region availability: {e}")
331
+ available_regions_list = []
332
+ availability_message = f"Could not determine region availability for {specie_selector}."
333
+
334
+ # --- Prepare context sections by region --- #
335
+ # Dictionary to hold context documents for each region
336
+ # region_contexts = {} # Unused variable, removing
337
+
338
+ # First check if selected region has information
339
+ selected_region_has_info = region in available_regions
340
+
341
+ # Create list of other available regions (excluding selected region)
342
+ other_regions = [r for r in available_regions_list if r != region]
343
+
344
+ # --- Create multi-region retrieval chain --- #
345
+ class MultiRegionRetriever:
346
+ def __init__(self, vectordb, species_filter, selected_region, other_regions, k=3):
347
+ self.vectordb = vectordb
348
+ self.species_filter = species_filter
349
+ self.selected_region = selected_region
350
+ self.other_regions = other_regions
351
+ self.k = k
352
+
353
+ def get_relevant_documents(self, query):
354
+ all_docs = []
355
+ region_docs = {}
356
+
357
+ # First get documents for selected region
358
+ # Fix illogical condition: self.selected_region == self.selected_region is always True
359
+ # Replace with a check if selected_region exists
360
+ if self.selected_region:
361
+ selected_filter = {"$and": [self.species_filter, {"region": self.selected_region}]}
362
+ selected_retriever = self.vectordb.as_retriever(search_kwargs={'k': self.k, 'filter': selected_filter})
363
+ try:
364
+ selected_docs = selected_retriever.get_relevant_documents(query)
365
+ if selected_docs:
366
+ all_docs.extend(selected_docs)
367
+ region_docs[self.selected_region] = selected_docs
368
+ except Exception as e:
369
+ print(f"Error retrieving docs for selected region {self.selected_region}: {e}")
370
+
371
+ # Then get documents for each other region
372
+ for other_region in self.other_regions:
373
+ if other_region != self.selected_region: # Skip if same as selected region
374
+ other_filter = {"$and": [self.species_filter, {"region": other_region}]}
375
+ other_retriever = self.vectordb.as_retriever(search_kwargs={'k': self.k, 'filter': other_filter})
376
+ try:
377
+ other_docs = other_retriever.get_relevant_documents(query)
378
+ if other_docs:
379
+ all_docs.extend(other_docs)
380
+ region_docs[other_region] = other_docs
381
+ except Exception as e:
382
+ print(f"Error retrieving docs for region {other_region}: {e}")
383
+
384
+ # Store the region-specific documents for formatting in the prompt
385
+ self.last_region_docs = region_docs
386
+ return all_docs
387
+
388
+ # Initialize the multi-region retriever
389
+ multi_region_retriever = MultiRegionRetriever(
390
+ vectordb=vectordb,
391
+ species_filter=species_filter,
392
+ selected_region=region,
393
+ other_regions=available_regions_list,
394
+ k=retriever_k_value
395
+ )
396
+
397
+ # Custom prompt handler that formats context by region
398
+ # Remove unused imports
399
+ # from langchain.chains.combine_documents import create_stuff_documents_chain
400
+ # from langchain.chains import create_retrieval_chain
401
+
402
+ # Updated prompt template for multi-part response with region-specific contexts
403
+ general_system_template = f"""
404
+ You are an AI assistant specialized in providing information about {domain_name.lower()} ({specie_selector}). The user is primarily interested in the '{region}' region.
405
+
406
+ The following context has been retrieved from a database organized by region:
407
+
408
+ {{context}}
409
+
410
+ Instructions:
411
+ 1. Analyze the user's question in relation to {specie_selector}.
412
+ 2. Structure your answer in the following multi-part format:
413
+
414
+ **Part 1: Selected Region Information ({region})**
415
+ If relevant information exists in the context for the selected region that answers the user's query:
416
+ Based on your selected region ({region}), for {specie_selector}, [summary of information for selected region] [1].
417
+
418
+ If no relevant information exists for the selected region:
419
+ "Based on the provided documents, there is no specific information for {specie_selector} in your selected region ({region}) regarding your question."
420
+
421
+ **Part 2: Other Regions Information** (Only include if information from other regions is available AND relevant to the query)
422
+ If you found relevant information from other regions that answers the user's query, include:
423
+
424
+ Additionally, information was found for other regions:
425
+ - In [Other Region Name]: [summary of information that directly answers the user's query] [next reference number].
426
+ - In [Another Region Name]: [summary of information that directly answers the user's query] [next reference number].
427
+
428
+ Only include regions where the information directly addresses the user's question.
429
+ Use consecutive reference numbers starting from where Part 1 left off.
430
+ If no other regions have relevant information, omit this part entirely.
431
+
432
+ **Part 3: General Knowledge** (Only include if context information is insufficient or incomplete)
433
+ If the available context does not fully address the query, add:
434
+
435
+ Based on my general knowledge as {model_name}: [Your general knowledge insights that directly address the query] [next reference number].
436
+
437
+ If the context information is sufficient, omit this part entirely.
438
+
439
+ 3. After providing all parts of your answer, include a References section ONLY for information you actually used:
440
+
441
+ References:
442
+ [1] Based on Expert Curated information about {specie_selector} in {region}
443
+ [2] Based on Expert Curated information about {specie_selector} in [Other Region Name]
444
+ [3] Based on Expert Curated information about {specie_selector} in [Another Region Name]
445
+ [x] {model_name}'s inherent knowledge
446
+
447
+ IMPORTANT:
448
+ - Only include reference numbers that correspond to information you actually used in your answer.
449
+ - Reference numbers should be sequential (1, 2, 3...) based on the order they appear in your answer.
450
+ - If you don't use information from a particular region, don't include a reference for it.
451
+ - If you don't use general knowledge, don't include a reference for it.
452
+ - Every claim with a reference marker [x] must have a corresponding entry in the References section.
453
+
454
+ 4. Apply this language constraint: {language_constraint}
455
+ 5. Keep your summaries concise and directly related to the user's question.
456
+
457
+ User Question about {specie_selector} ({domain_name}): {{question}}
458
+ """
459
+
460
+ class RegionFormattingLLMChain:
461
+ def __init__(self, llm, prompt, retriever):
462
+ self.llm = llm
463
+ self.prompt = prompt
464
+ self.retriever = retriever
465
+
466
+ def __call__(self, inputs):
467
+ # Get documents using the multi-region retriever
468
+ docs = self.retriever.get_relevant_documents(inputs["question"])
469
+
470
+ # Get the region-specific document organization
471
+ region_docs = getattr(self.retriever, "last_region_docs", {})
472
+
473
+ # Format context with clear region sections
474
+ formatted_context = ""
475
+
476
+ # First add context for selected region if available
477
+ if region in region_docs:
478
+ formatted_context += f"--- CONTEXT FROM SELECTED REGION: {region} ---\n"
479
+ for i, doc in enumerate(region_docs[region]):
480
+ formatted_context += f"Document {i+1} from {region}:\n{doc.page_content}\n\n"
481
+
482
+ # Then add context for each other region
483
+ for other_region in [r for r in region_docs.keys() if r != region]:
484
+ formatted_context += f"--- CONTEXT FROM OTHER REGION: {other_region} ---\n"
485
+ for i, doc in enumerate(region_docs[other_region]):
486
+ formatted_context += f"Document {i+1} from {other_region}:\n{doc.page_content}\n\n"
487
+
488
+ # Replace the context placeholder with our formatted context
489
+ formatted_prompt = self.prompt.format(
490
+ context=formatted_context,
491
+ question=inputs["question"]
492
+ )
493
+
494
+ # Call the LLM with our formatted prompt
495
+ result = self.llm.invoke(formatted_prompt)
496
+
497
+ # Return the result in the expected format
498
+ return {"answer": result.content, "source_documents": docs}
499
+
500
+ # Create the custom chain
501
+ qa_chain = RegionFormattingLLMChain(
502
+ llm=chosen_llm,
503
+ prompt=general_system_template,
504
+ retriever=multi_region_retriever
505
+ )
506
+
507
+ return qa_chain, availability_message
508
+ # result = qa_chain.invoke({"question": "where are stalk borer eggs laid?"})
509
+ # print("Got the first LLM task working: ", result)
510
+
511
+
512
+ #Application Interface:
513
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
514
+ with gr.Row():
515
+ with gr.Column(scale=1):
516
+ gr.Markdown(
517
+ """
518
+ ![Logo](file/logo1.png)
519
+ """
520
+ )
521
+ with gr.Column(scale=1):
522
+ gr.Markdown(
523
+ """
524
+ ![Logo](file/logo2.png)
525
+ """
526
+ )
527
+
528
+ # Configure UI layout
529
+ chatbot = gr.Chatbot(height=600, label="AgLLM")
530
+ with gr.Row():
531
+ with gr.Column(scale=1):
532
+ with gr.Row():
533
+ domain_name = gr.Dropdown(
534
+ list(["Insects", "Weeds"]),
535
+ value="Insects",
536
+ label="Domain",
537
+ info="Select Domain",
538
+ interactive=True,
539
+ scale=1,
540
+ visible=True
541
+ )
542
+ region_selector = gr.Dropdown(
543
+ list(["United States", "India", "Africa"]), # Updated regions
544
+ value="United States", # Updated default
545
+ label="Region",
546
+ info="Select the Region",
547
+ interactive=True,
548
+ scale=1,
549
+ visible=True
550
+ )
551
+
552
+ # Model selection
553
+ specie_selector = gr.Dropdown(
554
+ list(set(species_list_insects)),
555
+ value=species_list_insects[0],
556
+ label="Species",
557
+ info="Select the Species",
558
+ interactive=True,
559
+ scale=1,
560
+ visible=True
561
+ )
562
+ with gr.Row():
563
+ model_name = gr.Dropdown(
564
+ list(["GPT-4", "GPT-3.5", "Llama-3 70B", "Llama-3 8B", "Gemini-1.5 Pro", "Claude 3 Opus", "Claude 3.5 Sonnet"]),
565
+ value="Llama-3 70B",
566
+ label="LLM",
567
+ info="Select the LLM",
568
+ interactive=True,
569
+ scale=1,
570
+ visible=True
571
+ )
572
+ application_mode = gr.Dropdown(
573
+ list(["Farmer", "Researcher"]),
574
+ value="Researcher",
575
+ label="Mode",
576
+ info="Select the Mode",
577
+ interactive=True,
578
+ scale=1,
579
+ visible=True
580
+ )
581
+ region_availability_display = gr.Markdown(value="Select species/domain to see region availability.") # Added display area
582
+
583
+ with gr.Column(scale=2):
584
+ # User input prompt text field
585
+ user_prompt_message = gr.Textbox(placeholder="Please add user prompt here", label="User prompt")
586
+ with gr.Row():
587
+ # clear = gr.Button("Clear Conversation", scale=2)
588
+ submitBtn = gr.Button("Submit", scale=8)
589
+
590
+ state = gr.State([])
591
+ qa_chain_state = gr.State(value=None)
592
+
593
+ # Handle user message
594
+ def user(user_prompt_message, history):
595
+ # print("HISTORY IS: ", history) # TODO: REMOVE IT LATER
596
+ if user_prompt_message != "":
597
+ return history + [[user_prompt_message, None]]
598
+ else:
599
+ return history + [["Invalid prompts - user prompt cannot be empty", None]]
600
+
601
+ # Chatbot logic for configuration, sending the prompts, rendering the streamed back generations, etc.
602
+ def bot(model_name, application_mode, user_prompt_message, history, messages_history, qa_chain, domain_name, region): # Removed use_rag
603
+ if qa_chain == None:
604
+ # Initial QA chain setup if not already done (uses default species for the selected domain)
605
+ initial_species = species_list_insects[0] if domain_name == "Insects" else species_list_weeds[0]
606
+ # Need to handle the tuple returned by init_qa_chain now
607
+ # Use the currently selected region for initialization if qa_chain is None
608
+ qa_chain, _ = init_qa_chain(initial_species, application_mode, model_name, domain_name, region) # Pass region
609
+
610
+ history[-1][1] = "" # Placeholder for the answer
611
+
612
+ # RAG is always ON now
613
+ result = qa_chain({"question": user_prompt_message, "chat_history": messages_history})
614
+ answer = result["answer"]
615
+ # source_documents = result.get("source_documents", []) # Keep source_documents if needed for debugging or future refinement
616
+ # formatted_response = format_response_with_source(answer, source_documents, domain_name, region) # REMOVED: Rely on LLM prompt now
617
+
618
+ history[-1][1] = answer # Assign raw LLM answer directly
619
+ return [history, messages_history]
620
+
621
+ # Helper function to format the response with source information
622
+ # def format_response_with_source(answer, source_documents, domain_name, region): # Pass region # FUNCTION NO LONGER USED
623
+ # try:
624
+ # answer_start = answer.find("Answer:")
625
+ # source_start = answer.find("Source:")
626
+ # ... (rest of the function commented out or removed) ...
627
+ # except Exception as e:
628
+ # print(f"Error parsing output or formatting source: {e}")
629
+ # formatted_response = answer # Return raw answer on error
630
+ #
631
+ # return formatted_response
632
+
633
+ # Initialize the chat history with default system message
634
+ def init_history(messages_history):
635
+ messages_history = []
636
+ messages_history += [system_message]
637
+ return messages_history
638
+
639
+ # Clean up the user input text field
640
+ def input_cleanup():
641
+ return ""
642
+
643
+ def init_qa_chain(specie_selector, application_mode, model_name, domain_name, region): # Removed use_rag
644
+ print(f"--- init_qa_chain wrapper called with domain: '{domain_name}' ---") # DIAGNOSTIC PRINT
645
+ qa_chain_instance = None
646
+ availability_msg = "Error initializing QA chain."
647
+ try:
648
+ if domain_name=="Insects":
649
+ qa_chain_instance, availability_msg = initialize_qa_chain(specie_selector, application_mode, model_name, region, default_persist_directory_insects, domain_name) # Removed use_rag
650
+ elif domain_name=="Weeds":
651
+ qa_chain_instance, availability_msg = initialize_qa_chain(specie_selector, application_mode, model_name, region, default_persist_directory_weeds, domain_name) # Removed use_rag
652
+ else:
653
+ print("No Appropriate Chain Selected")
654
+ availability_msg = "Invalid domain selected."
655
+ # Return None for chain and the message
656
+ except Exception as e:
657
+ print(f"Error in init_qa_chain wrapper: {e}")
658
+ availability_msg = f"Error initializing: {e}"
659
+
660
+ return qa_chain_instance, availability_msg # Return both chain and message
661
+
662
+ # Update QA chain AND availability message when relevant inputs change
663
+ inputs_for_qa_chain = [specie_selector, application_mode, model_name, domain_name, region_selector] # CORRECT ORDER
664
+ outputs_for_qa_chain = [qa_chain_state, region_availability_display]
665
+
666
+ # specie_selector.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
667
+ # model_name.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
668
+ # region_selector.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
669
+ # application_mode.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
670
+ # domain_name.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
671
+ specie_selector.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
672
+ model_name.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
673
+ region_selector.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
674
+ application_mode.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
675
+ domain_name.change(init_qa_chain, inputs=inputs_for_qa_chain, outputs=outputs_for_qa_chain)
676
+
677
+ #####
678
+ def update_species_list(domain):
679
+ if domain == "Insects":
680
+ return gr.Dropdown( species_list_insects, value=species_list_insects[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
681
+ elif domain == "Weeds":
682
+ return gr.Dropdown( species_list_weeds, value=species_list_weeds[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
683
+
684
+ domain_name.change(
685
+ update_species_list,
686
+ inputs=[domain_name],
687
+ outputs=[specie_selector]
688
+ )
689
+
690
+ # When the user clicks Enter and the user message is submitted
691
+ user_prompt_message.submit(
692
+ user,
693
+ [user_prompt_message, chatbot],
694
+ [chatbot],
695
+ queue=False
696
+ ).then(
697
+ bot,
698
+ [model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name, region_selector], # Removed use_rag
699
+ [chatbot, state]
700
+ ).then(input_cleanup,
701
+ [],
702
+ [user_prompt_message],
703
+ queue=False
704
+ )
705
+
706
+ # When the user clicks the submit button
707
+ submitBtn.click(
708
+ user,
709
+ [user_prompt_message, chatbot],
710
+ [chatbot],
711
+ queue=False
712
+ ).then(
713
+ bot,
714
+ [model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name, region_selector], # Removed use_rag
715
+ [chatbot, state]
716
+ ).then(
717
+ input_cleanup,
718
+ [],
719
+ [user_prompt_message],
720
+ queue=False
721
+ )
722
+
723
+ # When the user clicks the clear button
724
+ # clear.click(lambda: None, None, chatbot, queue=False).success(init_history, [state], [state])
725
+ if __name__ == "__main__":
726
+ # demo.launch()
727
+ demo.queue().launch(allowed_paths=["/"], share=False, show_error=True)
app_backup.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # https://stackoverflow.com/questions/76175046/how-to-add-prompt-to-langchain-conversationalretrievalchain-chat-over-docs-with
3
+ # again from:
4
+ # https://python.langchain.com/docs/integrations/providers/vectara/vectara_chat
5
+ from langchain.document_loaders import PyPDFDirectoryLoader
6
+ import pandas as pd
7
+ import langchain
8
+ from queue import Queue
9
+ from typing import Any
10
+ from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
11
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
+ from langchain.schema import LLMResult
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
+ from langchain.vectorstores import FAISS
15
+ from langchain.prompts.prompt import PromptTemplate
16
+ from anyio.from_thread import start_blocking_portal #For model callback streaming
17
+
18
+
19
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
20
+ import os
21
+ from dotenv import load_dotenv
22
+
23
+ import streamlit as st
24
+ import json
25
+ from langchain.document_loaders import PyPDFLoader
26
+ from langchain.text_splitter import CharacterTextSplitter
27
+ from langchain.embeddings import OpenAIEmbeddings
28
+ from langchain.chains.question_answering import load_qa_chain
29
+ from langchain.chat_models import ChatOpenAI
30
+ # from langchain.chat_models import ChatAnthropic
31
+ from langchain_anthropic import ChatAnthropic
32
+ from langchain.vectorstores import Chroma
33
+ import chromadb
34
+
35
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
36
+ from langchain.llms import OpenAI
37
+ from langchain.chains import RetrievalQA
38
+ from langchain.document_loaders import TextLoader
39
+ from langchain.document_loaders import DirectoryLoader
40
+ from langchain_community.document_loaders import PyMuPDFLoader
41
+ from langchain.schema import Document
42
+
43
+ from langchain.memory import ConversationBufferMemory
44
+
45
+ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
46
+ from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
47
+ import gradio as gr
48
+ from langchain.memory import ConversationBufferMemory
49
+ from langchain.chains import ConversationalRetrievalChain
50
+ print("Started")
51
+
52
+ def get_species_list_from_db(db_name):
53
+ embedding = OpenAIEmbeddings()
54
+ vectordb_temp = Chroma(persist_directory=db_name,
55
+ embedding_function=embedding)
56
+ species_list=[]
57
+ for meta in vectordb_temp.get()["metadatas"] :
58
+ try:
59
+ matched_first_species = meta['matched_specie_0']
60
+ except KeyError:
61
+ continue
62
+ # Since each document is considered as a single chunk, the chunk_index is 0 for all
63
+ species_list.append( matched_first_species)
64
+
65
+ return species_list
66
+
67
+
68
+
69
+ # default_persist_directory = './db5' # For deployement
70
+ default_persist_directory_insects='./vector-databases-deployed/db5-agllm-data-isu-field-insects-all-species'
71
+ default_persist_directory_weeds='./vector-databases-deployed/db5-agllm-data-isu-field-weeds-all-species'
72
+
73
+ species_list_insects=get_species_list_from_db(default_persist_directory_insects)
74
+ species_list_weeds=get_species_list_from_db(default_persist_directory_weeds)
75
+ # default_persist_directory = 'vector-databases/db5-pre-completion' # For Development
76
+ csv_filepath1 = "./agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx"
77
+ csv_filepath2 = "./agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx"
78
+ model_name=4
79
+ max_tokens=400
80
+ system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later.
81
+ langchain.debug=False # TODO: DOUBLE CHECK
82
+ from langchain import globals
83
+ globals.set_debug(False)
84
+
85
+ retriever_k_value=3
86
+ embedding = OpenAIEmbeddings()
87
+ print("Started....")
88
+ class ChatOpenRouter(ChatOpenAI):
89
+ openai_api_base: str
90
+ openai_api_key: str
91
+ model_name: str
92
+
93
+ def __init__(self,
94
+ model_name: str,
95
+ openai_api_key: [str] = None,
96
+ openai_api_base: str = "https://openrouter.ai/api/v1",
97
+ **kwargs):
98
+ openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY')
99
+ super().__init__(openai_api_base=openai_api_base,
100
+ openai_api_key=openai_api_key,
101
+ model_name=model_name, **kwargs)
102
+
103
+
104
+ ######### todo: skipping the first step
105
+
106
+
107
+
108
+ # print(# Single example
109
+ # vectordb.as_retriever(k=2, search_kwargs={"filter": {"matched_specie_0": "Hypagyrtis unipunctata"}, 'k':1}).get_relevant_documents(
110
+ # "Checking if retriever is correctly initalized?"
111
+ # ))
112
+
113
+ columns = ['species', 'common name', 'order', 'family',
114
+ 'genus', 'Updated role in ecosystem', 'Proof',
115
+ 'ipm strategies', 'size of insect', 'geographical spread',
116
+ 'life cycle specifics', 'pest for plant species', 'species status',
117
+ 'distribution area', 'appearance', 'identification']
118
+
119
+ df1 = pd.read_excel(csv_filepath1, usecols=columns)
120
+ df2 = pd.read_excel(csv_filepath2, usecols=columns)
121
+
122
+ all_insects_data = pd.concat([df1, df2], ignore_index=True)
123
+
124
+ def get_prompt_with_vetted_info_from_specie_name(search_for_specie, mode):
125
+
126
+ def read_and_format_filtered_csv_better(insect_specie):
127
+ filtered_data = all_insects_data[all_insects_data['species'] == insect_specie]
128
+ formatted_data = ""
129
+ # Format the filtered data
130
+ for index, row in filtered_data.iterrows():
131
+ row_data = [f"{col}: {row[col]}" for col in filtered_data.columns]
132
+ formatted_row = "\n".join(row_data)
133
+ formatted_data += f"{formatted_row}\n"
134
+
135
+ return formatted_data
136
+
137
+ # Use the path to your CSV file here
138
+
139
+ vetted_info=read_and_format_filtered_csv_better(search_for_specie)
140
+
141
+ if mode=="Farmer":
142
+ language_constraint="The language should be acustomed to the Farmers. Given question is likely to be asked by a farmer in the field will ask which will help to make decisions which are immediate and practical."
143
+ elif mode=="Researcher":
144
+ language_constraint="The language should be acustomed to a researcher. Given question is likely to be asked by a scientist which are comprehensive and aimed at exploring new knowledge or refining existing methodologies"
145
+ else:
146
+ print("No valid mode provided. Exiting")
147
+ exit()
148
+ # general_system_template = """
149
+ # In every question you are provided information about the insect/weed. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect/weed species and a question by the user. answer the question according to these two types of informations.
150
+ # ----
151
+ # Vetted info is as follows:
152
+ # {vetted_info}
153
+ # ----
154
+ # The context retrieved for documents about this particular question is as follows:
155
+ # {context}
156
+ # ----
157
+ # Additional Instruction:
158
+ # 1. Reference Constraint
159
+ # At the end of each answer provide the source/reference for the given data in following format:
160
+ # \n\n[enter two new lines before writing below] References:
161
+ # Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'.
162
+ # Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used.
163
+ # 2. Information Constraint:
164
+ # Only answer the question from information provided otherwise say you dont know. You have to answer in 50 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about.
165
+ # 3. Language constraint:
166
+ # {language_constraint}
167
+
168
+ # ----
169
+ # """.format(vetted_info=vetted_info, language_constraint=language_constraint,context="{context}", )
170
+
171
+ general_system_template = f"""
172
+ You are an AI assistant specialized in providing information about insects/weeds. Answer the user's question based on the available information or your general knowledge.
173
+
174
+ The context retrieved for this question is as follows:
175
+ {{context}}
176
+
177
+ Instructions:
178
+ 1. Evaluate the relevance of the provided context to the question.
179
+ 2. If the context contains relevant information, use it to answer the question.
180
+ 3. If the context does not contain relevant information, use your general knowledge to answer the question.
181
+ 4. Format your response as follows:
182
+ Answer: Provide a concise answer in less than 50 words.
183
+ Reference: If you used the provided context, cite the specific information used. If you used your general knowledge, state "Based on general knowledge".
184
+
185
+ 5. Language constraint:
186
+ {language_constraint}
187
+
188
+ Question: {{question}}
189
+ """
190
+
191
+ general_user_template = "Question:```{question}```"
192
+ messages_formatted = [
193
+ SystemMessagePromptTemplate.from_template(general_system_template),
194
+ HumanMessagePromptTemplate.from_template(general_user_template)
195
+ ]
196
+ qa_prompt = ChatPromptTemplate.from_messages( messages_formatted )
197
+ # print(qa_prompt)
198
+ return qa_prompt
199
+
200
+
201
+
202
+
203
+
204
+ qa_prompt=get_prompt_with_vetted_info_from_specie_name("Papaipema nebris", "Researcher")
205
+ # print("First prompt is intialized as: " , qa_prompt, "\n\n")
206
+
207
+
208
+ memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer', return_messages=True) # https://github.com/langchain-ai/langchain/issues/9394#issuecomment-1683538834
209
+
210
+
211
+ if model_name==4:
212
+ llm_openai = ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens) # TODO: NEW MODEL VERSION AVAILABLE
213
+ else:
214
+ llm_openai = ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
215
+
216
+ specie_selector="Papaipema nebris"
217
+ filter = {
218
+ "$or": [
219
+ {"matched_specie_0": specie_selector},
220
+ {"matched_specie_1": specie_selector},
221
+ {"matched_specie_2": specie_selector},
222
+ ]
223
+ }
224
+ # retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
225
+
226
+ # qa_chain = ConversationalRetrievalChain.from_llm(
227
+ # llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,\
228
+ # combine_docs_chain_kwargs={'prompt': qa_prompt}
229
+
230
+ # )
231
+ #
232
+
233
+ def initialize_qa_chain(specie_selector, application_mode, model_name="GPT-4", database_persistent_directory=default_persist_directory_insects):
234
+ if model_name=="GPT-4":
235
+ chosen_llm=ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens)
236
+ elif model_name=="GPT-3.5":
237
+ chosen_llm=ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
238
+ elif model_name=="Llama-3 70B":
239
+ chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-70b-instruct", temperature=0,max_tokens=max_tokens )
240
+ elif model_name=="Llama-3 8B":
241
+ chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-8b-instruct", temperature=0, max_tokens=max_tokens)
242
+ elif model_name=="Gemini-1.5 Pro":
243
+ chosen_llm = ChatOpenRouter(model_name="google/gemini-pro-1.5", temperature=0, max_tokens=max_tokens)
244
+ elif model_name=="Claude 3 Opus":
245
+ chosen_llm = ChatAnthropic(model_name='claude-3-opus-20240229', temperature=0, max_tokens=max_tokens)
246
+
247
+ else:
248
+ print("No appropriate llm was selected")
249
+ exit()
250
+
251
+
252
+
253
+ filter = {
254
+ "$or": [
255
+ {"matched_specie_0": specie_selector},
256
+ {"matched_specie_1": specie_selector},
257
+ {"matched_specie_2": specie_selector},
258
+ {"matched_specie_3": specie_selector},
259
+ {"matched_specie_4": specie_selector},
260
+ {"matched_specie_5": specie_selector},
261
+ {"matched_specie_6": specie_selector},
262
+ {"matched_specie_7": specie_selector},
263
+ {"matched_specie_8": specie_selector},
264
+ {"matched_specie_9": specie_selector},
265
+ {"matched_specie_10": specie_selector}
266
+ ]
267
+ }
268
+
269
+ embedding = OpenAIEmbeddings()
270
+ vectordb = Chroma(persist_directory=database_persistent_directory,
271
+ embedding_function=embedding)
272
+
273
+ print("got updated retriever without metadata filtering")
274
+ retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
275
+
276
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
277
+ qa_prompt=get_prompt_with_vetted_info_from_specie_name(specie_selector, application_mode)
278
+ qa_chain = ConversationalRetrievalChain.from_llm(
279
+ chosen_llm, retriever, memory=memory, verbose=False, return_source_documents=True,
280
+ combine_docs_chain_kwargs={'prompt': qa_prompt}
281
+ )
282
+
283
+ return qa_chain
284
+ # result = qa_chain.invoke({"question": "where are stalk borer eggs laid?"})
285
+ # print("Got the first LLM task working: ", result)
286
+
287
+
288
+ #Application Interface:
289
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
290
+ with gr.Row():
291
+ with gr.Column(scale=1):
292
+ gr.Markdown(
293
+ """
294
+ ![Logo](file/logo1.png)
295
+ """
296
+ )
297
+ with gr.Column(scale=1):
298
+ gr.Markdown(
299
+ """
300
+ ![Logo](file/logo2.png)
301
+ """
302
+ )
303
+
304
+ # Configure UI layout
305
+ chatbot = gr.Chatbot(height=600, label="AgLLM")
306
+ with gr.Row():
307
+ with gr.Column(scale=1):
308
+ with gr.Row():
309
+ domain_name = gr.Dropdown(
310
+ list(["Insects", "Weeds"]),
311
+ value="Insects",
312
+ label="Domain",
313
+ info="Select Domain",
314
+ interactive=True,
315
+ scale=1,
316
+ visible=True
317
+ )
318
+
319
+ # Model selection
320
+ specie_selector = gr.Dropdown(
321
+ species_list_insects,
322
+ value=species_list_insects[0],
323
+ label="Species",
324
+ info="Select the Species",
325
+ interactive=True,
326
+ scale=1,
327
+ visible=True
328
+ )
329
+ with gr.Row():
330
+ model_name = gr.Dropdown(
331
+ list(["GPT-4", "GPT-3.5", "Llama-3 70B", "Llama-3 8B", "Gemini-1.5 Pro", "Claude 3 Opus"]),
332
+ value="Llama-3 70B",
333
+ label="LLM",
334
+ info="Select the LLM",
335
+ interactive=True,
336
+ scale=1,
337
+ visible=True
338
+ )
339
+ application_mode = gr.Dropdown(
340
+ list(["Farmer", "Researcher"]),
341
+ value="Researcher",
342
+ label="Mode",
343
+ info="Select the Mode",
344
+ interactive=True,
345
+ scale=1,
346
+ visible=True
347
+ )
348
+
349
+
350
+ with gr.Column(scale=2):
351
+ # User input prompt text field
352
+ user_prompt_message = gr.Textbox(placeholder="Please add user prompt here", label="User prompt")
353
+ with gr.Row():
354
+ # clear = gr.Button("Clear Conversation", scale=2)
355
+ submitBtn = gr.Button("Submit", scale=8)
356
+
357
+ state = gr.State([])
358
+ qa_chain_state = gr.State(value=None)
359
+
360
+ # Handle user message
361
+ def user(user_prompt_message, history):
362
+ # print("HISTORY IS: ", history) # TODO: REMOVE IT LATER
363
+ if user_prompt_message != "":
364
+ return history + [[user_prompt_message, None]]
365
+ else:
366
+ return history + [["Invalid prompts - user prompt cannot be empty", None]]
367
+
368
+ # Chatbot logic for configuration, sending the prompts, rendering the streamed back generations, etc.
369
+ def bot(model_name, application_mode, user_prompt_message, history, messages_history, qa_chain, domain_name):
370
+ if qa_chain == None:
371
+ qa_chain=init_qa_chain(species_list_insects[0], application_mode, model_name, domain_name)
372
+
373
+ dialog = []
374
+ bot_message = ""
375
+ history[-1][1] = "" # Placeholder for the answer
376
+
377
+ dialog = [
378
+ {"role": "user", "content": user_prompt_message},
379
+ ]
380
+ messages_history += dialog
381
+
382
+ # Queue for streamed character rendering
383
+ q = Queue()
384
+
385
+ # Async task for streamed chain results wired to callbacks we previously defined, so we don't block the UI
386
+
387
+ def task(user_prompt_message):
388
+ result = qa_chain.invoke({"question": user_prompt_message})
389
+ answer = result["answer"]
390
+
391
+ try:
392
+ answer_start = answer.find("Answer:")
393
+ reference_start = answer.find("Reference:")
394
+
395
+ if answer_start != -1 and reference_start != -1:
396
+ model_answer = answer[answer_start + len("Answer:"):reference_start].strip()
397
+ reference = answer[reference_start + len("Reference:"):].strip()
398
+ formatted_response = f"Answer:\n{model_answer}\n\nReferences:\n{reference}"
399
+ else:
400
+ formatted_response = answer
401
+ except:
402
+ print(f"Error parsing so displaying the raw output")
403
+ formatted_response = answer
404
+
405
+ return formatted_response
406
+
407
+ history[-1][1] = task(user_prompt_message)
408
+ return [history, messages_history]
409
+
410
+ # Initialize the chat history with default system message
411
+ def init_history(messages_history):
412
+ messages_history = []
413
+ messages_history += [system_message]
414
+ return messages_history
415
+
416
+ # Clean up the user input text field
417
+ def input_cleanup():
418
+ return ""
419
+
420
+ def init_qa_chain(specie_selector, application_mode, model_name, domain_name):
421
+ if domain_name=="Insects":
422
+ qa_chain = initialize_qa_chain(specie_selector, application_mode, model_name, default_persist_directory_insects)
423
+ elif domain_name=="Weeds":
424
+ qa_chain = initialize_qa_chain(specie_selector, application_mode, model_name, default_persist_directory_weeds)
425
+ else:
426
+ print("No Appropriate Chain Selected")
427
+ return qa_chain
428
+
429
+ specie_selector.change(
430
+ init_qa_chain,
431
+ inputs=[specie_selector, application_mode,model_name, domain_name ],
432
+ outputs=[qa_chain_state]
433
+ )
434
+ model_name.change(
435
+ init_qa_chain,
436
+ inputs=[specie_selector, application_mode,model_name, domain_name ],
437
+ outputs=[qa_chain_state]
438
+ )
439
+
440
+ #####
441
+ def update_species_list(domain):
442
+ if domain == "Insects":
443
+ return gr.Dropdown( species_list_insects, value=species_list_insects[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
444
+ elif domain == "Weeds":
445
+ return gr.Dropdown( species_list_weeds, value=species_list_weeds[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
446
+
447
+ domain_name.change(
448
+ update_species_list,
449
+ inputs=[domain_name],
450
+ outputs=[specie_selector]
451
+ )
452
+
453
+ # When the user clicks Enter and the user message is submitted
454
+ user_prompt_message.submit(
455
+ user,
456
+ [user_prompt_message, chatbot],
457
+ [chatbot],
458
+ queue=False
459
+ ).then(
460
+ bot,
461
+ [model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name],
462
+ [chatbot, state]
463
+ ).then(input_cleanup,
464
+ [],
465
+ [user_prompt_message],
466
+ queue=False
467
+ )
468
+
469
+ # When the user clicks the submit button
470
+ submitBtn.click(
471
+ user,
472
+ [user_prompt_message, chatbot],
473
+ [chatbot],
474
+ queue=False
475
+ ).then(
476
+ bot,
477
+ [model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name],
478
+ [chatbot, state]
479
+ ).then(
480
+ input_cleanup,
481
+ [],
482
+ [user_prompt_message],
483
+ queue=False
484
+ )
485
+
486
+ # When the user clicks the clear button
487
+ # clear.click(lambda: None, None, chatbot, queue=False).success(init_history, [state], [state])
488
+ if __name__ == "__main__":
489
+ # demo.launch()
490
+ demo.queue().launch(allowed_paths=["/"], share=False, show_error=True)
app_backup_2.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # https://stackoverflow.com/questions/76175046/how-to-add-prompt-to-langchain-conversationalretrievalchain-chat-over-docs-with
3
+ # again from:
4
+ # https://python.langchain.com/docs/integrations/providers/vectara/vectara_chat
5
+ from langchain.document_loaders import PyPDFDirectoryLoader
6
+ import pandas as pd
7
+ import langchain
8
+ from queue import Queue
9
+ from typing import Any
10
+ from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
11
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
+ from langchain.schema import LLMResult
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
+ from langchain.vectorstores import FAISS
15
+ from langchain.prompts.prompt import PromptTemplate
16
+ from anyio.from_thread import start_blocking_portal #For model callback streaming
17
+
18
+
19
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
20
+ import os
21
+ from dotenv import load_dotenv
22
+
23
+ import streamlit as st
24
+ import json
25
+ from langchain.document_loaders import PyPDFLoader
26
+ from langchain.text_splitter import CharacterTextSplitter
27
+ from langchain.embeddings import OpenAIEmbeddings
28
+ from langchain.chains.question_answering import load_qa_chain
29
+ from langchain.chat_models import ChatOpenAI
30
+ # from langchain.chat_models import ChatAnthropic
31
+ from langchain_anthropic import ChatAnthropic
32
+ from langchain.vectorstores import Chroma
33
+ import chromadb
34
+
35
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
36
+ from langchain.llms import OpenAI
37
+ from langchain.chains import RetrievalQA
38
+ from langchain.document_loaders import TextLoader
39
+ from langchain.document_loaders import DirectoryLoader
40
+ from langchain_community.document_loaders import PyMuPDFLoader
41
+ from langchain.schema import Document
42
+
43
+ from langchain.memory import ConversationBufferMemory
44
+
45
+ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
46
+ from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
47
+ import gradio as gr
48
+ from langchain.memory import ConversationBufferMemory
49
+ from langchain.chains import ConversationalRetrievalChain
50
+ print("Started")
51
+
52
+ def get_species_list_from_db(db_name):
53
+ embedding = OpenAIEmbeddings()
54
+ vectordb_temp = Chroma(persist_directory=db_name,
55
+ embedding_function=embedding)
56
+ species_list=[]
57
+ for meta in vectordb_temp.get()["metadatas"] :
58
+ try:
59
+ matched_first_species = meta['matched_specie_0']
60
+ except KeyError:
61
+ continue
62
+ # Since each document is considered as a single chunk, the chunk_index is 0 for all
63
+ species_list.append( matched_first_species)
64
+
65
+ return species_list
66
+
67
+
68
+
69
+ # default_persist_directory = './db5' # For deployement
70
+ default_persist_directory_insects='./vector-databases-deployed/db5-agllm-data-isu-field-insects-all-species'
71
+ default_persist_directory_weeds='./vector-databases-deployed/db5-agllm-data-isu-field-weeds-all-species'
72
+
73
+ species_list_insects=get_species_list_from_db(default_persist_directory_insects)
74
+ species_list_weeds=get_species_list_from_db(default_persist_directory_weeds)
75
+ # default_persist_directory = 'vector-databases/db5-pre-completion' # For Development
76
+ csv_filepath1 = "./agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx"
77
+ csv_filepath2 = "./agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx"
78
+ model_name=4
79
+ max_tokens=400
80
+ system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later.
81
+ langchain.debug=False # TODO: DOUBLE CHECK
82
+ from langchain import globals
83
+ globals.set_debug(False)
84
+
85
+ retriever_k_value=3
86
+ embedding = OpenAIEmbeddings()
87
+ print("Started....")
88
+ class ChatOpenRouter(ChatOpenAI):
89
+ openai_api_base: str
90
+ openai_api_key: str
91
+ model_name: str
92
+
93
+ def __init__(self,
94
+ model_name: str,
95
+ openai_api_key: [str] = None,
96
+ openai_api_base: str = "https://openrouter.ai/api/v1",
97
+ **kwargs):
98
+ openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY')
99
+ super().__init__(openai_api_base=openai_api_base,
100
+ openai_api_key=openai_api_key,
101
+ model_name=model_name, **kwargs)
102
+
103
+
104
+ ######### todo: skipping the first step
105
+
106
+
107
+
108
+ # print(# Single example
109
+ # vectordb.as_retriever(k=2, search_kwargs={"filter": {"matched_specie_0": "Hypagyrtis unipunctata"}, 'k':1}).get_relevant_documents(
110
+ # "Checking if retriever is correctly initalized?"
111
+ # ))
112
+
113
+ columns = ['species', 'common name', 'order', 'family',
114
+ 'genus', 'Updated role in ecosystem', 'Proof',
115
+ 'ipm strategies', 'size of insect', 'geographical spread',
116
+ 'life cycle specifics', 'pest for plant species', 'species status',
117
+ 'distribution area', 'appearance', 'identification']
118
+
119
+ df1 = pd.read_excel(csv_filepath1, usecols=columns)
120
+ df2 = pd.read_excel(csv_filepath2, usecols=columns)
121
+
122
+ all_insects_data = pd.concat([df1, df2], ignore_index=True)
123
+
124
+ def get_prompt_with_vetted_info_from_specie_name(search_for_specie, mode):
125
+
126
+ def read_and_format_filtered_csv_better(insect_specie):
127
+ filtered_data = all_insects_data[all_insects_data['species'] == insect_specie]
128
+ formatted_data = ""
129
+ # Format the filtered data
130
+ for index, row in filtered_data.iterrows():
131
+ row_data = [f"{col}: {row[col]}" for col in filtered_data.columns]
132
+ formatted_row = "\n".join(row_data)
133
+ formatted_data += f"{formatted_row}\n"
134
+
135
+ return formatted_data
136
+
137
+ # Use the path to your CSV file here
138
+
139
+ vetted_info=read_and_format_filtered_csv_better(search_for_specie)
140
+
141
+ if mode=="Farmer":
142
+ language_constraint="The language should be acustomed to the Farmers. Given question is likely to be asked by a farmer in the field will ask which will help to make decisions which are immediate and practical."
143
+ elif mode=="Researcher":
144
+ language_constraint="The language should be acustomed to a researcher. Given question is likely to be asked by a scientist which are comprehensive and aimed at exploring new knowledge or refining existing methodologies"
145
+ else:
146
+ print("No valid mode provided. Exiting")
147
+ exit()
148
+ # general_system_template = """
149
+ # In every question you are provided information about the insect/weed. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect/weed species and a question by the user. answer the question according to these two types of informations.
150
+ # ----
151
+ # Vetted info is as follows:
152
+ # {vetted_info}
153
+ # ----
154
+ # The context retrieved for documents about this particular question is as follows:
155
+ # {context}
156
+ # ----
157
+ # Additional Instruction:
158
+ # 1. Reference Constraint
159
+ # At the end of each answer provide the source/reference for the given data in following format:
160
+ # \n\n[enter two new lines before writing below] References:
161
+ # Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'.
162
+ # Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used.
163
+ # 2. Information Constraint:
164
+ # Only answer the question from information provided otherwise say you dont know. You have to answer in 50 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about.
165
+ # 3. Language constraint:
166
+ # {language_constraint}
167
+
168
+ # ----
169
+ # """.format(vetted_info=vetted_info, language_constraint=language_constraint,context="{context}", )
170
+
171
+ general_system_template = f"""
172
+ You are an AI assistant specialized in providing information about insects/weeds. Answer the user's question based on the available information or your general knowledge.
173
+
174
+ The context retrieved for this question is as follows:
175
+ {{context}}
176
+
177
+ Instructions:
178
+ 1. Evaluate the relevance of the provided context to the question.
179
+ 2. If the context contains relevant information, use it to answer the question and explicitly mention "Based on provided information" in your source.
180
+ 3. If the context does not contain relevant information, use your general knowledge to answer the question and state "Based on general knowledge" as the source.
181
+ 4. Format your response as follows:
182
+ Answer: Provide a concise answer in less than 50 words.
183
+ Source: State either "Based on provided information" or "Based on general knowledge".
184
+
185
+ 5. Language constraint:
186
+ {language_constraint}
187
+
188
+ Question: {{question}}
189
+ """
190
+
191
+ general_user_template = "Question:```{question}```"
192
+ messages_formatted = [
193
+ SystemMessagePromptTemplate.from_template(general_system_template),
194
+ HumanMessagePromptTemplate.from_template(general_user_template)
195
+ ]
196
+ qa_prompt = ChatPromptTemplate.from_messages( messages_formatted )
197
+ # print(qa_prompt)
198
+ return qa_prompt
199
+
200
+
201
+
202
+
203
+
204
+ qa_prompt=get_prompt_with_vetted_info_from_specie_name("Papaipema nebris", "Researcher")
205
+ # print("First prompt is intialized as: " , qa_prompt, "\n\n")
206
+
207
+
208
+ memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer', return_messages=True) # https://github.com/langchain-ai/langchain/issues/9394#issuecomment-1683538834
209
+
210
+
211
+ if model_name==4:
212
+ llm_openai = ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens) # TODO: NEW MODEL VERSION AVAILABLE
213
+ else:
214
+ llm_openai = ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
215
+
216
+ specie_selector="Papaipema nebris"
217
+ filter = {
218
+ "$or": [
219
+ {"matched_specie_0": specie_selector},
220
+ {"matched_specie_1": specie_selector},
221
+ {"matched_specie_2": specie_selector},
222
+ ]
223
+ }
224
+ # retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
225
+
226
+ # qa_chain = ConversationalRetrievalChain.from_llm(
227
+ # llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,\
228
+ # combine_docs_chain_kwargs={'prompt': qa_prompt}
229
+
230
+ # )
231
+ #
232
+
233
+ def initialize_qa_chain(specie_selector, application_mode, model_name="GPT-4", database_persistent_directory=default_persist_directory_insects):
234
+ if model_name=="GPT-4":
235
+ chosen_llm=ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens)
236
+ elif model_name=="GPT-3.5":
237
+ chosen_llm=ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
238
+ elif model_name=="Llama-3 70B":
239
+ chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-70b-instruct", temperature=0,max_tokens=max_tokens )
240
+ elif model_name=="Llama-3 8B":
241
+ chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-8b-instruct", temperature=0, max_tokens=max_tokens)
242
+ elif model_name=="Gemini-1.5 Pro":
243
+ chosen_llm = ChatOpenRouter(model_name="google/gemini-pro-1.5", temperature=0, max_tokens=max_tokens)
244
+ elif model_name=="Claude 3 Opus":
245
+ chosen_llm = ChatAnthropic(model_name='claude-3-opus-20240229', temperature=0, max_tokens=max_tokens)
246
+
247
+ else:
248
+ print("No appropriate llm was selected")
249
+ exit()
250
+
251
+
252
+
253
+ filter = {
254
+ "$or": [
255
+ {"matched_specie_0": specie_selector},
256
+ {"matched_specie_1": specie_selector},
257
+ {"matched_specie_2": specie_selector},
258
+ {"matched_specie_3": specie_selector},
259
+ {"matched_specie_4": specie_selector},
260
+ {"matched_specie_5": specie_selector},
261
+ {"matched_specie_6": specie_selector},
262
+ {"matched_specie_7": specie_selector},
263
+ {"matched_specie_8": specie_selector},
264
+ {"matched_specie_9": specie_selector},
265
+ {"matched_specie_10": specie_selector}
266
+ ]
267
+ }
268
+
269
+ embedding = OpenAIEmbeddings()
270
+ vectordb = Chroma(persist_directory=database_persistent_directory,
271
+ embedding_function=embedding)
272
+
273
+ # print("got updated retriever without metadata filtering")
274
+ retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
275
+
276
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
277
+ qa_prompt=get_prompt_with_vetted_info_from_specie_name(specie_selector, application_mode)
278
+ qa_chain = ConversationalRetrievalChain.from_llm(
279
+ chosen_llm, retriever, memory=memory, verbose=False, return_source_documents=True,
280
+ combine_docs_chain_kwargs={'prompt': qa_prompt}
281
+ )
282
+
283
+ return qa_chain
284
+ # result = qa_chain.invoke({"question": "where are stalk borer eggs laid?"})
285
+ # print("Got the first LLM task working: ", result)
286
+
287
+
288
+ #Application Interface:
289
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
290
+ with gr.Row():
291
+ with gr.Column(scale=1):
292
+ gr.Markdown(
293
+ """
294
+ ![Logo](file/logo1.png)
295
+ """
296
+ )
297
+ with gr.Column(scale=1):
298
+ gr.Markdown(
299
+ """
300
+ ![Logo](file/logo2.png)
301
+ """
302
+ )
303
+
304
+ # Configure UI layout
305
+ chatbot = gr.Chatbot(height=600, label="AgLLM")
306
+ with gr.Row():
307
+ with gr.Column(scale=1):
308
+ with gr.Row():
309
+ domain_name = gr.Dropdown(
310
+ list(["Insects", "Weeds"]),
311
+ value="Insects",
312
+ label="Domain",
313
+ info="Select Domain",
314
+ interactive=True,
315
+ scale=1,
316
+ visible=True
317
+ )
318
+
319
+ # Model selection
320
+ specie_selector = gr.Dropdown(
321
+ species_list_insects,
322
+ value=species_list_insects[0],
323
+ label="Species",
324
+ info="Select the Species",
325
+ interactive=True,
326
+ scale=1,
327
+ visible=True
328
+ )
329
+ with gr.Row():
330
+ model_name = gr.Dropdown(
331
+ list(["GPT-4", "GPT-3.5", "Llama-3 70B", "Llama-3 8B", "Gemini-1.5 Pro", "Claude 3 Opus"]),
332
+ value="Llama-3 70B",
333
+ label="LLM",
334
+ info="Select the LLM",
335
+ interactive=True,
336
+ scale=1,
337
+ visible=True
338
+ )
339
+ application_mode = gr.Dropdown(
340
+ list(["Farmer", "Researcher"]),
341
+ value="Researcher",
342
+ label="Mode",
343
+ info="Select the Mode",
344
+ interactive=True,
345
+ scale=1,
346
+ visible=True
347
+ )
348
+
349
+
350
+ with gr.Column(scale=2):
351
+ # User input prompt text field
352
+ user_prompt_message = gr.Textbox(placeholder="Please add user prompt here", label="User prompt")
353
+ with gr.Row():
354
+ # clear = gr.Button("Clear Conversation", scale=2)
355
+ submitBtn = gr.Button("Submit", scale=8)
356
+
357
+ state = gr.State([])
358
+ qa_chain_state = gr.State(value=None)
359
+
360
+ # Handle user message
361
+ def user(user_prompt_message, history):
362
+ # print("HISTORY IS: ", history) # TODO: REMOVE IT LATER
363
+ if user_prompt_message != "":
364
+ return history + [[user_prompt_message, None]]
365
+ else:
366
+ return history + [["Invalid prompts - user prompt cannot be empty", None]]
367
+
368
+ # Chatbot logic for configuration, sending the prompts, rendering the streamed back generations, etc.
369
+ def bot(model_name, application_mode, user_prompt_message, history, messages_history, qa_chain, domain_name):
370
+ if qa_chain == None:
371
+ qa_chain=init_qa_chain(species_list_insects[0], application_mode, model_name, domain_name)
372
+
373
+ dialog = []
374
+ bot_message = ""
375
+ history[-1][1] = "" # Placeholder for the answer
376
+
377
+ dialog = [
378
+ {"role": "user", "content": user_prompt_message},
379
+ ]
380
+ messages_history += dialog
381
+
382
+ # Queue for streamed character rendering
383
+ q = Queue()
384
+
385
+ # Async task for streamed chain results wired to callbacks we previously defined, so we don't block the UI
386
+
387
+ def task(user_prompt_message):
388
+ result = qa_chain.invoke({"question": user_prompt_message})
389
+ answer = result["answer"]
390
+ source_documents = result.get("source_documents", [])
391
+ print("SOURCE DOCUMENTS: ", source_documents)
392
+
393
+ try:
394
+ answer_start = answer.find("Answer:")
395
+ source_start = answer.find("Source:")
396
+
397
+ if answer_start != -1 and source_start != -1:
398
+ model_answer = answer[answer_start + len("Answer:"):source_start].strip()
399
+ source = answer[source_start + len("Source:"):].strip()
400
+
401
+ if "Based on provided information" in source:
402
+ # Extract the most relevant source document
403
+ if source_documents:
404
+ doc = source_documents[0]
405
+ species_name = doc.metadata.get('matched_specie_0', 'Unspecified species')
406
+
407
+ if domain_name == "Insects":
408
+ formatted_source = f'Iowa State University Extension and Outreach. "Field Crop Insects." Iowa State University Extension Store, June 26, 2023. https://store.extension.iastate.edu/product/13725. Information about {species_name}.'
409
+ elif domain_name == "Weeds":
410
+ formatted_source = f'Iowa State University Extension and Outreach. "Weed Identification Field Guide 2nd Edition." Iowa State University Extension Store, August 2015. https://store.extension.iastate.edu/product/13358. Information about {species_name}.'
411
+ else:
412
+ formatted_source = f"Based on provided information about {species_name}, but domain is unspecified."
413
+ else:
414
+ formatted_source = "Based on provided information, but source details are unavailable."
415
+ else:
416
+ formatted_source = source
417
+
418
+ formatted_response = f"Answer:\n{model_answer}\n\nSource:\n{formatted_source}"
419
+ else:
420
+ formatted_response = answer
421
+ except Exception as e:
422
+ print(f"Error parsing output: {e}")
423
+ formatted_response = answer
424
+
425
+ return formatted_response
426
+
427
+ history[-1][1] = task(user_prompt_message)
428
+ return [history, messages_history]
429
+
430
+ # Initialize the chat history with default system message
431
+ def init_history(messages_history):
432
+ messages_history = []
433
+ messages_history += [system_message]
434
+ return messages_history
435
+
436
+ # Clean up the user input text field
437
+ def input_cleanup():
438
+ return ""
439
+
440
+ def init_qa_chain(specie_selector, application_mode, model_name, domain_name):
441
+ if domain_name=="Insects":
442
+ qa_chain = initialize_qa_chain(specie_selector, application_mode, model_name, default_persist_directory_insects)
443
+ elif domain_name=="Weeds":
444
+ qa_chain = initialize_qa_chain(specie_selector, application_mode, model_name, default_persist_directory_weeds)
445
+ else:
446
+ print("No Appropriate Chain Selected")
447
+ return qa_chain
448
+
449
+ specie_selector.change(
450
+ init_qa_chain,
451
+ inputs=[specie_selector, application_mode,model_name, domain_name ],
452
+ outputs=[qa_chain_state]
453
+ )
454
+ model_name.change(
455
+ init_qa_chain,
456
+ inputs=[specie_selector, application_mode,model_name, domain_name ],
457
+ outputs=[qa_chain_state]
458
+ )
459
+
460
+ #####
461
+ def update_species_list(domain):
462
+ if domain == "Insects":
463
+ return gr.Dropdown( species_list_insects, value=species_list_insects[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
464
+ elif domain == "Weeds":
465
+ return gr.Dropdown( species_list_weeds, value=species_list_weeds[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True )
466
+
467
+ domain_name.change(
468
+ update_species_list,
469
+ inputs=[domain_name],
470
+ outputs=[specie_selector]
471
+ )
472
+
473
+ # When the user clicks Enter and the user message is submitted
474
+ user_prompt_message.submit(
475
+ user,
476
+ [user_prompt_message, chatbot],
477
+ [chatbot],
478
+ queue=False
479
+ ).then(
480
+ bot,
481
+ [model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name],
482
+ [chatbot, state]
483
+ ).then(input_cleanup,
484
+ [],
485
+ [user_prompt_message],
486
+ queue=False
487
+ )
488
+
489
+ # When the user clicks the submit button
490
+ submitBtn.click(
491
+ user,
492
+ [user_prompt_message, chatbot],
493
+ [chatbot],
494
+ queue=False
495
+ ).then(
496
+ bot,
497
+ [model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name],
498
+ [chatbot, state]
499
+ ).then(
500
+ input_cleanup,
501
+ [],
502
+ [user_prompt_message],
503
+ queue=False
504
+ )
505
+
506
+ # When the user clicks the clear button
507
+ # clear.click(lambda: None, None, chatbot, queue=False).success(init_history, [state], [state])
508
+ if __name__ == "__main__":
509
+ # demo.launch()
510
+ demo.queue().launch(allowed_paths=["/"], share=False, show_error=True)
app_database_prep.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ from langchain.document_loaders import PyPDFDirectoryLoader
5
+ import pandas as pd
6
+ import langchain
7
+ from queue import Queue
8
+ from typing import Any, List
9
+ from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
10
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
11
+ from langchain.schema import LLMResult
12
+ from langchain.embeddings import HuggingFaceEmbeddings
13
+ from langchain.vectorstores import FAISS
14
+ from langchain.prompts.prompt import PromptTemplate
15
+ from anyio.from_thread import start_blocking_portal #For model callback streaming
16
+
17
+ langchain.debug=True # TODO: DOUBLE CHECK
18
+
19
+ system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later.
20
+ import os
21
+ from dotenv import load_dotenv
22
+
23
+ import streamlit as st
24
+
25
+ from langchain.document_loaders import PyPDFLoader
26
+ from langchain.text_splitter import CharacterTextSplitter
27
+ from langchain.embeddings import OpenAIEmbeddings
28
+ from langchain.chains.question_answering import load_qa_chain
29
+ from langchain.chat_models import ChatOpenAI
30
+ from langchain.vectorstores import Chroma
31
+ import chromadb
32
+
33
+
34
+ ## added information in metadata:
35
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
36
+ from langchain.llms import OpenAI
37
+ from langchain.chains import RetrievalQA
38
+ from langchain.document_loaders import TextLoader
39
+ from langchain.document_loaders import DirectoryLoader
40
+ from langchain_community.document_loaders import PyMuPDFLoader
41
+ from langchain.schema import Document
42
+
43
+
44
+ # Function to process a sheet from the Excel file
45
+ def process_excel_sheet(
46
+ excel_path: str,
47
+ sheet_name: str,
48
+ region: str,
49
+ splitter: RecursiveCharacterTextSplitter
50
+ ) -> List[Document]:
51
+ """Loads data from an Excel sheet, creates Documents, splits them, and adds metadata."""
52
+ print(f"--- Processing Excel Sheet: {sheet_name} (Region: {region}) ---")
53
+ try:
54
+ df = pd.read_excel(excel_path, sheet_name=sheet_name)
55
+ print(f"Excel Data Head ({sheet_name}):\\n", df.head())
56
+ except Exception as e:
57
+ print(f"Error loading sheet '{sheet_name}' from {excel_path}: {e}")
58
+ return []
59
+
60
+ initial_documents = []
61
+ for index, row in df.iterrows():
62
+ ipm_info = str(row['IPM Info']) if pd.notna(row['IPM Info']) else ""
63
+ # Check if essential columns exist and are not empty (removed accuracy check)
64
+ if pd.isna(row['Common Name']) or pd.isna(row['Species']):
65
+ print(f"Skipping row {index+2} in sheet '{sheet_name}' due to missing essential data (Common Name or Species).")
66
+ continue
67
+
68
+ doc = Document(
69
+ page_content=ipm_info,
70
+ metadata={
71
+ "source": f"{excel_path}#sheet={sheet_name}#row={index+2}",
72
+ "common_name": row['Common Name'],
73
+ "species": row['Species'],
74
+ "matched_specie_0": row['Species'],
75
+ "region": region
76
+ }
77
+ )
78
+ initial_documents.append(doc)
79
+
80
+ if initial_documents:
81
+ print(f"First Document from {sheet_name} (before splitting):\\n", initial_documents[0])
82
+ else:
83
+ print(f"No documents created from sheet: {sheet_name}")
84
+ return [] # Return empty list if no documents were created
85
+
86
+ split_documents = []
87
+ for doc in initial_documents:
88
+ splits = splitter.split_documents([doc])
89
+ for i, split_doc in enumerate(splits, start=1):
90
+ metadata = split_doc.metadata.copy()
91
+ metadata["source"] = f"{metadata['source']}#chunk{i}"
92
+ split_doc.metadata = metadata
93
+ split_documents.append(split_doc)
94
+
95
+ if split_documents:
96
+ print(f"First Document chunk from {sheet_name}:\\n", split_documents[0])
97
+
98
+ print(f"Finished processing sheet: {sheet_name}. Found {len(split_documents)} chunks.")
99
+ print("---------------------------------------------------")
100
+ return split_documents
101
+
102
+ # --- Main Script Logic ---
103
+
104
+ # loader = DirectoryLoader('./agllm-data/', glob="./*.pdf", loader_cls=PyMuPDFLoader)
105
+ # loader = DirectoryLoader('/u/marshad/data/agllm-data/', glob='**/*.pdf', loader_cls=PyMuPDFLoader)
106
+ data_domain_identifier="agllm-data-isu-field-insects-all-species"
107
+ persist_directory = f'vector-databases-deployed/db5-{data_domain_identifier}' # was full
108
+ loader = DirectoryLoader(f'agllm-data/{data_domain_identifier}', glob='**/*.pdf', loader_cls=PyMuPDFLoader)#,# was full, loader_kwargs={'chunk_size':512})
109
+ chunk_size_input=512
110
+ metadata_raw = pd.read_csv(f"./agllm-data/{data_domain_identifier}/matched_species_results_v2.csv")
111
+ documents = loader.load()
112
+
113
+ ## Load Excel File Path (Define once)
114
+ excel_file_path = "agllm-data/PestID Species.xlsx"
115
+
116
+
117
+ ## Process PDF documents and add metadata
118
+ print("--- Processing PDF Documents ---")
119
+ pdf_documents_for_splitting = [] # Prepare list to hold docs with added metadata
120
+ for doc in documents:
121
+ # Add region for PDF docs
122
+ doc.metadata["region"] = "United States"
123
+
124
+ # Add species metadata (existing logic)
125
+ file_name_associated_with_this_doc = doc.metadata["source"].split('/')[-1]
126
+ matching_species_for_this_file_name = metadata_raw[metadata_raw["File Name"].str.lower() == file_name_associated_with_this_doc.lower()]["Species"]
127
+ # Ensure matching_species_for_this_file_name is iterable and not empty
128
+ if not matching_species_for_this_file_name.empty:
129
+ for specie_index in range(len(matching_species_for_this_file_name)):
130
+ # Check if specie_index is within bounds (although range should handle this)
131
+ if specie_index < len(matching_species_for_this_file_name):
132
+ specie_name = matching_species_for_this_file_name.iloc[specie_index]
133
+ doc.metadata["matched_specie_" + str(specie_index)] = specie_name
134
+ else:
135
+ # This case should ideally not happen with range(len(...))
136
+ print(f"Warning: Specie index {specie_index} out of bounds for file {file_name_associated_with_this_doc}")
137
+ else:
138
+ print(f"Warning: No matching species found in CSV for PDF: {file_name_associated_with_this_doc}")
139
+
140
+ pdf_documents_for_splitting.append(doc) # Add modified doc to new list
141
+
142
+
143
+ # Initialize Text Splitter
144
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size_input, chunk_overlap=10)
145
+
146
+
147
+ # Split PDF documents
148
+ pdf_splitted_documents = []
149
+ for doc in pdf_documents_for_splitting: # Use the list with added metadata
150
+ splits = text_splitter.split_documents([doc])
151
+ for i, split_doc in enumerate(splits, start=1):
152
+ metadata = split_doc.metadata.copy()
153
+ # Update source for PDF chunks (existing logic)
154
+ source_base = metadata.get('source', 'unknown_source')
155
+ page_num = metadata.get('page', 'unknown_page')
156
+ metadata["source"] = f"{source_base}#page{page_num}#chunk{i}"
157
+ # Remove the raw page number if desired, as it's now in the source string
158
+ # metadata.pop('page', None)
159
+ split_doc.metadata = metadata
160
+ pdf_splitted_documents.append(split_doc)
161
+
162
+ print("First PDF Document chunk:\\n", pdf_splitted_documents[0] if pdf_splitted_documents else "No PDF documents processed")
163
+ print(f"Count after PDF processing: {len(pdf_splitted_documents)}")
164
+ print("---------------------------------------------------")
165
+
166
+
167
+ # Process Excel Sheets using the function
168
+ india_splitted_documents = process_excel_sheet(
169
+ excel_path=excel_file_path,
170
+ sheet_name="India",
171
+ region="India",
172
+ splitter=text_splitter
173
+ )
174
+
175
+ africa_splitted_documents = process_excel_sheet(
176
+ excel_path=excel_file_path,
177
+ sheet_name="Africa",
178
+ region="Africa",
179
+ splitter=text_splitter
180
+ )
181
+
182
+
183
+ # Combine lists from all sources
184
+ splitted_documents = pdf_splitted_documents + india_splitted_documents + africa_splitted_documents
185
+
186
+
187
+ # print(splitted_documents[0]) # Original print statement - commented out as we print chunks above
188
+ print("=== Combined Processing Done ===") # Adjusted print statement
189
+ print(f"Total documents after combining PDF, India, and Africa sources: {len(splitted_documents)}")
190
+ print("=============================")
191
+
192
+
193
+
194
+ # ONLY FOR THE FIRST TIME
195
+
196
+ # Check if the persist directory exists and delete it to ensure a fresh start
197
+ if os.path.exists(persist_directory):
198
+ print(f"Deleting existing vector database directory: {persist_directory}")
199
+ shutil.rmtree(persist_directory)
200
+ print(f"Directory deleted.")
201
+ else:
202
+ print(f"Vector database directory not found, creating a new one: {persist_directory}")
203
+
204
+ embedding = OpenAIEmbeddings()
205
+ vectordb = Chroma.from_documents(documents=splitted_documents,
206
+ embedding=embedding,
207
+ persist_directory=persist_directory)
208
+
209
+ # persiste the db to disk
210
+ vectordb.persist()
211
+ vectordb = None
212
+
213
+
214
+ # Now we can load the persisted database from disk, and use it as normal.
215
+
216
+ vectordb = Chroma(persist_directory=persist_directory,
217
+ embedding_function=embedding)
218
+
219
+
220
+ print(vectordb.get())
221
+
222
+ #just a test script:
223
+
224
+ specie_selector="Aphis spiraecola"
225
+ filter = {
226
+ "$or": [
227
+ {"matched_specie_0": specie_selector},
228
+ {"matched_specie_1": specie_selector},
229
+ {"matched_specie_2": specie_selector},
230
+ ]
231
+ }
232
+ answer = vectordb.as_retriever(search_kwargs={'k':10, 'filter': filter}).get_relevant_documents(
233
+ "anything else.?")
234
+ print(answer)
outdated-files/agllm-data.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e644edb3d6ae6153f2411ff768ba5047314a9fdd92f97b3d4fa25b8e28e98a5
3
+ size 4658799
outdated-files/agllm_with_evaluation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
outdated-files/app-basic.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import gradio as gr
2
+
3
+ # def greet(name):
4
+ # return "Hello " + name + "!"
5
+
6
+ # demo = gr.Interface(fn=greet, inputs="textbox", outputs="textbox")
7
+
8
+ # demo.launch() # Share your demo with just 1 extra parameter
9
+
10
+
11
+ import gradio as gr
12
+
13
+ def greet(name):
14
+ return "Hello " + name + "!"
15
+
16
+ demo = gr.Interface(fn=greet, inputs="textbox", outputs="textbox")
17
+
18
+ if __name__ == "__main__":
19
+ demo.launch()
outdated-files/app-old.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # https://stackoverflow.com/questions/76175046/how-to-add-prompt-to-langchain-conversationalretrievalchain-chat-over-docs-with
3
+ # again from:
4
+ # https://python.langchain.com/docs/integrations/providers/vectara/vectara_chat
5
+ from langchain.document_loaders import PyPDFDirectoryLoader
6
+ import pandas as pd
7
+ import langchain
8
+ from queue import Queue
9
+ from typing import Any
10
+ from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
11
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
+ from langchain.schema import LLMResult
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
+ from langchain.vectorstores import FAISS
15
+ from langchain.prompts.prompt import PromptTemplate
16
+ from anyio.from_thread import start_blocking_portal #For model callback streaming
17
+
18
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
19
+ import os
20
+ from dotenv import load_dotenv
21
+
22
+ import streamlit as st
23
+
24
+ from langchain.document_loaders import PyPDFLoader
25
+ from langchain.text_splitter import CharacterTextSplitter
26
+ from langchain.embeddings import OpenAIEmbeddings
27
+ from langchain.chains.question_answering import load_qa_chain
28
+ from langchain.chat_models import ChatOpenAI
29
+ from langchain.vectorstores import Chroma
30
+ import chromadb
31
+
32
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
33
+ from langchain.llms import OpenAI
34
+ from langchain.chains import RetrievalQA
35
+ from langchain.document_loaders import TextLoader
36
+ from langchain.document_loaders import DirectoryLoader
37
+ from langchain_community.document_loaders import PyMuPDFLoader
38
+ from langchain.schema import Document
39
+
40
+ from langchain.memory import ConversationBufferMemory
41
+
42
+ from langchain.chains import ConversationalRetrievalChain
43
+ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
44
+ from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
45
+ import gradio as gr
46
+ from langchain.memory import ConversationBufferMemory
47
+ from langchain.chains import ConversationalRetrievalChain
48
+
49
+ persist_directory = '/projects/bcjp/marshad/agllm/db5'
50
+ csv_filepath1 = "./agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx"
51
+ csv_filepath2 = "./agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx"
52
+ model_name=4
53
+ max_tokens=400
54
+ system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later.
55
+ langchain.debug=True # TODO: DOUBLE CHECK
56
+ retriever_k_value=2
57
+ embedding = OpenAIEmbeddings()
58
+
59
+ ######### todo: skipping the first step
60
+
61
+ embedding = OpenAIEmbeddings()
62
+ vectordb = Chroma(persist_directory=persist_directory,
63
+ embedding_function=embedding)
64
+
65
+ retriever = vectordb.as_retriever()
66
+
67
+ print(# Single example
68
+ vectordb.as_retriever(k=2, search_kwargs={"filter": {"matched_specie_0": "Hypagyrtis unipunctata"}, 'k':1}).get_relevant_documents(
69
+ "Checking if retriever is correctly initalized?"
70
+ ))
71
+
72
+ columns = ['species', 'common name', 'order', 'family',
73
+ 'genus', 'Updated role in ecosystem', 'Proof',
74
+ 'ipm strategies', 'size of insect', 'geographical spread',
75
+ 'life cycle specifics', 'pest for plant species', 'species status',
76
+ 'distribution area', 'appearance', 'identification']
77
+
78
+ df1 = pd.read_excel(csv_filepath1, usecols=columns)
79
+ df2 = pd.read_excel(csv_filepath2, usecols=columns)
80
+
81
+ all_insects_data = pd.concat([df1, df2], ignore_index=True)
82
+
83
+ def get_prompt_with_vetted_info_from_specie_name(search_for_specie, mode):
84
+
85
+ def read_and_format_filtered_csv_better(insect_specie):
86
+ filtered_data = all_insects_data[all_insects_data['species'] == insect_specie]
87
+ formatted_data = ""
88
+ # Format the filtered data
89
+ for index, row in filtered_data.iterrows():
90
+ row_data = [f"{col}: {row[col]}" for col in filtered_data.columns]
91
+ formatted_row = "\n".join(row_data)
92
+ formatted_data += f"{formatted_row}\n"
93
+
94
+ return formatted_data
95
+
96
+ # Use the path to your CSV file here
97
+
98
+ vetted_info=read_and_format_filtered_csv_better(search_for_specie)
99
+ if mode=="user":
100
+ language_constraint="The language should be acustomed to the end user. This question is likely asked by a farmer. So, answer things in their language. Bur for referencing information, you can use the original content. This is only for the main answer to be provided by you."
101
+ elif mode=="researcher":
102
+ language_constraint="The language should be acustomed to a researcher. This question is likely asked by an academic researcher. So you can use all the technical terms freely. And for referencing information, you can use the original content. This is only for the main answer to be provided by you."
103
+ else:
104
+ print("No valid model provided. Exiting")
105
+ exit()
106
+ general_system_template = """
107
+ In every question you are provided information about the insect. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect specie and a question by the user. answer the question according to these two types of informations.
108
+ ----
109
+ Vetted info is as follows:
110
+ {vetted_info}
111
+ ----
112
+ The context retrieved for documents about this particular question is a as follows:
113
+ {context}
114
+ ----
115
+ Additional Instruction:
116
+ 1. Reference Constraint
117
+ At the end of each answer provide the source/reference for the given data in following format:
118
+ \n\n[enter two new lines before writing below] References:
119
+ Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'.
120
+ Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used.
121
+ 2. Information Constraint:
122
+ Only answer the question from information provided otherwise say you dont know. You have to answer in 150 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about.
123
+ 3. Language constraint:
124
+ {language_constraint}
125
+
126
+ ----
127
+ """.format(vetted_info=vetted_info, language_constraint=language_constraint,context="{context}", )
128
+
129
+
130
+ general_user_template = "Question:```{question}```"
131
+ messages_formatted = [
132
+ SystemMessagePromptTemplate.from_template(general_system_template),
133
+ # HumanMessagePromptTemplate.from_template(general_system_template),
134
+ HumanMessagePromptTemplate.from_template(general_user_template)
135
+ ]
136
+ qa_prompt = ChatPromptTemplate.from_messages( messages_formatted )
137
+ print(qa_prompt)
138
+ return qa_prompt
139
+ qa_prompt=get_prompt_with_vetted_info_from_specie_name("Papaipema nebris", "researcher")
140
+ print("First prompt is intialized as: " , qa_prompt, "\n\n")
141
+
142
+
143
+ memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer', return_messages=True) # https://github.com/langchain-ai/langchain/issues/9394#issuecomment-1683538834
144
+
145
+
146
+ if model_name==4:
147
+ llm_openai = ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens) # TODO: NEW MODEL VERSION AVAILABLE
148
+ else:
149
+ llm_openai = ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens)
150
+
151
+ specie_selector="Papaipema nebris"
152
+ filter = {
153
+ "$or": [
154
+ {"matched_specie_0": specie_selector},
155
+ {"matched_specie_1": specie_selector},
156
+ {"matched_specie_2": specie_selector},
157
+ ]
158
+ }
159
+ retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter})
160
+
161
+ qa_chain = ConversationalRetrievalChain.from_llm(
162
+ llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,\
163
+ combine_docs_chain_kwargs={'prompt': qa_prompt}
164
+
165
+ )
166
+ #
167
+
168
+ def initialize_qa_chain(specie_selector, application_mode):
169
+
170
+ filter = {
171
+ "$or": [
172
+ {"matched_specie_0": specie_selector},
173
+ {"matched_specie_1": specie_selector},
174
+ {"matched_specie_2": specie_selector},
175
+ ]
176
+ }
177
+ retriever = vectordb.as_retriever(search_kwargs={'k':2, 'filter': filter})
178
+
179
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
180
+ qa_prompt=get_prompt_with_vetted_info_from_specie_name(specie_selector, application_mode)
181
+ qa_chain = ConversationalRetrievalChain.from_llm(
182
+ llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,
183
+ combine_docs_chain_kwargs={'prompt': qa_prompt}
184
+ )
185
+
186
+ return qa_chain
187
+ result = qa_chain.invoke({"question": "where are stalk borer eggs laid?"})
188
+ print("Got the first LLM task working: ", result)
189
+
190
+
191
+ #Application Interface:
192
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
193
+ with gr.Row():
194
+ with gr.Column(scale=1):
195
+ gr.Markdown(
196
+ """
197
+ ![Logo](file/logo1.png)
198
+ """
199
+ )
200
+ with gr.Column(scale=1):
201
+ gr.Markdown(
202
+ """
203
+ ![Logo](file/logo2.png)
204
+ """
205
+ )
206
+
207
+ # Configure UI layout
208
+ chatbot = gr.Chatbot(height=600, label="AgLLM22")
209
+ with gr.Row():
210
+ with gr.Column(scale=1):
211
+ with gr.Row():
212
+ # Model selection
213
+ specie_selector = gr.Dropdown(
214
+ list(["Papaipema nebris", "Nomophila nearctica"]),
215
+ value="Papaipema nebris",
216
+ label="Species",
217
+ info="Select the Species",
218
+ interactive=True,
219
+ scale=2,
220
+ visible=True
221
+ )
222
+ with gr.Row():
223
+ application_mode = gr.Dropdown(
224
+ list(["user", "researcher"]),
225
+ value="researcher",
226
+ label="Mode",
227
+ info="Select the Mode",
228
+ interactive=True,
229
+ scale=2,
230
+ visible=True
231
+ )
232
+ with gr.Row():
233
+ pass
234
+ with gr.Column(scale=2):
235
+ # User input prompt text field
236
+ user_prompt_message = gr.Textbox(placeholder="Please add user prompt here", label="User prompt")
237
+ with gr.Row():
238
+ # clear = gr.Button("Clear Conversation", scale=2)
239
+ submitBtn = gr.Button("Submit", scale=8)
240
+
241
+ state = gr.State([])
242
+ qa_chain_state = gr.State(value=None)
243
+
244
+ # Handle user message
245
+ def user(user_prompt_message, history):
246
+ print("HISTORY IS: ", history) # TODO: REMOVE IT LATER
247
+ if user_prompt_message != "":
248
+ return history + [[user_prompt_message, None]]
249
+ else:
250
+ return history + [["Invalid prompts - user prompt cannot be empty", None]]
251
+
252
+ # Chatbot logic for configuration, sending the prompts, rendering the streamed back generations, etc.
253
+ def bot(application_mode, user_prompt_message, history, messages_history, qa_chain):
254
+ if qa_chain == None:
255
+ qa_chain=init_qa_chain("Papaipema nebris", application_mode)
256
+
257
+ dialog = []
258
+ bot_message = ""
259
+ history[-1][1] = "" # Placeholder for the answer
260
+
261
+ dialog = [
262
+ {"role": "user", "content": user_prompt_message},
263
+ ]
264
+ messages_history += dialog
265
+
266
+ # Queue for streamed character rendering
267
+ q = Queue()
268
+
269
+ # Async task for streamed chain results wired to callbacks we previously defined, so we don't block the UI
270
+ def task(user_prompt_message):
271
+ ret = qa_chain.invoke({"question": user_prompt_message})["answer"]
272
+ return ret
273
+
274
+ history[-1][1] = task(user_prompt_message)
275
+ return [history, messages_history]
276
+
277
+ # Initialize the chat history with default system message
278
+ def init_history(messages_history):
279
+ messages_history = []
280
+ messages_history += [system_message]
281
+ return messages_history
282
+
283
+ # Clean up the user input text field
284
+ def input_cleanup():
285
+ return ""
286
+
287
+ def init_qa_chain(specie_selector, application_mode):
288
+ qa_chain = initialize_qa_chain(specie_selector, application_mode)
289
+ return qa_chain
290
+
291
+ specie_selector.change(
292
+ init_qa_chain,
293
+ inputs=[specie_selector, application_mode],
294
+ outputs=[qa_chain_state]
295
+ )
296
+
297
+ # When the user clicks Enter and the user message is submitted
298
+ user_prompt_message.submit(
299
+ user,
300
+ [user_prompt_message, chatbot],
301
+ [chatbot],
302
+ queue=False
303
+ ).then(
304
+ bot,
305
+ [application_mode, user_prompt_message, chatbot, state, qa_chain_state],
306
+ [chatbot, state]
307
+ ).then(input_cleanup,
308
+ [],
309
+ [user_prompt_message],
310
+ queue=False
311
+ )
312
+
313
+ # When the user clicks the submit button
314
+ submitBtn.click(
315
+ user,
316
+ [user_prompt_message, chatbot],
317
+ [chatbot],
318
+ queue=False
319
+ ).then(
320
+ bot,
321
+ [application_mode, user_prompt_message, chatbot, state, qa_chain_state],
322
+ [chatbot, state]
323
+ ).then(
324
+ input_cleanup,
325
+ [],
326
+ [user_prompt_message],
327
+ queue=False
328
+ )
329
+
330
+ # When the user clicks the clear button
331
+ # clear.click(lambda: None, None, chatbot, queue=False).success(init_history, [state], [state])
332
+ if __name__ == "__main__":
333
+ # demo.launch()
334
+ demo.queue().launch(allowed_paths=["/"], server_name="0.0.0.0", share=True, debug=True)
outdated-files/dd.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ input_variables=['context', 'question'] messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['context'], template=" \n In every question you are provided information about the insect. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect specie and a question by the user. answer the question according to these two types of informations. \n ---- \n Vetted info is as follows:\n species: Papaipema nebris\ncommon name: stalk borer\norder: Lepidoptera\nfamily: Noctuidae\ngenus: Papaipema\nUpdated role in ecosystem: Pest\nProof: The stalk borer, Papaipema nebris, is correctly classified as a pest because it damages a variety of crops by boring into their stems. This can significantly reduce crop yields and impact agricultural production. Its larvae feed inside the stalks causing the plants to wilt and often die, which characterizes its detrimental role in the ecosystem predominantly affecting agricultural systems.\nipm strategies: IPM strategies for stalk borer include crop rotation, planting trap crops, timely planting to avoid peak egg-laying periods, mechanical control through cultivation, and chemical controls with insecticides if necessary.\nsize of insect: 25-35 mm\ngeographical spread: United States, Canada\nlife cycle specifics: Stalk borers undergo complete metamorphosis, including egg, larva, pupa, and adult stages. Eggs are laid on grasses in fall, hatch in spring, and the larvae bore into stems of a variety of host plants. Pupation takes place within the host plant or in the soil.\npest for plant species: Corn, sorghum, and other grasses\nspecies status: It is native to North America.\ndistribution area: Most widespread in the corn belt of the United States.\nappearance: The adult stalk borer is a medium-sized moth with a wingspan of about 25-35 mm with distinct, dark, and light gray or brown patterns on its wings.\nidentification: Eggs are white and difficult to find, larvae are purplish or pinkish with a distinct white stripe and darker line above this, and adults are gray or brown moths with a characteristic wing pattern.\n\n ----\n The context retrieved for documents about this particular question is a as follows:\n {context}\n ----\n Additional Instruction:\n 1. Reference Constraint\n At the end of each answer provide the source/reference for the given data in following format: \n \n\n[enter two new lines before writing below] References:\n Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'. \n Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used. \n 2. Information Constraint: \n Only answer the question from information provided otherwise say you dont know. You have to answer in 150 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about. \n 3. Language constraint:\n The language should be acustomed to a researcher. This question is likely asked by an academic researcher. So you can use all the technical terms freely. And for referencing information, you can use the original content. This is only for the main answer to be provided by you.\n\n ----\n ")), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='Question:```{question}```'))]
outdated-files/rag-evaluation (outdated).ipynb ADDED
@@ -0,0 +1,896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "application/vnd.databricks.v1+cell": {
7
+ "cellMetadata": {},
8
+ "inputWidgets": {},
9
+ "nuid": "42084110-295b-493a-9b3e-5d8d29ff78b3",
10
+ "showTitle": false,
11
+ "title": ""
12
+ }
13
+ },
14
+ "source": [
15
+ "# LLM RAG Evaluation with MLflow Example Notebook\n",
16
+ "\n",
17
+ "In this notebook, we will demonstrate how to evaluate various a RAG system with MLflow."
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "raw",
22
+ "metadata": {},
23
+ "source": [
24
+ "<a href=\"https://raw.githubusercontent.com/mlflow/mlflow/master/docs/source/llms/llm-evaluate/notebooks/rag-evaluation.ipynb\" class=\"notebook-download-btn\"><i class=\"fas fa-download\"></i>Download this Notebook</a>"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "We need to set our OpenAI API key.\n",
32
+ "\n",
33
+ "In order to set your private key safely, please be sure to either export your key through a command-line terminal for your current instance, or, for a permanent addition to all user-based sessions, configure your favored environment management configuration file (i.e., .bashrc, .zshrc) to have the following entry:\n",
34
+ "\n",
35
+ "`OPENAI_API_KEY=<your openai API key>`\n",
36
+ "\n",
37
+ "If using Azure OpenAI, you will instead need to set\n",
38
+ "\n",
39
+ "`OPENAI_API_TYPE=\"azure\"`\n",
40
+ "\n",
41
+ "`OPENAI_API_VERSION=<YYYY-MM-DD>`\n",
42
+ "\n",
43
+ "`OPENAI_API_KEY=<https://<>.<>.<>.com>`\n",
44
+ "\n",
45
+ "`OPENAI_API_DEPLOYMENT_NAME=<deployment name>`\n"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 4,
51
+ "metadata": {
52
+ "application/vnd.databricks.v1+cell": {
53
+ "cellMetadata": {
54
+ "byteLimit": 2048000,
55
+ "rowLimit": 10000
56
+ },
57
+ "inputWidgets": {},
58
+ "nuid": "fb946228-62fb-4d68-9732-75935c9cb401",
59
+ "showTitle": false,
60
+ "title": ""
61
+ }
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "import pandas as pd\n",
66
+ "\n",
67
+ "import mlflow\n",
68
+ "import os\n",
69
+ "os.environ[\"OPENAI_API_KEY\"] =\"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\"\n"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "markdown",
74
+ "metadata": {
75
+ "application/vnd.databricks.v1+cell": {
76
+ "cellMetadata": {},
77
+ "inputWidgets": {},
78
+ "nuid": "273d1345-95d7-435a-a7b6-a5f3dbb3f073",
79
+ "showTitle": false,
80
+ "title": ""
81
+ }
82
+ },
83
+ "source": [
84
+ "## Create a RAG system\n",
85
+ "\n",
86
+ "Use Langchain and Chroma to create a RAG system that answers questions based on the MLflow documentation."
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 13,
92
+ "metadata": {
93
+ "application/vnd.databricks.v1+cell": {
94
+ "cellMetadata": {
95
+ "byteLimit": 2048000,
96
+ "rowLimit": 10000
97
+ },
98
+ "inputWidgets": {},
99
+ "nuid": "2c28d0ad-f469-46ab-a2b4-c5e8db50a729",
100
+ "showTitle": false,
101
+ "title": ""
102
+ }
103
+ },
104
+ "outputs": [],
105
+ "source": [
106
+ "from langchain.chains import RetrievalQA\n",
107
+ "from langchain.document_loaders import WebBaseLoader\n",
108
+ "from langchain.embeddings.openai import OpenAIEmbeddings\n",
109
+ "from langchain.llms import OpenAI\n",
110
+ "from langchain.text_splitter import CharacterTextSplitter\n",
111
+ "from langchain.vectorstores import Chroma\n",
112
+ "from langchain.chat_models import ChatOpenAI\n"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 14,
118
+ "metadata": {
119
+ "application/vnd.databricks.v1+cell": {
120
+ "cellMetadata": {
121
+ "byteLimit": 2048000,
122
+ "rowLimit": 10000
123
+ },
124
+ "inputWidgets": {},
125
+ "nuid": "83a7e77e-6717-472a-86dc-02e2c356ddef",
126
+ "showTitle": false,
127
+ "title": ""
128
+ }
129
+ },
130
+ "outputs": [],
131
+ "source": [
132
+ "loader = WebBaseLoader(\"https://mlflow.org/docs/latest/index.html\")\n",
133
+ "\n",
134
+ "documents = loader.load()\n",
135
+ "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
136
+ "texts = text_splitter.split_documents(documents)\n",
137
+ "\n",
138
+ "embeddings = OpenAIEmbeddings()\n",
139
+ "docsearch = Chroma.from_documents(texts, embeddings)\n",
140
+ "\n",
141
+ "qa = RetrievalQA.from_chain_type(\n",
142
+ " llm=ChatOpenAI(model_name=\"gpt-3.5-turbo-0125\" , temperature=0),\n",
143
+ " chain_type=\"stuff\",\n",
144
+ " retriever=docsearch.as_retriever(),\n",
145
+ " return_source_documents=True,\n",
146
+ ")"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "markdown",
151
+ "metadata": {
152
+ "application/vnd.databricks.v1+cell": {
153
+ "cellMetadata": {},
154
+ "inputWidgets": {},
155
+ "nuid": "fd70bcf6-7c44-44d3-9435-567b82611e1c",
156
+ "showTitle": false,
157
+ "title": ""
158
+ }
159
+ },
160
+ "source": [
161
+ "## Evaluate the RAG system using `mlflow.evaluate()`"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "markdown",
166
+ "metadata": {
167
+ "application/vnd.databricks.v1+cell": {
168
+ "cellMetadata": {},
169
+ "inputWidgets": {},
170
+ "nuid": "de1bc359-2e40-459c-bea4-bed35a117988",
171
+ "showTitle": false,
172
+ "title": ""
173
+ }
174
+ },
175
+ "source": [
176
+ "Create a simple function that runs each input through the RAG chain"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 15,
182
+ "metadata": {
183
+ "application/vnd.databricks.v1+cell": {
184
+ "cellMetadata": {
185
+ "byteLimit": 2048000,
186
+ "rowLimit": 10000
187
+ },
188
+ "inputWidgets": {},
189
+ "nuid": "667ec809-2bb5-4170-9937-6804386b41ec",
190
+ "showTitle": false,
191
+ "title": ""
192
+ }
193
+ },
194
+ "outputs": [],
195
+ "source": [
196
+ "def model(input_df):\n",
197
+ " answer = []\n",
198
+ " for index, row in input_df.iterrows():\n",
199
+ " answer.append(qa(row[\"questions\"]))\n",
200
+ "\n",
201
+ " return answer"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": []
210
+ },
211
+ {
212
+ "cell_type": "markdown",
213
+ "metadata": {
214
+ "application/vnd.databricks.v1+cell": {
215
+ "cellMetadata": {},
216
+ "inputWidgets": {},
217
+ "nuid": "d1064306-b7f3-4b3e-825c-4353d808f21d",
218
+ "showTitle": false,
219
+ "title": ""
220
+ }
221
+ },
222
+ "source": [
223
+ "Create an eval dataset"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": 16,
229
+ "metadata": {
230
+ "application/vnd.databricks.v1+cell": {
231
+ "cellMetadata": {
232
+ "byteLimit": 2048000,
233
+ "rowLimit": 10000
234
+ },
235
+ "inputWidgets": {},
236
+ "nuid": "a5481491-e4a9-42ea-8a3f-f527faffd04d",
237
+ "showTitle": false,
238
+ "title": ""
239
+ }
240
+ },
241
+ "outputs": [],
242
+ "source": [
243
+ "eval_df = pd.DataFrame(\n",
244
+ " {\n",
245
+ " \"questions\": [\n",
246
+ " \"What is MLflow?\",\n",
247
+ " \"How to run mlflow.evaluate()?\",\n",
248
+ " \"How to log_table()?\",\n",
249
+ " \"How to load_table()?\",\n",
250
+ " ],\n",
251
+ " }\n",
252
+ ")"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "markdown",
257
+ "metadata": {
258
+ "application/vnd.databricks.v1+cell": {
259
+ "cellMetadata": {},
260
+ "inputWidgets": {},
261
+ "nuid": "9c3c8023-8feb-427a-b36d-34cd1853a5dc",
262
+ "showTitle": false,
263
+ "title": ""
264
+ }
265
+ },
266
+ "source": [
267
+ "Create a faithfulness metric"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 17,
273
+ "metadata": {
274
+ "application/vnd.databricks.v1+cell": {
275
+ "cellMetadata": {
276
+ "byteLimit": 2048000,
277
+ "rowLimit": 10000
278
+ },
279
+ "inputWidgets": {},
280
+ "nuid": "3882b940-9c25-41ce-a301-72d8c0c90aaa",
281
+ "showTitle": false,
282
+ "title": ""
283
+ }
284
+ },
285
+ "outputs": [
286
+ {
287
+ "name": "stdout",
288
+ "output_type": "stream",
289
+ "text": [
290
+ "EvaluationMetric(name=faithfulness, greater_is_better=True, long_name=faithfulness, version=v1, metric_details=\n",
291
+ "Task:\n",
292
+ "You must return the following fields in your response in two lines, one below the other:\n",
293
+ "score: Your numerical score for the model's faithfulness based on the rubric\n",
294
+ "justification: Your reasoning about the model's faithfulness score\n",
295
+ "\n",
296
+ "You are an impartial judge. You will be given an input that was sent to a machine\n",
297
+ "learning model, and you will be given an output that the model produced. You\n",
298
+ "may also be given additional information that was used by the model to generate the output.\n",
299
+ "\n",
300
+ "Your task is to determine a numerical score called faithfulness based on the input and output.\n",
301
+ "A definition of faithfulness and a grading rubric are provided below.\n",
302
+ "You must use the grading rubric to determine your score. You must also justify your score.\n",
303
+ "\n",
304
+ "Examples could be included below for reference. Make sure to use them as references and to\n",
305
+ "understand them before completing the task.\n",
306
+ "\n",
307
+ "Input:\n",
308
+ "{input}\n",
309
+ "\n",
310
+ "Output:\n",
311
+ "{output}\n",
312
+ "\n",
313
+ "{grading_context_columns}\n",
314
+ "\n",
315
+ "Metric definition:\n",
316
+ "Faithfulness is only evaluated with the provided output and provided context, please ignore the provided input entirely when scoring faithfulness. Faithfulness assesses how much of the provided output is factually consistent with the provided context. A higher score indicates that a higher proportion of claims present in the output can be derived from the provided context. Faithfulness does not consider how much extra information from the context is not present in the output.\n",
317
+ "\n",
318
+ "Grading rubric:\n",
319
+ "Faithfulness: Below are the details for different scores:\n",
320
+ "- Score 1: None of the claims in the output can be inferred from the provided context.\n",
321
+ "- Score 2: Some of the claims in the output can be inferred from the provided context, but the majority of the output is missing from, inconsistent with, or contradictory to the provided context.\n",
322
+ "- Score 3: Half or more of the claims in the output can be inferred from the provided context.\n",
323
+ "- Score 4: Most of the claims in the output can be inferred from the provided context, with very little information that is not directly supported by the provided context.\n",
324
+ "- Score 5: All of the claims in the output are directly supported by the provided context, demonstrating high faithfulness to the provided context.\n",
325
+ "\n",
326
+ "Examples:\n",
327
+ "\n",
328
+ "Example Input:\n",
329
+ "How do I disable MLflow autologging?\n",
330
+ "\n",
331
+ "Example Output:\n",
332
+ "mlflow.autolog(disable=True) will disable autologging for all functions. In Databricks, autologging is enabled by default. \n",
333
+ "\n",
334
+ "Additional information used by the model:\n",
335
+ "key: context\n",
336
+ "value:\n",
337
+ "mlflow.autolog(log_input_examples: bool = False, log_model_signatures: bool = True, log_models: bool = True, log_datasets: bool = True, disable: bool = False, exclusive: bool = False, disable_for_unsupported_versions: bool = False, silent: bool = False, extra_tags: Optional[Dict[str, str]] = None) → None[source] Enables (or disables) and configures autologging for all supported integrations. The parameters are passed to any autologging integrations that support them. See the tracking docs for a list of supported autologging integrations. Note that framework-specific configurations set at any point will take precedence over any configurations set by this function.\n",
338
+ "\n",
339
+ "Example score: 2\n",
340
+ "Example justification: The output provides a working solution, using the mlflow.autolog() function that is provided in the context.\n",
341
+ " \n",
342
+ "\n",
343
+ "Example Input:\n",
344
+ "How do I disable MLflow autologging?\n",
345
+ "\n",
346
+ "Example Output:\n",
347
+ "mlflow.autolog(disable=True) will disable autologging for all functions.\n",
348
+ "\n",
349
+ "Additional information used by the model:\n",
350
+ "key: context\n",
351
+ "value:\n",
352
+ "mlflow.autolog(log_input_examples: bool = False, log_model_signatures: bool = True, log_models: bool = True, log_datasets: bool = True, disable: bool = False, exclusive: bool = False, disable_for_unsupported_versions: bool = False, silent: bool = False, extra_tags: Optional[Dict[str, str]] = None) → None[source] Enables (or disables) and configures autologging for all supported integrations. The parameters are passed to any autologging integrations that support them. See the tracking docs for a list of supported autologging integrations. Note that framework-specific configurations set at any point will take precedence over any configurations set by this function.\n",
353
+ "\n",
354
+ "Example score: 5\n",
355
+ "Example justification: The output provides a solution that is using the mlflow.autolog() function that is provided in the context.\n",
356
+ " \n",
357
+ "\n",
358
+ "You must return the following fields in your response in two lines, one below the other:\n",
359
+ "score: Your numerical score for the model's faithfulness based on the rubric\n",
360
+ "justification: Your reasoning about the model's faithfulness score\n",
361
+ "\n",
362
+ "Do not add additional new lines. Do not add any other fields.\n",
363
+ " )\n"
364
+ ]
365
+ }
366
+ ],
367
+ "source": [
368
+ "from mlflow.metrics.genai import EvaluationExample, faithfulness\n",
369
+ "\n",
370
+ "# Create a good and bad example for faithfulness in the context of this problem\n",
371
+ "faithfulness_examples = [\n",
372
+ " EvaluationExample(\n",
373
+ " input=\"How do I disable MLflow autologging?\",\n",
374
+ " output=\"mlflow.autolog(disable=True) will disable autologging for all functions. In Databricks, autologging is enabled by default. \",\n",
375
+ " score=2,\n",
376
+ " justification=\"The output provides a working solution, using the mlflow.autolog() function that is provided in the context.\",\n",
377
+ " grading_context={\n",
378
+ " \"context\": \"mlflow.autolog(log_input_examples: bool = False, log_model_signatures: bool = True, log_models: bool = True, log_datasets: bool = True, disable: bool = False, exclusive: bool = False, disable_for_unsupported_versions: bool = False, silent: bool = False, extra_tags: Optional[Dict[str, str]] = None) → None[source] Enables (or disables) and configures autologging for all supported integrations. The parameters are passed to any autologging integrations that support them. See the tracking docs for a list of supported autologging integrations. Note that framework-specific configurations set at any point will take precedence over any configurations set by this function.\"\n",
379
+ " },\n",
380
+ " ),\n",
381
+ " EvaluationExample(\n",
382
+ " input=\"How do I disable MLflow autologging?\",\n",
383
+ " output=\"mlflow.autolog(disable=True) will disable autologging for all functions.\",\n",
384
+ " score=5,\n",
385
+ " justification=\"The output provides a solution that is using the mlflow.autolog() function that is provided in the context.\",\n",
386
+ " grading_context={\n",
387
+ " \"context\": \"mlflow.autolog(log_input_examples: bool = False, log_model_signatures: bool = True, log_models: bool = True, log_datasets: bool = True, disable: bool = False, exclusive: bool = False, disable_for_unsupported_versions: bool = False, silent: bool = False, extra_tags: Optional[Dict[str, str]] = None) → None[source] Enables (or disables) and configures autologging for all supported integrations. The parameters are passed to any autologging integrations that support them. See the tracking docs for a list of supported autologging integrations. Note that framework-specific configurations set at any point will take precedence over any configurations set by this function.\"\n",
388
+ " },\n",
389
+ " ),\n",
390
+ "]\n",
391
+ "\n",
392
+ "faithfulness_metric = faithfulness(model=\"openai:/gpt-4\", examples=faithfulness_examples)\n",
393
+ "print(faithfulness_metric)"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "markdown",
398
+ "metadata": {},
399
+ "source": [
400
+ "Create a relevance metric. You can see the full grading prompt by printing the metric or by accessing the `metric_details` attribute of the metric."
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": 18,
406
+ "metadata": {},
407
+ "outputs": [
408
+ {
409
+ "name": "stdout",
410
+ "output_type": "stream",
411
+ "text": [
412
+ "EvaluationMetric(name=relevance, greater_is_better=True, long_name=relevance, version=v1, metric_details=\n",
413
+ "Task:\n",
414
+ "You must return the following fields in your response in two lines, one below the other:\n",
415
+ "score: Your numerical score for the model's relevance based on the rubric\n",
416
+ "justification: Your reasoning about the model's relevance score\n",
417
+ "\n",
418
+ "You are an impartial judge. You will be given an input that was sent to a machine\n",
419
+ "learning model, and you will be given an output that the model produced. You\n",
420
+ "may also be given additional information that was used by the model to generate the output.\n",
421
+ "\n",
422
+ "Your task is to determine a numerical score called relevance based on the input and output.\n",
423
+ "A definition of relevance and a grading rubric are provided below.\n",
424
+ "You must use the grading rubric to determine your score. You must also justify your score.\n",
425
+ "\n",
426
+ "Examples could be included below for reference. Make sure to use them as references and to\n",
427
+ "understand them before completing the task.\n",
428
+ "\n",
429
+ "Input:\n",
430
+ "{input}\n",
431
+ "\n",
432
+ "Output:\n",
433
+ "{output}\n",
434
+ "\n",
435
+ "{grading_context_columns}\n",
436
+ "\n",
437
+ "Metric definition:\n",
438
+ "Relevance encompasses the appropriateness, significance, and applicability of the output with respect to both the input and context. Scores should reflect the extent to which the output directly addresses the question provided in the input, given the provided context.\n",
439
+ "\n",
440
+ "Grading rubric:\n",
441
+ "Relevance: Below are the details for different scores:- Score 1: The output doesn't mention anything about the question or is completely irrelevant to the provided context.\n",
442
+ "- Score 2: The output provides some relevance to the question and is somehow related to the provided context.\n",
443
+ "- Score 3: The output mostly answers the question and is largely consistent with the provided context.\n",
444
+ "- Score 4: The output answers the question and is consistent with the provided context.\n",
445
+ "- Score 5: The output answers the question comprehensively using the provided context.\n",
446
+ "\n",
447
+ "Examples:\n",
448
+ "\n",
449
+ "Example Input:\n",
450
+ "How is MLflow related to Databricks?\n",
451
+ "\n",
452
+ "Example Output:\n",
453
+ "Databricks is a data engineering and analytics platform designed to help organizations process and analyze large amounts of data. Databricks is a company specializing in big data and machine learning solutions.\n",
454
+ "\n",
455
+ "Additional information used by the model:\n",
456
+ "key: context\n",
457
+ "value:\n",
458
+ "MLflow is an open-source platform for managing the end-to-end machine learning (ML) lifecycle. It was developed by Databricks, a company that specializes in big data and machine learning solutions. MLflow is designed to address the challenges that data scientists and machine learning engineers face when developing, training, and deploying machine learning models.\n",
459
+ "\n",
460
+ "Example score: 2\n",
461
+ "Example justification: The output provides relevant information about Databricks, mentioning it as a company specializing in big data and machine learning solutions. However, it doesn't directly address how MLflow is related to Databricks, which is the specific question asked in the input. Therefore, the output is only somewhat related to the provided context.\n",
462
+ " \n",
463
+ "\n",
464
+ "Example Input:\n",
465
+ "How is MLflow related to Databricks?\n",
466
+ "\n",
467
+ "Example Output:\n",
468
+ "MLflow is a product created by Databricks to enhance the efficiency of machine learning processes.\n",
469
+ "\n",
470
+ "Additional information used by the model:\n",
471
+ "key: context\n",
472
+ "value:\n",
473
+ "MLflow is an open-source platform for managing the end-to-end machine learning (ML) lifecycle. It was developed by Databricks, a company that specializes in big data and machine learning solutions. MLflow is designed to address the challenges that data scientists and machine learning engineers face when developing, training, and deploying machine learning models.\n",
474
+ "\n",
475
+ "Example score: 4\n",
476
+ "Example justification: The output provides a relevant and accurate statement about the relationship between MLflow and Databricks. While it doesn't provide extensive detail, it still offers a substantial and meaningful response. To achieve a score of 5, the response could be further improved by providing additional context or details about how MLflow specifically functions within the Databricks ecosystem.\n",
477
+ " \n",
478
+ "\n",
479
+ "You must return the following fields in your response in two lines, one below the other:\n",
480
+ "score: Your numerical score for the model's relevance based on the rubric\n",
481
+ "justification: Your reasoning about the model's relevance score\n",
482
+ "\n",
483
+ "Do not add additional new lines. Do not add any other fields.\n",
484
+ " )\n"
485
+ ]
486
+ }
487
+ ],
488
+ "source": [
489
+ "from mlflow.metrics.genai import EvaluationExample, relevance\n",
490
+ "\n",
491
+ "relevance_metric = relevance(model=\"openai:/gpt-4\")\n",
492
+ "print(relevance_metric)"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "code",
497
+ "execution_count": 24,
498
+ "metadata": {},
499
+ "outputs": [],
500
+ "source": [
501
+ "eval_df_final=eval_df.copy(deep=True)"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "execution_count": 25,
507
+ "metadata": {},
508
+ "outputs": [],
509
+ "source": [
510
+ "aa=model(eval_df)"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": 26,
516
+ "metadata": {},
517
+ "outputs": [
518
+ {
519
+ "data": {
520
+ "text/plain": [
521
+ "[{'query': 'What is MLflow?',\n",
522
+ " 'result': 'MLflow is an open-source platform designed to help machine learning practitioners and teams manage the complexities of the machine learning process. It focuses on the full lifecycle of machine learning projects, making sure that each phase is manageable, traceable, and reproducible.',\n",
523
+ " 'source_documents': [Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
524
+ " Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
525
+ " Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
526
+ " Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation\\n\\n2.12.1\\n\\n\\n MLflow\\n\\nWhat is MLflow?\\nGetting Started with MLflow\\nNew Features\\nLLMs\\nModel Evaluation\\nDeep Learning\\nTraditional ML\\nDeployment\\nMLflow Tracking\\nSystem Metrics\\nMLflow Projects\\nMLflow Models\\nMLflow Model Registry\\nMLflow Recipes\\nMLflow Plugins\\nMLflow Authentication\\nCommand-Line Interface\\nSearch Runs\\nSearch Experiments\\nPython API\\nR API\\nJava API\\nREST API\\nOfficial MLflow Docker Image\\nCommunity Model Flavors\\nTutorials and Examples\\n\\n\\nContribute\\n\\n\\nDocumentation \\nMLflow: A Tool for Managing the Machine Learning Lifecycle', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
527
+ " {'query': 'How to run mlflow.evaluate()?',\n",
528
+ " 'result': 'To run `mlflow.evaluate()`, you need to follow these steps:\\n\\n1. Ensure you have MLflow installed in your environment.\\n2. Import the necessary libraries, including `mlflow`.\\n3. Load your model and data.\\n4. Use the `mlflow.evaluate()` function, passing in your model, data, and any other necessary parameters.\\n5. Review the evaluation results and metrics provided by the function.\\n\\nIf you need more specific details or code examples, please refer to the MLflow documentation or tutorials for a step-by-step guide on running `mlflow.evaluate()`.',\n",
529
+ " 'source_documents': [Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
530
+ " Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
531
+ " Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
532
+ " Document(page_content='Learn how to evaluate LLMs and LLM-powered solutions with MLflow Evaluate.\\n \\n\\n Using Custom PyFunc with LLMs\\n \\n\\n Explore the nuances of packaging and deploying advanced LLMs in MLflow using custom PyFuncs. This guide delves deep\\n into managing intricate model behaviors, ensuring seamless and efficient LLM deployments.\\n \\n\\n Evaluation for RAG\\n \\n\\n Learn how to evaluate Retrieval Augmented Generation applications by leveraging LLMs to generate a evaluation dataset and evaluate it using the built-in metrics in the MLflow Evaluate API.\\n \\n\\n LLM Tracking with MLflow', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
533
+ " {'query': 'How to log_table()?',\n",
534
+ " 'result': \"I don't have information on a specific function called `log_table()` in the context provided. It's possible that it might be a custom function or a feature not explicitly mentioned in the provided context. If you can provide more details or context, I may be able to assist you further.\",\n",
535
+ " 'source_documents': [Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
536
+ " Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
537
+ " Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
538
+ " Document(page_content=\"Dive into the intricacies of MLflow's LLM Tracking system. From capturing prompts to monitoring generated outputs,\\n discover how MLflow provides a holistic solution for managing LLM interactions.\", metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
539
+ " {'query': 'How to load_table()?',\n",
540
+ " 'result': \"I'm not sure what `load_table()` refers to in this context. If you can provide more information or context, I might be able to help you better.\",\n",
541
+ " 'source_documents': [Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
542
+ " Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
543
+ " Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
544
+ " Document(page_content='This guide showcases the seamless end-to-end process of training a linear regression model, packaging it in a reproducible format,\\n and deploying to a Kubernetes cluster using MLflow. Explore how MLflow simplifies model deployment to production environments.\\n \\n\\n\\nNext \\n\\n\\n © MLflow Project, a Series of LF Projects, LLC. All rights reserved.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]}]"
545
+ ]
546
+ },
547
+ "execution_count": 26,
548
+ "metadata": {},
549
+ "output_type": "execute_result"
550
+ }
551
+ ],
552
+ "source": [
553
+ "aa"
554
+ ]
555
+ },
556
+ {
557
+ "cell_type": "code",
558
+ "execution_count": 20,
559
+ "metadata": {},
560
+ "outputs": [
561
+ {
562
+ "data": {
563
+ "text/plain": [
564
+ "[{'query': 'What is MLflow?',\n",
565
+ " 'result': 'MLflow is an open-source platform designed to help machine learning practitioners and teams manage the complexities of the machine learning process. It focuses on the full lifecycle of machine learning projects, ensuring that each phase is manageable, traceable, and reproducible. It offers features such as tracking, projects, models, model registry, and more to assist in solving real-world MLOps problems.',\n",
566
+ " 'source_documents': [Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
567
+ " Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
568
+ " Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle \\nMLflow is an open-source platform, purpose-built to assist machine learning practitioners and teams in\\nhandling the complexities of the machine learning process. MLflow focuses on the full lifecycle for\\nmachine learning projects, ensuring that each phase is manageable, traceable, and reproducible.\\nIn each of the sections below, you will find overviews, guides, and step-by-step tutorials to walk you through\\nthe features of MLflow and how they can be leveraged to solve real-world MLOps problems.\\n\\nGetting Started with MLflow \\nIf this is your first time exploring MLflow, the tutorials and guides here are a great place to start. The emphasis in each of these is\\ngetting you up to speed as quickly as possible with the basic functionality, terms, APIs, and general best practices of using MLflow in order to\\nenhance your learning in area-specific guides and tutorials.\\n\\nGetting Started Guides and Quickstarts', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
569
+ " Document(page_content='MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation\\n\\n2.12.1\\n\\n\\n MLflow\\n\\nWhat is MLflow?\\nGetting Started with MLflow\\nNew Features\\nLLMs\\nModel Evaluation\\nDeep Learning\\nTraditional ML\\nDeployment\\nMLflow Tracking\\nSystem Metrics\\nMLflow Projects\\nMLflow Models\\nMLflow Model Registry\\nMLflow Recipes\\nMLflow Plugins\\nMLflow Authentication\\nCommand-Line Interface\\nSearch Runs\\nSearch Experiments\\nPython API\\nR API\\nJava API\\nREST API\\nOfficial MLflow Docker Image\\nCommunity Model Flavors\\nTutorials and Examples\\n\\n\\nContribute\\n\\n\\nDocumentation \\nMLflow: A Tool for Managing the Machine Learning Lifecycle', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
570
+ " {'query': 'How to run mlflow.evaluate()?',\n",
571
+ " 'result': 'To run `mlflow.evaluate()`, you need to follow these steps:\\n\\n1. Import the necessary libraries:\\n```python\\nimport mlflow\\n```\\n\\n2. Load your model and data:\\n```python\\nmodel = load_model() # Load your model\\ndata = load_data() # Load your evaluation data\\n```\\n\\n3. Use `mlflow.evaluate()` to evaluate your model:\\n```python\\nmlflow.evaluate(model, data)\\n```\\n\\n4. Review the evaluation metrics and visual insights in the MLflow UI.\\n\\nIf you need more specific details or have a particular use case in mind, please provide additional context for a more tailored explanation.',\n",
572
+ " 'source_documents': [Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
573
+ " Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
574
+ " Document(page_content='Model Evaluation \\nDive into MLflow’s robust framework for evaluating the performance of your ML models.\\nWith support for traditional ML evaluation (classification and regression tasks), as well as support for evaluating large language models (LLMs),\\nthis suite of APIs offers a simple but powerful automated approach to evaluating the quality of the model development work that you’re doing.\\nIn particular, for LLM evaluation, the mlflow.evaluate() API allows you to validate not only models, but providers and prompts.\\nBy leveraging your own datasets and using the provided default evaluation criteria for tasks such as text summarization and question answering, you can\\nget reliable metrics that allow you to focus on improving the quality of your solution, rather than spending time writing scoring code.\\nVisual insights are also available through the MLflow UI, showcasing logged outputs, auto-generated plots, and model comparison artifacts.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
575
+ " Document(page_content='Learn how to evaluate LLMs and LLM-powered solutions with MLflow Evaluate.\\n \\n\\n Using Custom PyFunc with LLMs\\n \\n\\n Explore the nuances of packaging and deploying advanced LLMs in MLflow using custom PyFuncs. This guide delves deep\\n into managing intricate model behaviors, ensuring seamless and efficient LLM deployments.\\n \\n\\n Evaluation for RAG\\n \\n\\n Learn how to evaluate Retrieval Augmented Generation applications by leveraging LLMs to generate a evaluation dataset and evaluate it using the built-in metrics in the MLflow Evaluate API.\\n \\n\\n LLM Tracking with MLflow', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
576
+ " {'query': 'How to log_table()?',\n",
577
+ " 'result': \"I don't have information on a specific function called `log_table()` in the context provided. It's possible that it might be a custom function or a feature not covered in the provided context. If you can provide more details or context, I may be able to assist you further.\",\n",
578
+ " 'source_documents': [Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
579
+ " Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
580
+ " Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
581
+ " Document(page_content=\"Dive into the intricacies of MLflow's LLM Tracking system. From capturing prompts to monitoring generated outputs,\\n discover how MLflow provides a holistic solution for managing LLM interactions.\", metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]},\n",
582
+ " {'query': 'How to load_table()?',\n",
583
+ " 'result': \"I don't have information on a function called `load_table()` in the context provided. It seems to be outside the scope of the MLflow Tracking, Autologging, and Deployment Quickstart guides. If you can provide more context or clarify where `load_table()` is from, I may be able to assist you further.\",\n",
584
+ " 'source_documents': [Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
585
+ " Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
586
+ " Document(page_content='MLflow Tracking Quickstart\\n \\n\\n A great place to start to learn the fundamentals of MLflow Tracking! Learn in 5 minutes how to log, register, and load a model for inference.\\n \\n\\n Intro to MLflow Tutorial\\n \\n\\n Learn how to get started with the basics of MLflow in a step-by-step instructional tutorial that shows the critical\\n path to logging your first model\\n \\n\\n Autologging Quickstart\\n \\n\\n Short on time? This is a no-frills quickstart that shows how to leverage autologging during training and how to\\n load a model for inference\\n \\n\\n Deployment Quickstart', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'}),\n",
587
+ " Document(page_content='This guide showcases the seamless end-to-end process of training a linear regression model, packaging it in a reproducible format,\\n and deploying to a Kubernetes cluster using MLflow. Explore how MLflow simplifies model deployment to production environments.\\n \\n\\n\\nNext \\n\\n\\n © MLflow Project, a Series of LF Projects, LLC. All rights reserved.', metadata={'language': 'en', 'source': 'https://mlflow.org/docs/latest/index.html', 'title': 'MLflow: A Tool for Managing the Machine Learning Lifecycle — MLflow 2.12.1 documentation'})]}]"
588
+ ]
589
+ },
590
+ "execution_count": 20,
591
+ "metadata": {},
592
+ "output_type": "execute_result"
593
+ }
594
+ ],
595
+ "source": [
596
+ "eval_df_final[\"results\"]=model(eval_df)"
597
+ ]
598
+ },
599
+ {
600
+ "cell_type": "code",
601
+ "execution_count": 29,
602
+ "metadata": {
603
+ "application/vnd.databricks.v1+cell": {
604
+ "cellMetadata": {
605
+ "byteLimit": 2048000,
606
+ "rowLimit": 10000
607
+ },
608
+ "inputWidgets": {},
609
+ "nuid": "ea40ce52-6ac7-4c20-9669-d24f80a6cebe",
610
+ "showTitle": false,
611
+ "title": ""
612
+ }
613
+ },
614
+ "outputs": [
615
+ {
616
+ "name": "stderr",
617
+ "output_type": "stream",
618
+ "text": [
619
+ "/u/marshad/.conda/envs/agllm-env1/lib/python3.9/site-packages/mlflow/data/digest_utils.py:26: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
620
+ " string_columns = trimmed_df.columns[(df.applymap(type) == str).all(0)]\n",
621
+ "/u/marshad/.conda/envs/agllm-env1/lib/python3.9/site-packages/mlflow/models/evaluation/base.py:414: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
622
+ " data = data.applymap(_hash_array_like_element_as_bytes)\n",
623
+ "2024/04/21 17:23:37 INFO mlflow.models.evaluation.base: Evaluating the model with the default evaluator.\n",
624
+ "2024/04/21 17:23:37 INFO mlflow.models.evaluation.default_evaluator: Computing model predictions.\n",
625
+ "2024/04/21 17:23:44 INFO mlflow.models.evaluation.default_evaluator: Testing metrics on first row...\n",
626
+ "2024/04/21 17:23:44 WARNING mlflow.metrics.metric_definitions: Failed to load 'toxicity' metric (error: ModuleNotFoundError(\"No module named 'evaluate'\")), skipping metric logging.\n",
627
+ "2024/04/21 17:23:44 WARNING mlflow.metrics.metric_definitions: Failed to load flesch kincaid metric, skipping metric logging.\n",
628
+ "2024/04/21 17:23:44 WARNING mlflow.metrics.metric_definitions: Failed to load automated readability index metric, skipping metric logging.\n",
629
+ "100%|██████████| 1/1 [00:03<00:00, 3.53s/it]\n",
630
+ "100%|██████████| 1/1 [00:03<00:00, 3.72s/it]\n",
631
+ "2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: token_count\n",
632
+ "2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: toxicity\n",
633
+ "2024/04/21 17:23:51 WARNING mlflow.metrics.metric_definitions: Failed to load 'toxicity' metric (error: ModuleNotFoundError(\"No module named 'evaluate'\")), skipping metric logging.\n",
634
+ "2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: flesch_kincaid_grade_level\n",
635
+ "2024/04/21 17:23:51 WARNING mlflow.metrics.metric_definitions: Failed to load flesch kincaid metric, skipping metric logging.\n",
636
+ "2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: ari_grade_level\n",
637
+ "2024/04/21 17:23:51 WARNING mlflow.metrics.metric_definitions: Failed to load automated readability index metric, skipping metric logging.\n",
638
+ "2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: exact_match\n",
639
+ "2024/04/21 17:23:51 INFO mlflow.models.evaluation.default_evaluator: Evaluating metrics: faithfulness\n",
640
+ "100%|██████████| 4/4 [00:03<00:00, 1.02it/s]\n",
641
+ "2024/04/21 17:23:55 INFO mlflow.models.evaluation.default_evaluator: Evaluating metrics: relevance\n",
642
+ "100%|██████████| 4/4 [00:04<00:00, 1.20s/it]\n"
643
+ ]
644
+ },
645
+ {
646
+ "name": "stdout",
647
+ "output_type": "stream",
648
+ "text": [
649
+ "{'latency/mean': 1.7213420271873474, 'latency/variance': 0.04483020464773446, 'latency/p90': 1.942362928390503, 'faithfulness/v1/mean': 4.75, 'faithfulness/v1/variance': 0.1875, 'faithfulness/v1/p90': 5.0, 'relevance/v1/mean': 3.75, 'relevance/v1/variance': 1.6875, 'relevance/v1/p90': 5.0}\n"
650
+ ]
651
+ }
652
+ ],
653
+ "source": [
654
+ "results = mlflow.evaluate(\n",
655
+ " model,\n",
656
+ " eval_df,\n",
657
+ " model_type=\"question-answering\",\n",
658
+ " evaluators=\"default\",\n",
659
+ " predictions=\"result\",\n",
660
+ " extra_metrics=[faithfulness_metric, relevance_metric, mlflow.metrics.latency()],\n",
661
+ " evaluator_config={\n",
662
+ " \"col_mapping\": {\n",
663
+ " \"inputs\": \"questions\",\n",
664
+ " \"context\": \"source_documents\",\n",
665
+ " }\n",
666
+ " },\n",
667
+ ")\n",
668
+ "print(results.metrics)"
669
+ ]
670
+ },
671
+ {
672
+ "cell_type": "code",
673
+ "execution_count": null,
674
+ "metadata": {},
675
+ "outputs": [],
676
+ "source": []
677
+ },
678
+ {
679
+ "cell_type": "code",
680
+ "execution_count": 13,
681
+ "metadata": {
682
+ "application/vnd.databricks.v1+cell": {
683
+ "cellMetadata": {},
684
+ "inputWidgets": {},
685
+ "nuid": "989a0861-5153-44e6-a19d-efcae7fe6cb5",
686
+ "showTitle": false,
687
+ "title": ""
688
+ }
689
+ },
690
+ "outputs": [
691
+ {
692
+ "data": {
693
+ "application/vnd.jupyter.widget-view+json": {
694
+ "model_id": "747f65b309b94257b396eebffe814fa6",
695
+ "version_major": 2,
696
+ "version_minor": 0
697
+ },
698
+ "text/plain": [
699
+ "Downloading artifacts: 0%| | 0/1 [00:00<?, ?it/s]"
700
+ ]
701
+ },
702
+ "metadata": {},
703
+ "output_type": "display_data"
704
+ },
705
+ {
706
+ "data": {
707
+ "text/html": [
708
+ "<div>\n",
709
+ "<style scoped>\n",
710
+ " .dataframe tbody tr th:only-of-type {\n",
711
+ " vertical-align: middle;\n",
712
+ " }\n",
713
+ "\n",
714
+ " .dataframe tbody tr th {\n",
715
+ " vertical-align: top;\n",
716
+ " }\n",
717
+ "\n",
718
+ " .dataframe thead th {\n",
719
+ " text-align: right;\n",
720
+ " }\n",
721
+ "</style>\n",
722
+ "<table border=\"1\" class=\"dataframe\">\n",
723
+ " <thead>\n",
724
+ " <tr style=\"text-align: right;\">\n",
725
+ " <th></th>\n",
726
+ " <th>questions</th>\n",
727
+ " <th>outputs</th>\n",
728
+ " <th>source_documents</th>\n",
729
+ " <th>latency</th>\n",
730
+ " <th>token_count</th>\n",
731
+ " <th>toxicity/v1/score</th>\n",
732
+ " <th>flesch_kincaid_grade_level/v1/score</th>\n",
733
+ " <th>ari_grade_level/v1/score</th>\n",
734
+ " <th>faithfulness/v1/score</th>\n",
735
+ " <th>faithfulness/v1/justification</th>\n",
736
+ " <th>relevance/v1/score</th>\n",
737
+ " <th>relevance/v1/justification</th>\n",
738
+ " </tr>\n",
739
+ " </thead>\n",
740
+ " <tbody>\n",
741
+ " <tr>\n",
742
+ " <th>0</th>\n",
743
+ " <td>What is MLflow?</td>\n",
744
+ " <td>MLflow is an open-source platform, purpose-bu...</td>\n",
745
+ " <td>[{'lc_attributes': {}, 'lc_namespace': ['langc...</td>\n",
746
+ " <td>1.989822</td>\n",
747
+ " <td>53</td>\n",
748
+ " <td>0.000137</td>\n",
749
+ " <td>12.5</td>\n",
750
+ " <td>18.4</td>\n",
751
+ " <td>5</td>\n",
752
+ " <td>The output provided by the model is a direct e...</td>\n",
753
+ " <td>5</td>\n",
754
+ " <td>The output provides a comprehensive answer to ...</td>\n",
755
+ " </tr>\n",
756
+ " <tr>\n",
757
+ " <th>1</th>\n",
758
+ " <td>How to run mlflow.evaluate()?</td>\n",
759
+ " <td>The mlflow.evaluate() API allows you to valid...</td>\n",
760
+ " <td>[{'lc_attributes': {}, 'lc_namespace': ['langc...</td>\n",
761
+ " <td>1.945368</td>\n",
762
+ " <td>55</td>\n",
763
+ " <td>0.000200</td>\n",
764
+ " <td>9.1</td>\n",
765
+ " <td>12.6</td>\n",
766
+ " <td>5</td>\n",
767
+ " <td>The output provided by the model is completely...</td>\n",
768
+ " <td>4</td>\n",
769
+ " <td>The output provides a relevant and accurate ex...</td>\n",
770
+ " </tr>\n",
771
+ " <tr>\n",
772
+ " <th>2</th>\n",
773
+ " <td>How to log_table()?</td>\n",
774
+ " <td>You can log a table with MLflow using the log...</td>\n",
775
+ " <td>[{'lc_attributes': {}, 'lc_namespace': ['langc...</td>\n",
776
+ " <td>1.521511</td>\n",
777
+ " <td>32</td>\n",
778
+ " <td>0.000289</td>\n",
779
+ " <td>5.0</td>\n",
780
+ " <td>6.8</td>\n",
781
+ " <td>1</td>\n",
782
+ " <td>The output claims that you can log a table wit...</td>\n",
783
+ " <td>5</td>\n",
784
+ " <td>The output provides a comprehensive answer to ...</td>\n",
785
+ " </tr>\n",
786
+ " <tr>\n",
787
+ " <th>3</th>\n",
788
+ " <td>How to load_table()?</td>\n",
789
+ " <td>You can't load_table() with MLflow. MLflow is...</td>\n",
790
+ " <td>[{'lc_attributes': {}, 'lc_namespace': ['langc...</td>\n",
791
+ " <td>1.105279</td>\n",
792
+ " <td>27</td>\n",
793
+ " <td>0.000279</td>\n",
794
+ " <td>5.8</td>\n",
795
+ " <td>8.8</td>\n",
796
+ " <td>5</td>\n",
797
+ " <td>The output claim that \"You can't load_table() ...</td>\n",
798
+ " <td>4</td>\n",
799
+ " <td>The output provides a relevant and accurate re...</td>\n",
800
+ " </tr>\n",
801
+ " </tbody>\n",
802
+ "</table>\n",
803
+ "</div>"
804
+ ],
805
+ "text/plain": [
806
+ " questions \\\n",
807
+ "0 What is MLflow? \n",
808
+ "1 How to run mlflow.evaluate()? \n",
809
+ "2 How to log_table()? \n",
810
+ "3 How to load_table()? \n",
811
+ "\n",
812
+ " outputs \\\n",
813
+ "0 MLflow is an open-source platform, purpose-bu... \n",
814
+ "1 The mlflow.evaluate() API allows you to valid... \n",
815
+ "2 You can log a table with MLflow using the log... \n",
816
+ "3 You can't load_table() with MLflow. MLflow is... \n",
817
+ "\n",
818
+ " source_documents latency token_count \\\n",
819
+ "0 [{'lc_attributes': {}, 'lc_namespace': ['langc... 1.989822 53 \n",
820
+ "1 [{'lc_attributes': {}, 'lc_namespace': ['langc... 1.945368 55 \n",
821
+ "2 [{'lc_attributes': {}, 'lc_namespace': ['langc... 1.521511 32 \n",
822
+ "3 [{'lc_attributes': {}, 'lc_namespace': ['langc... 1.105279 27 \n",
823
+ "\n",
824
+ " toxicity/v1/score flesch_kincaid_grade_level/v1/score \\\n",
825
+ "0 0.000137 12.5 \n",
826
+ "1 0.000200 9.1 \n",
827
+ "2 0.000289 5.0 \n",
828
+ "3 0.000279 5.8 \n",
829
+ "\n",
830
+ " ari_grade_level/v1/score faithfulness/v1/score \\\n",
831
+ "0 18.4 5 \n",
832
+ "1 12.6 5 \n",
833
+ "2 6.8 1 \n",
834
+ "3 8.8 5 \n",
835
+ "\n",
836
+ " faithfulness/v1/justification relevance/v1/score \\\n",
837
+ "0 The output provided by the model is a direct e... 5 \n",
838
+ "1 The output provided by the model is completely... 4 \n",
839
+ "2 The output claims that you can log a table wit... 5 \n",
840
+ "3 The output claim that \"You can't load_table() ... 4 \n",
841
+ "\n",
842
+ " relevance/v1/justification \n",
843
+ "0 The output provides a comprehensive answer to ... \n",
844
+ "1 The output provides a relevant and accurate ex... \n",
845
+ "2 The output provides a comprehensive answer to ... \n",
846
+ "3 The output provides a relevant and accurate re... "
847
+ ]
848
+ },
849
+ "execution_count": 13,
850
+ "metadata": {},
851
+ "output_type": "execute_result"
852
+ }
853
+ ],
854
+ "source": [
855
+ "results.tables[\"eval_results_table\"]"
856
+ ]
857
+ },
858
+ {
859
+ "cell_type": "code",
860
+ "execution_count": null,
861
+ "metadata": {},
862
+ "outputs": [],
863
+ "source": []
864
+ }
865
+ ],
866
+ "metadata": {
867
+ "application/vnd.databricks.v1+notebook": {
868
+ "dashboards": [],
869
+ "language": "python",
870
+ "notebookMetadata": {
871
+ "pythonIndentUnit": 2
872
+ },
873
+ "notebookName": "LLM Evaluation Examples -- RAG",
874
+ "widgets": {}
875
+ },
876
+ "kernelspec": {
877
+ "display_name": "mlflow-dev-env",
878
+ "language": "python",
879
+ "name": "python3"
880
+ },
881
+ "language_info": {
882
+ "codemirror_mode": {
883
+ "name": "ipython",
884
+ "version": 3
885
+ },
886
+ "file_extension": ".py",
887
+ "mimetype": "text/x-python",
888
+ "name": "python",
889
+ "nbconvert_exporter": "python",
890
+ "pygments_lexer": "ipython3",
891
+ "version": "3.9.19"
892
+ }
893
+ },
894
+ "nbformat": 4,
895
+ "nbformat_minor": 0
896
+ }
push_logs.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Uploading LFS objects: 0% (0/58), 0 B | 0 B/s, done.
question-generation-retrieval-evaluation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements-23feb2025.txt ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Package Version
2
+ ---------------------------------------- ---------
3
+ aiofiles 23.2.1
4
+ aiohappyeyeballs 2.4.0
5
+ aiohttp 3.10.5
6
+ aiosignal 1.3.1
7
+ alembic 1.13.2
8
+ altair 5.4.0
9
+ aniso8601 9.0.1
10
+ annotated-types 0.7.0
11
+ anthropic 0.34.1
12
+ anyio 4.4.0
13
+ appdirs 1.4.4
14
+ appnope 0.1.4
15
+ asgiref 3.8.1
16
+ asttokens 2.4.1
17
+ async-timeout 4.0.3
18
+ attrs 24.2.0
19
+ backcall 0.2.0
20
+ backoff 2.2.1
21
+ bcrypt 4.2.0
22
+ blinker 1.8.2
23
+ Bottleneck 1.3.7
24
+ build 1.2.1
25
+ cachetools 5.5.0
26
+ cattrs 24.1.2
27
+ certifi 2024.7.4
28
+ charset-normalizer 3.3.2
29
+ chroma-hnswlib 0.7.6
30
+ chromadb 0.5.5
31
+ click 8.1.7
32
+ click-plugins 1.1.1
33
+ cligj 0.7.2
34
+ cloudpickle 3.0.0
35
+ coloredlogs 15.0.1
36
+ comm 0.2.2
37
+ contourpy 1.2.0
38
+ cycler 0.12.1
39
+ databricks-sdk 0.31.1
40
+ dataclasses-json 0.6.7
41
+ debugpy 1.6.7
42
+ decorator 5.1.1
43
+ defusedxml 0.7.1
44
+ Deprecated 1.2.14
45
+ distro 1.9.0
46
+ docker 7.1.0
47
+ entrypoints 0.4
48
+ et-xmlfile 1.1.0
49
+ exceptiongroup 1.2.2
50
+ executing 2.0.1
51
+ fastapi 0.112.2
52
+ ffmpy 0.4.0
53
+ filelock 3.15.4
54
+ fiona 1.9.5
55
+ Flask 3.0.3
56
+ flatbuffers 24.3.25
57
+ fonttools 4.25.0
58
+ frozenlist 1.4.1
59
+ fsspec 2024.6.1
60
+ GDAL 3.6.2
61
+ geojson-rewind 1.1.0
62
+ geomet 1.1.0
63
+ geopandas 0.9.0
64
+ gitdb 4.0.11
65
+ GitPython 3.1.43
66
+ google-auth 2.34.0
67
+ googleapis-common-protos 1.63.2
68
+ gradio 4.42.0
69
+ gradio_client 1.3.0
70
+ graphene 3.3
71
+ graphql-core 3.2.3
72
+ graphql-relay 3.2.0
73
+ grpcio 1.66.0
74
+ gunicorn 23.0.0
75
+ h11 0.14.0
76
+ httpcore 1.0.5
77
+ httptools 0.6.1
78
+ httpx 0.27.0
79
+ huggingface-hub 0.24.6
80
+ humanfriendly 10.0
81
+ idna 3.8
82
+ importlib_metadata 8.0.0
83
+ importlib_resources 6.4.4
84
+ ipykernel 6.29.5
85
+ ipython 8.12.0
86
+ itsdangerous 2.2.0
87
+ jedi 0.19.1
88
+ Jinja2 3.1.4
89
+ jiter 0.5.0
90
+ joblib 1.4.2
91
+ jsonpatch 1.33
92
+ jsonpointer 3.0.0
93
+ jsonschema 4.23.0
94
+ jsonschema-specifications 2023.12.1
95
+ jupyter-client 7.3.4
96
+ jupyter_core 5.7.2
97
+ kiwisolver 1.4.4
98
+ kubernetes 30.1.0
99
+ langchain 0.2.14
100
+ langchain-anthropic 0.1.23
101
+ langchain-community 0.2.12
102
+ langchain-core 0.2.34
103
+ langchain-openai 0.1.22
104
+ langchain-text-splitters 0.2.2
105
+ langsmith 0.1.104
106
+ lightning-utilities 0.11.7
107
+ Mako 1.3.5
108
+ Markdown 3.7
109
+ markdown-it-py 3.0.0
110
+ MarkupSafe 2.1.5
111
+ marshmallow 3.22.0
112
+ matplotlib 3.9.2
113
+ matplotlib-inline 0.1.7
114
+ mdurl 0.1.2
115
+ missingno 0.5.2
116
+ mlflow 2.16.0
117
+ mlflow-skinny 2.16.0
118
+ mmh3 4.1.0
119
+ monotonic 1.6
120
+ mpmath 1.3.0
121
+ multidict 6.0.5
122
+ munkres 1.1.4
123
+ mypy-extensions 1.0.0
124
+ narwhals 1.5.5
125
+ nest_asyncio 1.6.0
126
+ networkx 3.2.1
127
+ nltk 3.9.1
128
+ numexpr 2.10.1
129
+ numpy 1.26.4
130
+ oauthlib 3.2.2
131
+ onnxruntime 1.19.0
132
+ openai 1.42.0
133
+ openpyxl 3.0.9
134
+ opentelemetry-api 1.26.0
135
+ opentelemetry-exporter-otlp-proto-common 1.26.0
136
+ opentelemetry-exporter-otlp-proto-grpc 1.26.0
137
+ opentelemetry-instrumentation 0.47b0
138
+ opentelemetry-instrumentation-asgi 0.47b0
139
+ opentelemetry-instrumentation-fastapi 0.47b0
140
+ opentelemetry-proto 1.26.0
141
+ opentelemetry-sdk 1.26.0
142
+ opentelemetry-semantic-conventions 0.47b0
143
+ opentelemetry-util-http 0.47b0
144
+ orjson 3.10.7
145
+ overrides 7.7.0
146
+ packaging 24.1
147
+ pandas 2.0.3
148
+ parso 0.8.4
149
+ patsy 0.5.6
150
+ pexpect 4.9.0
151
+ pickleshare 0.7.5
152
+ pillow 10.4.0
153
+ pip 24.2
154
+ platformdirs 4.2.2
155
+ plotly 5.24.1
156
+ posthog 3.5.2
157
+ prompt_toolkit 3.0.47
158
+ protobuf 4.25.4
159
+ psutil 5.9.0
160
+ ptyprocess 0.7.0
161
+ pure_eval 0.2.3
162
+ pyarrow 17.0.0
163
+ pyasn1 0.6.0
164
+ pyasn1_modules 0.4.0
165
+ pydantic 2.8.2
166
+ pydantic_core 2.20.1
167
+ pydeck 0.9.1
168
+ pydub 0.25.1
169
+ pygbif 0.6.4
170
+ Pygments 2.18.0
171
+ PyMuPDF 1.24.9
172
+ PyMuPDFb 1.24.9
173
+ pyogrio 0.10.0
174
+ pyparsing 3.1.2
175
+ PyPDF2 3.0.1
176
+ PyPika 0.48.9
177
+ pyproj 3.6.1
178
+ pyproject_hooks 1.1.0
179
+ python-dateutil 2.9.0
180
+ python-dotenv 1.0.1
181
+ python-multipart 0.0.9
182
+ pytz 2024.1
183
+ PyYAML 6.0.2
184
+ pyzmq 25.1.2
185
+ referencing 0.35.1
186
+ regex 2024.7.24
187
+ requests 2.32.3
188
+ requests-cache 1.2.1
189
+ requests-oauthlib 2.0.0
190
+ rich 13.7.1
191
+ rpds-py 0.20.0
192
+ rsa 4.9
193
+ Rtree 1.0.1
194
+ ruff 0.6.2
195
+ scikit-learn 1.5.1
196
+ scipy 1.13.1
197
+ seaborn 0.13.2
198
+ semantic-version 2.10.0
199
+ setuptools 72.1.0
200
+ shapely 2.0.5
201
+ shellingham 1.5.4
202
+ six 1.16.0
203
+ smmap 5.0.1
204
+ sniffio 1.3.1
205
+ SQLAlchemy 2.0.32
206
+ sqlparse 0.5.1
207
+ stack-data 0.6.2
208
+ starlette 0.38.2
209
+ statsmodels 0.14.2
210
+ streamlit 1.37.1
211
+ sympy 1.13.2
212
+ tenacity 8.5.0
213
+ threadpoolctl 3.5.0
214
+ tiktoken 0.7.0
215
+ tokenizers 0.20.0
216
+ toml 0.10.2
217
+ tomli 2.0.1
218
+ tomlkit 0.12.0
219
+ torch 2.4.0
220
+ torchmetrics 1.4.1
221
+ tornado 6.1
222
+ tqdm 4.66.5
223
+ traitlets 5.14.3
224
+ typer 0.12.5
225
+ typing_extensions 4.12.2
226
+ typing-inspect 0.9.0
227
+ tzdata 2024.1
228
+ url-normalize 1.4.3
229
+ urllib3 2.2.2
230
+ uvicorn 0.30.6
231
+ uvloop 0.20.0
232
+ watchfiles 0.23.0
233
+ wcwidth 0.2.13
234
+ websocket-client 1.8.0
235
+ websockets 12.0
236
+ Werkzeug 3.0.4
237
+ wget 3.2
238
+ wheel 0.43.0
239
+ wrapt 1.16.0
240
+ yarl 1.9.4
241
+ zipp 3.20.0
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ chromadb
3
+ gradio
4
+ openai
5
+ langchain
6
+ langchain-anthropic
7
+ langchain_community
8
+ pandas
9
+ python-dotenv
10
+ streamlit
11
+ pypdf2
12
+ tiktoken
13
+ streamlit
14
+ langchain-openai
15
+ pymupdf
16
+ openpyxl
retriever-evaluation-tutorial.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
temp_results.csv ADDED
The diff for this file is too large to render. See raw diff