Skip to content

Commit

Permalink
Add async methods to BaseStore (langchain-ai#16669)
Browse files Browse the repository at this point in the history
- **Description:**

The BaseStore methods are currently blocking. Some implementations
(AstraDBStore, RedisStore) would benefit from having async methods.
Also once we have async methods for BaseStore, we can implement the
async `aembed_documents` in CacheBackedEmbeddings to cache the
embeddings asynchronously.

* adds async methods amget, amset, amedelete and ayield_keys to
BaseStore
  * implements the async methods for InMemoryStore
  * adds tests for InMemoryStore async methods

- **Twitter handle:** cbornet_
  • Loading branch information
cbornet authored Feb 1, 2024
1 parent 17e8863 commit a0ec045
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 1 deletion.
64 changes: 63 additions & 1 deletion libs/core/langchain_core/stores.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from abc import ABC, abstractmethod
from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import (
AsyncIterator,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

from langchain_core.runnables import run_in_executor

K = TypeVar("K")
V = TypeVar("V")
Expand All @@ -20,6 +32,18 @@ def mget(self, keys: Sequence[K]) -> List[Optional[V]]:
If a key is not found, the corresponding value will be None.
"""

async def amget(self, keys: Sequence[K]) -> List[Optional[V]]:
"""Get the values associated with the given keys.
Args:
keys (Sequence[K]): A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
return await run_in_executor(None, self.mget, keys)

@abstractmethod
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
"""Set the values for the given keys.
Expand All @@ -28,6 +52,14 @@ def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""

async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""
return await run_in_executor(None, self.mset, key_value_pairs)

@abstractmethod
def mdelete(self, keys: Sequence[K]) -> None:
"""Delete the given keys and their associated values.
Expand All @@ -36,6 +68,14 @@ def mdelete(self, keys: Sequence[K]) -> None:
keys (Sequence[K]): A sequence of keys to delete.
"""

async def amdelete(self, keys: Sequence[K]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[K]): A sequence of keys to delete.
"""
return await run_in_executor(None, self.mdelete, keys)

@abstractmethod
def yield_keys(
self, *, prefix: Optional[str] = None
Expand All @@ -52,5 +92,27 @@ def yield_keys(
depending on what makes more sense for the given store.
"""

async def ayield_keys(
self, *, prefix: Optional[str] = None
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
"""Get an iterator over keys that match the given prefix.
Args:
prefix (str): The prefix to match.
Returns:
Iterator[K | str]: An iterator over keys that match the given prefix.
This method is allowed to return an iterator over either K or str
depending on what makes more sense for the given store.
"""
iterator = await run_in_executor(None, self.yield_keys, prefix=prefix)
done = object()
while True:
item = await run_in_executor(None, lambda it: next(it, done), iterator)
if item is done:
break
yield item


ByteStore = BaseStore[str, bytes]
49 changes: 49 additions & 0 deletions libs/langchain/langchain/storage/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
from typing import (
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
Expand Down Expand Up @@ -60,6 +61,18 @@ def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
"""
return [self.store.get(key) for key in keys]

async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
"""Get the values associated with the given keys.
Args:
keys (Sequence[str]): A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
return self.mget(keys)

def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
"""Set the values for the given keys.
Expand All @@ -72,6 +85,17 @@ def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
for key, value in key_value_pairs:
self.store[key] = value

async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.
Returns:
None
"""
return self.mset(key_value_pairs)

def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values.
Expand All @@ -82,6 +106,14 @@ def mdelete(self, keys: Sequence[str]) -> None:
if key in self.store:
del self.store[key]

async def amdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[str]): A sequence of keys to delete.
"""
self.mdelete(keys)

def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
"""Get an iterator over keys that match the given prefix.
Expand All @@ -98,6 +130,23 @@ def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
if key.startswith(prefix):
yield key

async def ayield_keys(self, prefix: Optional[str] = None) -> AsyncIterator[str]:
"""Get an async iterator over keys that match the given prefix.
Args:
prefix (str, optional): The prefix to match. Defaults to None.
Returns:
AsyncIterator[str]: An async iterator over keys that match the given prefix.
"""
if prefix is None:
for key in self.store.keys():
yield key
else:
for key in self.store.keys():
if key.startswith(prefix):
yield key


InMemoryStore = InMemoryBaseStore[Any]
InMemoryByteStore = InMemoryBaseStore[bytes]
47 changes: 47 additions & 0 deletions libs/langchain/tests/unit_tests/storage/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ def test_mget() -> None:
assert non_existent_value == [None]


async def test_amget() -> None:
store = InMemoryStore()
await store.amset([("key1", "value1"), ("key2", "value2")])

values = await store.amget(["key1", "key2"])
assert values == ["value1", "value2"]

# Test non-existent key
non_existent_value = await store.amget(["key3"])
assert non_existent_value == [None]


def test_mset() -> None:
store = InMemoryStore()
store.mset([("key1", "value1"), ("key2", "value2")])
Expand All @@ -21,6 +33,14 @@ def test_mset() -> None:
assert values == ["value1", "value2"]


async def test_amset() -> None:
store = InMemoryStore()
await store.amset([("key1", "value1"), ("key2", "value2")])

values = await store.amget(["key1", "key2"])
assert values == ["value1", "value2"]


def test_mdelete() -> None:
store = InMemoryStore()
store.mset([("key1", "value1"), ("key2", "value2")])
Expand All @@ -34,6 +54,19 @@ def test_mdelete() -> None:
store.mdelete(["key3"]) # No error should be raised


async def test_amdelete() -> None:
store = InMemoryStore()
await store.amset([("key1", "value1"), ("key2", "value2")])

await store.amdelete(["key1"])

values = await store.amget(["key1", "key2"])
assert values == [None, "value2"]

# Test deleting non-existent key
await store.amdelete(["key3"]) # No error should be raised


def test_yield_keys() -> None:
store = InMemoryStore()
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
Expand All @@ -46,3 +79,17 @@ def test_yield_keys() -> None:

keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
assert keys_with_invalid_prefix == []


async def test_ayield_keys() -> None:
store = InMemoryStore()
await store.amset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])

keys = [key async for key in store.ayield_keys()]
assert set(keys) == {"key1", "key2", "key3"}

keys_with_prefix = [key async for key in store.ayield_keys(prefix="key")]
assert set(keys_with_prefix) == {"key1", "key2", "key3"}

keys_with_invalid_prefix = [key async for key in store.ayield_keys(prefix="x")]
assert keys_with_invalid_prefix == []

0 comments on commit a0ec045

Please sign in to comment.