· last year · Jul 31, 2024, 07:00 PM
1import asyncio
2import logging
3import time
4from abc import ABC, abstractmethod
5from collections.abc import AsyncIterator
6from io import BytesIO
7from typing import Literal, Optional, cast, get_args
8from uuid import UUID, uuid4
9
10import anthropic
11import openai
12import tiktoken
13from anthropic import NOT_GIVEN, Anthropic, AsyncAnthropic, NotGiven
14from anthropic.types import MessageParam
15from deepgram import DeepgramClient # type: ignore
16from elevenlabs.client import AsyncElevenLabs
17from openai import AsyncOpenAI, RateLimitError
18from openai.types.chat import ChatCompletionMessageParam
19from pydantic import BaseModel, Field
20from pyht import AsyncClient as AsyncPlayHtClient # type: ignore
21from pyht import TTSOptions
22
23from config import credentials
24
25logger = logging.getLogger("uvicorn")
26
27
28class KVStore(ABC):
29 @abstractmethod
30 async def get(self, key: str) -> Optional[str]: ...
31
32 @abstractmethod
33 async def set(self, key: str, value: str) -> None: ...
34
35
36class AIConnection:
37 openai_client: AsyncOpenAI
38 anthropic_client: AsyncAnthropic
39 sync_anthropic_client: Anthropic
40 eleven_labs_client: AsyncElevenLabs
41 deepgram_client: DeepgramClient
42 play_ht_client: AsyncPlayHtClient
43 # Share one global Semaphore across all threads
44 openai_ratelimit_semaphore = asyncio.Semaphore(1)
45 anthropic_ratelimit_semaphore = asyncio.Semaphore(1)
46
47 def __init__(self) -> None:
48 self.openai_client = AsyncOpenAI(
49 api_key=credentials.ai.openai_api_key.get_secret_value()
50 )
51 self.anthropic_client = AsyncAnthropic(
52 api_key=credentials.ai.anthropic_api_key.get_secret_value()
53 )
54 self.sync_anthropic_client = Anthropic(
55 api_key=credentials.ai.anthropic_api_key.get_secret_value()
56 )
57 self.eleven_labs_client = AsyncElevenLabs(
58 api_key=credentials.ai.elevenlabs_api_key.get_secret_value()
59 )
60 self.deepgram_client = DeepgramClient(
61 credentials.ai.deepgram_api_key.get_secret_value()
62 )
63 self.play_ht_client = AsyncPlayHtClient(
64 credentials.ai.playht_user_id.get_secret_value(),
65 credentials.ai.playht_api_key.get_secret_value(),
66 )
67
68
69# NOTE: API Clients cannot be called from multiple event loops,
70# So every asyncio event loop needs its own API connection
71ai_connections: dict[asyncio.AbstractEventLoop, AIConnection] = {}
72
73
74def get_ai_connection() -> AIConnection:
75 event_loop = asyncio.get_event_loop()
76 if event_loop not in ai_connections:
77 ai_connections[event_loop] = AIConnection()
78 return ai_connections[event_loop]
79
80
81class TaskOutput(BaseModel):
82 id: UUID = Field(default_factory=lambda: uuid4())
83
84
85class AIModel(BaseModel):
86 company: Literal["openai", "anthropic"]
87 model: str
88
89
90class AIMessage(BaseModel):
91 role: Literal["system", "user", "assistant"]
92 content: str
93
94
95class AIError(Exception):
96 """A class for GPT Task Errors"""
97
98
99class AIModerationError(AIError):
100 pass
101
102
103def ai_num_tokens(model: AIModel, s: str) -> int:
104 if model.company == "anthropic":
105 # Doesn't actually connect to the network
106 return get_ai_connection().sync_anthropic_client.count_tokens(s)
107 elif model.company == "openai":
108 encoding = tiktoken.encoding_for_model(model.model)
109 num_tokens = len(encoding.encode(s))
110 return num_tokens
111
112
113async def ai_call(
114 model: AIModel,
115 messages: list[AIMessage],
116 *,
117 max_tokens: int = 4096,
118 temperature: float = 0.0,
119 num_ratelimit_retries: int = 10,
120 # When using anthropic, the first message must be from the user.
121 # If the first message is not a User, this message will be prepended to the messages.
122 anthropic_initial_message: str | None = "<START>",
123 # If two messages of the same role are given to anthropic, they must be concatenated.
124 # This is the delimiter between concatenated.
125 anthropic_combine_delimiter: str = "\n",
126) -> str:
127 if model.company == "openai":
128 for i in range(num_ratelimit_retries):
129 try:
130
131 def ai_message_to_openai_message_param(
132 message: AIMessage,
133 ) -> ChatCompletionMessageParam:
134 if message.role == "system": # noqa: SIM114
135 return {"role": message.role, "content": message.content}
136 elif message.role == "user": # noqa: SIM114
137 return {"role": message.role, "content": message.content}
138 elif message.role == "assistant":
139 return {"role": message.role, "content": message.content}
140
141 if i > 0:
142 logger.debug("Trying again after RateLimitError...")
143 response = (
144 await get_ai_connection().openai_client.chat.completions.create(
145 model=model.model,
146 messages=[
147 ai_message_to_openai_message_param(message)
148 for message in messages
149 ],
150 temperature=temperature,
151 max_tokens=max_tokens,
152 )
153 )
154 if response.choices[0].message.content is None:
155 raise RuntimeError("OpenAI returned nothing")
156 return response.choices[0].message.content
157 except RateLimitError:
158 logger.warning("OpenAI RateLimitError")
159 async with get_ai_connection().openai_ratelimit_semaphore:
160 await asyncio.sleep(1)
161 raise TimeoutError("Cannot overcome OpenAI RateLimitError")
162
163 elif model.company == "anthropic":
164 for i in range(num_ratelimit_retries):
165 try:
166
167 def ai_message_to_anthropic_message_param(
168 message: AIMessage,
169 ) -> MessageParam:
170 if message.role == "user" or message.role == "assistant":
171 return {"role": message.role, "content": message.content}
172 elif message.role == "system":
173 raise RuntimeError(
174 "system not allowed in anthropic message param"
175 )
176
177 if i > 0:
178 logger.debug("Trying again after RateLimitError...")
179
180 # Extract system message if it exists
181 system: str | NotGiven = NOT_GIVEN
182 if len(messages) > 0 and messages[0].role == "system":
183 system = messages[0].content
184 messages = messages[1:]
185 # Insert initial message if necessary
186 if (
187 anthropic_initial_message is not None
188 and len(messages) > 0
189 and messages[0].role != "user"
190 ):
191 messages = [
192 AIMessage(role="user", content=anthropic_initial_message)
193 ] + messages
194 # Combined messages (By combining consecutive messages of the same role)
195 combined_messages: list[AIMessage] = []
196 for message in messages:
197 if (
198 len(combined_messages) == 0
199 or combined_messages[-1].role != message.role
200 ):
201 combined_messages.append(message)
202 else:
203 # Copy before edit
204 combined_messages[-1] = combined_messages[-1].model_copy(
205 deep=True
206 )
207 # Merge consecutive messages with the same role
208 combined_messages[-1].content += (
209 anthropic_combine_delimiter + message.content
210 )
211 # Get the response
212 response_message = (
213 await get_ai_connection().anthropic_client.messages.create(
214 model=model.model,
215 system=system,
216 messages=[
217 ai_message_to_anthropic_message_param(message)
218 for message in combined_messages
219 ],
220 temperature=0.0,
221 max_tokens=max_tokens,
222 )
223 )
224 return response_message.content[0].text
225 except anthropic.RateLimitError as e:
226 logger.warning(f"Anthropic Error: {repr(e)}")
227 async with get_ai_connection().anthropic_ratelimit_semaphore:
228 await asyncio.sleep(1)
229 raise TimeoutError("Cannot overcome Anthropic RateLimitError")
230
231
232async def ai_stt(buffer: BytesIO) -> str:
233 try:
234 response = await get_ai_connection().openai_client.audio.transcriptions.create(
235 model="whisper-1",
236 file=buffer,
237 )
238 return response.text
239 except openai.BadRequestError as e:
240 # Return empty string for audio that's too short
241 if e.code == "audio_too_short":
242 return ""
243 else:
244 raise
245
246
247class AIVoiceModel(BaseModel):
248 company: Literal["openai", "elevenlabs", "playht"]
249 voice: str
250 speed: float = 1
251
252
253async def ai_tts(
254 transcript: str,
255 *,
256 voice_model: Optional[AIVoiceModel] = None,
257 low_latency: bool = False,
258) -> AsyncIterator[bytes]:
259 if voice_model is None:
260 voice_model = AIVoiceModel(company="openai", voice="nova")
261
262 async def log_bytes_iterator(
263 bytes_generator: AsyncIterator[bytes],
264 ) -> AsyncIterator[bytes]:
265 t1: float = time.time()
266 t2: Optional[float] = None
267 async for chunk in bytes_generator:
268 if t2 is None:
269 t2 = time.time()
270 yield chunk
271 if t2 is None:
272 t2 = time.time()
273 t3: float = time.time()
274 logger.debug(
275 f"TTS Latency ({t2-t1:.3f}): {repr(transcript)} (low_latency={low_latency})"
276 )
277 logger.debug(f"TTS ({t3-t2:.3f}): {repr(transcript)}")
278
279 bytes_generator: AsyncIterator[bytes]
280 match voice_model.company:
281 case "openai":
282 # Typing
283 openai_voice_type = Literal[
284 "alloy", "echo", "fable", "onyx", "nova", "shimmer"
285 ]
286
287 def voice_to_openai_voice(voice: str) -> openai_voice_type:
288 if voice not in get_args(openai_voice_type):
289 raise ValueError(
290 f"voice must be one of {get_args(openai_voice_type)}, received {voice}"
291 )
292 return cast(openai_voice_type, voice)
293
294 # Run it
295 response = await get_ai_connection().openai_client.audio.speech.with_raw_response.create(
296 model="tts-1",
297 voice=voice_to_openai_voice(voice_model.voice),
298 input=transcript,
299 response_format="aac",
300 speed=voice_model.speed,
301 )
302 bytes_generator = response.http_response.aiter_bytes(chunk_size=2048)
303 case "playht":
304 options = TTSOptions(voice=voice_model.voice, speed=voice_model.speed)
305 try:
306 bytes_generator = get_ai_connection().play_ht_client.tts(
307 transcript.strip().replace(" ", " "), options
308 )
309 except Exception as e:
310 logger.error(f"Error occurred in Play.ht TTS: {e}")
311 raise
312 case "elevenlabs":
313 if voice_model.speed != 1:
314 logger.warning(
315 f"elevenlabs does not support speed change to {voice_model.speed}. Ignoring."
316 )
317 bytes_generator = await get_ai_connection().eleven_labs_client.generate(
318 text=transcript.strip().replace(" ", " "),
319 voice=voice_model.voice,
320 optimize_streaming_latency=3 if low_latency else None,
321 model="eleven_turbo_v2",
322 stream=True,
323 )
324 return log_bytes_iterator(bytes_generator=bytes_generator)