Skip to content

Embedding

step

consists of DVCSteps to embedd files and save them as for example as csv.

Classes

Embedded

Bases: TypedDict

dict definition of a embedded document.

Source code in wurzel/steps/embedding/step.py
class Embedded(TypedDict):
    """dict definition of a embedded document."""

    text: str
    url: str
    vector: list[float]

EmbeddingStep

Bases: SimpleSplitterStep, TypedStep[EmbeddingSettings, list[MarkdownDataContract], DataFrame[EmbeddingResult]]

Step for consuming list[MarkdownDataContract] and returning DataFrame[EmbeddingResult].

Source code in wurzel/steps/embedding/step.py
class EmbeddingStep(
    SimpleSplitterStep,
    TypedStep[EmbeddingSettings, list[MarkdownDataContract], DataFrame[EmbeddingResult]],
):
    """Step for consuming list[MarkdownDataContract]
    and returning DataFrame[EmbeddingResult].
    """

    embedding: HuggingFaceInferenceAPIEmbeddings
    n_jobs: int
    markdown: Markdown
    stopwords: list[str]
    settings: EmbeddingSettings

    def __init__(self) -> None:
        super().__init__()
        self.embedding = self._select_embedding()
        self.n_jobs = max(1, (os.cpu_count() or 0) - 1)
        # Inject net output_format into 3rd party library Markdown
        Markdown.output_formats["plain"] = self.__md_to_plain  # type: ignore[index]
        self.markdown = Markdown(output_format="plain")  # type: ignore[arg-type]
        self.markdown.stripTopLevelTags = False
        self.settingstopwords = self._load_stopwords()

    def _load_stopwords(self) -> list[str]:
        path = self.settings.STEPWORDS_PATH
        with path.open(encoding="utf-8") as f:
            stopwords = [w.strip() for w in f.readlines() if not w.startswith(";")]
        return stopwords

    def _select_embedding(self) -> HuggingFaceInferenceAPIEmbeddings:
        """Selects the embedding model to be used for generating embeddings.

        Returns:
        -------
        Embeddings
            An instance of the Embeddings class.

        """
        return PrefixedAPIEmbeddings(self.settings.API, self.settings.PREFIX_MAP)

    def run(self, inpt: list[MarkdownDataContract]) -> DataFrame[EmbeddingResult]:
        """Executes the embedding step by processing input markdown files, generating embeddings,
        and saving them to a CSV file.
        """
        if len(inpt) == 0:
            log.info("Got empty result in Embedding - Skipping")
            return DataFrame[EmbeddingResult]([])
        splitted_md_rows = self._split_markdown(inpt)
        rows = []
        failed = 0
        for row in tqdm(splitted_md_rows, desc="Calculate Embeddings"):
            try:
                rows.append(self._get_embedding(row))
            except EmbeddingAPIException as err:
                log.warning(
                    f"Skipped because EmbeddingAPIException: {err.message}",
                    extra={"markdown": str(row)},
                )
                failed += 1
        if failed:
            log.warning(f"{failed}/{len(splitted_md_rows)} got skipped")
        if failed == len(splitted_md_rows):
            raise StepFailed(f"all {len(splitted_md_rows)} embeddings got skipped")
        return DataFrame[EmbeddingResult](DataFrame[EmbeddingResult](rows))

    def get_embedding_input_from_document(self, doc: MarkdownDataContract) -> str:
        """Clean the document such that it can be used as input to the embedding model.

        Parameters
        ----------
        doc : MarkdownDataContract
            The document containing the page content in Markdown format.

        Returns:
        -------
        str
            Cleaned text that can be used as input to the embedding model.

        """
        plain_text = self.markdown.convert(doc.md)
        plain_text = self._replace_link(plain_text)

        return plain_text

    def _get_embedding(self, doc: MarkdownDataContract) -> Embedded:
        """Generates an embedding for a given text and context.

        Parameters
        ----------
        d : dict
            A dictionary containing the text and context for which to generate the embedding.

        Returns:
        -------
        dict
            A dictionary containing the original text, its embedding, and the source URL.

        """
        context = self.get_simple_context(doc.keywords)
        text = self.get_embedding_input_from_document(doc) if self.settings.CLEAN_MD_BEFORE_EMBEDDING else doc.md
        vector = self.embedding.embed_query(text)
        return {"text": doc.md, "vector": vector, "url": doc.url, "keywords": context}

    def is_stopword(self, word: str) -> bool:
        """Stopword Detection Function."""
        return word.lower() in self.settingstopwords

    @classmethod
    def whitespace_word_tokenizer(cls, text: str) -> list[str]:
        """Simple Regex based whitespace word tokenizer."""
        return [x for x in re.split(r"([.,!?]+)?\s+", text) if x]

    def get_simple_context(self, text):
        """Simple function to create a context from a text."""
        tokens = self.whitespace_word_tokenizer(text)
        filtered_tokens = [token for token in tokens if not self.is_stopword(token)]
        return " ".join(filtered_tokens)

    @classmethod
    def __md_to_plain(cls, element, stream: Optional[StringIO] = None):
        """Converts a markdown element into plain text.

        Parameters
        ----------
        element : Element
            The markdown element to convert.
        stream : StringIO, optional
            The stream to which the plain text is written. If None, a new stream is created.

        Returns:
        -------
        str
            The plain text representation of the markdown element.

        """
        if stream is None:
            stream = StringIO()
        if element.text:
            stream.write(element.text)
        for sub in element:
            cls.__md_to_plain(sub, stream)
        if element.tail:
            stream.write(element.tail)
        return stream.getvalue()

    @classmethod
    def _replace_link(cls, text: str):
        """Replaces URLs in the text with a placeholder.

        Parameters
        ----------
        text : str
            The text in which URLs will be replaced.

        Returns:
        -------
        str
            The text with URLs replaced by 'LINK'.

        """
        # Extract URL from a string
        url_extract_pattern = (
            "https?:\\/\\/(?:www\\.)?[-a-zA-Z0-9@:%._\\+~#=]{1,256}\\.[a-zA-Z0-9()]{1,6}\\b(?:[-a-zA-Z0-9()@:%_\\+.~#?&\\/=]*)"  # pylint: disable=line-too-long
        )
        links = sorted(re.findall(url_extract_pattern, text), key=len, reverse=True)
        for link in links:
            text = text.replace(link, "LINK")
        return text
