stock-qlib-public / app_GBDT_catboost.py
tbdavid2019's picture
Add application file
d7d2bc1
import argparse
import yfinance as yf
import pandas as pd
import numpy as np
from catboost import CatBoostRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import os
def fetch_data(ticker, period="3mo"):
df = yf.download(ticker, period=period)[['Open', 'High', 'Low', 'Close', 'Volume']].dropna()
return df
def prepare_training_data(df, window_size=5, forecast_days=5):
X, Y = [], []
for i in range(len(df) - window_size - forecast_days):
X.append(df.iloc[i:i+window_size].values.flatten())
Y.append(df['Close'].iloc[i+window_size : i+window_size+forecast_days].values)
return np.array(X), np.array(Y).reshape(len(Y), -1)
def main(ticker, forecast_days):
print(f"下載 {ticker} 股票資料...")
df = fetch_data(ticker)
print("建立訓練資料...")
X, Y = prepare_training_data(df, window_size=10, forecast_days=forecast_days)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
print("訓練 CatBoost 模型...")
model = MultiOutputRegressor(CatBoostRegressor(
iterations=500,
learning_rate=0.05,
depth=6,
verbose=100
))
model.fit(X_scaled, Y)
# 使用最新一筆資料預測未來
print(f"使用最新資料預測未來 {forecast_days} 天...")
latest_input = df.iloc[-10:].values.flatten().reshape(1, -1)
latest_input_scaled = scaler.transform(latest_input)
forecast = model.predict(latest_input_scaled)[0]
# 繪圖
plt.figure(figsize=(10, 5))
plt.plot(range(1, forecast_days + 1), forecast, label='Predicted', color='orange')
plt.title(f"{ticker} Forecast for Next {forecast_days} Days (CatBoost)")
plt.xlabel("Days Ahead")
plt.ylabel("Predicted Close Price")
plt.legend()
plt.grid(True)
if "DISPLAY" in os.environ:
plt.show()
else:
filename = f"catboost_{ticker.lower()}_forecast.png"
plt.savefig(filename)
print(f"圖已儲存為 {filename}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Forecast future stock price using CatBoost")
parser.add_argument('--ticker', type=str, default='TSLA', help='股票代碼,例如 TSLA, AAPL')
parser.add_argument('--days', type=int, default=5, help='預測未來幾天的收盤價')
args = parser.parse_args()
main(args.ticker, args.days)