Spaces:
Running
Running
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import os | |
| import sys | |
| import io | |
| import base64 | |
| from datetime import timedelta | |
| import yfinance as yf | |
| # Adjust path to import from the Kronos folder | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'Kronos')) | |
| try: | |
| from model import Kronos, KronosTokenizer, KronosPredictor | |
| except ImportError: | |
| # If running from a different directory (like in HF Spaces) | |
| from Kronos.model import Kronos, KronosTokenizer, KronosPredictor | |
| # Global variables to cache models | |
| _tokenizer = None | |
| _model = None | |
| _predictor = None | |
| # MONKEY PATCH: Fix the 'DatetimeIndex' has no attribute 'dt' error in the Kronos library | |
| def patched_calc_time_stamps(x_timestamp): | |
| time_df = pd.DataFrame() | |
| # Ensure we use an index or series correctly | |
| if not isinstance(x_timestamp, pd.Series): | |
| # Using .index allows us to access minute, hour, etc. on a DatetimeIndex directly | |
| idx = pd.Index(x_timestamp) | |
| else: | |
| idx = x_timestamp.dt | |
| time_df['minute'] = idx.minute | |
| time_df['hour'] = idx.hour | |
| time_df['weekday'] = idx.weekday | |
| time_df['day'] = idx.day | |
| time_df['month'] = idx.month | |
| return time_df | |
| # Apply the patch immediately if the library is loaded | |
| try: | |
| # If the module was imported via its name | |
| import model.kronos as mk | |
| mk.calc_time_stamps = patched_calc_time_stamps | |
| except Exception as e: | |
| print(f"Warning: Could not patch Kronos: {e}") | |
| def load_models(): | |
| """Load Kronos models once and cache them.""" | |
| global _tokenizer, _model, _predictor | |
| if _predictor is None: | |
| print("Loading Kronos models from Hugging Face...") | |
| _tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") | |
| _model = Kronos.from_pretrained("NeoQuasar/Kronos-small") | |
| _predictor = KronosPredictor(_model, _tokenizer, max_context=512) | |
| return _predictor | |
| def get_outlook_category(change_pct: float): | |
| """Categorize outlook based on percentage change.""" | |
| if change_pct > 30: | |
| return "Extreme Bullish" | |
| elif change_pct > 0: | |
| return "Bullish" | |
| elif change_pct < -30: | |
| return "Extreme Bearish" | |
| else: | |
| return "Bearish" | |
| def fetch_stock_data(symbol: str, lookback: int = 300, interval: str = '1d'): | |
| """Fetch historical data using yfinance and normalize for Kronos.""" | |
| fetch_symbol = symbol.strip().upper() | |
| print(f"Fetching data for {fetch_symbol} (Interval: {interval}, Lookback: {lookback})...") | |
| ticker = yf.Ticker(fetch_symbol) | |
| # Map interval to yfinance period strings | |
| # yfinance uses 'd', 'wk', 'mo' for suffixes in history(period=...) | |
| # But for simplicity, we can just use the interval and a large enough period, | |
| # or calculate the exact period. yfinance 'period' can be 'max', '1y', etc. | |
| # Since lookback is in counts, we might need to fetch more and slice. | |
| if interval == '1d': | |
| period = f"{int(lookback * 1.5)}d" # Fetch extra to account for holidays | |
| elif interval == '1wk': | |
| period = f"{int(lookback * 1.2)}wk" | |
| else: # 1mo | |
| period = f"{int(lookback * 1.1)}mo" | |
| df = ticker.history(period="max" if lookback > 500 else "2y", interval=interval) | |
| if df.empty: | |
| raise ValueError(f"Could not fetch data for symbol: {symbol}") | |
| # Slice to requested lookback | |
| df = df.tail(lookback) | |
| df = df.reset_index() | |
| # Normalize column names to lowercase | |
| df.columns = [c.lower() for c in df.columns] | |
| if 'date' in df.columns: | |
| df = df.rename(columns={'date': 'timestamps'}) | |
| return df, fetch_symbol | |
| def predict_from_df(df: pd.DataFrame, pred_len: int = 30, interval: str = '1d'): | |
| """Run prediction logic on a normalized DataFrame.""" | |
| # 1. Validate columns | |
| required_cols = ['open', 'high', 'low', 'close'] | |
| for col in required_cols: | |
| if col not in df.columns: | |
| raise ValueError(f"Required column '{col}' missing.") | |
| df[col] = df[col].astype(float) | |
| if 'timestamps' not in df.columns: | |
| raise ValueError("Missing 'timestamps' column.") | |
| df['timestamps'] = pd.to_datetime(df['timestamps']) | |
| # Handle NaNs (common in stock data) | |
| df = df.ffill().bfill() | |
| df = df.dropna(subset=required_cols) | |
| x_df = df[required_cols] | |
| if 'volume' in df.columns: | |
| x_df = df[required_cols + ['volume']] | |
| x_timestamp = df['timestamps'] | |
| last_date = df['timestamps'].iloc[-1] | |
| # Frequency mapping for future timestamps | |
| freq_map = {'1d': 'B', '1wk': 'W', '1mo': 'M'} | |
| freq = freq_map.get(interval, 'B') | |
| y_timestamp = pd.Series(pd.date_range(start=last_date + pd.Timedelta(days=1), periods=pred_len, freq=freq)) | |
| # 2. Predict | |
| predictor = load_models() | |
| pred_df = predictor.predict( | |
| df=x_df, | |
| x_timestamp=x_timestamp, | |
| y_timestamp=y_timestamp, | |
| pred_len=pred_len, | |
| T=1.0, | |
| top_p=0.9, | |
| sample_count=1, | |
| verbose=False | |
| ) | |
| # 3. Generate visualization | |
| plt.figure(figsize=(12, 6)) | |
| plt.style.use('dark_background') | |
| # Show last 100 points of history | |
| hist_to_show = df.tail(100) | |
| plt.plot(hist_to_show['timestamps'], hist_to_show['close'], label='History', color='#3b82f6', linewidth=2) | |
| plt.plot(pred_df.index, pred_df['close'], label='AI Prediction', color='#ef4444', linestyle='--', linewidth=2) | |
| plt.axvline(x=last_date, color='white', linestyle='--', alpha=0.3) | |
| plt.title(f'Kronos AI Forecast ({interval})', fontsize=14, pad=20, color='white') | |
| plt.legend() | |
| plt.grid(True, alpha=0.1) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor='#0f172a') | |
| plt.close() | |
| chart_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') | |
| # 4. Results | |
| last_price = float(df['close'].iloc[-1]) | |
| pred_price = float(pred_df['close'].iloc[-1]) | |
| change_pct = ((pred_price - last_price) / last_price) * 100 | |
| prediction_list = [] | |
| for date, row in pred_df.iterrows(): | |
| prediction_list.append({ | |
| "date": date.strftime('%Y-%m-%d'), | |
| "close": round(row['close'], 2) | |
| }) | |
| outlook = get_outlook_category(change_pct) | |
| return { | |
| "prediction": prediction_list, | |
| "chart": chart_base64, | |
| "last_price": round(last_price, 2), | |
| "predicted_price": round(pred_price, 2), | |
| "change_pct": round(change_pct, 2), | |
| "outlook": outlook, | |
| "message": f"{outlook} forecast ({round(change_pct, 1)}%)" | |
| } | |
| def predict_from_csv(csv_data: str, pred_len: int = 30): | |
| """Old entry point for CSV data.""" | |
| df = pd.read_csv(io.StringIO(csv_data)) | |
| # Minimal normalization if needed (the predict_from_df handles renaming logic if we pass it) | |
| # But for safety, we'll keep the renaming logic here or move it to a shared helper. | |
| if 'time' in df.columns: | |
| df = df.rename(columns={'time': 'timestamps'}) | |
| elif 'Date' in df.columns: | |
| df = df.rename(columns={'Date': 'timestamps'}) | |
| elif 'date' in df.columns: | |
| df = df.rename(columns={'date': 'timestamps'}) | |
| col_map = {c: c.lower() for c in df.columns if c.lower() in ['open', 'high', 'low', 'close', 'volume']} | |
| df = df.rename(columns=col_map) | |
| return predict_from_df(df, pred_len) | |
| def predict_from_symbol(symbol: str, lookback: int = 300, interval: str = '1d', pred_len: int = 30): | |
| """New entry point for symbol-based prediction.""" | |
| df, actual_symbol = fetch_stock_data(symbol, lookback, interval) | |
| result = predict_from_df(df, pred_len, interval) | |
| result["symbol"] = actual_symbol | |
| return result | |
| def process_multiple_csvs(csv_list, pred_len=30): | |
| results = [] | |
| for csv_data in csv_list: | |
| try: | |
| res = predict_from_csv(csv_data, pred_len) | |
| results.append(res) | |
| except Exception as e: | |
| results.append({'error': str(e)}) | |
| return results | |