株のシステムトレードをしよう - 1から始める株自動取引システムの作り方

株式をコンピュータに売買させる仕組みを少しずつ作っていきます。できあがってから公開ではなく、書いたら途中でも記事として即掲載して、後から固定ページにして体裁を整える方式で進めていきます。

Backtrader で KABU+ のデータを読めるようにする その8

昨日の記事ではうまく動かせなかったが、それを修正して動くようにしたのが本日の記事である。

how-to-make-stock-trading-system.dogwood008.com

backtrader_plotting を動かす都合上、 Google Colab では TestStrategyWithLogger クラスを別ファイル test_strategy_with_logger.py に書き出し、それをインポートする必要があるので注意。

!pip install backtrader backtrader_plotting
path_to_csv = '/content/drive/MyDrive/Project/kabu-plus/japan-stock-prices-2_2020_9143_adjc.csv'
import pandas as pd
import backtrader as bt
csv = pd.read_csv(path_to_csv)
#############################################################
#    Copyright (C) 2020 dogwood008 (original author: Daniel Rodriguez; https://github.com/mementum/backtraders)
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <https://www.gnu.org/licenses/>.
#############################################################

import csv
import itertools
import io
from datetime import date, datetime
from backtrader.utils import date2num
from typing import Any

class KabuPlusJPCSVData(bt.feeds.YahooFinanceCSVData):
    '''
    Parses pre-downloaded KABU+ CSV Data Feeds (or locally generated if they
    comply to the Yahoo formatg)
    Specific parameters:
      - ``dataname``: The filename to parse or a file-like object
      - ``reverse`` (default: ``True``)
        It is assumed that locally stored files have already been reversed
        during the download process
      - ``round`` (default: ``True``)
        Whether to round the values to a specific number of decimals after
        having adjusted the close
      - ``roundvolume`` (default: ``0``)
        Round the resulting volume to the given number of decimals after having
        adjusted it
      - ``decimals`` (default: ``2``)
        Number of decimals to round to
      - ``swapcloses`` (default: ``False``)
        [2018-11-16] It would seem that the order of *close* and *adjusted
        close* is now fixed. The parameter is retained, in case the need to
        swap the columns again arose.
    '''
    DATE = 'date'
    OPEN = 'open'
    HIGH = 'high'
    LOW = 'low'
    CLOSE = 'close'
    VOLUME = 'volume'
    ADJUSTED_CLOSE = 'adjusted_close'
    
    params = (
        ('reverse', True),
        ('round', True),
        ('decimals', 2),
        ('roundvolume', False),
        ('swapcloses', False),
        ('headers', True),
        ('header_names', {  # CSVのカラム名と内部的なキーを変換する辞書
            DATE: 'date',
            OPEN: 'open',
            HIGH: 'high',
            LOW: 'low',
            CLOSE: 'close',
            VOLUME: 'volumes',
            ADJUSTED_CLOSE: 'adj_close',
        })
    )

    def _fetch_value(self, values: dict, column_name: str) -> Any:
        '''
        パラメタで指定された変換辞書を使用して、
        CSVで定義されたカラム名に沿って値を取得する。
        '''
        index = self._column_index(self.p.header_names[column_name])
        return values[index]

        
    def _column_index(self, column_name: str) -> int:
        '''
        与えたカラム名に対するインデックス番号を返す。
        見つからなければ ValueError を投げる。
        '''
        return self._csv_headers.index(column_name)

    # copied from https://github.com/mementum/backtrader/blob/0426c777b0abdfafbb0988f5c31347553256a2de/backtrader/feed.py#L666-L679
    def start(self):
        super(bt.feed.CSVDataBase, self).start()

        if self.f is None:
            if hasattr(self.p.dataname, 'readline'):
                self.f = self.p.dataname
            else:
                # Let an exception propagate to let the caller know
                self.f = io.open(self.p.dataname, 'r')

        if self.p.headers and self.p.header_names:
            _csv_reader = csv.reader([self.f.readline()])
            self._csv_headers = next(_csv_reader)

        self.separator = self.p.separator


    def _loadline(self, linetokens):
        while True:
            nullseen = False
            for tok in linetokens[1:]:
                if tok == 'null':
                    nullseen = True
                    linetokens = self._getnextline()  # refetch tokens
                    if not linetokens:
                        return False  # cannot fetch, go away

                    # out of for to carry on wiwth while True logic
                    break

            if not nullseen:
                break  # can proceed

        dttxt = self._fetch_value(linetokens, self.DATE)
        dt = date(int(dttxt[0:4]), int(dttxt[5:7]), int(dttxt[8:10]))
        dtnum = date2num(datetime.combine(dt, self.p.sessionend))

        self.lines.datetime[0] = dtnum
        o = float(self._fetch_value(linetokens, self.OPEN))
        h = float(self._fetch_value(linetokens, self.HIGH))
        l = float(self._fetch_value(linetokens, self.LOW))
        rawc = float(self._fetch_value(linetokens, self.CLOSE))
        self.lines.openinterest[0] = 0.0

        adjustedclose = float(self._fetch_value(linetokens, self.ADJUSTED_CLOSE))
        v = float(self._fetch_value(linetokens, self.VOLUME))

        if self.p.swapcloses:  # swap closing prices if requested
            rawc, adjustedclose = adjustedclose, rawc

        adjfactor = rawc / adjustedclose

        o /= adjfactor
        h /= adjfactor
        l /= adjfactor
        v *= adjfactor

        if self.p.round:
            decimals = self.p.decimals
            o = round(o, decimals)
            h = round(h, decimals)
            l = round(l, decimals)
            rawc = round(rawc, decimals)

        v = round(v, self.p.roundvolume)

        self.lines.open[0] = o
        self.lines.high[0] = h
        self.lines.low[0] = l
        self.lines.close[0] = adjustedclose
        self.lines.volume[0] = v

        return True
