# Copyright (c) 2023-2024 Philip May
# Copyright (c) 2024 Philip May, Deutsche Telekom AG
# Copyright (c) 2024 Alaeddine Abdessalem, Deutsche Telekom AG
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT
"""OpenAI specific module.
Hint:
Use pip to install the necessary dependencies for this module:
``pip install mltb2[openai]``
"""
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Union
import tiktoken
import yaml
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion
from tiktoken.core import Encoding
from tqdm import tqdm
[docs]@dataclass
class OpenAiTokenCounter:
"""Count OpenAI tokens.
Args:
model_name:
The OpenAI model name. Some examples:
* ``gpt-4``
* ``gpt-3.5-turbo``
* ``text-davinci-003``
* ``text-embedding-ada-002``
show_progress_bar: Show a progressbar during processing.
"""
model_name: str
encoding: Encoding = field(init=False, repr=False)
show_progress_bar: bool = False
def __post_init__(self) -> None:
"""Do post init."""
self.encoding = tiktoken.encoding_for_model(self.model_name)
[docs] def __call__(self, text: Union[str, Iterable]) -> Union[int, List[int]]:
"""Count tokens for text.
Args:
text: The text for which the tokens are to be counted.
Returns:
The number of tokens if text was just a ``str``.
If text is an ``Iterable`` then a ``list`` of number of tokens.
"""
if isinstance(text, str):
tokenized_text = self.encoding.encode(text)
return len(tokenized_text)
else:
counts = []
for t in tqdm(text, disable=not self.show_progress_bar):
tokenized_text = self.encoding.encode(t)
counts.append(len(tokenized_text))
return counts
[docs]@dataclass
class OpenAiChatResult:
"""Result of an OpenAI chat completion.
If you want to convert this to a ``dict`` use ``asdict(open_ai_chat_result)``
from the ``dataclasses`` module.
See Also:
OpenAI API reference: `The chat completion object <https://platform.openai.com/docs/api-reference/chat/object>`_
Args:
content: the result of the OpenAI completion
model: model name which has been used
prompt_tokens: number of tokens of the prompt
completion_tokens: number of tokens of the completion (``content``)
total_tokens: number of total tokens (``prompt_tokens + content_tokens``)
finish_reason: The reason why the completion stopped.
* ``stop``: Means the API returned the full completion without running into any token limit.
* ``length``: Means the API stopped the completion because of running into a token limit.
* ``content_filter``: When content was omitted due to a flag from the OpenAI content filters.
* ``tool_calls``: When the model called a tool.
* ``function_call`` (deprecated): When the model called a function.
completion_args: The arguments which have been used for the completion. Examples:
* ``model``: always set
* ``max_tokens``: only set if ``completion_kwargs`` contained ``max_tokens``
* ``temperature``: only set if ``completion_kwargs`` contained ``temperature``
* ``top_p``: only set if ``completion_kwargs`` contained ``top_p``
"""
content: Optional[str] = None
model: Optional[str] = None
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
total_tokens: Optional[int] = None
finish_reason: Optional[str] = None
completion_args: Optional[Dict[str, Any]] = None
[docs] @classmethod
def from_chat_completion(
cls,
chat_completion: ChatCompletion,
completion_kwargs: Optional[Dict[str, Any]] = None,
):
"""Construct this class from an OpenAI ``ChatCompletion`` object.
Args:
chat_completion: The OpenAI ``ChatCompletion`` object.
completion_kwargs: The arguments which have been used for the completion.
Returns:
The constructed class.
"""
result = {}
result["completion_args"] = completion_kwargs
chat_completion_dict = chat_completion.model_dump()
result["model"] = chat_completion_dict.get("model")
usage = chat_completion_dict.get("usage")
if usage is not None:
result["prompt_tokens"] = usage.get("prompt_tokens")
result["completion_tokens"] = usage.get("completion_tokens")
result["total_tokens"] = usage.get("total_tokens")
choices = chat_completion_dict.get("choices")
if choices is not None and len(choices) > 0:
choice = choices[0]
result["finish_reason"] = choice.get("finish_reason")
message = choice.get("message")
if message is not None:
result["content"] = message.get("content")
return cls(**result) # type: ignore[arg-type]
[docs]def remove_openai_tokens(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Remove OpenAI special tokens from the messages.
These tokens are ``<|im_start|>`` and ``<|im_end|>`` and they can cause problems when passed to the OpenAI API.
Args:
messages: The OpenAI messages.
Returns:
The messages without OpenAI special tokens.
"""
result = messages.copy()
for d in result:
d["content"] = d["content"].replace("<|im_start|>", "")
d["content"] = d["content"].replace("<|im_end|>", "")
return result
[docs]@dataclass
class OpenAiChat:
"""Tool to interact with OpenAI chat models.
This also be constructed with :meth:`~OpenAiChat.from_yaml`.
See Also:
OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
Args:
api_key: The OpenAI API key.
model: The OpenAI model name.
"""
model: str
client: Union[OpenAI, AzureOpenAI] = field(init=False, repr=False)
async_client: Union[AsyncOpenAI, AsyncAzureOpenAI] = field(init=False, repr=False)
api_key: Optional[str] = None
def __post_init__(self) -> None:
"""Do post init."""
self.client = OpenAI(api_key=self.api_key)
self.async_client = AsyncOpenAI(api_key=self.api_key)
[docs] @classmethod
def from_yaml(cls, yaml_file, api_key: Optional[str] = None, **kwargs):
"""Construct this class from a yaml file.
If the ``api_key`` is not set in the yaml file,
it will be loaded from the environment variable ``OPENAI_API_KEY``.
Args:
yaml_file: The yaml file.
api_key: The OpenAI API key.
kwargs: extra kwargs to override parameters
Returns:
The constructed class.
"""
with open(yaml_file, "r") as file:
completion_kwargs = yaml.safe_load(file)
# set api_key according to this priority:
# method parameter > yaml > environment variable
api_key = api_key or completion_kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
completion_kwargs["api_key"] = api_key
if kwargs:
completion_kwargs.update(kwargs)
return cls(**completion_kwargs)
[docs] def create_completions(
self,
prompt: Union[str, List[Dict[str, str]]],
completion_kwargs: Optional[Dict[str, Any]] = None,
clean_openai_tokens: bool = False,
) -> OpenAiChatResult:
"""Create a model response for the given prompt (chat conversation).
Args:
prompt: The prompt for the model.
completion_kwargs: Keyword arguments for the OpenAI completion.
- ``model`` can not be set via ``completion_kwargs``! Please set the ``model`` in the initializer.
- ``messages`` can not be set via ``completion_kwargs``! Please set the ``prompt`` argument.
Also see:
- ``openai.resources.chat.completions.Completions.create()``
- OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
clean_openai_tokens: Remove OpenAI special tokens from the prompt.
Returns:
The result of the OpenAI completion.
"""
if isinstance(prompt, list):
for message in prompt:
if "role" not in message or "content" not in message:
raise ValueError(
"If prompt is a list of messages, each message must have a 'role' and 'content' key!"
)
if message["role"] not in ["system", "user", "assistant", "tool"]:
raise ValueError(
"If prompt is a list of messages, each message must have a 'role' key with one of the values "
"'system', 'user', 'assistant' or 'tool'!"
)
if completion_kwargs is not None:
# check keys of completion_kwargs
if "model" in completion_kwargs:
raise ValueError(
"'model' can not be set via 'completion_kwargs'! Please set the 'model' in the initializer."
)
if "messages" in completion_kwargs:
raise ValueError(
"'messages' can not be set via 'completion_kwargs'! Please set the 'prompt' argument."
)
else:
completion_kwargs = {} # set default value
completion_kwargs["model"] = self.model
messages = [{"role": "user", "content": prompt}] if isinstance(prompt, str) else prompt
if clean_openai_tokens:
messages = remove_openai_tokens(messages)
chat_completion = self.client.chat.completions.create(
messages=messages, # type: ignore[arg-type]
**completion_kwargs,
)
result = OpenAiChatResult.from_chat_completion(chat_completion, completion_kwargs=completion_kwargs)
return result
[docs] async def create_completions_async(
self,
prompt: Union[str, List[Dict[str, str]]],
completion_kwargs: Optional[Dict[str, Any]] = None,
clean_openai_tokens: bool = False,
) -> OpenAiChatResult:
"""Create a model response for the given prompt (chat conversation).
Args:
prompt: The prompt for the model.
completion_kwargs: Keyword arguments for the OpenAI completion.
- ``model`` can not be set via ``completion_kwargs``! Please set the ``model`` in the initializer.
- ``messages`` can not be set via ``completion_kwargs``! Please set the ``prompt`` argument.
Also see:
- ``openai.resources.chat.completions.Completions.create()``
- OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
clean_openai_tokens: Remove OpenAI special tokens from the prompt.
Returns:
The result of the OpenAI completion.
"""
if isinstance(prompt, list):
for message in prompt:
if "role" not in message or "content" not in message:
raise ValueError(
"If prompt is a list of messages, each message must have a 'role' and 'content' key!"
)
if message["role"] not in ["system", "user", "assistant", "tool"]:
raise ValueError(
"If prompt is a list of messages, each message must have a 'role' key with one of the values "
"'system', 'user', 'assistant' or 'tool'!"
)
if completion_kwargs is not None:
# check keys of completion_kwargs
if "model" in completion_kwargs:
raise ValueError(
"'model' can not be set via 'completion_kwargs'! Please set the 'model' in the initializer."
)
if "messages" in completion_kwargs:
raise ValueError(
"'messages' can not be set via 'completion_kwargs'! Please set the 'prompt' argument."
)
else:
completion_kwargs = {} # set default value
completion_kwargs["model"] = self.model
messages = [{"role": "user", "content": prompt}] if isinstance(prompt, str) else prompt
if clean_openai_tokens:
messages = remove_openai_tokens(messages)
chat_completion = await self.async_client.chat.completions.create(
messages=messages, # type: ignore[arg-type]
**completion_kwargs,
)
result = OpenAiChatResult.from_chat_completion(chat_completion, completion_kwargs=completion_kwargs)
return result
# there is a limitation with python dataclasses when it comes to defining a subclass with positional arguments, while
# the parent class already defines keyword arguemnts (positional arguments cannot follow keyword arguments)
# workaroung is defined here: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
[docs]@dataclass
class _OpenAiAzureChatBase:
azure_endpoint: str
[docs]@dataclass
class OpenAiAzureChat(OpenAiChat, _OpenAiAzureChatBase):
"""Tool to interact with Azure OpenAI chat models.
This can also be constructed with :meth:`~OpenAiChat.from_yaml`.
See Also:
* OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
* `Quickstart: Get started generating text using Azure OpenAI Service <https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart?tabs=command-line&pivots=programming-language-python>`_
Args:
api_key: The OpenAI API key.
model: The OpenAI model name.
api_version: The OpenAI API version.
A common value for this is ``2023-05-15``.
azure_endpoint: The Azure endpoint.
"""
api_version: Optional[str] = None
api_key: Optional[str] = None
azure_ad_token: Optional[str] = None
def __post_init__(self) -> None:
"""Do post init."""
self.client = AzureOpenAI(
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_ad_token=self.azure_ad_token,
)
self.async_client = AsyncAzureOpenAI(
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_ad_token=self.azure_ad_token,
)
[docs] @classmethod
def from_yaml(cls, yaml_file, api_key: Optional[str] = None, azure_ad_token: Optional[str] = None, **kwargs):
"""Construct this class from a yaml file.
If the ``api_key`` is not set in the yaml file,
it will be loaded from the environment variable ``OPENAI_API_KEY``.
Args:
yaml_file: The yaml file.
api_key: The OpenAI API key.
azure_ad_token: Azure AD token
kwargs: extra kwargs to override parameters
Returns:
The constructed class.
"""
with open(yaml_file, "r") as file:
completion_kwargs = yaml.safe_load(file)
# set azure_ad_token according to this priority:
# method parameter > yaml > environment variable
azure_ad_token = azure_ad_token or completion_kwargs.get("AZURE_AD_TOKEN") or os.getenv("AZURE_AD_TOKEN")
return super().from_yaml(yaml_file, api_key=api_key, azure_ad_token=azure_ad_token, **kwargs)