containing the DVCStep sending embedding data into Qdrant.

QdrantConnectorStep

Bases: TypedStep[QdrantSettings, DataFrame[EmbeddingResult], DataFrame[QdrantResult]]

Qdrant connector step. It consumes embedding csv files, creates a new schema and inserts the embeddings.

Source code in wurzel/steps/qdrant/step.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
class QdrantConnectorStep(TypedStep[QdrantSettings, DataFrame[EmbeddingResult], DataFrame[QdrantResult]]):
    """Qdrant connector step. It consumes embedding csv files, creates a new schema and inserts the embeddings."""

    _timeout: int = 20
    s: QdrantSettings
    client: QdrantClient
    collection_name: str
    result_class = QdrantResult
    vector_key = "vector"

    def __init__(self) -> None:
        super().__init__()
        # Qdrant stuff passed as environment
        # because we need to enject them into the DVC step during runtime,
        # not during DVC pipeline definition time
        # uri = ":memory:"
        log.info(f"connecting to {self.settings.URI}")
        if not self.settings.APIKEY:
            log.warning("QDRANT__APIKEY for Qdrant not provided. Thus running in non-credential Mode")
        self.client = QdrantClient(
            location=self.settings.URI,
            api_key=self.settings.APIKEY,
            timeout=self._timeout,
        )
        self.collection_name = self.__construct_next_collection_name()
        self.id_iter = self.__id_gen()

    def __del__(self):
        if getattr(self, "client", None):
            self.client.close()

    def finalize(self) -> None:
        self._create_indices()
        self._update_alias()
        self._retire_collections()
        return super().finalize()

    def __id_gen(self):
        i = 0
        while True:
            i += 1
            yield i

    def run(self, inpt: DataFrame[EmbeddingResult]) -> DataFrame[QdrantResult]:
        if not self.client.collection_exists(self.collection_name):
            self._create_collection(len(inpt["vector"].loc[0]))
        df_result = self._insert_embeddings(inpt)
        return df_result

    def _create_collection(self, size: int):
        log.debug(f"Creating Qdrant collection {self.collection_name}")
        self.client.create_collection(
            collection_name=self.collection_name,
            vectors_config=models.VectorParams(size=size, distance=self.settings.DISTANCE),
            replication_factor=self.settings.REPLICATION_FACTOR,
        )

    def _get_entry_payload(self, row: dict[str, object]) -> dict[str, object]:
        """Create the payload for the entry."""
        payload = {
            "url": row["url"],
            "text": row["text"],
            **self.get_available_hashes(row["text"]),
            "keywords": row["keywords"],
            "history": str(step_history.get()),
        }
        return payload

    def _create_point(self, row: dict) -> models.PointStruct:
        """Creates a Qdrant PointStruct object from a given row dictionary.

        Args:
            row (dict): A dictionary representing a data entry, expected to contain at least the vector data under `self.vector_key`.

        Returns:
            models.PointStruct: An instance of PointStruct with a unique id, vector, and payload extracted from the row.

        Raises:
            KeyError: If the required vector key is not present in the row.

        """
        payload = self._get_entry_payload(row)

        return models.PointStruct(
            id=next(self.id_iter),  # type: ignore[arg-type]
            vector=row[self.vector_key],
            payload=payload,
        )

    def _upsert_points(self, points: list[models.PointStruct]):
        """Inserts a list of points into the Qdrant collection in batches.

        Args:
            points (list[models.PointStruct]): The list of point structures to upsert into the collection.

        Raises:
            StepFailed: If any batch fails to be inserted into the collection.

        Logs:
            Logs a message for each successfully inserted batch, including the collection name and number of points.

        """
        for point_chunk in _batch(points, self.settings.BATCH_SIZE):
            operation_info = self.client.upsert(
                collection_name=self.collection_name,
                wait=True,
                points=point_chunk,
            )
            if operation_info.status != "completed":
                raise StepFailed(f"Failed to insert df chunk into collection '{self.collection_name}' {operation_info}")
            log.info(
                "Successfully inserted vector_chunk",
                extra={"collection": self.collection_name, "count": len(point_chunk)},
            )

    def _build_result_dataframe(self, points: list[models.PointStruct]):
        """Constructs a DataFrame from a list of PointStruct objects.

        Each PointStruct's payload is unpacked into the resulting dictionary, along with its vector, collection name, and ID.
        The resulting list of dictionaries is used to create a DataFrame of the specified result_class.

        Args:
            points (list[models.PointStruct]): A list of PointStruct objects containing payload, vector, and id information.

        """
        result_data = [
            {
                **entry.payload,
                self.vector_key: entry.vector,
                "collection": self.collection_name,
                "id": entry.id,
            }
            for entry in points
        ]
        return DataFrame[self.result_class](result_data)

    def _insert_embeddings(self, data: DataFrame[EmbeddingResult]):
        log.info("Inserting embeddings", extra={"count": len(data), "collection": self.collection_name})

        points = [self._create_point(row) for _, row in data.iterrows()]

        self._upsert_points(points)

        return self._build_result_dataframe(points)

    def _create_indices(self):
        self.client.create_payload_index(
            collection_name=self.collection_name,
            field_name="keywords",
            field_schema=models.TextIndexParams(
                type=models.TextIndexType.TEXT,
                tokenizer=models.TokenizerType.WHITESPACE,
            ),
        )
        self.client.create_payload_index(
            collection_name=self.collection_name,
            field_name="url",
            field_schema=models.TextIndexParams(
                type=models.TextIndexType.TEXT,
                tokenizer=models.TokenizerType.PREFIX,
                min_token_len=3,
            ),
        )
        self.client.create_payload_index(
            collection_name=self.collection_name,
            field_name="text",
            field_schema=models.TextIndexParams(
                type=models.TextIndexType.TEXT,
                tokenizer=models.TokenizerType.MULTILINGUAL,
            ),
        )
        self.client.create_payload_index(
            collection_name=self.collection_name,
            field_name="history",
            field_schema=models.TextIndexParams(type=models.TextIndexType.TEXT, tokenizer=models.TokenizerType.WORD),
        )

    def _retire_collections(self):
        collections_versioned: dict[int, str] = self._get_collection_versions()
        to_delete = list(collections_versioned.keys())[: -self.settings.COLLECTION_HISTORY_LEN]
        if not to_delete:
            return

        for col_v in to_delete:
            col = collections_versioned[col_v]
            log.info(f"deleting {col} collection caused by retirement")
            self.client.delete_collection(col)

    def _update_alias(self):
        success = self.client.update_collection_aliases(
            change_aliases_operations=[
                models.CreateAliasOperation(
                    create_alias=models.CreateAlias(
                        collection_name=self.collection_name,
                        alias_name=self.settings.COLLECTION,
                    )
                )
            ]
        )
        if not success:
            raise CustomQdrantException("Alias Update failed")

    def __construct_next_collection_name(self) -> str:
        previous_collections = self._get_collection_versions()
        if not previous_collections:
            return f"{self.settings.COLLECTION}_v1"
        previous_version = max(previous_collections.keys())
        log.info(f"Found version v{previous_version}")
        return f"{self.settings.COLLECTION}_v{previous_version + 1}"

    def _get_collection_versions(self) -> dict[int, str]:
        previous_collections = self.client.get_collections().collections
        versioned_collections = {
            int(previous.name.split("_v")[-1]): previous.name
            for previous in previous_collections
            if f"{self.settings.COLLECTION}_v" in previous.name
        }
        return dict(sorted(versioned_collections.items()))

    @staticmethod
    def get_available_hashes(text: str, encoding: str = "utf-8") -> dict:
        """Compute `n` hashes for a given input text based.
        The number `n` depends on the optionally installed python libs.
        For now only TLSH (Trend Micro Locality Sensitive Hash) is supported
        ## TLSH
        Given a byte stream with a minimum length of 50 bytes TLSH generates a hash value which can be used for similarity comparisons.

        Args:
            text (str): Input text
            encoding (str, optional): Input text will encoded to bytes using this encoding. Defaults to "utf-8".

        Returns:
            dict[str, str]: keys: `text_<algo>_hash` hash as string ! Dict might be empty!

        """
        hashes = {}
        encoded_text = text.encode(encoding)
        if HAS_TLSH:
            # pylint: disable=no-name-in-module, import-outside-toplevel
            from tlsh import hash as tlsh_hash

            hashes["text_tlsh_hash"] = tlsh_hash(encoded_text)
        hashes["text_sha256_hash"] = sha256(encoded_text).hexdigest()
        return hashes

