From 281fc0a3d3a4414fb477cbdbffeeca9caca0bb07 Mon Sep 17 00:00:00 2001 From: Ivan Vaigult Date: Mon, 8 May 2023 19:28:01 +0100 Subject: [PATCH] BUG: Allow multiple names for vector indicators (#382) Previously we only allowed one name per vector indicator: def _my_indicator(open, close): return tuple( _my_indicator_one(open, close), _my_indicator_two(open, close), ) self.I( _my_indicator, # One name is used to describe two values name="My Indicator", self.data.Open, self.data.Close ) Now, the user can supply two (or more) names to annotate each value individually. The names will be shown in the plot legend. The following is now valid: self.I( _my_indicator, # One name is used to describe two values name=["My Indicator One", "My Indicator Two"], self.data.Open, self.data.Close ) --- backtesting/_plotting.py | 34 ++++++++++++++++++++++++---------- backtesting/backtesting.py | 21 ++++++++++++++++++--- backtesting/test/_test.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/backtesting/_plotting.py b/backtesting/_plotting.py index 844318aa..0f284a0a 100644 --- a/backtesting/_plotting.py +++ b/backtesting/_plotting.py @@ -3,7 +3,7 @@ import sys import warnings from colorsys import hls_to_rgb, rgb_to_hls -from itertools import cycle, combinations +from itertools import cycle, combinations, repeat from functools import partial from typing import Callable, List, Union @@ -537,10 +537,24 @@ def __eq__(self, other): colors = value._opts['color'] colors = colors and cycle(_as_list(colors)) or ( cycle([next(ohlc_colors)]) if is_overlay else colorgen()) - legend_label = LegendStr(value.name) - for j, arr in enumerate(value, 1): + + tooltip_label = value.name + + if len(value) == 1: + assert isinstance(value.name, str) + legend_labels = [LegendStr(value.name)] + elif isinstance(value.name, str): + legend_labels = [ + LegendStr(f"{name}[{index}]") + for index, name in enumerate(repeat(value.name, len(value))) + ] + else: + legend_labels = [LegendStr(item) for item in value.name] + tooltip_label = ", ".join(value.name) + + for j, arr in enumerate(value): color = next(colors) - source_name = f'{legend_label}_{i}_{j}' + source_name = f'{legend_labels[j]}_{i}_{j}' if arr.dtype == bool: arr = arr.astype(int) source.add(arr, source_name) @@ -550,24 +564,24 @@ def __eq__(self, other): if is_scatter: fig.scatter( 'index', source_name, source=source, - legend_label=legend_label, color=color, + legend_label=legend_labels[j], color=color, line_color='black', fill_alpha=.8, marker='circle', radius=BAR_WIDTH / 2 * 1.5) else: fig.line( 'index', source_name, source=source, - legend_label=legend_label, line_color=color, + legend_label=legend_labels[j], line_color=color, line_width=1.3) else: if is_scatter: r = fig.scatter( 'index', source_name, source=source, - legend_label=LegendStr(legend_label), color=color, + legend_label=legend_labels[j], color=color, marker='circle', radius=BAR_WIDTH / 2 * .9) else: r = fig.line( 'index', source_name, source=source, - legend_label=LegendStr(legend_label), line_color=color, + legend_label=legend_labels[j], line_color=color, line_width=1.3) # Add dashed centerline just because mean = float(pd.Series(arr).mean()) @@ -578,9 +592,9 @@ def __eq__(self, other): line_color='#666666', line_dash='dashed', line_width=.5)) if is_overlay: - ohlc_tooltips.append((legend_label, NBSP.join(tooltips))) + ohlc_tooltips.append((tooltip_label, NBSP.join(tooltips))) else: - set_tooltips(fig, [(legend_label, NBSP.join(tooltips))], vline=True, renderers=[r]) + set_tooltips(fig, [(tooltip_label, NBSP.join(tooltips))], vline=True, renderers=[r]) # If the sole indicator line on this figure, # have the legend only contain text without the glyph if len(value) == 1: diff --git a/backtesting/backtesting.py b/backtesting/backtesting.py index 9c168703..3e5a1022 100644 --- a/backtesting/backtesting.py +++ b/backtesting/backtesting.py @@ -90,7 +90,9 @@ def I(self, # noqa: E743 same length as `backtesting.backtesting.Strategy.data`. In the plot legend, the indicator is labeled with - function name, unless `name` overrides it. + function name, unless `name` overrides it. If `func` returns + multiple arrays, `name` can be a collection of strings, and + the size must agree with the number of arrays returned. If `plot` is `True`, the indicator is plotted on the resulting `backtesting.backtesting.Backtest.plot`. @@ -115,13 +117,21 @@ def I(self, # noqa: E743 def init(): self.sma = self.I(ta.SMA, self.data.Close, self.n_sma) """ + def _format_name(name: str) -> str: + return name.format(*map(_as_str, args), + **dict(zip(kwargs.keys(), map(_as_str, kwargs.values())))) + if name is None: params = ','.join(filter(None, map(_as_str, chain(args, kwargs.values())))) func_name = _as_str(func) name = (f'{func_name}({params})' if params else f'{func_name}') + elif isinstance(name, str): + name = _format_name(name) + elif try_(lambda: all(isinstance(item, str) for item in name), False): + name = [_format_name(item) for item in name] else: - name = name.format(*map(_as_str, args), - **dict(zip(kwargs.keys(), map(_as_str, kwargs.values())))) + raise TypeError(f'Unexpected `name` type {type(name)}, `str` or `Iterable[str]` ' + 'was expected.') try: value = func(*args, **kwargs) @@ -139,6 +149,11 @@ def init(): if is_arraylike and np.argmax(value.shape) == 0: value = value.T + if isinstance(name, list) and (value.ndim != 2 or value.shape[0] != len(name)): + raise ValueError( + f'The number of `name` elements ({len(name)}) must agree with the nubmer ' + f'of arrays ({value.shape[0]}) the indicator returns.') + if not is_arraylike or not 1 <= value.ndim <= 2 or value.shape[-1] != len(self._data.Close): raise ValueError( 'Indicators must return (optionally a tuple of) numpy.arrays of same ' diff --git a/backtesting/test/_test.py b/backtesting/test/_test.py index d8d87814..6f5fab18 100644 --- a/backtesting/test/_test.py +++ b/backtesting/test/_test.py @@ -755,6 +755,36 @@ def test_resample(self): # Give browser time to open before tempfile is removed time.sleep(1) + def test_indicator_name(self): + test_self = self + + class S(Strategy): + def init(self): + def _SMA(): + return SMA(self.data.Close, 5), SMA(self.data.Close, 10) + + test_self.assertRaises(TypeError, self.I, _SMA, name=42) + test_self.assertRaises(ValueError, self.I, _SMA, name=("SMA One", )) + test_self.assertRaises( + ValueError, self.I, _SMA, name=("SMA One", "SMA Two", "SMA Three")) + + for overlay in (True, False): + self.I(SMA, self.data.Close, 5, overlay=overlay) + self.I(SMA, self.data.Close, 5, name="My SMA", overlay=overlay) + self.I(_SMA, overlay=overlay) + self.I(_SMA, name="My SMA", overlay=overlay) + self.I(_SMA, name=("SMA One", "SMA Two"), overlay=overlay) + + def next(self): + pass + + bt = Backtest(GOOG, S) + bt.run() + with _tempfile() as f: + bt.plot(filename=f, + plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False, + open_browser=False) + def test_indicator_color(self): class S(Strategy): def init(self):