import backtrader as bt
from logging import getLogger, StreamHandler, Formatter, DEBUG, INFO

from backtrader_plotting import Bokeh
from backtrader_plotting.schemes import Tradimo

# Create a Stratey
class TestStrategyWithLogger(bt.Strategy):
    params = (
        ('size', 1000),
        ('smaperiod', 5),
    )

    def _log(self, txt, dt=None):
        ''' Logging function for this strategy '''
        dt = dt or self.datas[0].datetime.date(0)
        self._logger.debug('%s, %s' % (dt.isoformat(), txt))

    def __init__(self, loglevel):
        # Keep a reference to the "close" line in the data[0] dataseries
        self._dataclose = self.datas[0].close
        self._dataadjclose = self.datas[0].adjclose
        self._datavolume = self.datas[0].volume
        self._logger = getLogger(__name__)
        self.handler = StreamHandler()
        self.handler.setLevel(loglevel)
        self._logger.setLevel(loglevel)
        self._logger.addHandler(self.handler)
        self._logger.propagate = False
        self.handler.setFormatter(
                Formatter('[%(levelname)s] %(message)s'))
        self.sma = bt.indicators.SimpleMovingAverage(
            self.datas[0], period=self.params.smaperiod)                

    def _log(self, txt, loglevel=INFO, dt=None):
        ''' Logging function for this strategy '''
        dt = dt or self.datas[0].datetime.date(0)
        self._logger.log(loglevel, '%s, %s' % (dt.isoformat(), txt))

    def _debug(self, txt, dt=None):
        self._log(txt, DEBUG, dt)

    def _info(self, txt, dt=None):
        self._log(txt, INFO, dt)

    def next(self):
        # Simply log the closing price of the series from the reference
        # self._log('(Close, Adj. Close, Volume) = ({close:>5.2f}, {adjc:>5.2f}, {vol:>010.2f})'.format(
        #           close=self._dataclose[0], adjc=self._dataadjclose[0], vol=self._datavolume[0]))

        if self._dataclose[0] < self._dataclose[-1]:
            # current close less than previous close

            if self._dataclose[-1] < self._dataclose[-2]:
                # previous close less than the previous close

                # BUY, BUY, BUY!!! (with all possible default parameters)
                self._info('BUY CREATE, %.2f' % self._dataclose[0])
                self.buy(size=self.params.size)

            if self._dataclose[-2] * .97 > self._dataclose[-1]:
                # 前日が前々日の3%を超える下落時に手仕舞い
                self.close()                  
# ↑ここまでを test_strategy_with_logger.py に保存しておく(backtrader_plottingのエラー回避のため)
del TestStrategyWithLogger
from test_strategy_with_logger import TestStrategyWithLogger
class BackTest:
    def __init__(self):
        self.cerebro = bt.Cerebro()

        data = KabuPlusJPCSVData(
            dataname=path_to_csv,
            fromdate=datetime(2020, 1, 1),
            todate=datetime(2020, 11, 30),
            reverse=False)

        self.cerebro.adddata(data)
        # Add a strategy
        IN_DEVELOPMENT = True  # このフラグにより、ログレベルを切り替えることで、本番ではWARN以上のみをログに出すようにする。
        # フラグの切り替えは、環境変数で行う事が望ましいが今は一旦先送りする。
        loglevel = DEBUG if IN_DEVELOPMENT else WARN
        self.cerebro.broker.setcash(1000 * 10000)
        self.cerebro.addstrategy(TestStrategyWithLogger, loglevel)

    def run(self):
        initial_cash = self.cerebro.broker.getvalue()
        self.cerebro.run()

        print('Initial Portfolio Value: {val:,}'.format(val=initial_cash))
        print('Final Portfolio Value: {val:,}'.format(val=int(self.cerebro.broker.getvalue())))

        save_file = False
        if save_file:
            b = Bokeh(style='bar', plot_mode='single', scheme=Tradimo(), output_mode='save', filename='chart.html')
        else:
            b = Bokeh(style='bar', plot_mode='single', scheme=Tradimo())
        return self.cerebro.plot(b, iplot=not save_file)

if __name__ == '__main__':
    backtest = BackTest()
    chart = backtest.run()
    from IPython.display import display
    display(chart[0][0])

出力

f:id:dogwood008:20201229234309p:plain f:id:dogwood008:20201229234322p:plain f:id:dogwood008:20201229234343p:plain

(C) 2020 dogwood008 禁無断転載 不許複製 Reprinting, reproducing are prohibited.