get_available_hashes(text, encoding='utf-8') staticmethod

Compute n hashes for a given input text based. The number n depends on the optionally installed python libs. For now only TLSH (Trend Micro Locality Sensitive Hash) is supported

TLSH

Given a byte stream with a minimum length of 50 bytes TLSH generates a hash value which can be used for similarity comparisons.

Parameters:
  • text (str) –

    Input text

  • encoding (str, default: 'utf-8' ) –

    Input text will encoded to bytes using this encoding. Defaults to "utf-8".

Returns:
  • dict

    dict[str, str]: keys: text_<algo>_hash hash as string ! Dict might be empty!

Source code in wurzel/steps/qdrant/step.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
@staticmethod
def get_available_hashes(text: str, encoding: str = "utf-8") -> dict:
    """Compute `n` hashes for a given input text based.
    The number `n` depends on the optionally installed python libs.
    For now only TLSH (Trend Micro Locality Sensitive Hash) is supported
    ## TLSH
    Given a byte stream with a minimum length of 50 bytes TLSH generates a hash value which can be used for similarity comparisons.

    Args:
        text (str): Input text
        encoding (str, optional): Input text will encoded to bytes using this encoding. Defaults to "utf-8".

    Returns:
        dict[str, str]: keys: `text_<algo>_hash` hash as string ! Dict might be empty!

    """
    hashes = {}
    encoded_text = text.encode(encoding)
    if HAS_TLSH:
        # pylint: disable=no-name-in-module, import-outside-toplevel
        from tlsh import hash as tlsh_hash

        hashes["text_tlsh_hash"] = tlsh_hash(encoded_text)
    hashes["text_sha256_hash"] = sha256(encoded_text).hexdigest()
    return hashes

