jameshuntercarter commited on
Commit
6f91e60
·
verified ·
1 Parent(s): d549dfb

Upload 138 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +21 -0
  2. LICENSE +175 -0
  3. README.md +325 -0
  4. __pycache__/predict.cpython-311.pyc +0 -0
  5. boson_multimodal/.DS_Store +0 -0
  6. boson_multimodal/__init__.py +1 -0
  7. boson_multimodal/__pycache__/__init__.cpython-311.pyc +0 -0
  8. boson_multimodal/__pycache__/constants.cpython-311.pyc +0 -0
  9. boson_multimodal/__pycache__/data_types.cpython-311.pyc +0 -0
  10. boson_multimodal/audio_processing/LICENSE +51 -0
  11. boson_multimodal/audio_processing/__pycache__/higgs_audio_tokenizer.cpython-311.pyc +0 -0
  12. boson_multimodal/audio_processing/__pycache__/semantic_module.cpython-311.pyc +0 -0
  13. boson_multimodal/audio_processing/descriptaudiocodec/__init__.py +0 -0
  14. boson_multimodal/audio_processing/descriptaudiocodec/__pycache__/__init__.cpython-311.pyc +0 -0
  15. boson_multimodal/audio_processing/descriptaudiocodec/dac/model/__pycache__/base.cpython-311.pyc +0 -0
  16. boson_multimodal/audio_processing/descriptaudiocodec/dac/model/__pycache__/dac.cpython-311.pyc +0 -0
  17. boson_multimodal/audio_processing/descriptaudiocodec/dac/model/base.py +286 -0
  18. boson_multimodal/audio_processing/descriptaudiocodec/dac/model/dac.py +365 -0
  19. boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/layers.py +33 -0
  20. boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/quantize.py +251 -0
  21. boson_multimodal/audio_processing/higgs_audio_tokenizer.py +327 -0
  22. boson_multimodal/audio_processing/quantization/__init__.py +8 -0
  23. boson_multimodal/audio_processing/quantization/__pycache__/__init__.cpython-311.pyc +0 -0
  24. boson_multimodal/audio_processing/quantization/__pycache__/core_vq_lsx_version.cpython-311.pyc +0 -0
  25. boson_multimodal/audio_processing/quantization/__pycache__/ddp_utils.cpython-311.pyc +0 -0
  26. boson_multimodal/audio_processing/quantization/__pycache__/distrib.cpython-311.pyc +0 -0
  27. boson_multimodal/audio_processing/quantization/__pycache__/vq.cpython-311.pyc +0 -0
  28. boson_multimodal/audio_processing/quantization/ac.py +292 -0
  29. boson_multimodal/audio_processing/quantization/core_vq.py +360 -0
  30. boson_multimodal/audio_processing/quantization/core_vq_lsx_version.py +425 -0
  31. boson_multimodal/audio_processing/quantization/ddp_utils.py +197 -0
  32. boson_multimodal/audio_processing/quantization/distrib.py +123 -0
  33. boson_multimodal/audio_processing/quantization/vq.py +116 -0
  34. boson_multimodal/audio_processing/semantic_module.py +282 -0
  35. boson_multimodal/constants.py +3 -0
  36. boson_multimodal/data_collator/__init__.py +0 -0
  37. boson_multimodal/data_collator/__pycache__/__init__.cpython-311.pyc +0 -0
  38. boson_multimodal/data_collator/__pycache__/higgs_audio_collator.cpython-311.pyc +0 -0
  39. boson_multimodal/data_collator/higgs_audio_collator.py +509 -0
  40. boson_multimodal/data_types.py +38 -0
  41. boson_multimodal/dataset/__init__.py +0 -0
  42. boson_multimodal/dataset/__pycache__/__init__.cpython-311.pyc +0 -0
  43. boson_multimodal/dataset/__pycache__/chatml_dataset.cpython-311.pyc +0 -0
  44. boson_multimodal/dataset/chatml_dataset.py +533 -0
  45. boson_multimodal/model/__init__.py +0 -0
  46. boson_multimodal/model/__pycache__/__init__.cpython-311.pyc +0 -0
  47. boson_multimodal/model/higgs_audio/__init__.py +9 -0
  48. boson_multimodal/model/higgs_audio/__pycache__/__init__.cpython-311.pyc +0 -0
  49. boson_multimodal/model/higgs_audio/__pycache__/audio_head.cpython-311.pyc +0 -0
  50. boson_multimodal/model/higgs_audio/__pycache__/common.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,24 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ boson_multimodal/model/higgs_audio/__pycache__/modeling_higgs_audio.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
