Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extensible serializers support #209

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flask_caching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def cached(
cache_none: bool=False,
make_cache_key: Optional[Callable]=None,
source_check: Optional[bool]=None,
force_tuple: bool = True,
) -> Callable:
"""Decorator. Use this to cache a function. By default the cache key
is `view/request.path`. You are able to use this decorator with any
Expand Down Expand Up @@ -403,6 +404,8 @@ def get_list():
formed with the function's source code hash in
addition to other parameters that may be included
in the formation of the key.
:param force_tuple: Default True. Cast output from list to tuple.
JSON doesn't support tuple, but Flask expects it.
"""

def decorator(f):
Expand Down Expand Up @@ -453,6 +456,9 @@ def decorated_function(*args, **kwargs):
found = False
else:
found = self.cache.has(cache_key)
elif force_tuple and isinstance(rv, list) and len(rv) == 2:
# JSON compatibility for flask
rv = tuple(rv)
except Exception:
if self.app.debug:
raise
Expand Down
35 changes: 34 additions & 1 deletion flask_caching/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
:copyright: (c) 2010 by Thadeus Burgess.
:license: BSD, see LICENSE for more details.
"""
import warnings
try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle # type: ignore


def iteritems_wrapper(mappingorseq):
Expand All @@ -28,19 +33,47 @@ def iteritems_wrapper(mappingorseq):
return mappingorseq


def extract_serializer_args(data):
result = dict()
serializer_prefix = "serializer_"
for key in tuple(data.keys()):
if key.startswith(serializer_prefix):
result[key] = data.pop(key)
return result


class BaseCache(object):
"""Baseclass for the cache systems. All the cache systems implement this
API or a superset of it.

:param default_timeout: The default timeout (in seconds) that is used if
no timeout is specified on :meth:`set`. A timeout
of 0 indicates that the cache never expires.
:param serializer_impl: Pickle-like serialization implementation. It should
subnix marked this conversation as resolved.
Show resolved Hide resolved
support load(-s) and dump(-s) methods and binary
strings/files.
:param serializer_error: Deserialization exception - for specified
implementation.
"""

def __init__(self, default_timeout=300):
def __init__(
self,
default_timeout=300,
serializer_impl=pickle,
serializer_error=pickle.PickleError,
):
self.default_timeout = default_timeout
self.ignore_errors = False

if serializer_impl is pickle:
warnings.warn(
"Pickle serializer is not secure and may "
"lead to remote code execution. "
"Consider using another serializer (eg. JSON)."
)
self._serializer = serializer_impl
self._serialization_error = serializer_error

def _normalize_timeout(self, timeout):
if timeout is None:
timeout = self.default_timeout
Expand Down
27 changes: 12 additions & 15 deletions flask_caching/backends/filesystemcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
import tempfile
from time import time

from flask_caching.backends.base import BaseCache

try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle # type: ignore
from flask_caching.backends.base import BaseCache, extract_serializer_args


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,8 +59,11 @@ def __init__(
mode=0o600,
hash_method=hashlib.md5,
ignore_errors=False,
**kwargs
):
super(FileSystemCache, self).__init__(default_timeout)
super(FileSystemCache, self).__init__(
default_timeout, **extract_serializer_args(kwargs)
)
self._path = cache_dir
self._threshold = threshold
self._mode = mode
Expand Down Expand Up @@ -129,7 +127,7 @@ def _prune(self):
try:
remove = False
with open(fname, "rb") as f:
expires = pickle.load(f)
expires, _ = self._serializer.load(f)
remove = (expires != 0 and expires <= now) or idx % 3 == 0
if remove:
os.remove(fname)
Expand Down Expand Up @@ -162,16 +160,16 @@ def get(self, key):
filename = self._get_filename(key)
try:
with open(filename, "rb") as f:
pickle_time = pickle.load(f)
pickle_time, result = self._serializer.load(f)
expired = pickle_time != 0 and pickle_time < time()
if expired:
result = None
os.remove(filename)
else:
hit_or_miss = "hit"
result = pickle.load(f)
except FileNotFoundError:
pass
except (IOError, OSError, pickle.PickleError) as exc:
except (IOError, OSError, self._serialization_error) as exc:
logger.error("get key %r -> %s", key, exc)
expiredstr = "(expired)" if expired else ""
logger.debug("get key %r -> %s %s", key, hit_or_miss, expiredstr)
Expand Down Expand Up @@ -205,8 +203,7 @@ def set(self, key, value, timeout=None, mgmt_element=False):
suffix=self._fs_transaction_suffix, dir=self._path
)
with os.fdopen(fd, "wb") as f:
pickle.dump(timeout, f, 1)
pickle.dump(value, f, pickle.HIGHEST_PROTOCOL)
self._serializer.dump((timeout, value), f)
os.replace(tmp, filename)
os.chmod(filename, self._mode)
except (IOError, OSError) as exc:
Expand Down Expand Up @@ -241,15 +238,15 @@ def has(self, key):
filename = self._get_filename(key)
try:
with open(filename, "rb") as f:
pickle_time = pickle.load(f)
pickle_time, _ = self._serializer.load(f)
expired = pickle_time != 0 and pickle_time < time()
if expired:
os.remove(filename)
else:
result = True
except FileNotFoundError:
pass
except (IOError, OSError, pickle.PickleError) as exc:
except (IOError, OSError, self._serialization_error) as exc:
logger.error("get key %r -> %s", key, exc)
expiredstr = "(expired)" if expired else ""
logger.debug("has key %r -> %s %s", key, result, expiredstr)
Expand Down
21 changes: 11 additions & 10 deletions flask_caching/backends/memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
"""
from time import time

