containing the DVCStep sending embedding data into milvus.

MilvusConnectorStep

Bases: TypedStep[MilvusSettings, DataFrame[EmbeddingResult], Result]

Milvus connector step. It consumes embedding csv files, creates a new schema and inserts the embeddings.

Source code in wurzel/steps/milvus/step.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class MilvusConnectorStep(TypedStep[MilvusSettings, DataFrame[EmbeddingResult], MilvusResult]):  # pragma: no cover
    """Milvus connector step. It consumes embedding csv files, creates a new schema and inserts the embeddings."""

    milvus_timeout: float = 20.0

    def __init__(self) -> None:
        super().__init__()
        # milvus stuff passed as environment
        # because we need to enject them into the DVC step during runtime,
        # not during DVC pipeline definition time
        uri = f"http://{self.settings.HOST}:{self.settings.PORT}"
        if not self.settings.PASSWORD or not self.settings.USER:
            log.warning("MILVUS_HOST, MILVUS_USER or MILVUS_PASSWORD for Milvus not provided. Thus running in non-credential Mode")
        self.client: MilvusClient = MilvusClient(
            uri=uri,
            user=self.settings.USER,
            password=self.settings.PASSWORD,
            timeout=self.milvus_timeout,
        )
        self.collection_index: IndexParams = IndexParams(**self.settings.INDEX_PARAMS)
        self.collection_history_len = self.settings.COLLECTION_HISTORY_LEN

        self.collection_prefix = self.settings.COLLECTION

    def __del__(self):
        if getattr(self, "client", None):
            self.client.close()

    def run(self, inpt: DataFrame[EmbeddingResult]) -> MilvusResult:
        self._insert_embeddings(inpt)
        try:
            old = self.__construct_last_collection_name()
        except NoPreviousCollection:
            old = ""
        self._retire_collection()
        return MilvusResult(new=self.__construct_current_collection_name(), old=old)

    def _insert_embeddings(self, data: pd.DataFrame):
        collection_name = self.__construct_next_collection_name()
        log.info(f"Creating milvus collection {collection_name}")
        collection_schema = CollectionSchema(
            fields=[
                FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
                FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=3000),
                FieldSchema(
                    name="vector",
                    dtype=DataType.FLOAT_VECTOR,
                    dim=len(data["vector"].loc[0]),
                ),
                FieldSchema(name="url", dtype=DataType.VARCHAR, max_length=300),
            ],
            description="Collection for storing Milvus embeddings",
        )

        log.info("schema created")
        self.client.create_collection(collection_name=collection_name, schema=collection_schema)
        log.info("collection created")
        log.info(f"Inserting embedding {len(data)} into collection {collection_name}")
        result: dict = self.client.insert(collection_name=collection_name, data=data.to_dict("records"))
        if result["insert_count"] != len(data):
            raise StepFailed(
                f"Failed to insert df into collection '{collection_name}'.{result['insert_count']}/{len(data)} where successful"
            )
        log.info(f"Successfully inserted {len(data)} vectors into collection '{collection_name}'")
        self.client.create_index(collection_name=collection_name, index_params=self.collection_index)
        log.info(f"Successfully craeted index {self.collection_index} into collection '{collection_name}")
        self.client.load_collection(collection_name)
        log.info(f"Successfully loaded the collection {collection_name}' into collection '{collection_name}'")
        try:
            self.client.release_collection(self.__construct_last_collection_name())
        except NoPreviousCollection:
            pass
        self._update_alias(collection_name)

    def _retire_collection(self):
        collections_versioned: dict[int, str] = self._get_collection_versions()
        to_delete = sorted(collections_versioned.keys())[: -self.collection_history_len]
        if not to_delete:
            return

        for col_v in to_delete:
            col = collections_versioned[col_v]
            log.info(f"deleting {col} collection caused by retirement")
            self.client.drop_collection(col, timeout=self.milvus_timeout)

    def _update_alias(self, collection_name):
        try:
            self.client.create_alias(
                collection_name=collection_name,
                alias=self.collection_prefix,
                timeout=self.milvus_timeout,
            )
        except MilvusException:
            self.client.alter_alias(
                collection_name=collection_name,
                alias=self.collection_prefix,
                timeout=self.milvus_timeout,
            )

    def __construct_next_collection_name(self) -> str:
        previous_collections = self._get_collection_versions()
        if not previous_collections:
            return f"{self.collection_prefix}_v1"
        previous_version = max(previous_collections.keys())
        log.info(f"Found version v{previous_version}")
        return f"{self.collection_prefix}_v{previous_version + 1}"

    def __construct_last_collection_name(self) -> str:
        previous_collections = self._get_collection_versions()
        if not previous_collections or len(previous_collections) <= 1:
            raise NoPreviousCollection(f"Milvus does not contain a previous collection for {self.collection_prefix}")
        previous_version = sorted(previous_collections.keys())[-2]
        log.info(f"Found previous version v{previous_version}")
        return f"{self.collection_prefix}_v{previous_version}"

    def __construct_current_collection_name(self) -> str:
        previous_collections = self._get_collection_versions()
        if not previous_collections or len(previous_collections) < 1:
            raise NoPreviousCollection(f"Milvus does not contain a previous collection for {self.collection_prefix}")
        previous_version = sorted(previous_collections.keys())[-1]
        log.info(f"Found previous version v{previous_version}")
        return f"{self.collection_prefix}_v{previous_version}"

    def _get_collection_versions(self) -> dict[int, str]:
        previous_collections = self.client.list_collections(timeout=self.milvus_timeout)
        versioned_collections = {
            int(previous.split("_v")[-1]): previous for previous in previous_collections if self.collection_prefix in previous
        }
        return versioned_collections

