使用python搭建一个股票训练程序
python搭建的股票训练系统
·
环境:
win10
python3.75
使用到数据: https://download.csdn.net/download/xy3233/82009700
样式
方向键 → ↑ 进入下一个交易日
数字键1是买 0是卖 (s是独立窗口的保存快捷键,所以没有用s/b)
主要参考代码: https://blog.csdn.net/xy3233/article/details/123083341
完整代码如下
# -*- coding: utf-8 -*-
import datetime
import pandas as pd
import mplfinance as mpf
import numpy as np
import matplotlib.pyplot as plt
# 独立窗口
import matplotlib
matplotlib.use('tkagg')
# 自定义风格和颜色
# 设置mplfinance的蜡烛颜色,up为阳线颜色,down为阴线颜色
my_color = mpf.make_marketcolors(up='r', # 上涨K线的柱子的内部填充色
down='g', # 下跌K线的柱子的内部填充色
edge='inherit', # 边框设置“inherit”代表使用主配色 不设置则为黑色
wick='inherit', # wick设置的就是上下影线的颜色
volume='inherit') # volume设置的是交易量柱子的颜色
# 设置图表的背景色
my_style = mpf.make_mpf_style(marketcolors=my_color,
figcolor='(0.82, 0.83, 0.85)',
gridcolor='(0.82, 0.83, 0.85)')
# 标题格式,字体为中文字体,颜色为黑色,粗体,水平中心对齐
title_font = {'fontname': 'STZhongsong',
'size': '16',
'color': 'black',
'weight': 'bold',
'va': 'bottom',
'ha': 'center'}
# 标签格式,可以显示中文,普通黑色12号字
normal_label_font = {'fontname': 'STZhongsong',
'size': '12',
'color': 'black',
'va': 'bottom',
'ha': 'right'}
def get_stock_data(file_path):
'''
数据来源
:param file_path:
:return:
'''
data = pd.read_csv(file_path, index_col=0)
data['open'] = data['open_price']
data['high'] = data['high_price']
data['low'] = data['low_price']
data['close'] = data['close_price']
data['volume'] = data['deal_quantity']
data['change'] = data['change_rate']
data.index = pd.to_datetime(data['date'], format='%Y-%m-%d')
data.rename(index=pd.Timestamp)
return data
class StockSB:
def __init__(self, stock_data, money=10000):
self.all_data = pd.DataFrame(stock_data)
# DataFrame 初始化 添加相应列
# 添加列 original(投入本金) quantity(持有数量) cost_price(成本价) stock_value(市值) usable_money(可用金额)
self.all_data.insert(self.all_data.shape[1], 'original', money, allow_duplicates=False)
self.all_data.insert(self.all_data.shape[1], 'quantity', 0, allow_duplicates=False)
self.all_data.insert(self.all_data.shape[1], 'cost_price', 0, allow_duplicates=False)
# self.all_data.insert(self.all_data.shape[1], 'stock_value', 0, allow_duplicates=False)
self.all_data.insert(self.all_data.shape[1], 'usable_money', money, allow_duplicates=False)
self.all_data.insert(self.all_data.shape[1], 'b/s', 0, allow_duplicates=False) # 买1 卖 -1
self.start = 0
self.len = 50
self.plot_data = self.all_data.iloc[self.start:self.start + self.len]
# 初始化 今日数据
self.current_data = self.plot_data.iloc[-1]
# 添加三个图表,四个数字分别代表图表左下角在figure中的坐标,以及图表的宽(0.88)、高(0.60)
self.fig = mpf.figure(figsize=(12, 8), facecolor=(0.82, 0.83, 0.85))
# 添加三个图表,四个数字分别代表图表左下角在figure中的坐标,以及图表的宽(0.88)、高(0.60)
self.price_axe = self.fig.add_axes([0.06, 0.25, 0.88, 0.60]) # 添加价格图表 K线图
# 添加第二、三张图表时,使用sharex关键字指明与ax1在x轴上对齐,且共用x轴
self.volume_axe = self.fig.add_axes([0.06, 0.15, 0.88, 0.10], sharex=self.price_axe) # 添加成交量
self.macd_axe = self.fig.add_axes([0.06, 0.05, 0.88, 0.10], sharex=self.price_axe) # 添加macd
# 设置三张图表的Y轴标签
self.price_axe.set_ylabel('price')
self.volume_axe.set_ylabel('volume')
self.macd_axe.set_ylabel('macd')
# 标题等文本
# 初始化figure对象,在figure上预先放置文本并设置格式,文本内容根据需要显示的数据实时更新
self.t1 = self.fig.text(0.50, 0.95, '000001.SH - 平安保险', **title_font)
self.t2 = self.fig.text(0.10, 0.90, '开/收: ', **normal_label_font)
self.t2_1 = self.fig.text(0.20, 0.90, f'', **normal_label_font)
self.t3 = self.fig.text(0.25, 0.90, '高: ', **normal_label_font)
self.t3_1 = self.fig.text(0.30, 0.90, f'', **normal_label_font)
self.t4 = self.fig.text(0.35, 0.90, '低: ', **normal_label_font)
self.t4_1 = self.fig.text(0.40, 0.90, f'', **normal_label_font)
self.t5 = self.fig.text(0.50, 0.90, '量(万手): ', **normal_label_font)
self.t5_1 = self.fig.text(0.55, 0.90, f'', **normal_label_font)
self.t6 = self.fig.text(0.65, 0.90, '当前时间: ', **normal_label_font)
self.t6_1 = self.fig.text(0.75, 0.90, f'', **normal_label_font)
self.t7 = self.fig.text(0.09, 0.87, f'本金', **normal_label_font)
self.t7_1 = self.fig.text(0.20, 0.87, f' ', **normal_label_font)
self.t8 = self.fig.text(0.25, 0.87, f'成本', **normal_label_font)
self.t8_1 = self.fig.text(0.30, 0.87, f' ', **normal_label_font)
self.t9 = self.fig.text(0.35, 0.87, f'总手', **normal_label_font)
self.t9_1 = self.fig.text(0.40, 0.87, f' ', **normal_label_font)
self.t10 = self.fig.text(0.50, 0.87, f'利润', **normal_label_font)
self.t10_1 = self.fig.text(0.55, 0.87, f' ', **normal_label_font)
self.t11 = self.fig.text(0.65, 0.87, f'可用金额', **normal_label_font)
self.t11_1 = self.fig.text(0.75, 0.87, f' ', **normal_label_font)
self.t12 = self.fig.text(0.85, 0.87, f'市值', **normal_label_font)
self.t12_1 = self.fig.text(0.95, 0.87, f' ', **normal_label_font)
self.fig.canvas.mpl_connect('key_press_event', self.on_key_press)
def save_stock_result(self, file_path=None):
# 保存数据
result_path = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
self.all_data.to_csv(result_path + '.csv')
def get_plot_data(self):
return self.plot_data
def get_current_data(self):
return self.current_data
def next_day(self):
print(self.current_data)
# 下一天
if self.start + self.len < self.all_data.shape[0]:
self.start = self.start + 1
self._refresh_data()
print(self.current_data)
def last_day(self):
# 前一天
if self.start > 1:
self.start = self.start - 1
self._refresh_data()
def _refresh_data(self):
# 刷新数据
self.plot_data = self.all_data.iloc[self.start:self.start + self.len]
self.current_data = self.plot_data.iloc[-1]
plot_data = self.plot_data
# 刷新图
# 读取显示区间最后一个交易日的数据
last_data = self.current_data
# 将这些数据分别填入figure对象上的文本中
self.t2_1.set_text(f'{np.round(last_data["open"], 3)} / {np.round(last_data["close"], 3)}')
self.t3_1.set_text(f'{last_data["high"]}')
self.t4_1.set_text(f'{last_data["low"]}')
self.t5_1.set_text(f'{np.round(last_data["volume"] / 10000, 3)}')
self.t6_1.set_text(f'{last_data["date"]}')
self.t7_1.set_text(f'{last_data["original"]}')
self.t8_1.set_text(f'{last_data["cost_price"]}')
self.t9_1.set_text(f'{last_data["quantity"]}')
profit = last_data['usable_money'] + last_data['quantity'] * last_data['close'] - self.current_data['original']
# 利润
self.t10_1.set_text(str(round(profit, 2)))
# 可用金额
self.t11_1.set_text(f'{round(last_data["usable_money"],2)}')
self.t12_1.set_text(str(round(last_data['quantity'] * last_data['close'], 2)))
# 生成一个空列表用于存储多个addplot
ap = []
# 添加均线
ap.append(mpf.make_addplot(plot_data[['ma_5', 'ma_20', 'ma_60']], ax=self.price_axe))
# 添加 diff 和 dea
ap.append(mpf.make_addplot(plot_data[['diff']], color='black', ax=self.macd_axe))
ap.append(mpf.make_addplot(plot_data[['dea']], color='orange', ax=self.macd_axe))
# 添加macd
ap.append(mpf.make_addplot(plot_data[['macd']], type='bar', color='green', ax=self.macd_axe))
# 调用mpf.plot()函数,这里需要指定ax=price_axe,volume=ax2,将K线图显示在ax1中,交易量显示在ax2中
mpf.plot(plot_data, ax=self.price_axe, addplot=ap, volume=self.volume_axe, type='candle', style=my_style,
xrotation=0)
mpf.show()
# self.show_stock()
def buy_stock(self, part=10):
# 买
last_stock_data = self.plot_data.iloc[-2]
# 可用金额
last_usable_money = last_stock_data['usable_money'] * part / 10
# 可购买数量 = 可用金额 / (收盘价*100) 取整
buy_quantity = (last_usable_money // (self.current_data['close'] * 100)) * 100
if buy_quantity < 100:
return
pd_index = self.current_data['date']
# 整体数量
quantity = last_stock_data['quantity'] + buy_quantity
# 今日消费
buy_money = (self.current_data['close'] * buy_quantity) + 5
# 今日成本价 = 原始成本价* 数量 + 今天成本价* 数量 + 5 / 总数量
cost_price = round((last_stock_data['cost_price'] * last_stock_data['quantity'] + buy_money) / quantity, 2)
self.all_data.loc[pd_index:, 'cost_price'] = cost_price
# 当前数量
self.all_data.loc[pd_index:, 'quantity'] = quantity
# 可用金额
self.all_data.loc[pd_index:, 'usable_money'] = last_stock_data['usable_money'] - buy_money
self.all_data.loc[pd_index, 'b/s'] = 1
self.next_day()
def sell_stock(self, part=10):
# 卖
# 数量
temp_quantity = (self.current_data['quantity'] * part / 10) // 100
if temp_quantity < 1:
return
sell_quantity = temp_quantity * 100
pd_index = self.current_data['date']
quantity = self.current_data['quantity'] - sell_quantity
self.all_data.loc[pd_index:, 'quantity'] = quantity
# 可用金额
self.all_data.loc[pd_index:, 'usable_money'] = self.current_data['usable_money'] + \
sell_quantity * self.current_data['close'] - 5
if quantity == 0:
self.all_data.loc[pd_index:, 'cost_price'] = 0
self.all_data.loc[pd_index, 'b/s'] = -1
self.next_day()
def show_stock(self):
profit = self.current_data['usable_money'] + self.current_data['quantity'] * self.current_data['close']
# 可用金额 + 市值 - 本金
print("利润" + str(profit - self.current_data['original']))
print("可用金额" + str(self.current_data['usable_money']))
def on_key_press(self, event):
key = event.key
print("key: " + key)
if key == 'enter':
# 保存图片
img_path = str(self.start) + '-' + str(self.start + self.len)
plt.savefig(img_path + '.jpg')
return
elif key == 'left' or key == 'down':
self.last_day()
elif key == 'right' or key == 'up':
self.next_day()
elif key == '1':
self.buy_stock()
elif key == '0':
self.sell_stock()
self.price_axe.clear()
self.macd_axe.clear()
self.volume_axe.clear()
if key != '1' or key != '0':
self._refresh_data()
if __name__ == '__main__':
f_path = './../0_data/000001_stock.csv'
data = get_stock_data(f_path)
# 选取我需要的数据
all_data = data[
['open', 'high', 'low', 'close', 'volume', 'ma_5', 'ma_20', 'ma_60', 'macd', 'diff', 'dea', 'date']]
start = all_data.shape[0] - 1000 # 开始序号
len = 500 # 显示长度
train_data = all_data.iloc[start:start + len]
sb = StockSB(train_data)
sb.next_day()
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
已为社区贡献2条内容
所有评论(0)