| import torch |
| import gradio as gr |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import threading |
| import time |
|
|
| |
| model = None |
| tokenizer = None |
| model_loaded = False |
|
|
| def load_model(): |
| """Load the model and tokenizer""" |
| global model, tokenizer, model_loaded |
| try: |
| print("Loading Prompt Generator model...") |
| tokenizer = AutoTokenizer.from_pretrained("UnfilteredAI/Promt-generator") |
| model = AutoModelForCausalLM.from_pretrained( |
| "UnfilteredAI/Promt-generator", |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
| ).to("cuda" if torch.cuda.is_available() else "cpu") |
| |
| model_loaded = True |
| print("Prompt Generator model loaded successfully!") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| model_loaded = False |
|
|
| def generate_prompt(input_text, max_length, temperature, top_p, num_return_sequences): |
| """Generate enhanced prompts from input text""" |
| global model, tokenizer, model_loaded |
| |
| if not model_loaded: |
| return "模型尚未加载完成,请稍等..." |
| |
| if not input_text.strip(): |
| return "请输入一些文本作为提示词的起始内容。" |
| |
| try: |
| |
| inputs = tokenizer(input_text, return_tensors="pt") |
| if torch.cuda.is_available(): |
| inputs = inputs.to("cuda") |
| |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_length=max_length, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=True, |
| num_return_sequences=num_return_sequences, |
| pad_token_id=tokenizer.eos_token_id, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
| |
| |
| generated_prompts = [] |
| for output in outputs: |
| generated_text = tokenizer.decode(output, skip_special_tokens=True) |
| generated_prompts.append(generated_text) |
| |
| return "\n\n---\n\n".join(generated_prompts) |
| |
| except Exception as e: |
| return f"生成提示词时出错: {str(e)}" |
|
|
| def clear_output(): |
| """Clear the output""" |
| return "" |
|
|
| |
| loading_thread = threading.Thread(target=load_model) |
| loading_thread.start() |
|
|
| |
| with gr.Blocks(title="AI Prompt Generator") as demo: |
| gr.Markdown("# 🎨 AI Prompt Generator") |
| gr.Markdown("基于 UnfilteredAI/Promt-generator 模型的智能提示词生成器") |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| input_text = gr.Textbox( |
| label="输入起始文本", |
| placeholder="例如: a red car, beautiful landscape, futuristic city...", |
| lines=3 |
| ) |
| |
| with gr.Row(): |
| generate_btn = gr.Button("生成提示词", variant="primary", scale=2) |
| clear_btn = gr.Button("清空", scale=1) |
| |
| output_text = gr.Textbox( |
| label="生成的提示词", |
| lines=10, |
| max_lines=20, |
| show_copy_button=True |
| ) |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### 生成参数") |
| |
| max_length = gr.Slider( |
| minimum=50, |
| maximum=500, |
| value=150, |
| step=10, |
| label="最大长度" |
| ) |
| |
| temperature = gr.Slider( |
| minimum=0.1, |
| maximum=2.0, |
| value=0.8, |
| step=0.1, |
| label="Temperature (创造性)" |
| ) |
| |
| top_p = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.9, |
| step=0.05, |
| label="Top-p (多样性)" |
| ) |
| |
| num_return_sequences = gr.Slider( |
| minimum=1, |
| maximum=5, |
| value=3, |
| step=1, |
| label="生成数量" |
| ) |
| |
| gr.Markdown("### 使用说明") |
| gr.Markdown( |
| """- **输入起始文本**: 描述你想要的内容主题 |
| - **Temperature**: 控制生成的随机性,越高越有创意 |
| - **Top-p**: 控制词汇选择的多样性 |
| - **生成数量**: 一次生成多个不同的提示词""" |
| ) |
| |
| |
| generate_btn.click( |
| generate_prompt, |
| inputs=[input_text, max_length, temperature, top_p, num_return_sequences], |
| outputs=output_text |
| ) |
| |
| input_text.submit( |
| generate_prompt, |
| inputs=[input_text, max_length, temperature, top_p, num_return_sequences], |
| outputs=output_text |
| ) |
| |
| clear_btn.click( |
| clear_output, |
| outputs=output_text |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7861, |
| share=False, |
| show_error=True |
| ) |
|
|