First commit
Browse files- Dockerfile +1 -1
- app.py +161 -0
- requirements.txt +75 -3
- static/logo_light.png +0 -0
Dockerfile
CHANGED
|
@@ -17,4 +17,4 @@ EXPOSE 8501
|
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
-
ENTRYPOINT ["streamlit", "run", "
|
|
|
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
+
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
app.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# streamlit_app.py
|
| 2 |
+
# Compute-Optimal LLM Training Estimator (Chinchilla-style)
|
| 3 |
+
# ---------------------------------------------------------
|
| 4 |
+
# Usage: `streamlit run streamlit_app.py`
|
| 5 |
+
# This tool helps estimate total FLOPs, steps, wall-clock time, and rough cost
|
| 6 |
+
# for LLM pretraining given model parameters, token budget, and hardware.
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import streamlit as st
|
| 10 |
+
|
| 11 |
+
st.set_page_config(page_title="LLM Compute Estimator", page_icon="🧮", layout="centered")
|
| 12 |
+
|
| 13 |
+
st.title("🧮 LLM Compute-Optimal Estimator")
|
| 14 |
+
st.caption("Estimate total FLOPs, wall-clock time, steps, and cost for pretraining — with a Chinchilla-style token rule.")
|
| 15 |
+
|
| 16 |
+
# --- Sidebar: assumptions ---
|
| 17 |
+
with st.sidebar:
|
| 18 |
+
st.logo('./static/logo_light.png')
|
| 19 |
+
st.header("Assumptions & Notes")
|
| 20 |
+
st.markdown(
|
| 21 |
+
"""
|
| 22 |
+
**Formulas**
|
| 23 |
+
- **Total FLOPs** ≈ `c * N_params * N_tokens`, with default **c = 6** (forward+backward+optimizer overhead).
|
| 24 |
+
- **Compute-optimal tokens** (rule-of-thumb): `N_tokens ≈ k * N_params`, default **k = 20**.
|
| 25 |
+
- **Effective compute** = `GPU_count * (peak TFLOPs × 1e12) * efficiency`.
|
| 26 |
+
|
| 27 |
+
**Disclaimers**
|
| 28 |
+
- This is a *back-of-the-envelope* estimator. Real training efficiency depends on data pipeline, parallelism strategy, sequence length, kernel fusion, optimizer, etc.
|
| 29 |
+
- Preset TFLOPs are **approximate** and depend on precision (FP8/BF16), sparsity, clocks, and vendor kernels.
|
| 30 |
+
"""
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# --- 1) Model size & tokens ---
|
| 34 |
+
st.subheader("1) Model & Token Budget")
|
| 35 |
+
col1, col2, col3 = st.columns([1.2, 1, 1])
|
| 36 |
+
with col1:
|
| 37 |
+
model_params_b = st.number_input("Model size (Billions of parameters)", min_value=0.05, value=4.0, step=0.5, format="%.2f")
|
| 38 |
+
with col2:
|
| 39 |
+
c_overhead = st.number_input("c (FLOPs constant)", min_value=4.0, value=6.0, step=0.5)
|
| 40 |
+
with col3:
|
| 41 |
+
k_tokens_per_param = st.number_input("k (tokens per param for compute-optimal)", min_value=5.0, value=20.0, step=1.0)
|
| 42 |
+
|
| 43 |
+
use_compute_optimal = st.toggle("Use compute‑optimal tokens (k × params)", value=True)
|
| 44 |
+
if use_compute_optimal:
|
| 45 |
+
tokens_b = model_params_b * k_tokens_per_param
|
| 46 |
+
st.info(f"Compute‑optimal token budget ≈ **{tokens_b:,.2f} B** (k = {k_tokens_per_param:g})")
|
| 47 |
+
else:
|
| 48 |
+
tokens_b = st.number_input("Token budget (Billions)", min_value=1.0, value=80.0, step=5.0, format="%.2f")
|
| 49 |
+
|
| 50 |
+
# --- 2) Hardware (moved before batch to define gpu_count first) ---
|
| 51 |
+
st.subheader("2) Hardware")
|
| 52 |
+
col6, col7 = st.columns(2)
|
| 53 |
+
with col6:
|
| 54 |
+
gpu_preset = st.selectbox(
|
| 55 |
+
"GPU preset (approx peak TFLOPs per GPU)",
|
| 56 |
+
(
|
| 57 |
+
"Custom",
|
| 58 |
+
"A100 80GB BF16 ≈ 312",
|
| 59 |
+
"H100 SXM BF16 ≈ 989",
|
| 60 |
+
"B200 (FP8-ish) ≈ 20000",
|
| 61 |
+
),
|
| 62 |
+
index=0,
|
| 63 |
+
help="Values are back-of-the-envelope. Choose 'Custom' to enter your own.",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
preset_map = {
|
| 67 |
+
"A100 80GB BF16 ≈ 312": 312.0,
|
| 68 |
+
"H100 SXM BF16 ≈ 989": 989.0,
|
| 69 |
+
"B200 (FP8-ish) ≈ 20000": 20000.0,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
with col7:
|
| 73 |
+
if gpu_preset == "Custom":
|
| 74 |
+
peak_tflops = st.number_input("Peak TFLOPs per GPU (approx)", min_value=10.0, value=20000.0, step=100.0)
|
| 75 |
+
else:
|
| 76 |
+
peak_tflops = preset_map[gpu_preset]
|
| 77 |
+
st.number_input("Peak TFLOPs per GPU (approx)", value=peak_tflops, disabled=True)
|
| 78 |
+
|
| 79 |
+
col8, col9, col10 = st.columns(3)
|
| 80 |
+
with col8:
|
| 81 |
+
gpu_count = st.number_input("GPU count", min_value=1, value=8, step=1)
|
| 82 |
+
with col9:
|
| 83 |
+
efficiency = st.slider("Training efficiency (MFU, %)", min_value=10, max_value=95, value=50, step=1)
|
| 84 |
+
with col10:
|
| 85 |
+
price_per_gpu_hour = st.number_input("Price per GPU·hour (USD)", min_value=0.0, value=25.0, step=1.0)
|
| 86 |
+
|
| 87 |
+
# --- 3) Batch & Sequence Settings (tokens_per_step computed from gpu_count) ---
|
| 88 |
+
st.subheader("3) Batch & Sequence Settings")
|
| 89 |
+
col4, col5 = st.columns(2)
|
| 90 |
+
with col4:
|
| 91 |
+
micro_batch = st.number_input("Micro batch size per GPU", min_value=1, value=4, step=1, help="Number of sequences per GPU per optimizer step.")
|
| 92 |
+
with col5:
|
| 93 |
+
seq_len = st.number_input("Sequence length (tokens)", min_value=128, value=2048, step=128)
|
| 94 |
+
|
| 95 |
+
tokens_per_step = micro_batch * seq_len * gpu_count
|
| 96 |
+
st.info(f"Tokens per optimization step ≈ {tokens_per_step:,} (with {gpu_count} GPUs)")
|
| 97 |
+
|
| 98 |
+
# --- Compute ---
|
| 99 |
+
N_params = model_params_b * 1e9
|
| 100 |
+
N_tokens = tokens_b * 1e9
|
| 101 |
+
c = c_overhead
|
| 102 |
+
|
| 103 |
+
# Total FLOPs (scalar)
|
| 104 |
+
flops_total = c * N_params * N_tokens # in FLOPs
|
| 105 |
+
|
| 106 |
+
# Effective machine compute per second
|
| 107 |
+
effective_flops_per_s = gpu_count * (peak_tflops * 1e12) * (efficiency / 100.0)
|
| 108 |
+
|
| 109 |
+
# Time estimate
|
| 110 |
+
seconds = flops_total / effective_flops_per_s if effective_flops_per_s > 0 else float('inf')
|
| 111 |
+
hours = seconds / 3600
|
| 112 |
+
days = hours / 24
|
| 113 |
+
|
| 114 |
+
# Steps
|
| 115 |
+
steps = N_tokens / tokens_per_step if tokens_per_step > 0 else float('inf')
|
| 116 |
+
|
| 117 |
+
# Throughput
|
| 118 |
+
throughput_tokens_per_s = N_tokens / seconds if seconds > 0 else float('inf')
|
| 119 |
+
|
| 120 |
+
# Cost
|
| 121 |
+
cost = price_per_gpu_hour * gpu_count * hours
|
| 122 |
+
|
| 123 |
+
# --- Display ---
|
| 124 |
+
st.divider()
|
| 125 |
+
st.subheader("Results")
|
| 126 |
+
|
| 127 |
+
colA, colB = st.columns(2)
|
| 128 |
+
with colA:
|
| 129 |
+
st.metric("Total FLOPs", f"{flops_total:,.2e} FLOPs")
|
| 130 |
+
st.metric("Effective compute", f"{effective_flops_per_s:,.2e} FLOPs/s")
|
| 131 |
+
st.metric("Steps (est)", f"{0 if steps == float('inf') else steps:,.0f}")
|
| 132 |
+
with colB:
|
| 133 |
+
st.metric("Wall‑clock time", f"{hours:,.1f} h (~{days:,.2f} d)")
|
| 134 |
+
st.metric("Throughput", f"{0 if throughput_tokens_per_s == float('inf') else throughput_tokens_per_s:,.0f} tok/s")
|
| 135 |
+
st.metric("Projected cost", f"${0 if cost == float('inf') else cost:,.0f}")
|
| 136 |
+
|
| 137 |
+
st.divider()
|
| 138 |
+
|
| 139 |
+
st.markdown(
|
| 140 |
+
f"""
|
| 141 |
+
**Summary**
|
| 142 |
+
- Params: **{model_params_b:,.2f}B** · Tokens: **{tokens_b:,.2f}B** (compute‑optimal: {use_compute_optimal})
|
| 143 |
+
- Constant **c = {c:g}** → Total ≈ **{flops_total:,.2e} FLOPs**.
|
| 144 |
+
- Hardware: **{gpu_count}× GPU**, peak **{peak_tflops:g} TFLOPs/GPU**, MFU **{efficiency}%** → Effective ≈ **{effective_flops_per_s:,.2e} FLOPs/s**.
|
| 145 |
+
- Time ≈ **{hours:,.1f} hours** (≈ {days:,.2f} days). Steps ≈ **{0 if steps == float('inf') else steps:,.0f}** (@ {tokens_per_step:,} tok/step).
|
| 146 |
+
- Rough cost ≈ **${0 if cost == float('inf') else cost:,.0f}** (@ ${price_per_gpu_hour:g}/GPU·h).
|
| 147 |
+
"""
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
with st.expander("What is the Chinchilla rule? Is it 1 epoch?"):
|
| 151 |
+
st.markdown(
|
| 152 |
+
"""
|
| 153 |
+
**Chinchilla scaling** is a *compute‑optimal* rule of thumb: for a fixed compute budget, scale
|
| 154 |
+
the **training tokens** roughly in proportion to the **model parameters** (commonly ~20× tokens per parameter).
|
| 155 |
+
It is **not** about training for exactly one epoch. In web‑scale pretraining, datasets are often sampled with
|
| 156 |
+
replacement or mixed; you might see data multiple times or less than once. The rule speaks to the *total number
|
| 157 |
+
of tokens* the model should process for best use of compute, not to dataset passes.
|
| 158 |
+
"""
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
st.success("Ready. Tweak inputs on the left to explore different scenarios.")
|
requirements.txt
CHANGED
|
@@ -1,3 +1,75 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohappyeyeballs==2.6.1
|
| 2 |
+
aiohttp==3.12.15
|
| 3 |
+
aiosignal==1.4.0
|
| 4 |
+
altair==5.5.0
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
anyio==4.10.0
|
| 7 |
+
attrs==25.3.0
|
| 8 |
+
beautifulsoup4==4.13.5
|
| 9 |
+
blinker==1.9.0
|
| 10 |
+
bs4==0.0.2
|
| 11 |
+
cachetools==6.2.0
|
| 12 |
+
certifi==2025.8.3
|
| 13 |
+
charset-normalizer==3.4.3
|
| 14 |
+
click==8.2.1
|
| 15 |
+
datasets==4.1.1
|
| 16 |
+
deprecation==2.1.0
|
| 17 |
+
dill==0.4.0
|
| 18 |
+
distro==1.9.0
|
| 19 |
+
filelock==3.19.1
|
| 20 |
+
frozenlist==1.7.0
|
| 21 |
+
fsspec==2025.9.0
|
| 22 |
+
gitdb==4.0.12
|
| 23 |
+
GitPython==3.1.45
|
| 24 |
+
h11==0.16.0
|
| 25 |
+
hf-xet==1.1.10
|
| 26 |
+
httpcore==1.0.9
|
| 27 |
+
httpx==0.28.1
|
| 28 |
+
huggingface-hub==0.35.3
|
| 29 |
+
idna==3.10
|
| 30 |
+
Jinja2==3.1.6
|
| 31 |
+
jiter==0.10.0
|
| 32 |
+
jsonschema==4.25.1
|
| 33 |
+
jsonschema-specifications==2025.4.1
|
| 34 |
+
lancedb==0.24.3
|
| 35 |
+
MarkupSafe==3.0.2
|
| 36 |
+
multidict==6.6.4
|
| 37 |
+
multiprocess==0.70.16
|
| 38 |
+
narwhals==2.3.0
|
| 39 |
+
numpy==2.3.2
|
| 40 |
+
openai==1.105.0
|
| 41 |
+
overrides==7.7.0
|
| 42 |
+
packaging==25.0
|
| 43 |
+
pandas==2.3.2
|
| 44 |
+
pillow==11.3.0
|
| 45 |
+
propcache==0.3.2
|
| 46 |
+
protobuf==6.32.0
|
| 47 |
+
pyarrow==21.0.0
|
| 48 |
+
pydantic==2.11.7
|
| 49 |
+
pydantic_core==2.33.2
|
| 50 |
+
pydeck==0.9.1
|
| 51 |
+
pylance==0.35.0
|
| 52 |
+
python-dateutil==2.9.0.post0
|
| 53 |
+
python-dotenv==1.1.1
|
| 54 |
+
pytz==2025.2
|
| 55 |
+
PyYAML==6.0.3
|
| 56 |
+
referencing==0.36.2
|
| 57 |
+
requests==2.32.5
|
| 58 |
+
rpds-py==0.27.1
|
| 59 |
+
setuptools==78.1.1
|
| 60 |
+
six==1.17.0
|
| 61 |
+
smmap==5.0.2
|
| 62 |
+
sniffio==1.3.1
|
| 63 |
+
soupsieve==2.8
|
| 64 |
+
streamlit==1.49.1
|
| 65 |
+
tenacity==9.1.2
|
| 66 |
+
toml==0.10.2
|
| 67 |
+
tornado==6.5.2
|
| 68 |
+
tqdm==4.67.1
|
| 69 |
+
typing-inspection==0.4.1
|
| 70 |
+
typing_extensions==4.15.0
|
| 71 |
+
tzdata==2025.2
|
| 72 |
+
urllib3==2.5.0
|
| 73 |
+
wheel==0.45.1
|
| 74 |
+
xxhash==3.5.0
|
| 75 |
+
yarl==1.20.1
|
static/logo_light.png
ADDED
|