Skip to content

Commit

Permalink
ENH: add concat
Browse files Browse the repository at this point in the history
Adds a top level function `concat` and a `Cycler` method `concat` which
will concatenate two cyclers.  The method can be chained.

closes matplotlib#1
  • Loading branch information
tacaswell committed Oct 3, 2015
1 parent 1f3e5c4 commit ae50f6e
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
35 changes: 35 additions & 0 deletions cycler.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,41 @@ def simplify(self):
trans = self._transpose()
return reduce(add, (_cycler(k, v) for k, v in six.iteritems(trans)))

def concat(self, other):
return concat(self, other)


def concat(left, right):
"""Concatenate two cyclers.
The keys must match exactly.
This returns a single Cycler which is equivalent to
`itertools.chain(left, right)`
Parameters
----------
left, right : `Cycler`
The two `Cycler` instances to concatenate
Returns
-------
ret : `Cycler`
The concatenated `Cycler`
"""
if left.keys != right.keys:
msg = '\n\t'.join(["Keys do not match:",
"Intersection: {both!r}",
"Disjoint: {just_one!r}"
]).format(
both=left.keys&right.keys,
just_one=left.keys^right.keys)

raise ValueError(msg)

_l = left._transpose()
_r = right._transpose()
return reduce(add, (_cycler(k, _l[k] + _r[k]) for k in left.keys))

def cycler(*args, **kwargs):
"""
Expand Down
20 changes: 18 additions & 2 deletions test_cycler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import six
from six.moves import zip, range
from cycler import cycler, Cycler
from cycler import cycler, Cycler, concat
from nose.tools import (assert_equal, assert_not_equal,
assert_raises, assert_true)
from itertools import product, cycle
from itertools import product, cycle, chain
from operator import add, iadd, mul, imul


Expand Down Expand Up @@ -279,3 +279,19 @@ def test_starange_init():
c2 = cycler('lw', range(3))
cy = Cycler(list(c), list(c2), zip)
assert_equal(cy, c + c2)


def test_concat():
a = cycler('a', range(3))
for con, chn in zip(a.concat(a), chain(a, a)):
assert_equal(con, chn)

for con, chn in zip(concat(a, a), chain(a, a)):
assert_equal(con, chn)


def test_concat_fail():
a = cycler('a', range(3))
b = cycler('b', range(3))
assert_raises(ValueError, concat, a, b)
assert_raises(ValueError, a.concat, b)

0 comments on commit ae50f6e

Please sign in to comment.