|
|
"""Unified LLM client supporting multiple providers"""
|
|
|
|
|
|
import json
|
|
|
import re
|
|
|
from typing import Optional
|
|
|
from anthropic import Anthropic
|
|
|
from openai import OpenAI
|
|
|
|
|
|
from .config import LLM_PROVIDER, ANTHROPIC_API_KEY, OPENAI_API_KEY
|
|
|
|
|
|
|
|
|
class LLMClient:
|
|
|
"""Unified client for different LLM providers"""
|
|
|
|
|
|
def __init__(self, provider: str = None):
|
|
|
"""Initialize LLM client
|
|
|
|
|
|
Args:
|
|
|
provider: "anthropic" or "openai". If None, uses config default.
|
|
|
"""
|
|
|
self.provider = provider or LLM_PROVIDER
|
|
|
|
|
|
if self.provider == "anthropic":
|
|
|
if not ANTHROPIC_API_KEY:
|
|
|
raise ValueError("ANTHROPIC_API_KEY not set")
|
|
|
self.client = Anthropic(api_key=ANTHROPIC_API_KEY)
|
|
|
self.model = "claude-3-5-sonnet-20241022"
|
|
|
|
|
|
elif self.provider == "openai":
|
|
|
if not OPENAI_API_KEY:
|
|
|
raise ValueError("OPENAI_API_KEY not set")
|
|
|
self.client = OpenAI(api_key=OPENAI_API_KEY)
|
|
|
self.model = "gpt-4o"
|
|
|
|
|
|
else:
|
|
|
raise ValueError(f"Unknown provider: {self.provider}")
|
|
|
|
|
|
async def generate(self, prompt: str, max_tokens: int = 2000) -> str:
|
|
|
"""Generate text from prompt
|
|
|
|
|
|
Args:
|
|
|
prompt: The prompt to send
|
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
|
|
Returns:
|
|
|
Generated text
|
|
|
"""
|
|
|
if self.provider == "anthropic":
|
|
|
response = self.client.messages.create(
|
|
|
model=self.model,
|
|
|
max_tokens=max_tokens,
|
|
|
messages=[{
|
|
|
"role": "user",
|
|
|
"content": prompt
|
|
|
}]
|
|
|
)
|
|
|
return response.content[0].text
|
|
|
|
|
|
elif self.provider == "openai":
|
|
|
response = self.client.chat.completions.create(
|
|
|
model=self.model,
|
|
|
max_tokens=max_tokens,
|
|
|
messages=[{
|
|
|
"role": "user",
|
|
|
"content": prompt
|
|
|
}]
|
|
|
)
|
|
|
return response.choices[0].message.content
|
|
|
|
|
|
async def generate_json(self, prompt: str, max_tokens: int = 3000) -> dict:
|
|
|
"""Generate JSON from prompt
|
|
|
|
|
|
Args:
|
|
|
prompt: The prompt to send
|
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
|
|
Returns:
|
|
|
Parsed JSON dict
|
|
|
"""
|
|
|
text = await self.generate(prompt, max_tokens)
|
|
|
|
|
|
|
|
|
json_match = re.search(r'\{.*\}|\[.*\]', text, re.DOTALL)
|
|
|
if json_match:
|
|
|
return json.loads(json_match.group())
|
|
|
else:
|
|
|
raise ValueError(f"No valid JSON found in response: {text[:200]}")
|
|
|
|
|
|
|
|
|
|
|
|
def get_llm_client() -> LLMClient:
|
|
|
"""Get configured LLM client"""
|
|
|
return LLMClient()
|
|
|
|