File size: 2,965 Bytes
1e6d6a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Unified LLM client supporting multiple providers"""

import json
import re
from typing import Optional
from anthropic import Anthropic
from openai import OpenAI

from .config import LLM_PROVIDER, ANTHROPIC_API_KEY, OPENAI_API_KEY


class LLMClient:
    """Unified client for different LLM providers"""

    def __init__(self, provider: str = None):
        """Initialize LLM client



        Args:

            provider: "anthropic" or "openai". If None, uses config default.

        """
        self.provider = provider or LLM_PROVIDER

        if self.provider == "anthropic":
            if not ANTHROPIC_API_KEY:
                raise ValueError("ANTHROPIC_API_KEY not set")
            self.client = Anthropic(api_key=ANTHROPIC_API_KEY)
            self.model = "claude-3-5-sonnet-20241022"

        elif self.provider == "openai":
            if not OPENAI_API_KEY:
                raise ValueError("OPENAI_API_KEY not set")
            self.client = OpenAI(api_key=OPENAI_API_KEY)
            self.model = "gpt-4o"  # or "gpt-4o-mini" for cheaper

        else:
            raise ValueError(f"Unknown provider: {self.provider}")

    async def generate(self, prompt: str, max_tokens: int = 2000) -> str:
        """Generate text from prompt



        Args:

            prompt: The prompt to send

            max_tokens: Maximum tokens to generate



        Returns:

            Generated text

        """
        if self.provider == "anthropic":
            response = self.client.messages.create(
                model=self.model,
                max_tokens=max_tokens,
                messages=[{
                    "role": "user",
                    "content": prompt
                }]
            )
            return response.content[0].text

        elif self.provider == "openai":
            response = self.client.chat.completions.create(
                model=self.model,
                max_tokens=max_tokens,
                messages=[{
                    "role": "user",
                    "content": prompt
                }]
            )
            return response.choices[0].message.content

    async def generate_json(self, prompt: str, max_tokens: int = 3000) -> dict:
        """Generate JSON from prompt



        Args:

            prompt: The prompt to send

            max_tokens: Maximum tokens to generate



        Returns:

            Parsed JSON dict

        """
        text = await self.generate(prompt, max_tokens)

        # Extract JSON from response (might be wrapped in markdown)
        json_match = re.search(r'\{.*\}|\[.*\]', text, re.DOTALL)
        if json_match:
            return json.loads(json_match.group())
        else:
            raise ValueError(f"No valid JSON found in response: {text[:200]}")


# Global instance
def get_llm_client() -> LLMClient:
    """Get configured LLM client"""
    return LLMClient()