import pandas as pd import matplotlib.pyplot as plt import os import sys import io import base64 from datetime import timedelta import yfinance as yf # Adjust path to import from the Kronos folder sys.path.append(os.path.join(os.path.dirname(__file__), 'Kronos')) try: from model import Kronos, KronosTokenizer, KronosPredictor except ImportError: # If running from a different directory (like in HF Spaces) from Kronos.model import Kronos, KronosTokenizer, KronosPredictor # Global variables to cache models _tokenizer = None _model = None _predictor = None # MONKEY PATCH: Fix the 'DatetimeIndex' has no attribute 'dt' error in the Kronos library def patched_calc_time_stamps(x_timestamp): time_df = pd.DataFrame() # Ensure we use an index or series correctly if not isinstance(x_timestamp, pd.Series): # Using .index allows us to access minute, hour, etc. on a DatetimeIndex directly idx = pd.Index(x_timestamp) else: idx = x_timestamp.dt time_df['minute'] = idx.minute time_df['hour'] = idx.hour time_df['weekday'] = idx.weekday time_df['day'] = idx.day time_df['month'] = idx.month return time_df # Apply the patch immediately if the library is loaded try: # If the module was imported via its name import model.kronos as mk mk.calc_time_stamps = patched_calc_time_stamps except Exception as e: print(f"Warning: Could not patch Kronos: {e}") def load_models(): """Load Kronos models once and cache them.""" global _tokenizer, _model, _predictor if _predictor is None: print("Loading Kronos models from Hugging Face...") _tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") _model = Kronos.from_pretrained("NeoQuasar/Kronos-small") _predictor = KronosPredictor(_model, _tokenizer, max_context=512) return _predictor def get_outlook_category(change_pct: float): """Categorize outlook based on percentage change.""" if change_pct > 30: return "Extreme Bullish" elif change_pct > 0: return "Bullish" elif change_pct < -30: return "Extreme Bearish" else: return "Bearish" def fetch_stock_data(symbol: str, lookback: int = 300, interval: str = '1d'): """Fetch historical data using yfinance and normalize for Kronos.""" fetch_symbol = symbol.strip().upper() print(f"Fetching data for {fetch_symbol} (Interval: {interval}, Lookback: {lookback})...") ticker = yf.Ticker(fetch_symbol) # Map interval to yfinance period strings # yfinance uses 'd', 'wk', 'mo' for suffixes in history(period=...) # But for simplicity, we can just use the interval and a large enough period, # or calculate the exact period. yfinance 'period' can be 'max', '1y', etc. # Since lookback is in counts, we might need to fetch more and slice. if interval == '1d': period = f"{int(lookback * 1.5)}d" # Fetch extra to account for holidays elif interval == '1wk': period = f"{int(lookback * 1.2)}wk" else: # 1mo period = f"{int(lookback * 1.1)}mo" df = ticker.history(period="max" if lookback > 500 else "2y", interval=interval) if df.empty: raise ValueError(f"Could not fetch data for symbol: {symbol}") # Slice to requested lookback df = df.tail(lookback) df = df.reset_index() # Normalize column names to lowercase df.columns = [c.lower() for c in df.columns] if 'date' in df.columns: df = df.rename(columns={'date': 'timestamps'}) return df, fetch_symbol def predict_from_df(df: pd.DataFrame, pred_len: int = 30, interval: str = '1d'): """Run prediction logic on a normalized DataFrame.""" # 1. Validate columns required_cols = ['open', 'high', 'low', 'close'] for col in required_cols: if col not in df.columns: raise ValueError(f"Required column '{col}' missing.") df[col] = df[col].astype(float) if 'timestamps' not in df.columns: raise ValueError("Missing 'timestamps' column.") df['timestamps'] = pd.to_datetime(df['timestamps']) # Handle NaNs (common in stock data) df = df.ffill().bfill() df = df.dropna(subset=required_cols) x_df = df[required_cols] if 'volume' in df.columns: x_df = df[required_cols + ['volume']] x_timestamp = df['timestamps'] last_date = df['timestamps'].iloc[-1] # Frequency mapping for future timestamps freq_map = {'1d': 'B', '1wk': 'W', '1mo': 'M'} freq = freq_map.get(interval, 'B') y_timestamp = pd.Series(pd.date_range(start=last_date + pd.Timedelta(days=1), periods=pred_len, freq=freq)) # 2. Predict predictor = load_models() pred_df = predictor.predict( df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, pred_len=pred_len, T=1.0, top_p=0.9, sample_count=1, verbose=False ) # 3. Generate visualization plt.figure(figsize=(12, 6)) plt.style.use('dark_background') # Show last 100 points of history hist_to_show = df.tail(100) plt.plot(hist_to_show['timestamps'], hist_to_show['close'], label='History', color='#3b82f6', linewidth=2) plt.plot(pred_df.index, pred_df['close'], label='AI Prediction', color='#ef4444', linestyle='--', linewidth=2) plt.axvline(x=last_date, color='white', linestyle='--', alpha=0.3) plt.title(f'Kronos AI Forecast ({interval})', fontsize=14, pad=20, color='white') plt.legend() plt.grid(True, alpha=0.1) buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor='#0f172a') plt.close() chart_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') # 4. Results last_price = float(df['close'].iloc[-1]) pred_price = float(pred_df['close'].iloc[-1]) change_pct = ((pred_price - last_price) / last_price) * 100 prediction_list = [] for date, row in pred_df.iterrows(): prediction_list.append({ "date": date.strftime('%Y-%m-%d'), "close": round(row['close'], 2) }) outlook = get_outlook_category(change_pct) return { "prediction": prediction_list, "chart": chart_base64, "last_price": round(last_price, 2), "predicted_price": round(pred_price, 2), "change_pct": round(change_pct, 2), "outlook": outlook, "message": f"{outlook} forecast ({round(change_pct, 1)}%)" } def predict_from_csv(csv_data: str, pred_len: int = 30): """Old entry point for CSV data.""" df = pd.read_csv(io.StringIO(csv_data)) # Minimal normalization if needed (the predict_from_df handles renaming logic if we pass it) # But for safety, we'll keep the renaming logic here or move it to a shared helper. if 'time' in df.columns: df = df.rename(columns={'time': 'timestamps'}) elif 'Date' in df.columns: df = df.rename(columns={'Date': 'timestamps'}) elif 'date' in df.columns: df = df.rename(columns={'date': 'timestamps'}) col_map = {c: c.lower() for c in df.columns if c.lower() in ['open', 'high', 'low', 'close', 'volume']} df = df.rename(columns=col_map) return predict_from_df(df, pred_len) def predict_from_symbol(symbol: str, lookback: int = 300, interval: str = '1d', pred_len: int = 30): """New entry point for symbol-based prediction.""" df, actual_symbol = fetch_stock_data(symbol, lookback, interval) result = predict_from_df(df, pred_len, interval) result["symbol"] = actual_symbol return result def process_multiple_csvs(csv_list, pred_len=30): results = [] for csv_data in csv_list: try: res = predict_from_csv(csv_data, pred_len) results.append(res) except Exception as e: results.append({'error': str(e)}) return results