· last year · Aug 14, 2024, 02:56 AM
1import asyncio
2import hashlib
3import logging
4import os
5from collections.abc import Callable
6from enum import Enum
7from typing import Literal
8
9import anthropic
10import cohere
11import diskcache as dc # type: ignore
12import openai
13import tiktoken
14import voyageai # type: ignore
15import voyageai.error # type: ignore
16from anthropic import NOT_GIVEN, Anthropic, AsyncAnthropic, NotGiven
17from anthropic.types import MessageParam
18from openai import AsyncOpenAI, RateLimitError
19from openai.types.chat import ChatCompletionMessageParam
20from pydantic import BaseModel, computed_field
21
22from utils.credentials import credentials
23
24logger = logging.getLogger("uvicorn")
25
26# AI Types
27
28
29class AIModel(BaseModel):
30 company: Literal["openai", "anthropic"]
31 model: str
32
33 @computed_field # type: ignore[misc]
34 @property
35 def ratelimit_tpm(self) -> float:
36 match self.company:
37 case "openai":
38 # Tier 5
39 match self.model:
40 case "gpt-4o-mini":
41 return 150000000
42 case "gpt-4o":
43 return 30000000
44 case m if m.startswith("gpt-4-turbo"):
45 return 2000000
46 case _:
47 return 1000000
48 case "anthropic":
49 # Tier 4
50 return 400000
51
52
53class AIMessage(BaseModel):
54 role: Literal["system", "user", "assistant"]
55 content: str
56
57
58class AIEmbeddingModel(BaseModel):
59 company: Literal["openai", "cohere", "voyageai"]
60 model: str
61
62 @computed_field # type: ignore[misc]
63 @property
64 def ratelimit_rpm(self) -> float:
65 match self.company:
66 case "openai":
67 return 10000
68 case "cohere":
69 return 10000
70 case "voyageai":
71 # It says 300RPM but I can only get 0.5 out of it
72 return 0.5
73
74
75class AIEmbeddingType(Enum):
76 DOCUMENT = 1
77 QUERY = 2
78
79
80class AIRerankModel(BaseModel):
81 company: Literal["cohere", "voyageai"]
82 model: str
83
84 @computed_field # type: ignore[misc]
85 @property
86 def ratelimit_rpm(self) -> float:
87 match self.company:
88 case "cohere":
89 return 10000
90 case "voyageai":
91 # It says 100RPM but I can only get 1 out of it
92 return 1
93
94
95# Cache
96os.makedirs("./data/cache", exist_ok=True)
97cache = dc.Cache("./data/cache/ai_cache.db")
98
99
100class AIConnection:
101 openai_client: AsyncOpenAI
102 voyageai_client: voyageai.AsyncClient
103 cohere_client: cohere.AsyncClient
104 anthropic_client: AsyncAnthropic
105 sync_anthropic_client: Anthropic
106 # Share one global Semaphore across all threads
107 cohere_ratelimit_semaphore = asyncio.Semaphore(1)
108 voyageai_ratelimit_semaphore = asyncio.Semaphore(1)
109 openai_ratelimit_semaphore = asyncio.Semaphore(1)
110 anthropic_ratelimit_semaphore = asyncio.Semaphore(1)
111
112 def __init__(self) -> None:
113 self.openai_client = AsyncOpenAI(
114 api_key=credentials.ai.openai_api_key.get_secret_value()
115 )
116 self.anthropic_client = AsyncAnthropic(
117 api_key=credentials.ai.anthropic_api_key.get_secret_value()
118 )
119 self.sync_anthropic_client = Anthropic(
120 api_key=credentials.ai.anthropic_api_key.get_secret_value()
121 )
122 self.voyageai_client = voyageai.AsyncClient(
123 api_key=credentials.ai.voyageai_api_key.get_secret_value()
124 )
125 self.cohere_client = cohere.AsyncClient(
126 api_key=credentials.ai.cohere_api_key.get_secret_value()
127 )
128
129
130# NOTE: API Clients cannot be called from multiple event loops,
131# So every asyncio event loop needs its own API connection
132ai_connections: dict[asyncio.AbstractEventLoop, AIConnection] = {}
133
134
135def get_ai_connection() -> AIConnection:
136 event_loop = asyncio.get_event_loop()
137 if event_loop not in ai_connections:
138 ai_connections[event_loop] = AIConnection()
139 return ai_connections[event_loop]
140
141
142class AIError(Exception):
143 """A class for AI Task Errors"""
144
145
146class AIValueError(AIError, ValueError):
147 """A class for AI Value Errors"""
148
149
150class AITimeoutError(AIError, TimeoutError):
151 """A class for AI Task Timeout Errors"""
152
153
154def ai_num_tokens(model: AIModel, s: str) -> int:
155 if model.company == "anthropic":
156 # Doesn't actually connect to the network
157 return get_ai_connection().sync_anthropic_client.count_tokens(s)
158 elif model.company == "openai":
159 encoding = tiktoken.encoding_for_model(model.model)
160 num_tokens = len(encoding.encode(s))
161 return num_tokens
162
163
164def get_call_cache_key(
165 model: AIModel,
166 messages: list[AIMessage],
167) -> str:
168 # Hash the array of texts
169 md5_hasher = hashlib.md5()
170 md5_hasher.update(model.model_dump_json().encode())
171 for message in messages:
172 md5_hasher.update(md5_hasher.hexdigest().encode())
173 md5_hasher.update(message.model_dump_json().encode())
174 key = md5_hasher.hexdigest()
175
176 return key
177
178
179async def ai_call(
180 model: AIModel,
181 messages: list[AIMessage],
182 *,
183 max_tokens: int = 4096,
184 temperature: float = 0.0,
185 # When using anthropic, the first message must be from the user.
186 # If the first message is not a User, this message will be prepended to the messages.
187 anthropic_initial_message: str | None = "<START>",
188 # If two messages of the same role are given to anthropic, they must be concatenated.
189 # This is the delimiter between concatenated.
190 anthropic_combine_delimiter: str = "\n",
191 # Throw an AITimeoutError after this many retries fail
192 num_ratelimit_retries: int = 10,
193 # Backoff function (Receives index of attempt)
194 backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
195) -> str:
196 cache_key = get_call_cache_key(model, messages)
197 cached_call = cache.get(cache_key)
198
199 if cached_call is not None:
200 return cached_call
201
202 num_tokens_input: int = sum(
203 [ai_num_tokens(model, message.content) for message in messages]
204 )
205
206 return_value: str | None = None
207 match model.company:
208 case "openai":
209 for i in range(num_ratelimit_retries):
210 try:
211 # Guard with ratelimit
212 async with get_ai_connection().openai_ratelimit_semaphore:
213 tpm = model.ratelimit_tpm
214 ratio = 0.95
215 expected_wait = num_tokens_input / (tpm * ratio / 60)
216 await asyncio.sleep(expected_wait)
217
218 def ai_message_to_openai_message_param(
219 message: AIMessage,
220 ) -> ChatCompletionMessageParam:
221 if message.role == "system": # noqa: SIM114
222 return {"role": message.role, "content": message.content}
223 elif message.role == "user": # noqa: SIM114
224 return {"role": message.role, "content": message.content}
225 elif message.role == "assistant":
226 return {"role": message.role, "content": message.content}
227
228 if i > 0:
229 logger.debug("Trying again after RateLimitError...")
230 response = (
231 await get_ai_connection().openai_client.chat.completions.create(
232 model=model.model,
233 messages=[
234 ai_message_to_openai_message_param(message)
235 for message in messages
236 ],
237 temperature=temperature,
238 max_tokens=max_tokens,
239 )
240 )
241 assert response.choices[0].message.content is not None
242 return_value = response.choices[0].message.content
243 break
244 except RateLimitError:
245 logger.warning("OpenAI RateLimitError")
246 async with get_ai_connection().openai_ratelimit_semaphore:
247 await asyncio.sleep(backoff_algo(i))
248 if return_value is None:
249 raise AITimeoutError("Cannot overcome OpenAI RateLimitError")
250
251 case "anthropic":
252 for i in range(num_ratelimit_retries):
253 try:
254 # Guard with ratelimit
255 async with get_ai_connection().anthropic_ratelimit_semaphore:
256 tpm = model.ratelimit_tpm
257 ratio = 0.95
258 expected_wait = num_tokens_input / (tpm * ratio / 60)
259 await asyncio.sleep(expected_wait)
260
261 def ai_message_to_anthropic_message_param(
262 message: AIMessage,
263 ) -> MessageParam:
264 if message.role == "user" or message.role == "assistant":
265 return {"role": message.role, "content": message.content}
266 elif message.role == "system":
267 raise AIValueError(
268 "system not allowed in anthropic message param"
269 )
270
271 if i > 0:
272 logger.debug("Trying again after RateLimitError...")
273
274 # Extract system message if it exists
275 system: str | NotGiven = NOT_GIVEN
276 if len(messages) > 0 and messages[0].role == "system":
277 system = messages[0].content
278 messages = messages[1:]
279 # Insert initial message if necessary
280 if (
281 anthropic_initial_message is not None
282 and len(messages) > 0
283 and messages[0].role != "user"
284 ):
285 messages = [
286 AIMessage(role="user", content=anthropic_initial_message)
287 ] + messages
288 # Combined messages (By combining consecutive messages of the same role)
289 combined_messages: list[AIMessage] = []
290 for message in messages:
291 if (
292 len(combined_messages) == 0
293 or combined_messages[-1].role != message.role
294 ):
295 combined_messages.append(message)
296 else:
297 # Copy before edit
298 combined_messages[-1] = combined_messages[-1].model_copy(
299 deep=True
300 )
301 # Merge consecutive messages with the same role
302 combined_messages[-1].content += (
303 anthropic_combine_delimiter + message.content
304 )
305 # Get the response
306 response_message = (
307 await get_ai_connection().anthropic_client.messages.create(
308 model=model.model,
309 system=system,
310 messages=[
311 ai_message_to_anthropic_message_param(message)
312 for message in combined_messages
313 ],
314 temperature=0.0,
315 max_tokens=max_tokens,
316 )
317 )
318 assert isinstance(
319 response_message.content[0], anthropic.types.TextBlock
320 )
321 assert isinstance(response_message.content[0].text, str)
322 return_value = response_message.content[0].text
323 break
324 except anthropic.RateLimitError as e:
325 logger.warning(f"Anthropic Error: {repr(e)}")
326 async with get_ai_connection().anthropic_ratelimit_semaphore:
327 await asyncio.sleep(backoff_algo(i))
328 if return_value is None:
329 raise AITimeoutError("Cannot overcome Anthropic RateLimitError")
330
331 cache.set(cache_key, return_value)
332 return return_value
333
334
335def get_embeddings_cache_key(
336 model: AIEmbeddingModel, text: str, embedding_type: AIEmbeddingType
337) -> str:
338 key = f"{model.company}||||{model.model}||||{embedding_type.name}||||{hashlib.md5(text.encode()).hexdigest()}"
339 return key
340
341
342async def ai_embedding(
343 model: AIEmbeddingModel,
344 text: str,
345 embedding_type: AIEmbeddingType,
346 *,
347 # Throw an AITimeoutError after this many retries fail
348 num_ratelimit_retries: int = 10,
349 # Backoff function (Receives index of attempt)
350 backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
351) -> list[float]:
352 cache_key = get_embeddings_cache_key(model, text, embedding_type)
353 cached_embedding = cache.get(cache_key)
354
355 if cached_embedding is not None:
356 return cached_embedding
357
358 embedding: list[float] | None = None
359 match model.company:
360 case "openai":
361 for i in range(num_ratelimit_retries):
362 try:
363 async with get_ai_connection().openai_ratelimit_semaphore:
364 await asyncio.sleep(1.0 / model.ratelimit_rpm)
365 response = (
366 await get_ai_connection().openai_client.embeddings.create(
367 input=[text],
368 model=model.model,
369 )
370 )
371 embedding = response.data[0].embedding
372 break
373 except openai.RateLimitError:
374 logger.warning("OpenAI RateLimitError")
375 async with get_ai_connection().openai_ratelimit_semaphore:
376 await asyncio.sleep(backoff_algo(i))
377 if embedding is None:
378 raise AITimeoutError("Cannot overcome OpenAI RateLimitError")
379 case "cohere":
380 for i in range(num_ratelimit_retries):
381 try:
382 async with get_ai_connection().cohere_ratelimit_semaphore:
383 await asyncio.sleep(1.0 / model.ratelimit_rpm)
384 result = await get_ai_connection().cohere_client.embed(
385 texts=[text],
386 model=model.model,
387 input_type=(
388 "search_document"
389 if embedding_type == AIEmbeddingType.DOCUMENT
390 else "search_query"
391 ),
392 )
393 assert isinstance(result.embeddings, list)
394 embedding = result.embeddings[0]
395 break
396 except voyageai.error.RateLimitError:
397 logger.warning("Cohere RateLimitError")
398 async with get_ai_connection().cohere_ratelimit_semaphore:
399 await asyncio.sleep(backoff_algo(i))
400 if embedding is None:
401 raise AITimeoutError("Cannot overcome Cohere RateLimitError")
402 case "voyageai":
403 for i in range(num_ratelimit_retries):
404 try:
405 async with get_ai_connection().voyageai_ratelimit_semaphore:
406 await asyncio.sleep(1.0 / model.ratelimit_rpm)
407 result = await get_ai_connection().voyageai_client.embed(
408 [text],
409 model=model.model,
410 input_type=(
411 "document"
412 if embedding_type == AIEmbeddingType.DOCUMENT
413 else "query"
414 ),
415 )
416 assert isinstance(result.embeddings, list)
417 embedding = result.embeddings[0]
418 break
419 except voyageai.error.RateLimitError:
420 logger.warning("VoyageAI RateLimitError")
421 async with get_ai_connection().voyageai_ratelimit_semaphore:
422 await asyncio.sleep(backoff_algo(i))
423 if embedding is None:
424 raise AITimeoutError("Cannot overcome VoyageAI RateLimitError")
425 cache.set(cache_key, embedding)
426 return embedding
427
428
429def get_rerank_cache_key(
430 model: AIRerankModel, query: str, texts: list[str], top_k: int | None
431) -> str:
432 # Hash the array of texts
433 md5_hasher = hashlib.md5()
434 md5_hasher.update(query.encode())
435 for text in texts:
436 md5_hasher.update(md5_hasher.hexdigest().encode())
437 md5_hasher.update(text.encode())
438 texts_hash = md5_hasher.hexdigest()
439
440 key = f"{model.company}||||{model.model}||||{top_k}||||{texts_hash}"
441 return key
442
443
444# Gets the list of indices that reranks the original texts
445async def ai_rerank(
446 model: AIRerankModel,
447 query: str,
448 texts: list[str],
449 *,
450 top_k: int | None = None,
451 # Throw an AITimeoutError after this many retries fail
452 num_ratelimit_retries: int = 10,
453 # Backoff function (Receives index of attempt)
454 backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
455) -> list[int]:
456 cache_key = get_rerank_cache_key(model, query, texts, top_k)
457 cached_reranking = cache.get(cache_key)
458
459 if cached_reranking is not None:
460 return cached_reranking
461
462 indices: list[int] | None = None
463 match model.company:
464 case "cohere":
465 for i in range(num_ratelimit_retries):
466 try:
467 async with get_ai_connection().cohere_ratelimit_semaphore:
468 await asyncio.sleep(1 / model.ratelimit_rpm)
469 response = await get_ai_connection().cohere_client.rerank(
470 model=model.model,
471 query=query,
472 documents=texts,
473 top_n=top_k,
474 )
475 indices = [result.index for result in response.results]
476 break
477 except cohere.errors.TooManyRequestsError:
478 logger.warning("Cohere RateLimitError")
479 async with get_ai_connection().cohere_ratelimit_semaphore:
480 await asyncio.sleep(backoff_algo(i))
481 if indices is None:
482 raise AITimeoutError("Cannot overcome Cohere RateLimitError")
483 case "voyageai":
484 for i in range(num_ratelimit_retries):
485 try:
486 async with get_ai_connection().voyageai_ratelimit_semaphore:
487 await asyncio.sleep(1 / model.ratelimit_rpm)
488 voyageai_response = (
489 await get_ai_connection().voyageai_client.rerank(
490 query=query,
491 documents=texts,
492 model=model.model,
493 top_k=top_k,
494 )
495 )
496 indices = [
497 int(result.index) for result in voyageai_response.results
498 ]
499 break
500 except voyageai.error.RateLimitError:
501 logger.warning("VoyageAI RateLimitError")
502 async with get_ai_connection().voyageai_ratelimit_semaphore:
503 await asyncio.sleep(backoff_algo(i))
504 if indices is None:
505 raise AITimeoutError("Cannot overcome VoyageAI RateLimitError")
506 cache.set(cache_key, indices)
507 return indices
508