File size: 2,965 Bytes
1e6d6a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
"""Unified LLM client supporting multiple providers"""
import json
import re
from typing import Optional
from anthropic import Anthropic
from openai import OpenAI
from .config import LLM_PROVIDER, ANTHROPIC_API_KEY, OPENAI_API_KEY
class LLMClient:
"""Unified client for different LLM providers"""
def __init__(self, provider: str = None):
"""Initialize LLM client
Args:
provider: "anthropic" or "openai". If None, uses config default.
"""
self.provider = provider or LLM_PROVIDER
if self.provider == "anthropic":
if not ANTHROPIC_API_KEY:
raise ValueError("ANTHROPIC_API_KEY not set")
self.client = Anthropic(api_key=ANTHROPIC_API_KEY)
self.model = "claude-3-5-sonnet-20241022"
elif self.provider == "openai":
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY not set")
self.client = OpenAI(api_key=OPENAI_API_KEY)
self.model = "gpt-4o" # or "gpt-4o-mini" for cheaper
else:
raise ValueError(f"Unknown provider: {self.provider}")
async def generate(self, prompt: str, max_tokens: int = 2000) -> str:
"""Generate text from prompt
Args:
prompt: The prompt to send
max_tokens: Maximum tokens to generate
Returns:
Generated text
"""
if self.provider == "anthropic":
response = self.client.messages.create(
model=self.model,
max_tokens=max_tokens,
messages=[{
"role": "user",
"content": prompt
}]
)
return response.content[0].text
elif self.provider == "openai":
response = self.client.chat.completions.create(
model=self.model,
max_tokens=max_tokens,
messages=[{
"role": "user",
"content": prompt
}]
)
return response.choices[0].message.content
async def generate_json(self, prompt: str, max_tokens: int = 3000) -> dict:
"""Generate JSON from prompt
Args:
prompt: The prompt to send
max_tokens: Maximum tokens to generate
Returns:
Parsed JSON dict
"""
text = await self.generate(prompt, max_tokens)
# Extract JSON from response (might be wrapped in markdown)
json_match = re.search(r'\{.*\}|\[.*\]', text, re.DOTALL)
if json_match:
return json.loads(json_match.group())
else:
raise ValueError(f"No valid JSON found in response: {text[:200]}")
# Global instance
def get_llm_client() -> LLMClient:
"""Get configured LLM client"""
return LLMClient()
|