From ae50f6ed3c64c7269732c50fcb4506aa02048602 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Fri, 2 Oct 2015 22:13:43 -0400 Subject: [PATCH] ENH: add concat Adds a top level function `concat` and a `Cycler` method `concat` which will concatenate two cyclers. The method can be chained. closes #1 --- cycler.py | 35 +++++++++++++++++++++++++++++++++++ test_cycler.py | 20 ++++++++++++++++++-- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/cycler.py b/cycler.py index 48a2b3f..9e603ba 100644 --- a/cycler.py +++ b/cycler.py @@ -389,6 +389,41 @@ def simplify(self): trans = self._transpose() return reduce(add, (_cycler(k, v) for k, v in six.iteritems(trans))) + def concat(self, other): + return concat(self, other) + + +def concat(left, right): + """Concatenate two cyclers. + + The keys must match exactly. + + This returns a single Cycler which is equivalent to + `itertools.chain(left, right)` + + Parameters + ---------- + left, right : `Cycler` + The two `Cycler` instances to concatenate + + Returns + ------- + ret : `Cycler` + The concatenated `Cycler` + """ + if left.keys != right.keys: + msg = '\n\t'.join(["Keys do not match:", + "Intersection: {both!r}", + "Disjoint: {just_one!r}" + ]).format( + both=left.keys&right.keys, + just_one=left.keys^right.keys) + + raise ValueError(msg) + + _l = left._transpose() + _r = right._transpose() + return reduce(add, (_cycler(k, _l[k] + _r[k]) for k in left.keys)) def cycler(*args, **kwargs): """ diff --git a/test_cycler.py b/test_cycler.py index 6d93566..c8c2f8c 100644 --- a/test_cycler.py +++ b/test_cycler.py @@ -2,10 +2,10 @@ import six from six.moves import zip, range -from cycler import cycler, Cycler +from cycler import cycler, Cycler, concat from nose.tools import (assert_equal, assert_not_equal, assert_raises, assert_true) -from itertools import product, cycle +from itertools import product, cycle, chain from operator import add, iadd, mul, imul @@ -279,3 +279,19 @@ def test_starange_init(): c2 = cycler('lw', range(3)) cy = Cycler(list(c), list(c2), zip) assert_equal(cy, c + c2) + + +def test_concat(): + a = cycler('a', range(3)) + for con, chn in zip(a.concat(a), chain(a, a)): + assert_equal(con, chn) + + for con, chn in zip(concat(a, a), chain(a, a)): + assert_equal(con, chn) + + +def test_concat_fail(): + a = cycler('a', range(3)) + b = cycler('b', range(3)) + assert_raises(ValueError, concat, a, b) + assert_raises(ValueError, a.concat, b)