Functions
__md_to_plain(element, stream=None) classmethod

Converts a markdown element into plain text.

Parameters

element : Element The markdown element to convert. stream : StringIO, optional The stream to which the plain text is written. If None, a new stream is created.

Returns:

str The plain text representation of the markdown element.

Source code in wurzel/steps/embedding/step.py
@classmethod
def __md_to_plain(cls, element, stream: Optional[StringIO] = None):
    """Converts a markdown element into plain text.

    Parameters
    ----------
    element : Element
        The markdown element to convert.
    stream : StringIO, optional
        The stream to which the plain text is written. If None, a new stream is created.

    Returns:
    -------
    str
        The plain text representation of the markdown element.

    """
    if stream is None:
        stream = StringIO()
    if element.text:
        stream.write(element.text)
    for sub in element:
        cls.__md_to_plain(sub, stream)
    if element.tail:
        stream.write(element.tail)
    return stream.getvalue()
get_embedding_input_from_document(doc)

Clean the document such that it can be used as input to the embedding model.

Parameters

doc : MarkdownDataContract The document containing the page content in Markdown format.

Returns:

str Cleaned text that can be used as input to the embedding model.

Source code in wurzel/steps/embedding/step.py
def get_embedding_input_from_document(self, doc: MarkdownDataContract) -> str:
    """Clean the document such that it can be used as input to the embedding model.

    Parameters
    ----------
    doc : MarkdownDataContract
        The document containing the page content in Markdown format.

    Returns:
    -------
    str
        Cleaned text that can be used as input to the embedding model.

    """
    plain_text = self.markdown.convert(doc.md)
    plain_text = self._replace_link(plain_text)

    return plain_text
