Spaces:
Sleeping
Sleeping
import argparse | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
from catboost import CatBoostRegressor | |
from sklearn.multioutput import MultiOutputRegressor | |
from sklearn.preprocessing import StandardScaler | |
import matplotlib.pyplot as plt | |
import os | |
def fetch_data(ticker, period="3mo"): | |
df = yf.download(ticker, period=period)[['Open', 'High', 'Low', 'Close', 'Volume']].dropna() | |
return df | |
def prepare_training_data(df, window_size=5, forecast_days=5): | |
X, Y = [], [] | |
for i in range(len(df) - window_size - forecast_days): | |
X.append(df.iloc[i:i+window_size].values.flatten()) | |
Y.append(df['Close'].iloc[i+window_size : i+window_size+forecast_days].values) | |
return np.array(X), np.array(Y).reshape(len(Y), -1) | |
def main(ticker, forecast_days): | |
print(f"下載 {ticker} 股票資料...") | |
df = fetch_data(ticker) | |
print("建立訓練資料...") | |
X, Y = prepare_training_data(df, window_size=10, forecast_days=forecast_days) | |
scaler = StandardScaler() | |
X_scaled = scaler.fit_transform(X) | |
print("訓練 CatBoost 模型...") | |
model = MultiOutputRegressor(CatBoostRegressor( | |
iterations=500, | |
learning_rate=0.05, | |
depth=6, | |
verbose=100 | |
)) | |
model.fit(X_scaled, Y) | |
# 使用最新一筆資料預測未來 | |
print(f"使用最新資料預測未來 {forecast_days} 天...") | |
latest_input = df.iloc[-10:].values.flatten().reshape(1, -1) | |
latest_input_scaled = scaler.transform(latest_input) | |
forecast = model.predict(latest_input_scaled)[0] | |
# 繪圖 | |
plt.figure(figsize=(10, 5)) | |
plt.plot(range(1, forecast_days + 1), forecast, label='Predicted', color='orange') | |
plt.title(f"{ticker} Forecast for Next {forecast_days} Days (CatBoost)") | |
plt.xlabel("Days Ahead") | |
plt.ylabel("Predicted Close Price") | |
plt.legend() | |
plt.grid(True) | |
if "DISPLAY" in os.environ: | |
plt.show() | |
else: | |
filename = f"catboost_{ticker.lower()}_forecast.png" | |
plt.savefig(filename) | |
print(f"圖已儲存為 {filename}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Forecast future stock price using CatBoost") | |
parser.add_argument('--ticker', type=str, default='TSLA', help='股票代碼,例如 TSLA, AAPL') | |
parser.add_argument('--days', type=int, default=5, help='預測未來幾天的收盤價') | |
args = parser.parse_args() | |
main(args.ticker, args.days) |