mcp-generator / src /llm_client.py
visproj's picture
Upload 17 files
1e6d6a3 verified
raw
history blame
2.97 kB
"""Unified LLM client supporting multiple providers"""
import json
import re
from typing import Optional
from anthropic import Anthropic
from openai import OpenAI
from .config import LLM_PROVIDER, ANTHROPIC_API_KEY, OPENAI_API_KEY
class LLMClient:
"""Unified client for different LLM providers"""
def __init__(self, provider: str = None):
"""Initialize LLM client
Args:
provider: "anthropic" or "openai". If None, uses config default.
"""
self.provider = provider or LLM_PROVIDER
if self.provider == "anthropic":
if not ANTHROPIC_API_KEY:
raise ValueError("ANTHROPIC_API_KEY not set")
self.client = Anthropic(api_key=ANTHROPIC_API_KEY)
self.model = "claude-3-5-sonnet-20241022"
elif self.provider == "openai":
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY not set")
self.client = OpenAI(api_key=OPENAI_API_KEY)
self.model = "gpt-4o" # or "gpt-4o-mini" for cheaper
else:
raise ValueError(f"Unknown provider: {self.provider}")
async def generate(self, prompt: str, max_tokens: int = 2000) -> str:
"""Generate text from prompt
Args:
prompt: The prompt to send
max_tokens: Maximum tokens to generate
Returns:
Generated text
"""
if self.provider == "anthropic":
response = self.client.messages.create(
model=self.model,
max_tokens=max_tokens,
messages=[{
"role": "user",
"content": prompt
}]
)
return response.content[0].text
elif self.provider == "openai":
response = self.client.chat.completions.create(
model=self.model,
max_tokens=max_tokens,
messages=[{
"role": "user",
"content": prompt
}]
)
return response.choices[0].message.content
async def generate_json(self, prompt: str, max_tokens: int = 3000) -> dict:
"""Generate JSON from prompt
Args:
prompt: The prompt to send
max_tokens: Maximum tokens to generate
Returns:
Parsed JSON dict
"""
text = await self.generate(prompt, max_tokens)
# Extract JSON from response (might be wrapped in markdown)
json_match = re.search(r'\{.*\}|\[.*\]', text, re.DOTALL)
if json_match:
return json.loads(json_match.group())
else:
raise ValueError(f"No valid JSON found in response: {text[:200]}")
# Global instance
def get_llm_client() -> LLMClient:
"""Get configured LLM client"""
return LLMClient()