From da7d38f49b01c7b77280e3f76d70ce9fae4679d4 Mon Sep 17 00:00:00 2001 From: skshetry <18718008+skshetry@users.noreply.github.com> Date: Tue, 7 Jan 2025 21:31:12 +0545 Subject: [PATCH] prefetch: use a separate temporary cache for prefetching (#730) --- examples/get_started/torch-loader.py | 9 +- src/datachain/asyn.py | 22 ++- src/datachain/cache.py | 40 ++++-- src/datachain/catalog/catalog.py | 6 + src/datachain/lib/dc.py | 2 + src/datachain/lib/file.py | 19 ++- src/datachain/lib/pytorch.py | 70 ++++++++-- src/datachain/lib/udf.py | 122 ++++++++++------ src/datachain/progress.py | 19 ++- src/datachain/query/dataset.py | 200 +++++++++++++++------------ src/datachain/query/dispatch.py | 38 ++--- src/datachain/utils.py | 15 +- tests/func/test_datachain.py | 69 ++++++++- tests/func/test_pytorch.py | 41 ++++++ tests/unit/test_asyn.py | 33 +++++ tests/unit/test_cache.py | 28 +++- tests/unit/test_pytorch.py | 58 ++++++++ 17 files changed, 600 insertions(+), 191 deletions(-) create mode 100644 tests/unit/test_pytorch.py diff --git a/examples/get_started/torch-loader.py b/examples/get_started/torch-loader.py index 4e9812a0c..116cc9afb 100644 --- a/examples/get_started/torch-loader.py +++ b/examples/get_started/torch-loader.py @@ -56,7 +56,7 @@ def forward(self, x): if __name__ == "__main__": ds = ( DataChain.from_storage(STORAGE, type="image") - .settings(cache=True, prefetch=25) + .settings(prefetch=25) .filter(C("file.path").glob("*.jpg")) .map( label=lambda path: label_to_int(basename(path)[:3], CLASSES), @@ -68,7 +68,7 @@ def forward(self, x): train_loader = DataLoader( ds.to_pytorch(transform=transform), batch_size=25, - num_workers=max(4, os.cpu_count() or 2), + num_workers=min(4, os.cpu_count() or 2), persistent_workers=True, multiprocessing_context=multiprocessing.get_context("spawn"), ) @@ -80,10 +80,7 @@ def forward(self, x): # Train the model for epoch in range(NUM_EPOCHS): with tqdm( - train_loader, - desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", - unit="batch", - leave=False, + train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch" ) as loader: for data in loader: inputs, labels = data diff --git a/src/datachain/asyn.py b/src/datachain/asyn.py index 1b87afc41..4b6a6cfbc 100644 --- a/src/datachain/asyn.py +++ b/src/datachain/asyn.py @@ -8,12 +8,14 @@ Iterable, Iterator, ) -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, wait from heapq import heappop, heappush from typing import Any, Callable, Generic, Optional, TypeVar from fsspec.asyn import get_loop +from datachain.utils import safe_closing + ASYNC_WORKERS = 20 InputT = TypeVar("InputT", contravariant=True) # noqa: PLC0105 @@ -56,6 +58,7 @@ def __init__( self.pool = ThreadPoolExecutor(workers) self._tasks: set[asyncio.Task] = set() self._shutdown_producer = threading.Event() + self._producer_is_shutdown = threading.Event() def start_task(self, coro: Coroutine) -> asyncio.Task: task = self.loop.create_task(coro) @@ -64,11 +67,16 @@ def start_task(self, coro: Coroutine) -> asyncio.Task: return task def _produce(self) -> None: - for item in self.iterable: - if self._shutdown_producer.is_set(): - return - fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop) - fut.result() # wait until the item is in the queue + try: + with safe_closing(self.iterable): + for item in self.iterable: + if self._shutdown_producer.is_set(): + return + coro = self.work_queue.put(item) + fut = asyncio.run_coroutine_threadsafe(coro, self.loop) + fut.result() # wait until the item is in the queue + finally: + self._producer_is_shutdown.set() async def produce(self) -> None: await self.to_thread(self._produce) @@ -179,6 +187,8 @@ def iterate(self, timeout=None) -> Generator[ResultT, None, None]: self.shutdown_producer() if not async_run.done(): async_run.cancel() + wait([async_run]) + self._producer_is_shutdown.wait() def __iter__(self): return self.iterate() diff --git a/src/datachain/cache.py b/src/datachain/cache.py index 3edaa9a0a..ae8457d98 100644 --- a/src/datachain/cache.py +++ b/src/datachain/cache.py @@ -1,8 +1,12 @@ import os +from collections.abc import Iterator +from contextlib import contextmanager +from tempfile import mkdtemp from typing import TYPE_CHECKING, Optional from dvc_data.hashfile.db.local import LocalHashFileDB from dvc_objects.fs.local import LocalFileSystem +from dvc_objects.fs.utils import remove from fsspec.callbacks import Callback, TqdmCallback from .progress import Tqdm @@ -20,6 +24,23 @@ def try_scandir(path): pass +def get_temp_cache(tmp_dir: str, prefix: Optional[str] = None) -> "DataChainCache": + cache_dir = mkdtemp(prefix=prefix, dir=tmp_dir) + return DataChainCache(cache_dir, tmp_dir=tmp_dir) + + +@contextmanager +def temporary_cache( + tmp_dir: str, prefix: Optional[str] = None, delete: bool = True +) -> Iterator["DataChainCache"]: + cache = get_temp_cache(tmp_dir, prefix=prefix) + try: + yield cache + finally: + if delete: + cache.destroy() + + class DataChainCache: def __init__(self, cache_dir: str, tmp_dir: str): self.odb = LocalHashFileDB( @@ -28,6 +49,9 @@ def __init__(self, cache_dir: str, tmp_dir: str): tmp_dir=tmp_dir, ) + def __eq__(self, other) -> bool: + return self.odb == other.odb + @property def cache_dir(self): return self.odb.path @@ -82,20 +106,18 @@ async def download( os.unlink(tmp_info) def store_data(self, file: "File", contents: bytes) -> None: - checksum = file.get_hash() - dst = self.path_from_checksum(checksum) - if not os.path.exists(dst): - # Create the file only if it's not already in cache - os.makedirs(os.path.dirname(dst), exist_ok=True) - with open(dst, mode="wb") as f: - f.write(contents) - - def clear(self): + self.odb.add_bytes(file.get_hash(), contents) + + def clear(self) -> None: """ Completely clear the cache. """ self.odb.clear() + def destroy(self) -> None: + # `clear` leaves the prefix directory structure intact. + remove(self.cache_dir) + def get_total_size(self) -> int: total = 0 for subdir in try_scandir(self.odb.path): diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 4e0fe2f0f..5971ea16d 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -536,6 +536,12 @@ def find_column_to_str( # noqa: PLR0911 return "" +def clone_catalog_with_cache(catalog: "Catalog", cache: "DataChainCache") -> "Catalog": + clone = catalog.copy() + clone.cache = cache + return clone + + class Catalog: def __init__( self, diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index a9542d9d8..78b7f6971 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -451,6 +451,7 @@ def from_storage( return dc if update or not list_ds_exists: + # disable prefetch for listing, as it pre-downloads all files ( cls.from_records( DataChain.DEFAULT_FILE_RECORD, @@ -458,6 +459,7 @@ def from_storage( settings=settings, in_memory=in_memory, ) + .settings(prefetch=0) .gen( list_bucket(list_uri, cache, client_config=client_config), output={f"{object_name}": File}, diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 985ab318a..1aaa3fc91 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -269,10 +269,21 @@ def ensure_cached(self) -> None: client = self._catalog.get_client(self.source) client.download(self, callback=self._download_cb) - async def _prefetch(self) -> None: - if self._caching_enabled: - client = self._catalog.get_client(self.source) - await client._download(self, callback=self._download_cb) + async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool: + from datachain.client.hf import HfClient + + if self._catalog is None: + raise RuntimeError("cannot prefetch file because catalog is not setup") + + client = self._catalog.get_client(self.source) + if client.protocol == HfClient.protocol: + return False + + await client._download(self, callback=download_cb or self._download_cb) + self._set_stream( + self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK + ) + return True def get_local_path(self) -> Optional[str]: """Return path to a file in a local cache. diff --git a/src/datachain/lib/pytorch.py b/src/datachain/lib/pytorch.py index e85fb0aae..985af387c 100644 --- a/src/datachain/lib/pytorch.py +++ b/src/datachain/lib/pytorch.py @@ -1,5 +1,8 @@ import logging -from collections.abc import Iterator +import os +import weakref +from collections.abc import Generator, Iterable, Iterator +from contextlib import closing from typing import TYPE_CHECKING, Any, Callable, Optional from PIL import Image @@ -9,15 +12,19 @@ from torchvision.transforms import v2 from datachain import Session -from datachain.asyn import AsyncMapper +from datachain.cache import get_temp_cache from datachain.catalog import Catalog, get_catalog from datachain.lib.dc import DataChain from datachain.lib.settings import Settings from datachain.lib.text import convert_text +from datachain.progress import CombinedDownloadCallback +from datachain.query.dataset import get_download_callback if TYPE_CHECKING: from torchvision.transforms.v2 import Transform + from datachain.cache import DataChainCache as Cache + logger = logging.getLogger("datachain") @@ -75,6 +82,19 @@ def __init__( if (prefetch := dc_settings.prefetch) is not None: self.prefetch = prefetch + self._cache = catalog.cache + self._prefetch_cache: Optional[Cache] = None + if prefetch and not self.cache: + tmp_dir = catalog.cache.tmp_dir + assert tmp_dir + self._prefetch_cache = get_temp_cache(tmp_dir, prefix="prefetch-") + self._cache = self._prefetch_cache + weakref.finalize(self, self._prefetch_cache.destroy) + + def close(self) -> None: + if self._prefetch_cache: + self._prefetch_cache.destroy() + def _init_catalog(self, catalog: "Catalog"): # For compatibility with multiprocessing, # we can only store params in __init__(), as Catalog isn't picklable @@ -89,9 +109,15 @@ def _get_catalog(self) -> "Catalog": ms = ms_cls(*ms_args, **ms_kwargs) wh_cls, wh_args, wh_kwargs = self._wh_params wh = wh_cls(*wh_args, **wh_kwargs) - return Catalog(ms, wh, **self._catalog_params) + catalog = Catalog(ms, wh, **self._catalog_params) + catalog.cache = self._cache + return catalog - def _rows_iter(self, total_rank: int, total_workers: int): + def _row_iter( + self, + total_rank: int, + total_workers: int, + ) -> Generator[tuple[Any, ...], None, None]: catalog = self._get_catalog() session = Session("PyTorch", catalog=catalog) ds = DataChain.from_dataset( @@ -104,16 +130,34 @@ def _rows_iter(self, total_rank: int, total_workers: int): ds = ds.chunk(total_rank, total_workers) yield from ds.collect() - def __iter__(self) -> Iterator[Any]: - total_rank, total_workers = self.get_rank_and_workers() - rows = self._rows_iter(total_rank, total_workers) - if self.prefetch > 0: - from datachain.lib.udf import _prefetch_input - - rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate() - yield from map(self._process_row, rows) + def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]: + from datachain.lib.udf import _prefetch_inputs - def _process_row(self, row_features): + total_rank, total_workers = self.get_rank_and_workers() + download_cb = CombinedDownloadCallback() + if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"): + download_cb = get_download_callback( + f"{total_rank}/{total_workers}", + position=total_rank, + leave=True, + ) + + rows = self._row_iter(total_rank, total_workers) + rows = _prefetch_inputs( + rows, + self.prefetch, + download_cb=download_cb, + after_prefetch=download_cb.increment_file_count, + ) + + with download_cb, closing(rows): + yield from rows + + def __iter__(self) -> Iterator[list[Any]]: + with closing(self._iter_with_prefetch()) as rows: + yield from map(self._process_row, rows) + + def _process_row(self, row_features: Iterable[Any]) -> list[Any]: row = [] for fr in row_features: if hasattr(fr, "read"): diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index c59442d6b..4a69ff6fd 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -1,14 +1,16 @@ -import contextlib import sys import traceback -from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, Optional +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from contextlib import closing, nullcontext +from functools import partial +from typing import TYPE_CHECKING, Any, Optional, TypeVar import attrs from fsspec.callbacks import DEFAULT_CALLBACK, Callback from pydantic import BaseModel from datachain.asyn import AsyncMapper +from datachain.cache import temporary_cache from datachain.dataset import RowDict from datachain.lib.convert.flatten import flatten from datachain.lib.data_model import DataValue @@ -21,17 +23,22 @@ Partition, RowsOutputBatch, ) +from datachain.utils import safe_closing if TYPE_CHECKING: from collections import abc + from contextlib import AbstractContextManager from typing_extensions import Self + from datachain.cache import DataChainCache as Cache from datachain.catalog import Catalog from datachain.lib.signal_schema import SignalSchema from datachain.lib.udf_signature import UdfSignature from datachain.query.batch import RowsOutput +T = TypeVar("T", bound=Sequence[Any]) + class UdfError(DataChainParamsError): def __init__(self, msg): @@ -98,6 +105,10 @@ def run( processed_cb, ) + @property + def prefetch(self) -> int: + return self.inner.prefetch + class UDFBase(AbstractUDF): """Base class for stateful user-defined functions. @@ -148,12 +159,11 @@ def process(self, file) -> list[float]: """ is_output_batched = False - catalog: "Optional[Catalog]" + prefetch: int = 0 def __init__(self): self.params: Optional[SignalSchema] = None self.output = None - self.catalog = None self._func = None def process(self, *args, **kwargs): @@ -242,26 +252,23 @@ def _obj_to_list(obj): return flatten(obj) if isinstance(obj, BaseModel) else [obj] def _parse_row( - self, row_dict: RowDict, cache: bool, download_cb: Callback + self, row_dict: RowDict, catalog: "Catalog", cache: bool, download_cb: Callback ) -> list[DataValue]: assert self.params row = [row_dict[p] for p in self.params.to_udf_spec()] obj_row = self.params.row_to_objs(row) for obj in obj_row: if isinstance(obj, File): - assert self.catalog is not None - obj._set_stream( - self.catalog, caching_enabled=cache, download_cb=download_cb - ) + obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb) return obj_row - def _prepare_row(self, row, udf_fields, cache, download_cb): + def _prepare_row(self, row, udf_fields, catalog, cache, download_cb): row_dict = RowDict(zip(udf_fields, row)) - return self._parse_row(row_dict, cache, download_cb) + return self._parse_row(row_dict, catalog, cache, download_cb) - def _prepare_row_and_id(self, row, udf_fields, cache, download_cb): + def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb): row_dict = RowDict(zip(udf_fields, row)) - udf_input = self._parse_row(row_dict, cache, download_cb) + udf_input = self._parse_row(row_dict, catalog, cache, download_cb) return row_dict["sys__id"], *udf_input def process_safe(self, obj_rows): @@ -279,13 +286,47 @@ def process_safe(self, obj_rows): return result_objs -async def _prefetch_input(row): +def noop(*args, **kwargs): + pass + + +async def _prefetch_input( + row: T, + download_cb: Optional["Callback"] = None, + after_prefetch: "Callable[[], None]" = noop, +) -> T: for obj in row: - if isinstance(obj, File): - await obj._prefetch() + if isinstance(obj, File) and await obj._prefetch(download_cb): + after_prefetch() return row +def _prefetch_inputs( + prepared_inputs: "Iterable[T]", + prefetch: int = 0, + download_cb: Optional["Callback"] = None, + after_prefetch: "Callable[[], None]" = noop, +) -> "abc.Generator[T, None, None]": + if prefetch > 0: + f = partial( + _prefetch_input, + download_cb=download_cb, + after_prefetch=after_prefetch, + ) + prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment] + yield from prepared_inputs + + +def _get_cache( + cache: "Cache", prefetch: int = 0, use_cache: bool = False +) -> "AbstractContextManager[Cache]": + tmp_dir = cache.tmp_dir + assert tmp_dir + if prefetch and not use_cache: + return temporary_cache(tmp_dir, prefix="prefetch-") + return nullcontext(cache) + + class Mapper(UDFBase): """Inherit from this class to pass to `DataChain.map()`.""" @@ -300,18 +341,18 @@ def run( download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, ) -> Iterator[Iterable[UDFResult]]: - self.catalog = catalog self.setup() - prepared_inputs: abc.Generator[Sequence[Any], None, None] = ( - self._prepare_row_and_id(row, udf_fields, cache, download_cb) - for row in udf_inputs - ) - if self.prefetch > 0: - prepared_inputs = AsyncMapper( - _prefetch_input, prepared_inputs, workers=self.prefetch - ).iterate() - with contextlib.closing(prepared_inputs): + def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]": + with safe_closing(udf_inputs): + for row in udf_inputs: + yield self._prepare_row_and_id( + row, udf_fields, catalog, cache, download_cb + ) + + prepared_inputs = _prepare_rows(udf_inputs) + prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch) + with closing(prepared_inputs): for id_, *udf_args in prepared_inputs: result_objs = self.process_safe(udf_args) udf_output = self._flatten_row(result_objs) @@ -336,14 +377,15 @@ def run( download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, ) -> Iterator[Iterable[UDFResult]]: - self.catalog = catalog self.setup() for batch in udf_inputs: n_rows = len(batch.rows) row_ids, *udf_args = zip( *[ - self._prepare_row_and_id(row, udf_fields, cache, download_cb) + self._prepare_row_and_id( + row, udf_fields, catalog, cache, download_cb + ) for row in batch.rows ] ) @@ -378,17 +420,18 @@ def run( download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, ) -> Iterator[Iterable[UDFResult]]: - self.catalog = catalog self.setup() - prepared_inputs: abc.Generator[Sequence[Any], None, None] = ( - self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs - ) - if self.prefetch > 0: - prepared_inputs = AsyncMapper( - _prefetch_input, prepared_inputs, workers=self.prefetch - ).iterate() - with contextlib.closing(prepared_inputs): + def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]": + with safe_closing(udf_inputs): + for row in udf_inputs: + yield self._prepare_row( + row, udf_fields, catalog, cache, download_cb + ) + + prepared_inputs = _prepare_rows(udf_inputs) + prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch) + with closing(prepared_inputs): for row in prepared_inputs: result_objs = self.process_safe(row) udf_outputs = (self._flatten_row(row) for row in result_objs) @@ -413,13 +456,12 @@ def run( download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, ) -> Iterator[Iterable[UDFResult]]: - self.catalog = catalog self.setup() for batch in udf_inputs: udf_args = zip( *[ - self._prepare_row(row, udf_fields, cache, download_cb) + self._prepare_row(row, udf_fields, catalog, cache, download_cb) for row in batch.rows ] ) diff --git a/src/datachain/progress.py b/src/datachain/progress.py index bc8b8a0ac..c7e5e80cc 100644 --- a/src/datachain/progress.py +++ b/src/datachain/progress.py @@ -5,6 +5,7 @@ from threading import RLock from typing import Any, ClassVar +from fsspec import Callback from fsspec.callbacks import TqdmCallback from tqdm import tqdm @@ -132,8 +133,24 @@ def format_dict(self): return d -class CombinedDownloadCallback(TqdmCallback): +class CombinedDownloadCallback(Callback): def set_size(self, size): # This is a no-op to prevent fsspec's .get_file() from setting the combined # download size to the size of the current file. pass + + def increment_file_count(self, n: int = 1) -> None: + pass + + +class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback): + def __init__(self, tqdm_kwargs=None, *args, **kwargs): + self.files_count = 0 + tqdm_kwargs = tqdm_kwargs or {} + tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count) + super().__init__(tqdm_kwargs, *args, **kwargs) + + def increment_file_count(self, n: int = 1) -> None: + self.files_count += n + if self.tqdm is not None: + self.tqdm.postfix = f"{self.files_count} files" diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 6245e20a2..875bbcf25 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -35,6 +35,7 @@ from sqlalchemy.sql.selectable import Select from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper +from datachain.catalog.catalog import clone_catalog_with_cache from datachain.data_storage.schema import ( PARTITION_COLUMN_ID, partition_col_names, @@ -43,7 +44,8 @@ from datachain.dataset import DatasetStatus, RowDict from datachain.error import DatasetNotFoundError, QueryScriptCancelError from datachain.func.base import Function -from datachain.progress import CombinedDownloadCallback +from datachain.lib.udf import UDFAdapter, _get_cache +from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback from datachain.query.schema import C, UDFParamSpec, normalize_param from datachain.query.session import Session from datachain.sql.functions.random import rand @@ -52,6 +54,7 @@ determine_processes, filtered_cloudpickle_dumps, get_datachain_executable, + safe_closing, ) if TYPE_CHECKING: @@ -349,15 +352,16 @@ def process_udf_outputs( warehouse.insert_rows_done(udf_table) -def get_download_callback() -> Callback: - return CombinedDownloadCallback( +def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallback: + return TqdmCombinedDownloadCallback( { - "desc": "Download", + "desc": "Download" + suffix, "unit": "B", "unit_scale": True, "unit_divisor": 1024, "leave": False, - } + **kwargs, + }, ) @@ -418,97 +422,109 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: udf_fields = [str(c.name) for c in query.selected_columns] - try: - if workers: - if self.catalog.in_memory: - raise RuntimeError( - "In-memory databases cannot be used with " - "distributed processing." - ) + prefetch = self.udf.prefetch + with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache: + catalog = clone_catalog_with_cache(self.catalog, _cache) + try: + if workers: + if catalog.in_memory: + raise RuntimeError( + "In-memory databases cannot be used with " + "distributed processing." + ) - from datachain.catalog.loader import get_distributed_class - - distributor = get_distributed_class(min_task_size=self.min_task_size) - distributor( - self.udf, - self.catalog, - udf_table, - query, - workers, - processes, - udf_fields=udf_fields, - is_generator=self.is_generator, - use_partitioning=use_partitioning, - cache=self.cache, - ) - elif processes: - # Parallel processing (faster for more CPU-heavy UDFs) - if self.catalog.in_memory: - raise RuntimeError( - "In-memory databases cannot be used with parallel processing." - ) - udf_info: UdfInfo = { - "udf_data": filtered_cloudpickle_dumps(self.udf), - "catalog_init": self.catalog.get_init_params(), - "metastore_clone_params": self.catalog.metastore.clone_params(), - "warehouse_clone_params": self.catalog.warehouse.clone_params(), - "table": udf_table, - "query": query, - "udf_fields": udf_fields, - "batching": batching, - "processes": processes, - "is_generator": self.is_generator, - "cache": self.cache, - } - - # Run the UDFDispatcher in another process to avoid needing - # if __name__ == '__main__': in user scripts - exec_cmd = get_datachain_executable() - cmd = [*exec_cmd, "internal-run-udf"] - envs = dict(os.environ) - envs.update({"PYTHONPATH": os.getcwd()}) - process_data = filtered_cloudpickle_dumps(udf_info) - - with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603 - process.communicate(process_data) - if retval := process.poll(): - raise RuntimeError(f"UDF Execution Failed! Exit code: {retval}") - else: - # Otherwise process single-threaded (faster for smaller UDFs) - warehouse = self.catalog.warehouse - - udf_inputs = batching(warehouse.dataset_select_paginated, query) - download_cb = get_download_callback() - processed_cb = get_processed_callback() - generated_cb = get_generated_callback(self.is_generator) - try: - udf_results = self.udf.run( - udf_fields, - udf_inputs, - self.catalog, - self.cache, - download_cb, - processed_cb, + from datachain.catalog.loader import get_distributed_class + + distributor = get_distributed_class( + min_task_size=self.min_task_size ) - process_udf_outputs( - warehouse, - udf_table, - udf_results, + distributor( self.udf, - cb=generated_cb, + catalog, + udf_table, + query, + workers, + processes, + udf_fields=udf_fields, + is_generator=self.is_generator, + use_partitioning=use_partitioning, + cache=self.cache, ) - finally: - download_cb.close() - processed_cb.close() - generated_cb.close() - - except QueryScriptCancelError: - self.catalog.warehouse.close() - sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE) - except (Exception, KeyboardInterrupt): - # Close any open database connections if an error is encountered - self.catalog.warehouse.close() - raise + elif processes: + # Parallel processing (faster for more CPU-heavy UDFs) + if catalog.in_memory: + raise RuntimeError( + "In-memory databases cannot be used " + "with parallel processing." + ) + udf_info: UdfInfo = { + "udf_data": filtered_cloudpickle_dumps(self.udf), + "catalog_init": catalog.get_init_params(), + "metastore_clone_params": catalog.metastore.clone_params(), + "warehouse_clone_params": catalog.warehouse.clone_params(), + "table": udf_table, + "query": query, + "udf_fields": udf_fields, + "batching": batching, + "processes": processes, + "is_generator": self.is_generator, + "cache": self.cache, + } + + # Run the UDFDispatcher in another process to avoid needing + # if __name__ == '__main__': in user scripts + exec_cmd = get_datachain_executable() + cmd = [*exec_cmd, "internal-run-udf"] + envs = dict(os.environ) + envs.update({"PYTHONPATH": os.getcwd()}) + process_data = filtered_cloudpickle_dumps(udf_info) + + with subprocess.Popen( # noqa: S603 + cmd, env=envs, stdin=subprocess.PIPE + ) as process: + process.communicate(process_data) + if retval := process.poll(): + raise RuntimeError( + f"UDF Execution Failed! Exit code: {retval}" + ) + else: + # Otherwise process single-threaded (faster for smaller UDFs) + warehouse = catalog.warehouse + + udf_inputs = batching(warehouse.dataset_select_paginated, query) + download_cb = get_download_callback() + processed_cb = get_processed_callback() + generated_cb = get_generated_callback(self.is_generator) + + try: + udf_results = self.udf.run( + udf_fields, + udf_inputs, + catalog, + self.cache, + download_cb, + processed_cb, + ) + with safe_closing(udf_results): + process_udf_outputs( + warehouse, + udf_table, + udf_results, + self.udf, + cb=generated_cb, + ) + finally: + download_cb.close() + processed_cb.close() + generated_cb.close() + + except QueryScriptCancelError: + self.catalog.warehouse.close() + sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE) + except (Exception, KeyboardInterrupt): + # Close any open database connections if an error is encountered + self.catalog.warehouse.close() + raise def create_partitions_table(self, query: Select) -> "Table": """ diff --git a/src/datachain/query/dispatch.py b/src/datachain/query/dispatch.py index 722f68c10..549557014 100644 --- a/src/datachain/query/dispatch.py +++ b/src/datachain/query/dispatch.py @@ -14,7 +14,9 @@ from sqlalchemy.sql import func from datachain.catalog import Catalog +from datachain.catalog.catalog import clone_catalog_with_cache from datachain.catalog.loader import get_distributed_class +from datachain.lib.udf import _get_cache from datachain.query.batch import RowsOutput, RowsOutputBatch from datachain.query.dataset import ( get_download_callback, @@ -25,7 +27,7 @@ from datachain.query.queue import get_from_queue, put_into_queue from datachain.query.udf import UdfInfo from datachain.query.utils import get_query_id_column -from datachain.utils import batched, flatten +from datachain.utils import batched, flatten, safe_closing if TYPE_CHECKING: from sqlalchemy import Select, Table @@ -304,21 +306,25 @@ def run(self) -> None: processed_cb = ProcessedCallback() generated_cb = get_generated_callback(self.is_generator) - udf_results = self.udf.run( - self.udf_fields, - self.get_inputs(), - self.catalog, - self.cache, - download_cb=self.cb, - processed_cb=processed_cb, - ) - process_udf_outputs( - self.catalog.warehouse, - self.table, - self.notify_and_process(udf_results, processed_cb), - self.udf, - cb=generated_cb, - ) + prefetch = self.udf.prefetch + with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache: + catalog = clone_catalog_with_cache(self.catalog, _cache) + udf_results = self.udf.run( + self.udf_fields, + self.get_inputs(), + catalog, + self.cache, + download_cb=self.cb, + processed_cb=processed_cb, + ) + with safe_closing(udf_results): + process_udf_outputs( + catalog.warehouse, + self.table, + self.notify_and_process(udf_results, processed_cb), + self.udf, + cb=generated_cb, + ) put_into_queue( self.done_queue, diff --git a/src/datachain/utils.py b/src/datachain/utils.py index f8ddbd1e1..f3967011d 100644 --- a/src/datachain/utils.py +++ b/src/datachain/utils.py @@ -9,6 +9,7 @@ import sys import time from collections.abc import Iterable, Iterator, Sequence +from contextlib import contextmanager from datetime import date, datetime, timezone from itertools import chain, islice from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union @@ -22,6 +23,7 @@ if TYPE_CHECKING: import pandas as pd + from typing_extensions import Self NUL = b"\0" TIME_ZERO = datetime.fromtimestamp(0, tz=timezone.utc) @@ -33,7 +35,7 @@ STUDIO_URL = "https://studio.datachain.ai" -T = TypeVar("T", bound="DataChainDir") +T = TypeVar("T") class DataChainDir: @@ -90,7 +92,7 @@ def default_root(cls) -> str: return osp.join(root_dir, cls.DEFAULT) @classmethod - def find(cls: type[T], create: bool = True) -> T: + def find(cls, create: bool = True) -> "Self": try: root = os.environ[cls.ENV_VAR] except KeyError: @@ -479,3 +481,12 @@ def row_to_nested_dict( for h, v in zip(headers, row): nested_dict_path_set(result, h, v) return result + + +@contextmanager +def safe_closing(thing: T) -> Iterator[T]: + try: + yield thing + finally: + if hasattr(thing, "close"): + thing.close() diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index f4f504a4d..f2e8c4a65 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -220,9 +220,24 @@ def test_from_storage_dependencies(cloud_test_catalog, cloud_type): @pytest.mark.parametrize("prefetch", [0, 2]) def test_map_file(cloud_test_catalog, use_cache, prefetch): ctc = cloud_test_catalog + ctc.catalog.cache.clear() + + def is_prefetched(file: File) -> bool: + return file._catalog.cache.contains(file) and bool(file.get_local_path()) + + def verify_cache_used(file): + catalog = file._catalog + if use_cache or not prefetch: + assert catalog.cache == cloud_test_catalog.catalog.cache + return + head, tail = os.path.split(catalog.cache.cache_dir) + assert head == catalog.cache.tmp_dir + assert tail.startswith("prefetch-") def new_signal(file: File) -> str: - assert bool(file.get_local_path()) is (use_cache and prefetch > 0) + assert is_prefetched(file) == (prefetch > 0) + verify_cache_used(file) + with file.open() as f: return file.name + " -> " + f.read().decode("utf-8") @@ -245,6 +260,7 @@ def new_signal(file: File) -> str: assert set(dc.collect("signal")) == expected for file in dc.collect("file"): assert bool(file.get_local_path()) is use_cache + assert not os.listdir(ctc.catalog.cache.tmp_dir) @pytest.mark.parametrize("use_cache", [False, True]) @@ -1273,6 +1289,57 @@ def gen_func(): ).show() +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("prefetch", [0, 2]) +def test_gen_file(cloud_test_catalog, use_cache, prefetch): + ctc = cloud_test_catalog + ctc.catalog.cache.clear() + + def is_prefetched(file: File) -> bool: + return file._catalog.cache.contains(file) and bool(file.get_local_path()) + + def verify_cache_used(file): + catalog = file._catalog + if use_cache or not prefetch: + assert catalog.cache == cloud_test_catalog.catalog.cache + return + head, tail = os.path.split(catalog.cache.cache_dir) + assert head == catalog.cache.tmp_dir + assert tail.startswith("prefetch-") + + def new_signal(file: File) -> list[str]: + assert is_prefetched(file) == (prefetch > 0) + verify_cache_used(file) + + with file.open("rb") as f: + return [file.name, f.read().decode("utf-8")] + + dc = ( + DataChain.from_storage(ctc.src_uri, session=ctc.session) + .settings(cache=use_cache, prefetch=prefetch) + .gen(signal=new_signal, output=str) + .save() + ) + expected = { + "Cats and Dogs", + "arf", + "bark", + "cat1", + "cat2", + "description", + "dog1", + "dog2", + "dog3", + "dog4", + "meow", + "mrow", + "ruff", + "woof", + } + assert set(dc.collect("signal")) == expected + assert not os.listdir(ctc.catalog.cache.tmp_dir) + + def test_similarity_search(cloud_test_catalog): session = cloud_test_catalog.session src_uri = cloud_test_catalog.src_uri diff --git a/tests/func/test_pytorch.py b/tests/func/test_pytorch.py index 7920dc732..2ebcbc9ab 100644 --- a/tests/func/test_pytorch.py +++ b/tests/func/test_pytorch.py @@ -1,3 +1,6 @@ +import os +from contextlib import closing + import open_clip import pytest import torch @@ -7,6 +10,7 @@ from torchvision.transforms import v2 from datachain.lib.dc import DataChain +from datachain.lib.file import File from datachain.lib.pytorch import PytorchDataset @@ -80,6 +84,43 @@ def test_to_pytorch(fake_dataset): assert img.size() == Size([3, 64, 64]) +@pytest.mark.parametrize("use_cache", (True, False)) +@pytest.mark.parametrize("prefetch", (0, 2)) +def test_prefetch(mocker, catalog, fake_dataset, use_cache, prefetch): + catalog.cache.clear() + + dataset = fake_dataset.limit(10) + ds = dataset.settings(cache=use_cache, prefetch=prefetch).to_pytorch() + + iter_with_prefetch = ds._iter_with_prefetch + cache = ds._cache + + def is_prefetched(file: File): + assert file._catalog + assert file._catalog.cache == cache + return cache.contains(file) + + def check_prefetched(): + for row in iter_with_prefetch(): + files = [f for f in row if isinstance(f, File)] + assert files + files_not_in_cache = [f for f in files if not is_prefetched(f)] + if prefetch: + assert not files_not_in_cache, "Some files are not in cache" + else: + assert files == files_not_in_cache, "Some files are in cache" + yield row + + # we peek internally with `_iter_with_prefetch` to check if the files are prefetched + # as `__iter__` transforms them. + m = mocker.patch.object(ds, "_iter_with_prefetch", wraps=check_prefetched) + with closing(ds), closing(iter(ds)) as rows: + assert next(rows) + m.assert_called_once() + # prefetch cache directory should be removed after `close()` + assert os.path.exists(cache.cache_dir) == (use_cache or not prefetch) + + def test_hf_to_pytorch(catalog, fake_image_dir): hf_ds = load_dataset("imagefolder", data_dir=fake_image_dir) chain = DataChain.from_hf(hf_ds) diff --git a/tests/unit/test_asyn.py b/tests/unit/test_asyn.py index e102dd3a2..b77be2d79 100644 --- a/tests/unit/test_asyn.py +++ b/tests/unit/test_asyn.py @@ -1,5 +1,7 @@ import asyncio import functools +import itertools +import threading from collections import Counter from contextlib import contextmanager from queue import Queue @@ -143,6 +145,37 @@ async def process(x): assert list(it) == [] +@pytest.mark.parametrize("create_mapper", [AsyncMapper, OrderedMapper]) +@pytest.mark.parametrize("stop_at", [10, None]) +def test_mapper_closes_iterable(create_mapper, stop_at): + """Test that the iterable is closed when the `.iterate()` is closed or exhausted.""" + + async def process(x): + return x + + iterable_closed = False + start_thread = None + close_thread = None + + def gen(): + nonlocal iterable_closed, start_thread, close_thread + start_thread = threading.get_ident() + try: + yield from range(50) + finally: + iterable_closed = True + close_thread = threading.get_ident() + + mapper = create_mapper(process, gen(), workers=10, loop=get_loop()) + it = mapper.iterate() + list(itertools.islice(it, stop_at)) + if stop_at is not None: + it.close() + assert iterable_closed + assert start_thread == close_thread + assert start_thread != threading.get_ident() + + @pytest.mark.parametrize("create_mapper", [AsyncMapper, OrderedMapper]) @settings(deadline=None) @given( diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py index 012ff4e2c..d926acdb5 100644 --- a/tests/unit/test_cache.py +++ b/tests/unit/test_cache.py @@ -1,6 +1,8 @@ +import os + import pytest -from datachain.cache import DataChainCache +from datachain.cache import DataChainCache, get_temp_cache, temporary_cache from datachain.lib.file import File @@ -53,3 +55,27 @@ def test_remove(cache): assert cache.contains(uid) cache.remove(uid) assert not cache.contains(uid) + + +def test_destroy(cache: DataChainCache): + file = File(source="s3://foo", path="data/bar", etag="xyz", size=3, location=None) + cache.store_data(file, b"foo") + assert cache.contains(file) + + cache.destroy() + assert not os.path.exists(cache.cache_dir) + + +def test_get_temp_cache(tmp_path): + temp = get_temp_cache(tmp_path, prefix="test-") + assert os.path.isdir(temp.cache_dir) + assert isinstance(temp, DataChainCache) + head, tail = os.path.split(temp.cache_dir) + assert head == str(tmp_path) + assert tail.startswith("test-") + + +def test_temporary_cache(tmp_path): + with temporary_cache(tmp_path, prefix="test-") as temp: + assert os.path.isdir(temp.cache_dir) + assert not os.path.exists(temp.cache_dir) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py new file mode 100644 index 000000000..2018837aa --- /dev/null +++ b/tests/unit/test_pytorch.py @@ -0,0 +1,58 @@ +import gc +import os + +import pytest + +from datachain.cache import DataChainCache +from datachain.lib.pytorch import PytorchDataset +from datachain.lib.settings import Settings + + +@pytest.mark.parametrize( + "cache,prefetch", [(True, 0), (True, 10), (False, 10), (False, 0)] +) +def test_cache(catalog, cache, prefetch): + settings = Settings(cache=cache, prefetch=prefetch) + ds = PytorchDataset("fake", 1, catalog, dc_settings=settings) + assert ds.cache == cache + assert ds.prefetch == prefetch + + if cache or not prefetch: + assert catalog.cache is ds._cache + return + + assert catalog.cache is not ds._cache + head, tail = os.path.split(ds._cache.cache_dir) + assert head == catalog.cache.tmp_dir + assert tail.startswith("prefetch-") + + +@pytest.mark.parametrize("cache", [True, False]) +def test_close(mocker, catalog, cache): + spy = mocker.spy(DataChainCache, "destroy") + ds = PytorchDataset( + "fake", 1, catalog, dc_settings=Settings(cache=cache, prefetch=10) + ) + + ds.close() + + if cache: + spy.assert_not_called() + else: + spy.assert_called_once() + + +@pytest.mark.parametrize("cache", [True, False]) +def test_prefetch_cache_gets_destroyed_on_gc(mocker, catalog, cache): + spy = mocker.patch.object(DataChainCache, "destroy") + ds = PytorchDataset( + "fake", 1, catalog, dc_settings=Settings(cache=cache, prefetch=10) + ) + + del ds + gc.collect() + + if cache: + spy.assert_not_called() + else: + spy.assert_called_once()