|
4 | 4 | from langchain_core.stores import BaseStore
|
5 | 5 |
|
6 | 6 |
|
| 7 | +class MongoDBByteStore(BaseStore[str, bytes]): |
| 8 | + """BaseStore implementation using MongoDB as the underlying store. |
| 9 | +
|
| 10 | + Examples: |
| 11 | + Create a MongoDBByteStore instance and perform operations on it: |
| 12 | +
|
| 13 | + .. code-block:: python |
| 14 | +
|
| 15 | + # Instantiate the MongoDBByteStore with a MongoDB connection |
| 16 | + from langchain.storage import MongoDBByteStore |
| 17 | +
|
| 18 | + mongo_conn_str = "mongodb://localhost:27017/" |
| 19 | + mongodb_store = MongoDBBytesStore(mongo_conn_str, db_name="test-db", |
| 20 | + collection_name="test-collection") |
| 21 | +
|
| 22 | + # Set values for keys |
| 23 | + mongodb_store.mset([("key1", "hello"), ("key2", "workd")]) |
| 24 | +
|
| 25 | + # Get values for keys |
| 26 | + values = mongodb_store.mget(["key1", "key2"]) |
| 27 | + # [bytes1, bytes1] |
| 28 | +
|
| 29 | + # Iterate over keys |
| 30 | + for key in mongodb_store.yield_keys(): |
| 31 | + print(key) |
| 32 | +
|
| 33 | + # Delete keys |
| 34 | + mongodb_store.mdelete(["key1", "key2"]) |
| 35 | + """ |
| 36 | + |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + connection_string: str, |
| 40 | + db_name: str, |
| 41 | + collection_name: str, |
| 42 | + *, |
| 43 | + client_kwargs: Optional[dict] = None, |
| 44 | + ) -> None: |
| 45 | + """Initialize the MongoDBStore with a MongoDB connection string. |
| 46 | +
|
| 47 | + Args: |
| 48 | + connection_string (str): MongoDB connection string |
| 49 | + db_name (str): name to use |
| 50 | + collection_name (str): collection name to use |
| 51 | + client_kwargs (dict): Keyword arguments to pass to the Mongo client |
| 52 | + """ |
| 53 | + try: |
| 54 | + from pymongo import MongoClient |
| 55 | + except ImportError as e: |
| 56 | + raise ImportError( |
| 57 | + "The MongoDBStore requires the pymongo library to be " |
| 58 | + "installed. " |
| 59 | + "pip install pymongo" |
| 60 | + ) from e |
| 61 | + |
| 62 | + if not connection_string: |
| 63 | + raise ValueError("connection_string must be provided.") |
| 64 | + if not db_name: |
| 65 | + raise ValueError("db_name must be provided.") |
| 66 | + if not collection_name: |
| 67 | + raise ValueError("collection_name must be provided.") |
| 68 | + |
| 69 | + self.client: MongoClient = MongoClient( |
| 70 | + connection_string, **(client_kwargs or {}) |
| 71 | + ) |
| 72 | + self.collection = self.client[db_name][collection_name] |
| 73 | + |
| 74 | + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: |
| 75 | + """Get the list of documents associated with the given keys. |
| 76 | +
|
| 77 | + Args: |
| 78 | + keys (list[str]): A list of keys representing Document IDs.. |
| 79 | +
|
| 80 | + Returns: |
| 81 | + list[Document]: A list of Documents corresponding to the provided |
| 82 | + keys, where each Document is either retrieved successfully or |
| 83 | + represented as None if not found. |
| 84 | + """ |
| 85 | + result = self.collection.find({"_id": {"$in": keys}}) |
| 86 | + result_dict = {doc["_id"]: doc["value"] for doc in result} |
| 87 | + return [result_dict.get(key) for key in keys] |
| 88 | + |
| 89 | + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: |
| 90 | + """Set the given key-value pairs. |
| 91 | +
|
| 92 | + Args: |
| 93 | + key_value_pairs (list[tuple[str, Document]]): A list of id-document |
| 94 | + pairs. |
| 95 | + """ |
| 96 | + from pymongo import UpdateOne |
| 97 | + |
| 98 | + updates = [{"_id": k, "value": v} for k, v in key_value_pairs] |
| 99 | + self.collection.bulk_write( |
| 100 | + [UpdateOne({"_id": u["_id"]}, {"$set": u}, upsert=True) for u in updates] |
| 101 | + ) |
| 102 | + |
| 103 | + def mdelete(self, keys: Sequence[str]) -> None: |
| 104 | + """Delete the given ids. |
| 105 | +
|
| 106 | + Args: |
| 107 | + keys (list[str]): A list of keys representing Document IDs.. |
| 108 | + """ |
| 109 | + self.collection.delete_many({"_id": {"$in": keys}}) |
| 110 | + |
| 111 | + def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]: |
| 112 | + """Yield keys in the store. |
| 113 | +
|
| 114 | + Args: |
| 115 | + prefix (str): prefix of keys to retrieve. |
| 116 | + """ |
| 117 | + if prefix is None: |
| 118 | + for doc in self.collection.find(projection=["_id"]): |
| 119 | + yield doc["_id"] |
| 120 | + else: |
| 121 | + for doc in self.collection.find( |
| 122 | + {"_id": {"$regex": f"^{prefix}"}}, projection=["_id"] |
| 123 | + ): |
| 124 | + yield doc["_id"] |
| 125 | + |
| 126 | + |
7 | 127 | class MongoDBStore(BaseStore[str, Document]):
|
8 | 128 | """BaseStore implementation using MongoDB as the underlying store.
|
9 | 129 |
|
@@ -68,7 +188,9 @@ def __init__(
|
68 | 188 | if not collection_name:
|
69 | 189 | raise ValueError("collection_name must be provided.")
|
70 | 190 |
|
71 |
| - self.client = MongoClient(connection_string, **(client_kwargs or {})) |
| 191 | + self.client: MongoClient = MongoClient( |
| 192 | + connection_string, **(client_kwargs or {}) |
| 193 | + ) |
72 | 194 | self.collection = self.client[db_name][collection_name]
|
73 | 195 |
|
74 | 196 | def mget(self, keys: Sequence[str]) -> List[Optional[Document]]:
|
|
0 commit comments