stock-predictor / kronos_logic.py
Shokat's picture
Upload 3 files
bf43499 verified
import pandas as pd
import matplotlib.pyplot as plt
import os
import sys
import io
import base64
from datetime import timedelta
import yfinance as yf
# Adjust path to import from the Kronos folder
sys.path.append(os.path.join(os.path.dirname(__file__), 'Kronos'))
try:
from model import Kronos, KronosTokenizer, KronosPredictor
except ImportError:
# If running from a different directory (like in HF Spaces)
from Kronos.model import Kronos, KronosTokenizer, KronosPredictor
# Global variables to cache models
_tokenizer = None
_model = None
_predictor = None
# MONKEY PATCH: Fix the 'DatetimeIndex' has no attribute 'dt' error in the Kronos library
def patched_calc_time_stamps(x_timestamp):
time_df = pd.DataFrame()
# Ensure we use an index or series correctly
if not isinstance(x_timestamp, pd.Series):
# Using .index allows us to access minute, hour, etc. on a DatetimeIndex directly
idx = pd.Index(x_timestamp)
else:
idx = x_timestamp.dt
time_df['minute'] = idx.minute
time_df['hour'] = idx.hour
time_df['weekday'] = idx.weekday
time_df['day'] = idx.day
time_df['month'] = idx.month
return time_df
# Apply the patch immediately if the library is loaded
try:
# If the module was imported via its name
import model.kronos as mk
mk.calc_time_stamps = patched_calc_time_stamps
except Exception as e:
print(f"Warning: Could not patch Kronos: {e}")
def load_models():
"""Load Kronos models once and cache them."""
global _tokenizer, _model, _predictor
if _predictor is None:
print("Loading Kronos models from Hugging Face...")
_tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
_model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
_predictor = KronosPredictor(_model, _tokenizer, max_context=512)
return _predictor
def get_outlook_category(change_pct: float):
"""Categorize outlook based on percentage change."""
if change_pct > 30:
return "Extreme Bullish"
elif change_pct > 0:
return "Bullish"
elif change_pct < -30:
return "Extreme Bearish"
else:
return "Bearish"
def fetch_stock_data(symbol: str, lookback: int = 300, interval: str = '1d'):
"""Fetch historical data using yfinance and normalize for Kronos."""
fetch_symbol = symbol.strip().upper()
print(f"Fetching data for {fetch_symbol} (Interval: {interval}, Lookback: {lookback})...")
ticker = yf.Ticker(fetch_symbol)
# Map interval to yfinance period strings
# yfinance uses 'd', 'wk', 'mo' for suffixes in history(period=...)
# But for simplicity, we can just use the interval and a large enough period,
# or calculate the exact period. yfinance 'period' can be 'max', '1y', etc.
# Since lookback is in counts, we might need to fetch more and slice.
if interval == '1d':
period = f"{int(lookback * 1.5)}d" # Fetch extra to account for holidays
elif interval == '1wk':
period = f"{int(lookback * 1.2)}wk"
else: # 1mo
period = f"{int(lookback * 1.1)}mo"
df = ticker.history(period="max" if lookback > 500 else "2y", interval=interval)
if df.empty:
raise ValueError(f"Could not fetch data for symbol: {symbol}")
# Slice to requested lookback
df = df.tail(lookback)
df = df.reset_index()
# Normalize column names to lowercase
df.columns = [c.lower() for c in df.columns]
if 'date' in df.columns:
df = df.rename(columns={'date': 'timestamps'})
return df, fetch_symbol
def predict_from_df(df: pd.DataFrame, pred_len: int = 30, interval: str = '1d'):
"""Run prediction logic on a normalized DataFrame."""
# 1. Validate columns
required_cols = ['open', 'high', 'low', 'close']
for col in required_cols:
if col not in df.columns:
raise ValueError(f"Required column '{col}' missing.")
df[col] = df[col].astype(float)
if 'timestamps' not in df.columns:
raise ValueError("Missing 'timestamps' column.")
df['timestamps'] = pd.to_datetime(df['timestamps'])
# Handle NaNs (common in stock data)
df = df.ffill().bfill()
df = df.dropna(subset=required_cols)
x_df = df[required_cols]
if 'volume' in df.columns:
x_df = df[required_cols + ['volume']]
x_timestamp = df['timestamps']
last_date = df['timestamps'].iloc[-1]
# Frequency mapping for future timestamps
freq_map = {'1d': 'B', '1wk': 'W', '1mo': 'M'}
freq = freq_map.get(interval, 'B')
y_timestamp = pd.Series(pd.date_range(start=last_date + pd.Timedelta(days=1), periods=pred_len, freq=freq))
# 2. Predict
predictor = load_models()
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=1.0,
top_p=0.9,
sample_count=1,
verbose=False
)
# 3. Generate visualization
plt.figure(figsize=(12, 6))
plt.style.use('dark_background')
# Show last 100 points of history
hist_to_show = df.tail(100)
plt.plot(hist_to_show['timestamps'], hist_to_show['close'], label='History', color='#3b82f6', linewidth=2)
plt.plot(pred_df.index, pred_df['close'], label='AI Prediction', color='#ef4444', linestyle='--', linewidth=2)
plt.axvline(x=last_date, color='white', linestyle='--', alpha=0.3)
plt.title(f'Kronos AI Forecast ({interval})', fontsize=14, pad=20, color='white')
plt.legend()
plt.grid(True, alpha=0.1)
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor='#0f172a')
plt.close()
chart_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
# 4. Results
last_price = float(df['close'].iloc[-1])
pred_price = float(pred_df['close'].iloc[-1])
change_pct = ((pred_price - last_price) / last_price) * 100
prediction_list = []
for date, row in pred_df.iterrows():
prediction_list.append({
"date": date.strftime('%Y-%m-%d'),
"close": round(row['close'], 2)
})
outlook = get_outlook_category(change_pct)
return {
"prediction": prediction_list,
"chart": chart_base64,
"last_price": round(last_price, 2),
"predicted_price": round(pred_price, 2),
"change_pct": round(change_pct, 2),
"outlook": outlook,
"message": f"{outlook} forecast ({round(change_pct, 1)}%)"
}
def predict_from_csv(csv_data: str, pred_len: int = 30):
"""Old entry point for CSV data."""
df = pd.read_csv(io.StringIO(csv_data))
# Minimal normalization if needed (the predict_from_df handles renaming logic if we pass it)
# But for safety, we'll keep the renaming logic here or move it to a shared helper.
if 'time' in df.columns:
df = df.rename(columns={'time': 'timestamps'})
elif 'Date' in df.columns:
df = df.rename(columns={'Date': 'timestamps'})
elif 'date' in df.columns:
df = df.rename(columns={'date': 'timestamps'})
col_map = {c: c.lower() for c in df.columns if c.lower() in ['open', 'high', 'low', 'close', 'volume']}
df = df.rename(columns=col_map)
return predict_from_df(df, pred_len)
def predict_from_symbol(symbol: str, lookback: int = 300, interval: str = '1d', pred_len: int = 30):
"""New entry point for symbol-based prediction."""
df, actual_symbol = fetch_stock_data(symbol, lookback, interval)
result = predict_from_df(df, pred_len, interval)
result["symbol"] = actual_symbol
return result
def process_multiple_csvs(csv_list, pred_len=30):
results = []
for csv_data in csv_list:
try:
res = predict_from_csv(csv_data, pred_len)
results.append(res)
except Exception as e:
results.append({'error': str(e)})
return results