from flask_caching.backends.base import BaseCache, iteritems_wrapper

try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle # type: ignore
from flask_caching.backends.base import (
BaseCache,
extract_serializer_args,
iteritems_wrapper,
)


_test_memcached_key = re.compile(r"[^\x00-\x21\xff]{1,250}$").match
Expand Down Expand Up @@ -60,8 +59,10 @@ class MemcachedCache(BaseCache):
different prefix.
"""

def __init__(self, servers=None, default_timeout=300, key_prefix=None):
super(MemcachedCache, self).__init__(default_timeout)
def __init__(self, servers=None, default_timeout=300, key_prefix=None, **kwargs):
super(MemcachedCache, self).__init__(
default_timeout, **extract_serializer_args(kwargs)
)
if servers is None or isinstance(servers, (list, tuple)):
if servers is None:
servers = ["127.0.0.1:11211"]
Expand Down Expand Up @@ -294,7 +295,7 @@ def _set(self, key, value, timeout=None):
# I didn't found a good way to avoid pickling/unpickling if
# key is smaller than chunksize, because in case or <werkzeug.requests>
# getting the length consume the data iterator.
serialized = pickle.dumps(value, 2)
serialized = self._serializer.dumps(value)
values = {}
len_ser = len(serialized)
chks = range(0, len_ser, self.chunksize)
Expand Down Expand Up @@ -333,4 +334,4 @@ def _get(self, key):
if not serialized:
return None

return pickle.loads(serialized)
return self._serializer.loads(serialized)
21 changes: 9 additions & 12 deletions flask_caching/backends/rediscache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
:copyright: (c) 2010 by Thadeus Burgess.
:license: BSD, see LICENSE for more details.
"""
from flask_caching.backends.base import BaseCache, iteritems_wrapper

try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle # type: ignore
from flask_caching.backends.base import (
BaseCache, extract_serializer_args, iteritems_wrapper
)


class RedisCache(BaseCache):
Expand Down Expand Up @@ -49,7 +46,7 @@ def __init__(
key_prefix=None,
**kwargs
):
super().__init__(default_timeout)
super().__init__(default_timeout, **extract_serializer_args(kwargs))
if host is None:
raise ValueError("RedisCache host parameter may not be None")
if isinstance(host, str):
Expand Down Expand Up @@ -90,7 +87,7 @@ def dump_object(self, value):
t = type(value)
if t == int:
return str(value).encode("ascii")
return b"!" + pickle.dumps(value)
return b"!" + self._serializer.dumps(value)

