Skip to content

Commit

Permalink
prefetch: use a separate temporary cache for prefetching (#730)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Jan 7, 2025
1 parent ad44884 commit da7d38f
Show file tree
Hide file tree
Showing 17 changed files with 600 additions and 191 deletions.
9 changes: 3 additions & 6 deletions examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(self, x):
if __name__ == "__main__":
ds = (
DataChain.from_storage(STORAGE, type="image")
.settings(cache=True, prefetch=25)
.settings(prefetch=25)
.filter(C("file.path").glob("*.jpg"))
.map(
label=lambda path: label_to_int(basename(path)[:3], CLASSES),
Expand All @@ -68,7 +68,7 @@ def forward(self, x):
train_loader = DataLoader(
ds.to_pytorch(transform=transform),
batch_size=25,
num_workers=max(4, os.cpu_count() or 2),
num_workers=min(4, os.cpu_count() or 2),
persistent_workers=True,
multiprocessing_context=multiprocessing.get_context("spawn"),
)
Expand All @@ -80,10 +80,7 @@ def forward(self, x):
# Train the model
for epoch in range(NUM_EPOCHS):
with tqdm(
train_loader,
desc=f"epoch {epoch + 1}/{NUM_EPOCHS}",
unit="batch",
leave=False,
train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch"
) as loader:
for data in loader:
inputs, labels = data
Expand Down
22 changes: 16 additions & 6 deletions src/datachain/asyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
Iterable,
Iterator,
)
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, wait
from heapq import heappop, heappush
from typing import Any, Callable, Generic, Optional, TypeVar

from fsspec.asyn import get_loop

from datachain.utils import safe_closing

ASYNC_WORKERS = 20

InputT = TypeVar("InputT", contravariant=True) # noqa: PLC0105
Expand Down Expand Up @@ -56,6 +58,7 @@ def __init__(
self.pool = ThreadPoolExecutor(workers)
self._tasks: set[asyncio.Task] = set()
self._shutdown_producer = threading.Event()
self._producer_is_shutdown = threading.Event()

def start_task(self, coro: Coroutine) -> asyncio.Task:
task = self.loop.create_task(coro)
Expand All @@ -64,11 +67,16 @@ def start_task(self, coro: Coroutine) -> asyncio.Task:
return task

def _produce(self) -> None:
for item in self.iterable:
if self._shutdown_producer.is_set():
return
fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop)
fut.result() # wait until the item is in the queue
try:
with safe_closing(self.iterable):
for item in self.iterable:
if self._shutdown_producer.is_set():
return
coro = self.work_queue.put(item)
fut = asyncio.run_coroutine_threadsafe(coro, self.loop)
fut.result() # wait until the item is in the queue
finally:
self._producer_is_shutdown.set()

async def produce(self) -> None:
await self.to_thread(self._produce)
Expand Down Expand Up @@ -179,6 +187,8 @@ def iterate(self, timeout=None) -> Generator[ResultT, None, None]:
self.shutdown_producer()
if not async_run.done():
async_run.cancel()
wait([async_run])
self._producer_is_shutdown.wait()

def __iter__(self):
return self.iterate()
Expand Down
40 changes: 31 additions & 9 deletions src/datachain/cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
from collections.abc import Iterator
from contextlib import contextmanager
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Optional

from dvc_data.hashfile.db.local import LocalHashFileDB
from dvc_objects.fs.local import LocalFileSystem
from dvc_objects.fs.utils import remove
from fsspec.callbacks import Callback, TqdmCallback

from .progress import Tqdm
Expand All @@ -20,6 +24,23 @@ def try_scandir(path):
pass


def get_temp_cache(tmp_dir: str, prefix: Optional[str] = None) -> "DataChainCache":
cache_dir = mkdtemp(prefix=prefix, dir=tmp_dir)
return DataChainCache(cache_dir, tmp_dir=tmp_dir)


@contextmanager
def temporary_cache(
tmp_dir: str, prefix: Optional[str] = None, delete: bool = True
) -> Iterator["DataChainCache"]:
cache = get_temp_cache(tmp_dir, prefix=prefix)
try:
yield cache
finally:
if delete:
cache.destroy()


class DataChainCache:
def __init__(self, cache_dir: str, tmp_dir: str):
self.odb = LocalHashFileDB(
Expand All @@ -28,6 +49,9 @@ def __init__(self, cache_dir: str, tmp_dir: str):
tmp_dir=tmp_dir,
)

def __eq__(self, other) -> bool:
return self.odb == other.odb

@property
def cache_dir(self):
return self.odb.path
Expand Down Expand Up @@ -82,20 +106,18 @@ async def download(
os.unlink(tmp_info)

def store_data(self, file: "File", contents: bytes) -> None:
checksum = file.get_hash()
dst = self.path_from_checksum(checksum)
if not os.path.exists(dst):
# Create the file only if it's not already in cache
os.makedirs(os.path.dirname(dst), exist_ok=True)
with open(dst, mode="wb") as f:
f.write(contents)

def clear(self):
self.odb.add_bytes(file.get_hash(), contents)

def clear(self) -> None:
"""
Completely clear the cache.
"""
self.odb.clear()

def destroy(self) -> None:
# `clear` leaves the prefix directory structure intact.
remove(self.cache_dir)

def get_total_size(self) -> int:
total = 0
for subdir in try_scandir(self.odb.path):
Expand Down
6 changes: 6 additions & 0 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,12 @@ def find_column_to_str( # noqa: PLR0911
return ""


def clone_catalog_with_cache(catalog: "Catalog", cache: "DataChainCache") -> "Catalog":
clone = catalog.copy()
clone.cache = cache
return clone


class Catalog:
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,15 @@ def from_storage(
return dc

if update or not list_ds_exists:
# disable prefetch for listing, as it pre-downloads all files
(
cls.from_records(
DataChain.DEFAULT_FILE_RECORD,
session=session,
settings=settings,
in_memory=in_memory,
)
.settings(prefetch=0)
.gen(
list_bucket(list_uri, cache, client_config=client_config),
output={f"{object_name}": File},
Expand Down
19 changes: 15 additions & 4 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,21 @@ def ensure_cached(self) -> None:
client = self._catalog.get_client(self.source)
client.download(self, callback=self._download_cb)

async def _prefetch(self) -> None:
if self._caching_enabled:
client = self._catalog.get_client(self.source)
await client._download(self, callback=self._download_cb)
async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
from datachain.client.hf import HfClient

if self._catalog is None:
raise RuntimeError("cannot prefetch file because catalog is not setup")

client = self._catalog.get_client(self.source)
if client.protocol == HfClient.protocol:
return False

await client._download(self, callback=download_cb or self._download_cb)
self._set_stream(
self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
)
return True

def get_local_path(self) -> Optional[str]:
"""Return path to a file in a local cache.
Expand Down
70 changes: 57 additions & 13 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
from collections.abc import Iterator
import os
import weakref
from collections.abc import Generator, Iterable, Iterator
from contextlib import closing
from typing import TYPE_CHECKING, Any, Callable, Optional

from PIL import Image
Expand All @@ -9,15 +12,19 @@
from torchvision.transforms import v2

from datachain import Session
from datachain.asyn import AsyncMapper
from datachain.cache import get_temp_cache
from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
from datachain.lib.settings import Settings
from datachain.lib.text import convert_text
from datachain.progress import CombinedDownloadCallback
from datachain.query.dataset import get_download_callback

if TYPE_CHECKING:
from torchvision.transforms.v2 import Transform

from datachain.cache import DataChainCache as Cache


logger = logging.getLogger("datachain")

Expand Down Expand Up @@ -75,6 +82,19 @@ def __init__(
if (prefetch := dc_settings.prefetch) is not None:
self.prefetch = prefetch

self._cache = catalog.cache
self._prefetch_cache: Optional[Cache] = None
if prefetch and not self.cache:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
self._prefetch_cache = get_temp_cache(tmp_dir, prefix="prefetch-")
self._cache = self._prefetch_cache
weakref.finalize(self, self._prefetch_cache.destroy)

def close(self) -> None:
if self._prefetch_cache:
self._prefetch_cache.destroy()

def _init_catalog(self, catalog: "Catalog"):
# For compatibility with multiprocessing,
# we can only store params in __init__(), as Catalog isn't picklable
Expand All @@ -89,9 +109,15 @@ def _get_catalog(self) -> "Catalog":
ms = ms_cls(*ms_args, **ms_kwargs)
wh_cls, wh_args, wh_kwargs = self._wh_params
wh = wh_cls(*wh_args, **wh_kwargs)
return Catalog(ms, wh, **self._catalog_params)
catalog = Catalog(ms, wh, **self._catalog_params)
catalog.cache = self._cache
return catalog

def _rows_iter(self, total_rank: int, total_workers: int):
def _row_iter(
self,
total_rank: int,
total_workers: int,
) -> Generator[tuple[Any, ...], None, None]:
catalog = self._get_catalog()
session = Session("PyTorch", catalog=catalog)
ds = DataChain.from_dataset(
Expand All @@ -104,16 +130,34 @@ def _rows_iter(self, total_rank: int, total_workers: int):
ds = ds.chunk(total_rank, total_workers)
yield from ds.collect()

def __iter__(self) -> Iterator[Any]:
total_rank, total_workers = self.get_rank_and_workers()
rows = self._rows_iter(total_rank, total_workers)
if self.prefetch > 0:
from datachain.lib.udf import _prefetch_input

rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
yield from map(self._process_row, rows)
def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
from datachain.lib.udf import _prefetch_inputs

def _process_row(self, row_features):
total_rank, total_workers = self.get_rank_and_workers()
download_cb = CombinedDownloadCallback()
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
download_cb = get_download_callback(
f"{total_rank}/{total_workers}",
position=total_rank,
leave=True,
)

rows = self._row_iter(total_rank, total_workers)
rows = _prefetch_inputs(
rows,
self.prefetch,
download_cb=download_cb,
after_prefetch=download_cb.increment_file_count,
)

with download_cb, closing(rows):
yield from rows

def __iter__(self) -> Iterator[list[Any]]:
with closing(self._iter_with_prefetch()) as rows:
yield from map(self._process_row, rows)

def _process_row(self, row_features: Iterable[Any]) -> list[Any]:
row = []
for fr in row_features:
if hasattr(fr, "read"):
Expand Down
Loading

0 comments on commit da7d38f

Please sign in to comment.