--- license: apache-2.0 datasets: - yandex/alchemist language: - en base_model: - ByteDance-Seed/BAGEL-7B-MoT library_name: diffusers --- ![Intro image](https://i.ibb.co/whm5Dp5F/mosaic-10-1.png "Alchemist's tune generations") # BAGEL-7B-MoT Alchemist 👨‍🔬 [BAGEL-7B-MoT Alchemist](https://huggingface.co/yandex/bagel-alchemist) is T2I-finetuned version of [BAGEL-7B-MoT](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT) on [Alchemist](https://huggingface.co/datasets/yandex/alchemist) dataset, proposed in the research paper "Alchemist: Turning Public Text-to-Image Data into Generative Gold". Model generates images with improved aesthetics and complexity. Find more details about dataset and training details in the paper. ## Model usage For installation and usage instructions let's follow the **BAGEL**'s official [GitHub repository](https://github.com/bytedance-seed/BAGEL): 1️⃣ Set up environment ``` git clone https://github.com/bytedance-seed/BAGEL.git cd BAGEL conda create -n bagel python=3.10 -y conda activate bagel pip install -r requirements.txt pip install flash_attn==2.5.8 --no-build-isolation ``` 2️⃣ Download pretrained checkpoint ``` from huggingface_hub import snapshot_download save_dir = "models/BAGEL-7B-MoT-alchemist" repo_id = "yandex/BAGEL-7B-MoT-alchemist" cache_dir = save_dir + "/cache" snapshot_download(cache_dir=cache_dir, local_dir=save_dir, repo_id=repo_id, local_dir_use_symlinks=False, resume_download=True, allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"], ) ``` 3️⃣ Load BAGEL-Alchemist. Note that it was trained on images with maximum side of 1408 px! ``` import os from copy import deepcopy from typing import ( Any, AsyncIterable, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union, ) import requests from io import BytesIO from PIL import Image import torch from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights from data.transforms import ImageTransform from data.data_utils import pil_img2rgb, add_special_tokens from modeling.bagel import ( BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel ) from modeling.qwen2 import Qwen2Tokenizer from modeling.bagel.qwen2_navit import NaiveCache from modeling.autoencoder import load_ae from safetensors.torch import load_file model_path = "/path/to/BAGEL-7B-MoT-alchemist/weights" # LLM config preparing llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json")) llm_config.qk_norm = True llm_config.tie_word_embeddings = False llm_config.layer_module = "Qwen2MoTDecoderLayer" # ViT config preparing vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json")) vit_config.rope = False vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1 # VAE loading vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors")) # Bagel config preparing config = BagelConfig( visual_gen=True, visual_und=True, llm_config=llm_config, vit_config=vit_config, vae_config=vae_config, vit_max_num_patch_per_side=70, connector_act='gelu_pytorch_tanh', latent_patch_size=2, max_latent_size=88, # max_latent_size is 88 for BAGEL-alchemist! ) with init_empty_weights(): language_model = Qwen2ForCausalLM(llm_config) vit_model = SiglipVisionModel(vit_config) model = Bagel(language_model, vit_model, config) model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True) # Tokenizer Preparing tokenizer = Qwen2Tokenizer.from_pretrained(model_path) tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) # Image Transform Preparing vae_transform = ImageTransform(1408, 512, 16) # maximum image side is 1408 for BAGEL-alchemist! vit_transform = ImageTransform(980, 224, 14) max_mem_per_gpu = "40GiB" # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU. device_map = infer_auto_device_map( model, max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())}, no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], ) print(device_map) same_device_modules = [ 'language_model.model.embed_tokens', 'time_embedder', 'latent_pos_embed', 'vae2llm', 'llm2vae', 'connector', 'vit_pos_embed' ] if torch.cuda.device_count() == 1: first_device = device_map.get(same_device_modules[0], "cuda:0") for k in same_device_modules: if k in device_map: device_map[k] = first_device else: device_map[k] = "cuda:0" else: first_device = device_map.get(same_device_modules[0]) for k in same_device_modules: if k in device_map: device_map[k] = first_device # Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8 model = load_checkpoint_and_dispatch( model, checkpoint=os.path.join(model_path, "ema.safetensors"), device_map=device_map, offload_buffers=True, dtype=torch.bfloat16, force_hooks=True, offload_folder="/tmp/offload" ) model = model.eval() print('Model loaded') ``` 4️⃣ Follow final instructions for inference, e.g. T2I inference ``` from inferencer import InterleaveInferencer inferencer = InterleaveInferencer( model=model, vae_model=vae_model, tokenizer=tokenizer, vae_transform=vae_transform, vit_transform=vit_transform, new_token_ids=new_token_ids ) inference_hyper=dict( cfg_text_scale=6.0, cfg_img_scale=1.0, cfg_interval=[0.0, 1.0], timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=0.0, cfg_renorm_type="global", ) prompt = "A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere." print(prompt) print('-' * 10) output_dict = inferencer(text=prompt, **inference_hyper) display(output_dict['image']) ```