26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154 | class MilvusConnectorStep(TypedStep[MilvusSettings, DataFrame[EmbeddingResult], MilvusResult]): # pragma: no cover
"""Milvus connector step. It consumes embedding csv files, creates a new schema and inserts the embeddings."""
milvus_timeout: float = 20.0
def __init__(self) -> None:
super().__init__()
# milvus stuff passed as environment
# because we need to enject them into the DVC step during runtime,
# not during DVC pipeline definition time
uri = f"http://{self.settings.HOST}:{self.settings.PORT}"
if not self.settings.PASSWORD or not self.settings.USER:
log.warning("MILVUS_HOST, MILVUS_USER or MILVUS_PASSWORD for Milvus not provided. Thus running in non-credential Mode")
self.client: MilvusClient = MilvusClient(
uri=uri,
user=self.settings.USER,
password=self.settings.PASSWORD,
timeout=self.milvus_timeout,
)
self.collection_index: IndexParams = IndexParams(**self.settings.INDEX_PARAMS)
self.collection_history_len = self.settings.COLLECTION_HISTORY_LEN
self.collection_prefix = self.settings.COLLECTION
def __del__(self):
if getattr(self, "client", None):
self.client.close()
def run(self, inpt: DataFrame[EmbeddingResult]) -> MilvusResult:
self._insert_embeddings(inpt)
try:
old = self.__construct_last_collection_name()
except NoPreviousCollection:
old = ""
self._retire_collection()
return MilvusResult(new=self.__construct_current_collection_name(), old=old)
def _insert_embeddings(self, data: pd.DataFrame):
collection_name = self.__construct_next_collection_name()
log.info(f"Creating milvus collection {collection_name}")
collection_schema = CollectionSchema(
fields=[
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=3000),
FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=len(data["vector"].loc[0]),
),
FieldSchema(name="url", dtype=DataType.VARCHAR, max_length=300),
],
description="Collection for storing Milvus embeddings",
)
log.info("schema created")
self.client.create_collection(collection_name=collection_name, schema=collection_schema)
log.info("collection created")
log.info(f"Inserting embedding {len(data)} into collection {collection_name}")
result: dict = self.client.insert(collection_name=collection_name, data=data.to_dict("records"))
if result["insert_count"] != len(data):
raise StepFailed(
f"Failed to insert df into collection '{collection_name}'.{result['insert_count']}/{len(data)} where successful"
)
log.info(f"Successfully inserted {len(data)} vectors into collection '{collection_name}'")
self.client.create_index(collection_name=collection_name, index_params=self.collection_index)
log.info(f"Successfully craeted index {self.collection_index} into collection '{collection_name}")
self.client.load_collection(collection_name)
log.info(f"Successfully loaded the collection {collection_name}' into collection '{collection_name}'")
try:
self.client.release_collection(self.__construct_last_collection_name())
except NoPreviousCollection:
pass
self._update_alias(collection_name)
def _retire_collection(self):
collections_versioned: dict[int, str] = self._get_collection_versions()
to_delete = sorted(collections_versioned.keys())[: -self.collection_history_len]
if not to_delete:
return
for col_v in to_delete:
col = collections_versioned[col_v]
log.info(f"deleting {col} collection caused by retirement")
self.client.drop_collection(col, timeout=self.milvus_timeout)
def _update_alias(self, collection_name):
try:
self.client.create_alias(
collection_name=collection_name,
alias=self.collection_prefix,
timeout=self.milvus_timeout,
)
except MilvusException:
self.client.alter_alias(
collection_name=collection_name,
alias=self.collection_prefix,
timeout=self.milvus_timeout,
)
def __construct_next_collection_name(self) -> str:
previous_collections = self._get_collection_versions()
if not previous_collections:
return f"{self.collection_prefix}_v1"
previous_version = max(previous_collections.keys())
log.info(f"Found version v{previous_version}")
return f"{self.collection_prefix}_v{previous_version + 1}"
def __construct_last_collection_name(self) -> str:
previous_collections = self._get_collection_versions()
if not previous_collections or len(previous_collections) <= 1:
raise NoPreviousCollection(f"Milvus does not contain a previous collection for {self.collection_prefix}")
previous_version = sorted(previous_collections.keys())[-2]
log.info(f"Found previous version v{previous_version}")
return f"{self.collection_prefix}_v{previous_version}"
def __construct_current_collection_name(self) -> str:
previous_collections = self._get_collection_versions()
if not previous_collections or len(previous_collections) < 1:
raise NoPreviousCollection(f"Milvus does not contain a previous collection for {self.collection_prefix}")
previous_version = sorted(previous_collections.keys())[-1]
log.info(f"Found previous version v{previous_version}")
return f"{self.collection_prefix}_v{previous_version}"
def _get_collection_versions(self) -> dict[int, str]:
previous_collections = self.client.list_collections(timeout=self.milvus_timeout)
versioned_collections = {
int(previous.split("_v")[-1]): previous for previous in previous_collections if self.collection_prefix in previous
}
return versioned_collections
|