Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from datetime import datetime, timedelta | |
| import pickle | |
| import yfinance as yf | |
| import os | |
| import re | |
| from statsmodels.tsa.arima.model import ARIMA | |
| from prophet import Prophet | |
| from tensorflow import keras | |
| from sklearn.preprocessing import MinMaxScaler | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Load your saved models (update paths as needed) | |
| # For Hugging Face, these will be in the same directory as app.py | |
| def load_models(): | |
| """Load all three models""" | |
| try: | |
| # Load ARIMA model | |
| with open('arima_model.pkl', 'rb') as f: | |
| arima_model = pickle.load(f) | |
| # Load Prophet model | |
| with open('prophet_model.pkl', 'rb') as f: | |
| prophet_model = pickle.load(f) | |
| # Load LSTM model and scaler | |
| lstm_model = keras.models.load_model('lstm_model.h5') | |
| with open('lstm_scaler.pkl', 'rb') as f: | |
| scaler = pickle.load(f) | |
| return arima_model, prophet_model, lstm_model, scaler | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| return None, None, None, None | |
| # Global variables for models | |
| arima_model, prophet_model, lstm_model, scaler = load_models() | |
| SEQ_LENGTH = 60 # Should match your training | |
| def fetch_stock_data(ticker, days=365): | |
| """Fetch stock data from Yahoo Finance""" | |
| try: | |
| # Prefer local CSV file named <TICKER>.csv in the project root | |
| csv_name = f"{ticker.upper()}.csv" | |
| workspace_dir = os.path.dirname(__file__) | |
| csv_path = os.path.join(workspace_dir, csv_name) | |
| if os.path.exists(csv_path): | |
| # Read the CSV fully, then detect which column contains dates. Many of | |
| # the CSVs here contain extra header/noise rows; reading everything and | |
| # filtering non-date rows is more robust than skipping rows beforehand. | |
| df_raw = pd.read_csv(csv_path, header=0, dtype=str) | |
| # Try to detect a date column by checking which column's values look like dates | |
| date_col = None | |
| for col in df_raw.columns: | |
| sample = df_raw[col].astype(str).head(20) | |
| matches = sample.str.match(r"^\s*\d{4}-\d{2}-\d{2}") | |
| if matches.sum() >= max(1, int(len(sample) * 0.5)): | |
| date_col = col | |
| break | |
| if date_col is None and 'Date' in df_raw.columns: | |
| date_col = 'Date' | |
| if date_col is not None: | |
| df_raw[date_col] = pd.to_datetime(df_raw[date_col], errors='coerce') | |
| df = df_raw.dropna(subset=[date_col]).copy() | |
| df.set_index(date_col, inplace=True) | |
| else: | |
| # Try parsing the index as dates (if CSV had implicit index) | |
| try: | |
| df_raw.index = pd.to_datetime(df_raw.index) | |
| df = df_raw.copy() | |
| except Exception: | |
| # Give up and use raw DataFrame โ downstream checks will catch issues | |
| df = df_raw.copy() | |
| # Prefer 'Close' column, fall back to common alternatives | |
| if 'Close' in df.columns: | |
| df = df[['Close']].copy() | |
| elif 'Adj Close' in df.columns: | |
| df = df[['Adj Close']].copy() | |
| df.columns = ['Close'] | |
| elif 'Close*' in df.columns: | |
| df = df[['Close*']].copy() | |
| df.columns = ['Close'] | |
| else: | |
| # Try to find a column that looks like price | |
| possible = [c for c in df.columns if 'close' in c.lower() or 'price' in c.lower()] | |
| if possible: | |
| df = df[[possible[0]]].copy() | |
| df.columns = ['Close'] | |
| else: | |
| return None, f"Local CSV found but no 'Close' column in {csv_name}" | |
| # Coerce to numeric price and drop rows that can't be converted | |
| df.columns = ['Price'] | |
| df['Price'] = pd.to_numeric(df['Price'], errors='coerce') | |
| df.dropna(subset=['Price'], inplace=True) | |
| # Ensure sorted by date | |
| df.sort_index(inplace=True) | |
| # Remove index name to avoid printing a duplicate label | |
| try: | |
| df.index.name = None | |
| except Exception: | |
| pass | |
| # Slice to the requested window (last `days` days) | |
| if days is not None and days > 0: | |
| start_dt = df.index.max() - timedelta(days=days - 1) | |
| df = df.loc[df.index >= start_dt] | |
| if df.empty: | |
| return None, f"No data in local CSV for the requested period: {csv_name}" | |
| return df, None | |
| except Exception as e: | |
| return None, f"Error fetching stock data: {e}" | |
| def make_arima_forecast(data, days): | |
| """Make ARIMA forecast""" | |
| try: | |
| # Retrain ARIMA with recent data (or use loaded model) | |
| model = ARIMA(data['Price'], order=(1, 1, 1)) | |
| fitted = model.fit() | |
| forecast = fitted.forecast(steps=days) | |
| return forecast.values | |
| except Exception as e: | |
| print(f"ARIMA Error: {e}") | |
| return None | |
| def make_prophet_forecast(data, days): | |
| """Make Prophet forecast""" | |
| try: | |
| # Prepare data for Prophet | |
| prophet_data = pd.DataFrame({ | |
| 'ds': data.index, | |
| 'y': data['Price'].values | |
| }) | |
| # Create and fit model | |
| model = Prophet( | |
| daily_seasonality=True, | |
| weekly_seasonality=True, | |
| yearly_seasonality=True, | |
| changepoint_prior_scale=0.05 | |
| ) | |
| model.fit(prophet_data) | |
| # Make forecast | |
| future = model.make_future_dataframe(periods=days) | |
| forecast = model.predict(future) | |
| return forecast['yhat'].tail(days).values | |
| except Exception as e: | |
| print(f"Prophet Error: {e}") | |
| return None | |
| def make_lstm_forecast(data, days, model, scaler, seq_length=60): | |
| """Make LSTM forecast""" | |
| try: | |
| # Scale the data | |
| scaled_data = scaler.transform(data[['Price']]) | |
| # Prepare the last sequence | |
| last_sequence = scaled_data[-seq_length:].reshape(1, seq_length, 1) | |
| predictions = [] | |
| current_sequence = last_sequence.copy() | |
| # Generate predictions day by day | |
| for _ in range(days): | |
| pred = model.predict(current_sequence, verbose=0) | |
| predictions.append(pred[0, 0]) | |
| # Update sequence | |
| current_sequence = np.append(current_sequence[:, 1:, :], | |
| pred.reshape(1, 1, 1), axis=1) | |
| # Inverse transform predictions | |
| predictions = scaler.inverse_transform(np.array(predictions).reshape(-1, 1)) | |
| return predictions.flatten() | |
| except Exception as e: | |
| print(f"LSTM Error: {e}") | |
| return None | |
| def create_forecast_plot(historical_data, forecasts, ticker, model_names): | |
| """Create interactive plotly chart""" | |
| fig = go.Figure() | |
| # Historical data | |
| fig.add_trace(go.Scatter( | |
| x=historical_data.index, | |
| y=historical_data['Price'], | |
| mode='lines', | |
| name='Historical Price', | |
| line=dict(color='blue', width=2) | |
| )) | |
| # Generate future dates | |
| last_date = historical_data.index[-1] | |
| future_dates = pd.date_range(start=last_date + timedelta(days=1), | |
| periods=len(forecasts[0])) | |
| # Plot forecasts | |
| colors = ['red', 'purple', 'orange'] | |
| for i, (forecast, name) in enumerate(zip(forecasts, model_names)): | |
| if forecast is not None: | |
| fig.add_trace(go.Scatter( | |
| x=future_dates, | |
| y=forecast, | |
| mode='lines+markers', | |
| name=f'{name} Forecast', | |
| line=dict(color=colors[i], width=2, dash='dash'), | |
| marker=dict(size=6) | |
| )) | |
| fig.update_layout( | |
| title=f'{ticker} Stock Price Forecast', | |
| xaxis_title='Date', | |
| yaxis_title='Price ($)', | |
| hovermode='x unified', | |
| template='plotly_white', | |
| height=600, | |
| showlegend=True, | |
| legend=dict( | |
| yanchor="top", | |
| y=0.99, | |
| xanchor="left", | |
| x=0.01 | |
| ) | |
| ) | |
| return fig | |
| def predict_stock(ticker, forecast_days, model_choice): | |
| """Main prediction function""" | |
| # Validate inputs | |
| if not ticker: | |
| return None, "Please enter a stock ticker symbol", None | |
| ticker = ticker.upper().strip() | |
| # Fetch data | |
| data, error = fetch_stock_data(ticker, days=730) # 2 years of data | |
| if error: | |
| return None, f"Error: {error}", None | |
| # Make forecasts based on model choice | |
| forecasts = [] | |
| model_names = [] | |
| if model_choice in ["All Models", "ARIMA"]: | |
| arima_forecast = make_arima_forecast(data, forecast_days) | |
| if arima_forecast is not None: | |
| forecasts.append(arima_forecast) | |
| model_names.append("ARIMA") | |
| if model_choice in ["All Models", "Prophet"]: | |
| prophet_forecast = make_prophet_forecast(data, forecast_days) | |
| if prophet_forecast is not None: | |
| forecasts.append(prophet_forecast) | |
| model_names.append("Prophet") | |
| if model_choice in ["All Models", "LSTM"] and lstm_model is not None: | |
| lstm_forecast = make_lstm_forecast(data, forecast_days, lstm_model, scaler, SEQ_LENGTH) | |
| if lstm_forecast is not None: | |
| forecasts.append(lstm_forecast) | |
| model_names.append("LSTM") | |
| if not forecasts: | |
| return None, "Failed to generate forecasts. Please try again.", None | |
| # Create plot | |
| fig = create_forecast_plot(data, forecasts, ticker, model_names) | |
| # Create forecast table | |
| future_dates = pd.date_range( | |
| start=data.index[-1] + timedelta(days=1), | |
| periods=forecast_days | |
| ) | |
| forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')}) | |
| for forecast, name in zip(forecasts, model_names): | |
| forecast_df[f'{name} Prediction ($)'] = np.round(forecast, 2) | |
| # Summary statistics | |
| summary = f""" | |
| ๐ **Forecast Summary for {ticker}** | |
| - Current Price: ${data['Price'].iloc[-1]:.2f} | |
| - Forecast Period: {forecast_days} days | |
| - Models Used: {', '.join(model_names)} | |
| **Predicted Price Range (Day {forecast_days}):** | |
| """ | |
| for forecast, name in zip(forecasts, model_names): | |
| final_price = forecast[-1] | |
| change = ((final_price - data['Price'].iloc[-1]) / data['Price'].iloc[-1]) * 100 | |
| summary += f"\n- {name}: ${final_price:.2f} ({change:+.2f}%)" | |
| return fig, summary, forecast_df | |
| # Create Gradio Interface | |
| with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐ Stock Price Forecasting App | |
| Predict future stock prices using ARIMA, Prophet, and LSTM models. | |
| Enter a stock ticker symbol and select forecast parameters below. | |
| **Note:** Predictions are for educational purposes only. Not financial advice. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ticker_input = gr.Textbox( | |
| label="Stock Ticker Symbol", | |
| placeholder="e.g., AAPL, GOOGL, TSLA", | |
| value="AAPL" | |
| ) | |
| forecast_days = gr.Slider( | |
| minimum=1, | |
| maximum=90, | |
| value=30, | |
| step=1, | |
| label="Forecast Days" | |
| ) | |
| model_choice = gr.Radio( | |
| choices=["All Models", "ARIMA", "Prophet", "LSTM"], | |
| value="All Models", | |
| label="Select Model(s)" | |
| ) | |
| predict_btn = gr.Button("๐ฎ Generate Forecast", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| output_plot = gr.Plot(label="Forecast Visualization") | |
| with gr.Row(): | |
| output_summary = gr.Markdown(label="Forecast Summary") | |
| with gr.Row(): | |
| output_table = gr.Dataframe( | |
| label="Detailed Forecast", | |
| wrap=True, | |
| interactive=False | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["AAPL", 30, "All Models"], | |
| ["GOOGL", 14, "Prophet"], | |
| ["TSLA", 60, "LSTM"], | |
| ["MSFT", 45, "ARIMA"], | |
| ], | |
| inputs=[ticker_input, forecast_days, model_choice], | |
| ) | |
| # Connect the button to the function | |
| predict_btn.click( | |
| fn=predict_stock, | |
| inputs=[ticker_input, forecast_days, model_choice], | |
| outputs=[output_plot, output_summary, output_table] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### ๐ About the Models | |
| - **ARIMA**: Statistical model for time series forecasting | |
| - **Prophet**: Facebook's forecasting tool, excellent for seasonality | |
| - **LSTM**: Deep learning model that captures complex patterns | |
| ### โ ๏ธ Disclaimer | |
| This tool is for educational and research purposes only. Stock market predictions are inherently uncertain. | |
| Always conduct thorough research and consult with financial advisors before making investment decisions. | |
| """ | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |