Shaikat01 commited on
Commit
928f9ee
ยท
verified ยท
1 Parent(s): 1c5e7b0

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +34 -0
  2. app.py +325 -0
  3. arima_model.pkl +3 -0
  4. lstm_model.h5 +3 -0
  5. lstm_scaler.pkl +3 -0
  6. prophet_model.pkl +3 -0
  7. requirements.txt +9 -0
README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stock Price Forecasting
3
+ emoji: ๐Ÿ“ˆ
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.16.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # Stock Price Forecasting App
14
+
15
+ This application uses three different models (ARIMA, Prophet, and LSTM) to forecast stock prices.
16
+
17
+ ## Features
18
+ - Real-time stock data fetching from Yahoo Finance
19
+ - Multiple forecasting models
20
+ - Interactive visualizations
21
+ - Customizable forecast periods
22
+
23
+ ## Models
24
+ 1. **ARIMA** - Traditional statistical model
25
+ 2. **Prophet** - Facebook's time series forecasting
26
+ 3. **LSTM** - Deep learning neural network
27
+
28
+ ## Usage
29
+ 1. Enter a stock ticker symbol (e.g., AAPL, GOOGL)
30
+ 2. Select forecast period (1-90 days)
31
+ 3. Choose which model(s) to use
32
+ 4. Click "Generate Forecast"
33
+
34
+ โš ๏ธ **Disclaimer**: For educational purposes only. Not financial advice.
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ from datetime import datetime, timedelta
6
+ import pickle
7
+ import yfinance as yf
8
+ from statsmodels.tsa.arima.model import ARIMA
9
+ from prophet import Prophet
10
+ from tensorflow import keras
11
+ from sklearn.preprocessing import MinMaxScaler
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
+
15
+ # Load your saved models (update paths as needed)
16
+ # For Hugging Face, these will be in the same directory as app.py
17
+ def load_models():
18
+ """Load all three models"""
19
+ try:
20
+ # Load ARIMA model
21
+ with open('arima_model.pkl', 'rb') as f:
22
+ arima_model = pickle.load(f)
23
+
24
+ # Load Prophet model
25
+ with open('prophet_model.pkl', 'rb') as f:
26
+ prophet_model = pickle.load(f)
27
+
28
+ # Load LSTM model and scaler
29
+ lstm_model = keras.models.load_model('lstm_model.h5')
30
+ with open('lstm_scaler.pkl', 'rb') as f:
31
+ scaler = pickle.load(f)
32
+
33
+ return arima_model, prophet_model, lstm_model, scaler
34
+ except Exception as e:
35
+ print(f"Error loading models: {e}")
36
+ return None, None, None, None
37
+
38
+ # Global variables for models
39
+ arima_model, prophet_model, lstm_model, scaler = load_models()
40
+ SEQ_LENGTH = 60 # Should match your training
41
+
42
+ def fetch_stock_data(ticker, days=365):
43
+ """Fetch stock data from Yahoo Finance"""
44
+ try:
45
+ end_date = datetime.now()
46
+ start_date = end_date - timedelta(days=days)
47
+ df = yf.download(ticker, start=start_date, end=end_date, progress=False)
48
+ if df.empty:
49
+ return None, f"No data found for ticker: {ticker}"
50
+ df = df[['Close']].copy()
51
+ df.columns = ['Price']
52
+ return df, None
53
+ except Exception as e:
54
+ return None, str(e)
55
+
56
+ def make_arima_forecast(data, days):
57
+ """Make ARIMA forecast"""
58
+ try:
59
+ # Retrain ARIMA with recent data (or use loaded model)
60
+ model = ARIMA(data['Price'], order=(1, 1, 1))
61
+ fitted = model.fit()
62
+ forecast = fitted.forecast(steps=days)
63
+ return forecast.values
64
+ except Exception as e:
65
+ print(f"ARIMA Error: {e}")
66
+ return None
67
+
68
+ def make_prophet_forecast(data, days):
69
+ """Make Prophet forecast"""
70
+ try:
71
+ # Prepare data for Prophet
72
+ prophet_data = pd.DataFrame({
73
+ 'ds': data.index,
74
+ 'y': data['Price'].values
75
+ })
76
+
77
+ # Create and fit model
78
+ model = Prophet(
79
+ daily_seasonality=True,
80
+ weekly_seasonality=True,
81
+ yearly_seasonality=True,
82
+ changepoint_prior_scale=0.05
83
+ )
84
+ model.fit(prophet_data)
85
+
86
+ # Make forecast
87
+ future = model.make_future_dataframe(periods=days)
88
+ forecast = model.predict(future)
89
+ return forecast['yhat'].tail(days).values
90
+ except Exception as e:
91
+ print(f"Prophet Error: {e}")
92
+ return None
93
+
94
+ def make_lstm_forecast(data, days, model, scaler, seq_length=60):
95
+ """Make LSTM forecast"""
96
+ try:
97
+ # Scale the data
98
+ scaled_data = scaler.transform(data[['Price']])
99
+
100
+ # Prepare the last sequence
101
+ last_sequence = scaled_data[-seq_length:].reshape(1, seq_length, 1)
102
+
103
+ predictions = []
104
+ current_sequence = last_sequence.copy()
105
+
106
+ # Generate predictions day by day
107
+ for _ in range(days):
108
+ pred = model.predict(current_sequence, verbose=0)
109
+ predictions.append(pred[0, 0])
110
+
111
+ # Update sequence
112
+ current_sequence = np.append(current_sequence[:, 1:, :],
113
+ pred.reshape(1, 1, 1), axis=1)
114
+
115
+ # Inverse transform predictions
116
+ predictions = scaler.inverse_transform(np.array(predictions).reshape(-1, 1))
117
+ return predictions.flatten()
118
+ except Exception as e:
119
+ print(f"LSTM Error: {e}")
120
+ return None
121
+
122
+ def create_forecast_plot(historical_data, forecasts, ticker, model_names):
123
+ """Create interactive plotly chart"""
124
+ fig = go.Figure()
125
+
126
+ # Historical data
127
+ fig.add_trace(go.Scatter(
128
+ x=historical_data.index,
129
+ y=historical_data['Price'],
130
+ mode='lines',
131
+ name='Historical Price',
132
+ line=dict(color='blue', width=2)
133
+ ))
134
+
135
+ # Generate future dates
136
+ last_date = historical_data.index[-1]
137
+ future_dates = pd.date_range(start=last_date + timedelta(days=1),
138
+ periods=len(forecasts[0]))
139
+
140
+ # Plot forecasts
141
+ colors = ['red', 'purple', 'orange']
142
+ for i, (forecast, name) in enumerate(zip(forecasts, model_names)):
143
+ if forecast is not None:
144
+ fig.add_trace(go.Scatter(
145
+ x=future_dates,
146
+ y=forecast,
147
+ mode='lines+markers',
148
+ name=f'{name} Forecast',
149
+ line=dict(color=colors[i], width=2, dash='dash'),
150
+ marker=dict(size=6)
151
+ ))
152
+
153
+ fig.update_layout(
154
+ title=f'{ticker} Stock Price Forecast',
155
+ xaxis_title='Date',
156
+ yaxis_title='Price ($)',
157
+ hovermode='x unified',
158
+ template='plotly_white',
159
+ height=600,
160
+ showlegend=True,
161
+ legend=dict(
162
+ yanchor="top",
163
+ y=0.99,
164
+ xanchor="left",
165
+ x=0.01
166
+ )
167
+ )
168
+
169
+ return fig
170
+
171
+ def predict_stock(ticker, forecast_days, model_choice):
172
+ """Main prediction function"""
173
+ # Validate inputs
174
+ if not ticker:
175
+ return None, "Please enter a stock ticker symbol", None
176
+
177
+ ticker = ticker.upper().strip()
178
+
179
+ # Fetch data
180
+ data, error = fetch_stock_data(ticker, days=730) # 2 years of data
181
+ if error:
182
+ return None, f"Error: {error}", None
183
+
184
+ # Make forecasts based on model choice
185
+ forecasts = []
186
+ model_names = []
187
+
188
+ if model_choice in ["All Models", "ARIMA"]:
189
+ arima_forecast = make_arima_forecast(data, forecast_days)
190
+ if arima_forecast is not None:
191
+ forecasts.append(arima_forecast)
192
+ model_names.append("ARIMA")
193
+
194
+ if model_choice in ["All Models", "Prophet"]:
195
+ prophet_forecast = make_prophet_forecast(data, forecast_days)
196
+ if prophet_forecast is not None:
197
+ forecasts.append(prophet_forecast)
198
+ model_names.append("Prophet")
199
+
200
+ if model_choice in ["All Models", "LSTM"] and lstm_model is not None:
201
+ lstm_forecast = make_lstm_forecast(data, forecast_days, lstm_model, scaler, SEQ_LENGTH)
202
+ if lstm_forecast is not None:
203
+ forecasts.append(lstm_forecast)
204
+ model_names.append("LSTM")
205
+
206
+ if not forecasts:
207
+ return None, "Failed to generate forecasts. Please try again.", None
208
+
209
+ # Create plot
210
+ fig = create_forecast_plot(data, forecasts, ticker, model_names)
211
+
212
+ # Create forecast table
213
+ future_dates = pd.date_range(
214
+ start=data.index[-1] + timedelta(days=1),
215
+ periods=forecast_days
216
+ )
217
+
218
+ forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')})
219
+ for forecast, name in zip(forecasts, model_names):
220
+ forecast_df[f'{name} Prediction ($)'] = np.round(forecast, 2)
221
+
222
+ # Summary statistics
223
+ summary = f"""
224
+ ๐Ÿ“Š **Forecast Summary for {ticker}**
225
+
226
+ - Current Price: ${data['Price'].iloc[-1]:.2f}
227
+ - Forecast Period: {forecast_days} days
228
+ - Models Used: {', '.join(model_names)}
229
+
230
+ **Predicted Price Range (Day {forecast_days}):**
231
+ """
232
+
233
+ for forecast, name in zip(forecasts, model_names):
234
+ final_price = forecast[-1]
235
+ change = ((final_price - data['Price'].iloc[-1]) / data['Price'].iloc[-1]) * 100
236
+ summary += f"\n- {name}: ${final_price:.2f} ({change:+.2f}%)"
237
+
238
+ return fig, summary, forecast_df
239
+
240
+ # Create Gradio Interface
241
+ with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo:
242
+ gr.Markdown(
243
+ """
244
+ # ๐Ÿ“ˆ Stock Price Forecasting App
245
+
246
+ Predict future stock prices using ARIMA, Prophet, and LSTM models.
247
+ Enter a stock ticker symbol and select forecast parameters below.
248
+
249
+ **Note:** Predictions are for educational purposes only. Not financial advice.
250
+ """
251
+ )
252
+
253
+ with gr.Row():
254
+ with gr.Column(scale=1):
255
+ ticker_input = gr.Textbox(
256
+ label="Stock Ticker Symbol",
257
+ placeholder="e.g., AAPL, GOOGL, TSLA",
258
+ value="AAPL"
259
+ )
260
+
261
+ forecast_days = gr.Slider(
262
+ minimum=1,
263
+ maximum=90,
264
+ value=30,
265
+ step=1,
266
+ label="Forecast Days"
267
+ )
268
+
269
+ model_choice = gr.Radio(
270
+ choices=["All Models", "ARIMA", "Prophet", "LSTM"],
271
+ value="All Models",
272
+ label="Select Model(s)"
273
+ )
274
+
275
+ predict_btn = gr.Button("๐Ÿ”ฎ Generate Forecast", variant="primary", size="lg")
276
+
277
+ with gr.Column(scale=2):
278
+ output_plot = gr.Plot(label="Forecast Visualization")
279
+
280
+ with gr.Row():
281
+ output_summary = gr.Markdown(label="Forecast Summary")
282
+
283
+ with gr.Row():
284
+ output_table = gr.Dataframe(
285
+ label="Detailed Forecast",
286
+ wrap=True,
287
+ interactive=False
288
+ )
289
+
290
+ # Examples
291
+ gr.Examples(
292
+ examples=[
293
+ ["AAPL", 30, "All Models"],
294
+ ["GOOGL", 14, "Prophet"],
295
+ ["TSLA", 60, "LSTM"],
296
+ ["MSFT", 45, "ARIMA"],
297
+ ],
298
+ inputs=[ticker_input, forecast_days, model_choice],
299
+ )
300
+
301
+ # Connect the button to the function
302
+ predict_btn.click(
303
+ fn=predict_stock,
304
+ inputs=[ticker_input, forecast_days, model_choice],
305
+ outputs=[output_plot, output_summary, output_table]
306
+ )
307
+
308
+ gr.Markdown(
309
+ """
310
+ ---
311
+ ### ๐Ÿ“š About the Models
312
+
313
+ - **ARIMA**: Statistical model for time series forecasting
314
+ - **Prophet**: Facebook's forecasting tool, excellent for seasonality
315
+ - **LSTM**: Deep learning model that captures complex patterns
316
+
317
+ ### โš ๏ธ Disclaimer
318
+ This tool is for educational and research purposes only. Stock market predictions are inherently uncertain.
319
+ Always conduct thorough research and consult with financial advisors before making investment decisions.
320
+ """
321
+ )
322
+
323
+ # Launch the app
324
+ if __name__ == "__main__":
325
+ demo.launch()
arima_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f03d08049cdf6937a0c1bd79933cdb5822a144f05313ed980e5578c5a8dff1e7
3
+ size 2287121
lstm_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04df565ab23e73b8b8af3f36f94cabe79c0867dfa9afe427c4e59fb446744583
3
+ size 426896
lstm_scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cef5a6476cc98fe5a756665fb1917a7f3a602e72b0c46c317466f29013fb27a7
3
+ size 616
prophet_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff674ae41051bc6246ea30276ccf3671b986f7dc039c078141c1de2b61801649
3
+ size 98483
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.16.0
2
+ pandas==2.1.4
3
+ numpy==1.26.3
4
+ plotly==5.18.0
5
+ yfinance==0.2.35
6
+ statsmodels==0.14.1
7
+ prophet==1.1.5
8
+ tensorflow==2.15.0
9
+ scikit-learn==1.3.2