chronos2-excel-forecasting-api / app /main_from_hf_space.py
ttzzs's picture
Deploy Chronos2 Forecasting API v3.0.0 with new SOLID architecture
c40c447 verified
import os
from typing import List, Dict, Optional
import json
import numpy as np
import pandas as pd
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from huggingface_hub import InferenceClient
# =========================
# Configuraci贸n
# =========================
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_ID = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-t5-large")
app = FastAPI(
title="Chronos-2 Forecasting API (HF Inference)",
description=(
"API de pron贸sticos usando Chronos-2 via Hugging Face Inference API. "
"Compatible con Excel Add-in."
),
version="1.0.0",
)
# Configurar CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # En producci贸n, especificar dominios permitidos
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Cliente de HF Inference
if not HF_TOKEN:
print("鈿狅笍 WARNING: HF_TOKEN no configurado. La API puede no funcionar correctamente.")
print(" Configura HF_TOKEN en las variables de entorno del Space.")
client = None
else:
client = InferenceClient(token=HF_TOKEN)
# =========================
# Modelos Pydantic
# =========================
class UnivariateSeries(BaseModel):
values: List[float]
class ForecastUnivariateRequest(BaseModel):
series: UnivariateSeries
prediction_length: int = Field(7, description="N煤mero de pasos a predecir")
quantile_levels: Optional[List[float]] = Field(
default=[0.1, 0.5, 0.9],
description="Cuantiles para intervalos de confianza"
)
freq: str = Field("D", description="Frecuencia temporal (D, W, M, etc.)")
class ForecastUnivariateResponse(BaseModel):
timestamps: List[str]
median: List[float]
quantiles: Dict[str, List[float]]
class AnomalyDetectionRequest(BaseModel):
context: UnivariateSeries
recent_observed: List[float]
prediction_length: int = 7
quantile_low: float = 0.05
quantile_high: float = 0.95
class AnomalyPoint(BaseModel):
index: int
value: float
predicted_median: float
lower: float
upper: float
is_anomaly: bool
class AnomalyDetectionResponse(BaseModel):
anomalies: List[AnomalyPoint]
class BacktestRequest(BaseModel):
series: UnivariateSeries
prediction_length: int = 7
test_length: int = 28
class BacktestMetrics(BaseModel):
mae: float
mape: float
rmse: float
class BacktestResponse(BaseModel):
metrics: BacktestMetrics
forecast_median: List[float]
forecast_timestamps: List[str]
actuals: List[float]
# Modelos para Multi-Series
class MultiSeriesItem(BaseModel):
series_id: str
values: List[float]
class ForecastMultiIdRequest(BaseModel):
series_list: List[MultiSeriesItem]
prediction_length: int = Field(7, description="N煤mero de pasos a predecir")
quantile_levels: Optional[List[float]] = Field(
default=[0.1, 0.5, 0.9],
description="Cuantiles para intervalos de confianza"
)
freq: str = Field("D", description="Frecuencia temporal (D, W, M, etc.)")
class ForecastMultiIdResponse(BaseModel):
forecasts: List[ForecastUnivariateResponse]
# Modelos para Covariates
class CovariateData(BaseModel):
values: List[float]
name: str = Field(..., description="Nombre de la covariable")
class ForecastWithCovariatesRequest(BaseModel):
target_series: UnivariateSeries
covariates_history: List[CovariateData]
covariates_future: List[CovariateData]
prediction_length: int = Field(7, description="N煤mero de pasos a predecir")
quantile_levels: Optional[List[float]] = Field(
default=[0.1, 0.5, 0.9],
description="Cuantiles para intervalos de confianza"
)
freq: str = Field("D", description="Frecuencia temporal")
# Modelos para Scenarios
class ScenarioData(BaseModel):
scenario_name: str
covariate_values: Dict[str, List[float]]
class GenerateScenariosRequest(BaseModel):
target_series: UnivariateSeries
scenarios: List[ScenarioData]
prediction_length: int = Field(7, description="N煤mero de pasos a predecir")
freq: str = Field("D", description="Frecuencia temporal")
class ScenarioForecast(BaseModel):
scenario_name: str
timestamps: List[str]
median: List[float]
quantiles: Dict[str, List[float]]
class GenerateScenariosResponse(BaseModel):
scenarios: List[ScenarioForecast]
# Modelos para Multivariate
class MultivariateSeries(BaseModel):
series_name: str
values: List[float]
class ForecastMultivariateRequest(BaseModel):
series_list: List[MultivariateSeries]
prediction_length: int = Field(7, description="N煤mero de pasos a predecir")
quantile_levels: Optional[List[float]] = Field(
default=[0.1, 0.5, 0.9],
description="Cuantiles para intervalos de confianza"
)
freq: str = Field("D", description="Frecuencia temporal")
class MultivariateForecast(BaseModel):
series_name: str
timestamps: List[str]
median: List[float]
quantiles: Dict[str, List[float]]
class ForecastMultivariateResponse(BaseModel):
forecasts: List[MultivariateForecast]
# =========================
# Funci贸n auxiliar para llamar a HF Inference
# =========================
def call_chronos_inference(series: List[float], prediction_length: int) -> Dict:
"""
Llama a la API de Hugging Face Inference para Chronos.
Retorna un diccionario con las predicciones.
"""
if client is None:
raise HTTPException(
status_code=503,
detail="HF_TOKEN no configurado. Contacta al administrador del servicio."
)
try:
# Intentar usando el endpoint espec铆fico de time series
import requests
url = f"/static-proxy?url=https%3A%2F%2Frouter.huggingface.co%2Fhf-inference%2Fmodels%2F%3Cspan class="hljs-subst">{MODEL_ID}"
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
payload = {
"inputs": series,
"parameters": {
"prediction_length": prediction_length,
"num_samples": 100 # Para obtener cuantiles
}
}
response = requests.post(url, headers=headers, json=payload, timeout=60)
if response.status_code == 503:
raise HTTPException(
status_code=503,
detail="El modelo est谩 cargando. Por favor, intenta de nuevo en 30-60 segundos."
)
elif response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"Error de la API de HuggingFace: {response.text}"
)
result = response.json()
return result
except requests.exceptions.Timeout:
raise HTTPException(
status_code=504,
detail="Timeout al comunicarse con HuggingFace API. El modelo puede estar cargando."
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error inesperado: {str(e)}"
)
def process_chronos_output(raw_output: Dict, prediction_length: int) -> Dict:
"""
Procesa la salida de Chronos para extraer mediana y cuantiles.
"""
# La API de Chronos puede devolver diferentes formatos
# Intentamos adaptarnos a ellos
if isinstance(raw_output, list):
# Si es una lista de valores, asumimos que es la predicci贸n media
median = raw_output[:prediction_length]
return {
"median": median,
"quantiles": {
"0.1": median, # Sin cuantiles, usar median
"0.5": median,
"0.9": median
}
}
# Si tiene estructura m谩s compleja, intentar extraer
if "forecast" in raw_output:
forecast = raw_output["forecast"]
if "median" in forecast:
median = forecast["median"][:prediction_length]
else:
median = forecast.get("mean", [0] * prediction_length)[:prediction_length]
quantiles = forecast.get("quantiles", {})
return {
"median": median,
"quantiles": quantiles
}
# Formato por defecto
return {
"median": [0] * prediction_length,
"quantiles": {
"0.1": [0] * prediction_length,
"0.5": [0] * prediction_length,
"0.9": [0] * prediction_length
}
}
# =========================
# Endpoints
# =========================
@app.get("/")
def root():
"""Informaci贸n b谩sica de la API"""
return {
"name": "Chronos-2 Forecasting API",
"version": "1.0.0",
"model": MODEL_ID,
"status": "running",
"docs": "/docs",
"health": "/health"
}
@app.get("/health")
def health():
"""Health check del servicio"""
return {
"status": "ok" if HF_TOKEN else "warning",
"model_id": MODEL_ID,
"hf_token_configured": HF_TOKEN is not None,
"message": "Ready" if HF_TOKEN else "HF_TOKEN not configured"
}
@app.post("/forecast_univariate", response_model=ForecastUnivariateResponse)
def forecast_univariate(req: ForecastUnivariateRequest):
"""
Pron贸stico para una serie temporal univariada.
Compatible con el Excel Add-in.
"""
values = req.series.values
n = len(values)
if n == 0:
raise HTTPException(status_code=400, detail="La serie no puede estar vac铆a.")
if n < 3:
raise HTTPException(
status_code=400,
detail="La serie debe tener al menos 3 puntos hist贸ricos."
)
# Llamar a la API de HuggingFace
raw_output = call_chronos_inference(values, req.prediction_length)
# Procesar la salida
processed = process_chronos_output(raw_output, req.prediction_length)
# Generar timestamps
timestamps = [f"t+{i+1}" for i in range(req.prediction_length)]
return ForecastUnivariateResponse(
timestamps=timestamps,
median=processed["median"],
quantiles=processed["quantiles"]
)
@app.post("/detect_anomalies", response_model=AnomalyDetectionResponse)
def detect_anomalies(req: AnomalyDetectionRequest):
"""
Detecta anomal铆as comparando valores observados con predicciones.
"""
n_hist = len(req.context.values)
if n_hist == 0:
raise HTTPException(status_code=400, detail="El contexto no puede estar vac铆o.")
if len(req.recent_observed) != req.prediction_length:
raise HTTPException(
status_code=400,
detail="recent_observed debe tener la misma longitud que prediction_length."
)
# Hacer predicci贸n
raw_output = call_chronos_inference(req.context.values, req.prediction_length)
processed = process_chronos_output(raw_output, req.prediction_length)
# Comparar con valores observados
anomalies: List[AnomalyPoint] = []
median = processed["median"]
# Intentar obtener cuantiles o usar aproximaciones
q_low = processed["quantiles"].get(str(req.quantile_low), median)
q_high = processed["quantiles"].get(str(req.quantile_high), median)
for i, obs in enumerate(req.recent_observed):
if i < len(median):
lower = q_low[i] if i < len(q_low) else median[i] * 0.8
upper = q_high[i] if i < len(q_high) else median[i] * 1.2
predicted = median[i]
is_anom = (obs < lower) or (obs > upper)
anomalies.append(
AnomalyPoint(
index=i,
value=obs,
predicted_median=predicted,
lower=lower,
upper=upper,
is_anomaly=is_anom,
)
)
return AnomalyDetectionResponse(anomalies=anomalies)
@app.post("/backtest_simple", response_model=BacktestResponse)
def backtest_simple(req: BacktestRequest):
"""
Backtesting simple: divide la serie en train/test y eval煤a m茅tricas.
"""
values = np.array(req.series.values, dtype=float)
n = len(values)
if n <= req.test_length:
raise HTTPException(
status_code=400,
detail="La serie debe ser m谩s larga que test_length."
)
# Dividir en train/test
train = values[: n - req.test_length].tolist()
test = values[n - req.test_length :].tolist()
# Hacer predicci贸n
raw_output = call_chronos_inference(train, req.test_length)
processed = process_chronos_output(raw_output, req.test_length)
forecast = np.array(processed["median"], dtype=float)
test_arr = np.array(test, dtype=float)
# Calcular m茅tricas
mae = float(np.mean(np.abs(test_arr - forecast)))
rmse = float(np.sqrt(np.mean((test_arr - forecast) ** 2)))
eps = 1e-8
mape = float(np.mean(np.abs((test_arr - forecast) / (test_arr + eps)))) * 100.0
timestamps = [f"test_t{i+1}" for i in range(req.test_length)]
metrics = BacktestMetrics(mae=mae, mape=mape, rmse=rmse)
return BacktestResponse(
metrics=metrics,
forecast_median=forecast.tolist(),
forecast_timestamps=timestamps,
actuals=test,
)
# =========================
# Endpoints simplificados para testing
# =========================
@app.post("/simple_forecast")
def simple_forecast(series: List[float], prediction_length: int = 7):
"""
Endpoint simplificado para testing r谩pido.
"""
if not series:
raise HTTPException(status_code=400, detail="Serie vac铆a")
raw_output = call_chronos_inference(series, prediction_length)
processed = process_chronos_output(raw_output, prediction_length)
return {
"input_series": series,
"prediction_length": prediction_length,
"forecast": processed["median"],
"model": MODEL_ID
}
# =========================
# NUEVOS ENDPOINTS IMPLEMENTADOS
# =========================
@app.post("/forecast_multi_id", response_model=ForecastMultiIdResponse)
def forecast_multi_id(req: ForecastMultiIdRequest):
"""
Pron贸stico para m煤ltiples series temporales independientes.
Cada serie se procesa por separado y devuelve su pron贸stico.
脷til para pron贸sticos de m煤ltiples productos, ubicaciones, etc.
"""
if not req.series_list:
raise HTTPException(status_code=400, detail="La lista de series no puede estar vac铆a.")
forecasts = []
for series_item in req.series_list:
values = series_item.values
if len(values) < 3:
raise HTTPException(
status_code=400,
detail=f"La serie '{series_item.series_id}' debe tener al menos 3 puntos."
)
# Hacer predicci贸n para esta serie
raw_output = call_chronos_inference(values, req.prediction_length)
processed = process_chronos_output(raw_output, req.prediction_length)
# Generar timestamps
timestamps = [f"t+{i+1}" for i in range(req.prediction_length)]
# Agregar a la lista de resultados
forecasts.append(
ForecastUnivariateResponse(
timestamps=timestamps,
median=processed["median"],
quantiles=processed["quantiles"]
)
)
return ForecastMultiIdResponse(forecasts=forecasts)
@app.post("/forecast_with_covariates")
def forecast_with_covariates(req: ForecastWithCovariatesRequest):
"""
Pron贸stico con variables covariables (ex贸genas).
NOTA: Chronos-2 es un modelo univariado puro. Esta implementaci贸n
es una aproximaci贸n que usa las covariables para ajustar el contexto,
pero no es un modelo multivariado verdadero.
Para pron贸sticos reales con covariables, considera usar modelos como
TimesFM, Temporal Fusion Transformer, o Prophet.
"""
target_values = req.target_series.values
if len(target_values) < 3:
raise HTTPException(
status_code=400,
detail="La serie objetivo debe tener al menos 3 puntos."
)
# Verificar que las covariables tengan la longitud correcta
for cov in req.covariates_history:
if len(cov.values) != len(target_values):
raise HTTPException(
status_code=400,
detail=f"La covariable '{cov.name}' debe tener la misma longitud que la serie objetivo."
)
for cov in req.covariates_future:
if len(cov.values) != req.prediction_length:
raise HTTPException(
status_code=400,
detail=f"La covariable futura '{cov.name}' debe tener longitud = prediction_length."
)
# APROXIMACI脫N: Usar solo la serie objetivo
# En un modelo verdadero con covariables, estas se integrar铆an en el modelo
raw_output = call_chronos_inference(target_values, req.prediction_length)
processed = process_chronos_output(raw_output, req.prediction_length)
# Generar timestamps
timestamps = [f"t+{i+1}" for i in range(req.prediction_length)]
# Nota: Las covariables se devuelven para referencia pero no afectan el forecast
return {
"timestamps": timestamps,
"median": processed["median"],
"quantiles": processed["quantiles"],
"note": "Chronos-2 no usa covariables nativamente. Este forecast se basa solo en la serie objetivo.",
"covariates_used": [cov.name for cov in req.covariates_history],
"covariates_future": [cov.name for cov in req.covariates_future]
}
@app.post("/generate_scenarios", response_model=GenerateScenariosResponse)
def generate_scenarios(req: GenerateScenariosRequest):
"""
Genera pron贸sticos para m煤ltiples escenarios "what-if".
Cada escenario representa una configuraci贸n diferente de covariables futuras.
脷til para an谩lisis de sensibilidad y planificaci贸n.
NOTA: Como Chronos-2 no usa covariables, todos los escenarios
producir谩n el mismo forecast base. Esta funcionalidad es m谩s 煤til
con modelos que soporten covariables.
"""
target_values = req.target_series.values
if len(target_values) < 3:
raise HTTPException(
status_code=400,
detail="La serie objetivo debe tener al menos 3 puntos."
)
if not req.scenarios:
raise HTTPException(
status_code=400,
detail="Debe proporcionar al menos un escenario."
)
# Hacer una predicci贸n base
raw_output = call_chronos_inference(target_values, req.prediction_length)
processed = process_chronos_output(raw_output, req.prediction_length)
# Generar timestamps
timestamps = [f"t+{i+1}" for i in range(req.prediction_length)]
scenarios_output = []
for scenario in req.scenarios:
# En un modelo real con covariables, aqu铆 se usar铆an los valores
# de scenario.covariate_values para generar diferentes forecasts
# Por ahora, todos los escenarios usan el mismo forecast base
scenarios_output.append(
ScenarioForecast(
scenario_name=scenario.scenario_name,
timestamps=timestamps,
median=processed["median"],
quantiles=processed["quantiles"]
)
)
return GenerateScenariosResponse(scenarios=scenarios_output)
@app.post("/forecast_multivariate", response_model=ForecastMultivariateResponse)
def forecast_multivariate(req: ForecastMultivariateRequest):
"""
Pron贸stico multivariado: predice m煤ltiples series relacionadas.
NOTA: Chronos-2 es fundamentalmente univariado. Esta implementaci贸n
procesa cada serie independientemente. Para pron贸sticos multivariados
verdaderos (que capturan correlaciones entre series), usa modelos como
Temporal Fusion Transformer, DeepAR, o Vector Autoregression (VAR).
"""
if not req.series_list:
raise HTTPException(
status_code=400,
detail="La lista de series no puede estar vac铆a."
)
forecasts = []
for series_item in req.series_list:
values = series_item.values
if len(values) < 3:
raise HTTPException(
status_code=400,
detail=f"La serie '{series_item.series_name}' debe tener al menos 3 puntos."
)
# Procesar cada serie independientemente
raw_output = call_chronos_inference(values, req.prediction_length)
processed = process_chronos_output(raw_output, req.prediction_length)
# Generar timestamps
timestamps = [f"t+{i+1}" for i in range(req.prediction_length)]
forecasts.append(
MultivariateForecast(
series_name=series_item.series_name,
timestamps=timestamps,
median=processed["median"],
quantiles=processed["quantiles"]
)
)
return ForecastMultivariateResponse(forecasts=forecasts)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)