Skip to content

Commit

Permalink
BUG: Allow multiple names for vector indicators (kernc#382)
Browse files Browse the repository at this point in the history
Previously we only allowed one name per vector
indicator:

    def _my_indicator(open, close):
    	return tuple(
	    _my_indicator_one(open, close),
	    _my_indicator_two(open, close),
	)

    self.I(
        _my_indicator,
	# One name is used to describe two values
        name="My Indicator",
	self.data.Open,
	self.data.Close
    )

Now, the user can supply two (or more) names to annotate
each value individually. The names will be shown in the
plot legend. The following is now valid:

    self.I(
        _my_indicator,
	# Two names can now be passed
        name=["My Indicator One", "My Indicator Two"],
	self.data.Open,
	self.data.Close
    )

Co-authored-by: kernc <[email protected]>
  • Loading branch information
ivaigult and kernc committed May 9, 2023
1 parent 0ce24d8 commit 172767c
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 13 deletions.
32 changes: 22 additions & 10 deletions backtesting/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import warnings
from colorsys import hls_to_rgb, rgb_to_hls
from itertools import cycle, combinations
from itertools import cycle, combinations, repeat
from functools import partial
from typing import Callable, List, Union

Expand Down Expand Up @@ -537,10 +537,22 @@ def __eq__(self, other):
colors = value._opts['color']
colors = colors and cycle(_as_list(colors)) or (
cycle([next(ohlc_colors)]) if is_overlay else colorgen())
legend_label = LegendStr(value.name)
for j, arr in enumerate(value, 1):

tooltip_label = value.name if isinstance(value.name, str) else ", ".join(value.name)

if len(value) == 1:
legend_labels = [LegendStr(item) for item in _as_list(value.name)]
elif isinstance(value.name, str):
legend_labels = [
LegendStr(f"{name}[{index}]")
for index, name in enumerate(repeat(value.name, len(value)))
]
else:
legend_labels = [LegendStr(item) for item in value.name]

for j, arr in enumerate(value):
color = next(colors)
source_name = f'{legend_label}_{i}_{j}'
source_name = f'{legend_labels[j]}_{i}_{j}'
if arr.dtype == bool:
arr = arr.astype(int)
source.add(arr, source_name)
Expand All @@ -550,24 +562,24 @@ def __eq__(self, other):
if is_scatter:
fig.scatter(
'index', source_name, source=source,
legend_label=legend_label, color=color,
legend_label=legend_labels[j], color=color,
line_color='black', fill_alpha=.8,
marker='circle', radius=BAR_WIDTH / 2 * 1.5)
else:
fig.line(
'index', source_name, source=source,
legend_label=legend_label, line_color=color,
legend_label=legend_labels[j], line_color=color,
line_width=1.3)
else:
if is_scatter:
r = fig.scatter(
'index', source_name, source=source,
legend_label=LegendStr(legend_label), color=color,
legend_label=legend_labels[j], color=color,
marker='circle', radius=BAR_WIDTH / 2 * .9)
else:
r = fig.line(
'index', source_name, source=source,
legend_label=LegendStr(legend_label), line_color=color,
legend_label=legend_labels[j], line_color=color,
line_width=1.3)
# Add dashed centerline just because
mean = float(pd.Series(arr).mean())
Expand All @@ -578,9 +590,9 @@ def __eq__(self, other):
line_color='#666666', line_dash='dashed',
line_width=.5))
if is_overlay:
ohlc_tooltips.append((legend_label, NBSP.join(tooltips)))
ohlc_tooltips.append((tooltip_label, NBSP.join(tooltips)))
else:
set_tooltips(fig, [(legend_label, NBSP.join(tooltips))], vline=True, renderers=[r])
set_tooltips(fig, [(tooltip_label, NBSP.join(tooltips))], vline=True, renderers=[r])
# If the sole indicator line on this figure,
# have the legend only contain text without the glyph
if len(value) == 1:
Expand Down
21 changes: 18 additions & 3 deletions backtesting/backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def I(self, # noqa: E743
same length as `backtesting.backtesting.Strategy.data`.
In the plot legend, the indicator is labeled with
function name, unless `name` overrides it.
function name, unless `name` overrides it. If `func` returns
multiple arrays, `name` can be a sequence of strings, and
its size must agree with the number of arrays returned.
If `plot` is `True`, the indicator is plotted on the resulting
`backtesting.backtesting.Backtest.plot`.
Expand All @@ -115,13 +117,21 @@ def I(self, # noqa: E743
def init():
self.sma = self.I(ta.SMA, self.data.Close, self.n_sma)
"""
def _format_name(name: str) -> str:
return name.format(*map(_as_str, args),
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))

if name is None:
params = ','.join(filter(None, map(_as_str, chain(args, kwargs.values()))))
func_name = _as_str(func)
name = (f'{func_name}({params})' if params else f'{func_name}')
elif isinstance(name, str):
name = _format_name(name)
elif try_(lambda: all(isinstance(item, str) for item in name), False):
name = [_format_name(item) for item in name]
else:
name = name.format(*map(_as_str, args),
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))
raise TypeError(f'Unexpected `name=` type {type(name)}; expected `str` or '
'`Sequence[str]`')

try:
value = func(*args, **kwargs)
Expand All @@ -139,6 +149,11 @@ def init():
if is_arraylike and np.argmax(value.shape) == 0:
value = value.T

if isinstance(name, list) and (np.atleast_2d(value).shape[0] != len(name)):
raise ValueError(
f'Length of `name=` ({len(name)}) must agree with the number '
f'of arrays the indicator returns ({value.shape[0]}).')

if not is_arraylike or not 1 <= value.ndim <= 2 or value.shape[-1] != len(self._data.Close):
raise ValueError(
'Indicators must return (optionally a tuple of) numpy.arrays of same '
Expand Down
31 changes: 31 additions & 0 deletions backtesting/test/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,37 @@ def test_resample(self):
# Give browser time to open before tempfile is removed
time.sleep(1)

def test_indicator_name(self):
test_self = self

class S(Strategy):
def init(self):
def _SMA():
return SMA(self.data.Close, 5), SMA(self.data.Close, 10)

test_self.assertRaises(TypeError, self.I, _SMA, name=42)
test_self.assertRaises(ValueError, self.I, _SMA, name=("SMA One", ))
test_self.assertRaises(
ValueError, self.I, _SMA, name=("SMA One", "SMA Two", "SMA Three"))

for overlay in (True, False):
self.I(SMA, self.data.Close, 5, overlay=overlay)
self.I(SMA, self.data.Close, 5, name="My SMA", overlay=overlay)
self.I(SMA, self.data.Close, 5, name=("My SMA", ), overlay=overlay)
self.I(_SMA, overlay=overlay)
self.I(_SMA, name="My SMA", overlay=overlay)
self.I(_SMA, name=("SMA One", "SMA Two"), overlay=overlay)

def next(self):
pass

bt = Backtest(GOOG, S)
bt.run()
with _tempfile() as f:
bt.plot(filename=f,
plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False,
open_browser=False)

def test_indicator_color(self):
class S(Strategy):
def init(self):
Expand Down

0 comments on commit 172767c

Please sign in to comment.