lianghsun commited on
Commit
87cbd1b
·
1 Parent(s): 21e139d

First commit

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -1
  2. app.py +161 -0
  3. requirements.txt +75 -3
  4. static/logo_light.png +0 -0
Dockerfile CHANGED
@@ -17,4 +17,4 @@ EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
+ ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_app.py
2
+ # Compute-Optimal LLM Training Estimator (Chinchilla-style)
3
+ # ---------------------------------------------------------
4
+ # Usage: `streamlit run streamlit_app.py`
5
+ # This tool helps estimate total FLOPs, steps, wall-clock time, and rough cost
6
+ # for LLM pretraining given model parameters, token budget, and hardware.
7
+
8
+ import math
9
+ import streamlit as st
10
+
11
+ st.set_page_config(page_title="LLM Compute Estimator", page_icon="🧮", layout="centered")
12
+
13
+ st.title("🧮 LLM Compute-Optimal Estimator")
14
+ st.caption("Estimate total FLOPs, wall-clock time, steps, and cost for pretraining — with a Chinchilla-style token rule.")
15
+
16
+ # --- Sidebar: assumptions ---
17
+ with st.sidebar:
18
+ st.logo('./static/logo_light.png')
19
+ st.header("Assumptions & Notes")
20
+ st.markdown(
21
+ """
22
+ **Formulas**
23
+ - **Total FLOPs** ≈ `c * N_params * N_tokens`, with default **c = 6** (forward+backward+optimizer overhead).
24
+ - **Compute-optimal tokens** (rule-of-thumb): `N_tokens ≈ k * N_params`, default **k = 20**.
25
+ - **Effective compute** = `GPU_count * (peak TFLOPs × 1e12) * efficiency`.
26
+
27
+ **Disclaimers**
28
+ - This is a *back-of-the-envelope* estimator. Real training efficiency depends on data pipeline, parallelism strategy, sequence length, kernel fusion, optimizer, etc.
29
+ - Preset TFLOPs are **approximate** and depend on precision (FP8/BF16), sparsity, clocks, and vendor kernels.
30
+ """
31
+ )
32
+
33
+ # --- 1) Model size & tokens ---
34
+ st.subheader("1) Model & Token Budget")
35
+ col1, col2, col3 = st.columns([1.2, 1, 1])
36
+ with col1:
37
+ model_params_b = st.number_input("Model size (Billions of parameters)", min_value=0.05, value=4.0, step=0.5, format="%.2f")
38
+ with col2:
39
+ c_overhead = st.number_input("c (FLOPs constant)", min_value=4.0, value=6.0, step=0.5)
40
+ with col3:
41
+ k_tokens_per_param = st.number_input("k (tokens per param for compute-optimal)", min_value=5.0, value=20.0, step=1.0)
42
+
43
+ use_compute_optimal = st.toggle("Use compute‑optimal tokens (k × params)", value=True)
44
+ if use_compute_optimal:
45
+ tokens_b = model_params_b * k_tokens_per_param
46
+ st.info(f"Compute‑optimal token budget ≈ **{tokens_b:,.2f} B** (k = {k_tokens_per_param:g})")
47
+ else:
48
+ tokens_b = st.number_input("Token budget (Billions)", min_value=1.0, value=80.0, step=5.0, format="%.2f")
49
+
50
+ # --- 2) Hardware (moved before batch to define gpu_count first) ---
51
+ st.subheader("2) Hardware")
52
+ col6, col7 = st.columns(2)
53
+ with col6:
54
+ gpu_preset = st.selectbox(
55
+ "GPU preset (approx peak TFLOPs per GPU)",
56
+ (
57
+ "Custom",
58
+ "A100 80GB BF16 ≈ 312",
59
+ "H100 SXM BF16 ≈ 989",
60
+ "B200 (FP8-ish) ≈ 20000",
61
+ ),
62
+ index=0,
63
+ help="Values are back-of-the-envelope. Choose 'Custom' to enter your own.",
64
+ )
65
+
66
+ preset_map = {
67
+ "A100 80GB BF16 ≈ 312": 312.0,
68
+ "H100 SXM BF16 ≈ 989": 989.0,
69
+ "B200 (FP8-ish) ≈ 20000": 20000.0,
70
+ }
71
+
72
+ with col7:
73
+ if gpu_preset == "Custom":
74
+ peak_tflops = st.number_input("Peak TFLOPs per GPU (approx)", min_value=10.0, value=20000.0, step=100.0)
75
+ else:
76
+ peak_tflops = preset_map[gpu_preset]
77
+ st.number_input("Peak TFLOPs per GPU (approx)", value=peak_tflops, disabled=True)
78
+
79
+ col8, col9, col10 = st.columns(3)
80
+ with col8:
81
+ gpu_count = st.number_input("GPU count", min_value=1, value=8, step=1)
82
+ with col9:
83
+ efficiency = st.slider("Training efficiency (MFU, %)", min_value=10, max_value=95, value=50, step=1)
84
+ with col10:
85
+ price_per_gpu_hour = st.number_input("Price per GPU·hour (USD)", min_value=0.0, value=25.0, step=1.0)
86
+
87
+ # --- 3) Batch & Sequence Settings (tokens_per_step computed from gpu_count) ---
88
+ st.subheader("3) Batch & Sequence Settings")
89
+ col4, col5 = st.columns(2)
90
+ with col4:
91
+ micro_batch = st.number_input("Micro batch size per GPU", min_value=1, value=4, step=1, help="Number of sequences per GPU per optimizer step.")
92
+ with col5:
93
+ seq_len = st.number_input("Sequence length (tokens)", min_value=128, value=2048, step=128)
94
+
95
+ tokens_per_step = micro_batch * seq_len * gpu_count
96
+ st.info(f"Tokens per optimization step ≈ {tokens_per_step:,} (with {gpu_count} GPUs)")
97
+
98
+ # --- Compute ---
99
+ N_params = model_params_b * 1e9
100
+ N_tokens = tokens_b * 1e9
101
+ c = c_overhead
102
+
103
+ # Total FLOPs (scalar)
104
+ flops_total = c * N_params * N_tokens # in FLOPs
105
+
106
+ # Effective machine compute per second
107
+ effective_flops_per_s = gpu_count * (peak_tflops * 1e12) * (efficiency / 100.0)
108
+
109
+ # Time estimate
110
+ seconds = flops_total / effective_flops_per_s if effective_flops_per_s > 0 else float('inf')
111
+ hours = seconds / 3600
112
+ days = hours / 24
113
+
114
+ # Steps
115
+ steps = N_tokens / tokens_per_step if tokens_per_step > 0 else float('inf')
116
+
117
+ # Throughput
118
+ throughput_tokens_per_s = N_tokens / seconds if seconds > 0 else float('inf')
119
+
120
+ # Cost
121
+ cost = price_per_gpu_hour * gpu_count * hours
122
+
123
+ # --- Display ---
124
+ st.divider()
125
+ st.subheader("Results")
126
+
127
+ colA, colB = st.columns(2)
128
+ with colA:
129
+ st.metric("Total FLOPs", f"{flops_total:,.2e} FLOPs")
130
+ st.metric("Effective compute", f"{effective_flops_per_s:,.2e} FLOPs/s")
131
+ st.metric("Steps (est)", f"{0 if steps == float('inf') else steps:,.0f}")
132
+ with colB:
133
+ st.metric("Wall‑clock time", f"{hours:,.1f} h (~{days:,.2f} d)")
134
+ st.metric("Throughput", f"{0 if throughput_tokens_per_s == float('inf') else throughput_tokens_per_s:,.0f} tok/s")
135
+ st.metric("Projected cost", f"${0 if cost == float('inf') else cost:,.0f}")
136
+
137
+ st.divider()
138
+
139
+ st.markdown(
140
+ f"""
141
+ **Summary**
142
+ - Params: **{model_params_b:,.2f}B** · Tokens: **{tokens_b:,.2f}B** (compute‑optimal: {use_compute_optimal})
143
+ - Constant **c = {c:g}** → Total ≈ **{flops_total:,.2e} FLOPs**.
144
+ - Hardware: **{gpu_count}× GPU**, peak **{peak_tflops:g} TFLOPs/GPU**, MFU **{efficiency}%** → Effective ≈ **{effective_flops_per_s:,.2e} FLOPs/s**.
145
+ - Time ≈ **{hours:,.1f} hours** (≈ {days:,.2f} days). Steps ≈ **{0 if steps == float('inf') else steps:,.0f}** (@ {tokens_per_step:,} tok/step).
146
+ - Rough cost ≈ **${0 if cost == float('inf') else cost:,.0f}** (@ ${price_per_gpu_hour:g}/GPU·h).
147
+ """
148
+ )
149
+
150
+ with st.expander("What is the Chinchilla rule? Is it 1 epoch?"):
151
+ st.markdown(
152
+ """
153
+ **Chinchilla scaling** is a *compute‑optimal* rule of thumb: for a fixed compute budget, scale
154
+ the **training tokens** roughly in proportion to the **model parameters** (commonly ~20× tokens per parameter).
155
+ It is **not** about training for exactly one epoch. In web‑scale pretraining, datasets are often sampled with
156
+ replacement or mixed; you might see data multiple times or less than once. The rule speaks to the *total number
157
+ of tokens* the model should process for best use of compute, not to dataset passes.
158
+ """
159
+ )
160
+
161
+ st.success("Ready. Tweak inputs on the left to explore different scenarios.")
requirements.txt CHANGED
@@ -1,3 +1,75 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.12.15
3
+ aiosignal==1.4.0
4
+ altair==5.5.0
5
+ annotated-types==0.7.0
6
+ anyio==4.10.0
7
+ attrs==25.3.0
8
+ beautifulsoup4==4.13.5
9
+ blinker==1.9.0
10
+ bs4==0.0.2
11
+ cachetools==6.2.0
12
+ certifi==2025.8.3
13
+ charset-normalizer==3.4.3
14
+ click==8.2.1
15
+ datasets==4.1.1
16
+ deprecation==2.1.0
17
+ dill==0.4.0
18
+ distro==1.9.0
19
+ filelock==3.19.1
20
+ frozenlist==1.7.0
21
+ fsspec==2025.9.0
22
+ gitdb==4.0.12
23
+ GitPython==3.1.45
24
+ h11==0.16.0
25
+ hf-xet==1.1.10
26
+ httpcore==1.0.9
27
+ httpx==0.28.1
28
+ huggingface-hub==0.35.3
29
+ idna==3.10
30
+ Jinja2==3.1.6
31
+ jiter==0.10.0
32
+ jsonschema==4.25.1
33
+ jsonschema-specifications==2025.4.1
34
+ lancedb==0.24.3
35
+ MarkupSafe==3.0.2
36
+ multidict==6.6.4
37
+ multiprocess==0.70.16
38
+ narwhals==2.3.0
39
+ numpy==2.3.2
40
+ openai==1.105.0
41
+ overrides==7.7.0
42
+ packaging==25.0
43
+ pandas==2.3.2
44
+ pillow==11.3.0
45
+ propcache==0.3.2
46
+ protobuf==6.32.0
47
+ pyarrow==21.0.0
48
+ pydantic==2.11.7
49
+ pydantic_core==2.33.2
50
+ pydeck==0.9.1
51
+ pylance==0.35.0
52
+ python-dateutil==2.9.0.post0
53
+ python-dotenv==1.1.1
54
+ pytz==2025.2
55
+ PyYAML==6.0.3
56
+ referencing==0.36.2
57
+ requests==2.32.5
58
+ rpds-py==0.27.1
59
+ setuptools==78.1.1
60
+ six==1.17.0
61
+ smmap==5.0.2
62
+ sniffio==1.3.1
63
+ soupsieve==2.8
64
+ streamlit==1.49.1
65
+ tenacity==9.1.2
66
+ toml==0.10.2
67
+ tornado==6.5.2
68
+ tqdm==4.67.1
69
+ typing-inspection==0.4.1
70
+ typing_extensions==4.15.0
71
+ tzdata==2025.2
72
+ urllib3==2.5.0
73
+ wheel==0.45.1
74
+ xxhash==3.5.0
75
+ yarl==1.20.1
static/logo_light.png ADDED