# Copyright (c) 2023-2024 Philip May
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT
"""ArangoDB utils module.
Hint:
Use pip to install the necessary dependencies for this module:
``pip install mltb2[arangodb]``
"""
import gzip
from argparse import ArgumentParser
from collections.abc import Sequence
from contextlib import closing
from dataclasses import dataclass
from typing import Any, Optional, Union
import jsonlines
from arango import ArangoClient
from arango.database import StandardDatabase
from dotenv import dotenv_values
from pandas import DataFrame
from tqdm import tqdm
from mltb2.db import AbstractBatchDataManager
[docs]
def _check_config_keys(config: dict[str, Optional[str]], expected_config_keys: Sequence[str]) -> None:
"""Check if all expected keys are in config.
This is useful to check if a config file contains all necessary keys.
"""
for expected_config_key in expected_config_keys:
if expected_config_key not in config:
raise ValueError(f"Config file must contain '{expected_config_key}'!")
[docs]
@dataclass
class ArangoConnectionManager:
"""ArangoDB connection manager.
Base class to manage / create ArangoDB connections.
Args:
hosts: ArangoDB host or hosts.
db_name: ArangoDB database name.
username: ArangoDB username.
password: ArangoDB password.
"""
hosts: Union[str, Sequence[str]]
db_name: str
username: str
password: str
[docs]
def _arango_client_factory(self) -> ArangoClient:
"""Create an ArangoDB client."""
arango_client = ArangoClient(hosts=self.hosts)
return arango_client
[docs]
def _connection_factory(self, arango_client: ArangoClient) -> StandardDatabase:
"""Create an ArangoDB connection.
Args:
arango_client: ArangoDB client.
"""
connection = arango_client.db(self.db_name, username=self.username, password=self.password)
return connection
[docs]
@dataclass
class ArangoBatchDataManager(AbstractBatchDataManager, ArangoConnectionManager):
"""ArangoDB implementation of the ``AbstractBatchDataManager``.
Args:
hosts: ArangoDB host or hosts.
db_name: ArangoDB database name.
username: ArangoDB username.
password: ArangoDB password.
collection_name: Documents from this collection are processed.
attribute_name: This attribute is used to check if a document is already processed.
If the attribute is not present in a document, the document is processed.
If it is available the document is considered as already processed.
batch_size: The batch size.
aql_overwrite: AQL string to overwrite the default.
"""
collection_name: str
attribute_name: str
batch_size: int = 20
aql_overwrite: Optional[str] = None
[docs]
@classmethod
def from_config_file(cls, config_file_name, aql_overwrite: Optional[str] = None):
"""Construct this from config file.
The config file must contain these values:
- ``hosts``
- ``db_name``
- ``username``
- ``password``
- ``collection_name``
- ``attribute_name``
- ``batch_size``
Config file example:
.. code-block::
hosts="https://arangodb.com"
db_name="my_ml_database"
username="my_username"
password="secret"
collection_name="my_ml_data_collection"
attribute_name="processing_metadata"
batch_size=100
Args:
config_file_name: The config file name (path).
aql_overwrite: AQL string to overwrite the default.
"""
# load config file data
arango_config = dotenv_values(config_file_name)
# check if all necessary keys are in config file
expected_config_file_keys = [
"hosts",
"db_name",
"username",
"password",
"collection_name",
"attribute_name",
"batch_size",
]
_check_config_keys(arango_config, expected_config_file_keys)
return cls(
hosts=arango_config["hosts"], # type: ignore[arg-type]
db_name=arango_config["db_name"], # type: ignore[arg-type]
username=arango_config["username"], # type: ignore[arg-type]
password=arango_config["password"], # type: ignore[arg-type]
collection_name=arango_config["collection_name"], # type: ignore[arg-type]
attribute_name=arango_config["attribute_name"], # type: ignore[arg-type]
batch_size=int(arango_config["batch_size"]), # type: ignore[arg-type]
aql_overwrite=aql_overwrite,
)
[docs]
def load_batch(self) -> Sequence:
"""Load a batch of data from the ArangoDB database.
Returns:
The loaded batch of data.
"""
with closing(self._arango_client_factory()) as arango_client:
connection = self._connection_factory(arango_client)
bind_vars = {
"@coll": self.collection_name,
"attribute": self.attribute_name,
"batch_size": self.batch_size,
}
if self.aql_overwrite is None:
aql = "FOR doc IN @@coll FILTER !HAS(doc, @attribute) LIMIT @batch_size RETURN doc"
else:
aql = self.aql_overwrite
cursor = connection.aql.execute(
aql,
bind_vars=bind_vars, # type: ignore[arg-type]
batch_size=self.batch_size,
)
with closing(cursor) as closing_cursor: # type: ignore[type-var]
batch = closing_cursor.batch() # type: ignore[union-attr]
return batch # type: ignore[return-value]
[docs]
def save_batch(self, batch: Sequence) -> None:
"""Save a batch of data to the ArangoDB database.
Args:
batch: The batch of data to save.
"""
with closing(self._arango_client_factory()) as arango_client:
connection = self._connection_factory(arango_client)
collection = connection.collection(self.collection_name)
collection.import_bulk(batch, on_duplicate="update")
[docs]
def arango_collection_backup() -> None:
"""Commandline tool to do an ArangoDB backup of a collection.
The backup is written to a gzip compressed JSONL file in the current working directory.
Run ``arango-col-backup -h`` to get command line help.
"""
# argument parsing
description = (
"ArangoDB backup of a collection. "
"The backup is written to a gzip compressed JSONL file in the current working directory."
)
argument_parser = ArgumentParser(description=description)
argument_parser.add_argument(
"--conf", type=str, required=True, help="Config file containing 'hosts', 'db_name', 'username' and 'password'."
)
argument_parser.add_argument("--col", type=str, required=True, help="Collection name to backup.")
args = argument_parser.parse_args()
# load and check config file
arango_config = dotenv_values(args.conf)
expected_config_file_keys = ["hosts", "db_name", "username", "password"]
_check_config_keys(arango_config, expected_config_file_keys)
output_file_name = f"./{args.col}_backup.jsonl.gz"
print(f"Writing backup to '{output_file_name}'...")
with (
closing(ArangoClient(hosts=arango_config["hosts"])) as arango_client, # type: ignore[arg-type]
gzip.open(output_file_name, "w") as gzip_out,
):
connection = arango_client.db(
arango_config["db_name"], # type: ignore[arg-type]
arango_config["username"], # type: ignore[arg-type]
arango_config["password"], # type: ignore[arg-type]
)
jsonlines_writer = jsonlines.Writer(gzip_out) # type: ignore[arg-type]
try:
cursor = connection.aql.execute(
"FOR doc IN @@coll RETURN doc",
bind_vars={"@coll": args.col},
batch_size=100,
max_runtime=60 * 60, # type: ignore[arg-type] # 1 hour
stream=True,
)
for doc in tqdm(cursor):
jsonlines_writer.write(doc)
finally:
cursor.close(ignore_missing=True) # type: ignore[union-attr]
[docs]
@dataclass
class ArangoImportDataManager(ArangoConnectionManager):
"""ArangoDB import tool to fill data into a collection.
Args:
hosts: ArangoDB host or hosts.
db_name: ArangoDB database name.
username: ArangoDB username.
password: ArangoDB password.
"""
[docs]
@classmethod
def from_config_file(cls, config_file_name):
"""Construct this from config file.
The config file must contain at least these values:
- ``hosts``
- ``db_name``
- ``username``
- ``password``
Config file example:
.. code-block::
hosts="https://arangodb.com"
db_name="my_ml_database"
username="my_username"
password="secret"
Args:
config_file_name: The config file name (path).
"""
# load config file data
arango_config = dotenv_values(config_file_name)
# check if all necessary keys are in config file
expected_config_file_keys = [
"hosts",
"db_name",
"username",
"password",
]
_check_config_keys(arango_config, expected_config_file_keys)
return cls(
hosts=arango_config["hosts"],
db_name=arango_config["db_name"],
username=arango_config["username"],
password=arango_config["password"],
)
[docs]
def import_dicts(
self, dicts: Sequence[dict[str, Any]], collection_name: str, create_collection: bool = False
) -> None:
"""Import data to ArangoDB.
Args:
dicts: The data to import.
collection_name: The collection name to import to.
create_collection: If ``True`` the collection is created if it does not exist.
Raises:
arango.exceptions.DocumentInsertError: If import fails.
"""
with closing(self._arango_client_factory()) as arango_client:
connection = self._connection_factory(arango_client)
# get (or create) collection
if not connection.has_collection(collection_name):
if create_collection:
collection = connection.create_collection(collection_name)
else:
raise ValueError(
f"Collection '{collection_name}' does not exist! "
"Create it or specify 'create_collection=True'."
)
else:
collection = connection.collection(collection_name)
collection.import_bulk( # type: ignore[union-attr]
dicts,
halt_on_error=True,
details=False,
overwrite=False,
on_duplicate="error",
sync=True,
batch_size=100,
)
[docs]
def import_dataframe(self, dataframe: DataFrame, collection_name: str, create_collection: bool = False) -> None:
"""Import Pandas data to ArangoDB.
Args:
dataframe: The Pandas data to import.
collection_name: The collection name to import to.
create_collection: If ``True`` the collection is created if it does not exist.
Raises:
arango.exceptions.DocumentInsertError: If import fails.
"""
dicts = dataframe.to_dict(orient="records")
self.import_dicts(dicts, collection_name, create_collection)