get_simple_context(text)

Simple function to create a context from a text.

Source code in wurzel/steps/embedding/step.py
def get_simple_context(self, text):
    """Simple function to create a context from a text."""
    tokens = self.whitespace_word_tokenizer(text)
    filtered_tokens = [token for token in tokens if not self.is_stopword(token)]
    return " ".join(filtered_tokens)
is_stopword(word)

Stopword Detection Function.

Source code in wurzel/steps/embedding/step.py
def is_stopword(self, word: str) -> bool:
    """Stopword Detection Function."""
    return word.lower() in self.settingstopwords
run(inpt)

Executes the embedding step by processing input markdown files, generating embeddings, and saving them to a CSV file.

Source code in wurzel/steps/embedding/step.py
def run(self, inpt: list[MarkdownDataContract]) -> DataFrame[EmbeddingResult]:
    """Executes the embedding step by processing input markdown files, generating embeddings,
    and saving them to a CSV file.
    """
    if len(inpt) == 0:
        log.info("Got empty result in Embedding - Skipping")
        return DataFrame[EmbeddingResult]([])
    splitted_md_rows = self._split_markdown(inpt)
    rows = []
    failed = 0
    for row in tqdm(splitted_md_rows, desc="Calculate Embeddings"):
        try:
            rows.append(self._get_embedding(row))
        except EmbeddingAPIException as err:
            log.warning(
                f"Skipped because EmbeddingAPIException: {err.message}",
                extra={"markdown": str(row)},
            )
            failed += 1
    if failed:
        log.warning(f"{failed}/{len(splitted_md_rows)} got skipped")
    if failed == len(splitted_md_rows):
        raise StepFailed(f"all {len(splitted_md_rows)} embeddings got skipped")
    return DataFrame[EmbeddingResult](DataFrame[EmbeddingResult](rows))
whitespace_word_tokenizer(text) classmethod

Simple Regex based whitespace word tokenizer.

Source code in wurzel/steps/embedding/step.py
@classmethod
def whitespace_word_tokenizer(cls, text: str) -> list[str]:
    """Simple Regex based whitespace word tokenizer."""
    return [x for x in re.split(r"([.,!?]+)?\s+", text) if x]

step_multivector

consists of DVCSteps to embedd files and save them as for example as csv.

Classes

EmbeddingMultiVectorStep

Bases: EmbeddingStep, TypedStep[EmbeddingSettings, list[MarkdownDataContract], DataFrame[EmbeddingMultiVectorResult]]

Step for consuming list[MarkdownDataContract] and returning DataFrame[EmbeddingMultiVectorResult].