37
+ examples/serve_engine/voice_examples/old_man.wav filter=lfs diff=lfs merge=lfs -text
38
+ examples/voice_prompts/belinda.wav filter=lfs diff=lfs merge=lfs -text
39
+ examples/voice_prompts/bigbang_amy.wav filter=lfs diff=lfs merge=lfs -text
40
+ examples/voice_prompts/bigbang_sheldon.wav filter=lfs diff=lfs merge=lfs -text
41
+ examples/voice_prompts/broom_salesman.wav filter=lfs diff=lfs merge=lfs -text
42
+ examples/voice_prompts/chadwick.wav filter=lfs diff=lfs merge=lfs -text
43
+ examples/voice_prompts/en_man.wav filter=lfs diff=lfs merge=lfs -text
44
+ examples/voice_prompts/en_woman.wav filter=lfs diff=lfs merge=lfs -text
45
+ examples/voice_prompts/fiftyshades_anna.wav filter=lfs diff=lfs merge=lfs -text
46
+ examples/voice_prompts/mabaoguo.wav filter=lfs diff=lfs merge=lfs -text
47
+ examples/voice_prompts/mabel.wav filter=lfs diff=lfs merge=lfs -text
48
+ examples/voice_prompts/shrek_donkey_es.wav filter=lfs diff=lfs merge=lfs -text
49
+ examples/voice_prompts/shrek_donkey.wav filter=lfs diff=lfs merge=lfs -text
50
+ examples/voice_prompts/shrek_fiona.wav filter=lfs diff=lfs merge=lfs -text
51
+ examples/voice_prompts/shrek_shrek.wav filter=lfs diff=lfs merge=lfs -text
52
+ examples/voice_prompts/vex.wav filter=lfs diff=lfs merge=lfs -text
53
+ examples/voice_prompts/zh_man_sichuan.wav filter=lfs diff=lfs merge=lfs -text
54
+ figures/emergent-tts-emotions-win-rate.png filter=lfs diff=lfs merge=lfs -text
55
+ figures/higgs_audio_tokenizer_architecture.png filter=lfs diff=lfs merge=lfs -text
56
+ figures/higgs_audio_v2_architecture_combined.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
README.md ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">Higgs Audio V2: Redefining Expressiveness in Audio Generation</h1>
2
+
3
+ <div align="center" style="display: flex; justify-content: center; margin-top: 10px;">
4
+ <a href="https://boson.ai/blog/higgs-audio-v2"><img src='https://img.shields.io/badge/🚀-Launch Blogpost-228B22' style="margin-right: 5px;"></a>
5
+ <a href="https://boson.ai/demo/tts"><img src="https://img.shields.io/badge/🕹️-Boson%20AI%20Playground-9C276A" style="margin-right: 5px;"></a>
6
+ <a href="https://huggingface.co/spaces/smola/higgs_audio_v2"><img src="https://img.shields.io/badge/🎮-HF%20Space%20Playground-8A2BE2" style="margin-right: 5px;"></a>
7
+ <a href="https://huggingface.co/bosonai/higgs-audio-v2-generation-3B-base"><img src="https://img.shields.io/badge/🤗-Checkpoints (3.6B LLM + 2.2B audio adapter)-ED5A22.svg" style="margin-right: 5px;"></a>
8
+ <a href="https://replicate.com/lucataco/higgs-audio-v2-generation-3b-base"><img src="https://replicate.com/lucataco/higgs-audio-v2-generation-3b-base/badge"></a>
9
+ </div>
10
+
11
+
12
+ We are open-sourcing Higgs Audio v2, a powerful audio foundation model pretrained on over 10 million hours of audio data and a diverse set of text data. Despite having no post-training or fine-tuning, Higgs Audio v2 excels in expressive audio generation, thanks to its deep language and acoustic understanding.
13
+
14
+ On [EmergentTTS-Eval](https://github.com/boson-ai/emergenttts-eval-public), it achieves win rates of **75.7%** and **55.7%** over "gpt-4o-mini-tts" on the "Emotions" and "Questions" categories, respectively. It also obtains state-of-the-art performance on traditional TTS benchmarks like Seed-TTS Eval and Emotional Speech Dataset (ESD). Moreover, the model demonstrates capabilities rarely seen in previous systems, including generating natural multi-speaker dialogues in multiple languages, automatic prosody adaptation during narration, melodic humming with the cloned voice, and simultaneous generation of speech and background music.
15
+
16
+ <p align="center">
17
+ <img src="figures/emergent-tts-emotions-win-rate.png" width=900>
18
+ </p>
19
+
20
+ Here's the demo video that shows some of its emergent capabilities (remember to unmute):
21
+
22
+ <video src="https://github.com/user-attachments/assets/0fd73fad-097f-48a9-9f3f-bc2a63b3818d" type="video/mp4" width="80%" controls>
23
+ </video>
24
+
25
+ Here's another demo video that show-cases the model's multilingual capability and how it enabled live translation (remember to unmute):
26
+
27
+ <video src="https://github.com/user-attachments/assets/2b9b01ff-67fc-4bd9-9714-7c7df09e38d6" type="video/mp4" width="80%" controls>
28
+ </video>
29
+
30
+ ## Installation
31
+
32
+ We recommend to use NVIDIA Deep Learning Container to manage the CUDA environment. Following are two docker images that we have verified:
33
+ - nvcr.io/nvidia/pytorch:25.02-py3
34
+ - nvcr.io/nvidia/pytorch:25.01-py3
35
+
36
+ Here's an example command for launching a docker container environment. Please also check the [official NVIDIA documentations](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch).
37
+
38
+ ```bash
39
+ docker run --gpus all --ipc=host --net=host --ulimit memlock=-1 --ulimit stack=67108864 -it --rm nvcr.io/nvidia/pytorch:25.02-py3 bash
40
+ ```
41
+
42
+ ### Option 1: Direct installation
43
+
44
+
45
+ ```bash
46
+ git clone https://github.com/boson-ai/higgs-audio.git
47
+ cd higgs-audio
48
+
49
+ pip install -r requirements.txt
50
+ pip install -e .
51
+ ```
52
+
53
+ ### Option 2: Using venv
54
+
55
+ ```bash
56
+ git clone https://github.com/boson-ai/higgs-audio.git
57
+ cd higgs-audio
58
+
59
+ python3 -m venv higgs_audio_env
60
+ source higgs_audio_env/bin/activate
61
+ pip install -r requirements.txt
62
+ pip install -e .
63
+ ```
64
+
65
+
66
+ ### Option 3: Using conda
67
+ ```bash
68
+ git clone https://github.com/boson-ai/higgs-audio.git
69
+ cd higgs-audio
70
+
71
+ conda create -n higgs_audio_env python=3.10
72
+ conda activate higgs_audio_env
73
+ pip install -r requirements.txt
74
+ pip install -e .
75
+ ```
76
+
77
+ ### Option 4: Using uv
78
+ ```bash
79
+ git clone https://github.com/boson-ai/higgs-audio.git
80
+ cd higgs-audio
81
+
82
+ uv venv --python 3.10
83
+ source .venv/bin/activate
84
+ uv pip install -r requirements.txt
85
+ uv pip install -e .
86
+ ```
87
+
88
+ ### Option 5: Using vllm
89
+
90
+ For advanced usage with higher throughput, we also built OpenAI compatible API server backed by vLLM engine for you to use.
91
+ Please refer to [examples/vllm](./examples/vllm) for more details.
92
+
93
+
94
+ ## Usage
95
+
96
+ > [!TIP]
97
+ > For optimal performance, run the generation examples on a machine equipped with GPU with at least 24GB memory!
98
+
99
+ ### Get Started
100
+
101
+ Here's a basic python snippet to help you get started.
102
+
103
+ ```python
104
+ from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
105
+ from boson_multimodal.data_types import ChatMLSample, Message, AudioContent
106
+
107
+ import torch
108
+ import torchaudio
109
+ import time
110
+ import click
111
+
112
+ MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
113
+ AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
114
+
115
+ system_prompt = (
116
+ "Generate audio following instruction.\n\n<|scene_desc_start|>\nAudio is recorded from a quiet room.\n<|scene_desc_end|>"
117
+ )
118
+
119
+ messages = [
120
+ Message(
121
+ role="system",
122
+ content=system_prompt,
123
+ ),
124
+ Message(
125
+ role="user",
126
+ content="The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years.",
127
+ ),
128
+ ]
129
+ device = "cuda" if torch.cuda.is_available() else "cpu"
130
+
131
+ serve_engine = HiggsAudioServeEngine(MODEL_PATH, AUDIO_TOKENIZER_PATH, device=device)
132
+
133
+ output: HiggsAudioResponse = serve_engine.generate(
134
+ chat_ml_sample=ChatMLSample(messages=messages),
135
+ max_new_tokens=1024,
136
+ temperature=0.3,
137
+ top_p=0.95,
138
+ top_k=50,
139
+ stop_strings=["<|end_of_text|>", "<|eot_id|>"],
140
+ )
141
+ torchaudio.save(f"output.wav", torch.from_numpy(output.audio)[None, :], output.sampling_rate)
142
+ ```
143
+
144
+ We also provide a list of examples under [examples](./examples). In the following we highlight a few examples to help you use Higgs Audio v2.
145
+
146
+ ### Zero-Shot Voice Cloning
147
+ Generate audio that sounds similar as the provided [reference audio](./examples/voice_prompts/belinda.wav).
148
+
149
+ ```bash
150
+ python3 examples/generation.py \
151
+ --transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
152
+ --ref_audio belinda \
153
+ --temperature 0.3 \
154
+ --out_path generation.wav
155
+ ```
156
+
157
+ The generation script will automatically use `cuda:0` if it founds cuda is available. To change the device id, specify `--device_id`:
158
+
159
+ ```bash
160
+ python3 examples/generation.py \
161
+ --transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
162
+ --ref_audio belinda \
163
+ --temperature 0.3 \
164
+ --device_id 0 \
165
+ --out_path generation.wav
166
+ ```
167
+
168
+ You can also try other voices. Check more example voices in [examples/voice_prompts](./examples/voice_prompts). You can also add your own voice to the folder.
169
+
170
+ ```bash
171
+ python3 examples/generation.py \
172
+ --transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
173
+ --ref_audio broom_salesman \
174
+ --temperature 0.3 \
175
+ --out_path generation.wav
176
+ ```
177
+
178
+ ### Voice Cloning via Cog (Replicate)
179
+
180
+ You can also run Higgs Audio v2 using [Cog](https://cog.run), which packages the model for reproducible inference. This is useful for deploying on Replicate or other platforms.
181
+
182
+ #### Prerequisites
183
+ - [Install Cog](https://cog.run/getting-started)
184
+ - GPU with at least 24GB VRAM (e.g., A100, RTX 4090)
185
+
186
+ #### Basic Text-to-Speech
187
+ ```bash
188
+ cog predict -i text="The sun rises in the east and sets in the west."
189
+ ```
190
+
191
+ #### Voice Cloning with Reference Audio
192
+ To clone a voice, provide a reference audio file:
193
+
194
+ ```bash
195
+ cog predict -i text="The sun rises in the east and sets in the west." \
196
+ -i ref_audio=@/path/to/reference_audio.wav
197
+ ```
198
+
199
+ #### Customization Parameters
200
+ - `text` (str): Text to convert to speech
201
+ - `ref_audio` (Path, optional): Reference audio file for voice cloning (WAV, MP3, etc.)
202
+ - `scene_description` (str): Scene context for audio generation (default: "Audio is recorded from a quiet room.")
203
+ - `temperature` (float): Controls randomness, 0.1-1.0 (default: 0.3, lower = more deterministic)
204
+ - `top_p` (float): Nucleus sampling parameter, 0.1-1.0 (default: 0.95)
205
+ - `top_k` (int): Top-k sampling, 1-100 (default: 50)
206
+ - `max_new_tokens` (int): Maximum audio tokens to generate, 256-2048 (default: 1024)
207
+ - `system_message` (str): Custom system prompt (optional)
208
+
209
+ #### Example: Generate with Custom Scene
210
+ ```bash
211
+ cog predict -i text="Generate a whisper voice in a noisy cafe." \
212
+ -i scene_description="Audio is recorded from a busy cafe with background chatter." \
213
+ -i temperature=0.5
214
+ ```
215
+
216
+ #### Example: Clone Multiple Voices
217
+ ```bash
218
+ cog predict -i text="Speaker one talks here." -i [email protected]
219
+ cog predict -i text="Speaker two talks here." -i [email protected]
220
+ ```
221
+
222
+ ### Single-speaker Generation with Smart Voice
223
+ If you do not specify reference voice, the model will decide the voice based on the transcript it sees.
224
+
225
+ ```bash
226
+ python3 examples/generation.py \
227
+ --transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
228
+ --temperature 0.3 \
229
+ --out_path generation.wav
230
+ ```
231
+
232
+
233
+ ### Multi-speaker Dialog with Smart Voice
234
+ Generate multi-speaker dialog. The model will decide the voices based on the transcript it sees.
235
+
236
+ ```bash
237
+ python3 examples/generation.py \
238
+ --transcript examples/transcript/multi_speaker/en_argument.txt \
239
+ --seed 12345 \
240
+ --out_path generation.wav
241
+ ```
242
+
243
+ ### Multi-speaker Dialog with Voice Clone
244
+
245
+ Generate multi-speaker dialog with the voices you picked.
246
+
247
+ ```bash
248
+ python3 examples/generation.py \
249
+ --transcript examples/transcript/multi_speaker/en_argument.txt \
250
+ --ref_audio belinda,broom_salesman \
251
+ --ref_audio_in_system_message \
252
+ --chunk_method speaker \
253
+ --seed 12345 \
254
+ --out_path generation.wav
255
+ ```
256
+
257
+
258
+ ## Technical Details
259
+ <img src="figures/higgs_audio_v2_architecture_combined.png" width=900>
260
+
261
+
262
+ Higgs Audio v2 adopts the "generation variant" depicted in the architecture figure above. Its strong performance is driven by three key technical innovations:
263
+ - We developed an automated annotation pipeline that leverages multiple ASR models, sound event classification models, and our in-house audio understanding model. Using this pipeline, we cleaned and annotated 10 million hours audio data, which we refer to as **AudioVerse**. The in-house understanding model is finetuned on top of [Higgs Audio v1 Understanding](https://www.boson.ai/blog/higgs-audio), which adopts the "understanding variant" shown in the architecture figure.
264
+ - We trained a unified audio tokenizer from scratch that captures both semantic and acoustic features. Learn more in the [tokenizer blog](./tech_blogs/TOKENIZER_BLOG.md).
265
+ - We proposed the DualFFN architecture, which enhances the LLM’s ability to model acoustics tokens with minimal computational overhead. See the [architecture blog](./tech_blogs/ARCHITECTURE_BLOG.md).
266
+
267
+ ## Evaluation
268
+
269
+ Here's the performance of Higgs Audio v2 on four benchmarks, [Seed-TTS Eval](https://github.com/BytedanceSpeech/seed-tts-eval), [Emotional Speech Dataset (ESD)](https://paperswithcode.com/dataset/esd), [EmergentTTS-Eval](https://arxiv.org/abs/2505.23009), and Multi-speaker Eval:
270
+
271
+ #### Seed-TTS Eval & ESD
272
+
273
+ We prompt Higgs Audio v2 with the reference text, reference audio, and target text for zero-shot TTS. We use the standard evaluation metrics from Seed-TTS Eval and ESD.
274
+
275
+ | | SeedTTS-Eval| | ESD | |
276
+ |------------------------------|--------|--------|---------|-------------------|
277
+ | | WER ↓ | SIM ↑ | WER ↓ | SIM (emo2vec) ↑ |
278
+ | Cosyvoice2 | 2.28 | 65.49 | 2.71 | 80.48 |
279
+ | Qwen2.5-omni† | 2.33 | 64.10 | - | - |
280
+ | ElevenLabs Multilingual V2 | **1.43** | 50.00 | 1.66 | 65.87 |
281
+ | Higgs Audio v1 | 2.18 | 66.27 | **1.49** | 82.84 |
282
+ | Higgs Audio v2 (base) | 2.44 | **67.70** | 1.78 | **86.13** |
283
+
284
+
285
+ #### EmergentTTS-Eval ("Emotions" and "Questions")
286
+
287
+ Following the [EmergentTTS-Eval Paper](https://arxiv.org/abs/2505.23009), we report the win-rate over "gpt-4o-mini-tts" with the "alloy" voice. The judge model is Gemini 2.5 Pro.
288
+
289
+ | Model | Emotions (%) ↑ | Questions (%) ↑ |
290
+ |------------------------------------|--------------|----------------|
291
+ | Higgs Audio v2 (base) | **75.71%** | **55.71%** |
292
+ | [gpt-4o-audio-preview†](https://platform.openai.com/docs/models/gpt-4o-audio-preview) | 61.64% | 47.85% |
293
+ | [Hume.AI](https://www.hume.ai/research) | 61.60% | 43.21% |
294
+ | **BASELINE:** [gpt-4o-mini-tts](https://platform.openai.com/docs/models/gpt-4o-mini-tts) | 50.00% | 50.00% |
295
+ | [Qwen 2.5 Omni†](https://github.com/QwenLM/Qwen2.5-Omni) | 41.60% | 51.78% |
296
+ | [minimax/speech-02-hd](https://replicate.com/minimax/speech-02-hd) | 40.86% | 47.32% |
297
+ | [ElevenLabs Multilingual v2](https://elevenlabs.io/blog/eleven-multilingual-v2) | 30.35% | 39.46% |
298
+ | [DeepGram Aura-2](https://deepgram.com/learn/introducing-aura-2-enterprise-text-to-speech) | 29.28% | 48.21% |
299
+ | [Sesame csm-1B](https://github.com/SesameAILabs/csm) | 15.96% | 31.78% |
300
+
301
+ <sup><sub>'†' means using the strong-prompting method described in the paper.</sub></sup>
302
+
303
+
304
+ #### Multi-speaker Eval
305
+
306
+ We also designed a multi-speaker evaluation benchmark to evaluate the capability of Higgs Audio v2 for multi-speaker dialog generation. The benchmark contains three subsets
307
+
308
+ - `two-speaker-conversation`: 1000 synthetic dialogues involving two speakers. We fix two reference audio clips to evaluate the model's ability in double voice cloning for utterances ranging from 4 to 10 dialogues between two randomly chosen persona.
309
+ - `small talk (no ref)`: 250 synthetic dialogues curated in the same way as above, but are characterized by short utterances and a limited number of turns (4–6), we do not fix reference audios in this case and this set is designed to evaluate the model's ability to automatically assign appropriate voices to speakers.
310
+ - `small talk (ref)`: 250 synthetic dialogues similar to above, but contains even shorter utterances as this set is meant to include reference clips in it's context, similar to `two-speaker-conversation`.
311
+
312
+
313
+ We report the word-error-rate (WER) and the geometric mean between intra-speaker similarity and inter-speaker dis-similarity on these three subsets. Other than Higgs Audio v2, we also evaluated [MoonCast](https://github.com/jzq2000/MoonCast) and [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626), two of the most popular open-source models capable of multi-speaker dialog generation. Results are summarized in the following table. We are not able to run [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626) on our "two-speaker-conversation" subset due to its strict limitation on the length of the utterances and output audio.
314
+
315
+ | | two-speaker-conversation | |small talk | | small talk (no ref) | |
316
+ | ---------------------------------------------- | -------------- | ------------------ | ---------- | -------------- | ------------------- | -------------- |
317
+ | | WER ↓ | Mean Sim & Dis-sim ↑ | WER ↓ | Mean Sim & Dis-sim ↑ | WER ↓ | Mean Sim & Dis-sim ↑ |
318
+ | [MoonCast](https://github.com/jzq2000/MoonCast) | 38.77 | 46.02 | **8.33** | 63.68 | 24.65 | 53.94 |
319
+ | [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626) | \- | \- | 17.62 | 63.15 | 19.46 | **61.14** |
320
+ | Higgs Audio v2 (base) | **18.88** | **51.95** | 11.89 | **67.92** | **14.65** | 55.28 |
321
+
322
+
323
+ ## Third-Party Licenses
324
+
325
+ The `boson_multimodal/audio_processing/` directory contains code derived from third-party repositories, primarily from [xcodec](https://github.com/zhenye234/xcodec). Please see the [`LICENSE`](boson_multimodal/audio_processing/LICENSE) in that directory for complete attribution and licensing information.
__pycache__/predict.cpython-311.pyc ADDED
Binary file (5.98 kB). View file
 
boson_multimodal/.DS_Store ADDED
Binary file (6.15 kB). View file
 
boson_multimodal/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model.higgs_audio import HiggsAudioConfig, HiggsAudioModel
boson_multimodal/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (246 Bytes). View file
 
boson_multimodal/__pycache__/constants.cpython-311.pyc ADDED
Binary file (248 Bytes). View file
 
boson_multimodal/__pycache__/data_types.cpython-311.pyc ADDED
Binary file (2.35 kB). View file
 
boson_multimodal/audio_processing/LICENSE ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Third-Party License Attribution for Audio Processing Module
2
+ ===========================================================
3
+
4
+ This directory contains code derived from multiple open-source projects.
5
+ The following sections detail the licenses and attributions for third-party code.
6
+
7
+ ## XCodec Repository
8
+ The code in this directory is derived from:
9
+ https://github.com/zhenye234/xcodec
10
+
11
+ ## Individual File Attributions
12
+
13
+ ### Quantization Module (quantization/)
14
+ - Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository
15
+ - Individual files contain their own license headers where applicable
16
+ - The vector-quantize-pytorch portions are licensed under the MIT License
17
+
18
+ ## License Terms
19
+
20
+ ### MIT License (for applicable portions)
21
+ Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ of this software and associated documentation files (the "Software"), to deal
23
+ in the Software without restriction, including without limitation the rights
24
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ copies of the Software, and to permit persons to whom the Software is
26
+ furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice shall be included in all
29
+ copies or substantial portions of the Software.
30
+
31
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ SOFTWARE.
38
+
39
+ ## Attribution Requirements
40
+ When using this code, please ensure proper attribution to:
41
+ 1. The original xcodec repository: https://github.com/zhenye234/xcodec
42
+ 2. Any other repositories mentioned in individual file headers
43
+ 3. This derivative work and its modifications
44
+
45
+ ## Disclaimer
46
+ This directory contains modified versions of the original code. Please refer to
47
+ the original repositories for the canonical implementations and their specific
48
+ license terms.
49
+
50
+ For any questions about licensing or attribution, please check the individual
51
+ file headers and the original source repositories.
boson_multimodal/audio_processing/__pycache__/higgs_audio_tokenizer.cpython-311.pyc ADDED
Binary file (18.9 kB). View file
 
boson_multimodal/audio_processing/__pycache__/semantic_module.cpython-311.pyc ADDED
Binary file (12.4 kB). View file
 
boson_multimodal/audio_processing/descriptaudiocodec/__init__.py ADDED
File without changes
boson_multimodal/audio_processing/descriptaudiocodec/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (178 Bytes). View file
 
boson_multimodal/audio_processing/descriptaudiocodec/dac/model/__pycache__/base.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
boson_multimodal/audio_processing/descriptaudiocodec/dac/model/__pycache__/dac.cpython-311.pyc ADDED
Binary file (17.6 kB). View file
 
boson_multimodal/audio_processing/descriptaudiocodec/dac/model/base.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.")
52
+ return cls(codes=codes, **artifacts["metadata"])
53
+
54
+
55
+ class CodecMixin:
56
+ @property
57
+ def padding(self):
58
+ if not hasattr(self, "_padding"):
59
+ self._padding = True
60
+ return self._padding
61
+
62
+ @padding.setter
63
+ def padding(self, value):
64
+ assert isinstance(value, bool)
65
+
66
+ layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))]
67
+
68
+ for layer in layers:
69
+ if value:
70
+ if hasattr(layer, "original_padding"):
71
+ layer.padding = layer.original_padding
72
+ else:
73
+ layer.original_padding = layer.padding
74
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
75
+
76
+ self._padding = value
77
+
78
+ def get_delay(self):
79
+ # Any number works here, delay is invariant to input length
80
+ l_out = self.get_output_length(0)
81
+ L = l_out
82
+
83
+ layers = []
84
+ for layer in self.modules():
85
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
86
+ layers.append(layer)
87
+
88
+ for layer in reversed(layers):
89
+ d = layer.dilation[0]
90
+ k = layer.kernel_size[0]
91
+ s = layer.stride[0]
92
+
93
+ if isinstance(layer, nn.ConvTranspose1d):
94
+ L = ((L - d * (k - 1) - 1) / s) + 1
95
+ elif isinstance(layer, nn.Conv1d):
96
+ L = (L - 1) * s + d * (k - 1) + 1
97
+
98
+ L = math.ceil(L)
99
+
100
+ l_in = L
101
+
102
+ return (l_in - l_out) // 2
103
+
104
+ def get_output_length(self, input_length):
105
+ L = input_length
106
+ # Calculate output length
107
+ for layer in self.modules():
108
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
109
+ d = layer.dilation[0]
110
+ k = layer.kernel_size[0]
111
+ s = layer.stride[0]
112
+
113
+ if isinstance(layer, nn.Conv1d):
114
+ L = ((L - d * (k - 1) - 1) / s) + 1
115
+ elif isinstance(layer, nn.ConvTranspose1d):
116
+ L = (L - 1) * s + d * (k - 1) + 1
117
+
118
+ L = math.floor(L)
119
+ return L
120
+
121
+ @torch.no_grad()
122
+ def compress(
123
+ self,
124
+ audio_path_or_signal: Union[str, Path, AudioSignal],
125
+ win_duration: float = 1.0,
126
+ verbose: bool = False,
127
+ normalize_db: float = -16,
128
+ n_quantizers: int = None,
129
+ ) -> DACFile:
130
+ """Processes an audio signal from a file or AudioSignal object into
131
+ discrete codes. This function processes the signal in short windows,
132
+ using constant GPU memory.
133
+
134
+ Parameters
135
+ ----------
136
+ audio_path_or_signal : Union[str, Path, AudioSignal]
137
+ audio signal to reconstruct
138
+ win_duration : float, optional
139
+ window duration in seconds, by default 5.0
140
+ verbose : bool, optional
141
+ by default False
142
+ normalize_db : float, optional
143
+ normalize db, by default -16
144
+
145
+ Returns
146
+ -------
147
+ DACFile
148
+ Object containing compressed codes and metadata
149
+ required for decompression
150
+ """
151
+ audio_signal = audio_path_or_signal
152
+ if isinstance(audio_signal, (str, Path)):
153
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
154
+
155
+ self.eval()
156
+ original_padding = self.padding
157
+ original_device = audio_signal.device
158
+
159
+ audio_signal = audio_signal.clone()
160
+ original_sr = audio_signal.sample_rate
161
+
162
+ resample_fn = audio_signal.resample
163
+ loudness_fn = audio_signal.loudness
164
+
165
+ # If audio is > 10 minutes long, use the ffmpeg versions
166
+ if audio_signal.signal_duration >= 10 * 60 * 60:
167
+ resample_fn = audio_signal.ffmpeg_resample
168
+ loudness_fn = audio_signal.ffmpeg_loudness
169
+
170
+ original_length = audio_signal.signal_length
171
+ resample_fn(self.sample_rate)
172
+ input_db = loudness_fn()
173
+
174
+ if normalize_db is not None:
175
+ audio_signal.normalize(normalize_db)
176
+ audio_signal.ensure_max_of_audio()
177
+
178
+ nb, nac, nt = audio_signal.audio_data.shape
179
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
180
+ win_duration = audio_signal.signal_duration if win_duration is None else win_duration
181
+
182
+ if audio_signal.signal_duration <= win_duration:
183
+ # Unchunked compression (used if signal length < win duration)
184
+ self.padding = True
185
+ n_samples = nt
186
+ hop = nt
187
+ else:
188
+ # Chunked inference
189
+ self.padding = False
190
+ # Zero-pad signal on either side by the delay
191
+ audio_signal.zero_pad(self.delay, self.delay)
192
+ n_samples = int(win_duration * self.sample_rate)
193
+ # Round n_samples to nearest hop length multiple
194
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
195
+ hop = self.get_output_length(n_samples)
196
+
197
+ codes = []
198
+ range_fn = range if not verbose else tqdm.trange
199
+
200
+ for i in range_fn(0, nt, hop):
201
+ x = audio_signal[..., i : i + n_samples]
202
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
203
+
204
+ audio_data = x.audio_data.to(self.device)
205
+ audio_data = self.preprocess(audio_data, self.sample_rate)
206
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
207
+ codes.append(c.to(original_device))
208
+ chunk_length = c.shape[-1]
209
+
210
+ codes = torch.cat(codes, dim=-1)
211
+
212
+ dac_file = DACFile(
213
+ codes=codes,
214
+ chunk_length=chunk_length,
215
+ original_length=original_length,
216
+ input_db=input_db,
217
+ channels=nac,
218
+ sample_rate=original_sr,
219
+ padding=self.padding,
220
+ dac_version=SUPPORTED_VERSIONS[-1],
221
+ )
222
+
223
+ if n_quantizers is not None:
224
+ codes = codes[:, :n_quantizers, :]
225
+
226
+ self.padding = original_padding
227
+ return dac_file
228
+
229
+ @torch.no_grad()
230
+ def decompress(
231
+ self,
232
+ obj: Union[str, Path, DACFile],
233
+ verbose: bool = False,
234
+ ) -> AudioSignal:
235
+ """Reconstruct audio from a given .dac file
236
+
237
+ Parameters
238
+ ----------
239
+ obj : Union[str, Path, DACFile]
240
+ .dac file location or corresponding DACFile object.
241
+ verbose : bool, optional
242
+ Prints progress if True, by default False
243
+
244
+ Returns
245
+ -------
246
+ AudioSignal
247
+ Object with the reconstructed audio
248
+ """
249
+ self.eval()
250
+ if isinstance(obj, (str, Path)):
251
+ obj = DACFile.load(obj)
252
+
253
+ original_padding = self.padding
254
+ self.padding = obj.padding
255
+
256
+ range_fn = range if not verbose else tqdm.trange
257
+ codes = obj.codes
258
+ original_device = codes.device
259
+ chunk_length = obj.chunk_length
260
+ recons = []
261
+
262
+ for i in range_fn(0, codes.shape[-1], chunk_length):
263
+ c = codes[..., i : i + chunk_length].to(self.device)
264
+ z = self.quantizer.from_codes(c)[0]
265
+ r = self.decode(z)
266
+ recons.append(r.to(original_device))
267
+
268
+ recons = torch.cat(recons, dim=-1)
269
+ recons = AudioSignal(recons, self.sample_rate)
270
+
271
+ resample_fn = recons.resample
272
+ loudness_fn = recons.loudness
273
+
274
+ # If audio is > 10 minutes long, use the ffmpeg versions
275
+ if recons.signal_duration >= 10 * 60 * 60:
276
+ resample_fn = recons.ffmpeg_resample
277
+ loudness_fn = recons.ffmpeg_loudness
278
+
279
+ recons.normalize(obj.input_db)
280
+ resample_fn(obj.sample_rate)
281
+ recons = recons[..., : obj.original_length]
282
+ loudness_fn()
283
+ recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length)
284
+
285
+ self.padding = original_padding
286
+ return recons
boson_multimodal/audio_processing/descriptaudiocodec/dac/model/dac.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from dac.nn.layers import Snake1d
13
+ from dac.nn.layers import WNConv1d
14
+ from dac.nn.layers import WNConvTranspose1d
15
+ from dac.nn.quantize import ResidualVectorQuantize
16
+
17
+
18
+ def init_weights(m):
19
+ if isinstance(m, nn.Conv1d):
20
+ nn.init.trunc_normal_(m.weight, std=0.02)
21
+ nn.init.constant_(m.bias, 0)
22
+
23
+
24
+ class ResidualUnit(nn.Module):
25
+ def __init__(self, dim: int = 16, dilation: int = 1):
26
+ super().__init__()
27
+ pad = ((7 - 1) * dilation) // 2
28
+ self.block = nn.Sequential(
29
+ Snake1d(dim),
30
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
31
+ Snake1d(dim),
32
+ WNConv1d(dim, dim, kernel_size=1),
33
+ )
34
+
35
+ def forward(self, x):
36
+ y = self.block(x)
37
+ pad = (x.shape[-1] - y.shape[-1]) // 2
38
+ if pad > 0:
39
+ x = x[..., pad:-pad]
40
+ return x + y
41
+
42
+
43
+ class EncoderBlock(nn.Module):
44
+ def __init__(self, dim: int = 16, stride: int = 1):
45
+ super().__init__()
46
+ self.block = nn.Sequential(
47
+ ResidualUnit(dim // 2, dilation=1),
48
+ ResidualUnit(dim // 2, dilation=3),
49
+ ResidualUnit(dim // 2, dilation=9),
50
+ Snake1d(dim // 2),
51
+ WNConv1d(
52
+ dim // 2,
53
+ dim,
54
+ kernel_size=2 * stride,
55
+ stride=stride,
56
+ padding=math.ceil(stride / 2),
57
+ ),
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.block(x)
62
+
63
+
64
+ class Encoder(nn.Module):
65
+ def __init__(
66
+ self,
67
+ d_model: int = 64,
68
+ strides: list = [2, 4, 8, 8],
69
+ d_latent: int = 256,
70
+ ):
71
+ super().__init__()
72
+ # Create first convolution
73
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
74
+
75
+ # Create EncoderBlocks that double channels as they downsample by `stride`
76
+ for stride in strides:
77
+ d_model *= 2
78
+ self.block += [EncoderBlock(d_model, stride=stride)]
79
+
80
+ # Create last convolution
81
+ self.block += [
82
+ Snake1d(d_model),
83
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
84
+ ]
85
+
86
+ # Wrap black into nn.Sequential
87
+ self.block = nn.Sequential(*self.block)
88
+ self.enc_dim = d_model
89
+
90
+ def forward(self, x):
91
+ return self.block(x)
92
+
93
+
94
+ class DecoderBlock(nn.Module):
95
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0):
96
+ super().__init__()
97
+ self.block = nn.Sequential(
98
+ Snake1d(input_dim),
99
+ WNConvTranspose1d(
100
+ input_dim,
101
+ output_dim,
102
+ kernel_size=2 * stride,
103
+ stride=stride,
104
+ padding=math.ceil(stride / 2),
105
+ output_padding=stride % 2, # out_pad,
106
+ ),
107
+ ResidualUnit(output_dim, dilation=1),
108
+ ResidualUnit(output_dim, dilation=3),
109
+ ResidualUnit(output_dim, dilation=9),
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.block(x)
114
+
115
+
116
+ class Decoder(nn.Module):
117
+ def __init__(
118
+ self,
119
+ input_channel,
120
+ channels,
121
+ rates,
122
+ d_out: int = 1,
123
+ ):
124
+ super().__init__()
125
+
126
+ # Add first conv layer
127
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
128
+
129
+ # Add upsampling + MRF blocks
130
+ for i, stride in enumerate(rates):
131
+ input_dim = channels // 2**i
132
+ output_dim = channels // 2 ** (i + 1)
133
+ if i == 1:
134
+ out_pad = 1
135
+ else:
136
+ out_pad = 0
137
+ layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)]
138
+
139
+ # Add final conv layer
140
+ layers += [
141
+ Snake1d(output_dim),
142
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
143
+ # nn.Tanh(),
144
+ ]
145
+
146
+ self.model = nn.Sequential(*layers)
147
+
148
+ def forward(self, x):
149
+ return self.model(x)
150
+
151
+
152
+ class DAC(BaseModel, CodecMixin):
153
+ def __init__(
154
+ self,
155
+ encoder_dim: int = 64,
156
+ encoder_rates: List[int] = [2, 4, 8, 8],
157
+ latent_dim: int = None,
158
+ decoder_dim: int = 1536,
159
+ decoder_rates: List[int] = [8, 8, 4, 2],
160
+ n_codebooks: int = 9,
161
+ codebook_size: int = 1024,
162
+ codebook_dim: Union[int, list] = 8,
163
+ quantizer_dropout: bool = False,
164
+ sample_rate: int = 44100,
165
+ ):
166
+ super().__init__()
167
+
168
+ self.encoder_dim = encoder_dim
169
+ self.encoder_rates = encoder_rates
170
+ self.decoder_dim = decoder_dim
171
+ self.decoder_rates = decoder_rates
172
+ self.sample_rate = sample_rate
173
+
174
+ if latent_dim is None:
175
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
176
+
177
+ self.latent_dim = latent_dim
178
+
179
+ self.hop_length = np.prod(encoder_rates)
180
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
181
+
182
+ self.n_codebooks = n_codebooks
183
+ self.codebook_size = codebook_size
184
+ self.codebook_dim = codebook_dim
185
+ self.quantizer = ResidualVectorQuantize(
186
+ input_dim=latent_dim,
187
+ n_codebooks=n_codebooks,
188
+ codebook_size=codebook_size,
189
+ codebook_dim=codebook_dim,
190
+ quantizer_dropout=quantizer_dropout,
191
+ )
192
+
193
+ self.decoder = Decoder(
194
+ latent_dim,
195
+ decoder_dim,
196
+ decoder_rates,
197
+ )
198
+ self.sample_rate = sample_rate
199
+ self.apply(init_weights)
200
+
201
+ self.delay = self.get_delay()
202
+
203
+ def preprocess(self, audio_data, sample_rate):
204
+ if sample_rate is None:
205
+ sample_rate = self.sample_rate
206
+ assert sample_rate == self.sample_rate
207
+
208
+ length = audio_data.shape[-1]
209
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
210
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
211
+
212
+ return audio_data
213
+
214
+ def encode(
215
+ self,
216
+ audio_data: torch.Tensor,
217
+ n_quantizers: int = None,
218
+ ):
219
+ """Encode given audio data and return quantized latent codes
220
+
221
+ Parameters
222
+ ----------
223
+ audio_data : Tensor[B x 1 x T]
224
+ Audio data to encode
225
+ n_quantizers : int, optional
226
+ Number of quantizers to use, by default None
227
+ If None, all quantizers are used.
228
+
229
+ Returns
230
+ -------
231
+ dict
232
+ A dictionary with the following keys:
233
+ "z" : Tensor[B x D x T]
234
+ Quantized continuous representation of input
235
+ "codes" : Tensor[B x N x T]
236
+ Codebook indices for each codebook
237
+ (quantized discrete representation of input)
238
+ "latents" : Tensor[B x N*D x T]
239
+ Projected latents (continuous representation of input before quantization)
240
+ "vq/commitment_loss" : Tensor[1]
241
+ Commitment loss to train encoder to predict vectors closer to codebook
242
+ entries
243
+ "vq/codebook_loss" : Tensor[1]
244
+ Codebook loss to update the codebook
245
+ "length" : int
246
+ Number of samples in input audio
247
+ """
248
+ z = self.encoder(audio_data)
249
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
250
+ return z, codes, latents, commitment_loss, codebook_loss
251
+
252
+ def decode(self, z: torch.Tensor):
253
+ """Decode given latent codes and return audio data
254
+
255
+ Parameters
256
+ ----------
257
+ z : Tensor[B x D x T]
258
+ Quantized continuous representation of input
259
+ length : int, optional
260
+ Number of samples in output audio, by default None
261
+
262
+ Returns
263
+ -------
264
+ dict
265
+ A dictionary with the following keys:
266
+ "audio" : Tensor[B x 1 x length]
267
+ Decoded audio data.
268
+ """
269
+ return self.decoder(z)
270
+
271
+ def forward(
272
+ self,
273
+ audio_data: torch.Tensor,
274
+ sample_rate: int = None,
275
+ n_quantizers: int = None,
276
+ ):
277
+ """Model forward pass
278
+
279
+ Parameters
280
+ ----------
281
+ audio_data : Tensor[B x 1 x T]
282
+ Audio data to encode
283
+ sample_rate : int, optional
284
+ Sample rate of audio data in Hz, by default None
285
+ If None, defaults to `self.sample_rate`
286
+ n_quantizers : int, optional
287
+ Number of quantizers to use, by default None.
288
+ If None, all quantizers are used.
289
+
290
+ Returns
291
+ -------
292
+ dict
293
+ A dictionary with the following keys:
294
+ "z" : Tensor[B x D x T]
295
+ Quantized continuous representation of input
296
+ "codes" : Tensor[B x N x T]
297
+ Codebook indices for each codebook
298
+ (quantized discrete representation of input)
299
+ "latents" : Tensor[B x N*D x T]
300
+ Projected latents (continuous representation of input before quantization)
301
+ "vq/commitment_loss" : Tensor[1]
302
+ Commitment loss to train encoder to predict vectors closer to codebook
303
+ entries
304
+ "vq/codebook_loss" : Tensor[1]
305
+ Codebook loss to update the codebook
306
+ "length" : int
307
+ Number of samples in input audio
308
+ "audio" : Tensor[B x 1 x length]
309
+ Decoded audio data.
310
+ """
311
+ length = audio_data.shape[-1]
312
+ audio_data = self.preprocess(audio_data, sample_rate)
313
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
314
+
315
+ x = self.decode(z)
316
+ return {
317
+ "audio": x[..., :length],
318
+ "z": z,
319
+ "codes": codes,
320
+ "latents": latents,
321
+ "vq/commitment_loss": commitment_loss,
322
+ "vq/codebook_loss": codebook_loss,
323
+ }
324
+
325
+
326
+ if __name__ == "__main__":
327
+ import numpy as np
328
+ from functools import partial
329
+
330
+ model = DAC().to("cpu")
331
+
332
+ for n, m in model.named_modules():
333
+ o = m.extra_repr()
334
+ p = sum([np.prod(p.size()) for p in m.parameters()])
335
+ fn = lambda o, p: o + f" {p / 1e6:<.3f}M params."
336
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
337
+ print(model)
338
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
339
+
340
+ length = 88200 * 2
341
+ x = torch.randn(1, 1, length).to(model.device)
342
+ x.requires_grad_(True)
343
+ x.retain_grad()
344
+
345
+ # Make a forward pass
346
+ out = model(x)["audio"]
347
+ print("Input shape:", x.shape)
348
+ print("Output shape:", out.shape)
349
+
350
+ # Create gradient variable
351
+ grad = torch.zeros_like(out)
352
+ grad[:, :, grad.shape[-1] // 2] = 1
353
+
354
+ # Make a backward pass
355
+ out.backward(grad)
356
+
357
+ # Check non-zero values
358
+ gradmap = x.grad.squeeze(0)
359
+ gradmap = (gradmap != 0).sum(0) # sum across features
360
+ rf = (gradmap != 0).sum()
361
+
362
+ print(f"Receptive field: {rf.item()}")
363
+
364
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
365
+ model.decompress(model.compress(x, verbose=True), verbose=True)
boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/quantize.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from dac.nn.layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
65
+
66
+ z_q = self.out_proj(z_q)
67
+
68
+ return z_q, commitment_loss, codebook_loss, indices, z_e
69
+
70
+ def embed_code(self, embed_id):
71
+ return F.embedding(embed_id, self.codebook.weight)
72
+
73
+ def decode_code(self, embed_id):
74
+ return self.embed_code(embed_id).transpose(1, 2)
75
+
76
+ def decode_latents(self, latents):
77
+ encodings = rearrange(latents, "b d t -> (b t) d")
78
+ codebook = self.codebook.weight # codebook: (N x D)
79
+
80
+ # L2 normalize encodings and codebook (ViT-VQGAN)
81
+ encodings = F.normalize(encodings)
82
+ codebook = F.normalize(codebook)
83
+
84
+ # Compute euclidean distance with codebook
85
+ dist = (
86
+ encodings.pow(2).sum(1, keepdim=True)
87
+ - 2 * encodings @ codebook.t()
88
+ + codebook.pow(2).sum(1, keepdim=True).t()
89
+ )
90
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
91
+ z_q = self.decode_code(indices)
92
+ return z_q, indices
93
+
94
+
95
+ class ResidualVectorQuantize(nn.Module):
96
+ """
97
+ Introduced in SoundStream: An end2end neural audio codec
98
+ https://arxiv.org/abs/2107.03312
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ input_dim: int = 512,
104
+ n_codebooks: int = 9,
105
+ codebook_size: int = 1024,
106
+ codebook_dim: Union[int, list] = 8,
107
+ quantizer_dropout: float = 0.0,
108
+ ):
109
+ super().__init__()
110
+ if isinstance(codebook_dim, int):
111
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
112
+
113
+ self.n_codebooks = n_codebooks
114
+ self.codebook_dim = codebook_dim
115
+ self.codebook_size = codebook_size
116
+
117
+ self.quantizers = nn.ModuleList(
118
+ [VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)]
119
+ )
120
+ self.quantizer_dropout = quantizer_dropout
121
+
122
+ def forward(self, z, n_quantizers: int = None):
123
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
124
+ the corresponding codebook vectors
125
+ Parameters
126
+ ----------
127
+ z : Tensor[B x D x T]
128
+ n_quantizers : int, optional
129
+ No. of quantizers to use
130
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
131
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
132
+ when in training mode, and a random number of quantizers is used.
133
+ Returns
134
+ -------
135
+ dict
136
+ A dictionary with the following keys:
137
+
138
+ "z" : Tensor[B x D x T]
139
+ Quantized continuous representation of input
140
+ "codes" : Tensor[B x N x T]
141
+ Codebook indices for each codebook
142
+ (quantized discrete representation of input)
143
+ "latents" : Tensor[B x N*D x T]
144
+ Projected latents (continuous representation of input before quantization)
145
+ "vq/commitment_loss" : Tensor[1]
146
+ Commitment loss to train encoder to predict vectors closer to codebook
147
+ entries
148
+ "vq/codebook_loss" : Tensor[1]
149
+ Codebook loss to update the codebook
150
+ """
151
+ z_q = 0
152
+ residual = z
153
+ commitment_loss = 0
154
+ codebook_loss = 0
155
+
156
+ codebook_indices = []
157
+ latents = []
158
+
159
+ if n_quantizers is None:
160
+ n_quantizers = self.n_codebooks
161
+ if self.training:
162
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
163
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
164
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
165
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
166
+ n_quantizers = n_quantizers.to(z.device)
167
+
168
+ for i, quantizer in enumerate(self.quantizers):
169
+ if self.training is False and i >= n_quantizers:
170
+ break
171
+
172
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
173
+
174
+ # Create mask to apply quantizer dropout
175
+ mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
176
+ z_q = z_q + z_q_i * mask[:, None, None]
177
+ residual = residual - z_q_i
178
+
179
+ # Sum losses
180
+ commitment_loss += (commitment_loss_i * mask).mean()
181
+ codebook_loss += (codebook_loss_i * mask).mean()
182
+
183
+ codebook_indices.append(indices_i)
184
+ latents.append(z_e_i)
185
+
186
+ codes = torch.stack(codebook_indices, dim=1)
187
+ latents = torch.cat(latents, dim=1)
188
+
189
+ return z_q, codes, latents, commitment_loss, codebook_loss
190
+
191
+ def from_codes(self, codes: torch.Tensor):
192
+ """Given the quantized codes, reconstruct the continuous representation
193
+ Parameters
194
+ ----------
195
+ codes : Tensor[B x N x T]
196
+ Quantized discrete representation of input
197
+ Returns
198
+ -------
199
+ Tensor[B x D x T]
200
+ Quantized continuous representation of input
201
+ """
202
+ z_q = 0.0
203
+ z_p = []
204
+ n_codebooks = codes.shape[1]
205
+ for i in range(n_codebooks):
206
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
207
+ z_p.append(z_p_i)
208
+
209
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
210
+ z_q = z_q + z_q_i
211
+ return z_q, torch.cat(z_p, dim=1), codes
212
+
213
+ def from_latents(self, latents: torch.Tensor):
214
+ """Given the unquantized latents, reconstruct the
215
+ continuous representation after quantization.
216
+
217
+ Parameters
218
+ ----------
219
+ latents : Tensor[B x N x T]
220
+ Continuous representation of input after projection
221
+
222
+ Returns
223
+ -------
224
+ Tensor[B x D x T]
225
+ Quantized representation of full-projected space
226
+ Tensor[B x D x T]
227
+ Quantized representation of latent space
228
+ """
229
+ z_q = 0
230
+ z_p = []
231
+ codes = []
232
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
233
+
234
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
235
+ for i in range(n_codebooks):
236
+ j, k = dims[i], dims[i + 1]
237
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
238
+ z_p.append(z_p_i)
239
+ codes.append(codes_i)
240
+
241
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
242
+ z_q = z_q + z_q_i
243
+
244
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
249
+ x = torch.randn(16, 512, 80)
250
+ y = rvq(x)
251
+ print(y["latents"].shape)
boson_multimodal/audio_processing/higgs_audio_tokenizer.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on code from: https://github.com/zhenye234/xcodec
2
+ # Licensed under MIT License
3
+ # Modifications by BosonAI
4
+
5
+ import math
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Union, Sequence
11
+ import numpy as np
12
+ from transformers import AutoModel
13
+ import torchaudio
14
+ import json
15
+ import librosa
16
+ from huggingface_hub import snapshot_download
17
+
18
+ from vector_quantize_pytorch import ResidualFSQ
19
+ from .descriptaudiocodec.dac.model import dac as dac2
20
+ from .quantization.vq import ResidualVectorQuantizer
21
+ from .semantic_module import Encoder, Decoder
22
+
23
+
24
+ class EncodedResult:
25
+ def __init__(self, audio_codes):
26
+ self.audio_codes = audio_codes
27
+
28
+
29
+ class HiggsAudioFeatureExtractor(nn.Module):
30
+ def __init__(self, sampling_rate=16000):
31
+ super().__init__()
32
+ self.sampling_rate = sampling_rate
33
+
34
+ def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
35
+ # Convert from librosa to torch
36
+ audio_signal = torch.tensor(raw_audio)
37
+ audio_signal = audio_signal.unsqueeze(0)
38
+ if len(audio_signal.shape) < 3:
39
+ audio_signal = audio_signal.unsqueeze(0)
40
+ return {"input_values": audio_signal}
41
+
42
+
43
+ class HiggsAudioTokenizer(nn.Module):
44
+ def __init__(
45
+ self,
46
+ n_filters: int = 32,
47
+ D: int = 128,
48
+ target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
49
+ ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
50
+ sample_rate: int = 16000,
51
+ bins: int = 1024,
52
+ n_q: int = 8,
53
+ codebook_dim: int = None,
54
+ normalize: bool = False,
55
+ causal: bool = False,
56
+ semantic_techer: str = "hubert_base_general",
57
+ last_layer_semantic: bool = True,
58
+ merge_mode: str = "concat",
59
+ downsample_mode: str = "step_down",
60
+ semantic_mode: str = "classic",
61
+ vq_scale: int = 1,
62
+ semantic_sample_rate: int = None,
63
+ device: str = "cuda",
64
+ ):
65
+ super().__init__()
66
+ self.hop_length = np.prod(ratios)
67
+ self.semantic_techer = semantic_techer
68
+
69
+ self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
70
+
71
+ self.target_bandwidths = target_bandwidths
72
+ self.n_q = n_q
73
+ self.sample_rate = sample_rate
74
+ self.encoder = dac2.Encoder(64, ratios, D)
75
+
76
+ self.decoder_2 = dac2.Decoder(D, 1024, ratios)
77
+ self.last_layer_semantic = last_layer_semantic
78
+ self.device = device
79
+ if semantic_techer == "hubert_base":
80
+ self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
81
+ self.semantic_sample_rate = 16000
82
+ self.semantic_dim = 768
83
+ self.encoder_semantic_dim = 768
84
+
85
+ elif semantic_techer == "wavlm_base_plus":
86
+ self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus")
87
+ self.semantic_sample_rate = 16000
88
+ self.semantic_dim = 768
89
+ self.encoder_semantic_dim = 768
90
+
91
+ elif semantic_techer == "hubert_base_general":
92
+ self.semantic_model = AutoModel.from_pretrained("bosonai/hubert_base", trust_remote_code=True)
93
+ self.semantic_sample_rate = 16000
94
+ self.semantic_dim = 768
95
+ self.encoder_semantic_dim = 768
96
+
97
+ # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
98
+ if semantic_sample_rate is not None:
99
+ self.semantic_sample_rate = semantic_sample_rate
100
+
101
+ self.semantic_model.eval()
102
+
103
+ # make the semantic model parameters do not need gradient
104
+ for param in self.semantic_model.parameters():
105
+ param.requires_grad = False
106
+
107
+ self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320)
108
+
109
+ self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale)
110
+ self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim)
111
+ self.decoder_semantic = Decoder(
112
+ code_dim=self.encoder_semantic_dim, output_channels=self.semantic_dim, decode_channels=self.semantic_dim
113
+ )
114
+
115
+ # out_D=D+768
116
+ if isinstance(bins, int): # RVQ
117
+ self.quantizer = ResidualVectorQuantizer(
118
+ dimension=self.quantizer_dim, codebook_dim=codebook_dim, n_q=n_q, bins=bins
119
+ )
120
+ self.quantizer_type = "RVQ"
121
+ else: # RFSQ
122
+ self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
123
+ self.quantizer_type = "RFSQ"
124
+
125
+ self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
126
+ self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
127
+ self.fc_post2 = nn.Linear(self.quantizer_dim, D)
128
+
129
+ self.downsample_mode = downsample_mode
130
+ if downsample_mode == "avg":
131
+ self.semantic_pooling = nn.AvgPool1d(
132
+ kernel_size=self.semantic_downsample_factor, stride=self.semantic_downsample_factor
133
+ )
134
+
135
+ self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate)
136
+
137
+ @property
138
+ def tps(self):
139
+ return self.frame_rate
140
+
141
+ @property
142
+ def sampling_rate(self):
143
+ return self.sample_rate
144
+
145
+ @property
146
+ def num_codebooks(self):
147
+ return self.n_q
148
+
149
+ @property
150
+ def codebook_size(self):
151
+ return self.quantizer_dim
152
+
153
+ def get_last_layer(self):
154
+ return self.decoder.layers[-1].weight
155
+
156
+ def calculate_rec_loss(self, rec, target):
157
+ target = target / target.norm(dim=-1, keepdim=True)
158
+ rec = rec / rec.norm(dim=-1, keepdim=True)
159
+ rec_loss = (1 - (target * rec).sum(-1)).mean()
160
+
161
+ return rec_loss
162
+
163
+ @torch.no_grad()
164
+ def get_regress_target(self, x):
165
+ x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate)
166
+
167
+ if (
168
+ self.semantic_techer == "hubert_base"
169
+ or self.semantic_techer == "hubert_base_general"
170
+ or self.semantic_techer == "wavlm_base_plus"
171
+ ):
172
+ x = x[:, 0, :]
173
+ x = F.pad(x, (160, 160))
174
+ target = self.semantic_model(x, output_hidden_states=True).hidden_states
175
+ target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
176
+
177
+ # average for all layers
178
+ target = target.mean(1)
179
+ # target = target[9]
180
+ # if self.hop_length > 320:
181
+ # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
182
+
183
+ elif self.semantic_techer == "w2v_bert2":
184
+ target = self.semantic_model(x)
185
+
186
+ elif self.semantic_techer.startswith("whisper"):
187
+ if self.last_layer_semantic:
188
+ target = self.semantic_model(x, avg_layers=False)
189
+ else:
190
+ target = self.semantic_model(x, avg_layers=True)
191
+
192
+ elif self.semantic_techer.startswith("mert_music"):
193
+ if self.last_layer_semantic:
194
+ target = self.semantic_model(x, avg_layers=False)
195
+ else:
196
+ target = self.semantic_model(x, avg_layers=True)
197
+
198
+ elif self.semantic_techer.startswith("qwen_audio_omni"):
199
+ target = self.semantic_model(x)
200
+
201
+ if self.downsample_mode == "step_down":
202
+ if self.semantic_downsample_factor > 1:
203
+ target = target[:, :: self.semantic_downsample_factor, :]
204
+
205
+ elif self.downsample_mode == "avg":
206
+ target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
207
+ return target
208
+
209
+ def forward(self, x: torch.Tensor, bw: int):
210
+ e_semantic_input = self.get_regress_target(x).detach()
211
+
212
+ e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
213
+ e_acoustic = self.encoder(x)
214
+
215
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
216
+
217
+ e = self.fc_prior(e.transpose(1, 2))
218
+
219
+ if self.quantizer_type == "RVQ":
220
+ e = e.transpose(1, 2)
221
+ quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
222
+ quantized = quantized.transpose(1, 2)
223
+ else:
224
+ quantized, codes = self.quantizer(e)
225
+ commit_loss = torch.tensor(0.0)
226
+
227
+ quantized_semantic = self.fc_post1(quantized).transpose(1, 2)
228
+ quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
229
+
230
+ o = self.decoder_2(quantized_acoustic)
231
+
232
+ o_semantic = self.decoder_semantic(quantized_semantic)
233
+ semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic)
234
+
235
+ return o, commit_loss, semantic_recon_loss, None
236
+
237
+ def encode(self, audio_path_or_wv, sr=None, loudness_normalize=False, loudness_threshold=-23.0):
238
+ if isinstance(audio_path_or_wv, str):
239
+ wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
240
+ else:
241
+ wv = audio_path_or_wv
242
+ assert sr is not None
243
+ if loudness_normalize:
244
+ import pyloudnorm as pyln
245
+
246
+ meter = pyln.Meter(sr)
247
+ l = meter.integrated_loudness(wv)
248
+ wv = pyln.normalize.loudness(wv, l, loudness_threshold)
249
+ if sr != self.sampling_rate:
250
+ wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate)
251
+ if self.audio_tokenizer_feature_extractor is not None:
252
+ inputs = self.audio_tokenizer_feature_extractor(
253
+ raw_audio=wv, sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate, return_tensors="pt"
254
+ )
255
+ input_values = inputs["input_values"].to(self.device)
256
+ else:
257
+ input_values = torch.from_numpy(wv).float().unsqueeze(0)
258
+ with torch.no_grad():
259
+ encoder_outputs = self._xcodec_encode(input_values)
260
+ vq_code = encoder_outputs.audio_codes[0]
261
+ return vq_code
262
+
263
+ def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
264
+ bw = target_bw
265
+
266
+ e_semantic_input = self.get_regress_target(x).detach()
267
+
268
+ e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
269
+ e_acoustic = self.encoder(x)
270
+
271
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
272
+ pad_size = 160 * self.semantic_downsample_factor
273
+ e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0))
274
+
275
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
276
+ if e_acoustic.shape[2] > e_semantic.shape[2]:
277
+ e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]]
278
+ else:
279
+ e_semantic = e_semantic[:, :, : e_acoustic.shape[2]]
280
+
281
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
282
+
283
+ e = self.fc_prior(e.transpose(1, 2))
284
+
285
+ if self.quantizer_type == "RVQ":
286
+ e = e.transpose(1, 2)
287
+ quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
288
+ codes = codes.permute(1, 0, 2)
289
+ else:
290
+ quantized, codes = self.quantizer(e)
291
+ codes = codes.permute(0, 2, 1)
292
+
293
+ # return codes
294
+ return EncodedResult(codes)
295
+
296
+ def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
297
+ if self.quantizer_type == "RVQ":
298
+ vq_code = vq_code.permute(1, 0, 2)
299
+ quantized = self.quantizer.decode(vq_code)
300
+ quantized = quantized.transpose(1, 2)
301
+ else:
302
+ vq_code = vq_code.permute(0, 2, 1)
303
+ quantized = self.quantizer.get_output_from_indices(vq_code)
304
+ quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
305
+
306
+ o = self.decoder_2(quantized_acoustic)
307
+ return o.cpu().numpy()
308
+
309
+
310
+ def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
311
+ is_local = os.path.exists(tokenizer_name_or_path)
312
+ if not is_local:
313
+ tokenizer_path = snapshot_download(tokenizer_name_or_path)
314
+ else:
315
+ tokenizer_path = tokenizer_name_or_path
316
+ config_path = os.path.join(tokenizer_path, "config.json")
317
+ model_path = os.path.join(tokenizer_path, "model.pth")
318
+ config = json.load(open(config_path))
319
+ model = HiggsAudioTokenizer(
320
+ **config,
321
+ device=device,
322
+ )
323
+ parameter_dict = torch.load(model_path, map_location=device)
324
+ model.load_state_dict(parameter_dict, strict=False)
325
+ model.to(device)
326
+ model.eval()
327
+ return model
boson_multimodal/audio_processing/quantization/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .vq import QuantizedResult, ResidualVectorQuantizer
boson_multimodal/audio_processing/quantization/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (270 Bytes). View file
 
boson_multimodal/audio_processing/quantization/__pycache__/core_vq_lsx_version.cpython-311.pyc ADDED
Binary file (21.6 kB). View file
 
boson_multimodal/audio_processing/quantization/__pycache__/ddp_utils.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
boson_multimodal/audio_processing/quantization/__pycache__/distrib.cpython-311.pyc ADDED
Binary file (6.9 kB). View file
 
boson_multimodal/audio_processing/quantization/__pycache__/vq.cpython-311.pyc ADDED
Binary file (6.58 kB). View file
 
boson_multimodal/audio_processing/quantization/ac.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Arithmetic coder."""
8
+
9
+ import io
10
+ import math
11
+ import random
12
+ import typing as tp
13
+ import torch
14
+
15
+ from ..binary import BitPacker, BitUnpacker
16
+
17
+
18
+ def build_stable_quantized_cdf(
19
+ pdf: torch.Tensor, total_range_bits: int, roundoff: float = 1e-8, min_range: int = 2, check: bool = True
20
+ ) -> torch.Tensor:
21
+ """Turn the given PDF into a quantized CDF that splits
22
+ [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
23
+ to the PDF.
24
+
25
+ Args:
26
+ pdf (torch.Tensor): probability distribution, shape should be `[N]`.
27
+ total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
28
+ during the coding process is `[0, 2 ** total_range_bits - 1]`.
29
+ roundoff (float): will round the pdf up to that level to remove difference coming
30
+ from e.g. evaluating the Language Model on different architectures.
31
+ min_range (int): minimum range width. Should always be at least 2 for numerical
32
+ stability. Use this to avoid pathological behavior is a value
33
+ that is expected to be rare actually happens in real life.
34
+ check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
35
+ """
36
+ pdf = pdf.detach()
37
+ if roundoff:
38
+ pdf = (pdf / roundoff).floor() * roundoff
39
+ # interpolate with uniform distribution to achieve desired minimum probability.
40
+ total_range = 2**total_range_bits
41
+ cardinality = len(pdf)
42
+ alpha = min_range * cardinality / total_range
43
+ assert alpha <= 1, "you must reduce min_range"
44
+ ranges = (((1 - alpha) * total_range) * pdf).floor().long()
45
+ ranges += min_range
46
+ quantized_cdf = torch.cumsum(ranges, dim=-1)
47
+ if min_range < 2:
48
+ raise ValueError("min_range must be at least 2.")
49
+ if check:
50
+ assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
51
+ if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
52
+ raise ValueError("You must increase your total_range_bits.")
53
+ return quantized_cdf
54
+
55
+
56
+ class ArithmeticCoder:
57
+ """ArithmeticCoder,
58
+ Let us take a distribution `p` over `N` symbols, and assume we have a stream
59
+ of random variables `s_t` sampled from `p`. Let us assume that we have a budget
60
+ of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
61
+ corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
62
+ sequence `(s_t)` by doing the following:
63
+
64
+ 1) Initialize the current range to` [0 ** 2 B - 1]`.
65
+ 2) For each time step t, split the current range into contiguous chunks,
66
+ one for each possible outcome, with size roughly proportional to `p`.
67
+ For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
68
+ would be `{[0, 2], [3, 3]}`.
69
+ 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
70
+ 4) When done encoding all the values, just select any value remaining in the range.
71
+
72
+ You will notice that this procedure can fail: for instance if at any point in time
73
+ the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
74
+ possible outcome. Intuitively, the more likely a value is, the less the range width
75
+ will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
76
+ coding scheme, likely outcomes would take less bits, and more of them can be coded
77
+ with a fixed budget.
78
+
79
+ In practice, we do not know `B` ahead of time, but we have a way to inject new bits
80
+ when the current range decreases below a given limit (given by `total_range_bits`), without
81
+ having to redo all the computations. If we encode mostly likely values, we will seldom
82
+ need to inject new bits, but a single rare value can deplete our stock of entropy!
83
+
84
+ In this explanation, we assumed that the distribution `p` was constant. In fact, the present
85
+ code works for any sequence `(p_t)` possibly different for each timestep.
86
+ We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
87
+ the KL between the true distribution and `p_t`, the most efficient the coding will be.
88
+
89
+ Args:
90
+ fo (IO[bytes]): file-like object to which the bytes will be written to.
91
+ total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
92
+ Any time the current range width fall under this limit, new bits will
93
+ be injected to rescale the initial range.
94
+ """
95
+
96
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
97
+ assert total_range_bits <= 30
98
+ self.total_range_bits = total_range_bits
99
+ self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
100
+ self.low: int = 0
101
+ self.high: int = 0
102
+ self.max_bit: int = -1
103
+ self._dbg: tp.List[tp.Any] = []
104
+ self._dbg2: tp.List[tp.Any] = []
105
+
106
+ @property
107
+ def delta(self) -> int:
108
+ """Return the current range width."""
109
+ return self.high - self.low + 1
110
+
111
+ def _flush_common_prefix(self):
112
+ # If self.low and self.high start with the sames bits,
113
+ # those won't change anymore as we always just increase the range
114
+ # by powers of 2, and we can flush them out to the bit stream.
115
+ assert self.high >= self.low, (self.low, self.high)
116
+ assert self.high < 2 ** (self.max_bit + 1)
117
+ while self.max_bit >= 0:
118
+ b1 = self.low >> self.max_bit
119
+ b2 = self.high >> self.max_bit
120
+ if b1 == b2:
121
+ self.low -= b1 << self.max_bit
122
+ self.high -= b1 << self.max_bit
123
+ assert self.high >= self.low, (self.high, self.low, self.max_bit)
124
+ assert self.low >= 0
125
+ self.max_bit -= 1
126
+ self.packer.push(b1)
127
+ else:
128
+ break
129
+
130
+ def push(self, symbol: int, quantized_cdf: torch.Tensor):
131
+ """Push the given symbol on the stream, flushing out bits
132
+ if possible.
133
+
134
+ Args:
135
+ symbol (int): symbol to encode with the AC.
136
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
137
+ to build this from your pdf estimate.
138
+ """
139
+ while self.delta < 2**self.total_range_bits:
140
+ self.low *= 2
141
+ self.high = self.high * 2 + 1
142
+ self.max_bit += 1
143
+
144
+ range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
145
+ range_high = quantized_cdf[symbol].item() - 1
146
+ effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
147
+ effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
148
+ assert self.low <= self.high
149
+ self.high = self.low + effective_high
150
+ self.low = self.low + effective_low
151
+ assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
152
+ self._dbg.append((self.low, self.high))
153
+ self._dbg2.append((self.low, self.high))
154
+ outs = self._flush_common_prefix()
155
+ assert self.low <= self.high
156
+ assert self.max_bit >= -1
157
+ assert self.max_bit <= 61, self.max_bit
158
+ return outs
159
+
160
+ def flush(self):
161
+ """Flush the remaining information to the stream."""
162
+ while self.max_bit >= 0:
163
+ b1 = (self.low >> self.max_bit) & 1
164
+ self.packer.push(b1)
165
+ self.max_bit -= 1
166
+ self.packer.flush()
167
+
168
+
169
+ class ArithmeticDecoder:
170
+ """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
171
+
172
+ Note that this must be called with **exactly** the same parameters and sequence
173
+ of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
174
+
175
+ If the AC encoder current range is [L, H], with `L` and `H` having the some common
176
+ prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
177
+ For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
178
+ `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
179
+ for a specific sequence of symbols and a binary-search allows us to decode those symbols.
180
+ At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
181
+ and we will need to read new bits from the stream and repeat the process.
182
+
183
+ """
184
+
185
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
186
+ self.total_range_bits = total_range_bits
187
+ self.low: int = 0
188
+ self.high: int = 0
189
+ self.current: int = 0
190
+ self.max_bit: int = -1
191
+ self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
192
+ # Following is for debugging
193
+ self._dbg: tp.List[tp.Any] = []
194
+ self._dbg2: tp.List[tp.Any] = []
195
+ self._last: tp.Any = None
196
+
197
+ @property
198
+ def delta(self) -> int:
199
+ return self.high - self.low + 1
200
+
201
+ def _flush_common_prefix(self):
202
+ # Given the current range [L, H], if both have a common prefix,
203
+ # we know we can remove it from our representation to avoid handling large numbers.
204
+ while self.max_bit >= 0:
205
+ b1 = self.low >> self.max_bit
206
+ b2 = self.high >> self.max_bit
207
+ if b1 == b2:
208
+ self.low -= b1 << self.max_bit
209
+ self.high -= b1 << self.max_bit
210
+ self.current -= b1 << self.max_bit
211
+ assert self.high >= self.low
212
+ assert self.low >= 0
213
+ self.max_bit -= 1
214
+ else:
215
+ break
216
+
217
+ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
218
+ """Pull a symbol, reading as many bits from the stream as required.
219
+ This returns `None` when the stream has been exhausted.
220
+
221
+ Args:
222
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
223
+ to build this from your pdf estimate. This must be **exatly**
224
+ the same cdf as the one used at encoding time.
225
+ """
226
+ while self.delta < 2**self.total_range_bits:
227
+ bit = self.unpacker.pull()
228
+ if bit is None:
229
+ return None
230
+ self.low *= 2
231
+ self.high = self.high * 2 + 1
232
+ self.current = self.current * 2 + bit
233
+ self.max_bit += 1
234
+
235
+ def bin_search(low_idx: int, high_idx: int):
236
+ # Binary search is not just for coding interviews :)
237
+ if high_idx < low_idx:
238
+ raise RuntimeError("Binary search failed")
239
+ mid = (low_idx + high_idx) // 2
240
+ range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
241
+ range_high = quantized_cdf[mid].item() - 1
242
+ effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
243
+ effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
244
+ low = effective_low + self.low
245
+ high = effective_high + self.low
246
+ if self.current >= low:
247
+ if self.current <= high:
248
+ return (mid, low, high, self.current)
249
+ else:
250
+ return bin_search(mid + 1, high_idx)
251
+ else:
252
+ return bin_search(low_idx, mid - 1)
253
+
254
+ self._last = (self.low, self.high, self.current, self.max_bit)
255
+ sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
256
+ self._dbg.append((self.low, self.high, self.current))
257
+ self._flush_common_prefix()
258
+ self._dbg2.append((self.low, self.high, self.current))
259
+
260
+ return sym
261
+
262
+
263
+ def test():
264
+ torch.manual_seed(1234)
265
+ random.seed(1234)
266
+ for _ in range(4):
267
+ pdfs = []
268
+ cardinality = random.randrange(4000)
269
+ steps = random.randrange(100, 500)
270
+ fo = io.BytesIO()
271
+ encoder = ArithmeticCoder(fo)
272
+ symbols = []
273
+ for step in range(steps):
274
+ pdf = torch.softmax(torch.randn(cardinality), dim=0)
275
+ pdfs.append(pdf)
276
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
277
+ symbol = torch.multinomial(pdf, 1).item()
278
+ symbols.append(symbol)
279
+ encoder.push(symbol, q_cdf)
280
+ encoder.flush()
281
+
282
+ fo.seek(0)
283
+ decoder = ArithmeticDecoder(fo)
284
+ for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
285
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
286
+ decoded_symbol = decoder.pull(q_cdf)
287
+ assert decoded_symbol == symbol, idx
288
+ assert decoder.pull(torch.zeros(1)) is None
289
+
290
+
291
+ if __name__ == "__main__":
292
+ test()
boson_multimodal/audio_processing/quantization/core_vq.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+
34
+ import typing as tp
35
+
36
+ from einops import rearrange, repeat
37
+ import torch
38
+ from torch import nn
39
+ import torch.nn.functional as F
40
+
41
+ from xcodec.quantization.distrib import broadcast_tensors, rank
42
+
43
+
44
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
45
+ return val if val is not None else d
46
+
47
+
48
+ def ema_inplace(moving_avg, new, decay: float):
49
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
50
+
51
+
52
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
53
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
54
+
55
+
56
+ def uniform_init(*shape: int):
57
+ t = torch.empty(shape)
58
+ nn.init.kaiming_uniform_(t)
59
+ return t
60
+
61
+
62
+ def sample_vectors(samples, num: int):
63
+ num_samples, device = samples.shape[0], samples.device
64
+
65
+ if num_samples >= num:
66
+ indices = torch.randperm(num_samples, device=device)[:num]
67
+ else:
68
+ indices = torch.randint(0, num_samples, (num,), device=device)
69
+
70
+ return samples[indices]
71
+
72
+
73
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
74
+ dim, dtype = samples.shape[-1], samples.dtype
75
+
76
+ means = sample_vectors(samples, num_clusters)
77
+
78
+ for _ in range(num_iters):
79
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
80
+ dists = -(diffs**2).sum(dim=-1)
81
+
82
+ buckets = dists.max(dim=-1).indices
83
+ bins = torch.bincount(buckets, minlength=num_clusters)
84
+ zero_mask = bins == 0
85
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
86
+
87
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
88
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
89
+ new_means = new_means / bins_min_clamped[..., None]
90
+
91
+ means = torch.where(zero_mask[..., None], means, new_means)
92
+
93
+ return means, bins
94
+
95
+
96
+ class EuclideanCodebook(nn.Module):
97
+ """Codebook with Euclidean distance.
98
+ Args:
99
+ dim (int): Dimension.
100
+ codebook_size (int): Codebook size.
101
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
102
+ If set to true, run the k-means algorithm on the first training batch and use
103
+ the learned centroids as initialization.
104
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
105
+ decay (float): Decay for exponential moving average over the codebooks.
106
+ epsilon (float): Epsilon value for numerical stability.
107
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
108
+ that have an exponential moving average cluster size less than the specified threshold with
109
+ randomly selected vector from the current batch.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ dim: int,
115
+ codebook_size: int,
116
+ kmeans_init: int = False,
117
+ kmeans_iters: int = 10,
118
+ decay: float = 0.99,
119
+ epsilon: float = 1e-5,
120
+ threshold_ema_dead_code: int = 2,
121
+ ):
122
+ super().__init__()
123
+ self.decay = decay
124
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
125
+ embed = init_fn(codebook_size, dim)
126
+
127
+ self.codebook_size = codebook_size
128
+
129
+ self.kmeans_iters = kmeans_iters
130
+ self.epsilon = epsilon
131
+ self.threshold_ema_dead_code = threshold_ema_dead_code
132
+
133
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
134
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
135
+ self.register_buffer("embed", embed)
136
+ self.register_buffer("embed_avg", embed.clone())
137
+
138
+ @torch.jit.ignore
139
+ def init_embed_(self, data):
140
+ if self.inited:
141
+ return
142
+
143
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
144
+ self.embed.data.copy_(embed)
145
+ self.embed_avg.data.copy_(embed.clone())
146
+ self.cluster_size.data.copy_(cluster_size)
147
+ self.inited.data.copy_(torch.Tensor([True]))
148
+ # Make sure all buffers across workers are in sync after initialization
149
+ broadcast_tensors(self.buffers())
150
+
151
+ def replace_(self, samples, mask):
152
+ modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
153
+ self.embed.data.copy_(modified_codebook)
154
+
155
+ def expire_codes_(self, batch_samples):
156
+ if self.threshold_ema_dead_code == 0:
157
+ return
158
+
159
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
160
+ if not torch.any(expired_codes):
161
+ return
162
+
163
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
164
+ self.replace_(batch_samples, mask=expired_codes)
165
+ broadcast_tensors(self.buffers())
166
+
167
+ def preprocess(self, x):
168
+ x = rearrange(x, "... d -> (...) d")
169
+ return x
170
+
171
+ def quantize(self, x):
172
+ embed = self.embed.t()
173
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
174
+ embed_ind = dist.max(dim=-1).indices
175
+ return embed_ind
176
+
177
+ def postprocess_emb(self, embed_ind, shape):
178
+ return embed_ind.view(*shape[:-1])
179
+
180
+ def dequantize(self, embed_ind):
181
+ quantize = F.embedding(embed_ind, self.embed) # get embedding based on index
182
+ return quantize
183
+
184
+ def encode(self, x):
185
+ shape = x.shape
186
+ # pre-process
187
+ x = self.preprocess(x)
188
+ # quantize
189
+ embed_ind = self.quantize(x) # get index based on Euclidean distance
190
+ # post-process
191
+ embed_ind = self.postprocess_emb(embed_ind, shape)
192
+ return embed_ind
193
+
194
+ def decode(self, embed_ind):
195
+ quantize = self.dequantize(embed_ind)
196
+ return quantize
197
+
198
+ def forward(self, x):
199
+ shape, dtype = x.shape, x.dtype
200
+ x = self.preprocess(x)
201
+
202
+ self.init_embed_(x)
203
+
204
+ embed_ind = self.quantize(x)
205
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
206
+ embed_ind = self.postprocess_emb(embed_ind, shape)
207
+ quantize = self.dequantize(embed_ind)
208
+
209
+ if self.training:
210
+ # We do the expiry of code at that point as buffers are in sync
211
+ # and all the workers will take the same decision.
212
+ self.expire_codes_(x)
213
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
214
+ embed_sum = x.t() @ embed_onehot
215
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
216
+ cluster_size = (
217
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
218
+ )
219
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
220
+ self.embed.data.copy_(embed_normalized)
221
+
222
+ return quantize, embed_ind
223
+
224
+
225
+ class VectorQuantization(nn.Module):
226
+ """Vector quantization implementation.
227
+ Currently supports only euclidean distance.
228
+ Args:
229
+ dim (int): Dimension
230
+ codebook_size (int): Codebook size
231
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
232
+ decay (float): Decay for exponential moving average over the codebooks.
233
+ epsilon (float): Epsilon value for numerical stability.
234
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
235
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
236
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
237
+ that have an exponential moving average cluster size less than the specified threshold with
238
+ randomly selected vector from the current batch.
239
+ commitment_weight (float): Weight for commitment loss.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ dim: int,
245
+ codebook_size: int,
246
+ codebook_dim: tp.Optional[int] = None,
247
+ decay: float = 0.99,
248
+ epsilon: float = 1e-5,
249
+ kmeans_init: bool = True,
250
+ kmeans_iters: int = 50,
251
+ threshold_ema_dead_code: int = 2,
252
+ commitment_weight: float = 1.0,
253
+ ):
254
+ super().__init__()
255
+ _codebook_dim: int = default(codebook_dim, dim)
256
+
257
+ requires_projection = _codebook_dim != dim
258
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
259
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
260
+
261
+ self.epsilon = epsilon
262
+ self.commitment_weight = commitment_weight
263
+
264
+ self._codebook = EuclideanCodebook(
265
+ dim=_codebook_dim,
266
+ codebook_size=codebook_size,
267
+ kmeans_init=kmeans_init,
268
+ kmeans_iters=kmeans_iters,
269
+ decay=decay,
270
+ epsilon=epsilon,
271
+ threshold_ema_dead_code=threshold_ema_dead_code,
272
+ )
273
+ self.codebook_size = codebook_size
274
+
275
+ @property
276
+ def codebook(self):
277
+ return self._codebook.embed
278
+
279
+ def encode(self, x):
280
+ x = rearrange(x, "b d n -> b n d")
281
+ x = self.project_in(x)
282
+ embed_in = self._codebook.encode(x)
283
+ return embed_in
284
+
285
+ def decode(self, embed_ind):
286
+ quantize = self._codebook.decode(embed_ind)
287
+ quantize = self.project_out(quantize)
288
+ quantize = rearrange(quantize, "b n d -> b d n")
289
+ return quantize
290
+
291
+ def forward(self, x):
292
+ device = x.device
293
+ x = rearrange(x, "b d n -> b n d")
294
+ x = self.project_in(x)
295
+
296
+ quantize, embed_ind = self._codebook(x)
297
+
298
+ if self.training:
299
+ quantize = x + (quantize - x).detach()
300
+
301
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
302
+
303
+ if self.training:
304
+ if self.commitment_weight > 0:
305
+ commit_loss = F.mse_loss(quantize.detach(), x)
306
+ loss = loss + commit_loss * self.commitment_weight
307
+
308
+ quantize = self.project_out(quantize)
309
+ quantize = rearrange(quantize, "b n d -> b d n")
310
+ return quantize, embed_ind, loss
311
+
312
+
313
+ class ResidualVectorQuantization(nn.Module):
314
+ """Residual vector quantization implementation.
315
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
316
+ """
317
+
318
+ def __init__(self, *, num_quantizers, **kwargs):
319
+ super().__init__()
320
+ self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
321
+
322
+ def forward(self, x, n_q: tp.Optional[int] = None):
323
+ quantized_out = 0.0
324
+ residual = x
325
+
326
+ all_losses = []
327
+ all_indices = []
328
+
329
+ n_q = n_q or len(self.layers)
330
+
331
+ for layer in self.layers[:n_q]:
332
+ quantized, indices, loss = layer(residual)
333
+ residual = residual - quantized
334
+ quantized_out = quantized_out + quantized
335
+
336
+ all_indices.append(indices)
337
+ all_losses.append(loss)
338
+
339
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
340
+ return quantized_out, out_indices, out_losses
341
+
342
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
343
+ residual = x
344
+ all_indices = []
345
+ n_q = n_q or len(self.layers)
346
+ for layer in self.layers[:n_q]:
347
+ indices = layer.encode(residual)
348
+ quantized = layer.decode(indices)
349
+ residual = residual - quantized
350
+ all_indices.append(indices)
351
+ out_indices = torch.stack(all_indices)
352
+ return out_indices
353
+
354
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
355
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
356
+ for i, indices in enumerate(q_indices):
357
+ layer = self.layers[i]
358
+ quantized = layer.decode(indices)
359
+ quantized_out = quantized_out + quantized
360
+ return quantized_out
boson_multimodal/audio_processing/quantization/core_vq_lsx_version.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c)
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # This implementation is inspired from
6
+ # https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and
7
+ # https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81
8
+ #
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ # All rights reserved.
11
+ #
12
+ # This source code is licensed under the license found in the
13
+ # LICENSE file in the root directory of this source tree.
14
+ #
15
+ # This implementation is inspired from
16
+ # https://github.com/lucidrains/vector-quantize-pytorch
17
+ # which is released under MIT License. Hereafter, the original license:
18
+ # MIT License
19
+ #
20
+ # Copyright (c) 2020 Phil Wang
21
+ #
22
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
23
+ # of this software and associated documentation files (the "Software"), to deal
24
+ # in the Software without restriction, including without limitation the rights
25
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
26
+ # copies of the Software, and to permit persons to whom the Software is
27
+ # furnished to do so, subject to the following conditions:
28
+ #
29
+ # The above copyright notice and this permission notice shall be included in all
30
+ # copies or substantial portions of the Software.
31
+ #
32
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38
+ # SOFTWARE.
39
+
40
+ """Core vector quantization implementation."""
41
+
42
+ import typing as tp
43
+
44
+ from einops import rearrange
45
+ import torch
46
+ from torch import nn
47
+ import torch.nn.functional as F
48
+ import torch.distributed as dist
49
+
50
+ from .distrib import broadcast_tensors, is_distributed
51
+ from .ddp_utils import SyncFunction
52
+
53
+
54
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
55
+ return val if val is not None else d
56
+
57
+
58
+ def ema_inplace(moving_avg, new, decay: float):
59
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
60
+
61
+
62
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
63
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
64
+
65
+
66
+ def uniform_init(*shape: int):
67
+ t = torch.empty(shape)
68
+ nn.init.kaiming_uniform_(t)
69
+ return t
70
+
71
+
72
+ def sample_vectors(samples, num: int):
73
+ num_samples, device = samples.shape[0], samples.device
74
+
75
+ if num_samples >= num:
76
+ indices = torch.randperm(num_samples, device=device)[:num]
77
+ else:
78
+ indices = torch.randint(0, num_samples, (num,), device=device)
79
+
80
+ return samples[indices]
81
+
82
+
83
+ def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64):
84
+ """
85
+ Memory-efficient K-means clustering.
86
+ Args:
87
+ samples (tensor): shape [N, D]
88
+ num_clusters (int): number of centroids.
89
+ num_iters (int): number of iterations.
90
+ frames_to_use (int): subsample size from total samples.
91
+ batch_size (int): batch size used in distance computation.
92
+ Returns:
93
+ means: [num_clusters, D]
94
+ bins: [num_clusters] (number of points per cluster)
95
+ """
96
+ N, D = samples.shape
97
+ dtype, device = samples.dtype, samples.device
98
+
99
+ if frames_to_use < N:
100
+ indices = torch.randperm(N, device=device)[:frames_to_use]
101
+ samples = samples[indices]
102
+
103
+ means = sample_vectors(samples, num_clusters)
104
+
105
+ for _ in range(num_iters):
106
+ # Store cluster assignments
107
+ all_assignments = []
108
+
109
+ for i in range(0, samples.shape[0], batch_size):
110
+ batch = samples[i : i + batch_size] # [B, D]
111
+ dists = torch.cdist(batch, means, p=2) # [B, C]
112
+ assignments = dists.argmin(dim=1) # [B]
113
+ all_assignments.append(assignments)
114
+
115
+ buckets = torch.cat(all_assignments, dim=0) # [N]
116
+ bins = torch.bincount(buckets, minlength=num_clusters)
117
+ zero_mask = bins == 0
118
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
119
+
120
+ # Compute new means
121
+ new_means = torch.zeros_like(means)
122
+ for i in range(num_clusters):
123
+ mask = buckets == i
124
+ if mask.any():
125
+ new_means[i] = samples[mask].mean(dim=0)
126
+
127
+ means = torch.where(zero_mask[:, None], means, new_means)
128
+
129
+ return means, bins
130
+
131
+
132
+ class EuclideanCodebook(nn.Module):
133
+ """Codebook with Euclidean distance.
134
+ Args:
135
+ dim (int): Dimension.
136
+ codebook_size (int): Codebook size.
137
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
138
+ If set to true, run the k-means algorithm on the first training batch and use
139
+ the learned centroids as initialization.
140
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
141
+ decay (float): Decay for exponential moving average over the codebooks.
142
+ epsilon (float): Epsilon value for numerical stability.
143
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
144
+ that have an exponential moving average cluster size less than the specified threshold with
145
+ randomly selected vector from the current batch.
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ dim: int,
151
+ codebook_size: int,
152
+ kmeans_init: int = False,
153
+ kmeans_iters: int = 10,
154
+ decay: float = 0.99,
155
+ epsilon: float = 1e-5,
156
+ threshold_ema_dead_code: int = 2,
157
+ ):
158
+ super().__init__()
159
+ self.decay = decay
160
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
161
+ embed = init_fn(codebook_size, dim)
162
+
163
+ self.codebook_size = codebook_size
164
+
165
+ self.kmeans_iters = kmeans_iters
166
+ self.epsilon = epsilon
167
+ self.threshold_ema_dead_code = threshold_ema_dead_code
168
+
169
+ # Flag variable to indicate whether the codebook is initialized
170
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
171
+ # Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper
172
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
173
+ # Codebook
174
+ self.register_buffer("embed", embed)
175
+ # EMA codebook: eq. (7) in vqvae paper
176
+ self.register_buffer("embed_avg", embed.clone())
177
+
178
+ @torch.jit.ignore
179
+ def init_embed_(self, data):
180
+ """Initialize codebook.
181
+ Args:
182
+ data (tensor): [B * T, D].
183
+ """
184
+ if self.inited:
185
+ return
186
+
187
+ ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
188
+ if dist.is_available() and dist.is_initialized():
189
+ # [B * T * world_size, D]
190
+ data = SyncFunction.apply(data)
191
+
192
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
193
+ self.embed.data.copy_(embed)
194
+ self.embed_avg.data.copy_(embed.clone())
195
+ self.cluster_size.data.copy_(cluster_size)
196
+ self.inited.data.copy_(torch.Tensor([True]))
197
+ # Make sure all buffers across workers are in sync after initialization
198
+ broadcast_tensors(self.buffers())
199
+
200
+ def replace_(self, samples, mask):
201
+ modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
202
+ self.embed.data.copy_(modified_codebook)
203
+
204
+ def expire_codes_(self, batch_samples):
205
+ if self.threshold_ema_dead_code == 0:
206
+ return
207
+
208
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
209
+ if not torch.any(expired_codes):
210
+ return
211
+
212
+ ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
213
+ if is_distributed():
214
+ # [B * T * world_size, D]
215
+ batch_samples = SyncFunction.apply(batch_samples)
216
+
217
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
218
+ self.replace_(batch_samples, mask=expired_codes)
219
+ broadcast_tensors(self.buffers())
220
+
221
+ def preprocess(self, x):
222
+ x = rearrange(x, "... d -> (...) d")
223
+ return x
224
+
225
+ def quantize(self, x):
226
+ embed = self.embed.t()
227
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
228
+ embed_ind = dist.max(dim=-1).indices
229
+ return embed_ind
230
+
231
+ def postprocess_emb(self, embed_ind, shape):
232
+ return embed_ind.view(*shape[:-1])
233
+
234
+ def dequantize(self, embed_ind):
235
+ quantize = F.embedding(embed_ind, self.embed)
236
+ return quantize
237
+
238
+ def encode(self, x):
239
+ shape = x.shape
240
+ # pre-process
241
+ x = self.preprocess(x) # [B, T, D] -> [B*T, D]
242
+ # quantize
243
+ embed_ind = self.quantize(x)
244
+ # post-process
245
+ embed_ind = self.postprocess_emb(embed_ind, shape)
246
+ return embed_ind
247
+
248
+ def decode(self, embed_ind):
249
+ quantize = self.dequantize(embed_ind)
250
+ return quantize
251
+
252
+ def forward(self, x):
253
+ # shape: [B, T, D]
254
+ shape, dtype = x.shape, x.dtype
255
+ x = self.preprocess(x) # [B, T, D] -> [B*T, D]
256
+
257
+ # Initialize codebook
258
+ self.init_embed_(x)
259
+
260
+ embed_ind = self.quantize(x) # [B*T,]
261
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size]
262
+ embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T]
263
+ quantize = self.dequantize(embed_ind) # [B, T, D]
264
+
265
+ if self.training:
266
+ ### Update codebook by EMA
267
+ embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
268
+ embed_sum = x.t() @ embed_onehot # [D, cb-size]
269
+ if is_distributed():
270
+ dist.all_reduce(embed_onehot_sum)
271
+ dist.all_reduce(embed_sum)
272
+ # Update ema cluster count N_i^t, eq. (6) in vqvae paper
273
+ self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
274
+ # Update ema embed: eq. (7) in vqvae paper
275
+ self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
276
+ # apply laplace smoothing
277
+ n = self.cluster_size.sum()
278
+ cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
279
+ # Update ema embed: eq. (8) in vqvae paper
280
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
281
+ self.embed.data.copy_(embed_normalized)
282
+
283
+ # We do the expiry of code at that point as buffers are in sync
284
+ # and all the workers will take the same decision.
285
+ self.expire_codes_(x)
286
+
287
+ return quantize, embed_ind
288
+
289
+
290
+ class VectorQuantization(nn.Module):
291
+ """Vector quantization implementation.
292
+ Currently supports only euclidean distance.
293
+ Args:
294
+ dim (int): Dimension
295
+ codebook_size (int): Codebook size
296
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
297
+ decay (float): Decay for exponential moving average over the codebooks.
298
+ epsilon (float): Epsilon value for numerical stability.
299
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
300
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
301
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
302
+ that have an exponential moving average cluster size less than the specified threshold with
303
+ randomly selected vector from the current batch.
304
+ commitment_weight (float): Weight for commitment loss.
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ dim: int,
310
+ codebook_size: int,
311
+ codebook_dim: tp.Optional[int] = None,
312
+ decay: float = 0.99,
313
+ epsilon: float = 1e-5,
314
+ kmeans_init: bool = True,
315
+ kmeans_iters: int = 50,
316
+ threshold_ema_dead_code: int = 2,
317
+ commitment_weight: float = 1.0,
318
+ ):
319
+ super().__init__()
320
+ _codebook_dim: int = default(codebook_dim, dim)
321
+
322
+ requires_projection = _codebook_dim != dim
323
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
324
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
325
+
326
+ self.epsilon = epsilon
327
+ self.commitment_weight = commitment_weight
328
+
329
+ self._codebook = EuclideanCodebook(
330
+ dim=_codebook_dim,
331
+ codebook_size=codebook_size,
332
+ kmeans_init=kmeans_init,
333
+ kmeans_iters=kmeans_iters,
334
+ decay=decay,
335
+ epsilon=epsilon,
336
+ threshold_ema_dead_code=threshold_ema_dead_code,
337
+ )
338
+ self.codebook_size = codebook_size
339
+
340
+ @property
341
+ def codebook(self):
342
+ return self._codebook.embed
343
+
344
+ def encode(self, x):
345
+ x = rearrange(x, "b d n -> b n d")
346
+ x = self.project_in(x)
347
+ embed_in = self._codebook.encode(x)
348
+ return embed_in
349
+
350
+ def decode(self, embed_ind):
351
+ quantize = self._codebook.decode(embed_ind)
352
+ quantize = self.project_out(quantize)
353
+ quantize = rearrange(quantize, "b n d -> b d n")
354
+ return quantize
355
+
356
+ def forward(self, x):
357
+ device = x.device
358
+ x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d]
359
+ x = self.project_in(x)
360
+
361
+ quantize, embed_ind = self._codebook(x)
362
+
363
+ if self.training:
364
+ quantize = x + (quantize - x).detach()
365
+
366
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
367
+
368
+ if self.training:
369
+ if self.commitment_weight > 0:
370
+ commit_loss = F.mse_loss(quantize.detach(), x)
371
+ loss = loss + commit_loss * self.commitment_weight
372
+
373
+ quantize = self.project_out(quantize)
374
+ quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n]
375
+ return quantize, embed_ind, loss
376
+
377
+
378
+ class ResidualVectorQuantization(nn.Module):
379
+ """Residual vector quantization implementation.
380
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
381
+ """
382
+
383
+ def __init__(self, *, num_quantizers, **kwargs):
384
+ super().__init__()
385
+ self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
386
+
387
+ def forward(self, x, n_q: tp.Optional[int] = None):
388
+ quantized_out = 0.0
389
+ residual = x
390
+
391
+ all_losses = []
392
+ all_indices = []
393
+
394
+ n_q = n_q or len(self.layers)
395
+
396
+ for layer in self.layers[:n_q]:
397
+ quantized, indices, loss = layer(residual)
398
+ residual = residual - quantized
399
+ quantized_out = quantized_out + quantized
400
+
401
+ all_indices.append(indices)
402
+ all_losses.append(loss)
403
+
404
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
405
+ return quantized_out, out_indices, out_losses
406
+
407
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
408
+ residual = x
409
+ all_indices = []
410
+ n_q = n_q or len(self.layers)
411
+ for layer in self.layers[:n_q]:
412
+ indices = layer.encode(residual)
413
+ quantized = layer.decode(indices)
414
+ residual = residual - quantized
415
+ all_indices.append(indices)
416
+ out_indices = torch.stack(all_indices)
417
+ return out_indices
418
+
419
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
420
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
421
+ for i, indices in enumerate(q_indices):
422
+ layer = self.layers[i]
423
+ quantized = layer.decode(indices)
424
+ quantized_out = quantized_out + quantized
425
+ return quantized_out
boson_multimodal/audio_processing/quantization/ddp_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ import subprocess
4
+ from datetime import datetime
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel import DistributedDataParallel
10
+ from torch.nn.parallel.distributed import _find_tensors
11
+ import torch.optim
12
+ import torch.utils.data
13
+ from packaging import version
14
+ from omegaconf import OmegaConf
15
+
16
+
17
+ def set_random_seed(seed):
18
+ random.seed(seed)
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+
23
+
24
+ def is_logging_process():
25
+ return not dist.is_initialized() or dist.get_rank() == 0
26
+
27
+
28
+ def get_logger(cfg, name=None):
29
+ # log_file_path is used when unit testing
30
+ if is_logging_process():
31
+ logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True))
32
+ return logging.getLogger(name)
33
+
34
+
35
+ # from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
36
+ class SyncFunction(torch.autograd.Function):
37
+ @staticmethod
38
+ # @torch.no_grad()
39
+ def forward(ctx, tensor):
40
+ ctx.batch_size = tensor.shape[0]
41
+
42
+ gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
43
+
44
+ torch.distributed.all_gather(gathered_tensor, tensor)
45
+ gathered_tensor = torch.cat(gathered_tensor, 0)
46
+
47
+ return gathered_tensor
48
+
49
+ @staticmethod
50
+ def backward(ctx, grad_output):
51
+ grad_input = grad_output.clone()
52
+ torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
53
+
54
+ idx_from = torch.distributed.get_rank() * ctx.batch_size
55
+ idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
56
+ return grad_input[idx_from:idx_to]
57
+
58
+
59
+ def get_timestamp():
60
+ return datetime.now().strftime("%y%m%d-%H%M%S")
61
+
62
+
63
+ def get_commit_hash():
64
+ message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
65
+ return message.strip().decode("utf-8")
66
+
67
+
68
+ class DDP(DistributedDataParallel):
69
+ """
70
+ Override the forward call in lightning so it goes to training and validation step respectively
71
+ """
72
+
73
+ def forward(self, *inputs, **kwargs): # pragma: no cover
74
+ if version.parse(torch.__version__[:6]) < version.parse("1.11"):
75
+ self._sync_params()
76
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
77
+ assert len(self.device_ids) == 1
78
+ if self.module.training:
79
+ output = self.module.training_step(*inputs[0], **kwargs[0])
80
+ elif self.module.testing:
81
+ output = self.module.test_step(*inputs[0], **kwargs[0])
82
+ else:
83
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
84
+ if torch.is_grad_enabled():
85
+ # We'll return the output object verbatim since it is a freeform
86
+ # object. We need to find any tensors in this object, though,
87
+ # because we need to figure out which parameters were used during
88
+ # this forward pass, to ensure we short circuit reduction for any
89
+ # unused parameters. Only if `find_unused_parameters` is set.
90
+ if self.find_unused_parameters:
91
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
92
+ else:
93
+ self.reducer.prepare_for_backward([])
94
+ else:
95
+ from torch.nn.parallel.distributed import (
96
+ logging,
97
+ Join,
98
+ _DDPSink,
99
+ _tree_flatten_with_rref,
100
+ _tree_unflatten_with_rref,
101
+ )
102
+
103
+ with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
104
+ if torch.is_grad_enabled() and self.require_backward_grad_sync:
105
+ self.logger.set_runtime_stats_and_log()
106
+ self.num_iterations += 1
107
+ self.reducer.prepare_for_forward()
108
+
109
+ # Notify the join context that this process has not joined, if
110
+ # needed
111
+ work = Join.notify_join_context(self)
112
+ if work:
113
+ self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
114
+
115
+ # Calling _rebuild_buckets before forward compuation,
116
+ # It may allocate new buckets before deallocating old buckets
117
+ # inside _rebuild_buckets. To save peak memory usage,
118
+ # call _rebuild_buckets before the peak memory usage increases
119
+ # during forward computation.
120
+ # This should be called only once during whole training period.
121
+ if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
122
+ logging.info("Reducer buckets have been rebuilt in this iteration.")
123
+ self._has_rebuilt_buckets = True
124
+
125
+ # sync params according to location (before/after forward) user
126
+ # specified as part of hook, if hook was specified.
127
+ buffer_hook_registered = hasattr(self, "buffer_hook")
128
+ if self._check_sync_bufs_pre_fwd():
129
+ self._sync_buffers()
130
+
131
+ if self._join_config.enable:
132
+ # Notify joined ranks whether they should sync in backwards pass or not.
133
+ self._check_global_requires_backward_grad_sync(is_joined_rank=False)
134
+
135
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
136
+ if self.module.training:
137
+ output = self.module.training_step(*inputs[0], **kwargs[0])
138
+ elif self.module.testing:
139
+ output = self.module.test_step(*inputs[0], **kwargs[0])
140
+ else:
141
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
142
+
143
+ # sync params according to location (before/after forward) user
144
+ # specified as part of hook, if hook was specified.
145
+ if self._check_sync_bufs_post_fwd():
146
+ self._sync_buffers()
147
+
148
+ if torch.is_grad_enabled() and self.require_backward_grad_sync:
149
+ self.require_forward_param_sync = True
150
+ # We'll return the output object verbatim since it is a freeform
151
+ # object. We need to find any tensors in this object, though,
152
+ # because we need to figure out which parameters were used during
153
+ # this forward pass, to ensure we short circuit reduction for any
154
+ # unused parameters. Only if `find_unused_parameters` is set.
155
+ if self.find_unused_parameters and not self.static_graph:
156
+ # Do not need to populate this for static graph.
157
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
158
+ else:
159
+ self.reducer.prepare_for_backward([])
160
+ else:
161
+ self.require_forward_param_sync = False
162
+
163
+ # TODO: DDPSink is currently enabled for unused parameter detection and
164
+ # static graph training for first iteration.
165
+ if (self.find_unused_parameters and not self.static_graph) or (
166
+ self.static_graph and self.num_iterations == 1
167
+ ):
168
+ state_dict = {
169
+ "static_graph": self.static_graph,
170
+ "num_iterations": self.num_iterations,
171
+ }
172
+
173
+ output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
174
+ output_placeholders = [None for _ in range(len(output_tensor_list))]
175
+ # Do not touch tensors that have no grad_fn, which can cause issues
176
+ # such as https://github.com/pytorch/pytorch/issues/60733
177
+ for i, output in enumerate(output_tensor_list):
178
+ if torch.is_tensor(output) and output.grad_fn is None:
179
+ output_placeholders[i] = output
180
+
181
+ # When find_unused_parameters=True, makes tensors which require grad
182
+ # run through the DDPSink backward pass. When not all outputs are
183
+ # used in loss, this makes those corresponding tensors receive
184
+ # undefined gradient which the reducer then handles to ensure
185
+ # param.grad field is not touched and we don't error out.
186
+ passthrough_tensor_list = _DDPSink.apply(
187
+ self.reducer,
188
+ state_dict,
189
+ *output_tensor_list,
190
+ )
191
+ for i in range(len(output_placeholders)):
192
+ if output_placeholders[i] is None:
193
+ output_placeholders[i] = passthrough_tensor_list[i]
194
+
195
+ # Reconstruct output data structure.
196
+ output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
197
+ return output
boson_multimodal/audio_processing/quantization/distrib.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Torch distributed utilities."""
8
+
9
+ import typing as tp
10
+
11
+ import torch
12
+
13
+
14
+ def rank():
15
+ if torch.distributed.is_initialized():
16
+ return torch.distributed.get_rank()
17
+ else:
18
+ return 0
19
+
20
+
21
+ def world_size():
22
+ if torch.distributed.is_initialized():
23
+ return torch.distributed.get_world_size()
24
+ else:
25
+ return 1
26
+
27
+
28
+ def is_distributed():
29
+ return world_size() > 1
30
+
31
+
32
+ def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
33
+ if is_distributed():
34
+ return torch.distributed.all_reduce(tensor, op)
35
+
36
+
37
+ def _is_complex_or_float(tensor):
38
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
39
+
40
+
41
+ def _check_number_of_params(params: tp.List[torch.Tensor]):
42
+ # utility function to check that the number of params in all workers is the same,
43
+ # and thus avoid a deadlock with distributed all reduce.
44
+ if not is_distributed() or not params:
45
+ return
46
+ # print('params[0].device ', params[0].device)
47
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
48
+ all_reduce(tensor)
49
+ if tensor.item() != len(params) * world_size():
50
+ # If not all the workers have the same number, for at least one of them,
51
+ # this inequality will be verified.
52
+ raise RuntimeError(
53
+ f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
54
+ )
55
+
56
+
57
+ def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
58
+ """Broadcast the tensors from the given parameters to all workers.
59
+ This can be used to ensure that all workers have the same model to start with.
60
+ """
61
+ if not is_distributed():
62
+ return
63
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
64
+ _check_number_of_params(tensors)
65
+ handles = []
66
+ for tensor in tensors:
67
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
68
+ handles.append(handle)
69
+ for handle in handles:
70
+ handle.wait()
71
+
72
+
73
+ def sync_buffer(buffers, average=True):
74
+ """
75
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
76
+ """
77
+ if not is_distributed():
78
+ return
79
+ handles = []
80
+ for buffer in buffers:
81
+ if torch.is_floating_point(buffer.data):
82
+ if average:
83
+ handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
84
+ else:
85
+ handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
86
+ handles.append((buffer, handle))
87
+ for buffer, handle in handles:
88
+ handle.wait()
89
+ if average:
90
+ buffer.data /= world_size
91
+
92
+
93
+ def sync_grad(params):
94
+ """
95
+ Simpler alternative to DistributedDataParallel, that doesn't rely
96
+ on any black magic. For simple models it can also be as fast.
97
+ Just call this on your model parameters after the call to backward!
98
+ """
99
+ if not is_distributed():
100
+ return
101
+ handles = []
102
+ for p in params:
103
+ if p.grad is not None:
104
+ handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
105
+ handles.append((p, handle))
106
+ for p, handle in handles:
107
+ handle.wait()
108
+ p.grad.data /= world_size()
109
+
110
+
111
+ def average_metrics(metrics: tp.Dict[str, float], count=1.0):
112
+ """Average a dictionary of metrics across all workers, using the optional
113
+ `count` as unormalized weight.
114
+ """
115
+ if not is_distributed():
116
+ return metrics
117
+ keys, values = zip(*metrics.items())
118
+ device = "cuda" if torch.cuda.is_available() else "cpu"
119
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
120
+ tensor *= count
121
+ all_reduce(tensor)
122
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
123
+ return dict(zip(keys, averaged))
boson_multimodal/audio_processing/quantization/vq.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Residual vector quantizer implementation."""
8
+
9
+ from dataclasses import dataclass, field
10
+ import math
11
+ import typing as tp
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ # from .core_vq import ResidualVectorQuantization
17
+ from .core_vq_lsx_version import ResidualVectorQuantization
18
+
19
+
20
+ @dataclass
21
+ class QuantizedResult:
22
+ quantized: torch.Tensor
23
+ codes: torch.Tensor
24
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
25
+ penalty: tp.Optional[torch.Tensor] = None
26
+ metrics: dict = field(default_factory=dict)
27
+
28
+
29
+ class ResidualVectorQuantizer(nn.Module):
30
+ """Residual Vector Quantizer.
31
+ Args:
32
+ dimension (int): Dimension of the codebooks.
33
+ n_q (int): Number of residual vector quantizers used.
34
+ bins (int): Codebook size.
35
+ decay (float): Decay for exponential moving average over the codebooks.
36
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
37
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
38
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
39
+ that have an exponential moving average cluster size less than the specified threshold with
40
+ randomly selected vector from the current batch.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dimension: int = 256,
46
+ codebook_dim: int = None,
47
+ n_q: int = 8,
48
+ bins: int = 1024,
49
+ decay: float = 0.99,
50
+ kmeans_init: bool = True,
51
+ kmeans_iters: int = 50,
52
+ threshold_ema_dead_code: int = 2,
53
+ ):
54
+ super().__init__()
55
+ self.n_q = n_q
56
+ self.dimension = dimension
57
+ self.codebook_dim = codebook_dim
58
+ self.bins = bins
59
+ self.decay = decay
60
+ self.kmeans_init = kmeans_init
61
+ self.kmeans_iters = kmeans_iters
62
+ self.threshold_ema_dead_code = threshold_ema_dead_code
63
+ self.vq = ResidualVectorQuantization(
64
+ dim=self.dimension,
65
+ codebook_dim=self.codebook_dim,
66
+ codebook_size=self.bins,
67
+ num_quantizers=self.n_q,
68
+ decay=self.decay,
69
+ kmeans_init=self.kmeans_init,
70
+ kmeans_iters=self.kmeans_iters,
71
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
72
+ )
73
+
74
+ def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult:
75
+ """Residual vector quantization on the given input tensor.
76
+ Args:
77
+ x (torch.Tensor): Input tensor.
78
+ sample_rate (int): Sample rate of the input tensor.
79
+ bandwidth (float): Target bandwidth.
80
+ Returns:
81
+ QuantizedResult:
82
+ The quantized (or approximately quantized) representation with
83
+ the associated bandwidth and any penalty term for the loss.
84
+ """
85
+ bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
86
+ n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
87
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
88
+ bw = torch.tensor(n_q * bw_per_q).to(x)
89
+ return quantized, codes, bw, torch.mean(commit_loss)
90
+ # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
91
+
92
+ def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
93
+ """Return n_q based on specified target bandwidth."""
94
+ bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
95
+ n_q = self.n_q
96
+ if bandwidth and bandwidth > 0.0:
97
+ n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
98
+ return n_q
99
+
100
+ def get_bandwidth_per_quantizer(self, sample_rate: int):
101
+ """Return bandwidth per quantizer for a given input sample rate."""
102
+ return math.log2(self.bins) * sample_rate / 1000
103
+
104
+ def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
105
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
106
+ The RVQ encode method sets the appropriate number of quantizer to use
107
+ and returns indices for each quantizer.
108
+ """
109
+ n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
110
+ codes = self.vq.encode(x, n_q=n_q)
111
+ return codes
112
+
113
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
114
+ """Decode the given codes to the quantized representation."""
115
+ quantized = self.vq.decode(codes)
116
+ return quantized
boson_multimodal/audio_processing/semantic_module.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on code from: https://github.com/zhenye234/xcodec
2
+ # Licensed under MIT License
3
+ # Modifications by BosonAI
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class Conv1d1x1(nn.Conv1d):
10
+ """1x1 Conv1d."""
11
+
12
+ def __init__(self, in_channels, out_channels, bias=True):
13
+ super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
14
+
15
+
16
+ class Conv1d(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ out_channels: int,
21
+ kernel_size: int,
22
+ stride: int = 1,
23
+ padding: int = -1,
24
+ dilation: int = 1,
25
+ groups: int = 1,
26
+ bias: bool = True,
27
+ ):
28
+ super().__init__()
29
+ self.in_channels = in_channels
30
+ self.out_channels = out_channels
31
+ self.kernel_size = kernel_size
32
+ if padding < 0:
33
+ padding = (kernel_size - 1) // 2 * dilation
34
+ self.dilation = dilation
35
+ self.conv = nn.Conv1d(
36
+ in_channels=in_channels,
37
+ out_channels=out_channels,
38
+ kernel_size=kernel_size,
39
+ stride=stride,
40
+ padding=padding,
41
+ dilation=dilation,
42
+ groups=groups,
43
+ bias=bias,
44
+ )
45
+
46
+ def forward(self, x):
47
+ """
48
+ Args:
49
+ x (Tensor): Float tensor variable with the shape (B, C, T).
50
+ Returns:
51
+ Tensor: Float tensor variable with the shape (B, C, T).
52
+ """
53
+ x = self.conv(x)
54
+ return x
55
+
56
+
57
+ class ResidualUnit(nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_channels: int,
61
+ out_channels: int,
62
+ kernel_size=3,
63
+ dilation=1,
64
+ bias=False,
65
+ nonlinear_activation="ELU",
66
+ nonlinear_activation_params={},
67
+ ):
68
+ super().__init__()
69
+ self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
70
+ self.conv1 = Conv1d(
71
+ in_channels=in_channels,
72
+ out_channels=out_channels,
73
+ kernel_size=kernel_size,
74
+ stride=1,
75
+ dilation=dilation,
76
+ bias=bias,
77
+ )
78
+ self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
79
+
80
+ def forward(self, x):
81
+ y = self.conv1(self.activation(x))
82
+ y = self.conv2(self.activation(y))
83
+ return x + y
84
+
85
+
86
+ class ConvTranspose1d(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels: int,
90
+ out_channels: int,
91
+ kernel_size: int,
92
+ stride: int,
93
+ padding=-1,
94
+ output_padding=-1,
95
+ groups=1,
96
+ bias=True,
97
+ ):
98
+ super().__init__()
99
+ if padding < 0:
100
+ padding = (stride + 1) // 2
101
+ if output_padding < 0:
102
+ output_padding = 1 if stride % 2 else 0
103
+ self.deconv = nn.ConvTranspose1d(
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ kernel_size=kernel_size,
107
+ stride=stride,
108
+ padding=padding,
109
+ output_padding=output_padding,
110
+ groups=groups,
111
+ bias=bias,
112
+ )
113
+
114
+ def forward(self, x):
115
+ """
116
+ Args:
117
+ x (Tensor): Float tensor variable with the shape (B, C, T).
118
+ Returns:
119
+ Tensor: Float tensor variable with the shape (B, C', T').
120
+ """
121
+ x = self.deconv(x)
122
+ return x
123
+
124
+
125
+ class EncoderBlock(nn.Module):
126
+ def __init__(
127
+ self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True
128
+ ):
129
+ super().__init__()
130
+ self.res_units = torch.nn.ModuleList()
131
+ for dilation in dilations:
132
+ self.res_units += [ResidualUnit(in_channels, in_channels, kernel_size=unit_kernel_size, dilation=dilation)]
133
+ self.num_res = len(self.res_units)
134
+
135
+ self.conv = Conv1d(
136
+ in_channels=in_channels,
137
+ out_channels=out_channels,
138
+ kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
139
+ stride=stride,
140
+ bias=bias,
141
+ )
142
+
143
+ def forward(self, x):
144
+ for idx in range(self.num_res):
145
+ x = self.res_units[idx](x)
146
+ x = self.conv(x)
147
+ return x
148
+
149
+
150
+ class Encoder(nn.Module):
151
+ def __init__(
152
+ self,
153
+ input_channels: int,
154
+ encode_channels: int,
155
+ channel_ratios=(1, 1),
156
+ strides=(1, 1),
157
+ kernel_size=3,
158
+ bias=True,
159
+ block_dilations=(1, 1),
160
+ unit_kernel_size=3,
161
+ ):
162
+ super().__init__()
163
+ assert len(channel_ratios) == len(strides)
164
+
165
+ self.conv = Conv1d(
166
+ in_channels=input_channels, out_channels=encode_channels, kernel_size=kernel_size, stride=1, bias=False
167
+ )
168
+ self.conv_blocks = torch.nn.ModuleList()
169
+ in_channels = encode_channels
170
+ for idx, stride in enumerate(strides):
171
+ out_channels = int(encode_channels * channel_ratios[idx]) # could be float
172
+ self.conv_blocks += [
173
+ EncoderBlock(
174
+ in_channels,
175
+ out_channels,
176
+ stride,
177
+ dilations=block_dilations,
178
+ unit_kernel_size=unit_kernel_size,
179
+ bias=bias,
180
+ )
181
+ ]
182
+ in_channels = out_channels
183
+ self.num_blocks = len(self.conv_blocks)
184
+ self.out_channels = out_channels
185
+
186
+ def forward(self, x):
187
+ x = self.conv(x)
188
+ for i in range(self.num_blocks):
189
+ x = self.conv_blocks[i](x)
190
+ return x
191
+
192
+
193
+ class DecoderBlock(nn.Module):
194
+ """Decoder block (no up-sampling)"""
195
+
196
+ def __init__(
197
+ self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True
198
+ ):
199
+ super().__init__()
200
+
201
+ if stride == 1:
202
+ self.conv = Conv1d(
203
+ in_channels=in_channels,
204
+ out_channels=out_channels,
205
+ kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
206
+ stride=stride,
207
+ bias=bias,
208
+ )
209
+ else:
210
+ self.conv = ConvTranspose1d(
211
+ in_channels=in_channels,
212
+ out_channels=out_channels,
213
+ kernel_size=(2 * stride),
214
+ stride=stride,
215
+ bias=bias,
216
+ )
217
+
218
+ self.res_units = torch.nn.ModuleList()
219
+ for idx, dilation in enumerate(dilations):
220
+ self.res_units += [
221
+ ResidualUnit(out_channels, out_channels, kernel_size=unit_kernel_size, dilation=dilation)
222
+ ]
223
+ self.num_res = len(self.res_units)
224
+
225
+ def forward(self, x):
226
+ x = self.conv(x)
227
+ for idx in range(self.num_res):
228
+ x = self.res_units[idx](x)
229
+ return x
230
+
231
+
232
+ class Decoder(nn.Module):
233
+ def __init__(
234
+ self,
235
+ code_dim: int,
236
+ output_channels: int,
237
+ decode_channels: int,
238
+ channel_ratios=(1, 1),
239
+ strides=(1, 1),
240
+ kernel_size=3,
241
+ bias=True,
242
+ block_dilations=(1, 1),
243
+ unit_kernel_size=3,
244
+ ):
245
+ super().__init__()
246
+ assert len(channel_ratios) == len(strides)
247
+
248
+ self.conv1 = Conv1d(
249
+ in_channels=code_dim,
250
+ out_channels=int(decode_channels * channel_ratios[0]),
251
+ kernel_size=kernel_size,
252
+ stride=1,
253
+ bias=False,
254
+ )
255
+
256
+ self.conv_blocks = torch.nn.ModuleList()
257
+ for idx, stride in enumerate(strides):
258
+ in_channels = int(decode_channels * channel_ratios[idx])
259
+ if idx < (len(channel_ratios) - 1):
260
+ out_channels = int(decode_channels * channel_ratios[idx + 1])
261
+ else:
262
+ out_channels = decode_channels
263
+ self.conv_blocks += [
264
+ DecoderBlock(
265
+ in_channels,
266
+ out_channels,
267
+ stride,
268
+ dilations=block_dilations,
269
+ unit_kernel_size=unit_kernel_size,
270
+ bias=bias,
271
+ )
272
+ ]
273
+ self.num_blocks = len(self.conv_blocks)
274
+
275
+ self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
276
+
277
+ def forward(self, z):
278
+ x = self.conv1(z)
279
+ for i in range(self.num_blocks):
280
+ x = self.conv_blocks[i](x)
281
+ x = self.conv2(x)
282
+ return x
boson_multimodal/constants.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ AUDIO_IN_TOKEN = "<|AUDIO|>"
2
+ AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>"
3
+ EOS_TOKEN = "<|end_of_text|>"
boson_multimodal/data_collator/__init__.py ADDED
File without changes
boson_multimodal/data_collator/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (156 Bytes). View file
 
boson_multimodal/data_collator/__pycache__/higgs_audio_collator.cpython-311.pyc ADDED
Binary file (23.8 kB). View file
 
boson_multimodal/data_collator/higgs_audio_collator.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import List, Tuple
6
+
7
+ from dataclasses import dataclass
8
+ from typing import List, Optional
9
+ from transformers.models.whisper.processing_whisper import WhisperProcessor
10
+
11
+ from ..dataset.chatml_dataset import ChatMLDatasetSample
12
+ from ..model.higgs_audio.utils import build_delay_pattern_mask
13
+
14
+
15
+ def _ceil_to_nearest(n, round_to):
16
+ return (n + round_to - 1) // round_to * round_to
17
+
18
+
19
+ def _ceil_to_next_power_of_two(self, x):
20
+ return 1 if x == 0 else 2 ** (x - 1).bit_length()
21
+
22
+
23
+ @dataclass
24
+ class HiggsAudioBatchInput:
25
+ input_ids: torch.LongTensor # shape (bsz, seq_len).
26
+ attention_mask: torch.Tensor # shape (bsz, seq_len).
27
+ audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len).
28
+ audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len).
29
+ audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
30
+ audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,)
31
+ # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
32
+ # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
33
+ # For example,
34
+ # audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
35
+ # This is a batch of 3 samples, then we will have the group location as:
36
+ # audio_out_ids_start_group_loc = [0, 0, 1, 2]
37
+ audio_out_ids_start_group_loc: Optional[
38
+ torch.LongTensor
39
+ ] # shape (num_audio_out,), specify which a sample's group location in the batch
40
+ audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length)
41
+ audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,)
42
+ label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len)
43
+ label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
44
+ reward: Optional[float] = None
45
+
46
+
47
+ class HiggsAudioSampleCollator:
48
+ """Sample collator for Higgs-Audio model.
49
+
50
+ Args:
51
+ whisper_processor (WhisperProcessor): The whisper processor.
52
+ audio_in_token_id (int): The token id for audio-in.
53
+ audio_out_token_id (int): The token id for audio-out.
54
+ pad_token_id (int): The token id for padding.
55
+ audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
56
+ audio_stream_eos_id (int): The token id for audio-stream end of sentence.
57
+ round_to (int): The round-to value.
58
+ pad_left (bool): Whether to pad left.
59
+ return_audio_in_tokens (bool): Whether to return audio-in tokens.
60
+ use_delay_pattern (bool): Whether to use delay pattern.
61
+ disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
62
+ chunk_size_seconds (int): The chunk size in seconds.
63
+ add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
64
+ mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.
65
+
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ whisper_processor: WhisperProcessor,
71
+ audio_in_token_id,
72
+ audio_out_token_id,
73
+ pad_token_id,
74
+ audio_stream_bos_id,
75
+ audio_stream_eos_id,
76
+ round_to=8,
77
+ pad_left=False,
78
+ encode_whisper_embed=True,
79
+ return_audio_in_tokens=True,
80
+ audio_num_codebooks=None,
81
+ use_delay_pattern=False,
82
+ disable_audio_codes_transform=False,
83
+ chunk_size_seconds=30, # Maximum duration for each chunk
84
+ add_new_bos_eos_for_long_chunk=True,
85
+ mask_audio_out_token_label=True,
86
+ ):
87
+ self.whisper_processor = whisper_processor
88
+ self.round_to = round_to
89
+ self.pad_left = pad_left
90
+ self.audio_in_token_id = audio_in_token_id
91
+ self.audio_out_token_id = audio_out_token_id
92
+ self.audio_stream_bos_id = audio_stream_bos_id
93
+ self.audio_stream_eos_id = audio_stream_eos_id
94
+ self.pad_token_id = pad_token_id
95
+ self.encode_whisper_embed = encode_whisper_embed
96
+ self.return_audio_in_tokens = return_audio_in_tokens
97
+ self.audio_num_codebooks = audio_num_codebooks
98
+ self.use_delay_pattern = use_delay_pattern
99
+ if encode_whisper_embed:
100
+ self.chunk_size_seconds = chunk_size_seconds
101
+ self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
102
+ else:
103
+ self.chunk_size_seconds = None
104
+ self.chunk_size_samples = None
105
+ self.disable_audio_codes_transform = disable_audio_codes_transform
106
+ self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
107
+ self.mask_audio_out_token_label = mask_audio_out_token_label
108
+
109
+ def _process_and_duplicate_audio_tokens(
110
+ self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, sr: int, labels: Optional[torch.Tensor] = None
111
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
112
+ """Process long audio and duplicate corresponding audio tokens.
113
+
114
+ Args:
115
+ input_ids: Input token ids
116
+ audio_idx: Index of the audio token in the sequence
117
+ wv: Audio waveform
118
+ sr: Sample rate
119
+ labels: Optional label ids to be duplicated alongside input ids
120
+
121
+ Returns:
122
+ Tuple of:
123
+ - New input ids with duplicated audio tokens
124
+ - New label ids (if labels were provided) or None
125
+ - Number of chunks created
126
+ """
127
+ # Calculate number of chunks needed
128
+ total_samples = len(wv)
129
+ num_chunks = math.ceil(total_samples / self.chunk_size_samples)
130
+
131
+ if num_chunks <= 1:
132
+ return input_ids, labels, 1
133
+
134
+ # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
135
+ audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
136
+ # Duplicate sequence for each chunk
137
+ duplicated_sequence = audio_token_seq.repeat(num_chunks)
138
+
139
+ # Create new input_ids with duplicated tokens
140
+ new_input_ids = torch.cat([input_ids[: audio_idx - 1], duplicated_sequence, input_ids[audio_idx + 2 :]])
141
+
142
+ # If labels are provided, duplicate them as well
143
+ new_labels = None
144
+ if labels is not None:
145
+ label_seq = labels[audio_idx - 1 : audio_idx + 2]
146
+ duplicated_labels = label_seq.repeat(num_chunks)
147
+ new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
148
+
149
+ return new_input_ids, new_labels, num_chunks
150
+
151
+ def __call__(self, batch: List[ChatMLDatasetSample]):
152
+ """Collate the input data with support for long audio processing."""
153
+
154
+ label_ids = None
155
+ label_audio_ids = None
156
+ if all([ele.label_ids is None for ele in batch]):
157
+ return_labels = False
158
+ else:
159
+ return_labels = True
160
+
161
+ if self.encode_whisper_embed:
162
+ # Process each sample in the batch to handle long audio
163
+ # TODO(?) The implementation here can be optimized.
164
+ processed_batch = []
165
+ for i in range(len(batch)):
166
+ sample = batch[i]
167
+ audio_in_mask = sample.input_ids == self.audio_in_token_id
168
+ audio_in_indices = torch.where(audio_in_mask)[0]
169
+ audio_out_mask = sample.input_ids == self.audio_out_token_id
170
+
171
+ # Process each audio token and duplicate if needed
172
+ modified_input_ids = sample.input_ids
173
+ modified_labels = sample.label_ids if return_labels else None
174
+ modified_waveforms_concat = []
175
+ modified_waveforms_start = []
176
+ modified_sample_rate = []
177
+ offset = 0 # Track position changes from duplicating tokens
178
+ curr_wv_offset = 0
179
+
180
+ # Process input audio tokens
181
+ for idx, audio_idx in enumerate(audio_in_indices):
182
+ # Get the audio for this token
183
+ wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index
184
+ if sr != self.whisper_processor.feature_extractor.sampling_rate:
185
+ resampled_wv = librosa.resample(
186
+ wv.cpu().numpy(),
187
+ orig_sr=sr,
188
+ target_sr=self.whisper_processor.feature_extractor.sampling_rate,
189
+ )
190
+ else:
191
+ resampled_wv = wv.cpu().numpy()
192
+ wv = torch.tensor(resampled_wv, device=wv.device)
193
+ sr = self.whisper_processor.feature_extractor.sampling_rate
194
+
195
+ # Process and duplicate tokens if necessary
196
+ token_pos = audio_idx + offset
197
+ modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
198
+ modified_input_ids, token_pos, wv, sr, modified_labels
199
+ )
200
+
201
+ # Update audio data
202
+ for chunk_idx in range(num_chunks):
203
+ chunk_start = chunk_idx * self.chunk_size_samples
204
+ chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
205
+ chunk_wv = wv[chunk_start:chunk_end]
206
+ modified_waveforms_concat.append(chunk_wv)
207
+ modified_waveforms_start.append(curr_wv_offset)
208
+ curr_wv_offset += len(chunk_wv)
209
+ modified_sample_rate.append(sr)
210
+
211
+ # Update offset for next iteration
212
+ offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens
213
+
214
+ # Create new sample with modified tokens and audio data
215
+ processed_sample = ChatMLDatasetSample(
216
+ input_ids=modified_input_ids,
217
+ label_ids=modified_labels if return_labels else sample.label_ids,
218
+ audio_ids_concat=sample.audio_ids_concat,
219
+ audio_ids_start=sample.audio_ids_start,
220
+ audio_waveforms_concat=torch.cat(modified_waveforms_concat)
221
+ if modified_waveforms_concat
222
+ else sample.audio_waveforms_concat,
223
+ audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
224
+ if modified_waveforms_start
225
+ else sample.audio_waveforms_start,
226
+ audio_sample_rate=torch.tensor(modified_sample_rate)
227
+ if modified_sample_rate
228
+ else sample.audio_sample_rate,
229
+ audio_speaker_indices=torch.tensor([]),
230
+ # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
231
+ audio_label_ids_concat=sample.audio_label_ids_concat,
232
+ )
233
+ # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
234
+ # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
235
+ processed_batch.append(processed_sample)
236
+ else:
237
+ processed_batch = batch
238
+
239
+ # Get the max sequence length based on processed batch
240
+ max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
241
+
242
+ # Get the ids for audio-in and audio-out for each batch
243
+ audio_in_wv_l = []
244
+ audio_in_ids_l = []
245
+ audio_out_ids_l = []
246
+ audio_out_ids_group_loc_l = []
247
+ audio_in_label_ids_l = None
248
+ audio_out_label_ids_l = None
249
+ reward_l = []
250
+
251
+ if return_labels:
252
+ audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not.
253
+
254
+ # Process the audio inputs and outputs
255
+ for i in range(len(processed_batch)):
256
+ audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
257
+ audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
258
+ audio_ids = torch.ones_like(processed_batch[i].input_ids)
259
+ audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
260
+ audio_in_ids = audio_ids[audio_in_mask]
261
+ audio_out_ids = audio_ids[audio_out_mask]
262
+
263
+ if return_labels:
264
+ audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
265
+ if self.mask_audio_out_token_label:
266
+ processed_batch[i].label_ids[audio_out_mask] = -100
267
+
268
+ # Process audio inputs
269
+ if self.return_audio_in_tokens:
270
+ audio_in_ids_l.extend(
271
+ [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
272
+ )
273
+ if processed_batch[i].audio_label_ids_concat is not None:
274
+ if audio_in_label_ids_l is None:
275
+ audio_in_label_ids_l = []
276
+ audio_in_label_ids_l.extend(
277
+ [
278
+ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
279
+ for idx in audio_in_ids
280
+ ]
281
+ )
282
+
283
+ audio_out_ids_l.extend(
284
+ [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
285
+ )
286
+ audio_out_ids_group_loc_l.append(i)
287
+ if processed_batch[i].reward is not None:
288
+ reward_l.append(processed_batch[i].reward)
289
+
290
+ if processed_batch[i].audio_label_ids_concat is not None:
291
+ if audio_out_label_ids_l is None:
292
+ audio_out_label_ids_l = []
293
+ audio_out_label_ids_l.extend(
294
+ [
295
+ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
296
+ for idx in audio_out_ids
297
+ ]
298
+ )
299
+
300
+ if self.encode_whisper_embed:
301
+ for idx in audio_in_ids:
302
+ wv, sr = processed_batch[i].get_wv(idx)
303
+ resampled_wv = wv.cpu().numpy()
304
+ # Split long audio into chunks
305
+ total_samples = len(resampled_wv)
306
+ for chunk_start in range(0, total_samples, self.chunk_size_samples):
307
+ chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
308
+ chunk = resampled_wv[chunk_start:chunk_end]
309
+ audio_in_wv_l.append(chunk)
310
+ # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
311
+ # f"Assertion failed: Mismatch in number of audios. " \
312
+ # f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."
313
+
314
+ if return_labels:
315
+ audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
316
+
317
+ # Process all audio features
318
+ if len(audio_in_wv_l) > 0:
319
+ feature_ret = self.whisper_processor.feature_extractor(
320
+ audio_in_wv_l,
321
+ sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
322
+ return_attention_mask=True,
323
+ padding="max_length",
324
+ )
325
+ audio_features = torch.from_numpy(feature_ret["input_features"])
326
+ audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
327
+ else:
328
+ if self.encode_whisper_embed:
329
+ audio_features = torch.zeros(
330
+ (
331
+ 0,
332
+ self.whisper_processor.feature_extractor.feature_size,
333
+ self.whisper_processor.feature_extractor.nb_max_frames,
334
+ ),
335
+ dtype=torch.float32,
336
+ )
337
+ audio_feature_attention_mask = torch.zeros(
338
+ (0, self.whisper_processor.feature_extractor.nb_max_frames), dtype=torch.int32
339
+ )
340
+ else:
341
+ audio_features = None
342
+ audio_feature_attention_mask = None
343
+
344
+ # Process audio input tokens
345
+ if len(audio_in_ids_l) > 0:
346
+ # Append audio-stream-bos and eos tokens
347
+ new_audio_in_ids_l = []
348
+ for ele in audio_in_ids_l:
349
+ if self.disable_audio_codes_transform:
350
+ # Do not add audio-stream-bos or eos tokens.
351
+ # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
352
+ audio_codes = ele
353
+ else:
354
+ audio_codes = torch.cat(
355
+ [
356
+ torch.full((ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long),
357
+ ele,
358
+ torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
359
+ ],
360
+ dim=1,
361
+ )
362
+ if self.use_delay_pattern:
363
+ audio_codes = build_delay_pattern_mask(
364
+ audio_codes.unsqueeze(0),
365
+ bos_token_id=self.audio_stream_bos_id,
366
+ pad_token_id=self.audio_stream_eos_id,
367
+ )[0].squeeze(0)
368
+ new_audio_in_ids_l.append(audio_codes)
369
+ audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
370
+ audio_in_ids_start = torch.cumsum(
371
+ torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]), dim=0
372
+ )
373
+ else:
374
+ audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
375
+ audio_in_ids_start = torch.zeros(0, dtype=torch.long)
376
+
377
+ # Process audio output tokens
378
+ audio_out_ids_start_group_loc = None
379
+ if len(audio_out_ids_l) > 0:
380
+ new_audio_out_ids_l = []
381
+ label_audio_ids_l = []
382
+ for idx, ele in enumerate(audio_out_ids_l):
383
+ if self.disable_audio_codes_transform:
384
+ # Do not add audio-stream-bos or eos tokens.
385
+ # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
386
+ audio_codes = ele
387
+ if return_labels:
388
+ label_audio_ids = audio_out_label_ids_l[idx]
389
+ else:
390
+ audio_codes = torch.cat(
391
+ [
392
+ torch.full((ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long),
393
+ ele,
394
+ torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
395
+ ],
396
+ dim=1,
397
+ )
398
+ if return_labels:
399
+ label_audio_ids = torch.cat(
400
+ [
401
+ torch.full((ele.shape[0], 1), -100, dtype=torch.long),
402
+ ele,
403
+ torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
404
+ ],
405
+ dim=1,
406
+ )
407
+ if self.use_delay_pattern:
408
+ audio_codes = build_delay_pattern_mask(
409
+ audio_codes.unsqueeze(0),
410
+ bos_token_id=self.audio_stream_bos_id,
411
+ pad_token_id=self.audio_stream_eos_id,
412
+ )[0].squeeze(0)
413
+ if return_labels:
414
+ label_audio_ids = build_delay_pattern_mask(
415
+ label_audio_ids.unsqueeze(0),
416
+ bos_token_id=-100,
417
+ pad_token_id=-100,
418
+ )[0].squeeze(0)
419
+ new_audio_out_ids_l.append(audio_codes)
420
+
421
+ if return_labels:
422
+ if audio_out_no_train_flag[idx]:
423
+ label_audio_ids[:] = -100
424
+ label_audio_ids_l.append(label_audio_ids)
425
+
426
+ audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
427
+ if return_labels:
428
+ label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
429
+ audio_out_ids_start = torch.cumsum(
430
+ torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]), dim=0
431
+ )
432
+ audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
433
+ else:
434
+ audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
435
+ audio_out_ids_start = torch.zeros(0, dtype=torch.long)
436
+ if return_labels:
437
+ label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
438
+
439
+ reward = torch.tensor(reward_l, dtype=torch.float32)
440
+
441
+ # Handle padding for input ids and attention mask
442
+ if self.pad_left:
443
+ input_ids = torch.stack(
444
+ [
445
+ F.pad(ele.input_ids, (max_seq_length - len(ele.input_ids), 0), value=self.pad_token_id)
446
+ for ele in processed_batch
447
+ ]
448
+ )
449
+ if return_labels:
450
+ label_ids = torch.stack(
451
+ [
452
+ F.pad(ele.label_ids, (max_seq_length - len(ele.label_ids), 0), value=-100)
453
+ for ele in processed_batch
454
+ ]
455
+ )
456
+ attention_mask = torch.stack(
457
+ [
458
+ F.pad(torch.ones_like(ele.input_ids), (max_seq_length - len(ele.input_ids), 0), value=0)
459
+ for ele in processed_batch
460
+ ]
461
+ )
462
+ else:
463
+ input_ids = torch.stack(
464
+ [
465
+ F.pad(ele.input_ids, (0, max_seq_length - len(ele.input_ids)), value=self.pad_token_id)
466
+ for ele in processed_batch
467
+ ]
468
+ )
469
+ if return_labels:
470
+ label_ids = torch.stack(
471
+ [
472
+ F.pad(ele.label_ids, (0, max_seq_length - len(ele.label_ids)), value=-100)
473
+ for ele in processed_batch
474
+ ]
475
+ )
476
+ attention_mask = torch.stack(
477
+ [
478
+ F.pad(torch.ones_like(ele.input_ids), (0, max_seq_length - len(ele.input_ids)), value=0)
479
+ for ele in processed_batch
480
+ ]
481
+ )
482
+
483
+ if not self.return_audio_in_tokens:
484
+ audio_in_ids = None
485
+ audio_in_ids_start = None
486
+
487
+ # Apply audio_num_codebooks limit if specified
488
+ if self.audio_num_codebooks is not None:
489
+ if audio_in_ids is not None:
490
+ audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
491
+ if audio_out_ids is not None:
492
+ audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
493
+ if label_audio_ids is not None:
494
+ label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
495
+
496
+ return HiggsAudioBatchInput(
497
+ input_ids=input_ids,
498
+ attention_mask=attention_mask,
499
+ audio_features=audio_features,
500
+ audio_feature_attention_mask=audio_feature_attention_mask,
501
+ audio_out_ids=audio_out_ids,
502
+ audio_out_ids_start=audio_out_ids_start,
503
+ audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
504
+ audio_in_ids=audio_in_ids,
505
+ audio_in_ids_start=audio_in_ids_start,
506
+ label_ids=label_ids,
507
+ label_audio_ids=label_audio_ids,
508
+ reward=reward,
509
+ )
boson_multimodal/data_types.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Basic data types for multimodal ChatML format."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Union
5
+
6
+
7
+ @dataclass
8
+ class AudioContent:
9
+ audio_url: str
10
+ # Base64 encoded audio bytes
11
+ raw_audio: Optional[str] = None
12
+ offset: Optional[float] = None
13
+ duration: Optional[float] = None
14
+ row_id: Optional[int] = None
15
+ type: str = "audio"
16
+
17
+
18
+ @dataclass
19
+ class TextContent:
20
+ text: str
21
+ type: str = "text"
22
+
23
+
24
+ @dataclass
25
+ class Message:
26
+ role: str
27
+ content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]]
28
+ recipient: Optional[str] = None
29
+
30
+
31
+ @dataclass
32
+ class ChatMLSample:
33
+ """Dataclass to hold multimodal ChatML data."""
34
+
35
+ messages: List[Message]
36
+ start_index: Optional[int] = None # We will mask the messages[:start_index] when finetuning the LLM.
37
+ misc: Optional[Dict] = None
38
+ speaker: Optional[str] = None
boson_multimodal/dataset/__init__.py ADDED
File without changes
boson_multimodal/dataset/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (150 Bytes). View file
 
boson_multimodal/dataset/__pycache__/chatml_dataset.cpython-311.pyc ADDED
Binary file (29.2 kB). View file
 
boson_multimodal/dataset/chatml_dataset.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dacite
2
+ import pandas as pd
3
+ import torch
4
+ import json
5
+
6
+ import numpy as np
7
+ import multiprocessing as mp
8
+
9
+ from dataclasses import dataclass, fields
10
+ from abc import ABC, abstractmethod
11
+ from typing import Union, List, Dict, Optional
12
+
13
+ from ..data_types import ChatMLSample, TextContent, AudioContent
14
+ from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN
15
+
16
+ from loguru import logger
17
+
18
+ # Whisper processor, 30 sec -> 3000 features
19
+ # Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz
20
+ WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25
21
+
22
+
23
+ @dataclass
24
+ class ChatMLDatasetSample:
25
+ input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens.
26
+ label_ids: torch.LongTensor # Shape (seq_len,): The label ids.
27
+ audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
28
+ # Here `audio_seq_len` is the length of the concatenated audio tokens.`
29
+ audio_ids_start: (
30
+ torch.LongTensor
31
+ ) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens.
32
+ audio_waveforms_concat: (
33
+ torch.Tensor
34
+ ) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features.
35
+ audio_waveforms_start: (
36
+ torch.LongTensor
37
+ ) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms.
38
+ audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms.
39
+ audio_speaker_indices: (
40
+ torch.LongTensor
41
+ ) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio.
42
+ audio_label_ids_concat: Optional[torch.LongTensor] = (
43
+ None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
44
+ )
45
+ # Here `audio_seq_len` is the length of the concatenated audio tokens.`
46
+ reward: Optional[float] = None
47
+
48
+ def num_audios(self):
49
+ return max(len(self.audio_waveforms_start), len(self.audio_ids_start))
50
+
51
+ def get_audio_codes(self, idx):
52
+ code_start = self.audio_ids_start[idx]
53
+ if idx < len(self.audio_ids_start) - 1:
54
+ code_end = self.audio_ids_start[idx + 1]
55
+ else:
56
+ code_end = self.audio_ids_concat.shape[-1]
57
+
58
+ return self.audio_ids_concat[:, code_start:code_end]
59
+
60
+ def get_audio_codes_labels(self, idx):
61
+ if self.audio_label_ids_concat is None:
62
+ return None
63
+ code_start = self.audio_ids_start[idx]
64
+ if idx < len(self.audio_ids_start) - 1:
65
+ code_end = self.audio_ids_start[idx + 1]
66
+ else:
67
+ code_end = self.audio_ids_concat.shape[-1]
68
+
69
+ return self.audio_label_ids_concat[:, code_start:code_end]
70
+
71
+ def get_wv(self, idx):
72
+ wv_start = self.audio_waveforms_start[idx]
73
+ sr = self.audio_sample_rate[idx]
74
+ if idx < len(self.audio_waveforms_start) - 1:
75
+ wv_end = self.audio_waveforms_start[idx + 1]
76
+ else:
77
+ wv_end = self.audio_waveforms_concat.shape[-1]
78
+ return self.audio_waveforms_concat[wv_start:wv_end], sr
79
+
80
+ def cal_num_tokens(
81
+ self,
82
+ encode_whisper_embed: bool = True,
83
+ encode_audio_in_tokens: bool = False,
84
+ encode_audio_out_tokens: bool = True,
85
+ audio_in_token_id: int = 128015,
86
+ audio_out_token_id: int = 128016,
87
+ ) -> int:
88
+ # we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids
89
+ # It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa)
90
+ num_tokens = len(self.input_ids) - len(self.audio_ids_start)
91
+
92
+ if encode_whisper_embed and len(self.audio_waveforms_concat) > 0:
93
+ audio_lengths = torch.diff(self.audio_waveforms_start)
94
+ if len(audio_lengths):
95
+ # Sum before calling .item()
96
+ num_tokens += (
97
+ (
98
+ np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1])
99
+ ).sum()
100
+ ).item()
101
+ # add the last audio's token estimation
102
+ num_tokens += (
103
+ np.ceil(
104
+ WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC
105
+ * (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1])
106
+ / self.audio_sample_rate[-1]
107
+ )
108
+ ).item()
109
+
110
+ if self.audio_ids_concat.size(1) > 0:
111
+ audio_io_ids = self.input_ids[
112
+ (self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id)
113
+ ]
114
+ audio_io_id_lengths = torch.concat(
115
+ [
116
+ torch.diff(self.audio_ids_start),
117
+ torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]),
118
+ ]
119
+ )
120
+ if encode_audio_in_tokens:
121
+ num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item()
122
+
123
+ if encode_audio_out_tokens:
124
+ num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item()
125
+
126
+ return int(num_tokens)
127
+
128
+ @classmethod
129
+ def merge(
130
+ cls,
131
+ samples: List["ChatMLDatasetSample"],
132
+ eos_token_id: int,
133
+ ignore_index: int,
134
+ padding_size: Optional[int] = None,
135
+ ) -> "ChatMLDatasetSample":
136
+ """Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start.
137
+
138
+ Args:
139
+ samples (List[ChatMLDatasetSample]): List of samples to merge.
140
+ eos_token_id (int): Tokens to be inserted into input_ids between samples.
141
+ ignore_index (int): Default label for padding.
142
+ padding_size (Optional[int]): If provided, pad the sequence to with this length.
143
+
144
+ Returns:
145
+ ChatMLDatasetSample: Merged and potentially padded sample.
146
+ """
147
+ if not samples:
148
+ logger.fatal("The samples list is empty and cannot be merged.")
149
+ raise ValueError("The samples list is empty and cannot be merged.")
150
+
151
+ # Initialize empty lists for concatenation
152
+ input_ids_list = []
153
+ label_ids_list = []
154
+ audio_ids_concat_list = []
155
+ audio_ids_start_list = []
156
+ audio_waveforms_concat_list = []
157
+ audio_waveforms_start_list = []
158
+ audio_sample_rate_list = []
159
+ audio_speaker_indices_list = []
160
+
161
+ # Track offsets
162
+ audio_ids_offset = 0
163
+ audio_waveforms_offset = 0
164
+
165
+ for sample in samples:
166
+ # Add input_ids and label_ids with padding
167
+ if input_ids_list:
168
+ input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long))
169
+ label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long))
170
+ input_ids_list.append(sample.input_ids)
171
+ label_ids_list.append(sample.label_ids)
172
+
173
+ # Add audio_ids_concat and handle empty audio ids
174
+ if sample.audio_ids_concat.size(1) > 0:
175
+ audio_ids_concat_list.append(sample.audio_ids_concat)
176
+
177
+ # Offset and add audio_ids_start
178
+ audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset)
179
+ audio_ids_offset += sample.audio_ids_concat.size(
180
+ 1
181
+ ) # (num_codebooks, seq_len): Update offset by audio_seq_len
182
+
183
+ # Add audio_waveforms_concat
184
+ if sample.audio_waveforms_concat.size(0) > 0:
185
+ # Check dimensions of the audio waveform to ensure consistency
186
+ if (
187
+ audio_waveforms_concat_list
188
+ and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim()
189
+ ):
190
+ logger.warning(
191
+ f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D"
192
+ )
193
+ continue
194
+
195
+ audio_waveforms_concat_list.append(sample.audio_waveforms_concat)
196
+ audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset)
197
+ audio_waveforms_offset += sample.audio_waveforms_concat.size(0)
198
+
199
+ # Add audio_sample_rate and audio_speaker_indices
200
+ audio_sample_rate_list.append(sample.audio_sample_rate)
201
+
202
+ audio_speaker_indices_list.append(sample.audio_speaker_indices)
203
+
204
+ # Concatenate all tensors
205
+ input_ids = torch.cat(input_ids_list, dim=0)
206
+ label_ids = torch.cat(label_ids_list, dim=0)
207
+
208
+ # Apply padding if padding_size is specified
209
+ if padding_size is not None and padding_size > 0:
210
+ input_ids = torch.cat([input_ids, torch.full((padding_size,), eos_token_id, dtype=torch.long)], dim=0)
211
+ label_ids = torch.cat([label_ids, torch.full((padding_size,), ignore_index, dtype=torch.long)], dim=0)
212
+
213
+ # Safely concatenate audio tensors with proper error handling
214
+ try:
215
+ audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]])
216
+ audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([])
217
+
218
+ # Check for dimensional consistency in audio waveforms
219
+ if audio_waveforms_concat_list:
220
+ dims = [t.dim() for t in audio_waveforms_concat_list]
221
+ if not all(d == dims[0] for d in dims):
222
+ # If dimensions don't match, log warning and filter out the problematic tensors
223
+ logger.warning(
224
+ f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones."
225
+ )
226
+ expected_dim = max(set(dims), key=dims.count) # Most common dimension
227
+ audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim]
228
+
229
+ # Recalculate audio_waveforms_start with the filtered list
230
+ if audio_waveforms_concat_list:
231
+ audio_waveforms_offset = 0
232
+ audio_waveforms_start_list = []
233
+ for waveform in audio_waveforms_concat_list:
234
+ audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset]))
235
+ audio_waveforms_offset += waveform.size(0)
236
+
237
+ audio_waveforms_concat = (
238
+ torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([])
239
+ )
240
+ audio_waveforms_start = (
241
+ torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([])
242
+ )
243
+ audio_sample_rate = (
244
+ torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([])
245
+ )
246
+ audio_speaker_indices = (
247
+ torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([])
248
+ )
249
+
250
+ except RuntimeError as e:
251
+ logger.error(f"Error during tensor concatenation: {str(e)}")
252
+ logger.warning("Falling back to empty audio tensors")
253
+ # Fall back to empty tensors
254
+ audio_ids_concat = torch.tensor([[]])
255
+ audio_ids_start = torch.tensor([])
256
+ audio_waveforms_concat = torch.tensor([])
257
+ audio_waveforms_start = torch.tensor([])
258
+ audio_sample_rate = torch.tensor([])
259
+ audio_speaker_indices = torch.tensor([])
260
+
261
+ # Create the merged sample
262
+ merged_sample = cls(
263
+ input_ids=input_ids,
264
+ label_ids=label_ids,
265
+ audio_ids_concat=audio_ids_concat,
266
+ audio_ids_start=audio_ids_start,
267
+ audio_waveforms_concat=audio_waveforms_concat,
268
+ audio_waveforms_start=audio_waveforms_start,
269
+ audio_sample_rate=audio_sample_rate,
270
+ audio_speaker_indices=audio_speaker_indices,
271
+ )
272
+
273
+ return merged_sample
274
+
275
+
276
+ @dataclass
277
+ class RankedChatMLDatasetSampleTuple:
278
+ samples: List[ChatMLDatasetSample]
279
+ scores: List[float]
280
+
281
+ def max_score_sample(self) -> ChatMLDatasetSample:
282
+ idx = self.scores.index(max(self.scores))
283
+ self.samples[idx].reward = self.scores[idx]
284
+ return self.samples[idx]
285
+
286
+ def min_score_sample(self) -> ChatMLDatasetSample:
287
+ idx = self.scores.index(min(self.scores))
288
+ self.samples[idx].reward = self.scores[idx]
289
+ return self.samples[idx]
290
+
291
+
292
+ @dataclass
293
+ class ChatMLDatasetStorageSample:
294
+ input_tokens: torch.LongTensor
295
+ label_tokens: torch.LongTensor
296
+ audio_bytes_cache_dir_index: int
297
+ audio_codes_cache_dir_index: int
298
+ audio_bytes_indices: torch.LongTensor
299
+ audio_codes_indices: torch.LongTensor
300
+ speaker_indices: torch.LongTensor
301
+ file_index: int
302
+ original_sample_index: int
303
+
304
+
305
+ # TODO(sxjscience): We need to revist the logic about parsing speaker ids.
306
+ # Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample.
307
+ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
308
+ """Preprocess the ChatML sample to get the tokens for the text part.
309
+
310
+ Args:
311
+ sample (ChatMLSample): The ChatML sample to preprocess.
312
+ tokenizer: The tokenizer to use for encoding the text.
313
+
314
+ """
315
+
316
+ try:
317
+ if not isinstance(sample, ChatMLSample):
318
+ # Handle all fields that could be NaN
319
+ if "speaker" in sample and pd.isna(sample["speaker"]):
320
+ sample["speaker"] = None
321
+ if "start_index" in sample and pd.isna(sample["start_index"]):
322
+ sample["start_index"] = None
323
+ if "content" in sample and pd.isna(sample["content"]):
324
+ sample["content"] = ""
325
+
326
+ # Convert any other potential NaN values in nested structures
327
+ def convert_nan_to_none(obj):
328
+ import numpy as np
329
+
330
+ if isinstance(obj, (pd.Series, np.ndarray)):
331
+ return obj.tolist()
332
+ elif pd.api.types.is_scalar(obj) and pd.isna(obj):
333
+ return None
334
+ elif isinstance(obj, dict):
335
+ return {k: convert_nan_to_none(v) for k, v in obj.items()}
336
+ elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple
337
+ return [convert_nan_to_none(item) for item in obj]
338
+ return obj
339
+
340
+ # Clean the sample data
341
+ clean_sample = convert_nan_to_none(sample)
342
+
343
+ val_keys = []
344
+ for field in fields(ChatMLSample):
345
+ if field.name in clean_sample:
346
+ val_keys.append(field.name)
347
+ clean_sample = {k: clean_sample[k] for k in val_keys}
348
+
349
+ try:
350
+ sample = dacite.from_dict(
351
+ data_class=ChatMLSample, data=clean_sample, config=dacite.Config(strict=True, check_types=True)
352
+ )
353
+ except Exception as e:
354
+ print(f"Failed to convert to ChatMLSample: {e}")
355
+ print(f"Clean sample: {json.dumps(clean_sample, indent=2)}")
356
+ return None, None, None, None
357
+
358
+ input_tokens = []
359
+ label_tokens = []
360
+ audio_contents = []
361
+ speaker_id = None
362
+ if sample.speaker is not None:
363
+ speaker_id = sample.speaker
364
+ elif sample.misc is not None:
365
+ if "speaker" in sample.misc:
366
+ speaker_id = sample.misc["speaker"]
367
+
368
+ total_m = len(sample.messages)
369
+ for turn_id, message in enumerate(sample.messages):
370
+ role = message.role
371
+ recipient = message.recipient
372
+ content = message.content
373
+ content_l = []
374
+
375
+ if isinstance(content, str):
376
+ content_l.append(TextContent(text=content))
377
+ elif isinstance(content, TextContent):
378
+ content_l.append(content)
379
+ elif isinstance(content, AudioContent):
380
+ content_l.append(content)
381
+ elif isinstance(content, list):
382
+ for ele in content:
383
+ if isinstance(ele, str):
384
+ content_l.append(TextContent(text=ele))
385
+ else:
386
+ content_l.append(ele)
387
+ if turn_id == 0:
388
+ prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n"
389
+ else:
390
+ prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n"
391
+ eot_postfix = "<|eot_id|>"
392
+ eom_postfix = "<|eom_id|>"
393
+
394
+ prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
395
+ input_tokens.extend(prefix_tokens)
396
+ label_tokens.extend([-100 for _ in prefix_tokens])
397
+
398
+ if recipient:
399
+ assert role == "assistant", "Recipient is only available for assistant role."
400
+ recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False)
401
+ input_tokens.extend(recipient_tokens)
402
+ label_tokens.extend(recipient_tokens)
403
+
404
+ for content in content_l:
405
+ if content.type == "text":
406
+ text_tokens = tokenizer.encode(content.text, add_special_tokens=False)
407
+ input_tokens.extend(text_tokens)
408
+ if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
409
+ label_tokens.extend(text_tokens)
410
+ else:
411
+ label_tokens.extend([-100 for _ in text_tokens])
412
+
413
+ elif content.type == "audio":
414
+ # Generate the text-part of the audio tokens
415
+ audio_contents.append(content)
416
+ if role == "user" or role == "system":
417
+ # Add the text tokens
418
+ text_tokens = tokenizer.encode(
419
+ f"<|audio_bos|><|AUDIO|><|audio_eos|>",
420
+ add_special_tokens=False,
421
+ )
422
+ input_tokens.extend(text_tokens)
423
+ label_tokens.extend([-100 for _ in text_tokens])
424
+ elif role == "assistant":
425
+ # Add the text tokens for audio-out part.
426
+ text_tokens = tokenizer.encode(
427
+ f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
428
+ add_special_tokens=False,
429
+ )
430
+ input_tokens.extend(text_tokens)
431
+ if sample.start_index is None or turn_id >= sample.start_index:
432
+ label_tokens.extend(text_tokens)
433
+ else:
434
+ label_tokens.extend([-100 for _ in text_tokens])
435
+ next_id = turn_id + 1
436
+ if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant":
437
+ postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False)
438
+ input_tokens.extend(postfix_tokens)
439
+ else:
440
+ postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False)
441
+ input_tokens.extend(postfix_tokens)
442
+ if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
443
+ label_tokens.extend(postfix_tokens)
444
+ else:
445
+ label_tokens.extend([-100 for _ in postfix_tokens])
446
+
447
+ return input_tokens, label_tokens, audio_contents, speaker_id
448
+
449
+ except Exception as e:
450
+ print(f"Error in prepare_chatml_sample: {str(e)}")
451
+ print(f"Sample data: {json.dumps(sample, indent=2)}")
452
+ return None, None, None, None
453
+
454
+
455
+ def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer):
456
+ """Extract the generation prompt and reference answer from the input tokens.
457
+
458
+ For example:
459
+
460
+ Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
461
+ What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
462
+ <|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>'
463
+
464
+ -->
465
+
466
+ Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
467
+ What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
468
+ <|start_header_id|>assistant<|end_header_id|>\n\n',
469
+ Reference = 'At first they went by quick, too quick to even get.'
470
+
471
+ Args:
472
+ input_tokens: The input tokens.
473
+ audio_contents: The audio contents.
474
+ tokenizer: The tokenizer to use for decoding the text.
475
+
476
+ Returns:
477
+ prompt_tokens: The tokens for the prompt.
478
+ reference_answer: The reference answer.
479
+ num_audios_in_reference: The number of audios in the reference answer.
480
+
481
+ """
482
+ input_text = tokenizer.decode(input_tokens)
483
+ generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
484
+ postfix = "<|eot_id|>"
485
+ assert generation_prefix in input_text
486
+ generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix)
487
+ generation_prompt = input_text[:generation_prompt_end_loc]
488
+ reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)]
489
+ num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN)
490
+ return tokenizer.encode(generation_prompt, add_special_tokens=False), reference_answer, num_audios_in_reference
491
+
492
+
493
+ def prepare_chatml_dataframe_single_process(df, tokenizer):
494
+ """Prepare the ChatML DataFrame."""
495
+ ret = []
496
+ for _, row in df.iterrows():
497
+ input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer)
498
+ ret.append((input_tokens, label_tokens, audio_contents, speaker_id))
499
+ return ret
500
+
501
+
502
+ def prepare_chatml_dataframe(df, tokenizer, num_process=16):
503
+ if num_process is None:
504
+ return prepare_chatml_dataframe_single_process(df, tokenizer)
505
+ else:
506
+ num_process = max(min(len(df) // 1000, num_process), 1)
507
+ workloads = np.array_split(df, num_process)
508
+ with mp.Pool(num_process) as pool:
509
+ ret = pool.starmap(
510
+ prepare_chatml_dataframe_single_process, [(workload, tokenizer) for workload in workloads]
511
+ )
512
+ return sum(ret, [])
513
+
514
+
515
+ class DatasetInterface(ABC):
516
+ @abstractmethod
517
+ def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
518
+ """Retrieve a dataset sample by index."""
519
+ raise NotImplementedError
520
+
521
+
522
+ class IterableDatasetInterface(ABC):
523
+ @abstractmethod
524
+ def __iter__(self) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
525
+ """Retrieve a sample by iterating through the dataset."""
526
+ raise NotImplementedError
527
+
528
+
529
+ @dataclass
530
+ class DatasetInfo:
531
+ dataset_type: str
532
+ group_type: Optional[str] = None
533
+ mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples.
boson_multimodal/model/__init__.py ADDED
File without changes
boson_multimodal/model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (148 Bytes). View file
 
boson_multimodal/model/higgs_audio/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+
3
+ from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig
4
+ from .modeling_higgs_audio import HiggsAudioModel
5
+
6
+
7
+ AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig)
8
+ AutoConfig.register("higgs_audio", HiggsAudioConfig)
9
+ AutoModel.register(HiggsAudioConfig, HiggsAudioModel)
boson_multimodal/model/higgs_audio/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (665 Bytes). View file
 
boson_multimodal/model/higgs_audio/__pycache__/audio_head.cpython-311.pyc ADDED
Binary file (5.67 kB). View file
 
boson_multimodal/model/higgs_audio/__pycache__/common.cpython-311.pyc ADDED
Binary file (2.02 kB). View file