def load_object(self, value):
"""The reversal of :meth:`dump_object`. This might be called with
Expand All @@ -100,8 +97,8 @@ def load_object(self, value):
return None
if value.startswith(b"!"):
try:
return pickle.loads(value[1:])
except pickle.PickleError:
return self._serializer.loads(value[1:])
except self._serialization_error:
return None
try:
return int(value)
Expand Down Expand Up @@ -273,7 +270,7 @@ def __init__(
self._read_clients = sentinel.slave_for(master)

self.key_prefix = key_prefix or ""

class RedisClusterCache(RedisCache):
"""Uses the Redis key-value store as a cache backend.

Expand Down Expand Up @@ -324,7 +321,7 @@ def __init__(self,
# Skips the check of cluster-require-full-coverage config,
# useful for clusters without the CONFIG command (like aws)
skip_full_coverage_check = kwargs.pop('skip_full_coverage_check', True)

cluster = RedisCluster(startup_nodes=startup_nodes,
password=password,
skip_full_coverage_check=skip_full_coverage_check,
Expand Down
21 changes: 10 additions & 11 deletions flask_caching/backends/simplecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
import logging
from time import time

from flask_caching.backends.base import BaseCache

try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle # type: ignore
from flask_caching.backends.base import BaseCache, extract_serializer_args


logger = logging.getLogger(__name__)
Expand All @@ -41,8 +36,12 @@ class SimpleCache(BaseCache):
``False``.
"""

def __init__(self, threshold=500, default_timeout=300, ignore_errors=False):
super(SimpleCache, self).__init__(default_timeout)
def __init__(
self, threshold=500, default_timeout=300, ignore_errors=False, **kwargs
):
super(SimpleCache, self).__init__(
default_timeout, **extract_serializer_args(kwargs)
)
self._cache = {}
self.clear = self._cache.clear
self._threshold = threshold
Expand Down Expand Up @@ -78,7 +77,7 @@ def get(self, key):
if not expired:
hit_or_miss = "hit"
try:
result = pickle.loads(value)
result = self._serializer.loads(value)
except Exception as exc:
logger.error("get key %r -> %s", key, exc)
expiredstr = "(expired)" if expired else ""
Expand All @@ -88,15 +87,15 @@ def get(self, key):
def set(self, key, value, timeout=None):
expires = self._normalize_timeout(timeout)
self._prune()
item = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL))
item = (expires, self._serializer.dumps(value))
self._cache[key] = item
logger.debug("set key %r", key)
return True

def add(self, key, value, timeout=None):
expires = self._normalize_timeout(timeout)
self._prune()
item = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL))
item = (expires, self._serializer.dumps(value))
updated = False
should_add = key not in self._cache
if should_add:
Expand Down
19 changes: 8 additions & 11 deletions flask_caching/backends/uwsgicache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@
"""
import platform

from flask_caching.backends.base import BaseCache

try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle # type: ignore
from flask_caching.backends.base import BaseCache, extract_serializer_args


class UWSGICache(BaseCache):
Expand All @@ -34,8 +29,10 @@ class UWSGICache(BaseCache):
you only have to provide the name of the cache.
"""

def __init__(self, default_timeout=300, cache=""):
super(UWSGICache, self).__init__(default_timeout)
def __init__(self, default_timeout=300, cache="", **kwargs):
super(UWSGICache, self).__init__(
default_timeout, **extract_serializer_args(kwargs)
)

if platform.python_implementation() == "PyPy":
raise RuntimeError(
Expand All @@ -62,23 +59,23 @@ def get(self, key):
rv = self._uwsgi.cache_get(key, self.cache)
if rv is None:
return
return pickle.loads(rv)
return self._serializer.loads(rv)

def delete(self, key):
return self._uwsgi.cache_del(key, self.cache)

def set(self, key, value, timeout=None):
return self._uwsgi.cache_update(
key,
pickle.dumps(value),
self._serializer.dumps(value),
self._normalize_timeout(timeout),
self.cache,
)

def add(self, key, value, timeout=None):
return self._uwsgi.cache_set(
key,
pickle.dumps(value),
self._serializer.dumps(value),
self._normalize_timeout(timeout),
self.cache,
)
Expand Down
Loading