MilvusSettings

Bases: Settings

MilvusSettings is a configuration class for managing settings related to MilvusDB.

Attributes:
  • HOST (str) –

    The hostname or IP address of the Milvus server. Defaults to "localhost".

  • PORT (int) –

    The port number for the Milvus server. Must be between 1 and 65535. Defaults to 19530.

  • COLLECTION (str) –

    The name of the collection in MilvusDB.

  • COLLECTION_HISTORY_LEN (int) –

    The length of the collection history. Defaults to 10.

  • SEARCH_PARAMS (dict) –

    Parameters for search operations in MilvusDB. Defaults to {"metric_type": "IP", "params": {}}.

  • INDEX_PARAMS (dict) –

    Parameters for indexing operations in MilvusDB. Defaults to {"index_type": "FLAT", "field_name": "vector", "metric_type": "IP", "params": {}}.

  • USER (str) –

    The username for authentication with MilvusDB.

  • PASSWORD (str) –

    The password for authentication with MilvusDB.

  • SECURED (bool) –

    Indicates whether the connection to MilvusDB is secured. Defaults to False.

Methods:

Name Description
parse_json

Validates and parses JSON strings into Python objects for SEARCH_PARAMS and INDEX_PARAMS.

Parameters:
  • v (str or dict) –

    The value to validate and parse.

Returns:
  • dict

    The parsed dictionary if the input is a JSON string, or the original value if it is already a dictionary.

Source code in wurzel/steps/milvus/settings.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class MilvusSettings(Settings):
    """MilvusSettings is a configuration class for managing settings related to MilvusDB.

    Attributes:
        HOST (str): The hostname or IP address of the Milvus server. Defaults to "localhost".
        PORT (int): The port number for the Milvus server. Must be between 1 and 65535. Defaults to 19530.
        COLLECTION (str): The name of the collection in MilvusDB.
        COLLECTION_HISTORY_LEN (int): The length of the collection history. Defaults to 10.
        SEARCH_PARAMS (dict): Parameters for search operations in MilvusDB. Defaults to {"metric_type": "IP", "params": {}}.
        INDEX_PARAMS (dict): Parameters for indexing operations in MilvusDB. Defaults to {"index_type": "FLAT",
                                "field_name": "vector", "metric_type": "IP", "params": {}}.
        USER (str): The username for authentication with MilvusDB.
        PASSWORD (str): The password for authentication with MilvusDB.
        SECURED (bool): Indicates whether the connection to MilvusDB is secured. Defaults to False.

    Methods:
        parse_json(cls, v): Validates and parses JSON strings into Python objects for SEARCH_PARAMS and INDEX_PARAMS.

    Args:
                v (str or dict): The value to validate and parse.

    Returns:
                dict: The parsed dictionary if the input is a JSON string, or the original value if it is already a dictionary.

    """

    HOST: str = "localhost"
    PORT: int = Field(19530, gt=0, le=65535)
    COLLECTION: str
    COLLECTION_HISTORY_LEN: int = 10
    SEARCH_PARAMS: dict = {"metric_type": "IP", "params": {}}
    INDEX_PARAMS: dict = {
        "index_type": "FLAT",
        "field_name": "vector",
        "metric_type": "IP",
        "params": {},
    }
    USER: str
    PASSWORD: str
    SECURED: bool = False

    @field_validator("SEARCH_PARAMS", "INDEX_PARAMS", mode="before")
    @classmethod
    # pylint: disable-next=R0801
    def parse_json(cls, v):
        """Validation for json."""
        if isinstance(v, str):
            return json.loads(v)
        return v

parse_json(v) classmethod

Validation for json.

Source code in wurzel/steps/milvus/settings.py
54
55
56
57
58
59
60
61
@field_validator("SEARCH_PARAMS", "INDEX_PARAMS", mode="before")
@classmethod
# pylint: disable-next=R0801
def parse_json(cls, v):
    """Validation for json."""
    if isinstance(v, str):
        return json.loads(v)
    return v