Source code in wurzel/steps/embedding/step_multivector.py
class EmbeddingMultiVectorStep(
    EmbeddingStep,
    TypedStep[
        EmbeddingSettings,
        list[MarkdownDataContract],
        DataFrame[EmbeddingMultiVectorResult],
    ],
):
    """Step for consuming list[MarkdownDataContract]
    and returning DataFrame[EmbeddingMultiVectorResult].
    """

    def run(self, inpt: list[MarkdownDataContract]) -> DataFrame[EmbeddingMultiVectorResult]:
        """Executes the embedding step by processing a list of MarkdownDataContract objects,
        generating embeddings for each document, and returning the results as a DataFrame.

        Args:
            inpt (list[MarkdownDataContract]): A list of markdown data contracts to process.

        Returns:
            DataFrame[EmbeddingMultiVectorResult]: A DataFrame containing the embedding results.

        Raises:
            StepFailed: If all input documents fail to generate embeddings.

        Logs:
            - Warnings for documents skipped due to EmbeddingAPIException.
            - A summary warning if some or all documents are skipped.

        """

        def process_document(doc):
            try:
                return self._get_embedding(doc)
            except EmbeddingAPIException as err:
                log.warning(
                    f"Skipped because EmbeddingAPIException: {err.message}",
                    extra={"markdown": str(doc)},
                )
                return None

        results = Parallel(backend="threading", n_jobs=self.settings.N_JOBS)(delayed(process_document)(doc) for doc in inpt)

        rows = [res for res in results if res is not None]
        failed = len(results) - len(rows)

        if failed:
            log.warning(f"{failed}/{len(inpt)} got skipped")
        if failed == len(inpt):
            raise StepFailed(f"All {len(inpt)} embeddings got skipped")

        return DataFrame[EmbeddingMultiVectorResult](DataFrame[EmbeddingMultiVectorResult](rows))

    def _get_embedding(self, doc: MarkdownDataContract) -> _EmbeddedMultiVector:
        """Generates an embedding for a given text and context.

        Parameters
        ----------
        d : dict
            A dictionary containing the text and context for which to generate the embedding.

        Returns:
        -------
        dict
            A dictionary containing the original text, its embedding, and the source URL.

        """

        def prepare_plain(document: MarkdownDataContract) -> str:
            plain_text = self.markdown.convert(document.md)
            plain_text = self._replace_link(plain_text)
            return plain_text

        try:
            splitted_md_rows = self._split_markdown([doc])
        except SplittException as err:
            raise EmbeddingAPIException("splitting failed") from err
        vectors = [self.embedding.embed_query(prepare_plain(split)) for split in splitted_md_rows]
        if not vectors:
            raise EmbeddingAPIException("Embedding failed for all splits")

        context = self.get_simple_context(doc.keywords)

        return {
            "text": doc.md,
            "vectors": vectors,
            "url": doc.url,
            "keywords": context,
            "splits": [split.md for split in splitted_md_rows],
        }
Functions
run(inpt)

Executes the embedding step by processing a list of MarkdownDataContract objects, generating embeddings for each document, and returning the results as a DataFrame.

Parameters:

Name Type Description Default
inpt list[MarkdownDataContract]

A list of markdown data contracts to process.

required

Returns:

Type Description
DataFrame[EmbeddingMultiVectorResult]

DataFrame[EmbeddingMultiVectorResult]: A DataFrame containing the embedding results.

Raises:

Type Description
StepFailed

If all input documents fail to generate embeddings.

Logs
  • Warnings for documents skipped due to EmbeddingAPIException.
  • A summary warning if some or all documents are skipped.
Source code in wurzel/steps/embedding/step_multivector.py
def run(self, inpt: list[MarkdownDataContract]) -> DataFrame[EmbeddingMultiVectorResult]:
    """Executes the embedding step by processing a list of MarkdownDataContract objects,
    generating embeddings for each document, and returning the results as a DataFrame.

    Args:
        inpt (list[MarkdownDataContract]): A list of markdown data contracts to process.

    Returns:
        DataFrame[EmbeddingMultiVectorResult]: A DataFrame containing the embedding results.

    Raises:
        StepFailed: If all input documents fail to generate embeddings.

    Logs:
        - Warnings for documents skipped due to EmbeddingAPIException.
        - A summary warning if some or all documents are skipped.

    """

    def process_document(doc):
        try:
            return self._get_embedding(doc)
        except EmbeddingAPIException as err:
            log.warning(
                f"Skipped because EmbeddingAPIException: {err.message}",
                extra={"markdown": str(doc)},
            )
            return None

    results = Parallel(backend="threading", n_jobs=self.settings.N_JOBS)(delayed(process_document)(doc) for doc in inpt)

    rows = [res for res in results if res is not None]
    failed = len(results) - len(rows)

    if failed:
        log.warning(f"{failed}/{len(inpt)} got skipped")
    if failed == len(inpt):
        raise StepFailed(f"All {len(inpt)} embeddings got skipped")

    return DataFrame[EmbeddingMultiVectorResult](DataFrame[EmbeddingMultiVectorResult](rows))

settings

Classes

