import concurrent.futures import math import time from io import BytesIO from pathlib import Path import requests from convert_atf import ATFConverter from datasets import Dataset from PIL import Image from tqdm.auto import tqdm atf_converter = ATFConverter() IMG_CACHE = Path("./data/cdli_images") IMG_CACHE.mkdir(exist_ok=True, parents=True) MAX_IMG_RES = 2048 DOWNLOAD_MODE = False def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 28 * 28 * 130, max_pixels: int = 28 * 28 * 1280, ): """Rescales the image so that the following conditions are met: 1. Both dimensions (height and width) are divisible by 'factor'. 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 3. The aspect ratio of the image is maintained as closely as possible. """ # if height < factor or width < factor: # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") # if int(height < factor//4) + int(width < factor//4): # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor//4}") if height < factor: print(f"smart_resize: height={height} < factor={factor}, reset height=factor") width = round((width * factor) / height) height = factor if width < factor: print(f"smart_resize: width={width} < factor={factor}, reset width=factor") height = round((height * factor) / width) width = factor if max(height, width) / min(height, width) > 200: raise ValueError( f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" ) h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = math.floor(height / beta / factor) * factor w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar def resize_image(img_path): with Image.open(img_path).convert("RGB") as image: width, height = image.size # Scale down if larger than MAX_IMG_RES if width > MAX_IMG_RES or height > MAX_IMG_RES: scale = MAX_IMG_RES / max(width, height) height = int(height * scale) width = int(width * scale) # Always ensure dimensions are multiples of 28 for vision model compatibility new_height, new_width = smart_resize(height, width) if new_height != image.height or new_width != image.width: image = image.resize((new_width, new_height), Image.LANCZOS) image.save(img_path) def resize_cached_images(): img_paths = list(IMG_CACHE.glob("*.jpg")) pbar = tqdm(img_paths) with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: futures = [executor.submit(resize_image, img_path) for img_path in img_paths] for future in concurrent.futures.as_completed(futures): pbar.update(1) pbar.close() def get_image(id: int): file_name = f"P{str(id).rjust(6, '0')}.jpg" url = f"https://cdli.earth/dl/photo/{file_name}" cache_file = IMG_CACHE / file_name try: if cache_file.exists(): tqdm.write(f"Found {file_name} in cache") image = Image.open(cache_file).convert("RGB") else: response = requests.get(url, timeout=5) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") tqdm.write(f"Downloaded {file_name}") width, height = image.size # Scale down if larger than MAX_IMG_RES if width > MAX_IMG_RES or height > MAX_IMG_RES: scale = MAX_IMG_RES / max(width, height) height = int(height * scale) width = int(width * scale) # Always ensure dimensions are multiples of 28 for vision model compatibility new_height, new_width = smart_resize(height, width) if new_height != image.height or new_width != image.width: image = image.resize((new_width, new_height), Image.LANCZOS) image.save(cache_file) time.sleep(0.02) # Rate limiting except requests.exceptions.Timeout: tqdm.write(f"Timeout downloading {file_name}") return None except requests.exceptions.RequestException as e: tqdm.write(f"Error downloading {file_name}: {e}") return None except Exception as e: tqdm.write(f"Error processing {file_name}: {type(e).__name__}: {e}") return None return image def count_repetitions(text: str) -> int: """ Count the total number of repeated token occurrences in a sequence. E.g., 122233 has 3 repetitions (2 appears 2 extra times, 3 appears 1 extra time). """ if len(text) < 2: return 0 return len(text) - len(set(text)) def get_dataset(file="./data/cdli_dataset.parquet"): if Path(file).exists(): return Dataset.from_parquet(file).train_test_split(test_size=1000, seed=42) # 1. Get all the ids from cdli.atf (source: https://github.com/cdli-gh/data/raw/refs/heads/master/cdliatf_unblocked.atf) cdli_raw = Path("./data/cdli.atf").read_text(encoding="utf-8").split("&P") cdli_filtered = [ section.strip() for section in cdli_raw if section.strip() # Ignore empty sections and "@tablet" in section # Only include tablets and len(section) > 50 # Ignore short sections and len(section) < 1000 # Ignore long sections and any(lang in section for lang in ["sux", "akk"]) # Limit supported languages ] ids = [] atfs = [] unicodes = [] for section in tqdm(cdli_filtered, desc="Parsing CDLI dump"): # Split section at first space to get the ID, ignore if not parseable lines = section.splitlines() id_part = lines[0].split("=")[0].strip() if not id_part.isdigit(): continue atf = "\n".join( [ line for line in lines[1:] if not ( line.startswith("# ") or line.startswith(">>") or line.startswith("<<") or line.startswith("||") ) ] ) parsed = atf_converter.parse(atf) if parsed is None: tqdm.write(f"=====\033[91m {id_part} skip (parse fail) \033[0m=====") continue unicode_parts = [ f"@{face}\n{parsed.get_unicode(face)}" for face in parsed.ALL_FACES if parsed.get_unicode(face) ] # Skip massive tablets unicode_len = sum([len(part) for part in unicode_parts]) if unicode_len > 300 or unicode_len < 20: tqdm.write(f"=====\033[91m {id_part} skip (too short/long) \033[0m=====") continue # Skip tablets that are poorly translated to unicode if sum([part.count("x") for part in unicode_parts]) >= 2: tqdm.write(f"=====\033[91m {id_part} skip (missing symbols) \033[0m=====") continue unicode = "\n".join(unicode_parts) # Drop the super repetitive admin tablets (model ends up getting stuck repeating the common phrases) if count_repetitions(unicode) / len(unicode) > 0.7: tqdm.write(f"=====\033[91m {id_part} skip (too repetitive) \033[0m=====") continue # Ignore if we don't have an image for this atf if DOWNLOAD_MODE: image = get_image(int(id_part)) elif (IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg").exists(): image = Image.open( IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg" ).convert("RGB") else: tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====") continue if not image: tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====") continue # Drop low res, B&W, or non-isolated background try: if min(image.size) < 100: tqdm.write(f"=====\033[91m {id_part} skip (lowres) \033[0m=====") continue scale = 150 / image.height small_image = image.resize( (int(image.width * scale), int(image.height * scale)), Image.LANCZOS ) pixels = list(small_image.getdata()) small_image.close() image.close() bw_pixels = sum(1 for r, g, b in pixels if r == g == b) bw_percent = bw_pixels / len(pixels) if bw_percent > 0.95 or bw_percent < 0.1: tqdm.write( f"=====\033[91m {id_part} skip (bw {bw_percent*100:.1f}%) \033[0m=====" ) continue if sum(1 for r, g, b in pixels if r == g == b == 0) / len(pixels) < 0.15: tqdm.write( f"=====\033[91m {id_part} skip (not on black background) \033[0m=====" ) continue except Exception as e: tqdm.write( f"=====\033[91m {id_part} skip (err img check: {e}) \033[0m=====" ) continue ids.append(int(id_part)) atfs.append(atf) unicodes.append(unicode) tqdm.write(f"=====\033[32m {id_part} unicode (len {unicode_len}) \033[0m=====") dataset = Dataset.from_dict( { "id": ids, "atf": atfs, "unicode": unicodes, } ) dataset.to_parquet(file) return dataset.train_test_split(test_size=1000, seed=42)