containing the DVCStep sending embedding data into Qdrant.

QdrantConnectorMultiVectorStep

Bases: QdrantConnectorStep, TypedStep[QdrantSettings, DataFrame[EmbeddingMultiVectorResult], DataFrame[QdrantMultiVectorResult]]

Qdrant connector step. It consumes embedding csv files, creates a new schema and inserts the embeddings.

Source code in wurzel/steps/qdrant/step_multi_vector.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class QdrantConnectorMultiVectorStep(
    QdrantConnectorStep,
    TypedStep[
        QdrantSettings,
        DataFrame[EmbeddingMultiVectorResult],
        DataFrame[QdrantMultiVectorResult],
    ],
):
    """Qdrant connector step. It consumes embedding csv files, creates a new schema and inserts the embeddings."""

    vector_key = "vectors"
    result_class = QdrantMultiVectorResult

    def _create_collection(self, size: int):
        self.client.create_collection(
            collection_name=self.collection_name,
            vectors_config=models.VectorParams(
                size=size,
                distance=self.settings.DISTANCE,
                multivector_config=models.MultiVectorConfig(comparator=models.MultiVectorComparator.MAX_SIM),
            ),
            replication_factor=self.settings.REPLICATION_FACTOR,
        )

    def run(self, inpt: DataFrame[EmbeddingMultiVectorResult]) -> DataFrame[QdrantMultiVectorResult]:
        log.debug(f"Creating Qdrant collection {self.collection_name}")
        if not self.client.collection_exists(self.collection_name):
            self._create_collection(len(inpt["vectors"].loc[0][0]))
        df_result = self._insert_embeddings(inpt)
        return df_result

    def _get_entry_payload(self, row: dict[str, object]) -> dict[str, object]:
        """Create the payload for the entry."""
        payload = super()._get_entry_payload(row)
        payload["splits"] = row["splits"]
        return payload

