# Copyright (c) 2023-2025 Philip May
# Copyright (c) 2024-2025 Philip May, Deutsche Telekom AG
# Copyright (c) 2024 Alaeddine Abdessalem, Deutsche Telekom AG
# Copyright (c) 2025 Sijun John Tu, Deutsche Telekom AG
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT
"""OpenAI specific module.
Hint:
Use pip to install the necessary dependencies for this module:
``pip install mltb2[openai]``
"""
import os
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Optional, Union, cast
import tiktoken
import yaml
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.lib.azure import AzureADTokenProvider
from openai.types.chat import ChatCompletion
from tiktoken.core import Encoding
from tqdm import tqdm
[docs]
@dataclass
class OpenAiTokenCounter:
"""Count OpenAI tokens.
Args:
model_name:
The OpenAI model name. Some examples:
* ``gpt-4``
* ``gpt-3.5-turbo``
* ``text-davinci-003``
* ``text-embedding-ada-002``
show_progress_bar: Show a progressbar during processing.
"""
model_name: str
encoding: Encoding = field(init=False, repr=False)
show_progress_bar: bool = False
def __post_init__(self) -> None:
"""Do post init."""
self.encoding = tiktoken.encoding_for_model(self.model_name)
[docs]
def __call__(self, text: Union[str, Iterable]) -> Union[int, list[int]]:
"""Count tokens for text.
Args:
text: The text for which the tokens are to be counted.
Returns:
The number of tokens if text was just a ``str``.
If text is an ``Iterable`` then a ``list`` of number of tokens.
"""
if isinstance(text, str):
tokenized_text = self.encoding.encode(text)
return len(tokenized_text)
else:
counts = []
for t in tqdm(text, disable=not self.show_progress_bar):
tokenized_text = self.encoding.encode(t)
counts.append(len(tokenized_text))
return counts
[docs]
@dataclass
class OpenAiChatResult:
"""Result of an OpenAI chat completion.
If you want to convert this to a ``dict`` use ``asdict(open_ai_chat_result)``
from the ``dataclasses`` module.
See Also:
OpenAI API reference: `The chat completion object <https://platform.openai.com/docs/api-reference/chat/object>`_
Args:
content: the result of the OpenAI completion
model: model name which has been used
prompt_tokens: number of tokens of the prompt
completion_tokens: number of tokens of the completion (``content``)
total_tokens: number of total tokens (``prompt_tokens + content_tokens``)
finish_reason: The reason why the completion stopped.
* ``stop``: Means the API returned the full completion without running into any token limit.
* ``length``: Means the API stopped the completion because of running into a token limit.
* ``content_filter``: When content was omitted due to a flag from the OpenAI content filters.
* ``tool_calls``: When the model called a tool.
* ``function_call`` (deprecated): When the model called a function.
completion_args: The arguments which have been used for the completion. Examples:
* ``model``: always set
* ``max_tokens``: only set if ``completion_kwargs`` contained ``max_tokens``
* ``temperature``: only set if ``completion_kwargs`` contained ``temperature``
* ``top_p``: only set if ``completion_kwargs`` contained ``top_p``
"""
content: Optional[str] = None
model: Optional[str] = None
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
total_tokens: Optional[int] = None
finish_reason: Optional[str] = None
completion_args: Optional[dict[str, Any]] = None
[docs]
@classmethod
def from_chat_completion(
cls,
chat_completion: ChatCompletion,
completion_kwargs: Optional[dict[str, Any]] = None,
):
"""Construct this class from an OpenAI ``ChatCompletion`` object.
Args:
chat_completion: The OpenAI ``ChatCompletion`` object.
completion_kwargs: The arguments which have been used for the completion.
Returns:
The constructed class.
"""
result = {}
result["completion_args"] = completion_kwargs
chat_completion_dict = chat_completion.model_dump()
result["model"] = chat_completion_dict.get("model")
usage = chat_completion_dict.get("usage")
if usage is not None:
result["prompt_tokens"] = usage.get("prompt_tokens")
result["completion_tokens"] = usage.get("completion_tokens")
result["total_tokens"] = usage.get("total_tokens")
choices = chat_completion_dict.get("choices")
if choices is not None and len(choices) > 0:
choice = choices[0]
result["finish_reason"] = choice.get("finish_reason")
message = choice.get("message")
if message is not None:
result["content"] = message.get("content")
return cls(**result) # type: ignore[arg-type]
[docs]
def remove_openai_tokens(messages: list[dict[str, str]]) -> list[dict[str, str]]:
"""Remove OpenAI special tokens from the messages.
These tokens are ``<|im_start|>`` and ``<|im_end|>`` and they can cause problems when passed to the OpenAI API.
Args:
messages: The OpenAI messages.
Returns:
The messages without OpenAI special tokens.
"""
result = messages.copy()
for d in result:
d["content"] = d["content"].replace("<|im_start|>", "")
d["content"] = d["content"].replace("<|im_end|>", "")
return result
[docs]
@dataclass
class OpenAiChat:
"""Tool to interact with OpenAI chat models.
This also be constructed with :meth:`~OpenAiChat.from_yaml`.
See Also:
OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
Args:
api_key: The OpenAI API key.
model: The OpenAI model name.
base_url: change the base url - this is useful for Perplexity
"""
model: str
client: Union[OpenAI, AzureOpenAI] = field(init=False, repr=False)
async_client: Union[AsyncOpenAI, AsyncAzureOpenAI] = field(init=False, repr=False)
api_key: Optional[str] = None
base_url: Optional[str] = None
def __post_init__(self) -> None:
"""Do post init."""
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
self.async_client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
[docs]
@classmethod
def from_yaml(cls, yaml_file, api_key: Optional[str] = None, **kwargs):
"""Construct this class from a yaml file.
If the ``api_key`` is not set in the yaml file,
it will be loaded from the environment variable ``OPENAI_API_KEY``.
Args:
yaml_file: The yaml file.
api_key: The OpenAI API key.
kwargs: extra kwargs to override parameters
Returns:
The constructed class.
"""
with open(yaml_file, "r") as file:
completion_kwargs = yaml.safe_load(file)
# set api_key according to this priority:
# method parameter > yaml > environment variable
api_key = api_key or completion_kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
completion_kwargs["api_key"] = api_key
if kwargs:
completion_kwargs.update(kwargs)
return cls(**completion_kwargs)
[docs]
def create_completions(
self,
prompt: Union[str, list[dict[str, str]]],
completion_kwargs: Optional[dict[str, Any]] = None,
clean_openai_tokens: bool = False,
) -> OpenAiChatResult:
"""Create a model response for the given prompt (chat conversation).
Args:
prompt: The prompt for the model.
completion_kwargs: Keyword arguments for the OpenAI completion.
- ``model`` can not be set via ``completion_kwargs``! Please set the ``model`` in the initializer.
- ``messages`` can not be set via ``completion_kwargs``! Please set the ``prompt`` argument.
Also see:
- ``openai.resources.chat.completions.Completions.create()``
- OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
clean_openai_tokens: Remove OpenAI special tokens from the prompt.
Returns:
The result of the OpenAI completion.
"""
if isinstance(prompt, list):
for message in prompt:
if "role" not in message or "content" not in message:
raise ValueError(
"If prompt is a list of messages, each message must have a 'role' and 'content' key!"
)
if message["role"] not in ["system", "user", "assistant", "tool"]:
raise ValueError(
"If prompt is a list of messages, each message must have a 'role' key with one of the values "
"'system', 'user', 'assistant' or 'tool'!"
)
if completion_kwargs is not None:
# check keys of completion_kwargs
if "model" in completion_kwargs:
raise ValueError(
"'model' can not be set via 'completion_kwargs'! Please set the 'model' in the initializer."
)
if "messages" in completion_kwargs:
raise ValueError(
"'messages' can not be set via 'completion_kwargs'! Please set the 'prompt' argument."
)
else:
completion_kwargs = {} # set default value
completion_kwargs["model"] = self.model
messages = [{"role": "user", "content": prompt}] if isinstance(prompt, str) else prompt
if clean_openai_tokens:
messages = remove_openai_tokens(messages)
chat_completion = self.client.chat.completions.create(
messages=messages, # type: ignore[arg-type]
**completion_kwargs,
)
result = OpenAiChatResult.from_chat_completion(chat_completion, completion_kwargs=completion_kwargs)
return result
[docs]
async def create_completions_async(
self,
prompt: Union[str, list[dict[str, str]]],
completion_kwargs: Optional[dict[str, Any]] = None,
clean_openai_tokens: bool = False,
) -> OpenAiChatResult:
"""Create a model response for the given prompt (chat conversation).
Args:
prompt: The prompt for the model.
completion_kwargs: Keyword arguments for the OpenAI completion.
- ``model`` can not be set via ``completion_kwargs``! Please set the ``model`` in the initializer.
- ``messages`` can not be set via ``completion_kwargs``! Please set the ``prompt`` argument.
Also see:
- ``openai.resources.chat.completions.Completions.create()``
- OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
clean_openai_tokens: Remove OpenAI special tokens from the prompt.
Returns:
The result of the OpenAI completion.
"""
if isinstance(prompt, list):
for message in prompt:
if "role" not in message or "content" not in message:
raise ValueError(
"If prompt is a list of messages, each message must have a 'role' and 'content' key!"
)
if message["role"] not in ["system", "user", "assistant", "tool"]:
raise ValueError(
"If prompt is a list of messages, each message must have a 'role' key with one of the values "
"'system', 'user', 'assistant' or 'tool'!"
)
if completion_kwargs is not None:
# check keys of completion_kwargs
if "model" in completion_kwargs:
raise ValueError(
"'model' can not be set via 'completion_kwargs'! Please set the 'model' in the initializer."
)
if "messages" in completion_kwargs:
raise ValueError(
"'messages' can not be set via 'completion_kwargs'! Please set the 'prompt' argument."
)
else:
completion_kwargs = {} # set default value
completion_kwargs["model"] = self.model
messages = [{"role": "user", "content": prompt}] if isinstance(prompt, str) else prompt
if clean_openai_tokens:
messages = remove_openai_tokens(messages)
chat_completion = await self.async_client.chat.completions.create(
messages=messages, # type: ignore[arg-type]
**completion_kwargs,
)
result = OpenAiChatResult.from_chat_completion(chat_completion, completion_kwargs=completion_kwargs)
return result
# there is a limitation with python dataclasses when it comes to defining a subclass with positional arguments, while
# the parent class already defines keyword arguemnts (positional arguments cannot follow keyword arguments)
# workaroung is defined here: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
[docs]
@dataclass
class _OpenAiAzureChatBase:
azure_endpoint: str
[docs]
@dataclass
class OpenAiAzureChat(OpenAiChat, _OpenAiAzureChatBase):
"""Tool to interact with Azure OpenAI chat models.
This can also be constructed with :meth:`~OpenAiChat.from_yaml`.
See Also:
* OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
* `Quickstart: Get started generating text using Azure OpenAI Service <https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart?tabs=command-line&pivots=programming-language-python>`_
* ```AzureADTokenProvider`` example <https://github.com/openai/openai-python/blob/main/examples/azure_ad.py>`_
Args:
api_key: The OpenAI API key.
model: The OpenAI model name.
api_version: The OpenAI API version.
A common value for this is ``2023-05-15``.
azure_ad_token: The Azure Active Directory token.
azure_ad_token_provider: A function that returns an Azure Active Directory token,
which will be invoked on every request.
Or set to "auto" to use default credentials.
azure_endpoint: The Azure endpoint.
"""
api_version: Optional[str] = None
api_key: Optional[str] = None
azure_ad_token: Optional[str] = None
azure_ad_token_provider: Union[AzureADTokenProvider, str, None] = None
def __post_init__(self) -> None:
"""Do post init."""
# init default token provider if azure_ad_token_provider=="auto"
if self.azure_ad_token_provider == "auto": # NOQA: S105
self.azure_ad_token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
self.azure_ad_token_provider = cast("Optional[AzureADTokenProvider]", self.azure_ad_token_provider)
self.client = AzureOpenAI(
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
)
self.async_client = AsyncAzureOpenAI(
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
)
[docs]
@classmethod
def from_yaml(
cls,
yaml_file,
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Union[AzureADTokenProvider, str, None] = None,
**kwargs,
):
"""Construct this class from a yaml file.
If the ``api_key`` is not set in the yaml file,
it will be loaded from the environment variable ``OPENAI_API_KEY``.
Args:
yaml_file: The yaml file.
api_key: The OpenAI API key.
azure_ad_token: The Azure Active Directory token.
azure_ad_token_provider: A function that returns an Azure Active Directory token,
which will be invoked on every request.
Or set to "auto" to use default credentials.
kwargs: extra kwargs to override parameters
Returns:
The constructed class.
"""
with open(yaml_file, "r") as file:
completion_kwargs = yaml.safe_load(file)
# set azure_ad_token according to this priority:
# method parameter > yaml > environment variable
azure_ad_token = azure_ad_token or completion_kwargs.get("AZURE_AD_TOKEN") or os.getenv("AZURE_AD_TOKEN")
# init the token provider
## method parameter > yaml
azure_ad_token_provider = azure_ad_token_provider or completion_kwargs.get("azure_ad_token_provider")
## if token_provider==auto use default settings
if azure_ad_token_provider == "auto": # NOQA: S105
azure_ad_token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
return super().from_yaml(
yaml_file,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
**kwargs,
)