Spaces:
Sleeping
Sleeping
import argparse | |
import yfinance as yf | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
from sklearn.preprocessing import StandardScaler | |
import os | |
from datetime import datetime, timedelta | |
import pandas as pd | |
#用法 | |
# python3 app_Time_sandwich.py --ticker AAPL --days 10 --period 6mo | |
# python3 app_Time_sandwich.py --ticker AAPL --days 10 --period 6mo --cutoff 2025-03-15 --compare real | |
# ==== Sandwich 模型結構 ==== | |
class SandwichModel(nn.Module): | |
def __init__(self, input_dim, hidden_dim=64, output_dim=5): | |
super().__init__() | |
self.encoder1 = nn.Linear(input_dim, hidden_dim) | |
self.encoder2 = nn.Linear(hidden_dim, hidden_dim) | |
self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True) | |
self.decoder1 = nn.Linear(hidden_dim, hidden_dim) | |
self.decoder2 = nn.Linear(hidden_dim, output_dim) | |
def forward(self, x): | |
x = torch.relu(self.encoder1(x)) | |
x = torch.relu(self.encoder2(x)) | |
out, _ = self.lstm(x) | |
out = torch.relu(self.decoder1(out[:, -1, :])) | |
return self.decoder2(out) | |
# ==== 資料處理 ==== | |
def fetch_data(ticker, period="3mo"): | |
df = yf.download(ticker, period=period)[['Open','High','Low','Close','Volume']].dropna() | |
df.index = df.index.tz_localize(None) | |
return df | |
def prepare_data(df, window_size=10, forecast_days=5): | |
X, Y = [], [] | |
for i in range(len(df) - window_size - forecast_days): | |
X.append(df.iloc[i:i+window_size].values) | |
Y.append(df['Close'].iloc[i+window_size:i+window_size+forecast_days].values) | |
return np.array(X), np.array(Y) | |
# ==== 主流程 ==== | |
def main(ticker, forecast_days, period, cutoff_str, compare_real): | |
print(f"📈 預測 {ticker} 未來 {forecast_days} 天股價(使用 Sandwich 模型)") | |
df_all = fetch_data(ticker, period) | |
if cutoff_str: | |
cutoff = datetime.strptime(cutoff_str, "%Y-%m-%d") | |
df_train = df_all[df_all.index < cutoff] | |
df_test = df_all[df_all.index >= cutoff] | |
else: | |
cutoff = df_all.index[-1] | |
df_train = df_all | |
df_test = pd.DataFrame() | |
X, Y = prepare_data(df_train, window_size=10, forecast_days=forecast_days) | |
scaler = StandardScaler() | |
X_scaled = scaler.fit_transform(X.reshape(-1, X.shape[-1])).reshape(X.shape) | |
X_tensor = torch.tensor(X_scaled, dtype=torch.float32) | |
Y_tensor = torch.tensor(Y, dtype=torch.float32) | |
if Y_tensor.ndim == 3: | |
Y_tensor = Y_tensor.squeeze(-1) | |
model = SandwichModel(input_dim=X.shape[2], output_dim=forecast_days) | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
loss_fn = nn.MSELoss() | |
print("🧠 開始訓練...") | |
for epoch in range(200): | |
model.train() | |
pred = model(X_tensor) | |
loss = loss_fn(pred, Y_tensor) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if epoch % 50 == 0: | |
print(f"Epoch {epoch} | Loss: {loss.item():.4f}") | |
# 預測未來 | |
latest = df_train.iloc[-10:].values.reshape(1, 10, -1) | |
latest_scaled = scaler.transform(latest.reshape(-1, latest.shape[-1])).reshape(1, 10, -1) | |
latest_tensor = torch.tensor(latest_scaled, dtype=torch.float32) | |
model.eval() | |
with torch.no_grad(): | |
forecast = model(latest_tensor).numpy()[0] | |
forecast_dates = [cutoff + timedelta(days=i+1) for i in range(forecast_days)] | |
# 畫圖 | |
plt.figure(figsize=(10, 5)) | |
plt.plot(forecast_dates, forecast, label='Sandwich', color='teal') | |
if compare_real and not df_test.empty: | |
real_segment = df_test['Close'].iloc[:forecast_days] | |
if len(real_segment) == forecast_days: | |
plt.plot(real_segment.index, real_segment.values, label='Real', color='black', linestyle='--') | |
plt.title(f"{ticker} Forecast for Next {forecast_days} Days (Sandwich)") | |
plt.xlabel("Date") | |
plt.ylabel("Predicted Close Price") | |
plt.legend() | |
plt.grid(True) | |
filename = f"sandwich_{ticker.lower()}_forecast.png" | |
if "DISPLAY" in os.environ: | |
plt.show() | |
else: | |
plt.savefig(filename) | |
print(f"📊 圖已儲存為 {filename}") | |
# ==== CLI ==== | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Sandwich-based stock price forecast with backtest support") | |
parser.add_argument('--ticker', type=str, default='TSLA') | |
parser.add_argument('--days', type=int, default=5) | |
parser.add_argument('--period', type=str, default='3mo') | |
parser.add_argument('--cutoff', type=str, default='', help='模擬預測的起始日,如 2025-03-15') | |
parser.add_argument('--compare', type=str, default='', help='輸入 "real" 顯示真實價格線') | |
args = parser.parse_args() | |
compare_real = args.compare.lower() == 'real' | |
main(args.ticker, args.days, args.period, args.cutoff, compare_real) |