EmbeddingSettings

Bases: SplitterSettings

EmbeddingSettings is a configuration class for embedding-related settings.

Attributes:

Name Type Description
API Url

The API endpoint for embedding operations.

NORMALIZE bool

A flag indicating whether to normalize embeddings. Defaults to False.

BATCH_SIZE int

The batch size for processing embeddings. Must be greater than 0. Defaults to 100.

TOKEN_COUNT_MIN int

The minimum token count for processing. Must be greater than 0. Defaults to 64.

TOKEN_COUNT_MAX int

The maximum token count for processing. Must be greater than 1. Defaults to 256.

TOKEN_COUNT_BUFFER int

The buffer size for token count. Must be greater than 0. Defaults to 32.

STEPWORDS_PATH Path

The file path to the stopwords file. Defaults to "data/german_stopwords_full.txt".

N_JOBS int

The number of parallel jobs to use. Must be greater than 0. Defaults to 1.

PREFIX_MAP dict[Pattern, str]

A mapping of regex patterns to string prefixes. This is validated and transformed using the _wrap_validator_model_mapping method.

Methods:

Name Description
_wrap_validator_model_mapping

dict[str, str], handler): A static method to wrap and validate the model mapping. It converts string regex keys in the input dictionary to compiled regex patterns and applies a handler function to the result.

Source code in wurzel/steps/embedding/settings.py
class EmbeddingSettings(SplitterSettings):
    """EmbeddingSettings is a configuration class for embedding-related settings.

    Attributes:
        API (Url): The API endpoint for embedding operations.
        NORMALIZE (bool): A flag indicating whether to normalize embeddings. Defaults to False.
        BATCH_SIZE (int): The batch size for processing embeddings. Must be greater than 0. Defaults to 100.
        TOKEN_COUNT_MIN (int): The minimum token count for processing. Must be greater than 0. Defaults to 64.
        TOKEN_COUNT_MAX (int): The maximum token count for processing. Must be greater than 1. Defaults to 256.
        TOKEN_COUNT_BUFFER (int): The buffer size for token count. Must be greater than 0. Defaults to 32.
        STEPWORDS_PATH (Path): The file path to the stopwords file. Defaults to "data/german_stopwords_full.txt".
        N_JOBS (int): The number of parallel jobs to use. Must be greater than 0. Defaults to 1.
        PREFIX_MAP (dict[re.Pattern, str]): A mapping of regex patterns to string prefixes.
            This is validated and transformed using the `_wrap_validator_model_mapping` method.

    Methods:
        _wrap_validator_model_mapping(input_dict: dict[str, str], handler):
            A static method to wrap and validate the model mapping. It converts string regex keys
            in the input dictionary to compiled regex patterns and applies a handler function to the result.

    """

    @staticmethod
    def _wrap_validator_model_mapping(input_dict: dict[str, str], handler):
        new_dict = {}
        for regex, prefix in input_dict.items():
            if isinstance(regex, str):
                new_dict[re.compile(regex)] = prefix
            else:
                new_dict.update({regex: prefix})
        return handler(new_dict)

    API: Url
    NORMALIZE: bool = False
    BATCH_SIZE: int = Field(100, gt=0)
    TOKEN_COUNT_MIN: int = Field(64, gt=0)
    TOKEN_COUNT_MAX: int = Field(256, gt=1)
    TOKEN_COUNT_BUFFER: int = Field(32, gt=0)
    STEPWORDS_PATH: Path = Path("data/german_stopwords_full.txt")
    N_JOBS: int = Field(1, gt=0)
    PREFIX_MAP: Annotated[dict[re.Pattern, str], WrapValidator(_wrap_validator_model_mapping)] = Field(
        default={"e5-": "query: ", "DPR|dpr": ""}
    )
    CLEAN_MD_BEFORE_EMBEDDING: bool = True
    TOKENIZER_MODEL: str = Field("gpt-3.5-turbo", description="The tokenizer model to use for splitting documents.")