QdrantSettings

Bases: Settings

QdrantSettings is a configuration class for managing settings related to the Qdrant database.

Attributes:
  • DISTANCE (Distance) –

    The distance metric to be used, default is Distance.DOT.

  • URI (str) –

    The URI for the Qdrant database, default is "http://localhost:6333".

  • COLLECTION (str) –

    The name of the collection in the Qdrant database.

  • COLLECTION_HISTORY_LEN (int) –

    The length of the collection history, default is 10.

  • SEARCH_PARAMS (dict) –

    Parameters for search operations, default is {"metric_type": "IP", "params": {}}.

  • INDEX_PARAMS (dict) –

    Parameters for index creation, default includes "index_type", "field_name", "distance", and "params".

  • APIKEY (Optional[str]) –

    The API key for authentication, default is an empty string.

  • REPLICATION_FACTOR (int) –

    The replication factor for the database, default is 3, must be greater than 0.

  • BATCH_SIZE (int) –

    The batch size for operations, default is 1024, must be greater than 0.

Methods:

Name Description
parse_json

Validates and parses JSON strings into Python objects for SEARCH_PARAMS and INDEX_PARAMS.

Parameters:
  • v (Union[str, dict]) –

    The input value, either a JSON string or a dictionary.

Returns:
  • dict

    The parsed dictionary if the input is a JSON string, otherwise the input value.

Source code in wurzel/steps/qdrant/settings.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class QdrantSettings(Settings):
    """QdrantSettings is a configuration class for managing settings related to the Qdrant database.

    Attributes:
        DISTANCE (Distance): The distance metric to be used, default is Distance.DOT.
        URI (str): The URI for the Qdrant database, default is "http://localhost:6333".
        COLLECTION (str): The name of the collection in the Qdrant database.
        COLLECTION_HISTORY_LEN (int): The length of the collection history, default is 10.
        SEARCH_PARAMS (dict): Parameters for search operations, default is {"metric_type": "IP", "params": {}}.
        INDEX_PARAMS (dict): Parameters for index creation, default includes "index_type", "field_name", "distance", and "params".
        APIKEY (Optional[str]): The API key for authentication, default is an empty string.
        REPLICATION_FACTOR (int): The replication factor for the database, default is 3, must be greater than 0.
        BATCH_SIZE (int): The batch size for operations, default is 1024, must be greater than 0.

    Methods:
        parse_json(v):
            Validates and parses JSON strings into Python objects for SEARCH_PARAMS and INDEX_PARAMS.

    Args:
                v (Union[str, dict]): The input value, either a JSON string or a dictionary.

    Returns:
                dict: The parsed dictionary if the input is a JSON string, otherwise the input value.

    """

    DISTANCE: Distance = Distance.DOT
    URI: str = "http://localhost:6333"
    COLLECTION: str
    COLLECTION_HISTORY_LEN: int = 10
    SEARCH_PARAMS: dict = {"metric_type": "IP", "params": {}}
    INDEX_PARAMS: dict = {
        "index_type": "FLAT",
        "field_name": "vector",
        "distance": "Dot",
        "params": {},
    }
    APIKEY: Optional[str] = ""
    REPLICATION_FACTOR: int = Field(default=3, gt=0)
    BATCH_SIZE: int = Field(default=1024, gt=0)

    @field_validator("SEARCH_PARAMS", "INDEX_PARAMS", mode="before")
    @classmethod
    def parse_json(cls, v):
        """Validation for json."""
        if isinstance(v, str):
            return json.loads(v)
        return v

parse_json(v) classmethod

Validation for json.

Source code in wurzel/steps/qdrant/settings.py
56
57
58
59
60
61
62
@field_validator("SEARCH_PARAMS", "INDEX_PARAMS", mode="before")
@classmethod
def parse_json(cls, v):
    """Validation for json."""
    if isinstance(v, str):
        return json.loads(v)
    return v