smoldlm-144m / generate.py
HoangHa's picture
Add generate.py
7c65c7f verified
#!/usr/bin/env python3
"""Generate text with SmolDLM-144M.
Usage:
python generate.py "The meaning of life is"
python generate.py --steps 20 --temperature 0.5 "Once upon a time"
python generate.py --checkpoint checkpoints/model_step_025000.safetensors "Hello world"
Install:
pip install torch safetensors tokenizers huggingface-hub
"""
import argparse
from modeling_smoldlm import SmolDLM, generate
def main():
parser = argparse.ArgumentParser(description="Generate text with SmolDLM-144M")
parser.add_argument("prompt", nargs="?", default="The meaning of life is",
help="Text prompt to continue from")
parser.add_argument("--steps", type=int, default=10,
help="Denoising steps per block (default: 10)")
parser.add_argument("--temperature", type=float, default=0.7,
help="Sampling temperature (default: 0.7)")
parser.add_argument("--max-tokens", type=int, default=256,
help="Maximum tokens to generate (default: 256)")
parser.add_argument("--checkpoint", type=str, default=None,
help="Specific checkpoint (e.g., checkpoints/model_step_025000.safetensors)")
parser.add_argument("--device", type=str, default="auto",
help="Device: auto, cpu, cuda, mps (default: auto)")
args = parser.parse_args()
model = SmolDLM.from_pretrained(
"HoangHa/smoldlm-144m",
checkpoint=args.checkpoint,
device=args.device,
)
text = generate(
model,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
steps=args.steps,
temperature=args.temperature,
)
print(f"\n{text}")
if __name__ == "__main__":
main()