Spaces:
Sleeping
Sleeping
import argparse | |
import yfinance as yf | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
from sklearn.preprocessing import StandardScaler | |
from datetime import datetime, timedelta | |
import os | |
# 用法示範 | |
# python3 app_Time_tcts.py --ticker AAPL --days 10 --period 6mo --cutoff 2025-03-15 --compare real | |
# ==== 簡化的 TCTS-like 結構:GRU + 動態加權(模擬效果用) ==== | |
class TCTSModel(nn.Module): | |
def __init__(self, input_dim, hidden_dim=64, output_dim=5): | |
super().__init__() | |
self.rnn = nn.GRU(input_dim, hidden_dim, batch_first=True) | |
self.attn = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.Tanh(), | |
nn.Linear(hidden_dim, 1), | |
nn.Softmax(dim=1) | |
) | |
self.fc = nn.Linear(hidden_dim, output_dim) | |
def forward(self, x): | |
rnn_out, _ = self.rnn(x) | |
weight = self.attn(rnn_out) # [B, T, 1] | |
context = torch.sum(weight * rnn_out, dim=1) | |
return self.fc(context) | |
# ==== 資料處理 ==== | |
def fetch_data(ticker, period="3mo"): | |
df = yf.download(ticker, period=period)[['Open','High','Low','Close','Volume']].dropna() | |
df.index = df.index.tz_localize(None) | |
return df | |
def prepare_data(df, window_size=10, forecast_days=5): | |
X, Y = [], [] | |
for i in range(len(df) - window_size - forecast_days): | |
X.append(df.iloc[i:i+window_size].values) | |
Y.append(df['Close'].iloc[i+window_size:i+window_size+forecast_days].values) | |
return np.array(X), np.array(Y) | |
# ==== 主流程 ==== | |
def main(ticker, forecast_days, period, cutoff_str, compare_real): | |
print(f"📈 預測 {ticker} 未來 {forecast_days} 天股價(TCTS 模型)") | |
df_all = fetch_data(ticker, period) | |
# 分割訓練資料與未來真實資料 | |
if cutoff_str: | |
cutoff = datetime.strptime(cutoff_str, "%Y-%m-%d") | |
df_train = df_all[df_all.index < cutoff] | |
df_test = df_all[df_all.index >= cutoff] | |
else: | |
cutoff = df_all.index[-1] | |
df_train = df_all | |
df_test = pd.DataFrame() | |
X, Y = prepare_data(df_train, window_size=10, forecast_days=forecast_days) | |
scaler = StandardScaler() | |
X_scaled = scaler.fit_transform(X.reshape(-1, X.shape[-1])).reshape(X.shape) | |
X_tensor = torch.tensor(X_scaled, dtype=torch.float32) | |
Y_tensor = torch.tensor(Y, dtype=torch.float32).squeeze() | |
model = TCTSModel(input_dim=X.shape[2], output_dim=forecast_days) | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
loss_fn = nn.MSELoss() | |
print("🧠 訓練中...") | |
for epoch in range(200): | |
model.train() | |
pred = model(X_tensor) | |
loss = loss_fn(pred, Y_tensor) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if epoch % 50 == 0: | |
print(f"Epoch {epoch} | Loss: {loss.item():.4f}") | |
# 預測未來 | |
latest = df_train.iloc[-10:].values.reshape(1, 10, -1) | |
latest_scaled = scaler.transform(latest.reshape(-1, latest.shape[-1])).reshape(1, 10, -1) | |
latest_tensor = torch.tensor(latest_scaled, dtype=torch.float32) | |
model.eval() | |
with torch.no_grad(): | |
forecast = model(latest_tensor).numpy()[0] | |
forecast_dates = [cutoff + timedelta(days=i+1) for i in range(forecast_days)] | |
# 畫圖 | |
plt.figure(figsize=(10, 5)) | |
plt.plot(forecast_dates, forecast, label='TCTS', color='darkred') | |
if compare_real and not df_test.empty: | |
real_segment = df_test['Close'].iloc[:forecast_days] | |
if len(real_segment) == forecast_days: | |
plt.plot(real_segment.index, real_segment.values, label='Real', color='black', linestyle='--') | |
plt.title(f"{ticker} Forecast for Next {forecast_days} Days (TCTS)") | |
plt.xlabel("Date") | |
plt.ylabel("Predicted Close Price") | |
plt.legend() | |
plt.grid(True) | |
filename = f"tcts_{ticker.lower()}_forecast.png" | |
if "DISPLAY" in os.environ: | |
plt.show() | |
else: | |
plt.savefig(filename) | |
print(f"📊 圖已儲存為 {filename}") | |
# ==== CLI ==== | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="TCTS stock price forecast with cutoff support") | |
parser.add_argument('--ticker', type=str, default='TSLA') | |
parser.add_argument('--days', type=int, default=5) | |
parser.add_argument('--period', type=str, default='3mo') | |
parser.add_argument('--cutoff', type=str, default='', help='模擬預測的起始日,如 2025-03-15') | |
parser.add_argument('--compare', type=str, default='', help='輸入 "real" 顯示真實價格線') | |
args = parser.parse_args() | |
compare_real = args.compare.lower() == 'real' | |
main(args.ticker, args.days, args.period, args.cutoff, compare_real) |