Source code for mltb2.openai

# Copyright (c) 2023-2025 Philip May
# Copyright (c) 2024-2025 Philip May, Deutsche Telekom AG
# Copyright (c) 2024 Alaeddine Abdessalem, Deutsche Telekom AG
# Copyright (c) 2025 Sijun John Tu, Deutsche Telekom AG
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

"""OpenAI specific module.

Hint:
    Use pip to install the necessary dependencies for this module:
    ``pip install mltb2[openai]``
"""


import os
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Optional, Union, cast

import tiktoken
import yaml
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.lib.azure import AzureADTokenProvider
from openai.types.chat import ChatCompletion
from tiktoken.core import Encoding
from tqdm import tqdm


[docs] @dataclass class OpenAiTokenCounter: """Count OpenAI tokens. Args: model_name: The OpenAI model name. Some examples: * ``gpt-4`` * ``gpt-3.5-turbo`` * ``text-davinci-003`` * ``text-embedding-ada-002`` show_progress_bar: Show a progressbar during processing. """ model_name: str encoding: Encoding = field(init=False, repr=False) show_progress_bar: bool = False def __post_init__(self) -> None: """Do post init.""" self.encoding = tiktoken.encoding_for_model(self.model_name)
[docs] def __call__(self, text: Union[str, Iterable]) -> Union[int, list[int]]: """Count tokens for text. Args: text: The text for which the tokens are to be counted. Returns: The number of tokens if text was just a ``str``. If text is an ``Iterable`` then a ``list`` of number of tokens. """ if isinstance(text, str): tokenized_text = self.encoding.encode(text) return len(tokenized_text) else: counts = [] for t in tqdm(text, disable=not self.show_progress_bar): tokenized_text = self.encoding.encode(t) counts.append(len(tokenized_text)) return counts
[docs] @dataclass class OpenAiChatResult: """Result of an OpenAI chat completion. If you want to convert this to a ``dict`` use ``asdict(open_ai_chat_result)`` from the ``dataclasses`` module. See Also: OpenAI API reference: `The chat completion object <https://platform.openai.com/docs/api-reference/chat/object>`_ Args: content: the result of the OpenAI completion model: model name which has been used prompt_tokens: number of tokens of the prompt completion_tokens: number of tokens of the completion (``content``) total_tokens: number of total tokens (``prompt_tokens + content_tokens``) finish_reason: The reason why the completion stopped. * ``stop``: Means the API returned the full completion without running into any token limit. * ``length``: Means the API stopped the completion because of running into a token limit. * ``content_filter``: When content was omitted due to a flag from the OpenAI content filters. * ``tool_calls``: When the model called a tool. * ``function_call`` (deprecated): When the model called a function. completion_args: The arguments which have been used for the completion. Examples: * ``model``: always set * ``max_tokens``: only set if ``completion_kwargs`` contained ``max_tokens`` * ``temperature``: only set if ``completion_kwargs`` contained ``temperature`` * ``top_p``: only set if ``completion_kwargs`` contained ``top_p`` """ content: Optional[str] = None model: Optional[str] = None prompt_tokens: Optional[int] = None completion_tokens: Optional[int] = None total_tokens: Optional[int] = None finish_reason: Optional[str] = None completion_args: Optional[dict[str, Any]] = None
[docs] @classmethod def from_chat_completion( cls, chat_completion: ChatCompletion, completion_kwargs: Optional[dict[str, Any]] = None, ): """Construct this class from an OpenAI ``ChatCompletion`` object. Args: chat_completion: The OpenAI ``ChatCompletion`` object. completion_kwargs: The arguments which have been used for the completion. Returns: The constructed class. """ result = {} result["completion_args"] = completion_kwargs chat_completion_dict = chat_completion.model_dump() result["model"] = chat_completion_dict.get("model") usage = chat_completion_dict.get("usage") if usage is not None: result["prompt_tokens"] = usage.get("prompt_tokens") result["completion_tokens"] = usage.get("completion_tokens") result["total_tokens"] = usage.get("total_tokens") choices = chat_completion_dict.get("choices") if choices is not None and len(choices) > 0: choice = choices[0] result["finish_reason"] = choice.get("finish_reason") message = choice.get("message") if message is not None: result["content"] = message.get("content") return cls(**result) # type: ignore[arg-type]
[docs] def remove_openai_tokens(messages: list[dict[str, str]]) -> list[dict[str, str]]: """Remove OpenAI special tokens from the messages. These tokens are ``<|im_start|>`` and ``<|im_end|>`` and they can cause problems when passed to the OpenAI API. Args: messages: The OpenAI messages. Returns: The messages without OpenAI special tokens. """ result = messages.copy() for d in result: d["content"] = d["content"].replace("<|im_start|>", "") d["content"] = d["content"].replace("<|im_end|>", "") return result
[docs] @dataclass class OpenAiChat: """Tool to interact with OpenAI chat models. This also be constructed with :meth:`~OpenAiChat.from_yaml`. See Also: OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_ Args: api_key: The OpenAI API key. model: The OpenAI model name. base_url: change the base url - this is useful for Perplexity """ model: str client: Union[OpenAI, AzureOpenAI] = field(init=False, repr=False) async_client: Union[AsyncOpenAI, AsyncAzureOpenAI] = field(init=False, repr=False) api_key: Optional[str] = None base_url: Optional[str] = None def __post_init__(self) -> None: """Do post init.""" self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) self.async_client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
[docs] @classmethod def from_yaml(cls, yaml_file, api_key: Optional[str] = None, **kwargs): """Construct this class from a yaml file. If the ``api_key`` is not set in the yaml file, it will be loaded from the environment variable ``OPENAI_API_KEY``. Args: yaml_file: The yaml file. api_key: The OpenAI API key. kwargs: extra kwargs to override parameters Returns: The constructed class. """ with open(yaml_file, "r") as file: completion_kwargs = yaml.safe_load(file) # set api_key according to this priority: # method parameter > yaml > environment variable api_key = api_key or completion_kwargs.get("api_key") or os.getenv("OPENAI_API_KEY") completion_kwargs["api_key"] = api_key if kwargs: completion_kwargs.update(kwargs) return cls(**completion_kwargs)
[docs] def create_completions( self, prompt: Union[str, list[dict[str, str]]], completion_kwargs: Optional[dict[str, Any]] = None, clean_openai_tokens: bool = False, ) -> OpenAiChatResult: """Create a model response for the given prompt (chat conversation). Args: prompt: The prompt for the model. completion_kwargs: Keyword arguments for the OpenAI completion. - ``model`` can not be set via ``completion_kwargs``! Please set the ``model`` in the initializer. - ``messages`` can not be set via ``completion_kwargs``! Please set the ``prompt`` argument. Also see: - ``openai.resources.chat.completions.Completions.create()`` - OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_ clean_openai_tokens: Remove OpenAI special tokens from the prompt. Returns: The result of the OpenAI completion. """ if isinstance(prompt, list): for message in prompt: if "role" not in message or "content" not in message: raise ValueError( "If prompt is a list of messages, each message must have a 'role' and 'content' key!" ) if message["role"] not in ["system", "user", "assistant", "tool"]: raise ValueError( "If prompt is a list of messages, each message must have a 'role' key with one of the values " "'system', 'user', 'assistant' or 'tool'!" ) if completion_kwargs is not None: # check keys of completion_kwargs if "model" in completion_kwargs: raise ValueError( "'model' can not be set via 'completion_kwargs'! Please set the 'model' in the initializer." ) if "messages" in completion_kwargs: raise ValueError( "'messages' can not be set via 'completion_kwargs'! Please set the 'prompt' argument." ) else: completion_kwargs = {} # set default value completion_kwargs["model"] = self.model messages = [{"role": "user", "content": prompt}] if isinstance(prompt, str) else prompt if clean_openai_tokens: messages = remove_openai_tokens(messages) chat_completion = self.client.chat.completions.create( messages=messages, # type: ignore[arg-type] **completion_kwargs, ) result = OpenAiChatResult.from_chat_completion(chat_completion, completion_kwargs=completion_kwargs) return result
[docs] async def create_completions_async( self, prompt: Union[str, list[dict[str, str]]], completion_kwargs: Optional[dict[str, Any]] = None, clean_openai_tokens: bool = False, ) -> OpenAiChatResult: """Create a model response for the given prompt (chat conversation). Args: prompt: The prompt for the model. completion_kwargs: Keyword arguments for the OpenAI completion. - ``model`` can not be set via ``completion_kwargs``! Please set the ``model`` in the initializer. - ``messages`` can not be set via ``completion_kwargs``! Please set the ``prompt`` argument. Also see: - ``openai.resources.chat.completions.Completions.create()`` - OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_ clean_openai_tokens: Remove OpenAI special tokens from the prompt. Returns: The result of the OpenAI completion. """ if isinstance(prompt, list): for message in prompt: if "role" not in message or "content" not in message: raise ValueError( "If prompt is a list of messages, each message must have a 'role' and 'content' key!" ) if message["role"] not in ["system", "user", "assistant", "tool"]: raise ValueError( "If prompt is a list of messages, each message must have a 'role' key with one of the values " "'system', 'user', 'assistant' or 'tool'!" ) if completion_kwargs is not None: # check keys of completion_kwargs if "model" in completion_kwargs: raise ValueError( "'model' can not be set via 'completion_kwargs'! Please set the 'model' in the initializer." ) if "messages" in completion_kwargs: raise ValueError( "'messages' can not be set via 'completion_kwargs'! Please set the 'prompt' argument." ) else: completion_kwargs = {} # set default value completion_kwargs["model"] = self.model messages = [{"role": "user", "content": prompt}] if isinstance(prompt, str) else prompt if clean_openai_tokens: messages = remove_openai_tokens(messages) chat_completion = await self.async_client.chat.completions.create( messages=messages, # type: ignore[arg-type] **completion_kwargs, ) result = OpenAiChatResult.from_chat_completion(chat_completion, completion_kwargs=completion_kwargs) return result
# there is a limitation with python dataclasses when it comes to defining a subclass with positional arguments, while # the parent class already defines keyword arguemnts (positional arguments cannot follow keyword arguments) # workaroung is defined here: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
[docs] @dataclass class _OpenAiAzureChatBase: azure_endpoint: str
[docs] @dataclass class OpenAiAzureChat(OpenAiChat, _OpenAiAzureChatBase): """Tool to interact with Azure OpenAI chat models. This can also be constructed with :meth:`~OpenAiChat.from_yaml`. See Also: * OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_ * `Quickstart: Get started generating text using Azure OpenAI Service <https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart?tabs=command-line&pivots=programming-language-python>`_ * ```AzureADTokenProvider`` example <https://github.com/openai/openai-python/blob/main/examples/azure_ad.py>`_ Args: api_key: The OpenAI API key. model: The OpenAI model name. api_version: The OpenAI API version. A common value for this is ``2023-05-15``. azure_ad_token: The Azure Active Directory token. azure_ad_token_provider: A function that returns an Azure Active Directory token, which will be invoked on every request. Or set to "auto" to use default credentials. azure_endpoint: The Azure endpoint. """ api_version: Optional[str] = None api_key: Optional[str] = None azure_ad_token: Optional[str] = None azure_ad_token_provider: Union[AzureADTokenProvider, str, None] = None def __post_init__(self) -> None: """Do post init.""" # init default token provider if azure_ad_token_provider=="auto" if self.azure_ad_token_provider == "auto": # NOQA: S105 self.azure_ad_token_provider = get_bearer_token_provider( DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" ) self.azure_ad_token_provider = cast("Optional[AzureADTokenProvider]", self.azure_ad_token_provider) self.client = AzureOpenAI( api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider, ) self.async_client = AsyncAzureOpenAI( api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider, )
[docs] @classmethod def from_yaml( cls, yaml_file, api_key: Optional[str] = None, azure_ad_token: Optional[str] = None, azure_ad_token_provider: Union[AzureADTokenProvider, str, None] = None, **kwargs, ): """Construct this class from a yaml file. If the ``api_key`` is not set in the yaml file, it will be loaded from the environment variable ``OPENAI_API_KEY``. Args: yaml_file: The yaml file. api_key: The OpenAI API key. azure_ad_token: The Azure Active Directory token. azure_ad_token_provider: A function that returns an Azure Active Directory token, which will be invoked on every request. Or set to "auto" to use default credentials. kwargs: extra kwargs to override parameters Returns: The constructed class. """ with open(yaml_file, "r") as file: completion_kwargs = yaml.safe_load(file) # set azure_ad_token according to this priority: # method parameter > yaml > environment variable azure_ad_token = azure_ad_token or completion_kwargs.get("AZURE_AD_TOKEN") or os.getenv("AZURE_AD_TOKEN") # init the token provider ## method parameter > yaml azure_ad_token_provider = azure_ad_token_provider or completion_kwargs.get("azure_ad_token_provider") ## if token_provider==auto use default settings if azure_ad_token_provider == "auto": # NOQA: S105 azure_ad_token_provider = get_bearer_token_provider( DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" ) return super().from_yaml( yaml_file, api_key=api_key, azure_ad_token=azure_ad_token, azure_ad_token_provider=azure_ad_token_provider, **kwargs, )