import os from pathlib import Path from typing import Optional from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import torch def _resolve_model_identifier(model_name: str) -> str: """Return a valid model identifier or local path.""" path_candidate = Path(model_name) if path_candidate.exists(): return str(path_candidate) return model_name def _build_hub_kwargs(token: Optional[str]) -> dict: """Prepare kwargs for Hugging Face Hub auth across library versions.""" if not token: return {} return {"token": token} def _fallback_hub_kwargs(token: Optional[str]) -> dict: """Older transformers versions still expect use_auth_token.""" if not token: return {} return {"use_auth_token": token} def load_model_and_tokenizer(model_name: str, token: Optional[str] = None): """Load a model and tokenizer for inference.""" resolved_model = _resolve_model_identifier(model_name) auth_token = token or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN") kwargs = _build_hub_kwargs(auth_token) try: tokenizer = AutoTokenizer.from_pretrained(resolved_model, **kwargs) model = AutoModelForSeq2SeqLM.from_pretrained(resolved_model, **kwargs) except TypeError: fallback_kwargs = _fallback_hub_kwargs(auth_token) tokenizer = AutoTokenizer.from_pretrained(resolved_model, **fallback_kwargs) model = AutoModelForSeq2SeqLM.from_pretrained(resolved_model, **fallback_kwargs) return model, tokenizer def generate_answer(model, tokenizer, prompt: str, max_tokens: int = 256): """Generate text output from a model given a prompt.""" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=max_tokens) return tokenizer.decode(outputs[0], skip_special_tokens=True) def build_pipeline(model_name: str, task="text2text-generation", token: Optional[str] = None): """Return a Hugging Face pipeline for inference.""" resolved_model = _resolve_model_identifier(model_name) auth_token = token or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN") kwargs = _build_hub_kwargs(auth_token) try: return pipeline(task, model=resolved_model, **kwargs) except TypeError: fallback_kwargs = _fallback_hub_kwargs(auth_token) return pipeline(task, model=resolved_model, **fallback_kwargs)