xinjie.wang commited on
Commit
5638c1f
·
1 Parent(s): 22afe09
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -1
  2. embodied_gen/models/sam3d.py +3 -2
  3. embodied_gen/utils/monkey_patches.py +4 -8
  4. thirdparty/sam3d/sam3d/.gitignore +1 -0
  5. thirdparty/sam3d/sam3d/CODE_OF_CONDUCT.md +80 -0
  6. thirdparty/sam3d/sam3d/CONTRIBUTING.md +39 -0
  7. thirdparty/sam3d/sam3d/LICENSE +52 -0
  8. thirdparty/sam3d/sam3d/README.md +152 -0
  9. thirdparty/sam3d/sam3d/checkpoints/.gitignore +2 -0
  10. thirdparty/sam3d/sam3d/demo.py +21 -0
  11. thirdparty/sam3d/sam3d/doc/setup.md +58 -0
  12. thirdparty/sam3d/sam3d/environments/default.yml +216 -0
  13. thirdparty/sam3d/sam3d/notebook/demo_3db_mesh_alignment.ipynb +149 -0
  14. thirdparty/sam3d/sam3d/notebook/demo_multi_object.ipynb +162 -0
  15. thirdparty/sam3d/sam3d/notebook/demo_single_object.ipynb +164 -0
  16. thirdparty/sam3d/sam3d/notebook/inference.py +414 -0
  17. thirdparty/sam3d/sam3d/notebook/mesh_alignment.py +469 -0
  18. thirdparty/sam3d/sam3d/patching/hydra +16 -0
  19. thirdparty/sam3d/sam3d/pyproject.toml +30 -0
  20. thirdparty/sam3d/sam3d/requirements.dev.txt +4 -0
  21. thirdparty/sam3d/sam3d/requirements.inference.txt +4 -0
  22. thirdparty/sam3d/sam3d/requirements.p3d.txt +2 -0
  23. thirdparty/sam3d/sam3d/requirements.txt +88 -0
  24. thirdparty/sam3d/sam3d/sam3d_objects/__init__.py +6 -0
  25. thirdparty/sam3d/sam3d/sam3d_objects/config/__init__.py +1 -0
  26. thirdparty/sam3d/sam3d/sam3d_objects/config/utils.py +174 -0
  27. thirdparty/sam3d/sam3d/sam3d_objects/data/__init__.py +1 -0
  28. thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/__init__.py +1 -0
  29. thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/__init__.py +1 -0
  30. thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py +986 -0
  31. thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_processing.py +189 -0
  32. thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/pose_target.py +784 -0
  33. thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/preprocessor.py +203 -0
  34. thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/transforms_3d.py +50 -0
  35. thirdparty/sam3d/sam3d/sam3d_objects/data/utils.py +243 -0
  36. thirdparty/sam3d/sam3d/sam3d_objects/model/__init__.py +1 -0
  37. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/__init__.py +1 -0
  38. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/__init__.py +1 -0
  39. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/__init__.py +1 -0
  40. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/dino.py +142 -0
  41. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/embedder_fuser.py +238 -0
  42. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/point_remapper.py +78 -0
  43. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/pointmap.py +238 -0
  44. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/__init__.py +1 -0
  45. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/base.py +65 -0
  46. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/classifier_free_guidance.py +259 -0
  47. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/__init__.py +1 -0
  48. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/model.py +363 -0
  49. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/solver.py +126 -0
  50. thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/shortcut/__init__.py +1 -0
README.md CHANGED
@@ -10,7 +10,7 @@ pinned: false
10
  license: apache-2.0
11
  short_description: Generate physically plausible 3D model from single image.
12
  paper: https://huggingface.co/papers/2506.10600
13
- startup_duration_timeout: 2h
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  license: apache-2.0
11
  short_description: Generate physically plausible 3D model from single image.
12
  paper: https://huggingface.co/papers/2506.10600
13
+ startup_duration_timeout: 4h
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
embodied_gen/models/sam3d.py CHANGED
@@ -94,9 +94,10 @@ class Sam3dInference:
94
  ) -> dict:
95
  if isinstance(image, Image.Image):
96
  image = np.array(image)
 
97
  return self.pipeline.run(
98
  image,
99
- mask,
100
  seed,
101
  stage1_only=False,
102
  with_mesh_postprocess=False,
@@ -132,7 +133,7 @@ if __name__ == "__main__":
132
 
133
  start = time()
134
 
135
- output = pipeline(image, mask, seed=42)
136
  print(f"Running cost: {round(time()-start, 1)}")
137
 
138
  if torch.cuda.is_available():
 
94
  ) -> dict:
95
  if isinstance(image, Image.Image):
96
  image = np.array(image)
97
+ image = self.merge_mask_to_rgba(image, mask)
98
  return self.pipeline.run(
99
  image,
100
+ None,
101
  seed,
102
  stage1_only=False,
103
  with_mesh_postprocess=False,
 
133
 
134
  start = time()
135
 
136
+ output = pipeline.run(image, mask, seed=42)
137
  print(f"Running cost: {round(time()-start, 1)}")
138
 
139
  if torch.cuda.is_available():
embodied_gen/utils/monkey_patches.py CHANGED
@@ -397,17 +397,13 @@ def monkey_patch_sam3d():
397
  exc_info=True,
398
  )
399
 
400
- # glb.export("sample.glb")
401
- logger.info("Finished!")
402
-
403
- return {
404
  **ss_return_dict,
405
  **outputs,
406
- "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3
407
- "pointmap_colors": pts_colors.cpu().permute(
408
- (1, 2, 0)
409
- ), # HxWx3
410
  }
 
411
 
412
  InferencePipelinePointMap.run = patch_run
413
 
 
397
  exc_info=True,
398
  )
399
 
400
+ result = {
 
 
 
401
  **ss_return_dict,
402
  **outputs,
403
+ "pointmap": pts.cpu().permute((1, 2, 0)),
404
+ "pointmap_colors": pts_colors.cpu().permute((1, 2, 0)),
 
 
405
  }
406
+ return result
407
 
408
  InferencePipelinePointMap.run = patch_run
409
 
thirdparty/sam3d/sam3d/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
thirdparty/sam3d/sam3d/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <[email protected]>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
thirdparty/sam3d/sam3d/CONTRIBUTING.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to sam-3d-objects
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Our Development Process
6
+ ... (in particular how this is synced with internal changes to the project)
7
+
8
+ ## Pull Requests
9
+ We actively welcome your pull requests.
10
+
11
+ 1. Fork the repo and create your branch from `main`.
12
+ 2. If you've added code that should be tested, add tests.
13
+ 3. If you've changed APIs, update the documentation.
14
+ 4. Ensure the test suite passes.
15
+ 5. Make sure your code lints.
16
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
17
+
18
+ ## Contributor License Agreement ("CLA")
19
+ In order to accept your pull request, we need you to submit a CLA. You only need
20
+ to do this once to work on any of Meta's open source projects.
21
+
22
+ Complete your CLA here: <https://code.facebook.com/cla>
23
+
24
+ ## Issues
25
+ We use GitHub issues to track public bugs. Please ensure your description is
26
+ clear and has sufficient instructions to be able to reproduce the issue.
27
+
28
+ Meta has a [bounty program](https://bugbounty.meta.com/) for the safe
29
+ disclosure of security bugs. In those cases, please go through the process
30
+ outlined on that page and do not file a public issue.
31
+
32
+ ## Coding Style
33
+ * 2 spaces for indentation rather than tabs
34
+ * 80 character line length
35
+ * ...
36
+
37
+ ## License
38
+ By contributing to sam-3d-objects, you agree that your contributions will be licensed
39
+ under the LICENSE file in the root directory of this source tree.
thirdparty/sam3d/sam3d/LICENSE ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SAM License
2
+ Last Updated: November 19, 2025
3
+
4
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein.
5
+
6
+ “SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
7
+
8
+ “Documentation” means the specifications, manuals and documentation accompanying
9
+ SAM Materials distributed by Meta.
10
+
11
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
12
+
13
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
14
+
15
+ “Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
16
+
17
+ “Trade Controls” means any of the following: Sanctions and applicable export and import controls.
18
+
19
+ By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement.
20
+
21
+ 1. License Rights and Redistribution.
22
+
23
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials.
24
+
25
+ i. Grant of Patent License. Subject to the terms and conditions of this License, you are granted a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by Meta that are necessarily infringed alone or by combination of their contribution(s) with the SAM 3 Materials. If you institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM 3 Materials incorporated within the work constitutes direct or contributory patent infringement, then any patent licenses granted to you under this License for that work shall terminate as of the date such litigation is filed.
26
+
27
+ b. Redistribution and Use.
28
+
29
+ i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials.
30
+
31
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication.
32
+
33
+ iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
34
+ iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials.
35
+ v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
36
+ 2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
37
+
38
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS.
39
+
40
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
41
+
42
+ 5. Intellectual Property.
43
+
44
+ a. Subject to Meta’s ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
45
+
46
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials.
47
+
48
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
49
+
50
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
51
+
52
+ 8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
thirdparty/sam3d/sam3d/README.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM 3D
2
+
3
+ SAM 3D Objects is one part of SAM 3D, a pair of models for object and human mesh reconstruction. If you’re looking for SAM 3D Body, [click here](https://github.com/facebookresearch/sam-3d-body).
4
+
5
+ # SAM 3D Objects
6
+
7
+ **SAM 3D Team**, [Xingyu Chen](https://scholar.google.com/citations?user=gjSHr6YAAAAJ&hl=en&oi=sra)\*, [Fu-Jen Chu](https://fujenchu.github.io/)\*, [Pierre Gleize](https://scholar.google.com/citations?user=4imOcw4AAAAJ&hl=en&oi=ao)\*, [Kevin J Liang](https://kevinjliang.github.io/)\*, [Alexander Sax](https://alexsax.github.io/)\*, [Hao Tang](https://scholar.google.com/citations?user=XY6Nh9YAAAAJ&hl=en&oi=sra)\*, [Weiyao Wang](https://sites.google.com/view/weiyaowang/home)\*, [Michelle Guo](https://scholar.google.com/citations?user=lyjjpNMAAAAJ&hl=en&oi=ao), [Thibaut Hardin](https://github.com/Thibaut-H), [Xiang Li](https://ryanxli.github.io/)⚬, [Aohan Lin](https://github.com/linaohan), [Jia-Wei Liu](https://jia-wei-liu.github.io/), [Ziqi Ma](https://ziqi-ma.github.io/)⚬, [Anushka Sagar](https://www.linkedin.com/in/anushkasagar/), [Bowen Song](https://scholar.google.com/citations?user=QQKVkfcAAAAJ&hl=en&oi=sra)⚬, [Xiaodong Wang](https://scholar.google.com/citations?authuser=2&user=rMpcFYgAAAAJ), [Jianing Yang](https://jedyang.com/)⚬, [Bowen Zhang](http://home.ustc.edu.cn/~zhangbowen/)⚬, [Piotr Dollár](https://pdollar.github.io/)†, [Georgia Gkioxari](https://georgiagkioxari.com/)†, [Matt Feiszli](https://scholar.google.com/citations?user=A-wA73gAAAAJ&hl=en&oi=ao)†§, [Jitendra Malik](https://people.eecs.berkeley.edu/~malik/)†§
8
+
9
+ ***Meta Superintelligence Labs***
10
+
11
+ *Core contributor (Alphabetical, Equal Contribution), ⚬Intern, †Project leads, §Equal Contribution
12
+
13
+ [[`Paper`](https://ai.meta.com/research/publications/sam-3d-3dfy-anything-in-images/)] [[`Code`](https://github.com/facebookresearch/sam-3d-objects)] [[`Website`](https://ai.meta.com/sam3d/)] [[`Demo`](https://www.aidemos.meta.com/segment-anything/editor/convert-image-to-3d)] [[`Blog`](https://ai.meta.com/blog/sam-3d/)] [[`BibTeX`](#citing-sam-3d-objects)] [[`Roboflow`](https://blog.roboflow.com/sam-3d/)]
14
+
15
+ **SAM 3D Objects** is a foundation model that reconstructs full 3D shape geometry, texture, and layout from a single image, excelling in real-world scenarios with occlusion and clutter by using progressive training and a data engine with human feedback. It outperforms prior 3D generation models in human preference tests on real-world objects and scenes. We released code, weights, online demo, and a new challenging benchmark.
16
+
17
+
18
+ <p align="center"><img src="doc/intro.png"/></p>
19
+
20
+ -----
21
+
22
+ <p align="center"><img src="doc/arch.png"/></p>
23
+
24
+ ## Latest updates
25
+
26
+ **11/19/2025** - Checkpoints Launched, Web Demo and Paper are out.
27
+
28
+ ## Installation
29
+
30
+ Follow the [setup](doc/setup.md) steps before running the following.
31
+
32
+ ## Single or Multi-Object 3D Generation
33
+
34
+ SAM 3D Objects can convert masked objects in an image, into 3D models with pose, shape, texture, and layout. SAM 3D is designed to be robust in challenging natural images, handling small objects and occlusions, unusual poses, and difficult situations encountered in uncurated natural scenes like this kidsroom:
35
+
36
+ <p align="center">
37
+ <img src="notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png" width="55%"/>
38
+ <img src="doc/kidsroom_transparent.gif" width="40%"/>
39
+ </p>
40
+
41
+ For a quick start, run `python demo.py` or use the the following lines of code:
42
+
43
+ ```python
44
+ import sys
45
+
46
+ # import inference code
47
+ sys.path.append("notebook")
48
+ from inference import Inference, load_image, load_single_mask
49
+
50
+ # load model
51
+ tag = "hf"
52
+ config_path = f"checkpoints/{tag}/pipeline.yaml"
53
+ inference = Inference(config_path, compile=False)
54
+
55
+ # load image and mask
56
+ image = load_image("notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png")
57
+ mask = load_single_mask("notebook/images/shutterstock_stylish_kidsroom_1640806567", index=14)
58
+
59
+ # run model
60
+ output = inference(image, mask, seed=42)
61
+
62
+ # export gaussian splat
63
+ output["gs"].save_ply(f"splat.ply")
64
+ ```
65
+
66
+ For more details and multi-object reconstruction, please take a look at out two jupyter notebooks:
67
+ * [single object](notebook/demo_single_object.ipynb)
68
+ * [multi object](notebook/demo_multi_object.ipynb)
69
+
70
+
71
+ ## SAM 3D Body
72
+
73
+ [SAM 3D Body (3DB)](https://github.com/facebookresearch/sam-3d-body) is a robust promptable foundation model for single-image 3D human mesh recovery (HMR).
74
+
75
+ As a way to combine the strengths of both **SAM 3D Objects** and **SAM 3D Body**, we provide an example notebook that demonstrates how to combine the results of both models such that they are aligned in the same frame of reference. Check it out [here](notebook/demo_3db_mesh_alignment.ipynb).
76
+
77
+ ## License
78
+
79
+ The SAM 3D Objects model checkpoints and code are licensed under [SAM License](./LICENSE).
80
+
81
+ ## Contributing
82
+
83
+ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
84
+
85
+ ## Contributors
86
+
87
+ The SAM 3D Objects project was made possible with the help of many contributors.
88
+
89
+ Robbie Adkins,
90
+ Paris Baptiste,
91
+ Karen Bergan,
92
+ Kai Brown,
93
+ Michelle Chan,
94
+ Ida Cheng,
95
+ Khadijat Durojaiye,
96
+ Patrick Edwards,
97
+ Daniella Factor,
98
+ Facundo Figueroa,
99
+ Rene de la Fuente,
100
+ Eva Galper,
101
+ Cem Gokmen,
102
+ Alex He,
103
+ Enmanuel Hernandez,
104
+ Dex Honsa,
105
+ Leonna Jones,
106
+ Arpit Kalla,
107
+ Kris Kitani,
108
+ Helen Klein,
109
+ Kei Koyama,
110
+ Robert Kuo,
111
+ Vivian Lee,
112
+ Alex Lende,
113
+ Jonny Li,
114
+ Kehan Lyu,
115
+ Faye Ma,
116
+ Mallika Malhotra,
117
+ Sasha Mitts,
118
+ William Ngan,
119
+ George Orlin,
120
+ Peter Park,
121
+ Don Pinkus,
122
+ Roman Radle,
123
+ Nikhila Ravi,
124
+ Azita Shokrpour,
125
+ Jasmine Shone,
126
+ Zayida Suber,
127
+ Phillip Thomas,
128
+ Tatum Turner,
129
+ Joseph Walker,
130
+ Meng Wang,
131
+ Claudette Ward,
132
+ Andrew Westbury,
133
+ Lea Wilken,
134
+ Nan Yang,
135
+ Yael Yungster
136
+
137
+
138
+ ## Citing SAM 3D Objects
139
+
140
+ If you use SAM 3D Objects in your research, please use the following BibTeX entry.
141
+
142
+ ```
143
+ @article{sam3dteam2025sam3d3dfyimages,
144
+ title={SAM 3D: 3Dfy Anything in Images},
145
+ author={SAM 3D Team and Xingyu Chen and Fu-Jen Chu and Pierre Gleize and Kevin J Liang and Alexander Sax and Hao Tang and Weiyao Wang and Michelle Guo and Thibaut Hardin and Xiang Li and Aohan Lin and Jiawei Liu and Ziqi Ma and Anushka Sagar and Bowen Song and Xiaodong Wang and Jianing Yang and Bowen Zhang and Piotr Dollár and Georgia Gkioxari and Matt Feiszli and Jitendra Malik},
146
+ year={2025},
147
+ eprint={2511.16624},
148
+ archivePrefix={arXiv},
149
+ primaryClass={cs.CV},
150
+ url={https://arxiv.org/abs/2511.16624},
151
+ }
152
+ ```
thirdparty/sam3d/sam3d/checkpoints/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
thirdparty/sam3d/sam3d/demo.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ # import inference code
4
+ sys.path.append("notebook")
5
+ from inference import Inference, load_image, load_single_mask
6
+
7
+ # load model
8
+ tag = "hf"
9
+ config_path = f"checkpoints/{tag}/pipeline.yaml"
10
+ inference = Inference(config_path, compile=False)
11
+
12
+ # load image (RGBA only, mask is embedded in the alpha channel)
13
+ image = load_image("notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png")
14
+ mask = load_single_mask("notebook/images/shutterstock_stylish_kidsroom_1640806567", index=14)
15
+
16
+ # run model
17
+ output = inference(image, mask, seed=42)
18
+
19
+ # export gaussian splat
20
+ output["gs"].save_ply(f"splat.ply")
21
+ print("Your reconstruction has been saved to splat.ply")
thirdparty/sam3d/sam3d/doc/setup.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Setup
2
+
3
+ ## Prerequisites
4
+
5
+ * A linux 64-bits architecture (i.e. `linux-64` platform in `mamba info`).
6
+ * A NVIDIA GPU with at least 32 Gb of VRAM.
7
+
8
+ ## 1. Setup Python Environment
9
+
10
+ The following will install the default environment. If you use `conda` instead of `mamba`, replace its name in the first two lines. Note that you may have to build the environment on a compute node with GPU (e.g., you may get a `RuntimeError: Not compiled with GPU support` error when running certain parts of the code that use Pytorch3D).
11
+
12
+ ```bash
13
+ # create sam3d-objects environment
14
+ mamba env create -f environments/default.yml
15
+ mamba activate sam3d-objects
16
+
17
+ # for pytorch/cuda dependencies
18
+ export PIP_EXTRA_INDEX_URL="https://pypi.ngc.nvidia.com https://download.pytorch.org/whl/cu121"
19
+
20
+ # install sam3d-objects and core dependencies
21
+ pip install -e '.[dev]'
22
+ pip install -e '.[p3d]' # pytorch3d dependency on pytorch is broken, this 2-step approach solves it
23
+
24
+ # for inference
25
+ export PIP_FIND_LINKS="https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html"
26
+ pip install -e '.[inference]'
27
+
28
+ # patch things that aren't yet in official pip packages
29
+ ./patching/hydra # https://github.com/facebookresearch/hydra/pull/2863
30
+ ```
31
+
32
+ ## 2. Getting Checkpoints
33
+
34
+ ### From HuggingFace
35
+
36
+ ⚠️ Before using SAM 3D Objects, please request access to the checkpoints on the SAM 3D Objects
37
+ Hugging Face [repo](https://huggingface.co/facebook/sam-3d-objects). Once accepted, you
38
+ need to be authenticated to download the checkpoints. You can do this by running
39
+ the following [steps](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication)
40
+ (e.g. `hf auth login` after generating an access token).
41
+
42
+ ⚠️ SAM 3D Objects is available via HuggingFace globally, **except** in comprehensively sanctioned jurisdictions.
43
+ Sanctioned jurisdiction will result in requests being **rejected**.
44
+
45
+ ```bash
46
+ pip install 'huggingface-hub[cli]<1.0'
47
+
48
+ TAG=hf
49
+ hf download \
50
+ --repo-type model \
51
+ --local-dir checkpoints/${TAG}-download \
52
+ --max-workers 1 \
53
+ facebook/sam-3d-objects
54
+ mv checkpoints/${TAG}-download/checkpoints checkpoints/${TAG}
55
+ rm -rf checkpoints/${TAG}-download
56
+ ```
57
+
58
+
thirdparty/sam3d/sam3d/environments/default.yml ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sam3d-objects
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=conda_forge
6
+ - _openmp_mutex=4.5=2_gnu
7
+ - alsa-lib=1.2.13=hb9d3cd8_0
8
+ - attr=2.5.1=h166bdaf_1
9
+ - binutils=2.43=h4852527_4
10
+ - binutils_impl_linux-64=2.43=h4bf12b8_4
11
+ - binutils_linux-64=2.43=h4852527_4
12
+ - bzip2=1.0.8=h4bc722e_7
13
+ - c-compiler=1.7.0=hd590300_1
14
+ - ca-certificates=2025.1.31=hbcca054_0
15
+ - cairo=1.18.0=h3faef2a_0
16
+ - cuda-cccl=12.1.109=ha770c72_0
17
+ - cuda-cccl-impl=2.0.1=ha770c72_1
18
+ - cuda-cccl_linux-64=12.1.109=ha770c72_0
19
+ - cuda-command-line-tools=12.1.1=ha770c72_0
20
+ - cuda-compiler=12.1.1=hbad6d8a_0
21
+ - cuda-cudart=12.1.105=hd3aeb46_0
22
+ - cuda-cudart-dev=12.1.105=hd3aeb46_0
23
+ - cuda-cudart-dev_linux-64=12.1.105=h59595ed_0
24
+ - cuda-cudart-static=12.1.105=hd3aeb46_0
25
+ - cuda-cudart-static_linux-64=12.1.105=h59595ed_0
26
+ - cuda-cudart_linux-64=12.1.105=h59595ed_0
27
+ - cuda-cuobjdump=12.1.111=h59595ed_0
28
+ - cuda-cupti=12.1.105=h59595ed_0
29
+ - cuda-cupti-dev=12.1.105=h59595ed_0
30
+ - cuda-cuxxfilt=12.1.105=h59595ed_0
31
+ - cuda-driver-dev=12.1.105=hd3aeb46_0
32
+ - cuda-driver-dev_linux-64=12.1.105=h59595ed_0
33
+ - cuda-gdb=12.1.105=hd47b8d6_0
34
+ - cuda-libraries=12.1.1=ha770c72_0
35
+ - cuda-libraries-dev=12.1.1=ha770c72_0
36
+ - cuda-nsight=12.1.105=ha770c72_0
37
+ - cuda-nvcc=12.1.105=hcdd1206_1
38
+ - cuda-nvcc-dev_linux-64=12.1.105=ha770c72_0
39
+ - cuda-nvcc-impl=12.1.105=hd3aeb46_0
40
+ - cuda-nvcc-tools=12.1.105=hd3aeb46_0
41
+ - cuda-nvcc_linux-64=12.1.105=h8a487aa_1
42
+ - cuda-nvdisasm=12.1.105=h59595ed_0
43
+ - cuda-nvml-dev=12.1.105=h59595ed_0
44
+ - cuda-nvprof=12.1.105=h59595ed_0
45
+ - cuda-nvprune=12.1.105=h59595ed_0
46
+ - cuda-nvrtc=12.1.105=hd3aeb46_0
47
+ - cuda-nvrtc-dev=12.1.105=hd3aeb46_0
48
+ - cuda-nvtx=12.1.105=h59595ed_0
49
+ - cuda-nvvp=12.1.105=h59595ed_0
50
+ - cuda-opencl=12.1.105=h59595ed_0
51
+ - cuda-opencl-dev=12.1.105=h59595ed_0
52
+ - cuda-profiler-api=12.1.105=ha770c72_0
53
+ - cuda-sanitizer-api=12.1.105=h59595ed_0
54
+ - cuda-toolkit=12.1.1=ha804496_0
55
+ - cuda-tools=12.1.1=ha770c72_0
56
+ - cuda-version=12.1=h1d6eff3_3
57
+ - cuda-visual-tools=12.1.1=ha770c72_0
58
+ - cxx-compiler=1.7.0=h00ab1b0_1
59
+ - dbus=1.13.6=h5008d03_3
60
+ - expat=2.6.4=h5888daf_0
61
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
62
+ - font-ttf-inconsolata=3.000=h77eed37_0
63
+ - font-ttf-source-code-pro=2.038=h77eed37_0
64
+ - font-ttf-ubuntu=0.83=h77eed37_3
65
+ - fontconfig=2.15.0=h7e30c49_1
66
+ - fonts-conda-ecosystem=1=0
67
+ - fonts-conda-forge=1=0
68
+ - freetype=2.13.3=h48d6fc4_0
69
+ - gcc=12.4.0=h236703b_2
70
+ - gcc_impl_linux-64=12.4.0=h26ba24d_2
71
+ - gcc_linux-64=12.4.0=h6b7512a_8
72
+ - gds-tools=1.6.1.9=hd3aeb46_0
73
+ - gettext=0.23.1=h5888daf_0
74
+ - gettext-tools=0.23.1=h5888daf_0
75
+ - glib=2.82.2=h07242d1_1
76
+ - glib-tools=2.82.2=h4833e2c_1
77
+ - gmp=6.3.0=hac33072_2
78
+ - graphite2=1.3.13=h59595ed_1003
79
+ - gst-plugins-base=1.24.4=h9ad1361_0
80
+ - gstreamer=1.24.4=haf2f30d_0
81
+ - gxx=12.4.0=h236703b_2
82
+ - gxx_impl_linux-64=12.4.0=h3ff227c_2
83
+ - gxx_linux-64=12.4.0=h8489865_8
84
+ - harfbuzz=8.5.0=hfac3d4d_0
85
+ - icu=73.2=h59595ed_0
86
+ - kernel-headers_linux-64=3.10.0=he073ed8_18
87
+ - keyutils=1.6.1=h166bdaf_0
88
+ - krb5=1.21.3=h659f571_0
89
+ - lame=3.100=h166bdaf_1003
90
+ - ld_impl_linux-64=2.43=h712a8e2_4
91
+ - libasprintf=0.23.1=h8e693c7_0
92
+ - libasprintf-devel=0.23.1=h8e693c7_0
93
+ - libcap=2.75=h39aace5_0
94
+ - libclang-cpp15=15.0.7=default_h127d8a8_5
95
+ - libclang13=19.1.2=default_h9c6a7e4_1
96
+ - libcublas=12.1.3.1=hd3aeb46_0
97
+ - libcublas-dev=12.1.3.1=hd3aeb46_0
98
+ - libcufft=11.0.2.54=hd3aeb46_0
99
+ - libcufft-dev=11.0.2.54=hd3aeb46_0
100
+ - libcufile=1.6.1.9=hd3aeb46_0
101
+ - libcufile-dev=1.6.1.9=hd3aeb46_0
102
+ - libcups=2.3.3=h4637d8d_4
103
+ - libcurand=10.3.2.106=hd3aeb46_0
104
+ - libcurand-dev=10.3.2.106=hd3aeb46_0
105
+ - libcusolver=11.4.5.107=hd3aeb46_0
106
+ - libcusolver-dev=11.4.5.107=hd3aeb46_0
107
+ - libcusparse=12.1.0.106=hd3aeb46_0
108
+ - libcusparse-dev=12.1.0.106=hd3aeb46_0
109
+ - libedit=3.1.20250104=pl5321h7949ede_0
110
+ - libevent=2.1.12=hf998b51_1
111
+ - libexpat=2.6.4=h5888daf_0
112
+ - libffi=3.4.6=h2dba641_0
113
+ - libflac=1.4.3=h59595ed_0
114
+ - libgcc=14.2.0=h767d61c_2
115
+ - libgcc-devel_linux-64=12.4.0=h1762d19_102
116
+ - libgcc-ng=14.2.0=h69a702a_2
117
+ - libgcrypt-lib=1.11.0=hb9d3cd8_2
118
+ - libgettextpo=0.23.1=h5888daf_0
119
+ - libgettextpo-devel=0.23.1=h5888daf_0
120
+ - libglib=2.82.2=h2ff4ddf_1
121
+ - libgomp=14.2.0=h767d61c_2
122
+ - libgpg-error=1.51=hbd13f7d_1
123
+ - libiconv=1.18=h4ce23a2_1
124
+ - libjpeg-turbo=3.0.0=hd590300_1
125
+ - libllvm15=15.0.7=hb3ce162_4
126
+ - libllvm19=19.1.2=ha7bfdaf_0
127
+ - liblzma=5.6.4=hb9d3cd8_0
128
+ - liblzma-devel=5.6.4=hb9d3cd8_0
129
+ - libnpp=12.1.0.40=hd3aeb46_0
130
+ - libnpp-dev=12.1.0.40=hd3aeb46_0
131
+ - libnsl=2.0.1=hd590300_0
132
+ - libnuma=2.0.18=h4ab18f5_2
133
+ - libnvjitlink=12.1.105=hd3aeb46_0
134
+ - libnvjitlink-dev=12.1.105=hd3aeb46_0
135
+ - libnvjpeg=12.2.0.2=h59595ed_0
136
+ - libnvjpeg-dev=12.2.0.2=ha770c72_0
137
+ - libogg=1.3.5=h4ab18f5_0
138
+ - libopus=1.3.1=h7f98852_1
139
+ - libpng=1.6.47=h943b412_0
140
+ - libpq=16.8=h87c4ccc_0
141
+ - libsanitizer=12.4.0=ha732cd4_2
142
+ - libsndfile=1.2.2=hc60ed4a_1
143
+ - libsqlite=3.49.1=hee588c1_2
144
+ - libstdcxx=14.2.0=h8f9b012_2
145
+ - libstdcxx-devel_linux-64=12.4.0=h1762d19_102
146
+ - libstdcxx-ng=14.2.0=h4852527_2
147
+ - libsystemd0=257.4=h4e0b6ca_1
148
+ - libuuid=2.38.1=h0b41bf4_0
149
+ - libvorbis=1.3.7=h9c3ff4c_0
150
+ - libxcb=1.15=h0b41bf4_0
151
+ - libxkbcommon=1.7.0=h662e7e4_0
152
+ - libxkbfile=1.1.0=h166bdaf_1
153
+ - libxml2=2.12.7=h4c95cb1_3
154
+ - libzlib=1.3.1=hb9d3cd8_2
155
+ - lz4-c=1.10.0=h5888daf_1
156
+ - mpg123=1.32.9=hc50e24c_0
157
+ - mysql-common=8.3.0=h70512c7_5
158
+ - mysql-libs=8.3.0=ha479ceb_5
159
+ - ncurses=6.5=h2d0b736_3
160
+ - nsight-compute=2023.1.1.4=h3718151_0
161
+ - nspr=4.36=h5888daf_0
162
+ - nss=3.108=h159eef7_0
163
+ - ocl-icd=2.3.2=hb9d3cd8_2
164
+ - opencl-headers=2024.10.24=h5888daf_0
165
+ - openssl=3.4.1=h7b32b05_0
166
+ - packaging=24.2=pyhd8ed1ab_2
167
+ - pcre2=10.44=hba22ea6_2
168
+ - pip=25.0.1=pyh8b19718_0
169
+ - pixman=0.44.2=h29eaf8c_0
170
+ - pthread-stubs=0.4=hb9d3cd8_1002
171
+ - pulseaudio-client=17.0=hb77b528_0
172
+ - python=3.11.0=he550d4f_1_cpython
173
+ - qt-main=5.15.8=hc9dc06e_21
174
+ - readline=8.2=h8c095d6_2
175
+ - setuptools=75.8.2=pyhff2d567_0
176
+ - sysroot_linux-64=2.17=h0157908_18
177
+ - tk=8.6.13=noxft_h4845f30_101
178
+ - tzdata=2025b=h78e105d_0
179
+ - wayland=1.23.1=h3e06ad9_0
180
+ - wheel=0.45.1=pyhd8ed1ab_1
181
+ - xcb-util=0.4.0=hd590300_1
182
+ - xcb-util-image=0.4.0=h8ee46fc_1
183
+ - xcb-util-keysyms=0.4.0=h8ee46fc_1
184
+ - xcb-util-renderutil=0.3.9=hd590300_1
185
+ - xcb-util-wm=0.4.1=h8ee46fc_1
186
+ - xkeyboard-config=2.42=h4ab18f5_0
187
+ - xorg-compositeproto=0.4.2=hb9d3cd8_1002
188
+ - xorg-damageproto=1.2.1=hb9d3cd8_1003
189
+ - xorg-fixesproto=5.0=hb9d3cd8_1003
190
+ - xorg-inputproto=2.3.2=hb9d3cd8_1003
191
+ - xorg-kbproto=1.0.7=hb9d3cd8_1003
192
+ - xorg-libice=1.1.2=hb9d3cd8_0
193
+ - xorg-libsm=1.2.6=he73a12e_0
194
+ - xorg-libx11=1.8.9=h8ee46fc_0
195
+ - xorg-libxau=1.0.12=hb9d3cd8_0
196
+ - xorg-libxcomposite=0.4.6=h0b41bf4_1
197
+ - xorg-libxdamage=1.1.5=h7f98852_1
198
+ - xorg-libxdmcp=1.1.5=hb9d3cd8_0
199
+ - xorg-libxext=1.3.4=h0b41bf4_2
200
+ - xorg-libxfixes=5.0.3=h7f98852_1004
201
+ - xorg-libxi=1.7.10=h4bc722e_1
202
+ - xorg-libxrandr=1.5.2=h7f98852_1
203
+ - xorg-libxrender=0.9.11=hd590300_0
204
+ - xorg-libxtst=1.2.5=h4bc722e_0
205
+ - xorg-randrproto=1.5.0=hb9d3cd8_1002
206
+ - xorg-recordproto=1.14.2=hb9d3cd8_1003
207
+ - xorg-renderproto=0.11.1=hb9d3cd8_1003
208
+ - xorg-util-macros=1.20.2=hb9d3cd8_0
209
+ - xorg-xextproto=7.3.0=hb9d3cd8_1004
210
+ - xorg-xf86vidmodeproto=2.3.1=hb9d3cd8_1005
211
+ - xorg-xproto=7.0.31=hb9d3cd8_1008
212
+ - xz=5.6.4=hbcc6ac9_0
213
+ - xz-gpl-tools=5.6.4=hbcc6ac9_0
214
+ - xz-tools=5.6.4=hb9d3cd8_0
215
+ - zlib=1.3.1=hb9d3cd8_2
216
+ - zstd=1.5.7=hb8e6e7a_2
thirdparty/sam3d/sam3d/notebook/demo_3db_mesh_alignment.ipynb ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# SAM 3D Body (3DB) Mesh Alignment to SAM 3D Object Scale\n",
8
+ "\n",
9
+ "This notebook processes a single 3DB mesh and aligns it to the SAM 3D Objects scale.\n",
10
+ "\n",
11
+ "**Input Data:**\n",
12
+ "- `images/human_object/image.jpg` - Input image for MoGe\n",
13
+ "- `meshes/human_object/3DB_results/mask_human.png` - Human mask\n",
14
+ "- `meshes/human_object/3DB_results/human.ply` - Single 3DB mesh in OpenGL coordinates\n",
15
+ "- `meshes/human_object/3DB_results/focal_length.json` - 3DB focal length\n",
16
+ "\n",
17
+ "**Output:**\n",
18
+ "- `meshes/human_object/aligned_meshes/human_aligned.ply` - Aligned 3DB mesh in OpenGL coordinates"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "import os\n",
28
+ "import torch\n",
29
+ "import matplotlib.pyplot as plt\n",
30
+ "from PIL import Image\n",
31
+ "from mesh_alignment import process_and_save_alignment\n",
32
+ "\n",
33
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
34
+ "print(f\"Using device: {device}\")\n",
35
+ "PATH = os.getcwd()\n",
36
+ "print(f\"Current working directory: {PATH}\")\n",
37
+ "\n",
38
+ "# Please inference the SAM 3D Body (3DB) Repo (https://github.com/facebookresearch/sam-3d-body) to get the 3DB Results\n",
39
+ "image_path = f\"{PATH}/images/human_object/image.png\"\n",
40
+ "mask_path = f\"{PATH}/meshes/human_object/3DB_results/mask_human.png\"\n",
41
+ "mesh_path = f\"{PATH}/meshes/human_object/3DB_results/human.ply\"\n",
42
+ "focal_length_json_path = f\"{PATH}/meshes/human_object/3DB_results/focal_length.json\"\n",
43
+ "output_dir = f\"{PATH}/meshes/human_object/aligned_meshes\"\n",
44
+ "os.makedirs(output_dir, exist_ok=True)\n"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "metadata": {},
50
+ "source": [
51
+ "## 1. Load and Display Input Data"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "input_image = Image.open(image_path)\n",
61
+ "mask = Image.open(mask_path).convert('L')\n",
62
+ "fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
63
+ "axes[0].imshow(input_image)\n",
64
+ "axes[0].set_title('Input Image')\n",
65
+ "axes[0].axis('off')\n",
66
+ "axes[1].imshow(mask, cmap='gray')\n",
67
+ "axes[1].set_title('Mask')\n",
68
+ "axes[1].axis('off')\n",
69
+ "plt.tight_layout()\n",
70
+ "plt.show()"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "metadata": {},
76
+ "source": [
77
+ "## 2. Process and Save Aligned Mesh"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "\n",
87
+ "success, output_mesh_path, result = process_and_save_alignment(\n",
88
+ " mesh_path=mesh_path,\n",
89
+ " mask_path=mask_path,\n",
90
+ " image_path=image_path,\n",
91
+ " output_dir=output_dir,\n",
92
+ " device=device,\n",
93
+ " focal_length_json_path=focal_length_json_path\n",
94
+ ")\n",
95
+ "\n",
96
+ "if success:\n",
97
+ " print(f\"Alignment completed successfully! Output: {output_mesh_path}\")\n",
98
+ "else:\n",
99
+ " print(\"Alignment failed!\")"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "markdown",
104
+ "metadata": {},
105
+ "source": [
106
+ "## 3. Interactive 3D Visualization\n"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "from mesh_alignment import visualize_meshes_interactive\n",
116
+ "\n",
117
+ "aligned_mesh_path = f\"{PATH}/meshes/human_object/aligned_meshes/human_aligned.ply\"\n",
118
+ "dfy_mesh_path = f\"{PATH}/meshes/human_object/3Dfy_results/0.glb\"\n",
119
+ "\n",
120
+ "demo, combined_glb_path = visualize_meshes_interactive(\n",
121
+ " aligned_mesh_path=aligned_mesh_path,\n",
122
+ " dfy_mesh_path=dfy_mesh_path,\n",
123
+ " share=True\n",
124
+ ")"
125
+ ]
126
+ }
127
+ ],
128
+ "metadata": {
129
+ "kernelspec": {
130
+ "display_name": "sam3d_objects-3dfy",
131
+ "language": "python",
132
+ "name": "python3"
133
+ },
134
+ "language_info": {
135
+ "codemirror_mode": {
136
+ "name": "ipython",
137
+ "version": 3
138
+ },
139
+ "file_extension": ".py",
140
+ "mimetype": "text/x-python",
141
+ "name": "python",
142
+ "nbconvert_exporter": "python",
143
+ "pygments_lexer": "ipython3",
144
+ "version": "3.11.0"
145
+ }
146
+ },
147
+ "nbformat": 4,
148
+ "nbformat_minor": 4
149
+ }
thirdparty/sam3d/sam3d/notebook/demo_multi_object.ipynb ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# Copyright (c) Meta Platforms, Inc. and affiliates."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "metadata": {},
15
+ "source": [
16
+ "## 1. Imports and Model Loading"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "import os\n",
26
+ "import uuid\n",
27
+ "import imageio\n",
28
+ "import numpy as np\n",
29
+ "from IPython.display import Image as ImageDisplay\n",
30
+ "\n",
31
+ "from inference import Inference, ready_gaussian_for_video_rendering, load_image, load_masks, display_image, make_scene, render_video, interactive_visualizer"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "PATH = os.getcwd()\n",
41
+ "TAG = \"hf\"\n",
42
+ "config_path = f\"{PATH}/../checkpoints/{TAG}/pipeline.yaml\"\n",
43
+ "inference = Inference(config_path, compile=False)"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "metadata": {},
49
+ "source": [
50
+ "## 2. Load input image to lift to 3D (multiple objects)"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "IMAGE_PATH = f\"{PATH}/images/shutterstock_stylish_kidsroom_1640806567/image.png\"\n",
60
+ "IMAGE_NAME = os.path.basename(os.path.dirname(IMAGE_PATH))\n",
61
+ "\n",
62
+ "image = load_image(IMAGE_PATH)\n",
63
+ "masks = load_masks(os.path.dirname(IMAGE_PATH), extension=\".png\")\n",
64
+ "display_image(image, masks)"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "metadata": {},
70
+ "source": [
71
+ "## 3. Generate Gaussian Splats"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "outputs = [inference(image, mask, seed=42) for mask in masks]"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "metadata": {},
86
+ "source": [
87
+ "## 4. Visualize Gaussian Splat of the Scene\n",
88
+ "### a. Animated Gif"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "scene_gs = make_scene(*outputs)\n",
98
+ "scene_gs = ready_gaussian_for_video_rendering(scene_gs)\n",
99
+ "\n",
100
+ "# export gaussian splatting (as point cloud)\n",
101
+ "scene_gs.save_ply(f\"{PATH}/gaussians/multi/{IMAGE_NAME}.ply\")\n",
102
+ "\n",
103
+ "video = render_video(\n",
104
+ " scene_gs,\n",
105
+ " r=1,\n",
106
+ " fov=60,\n",
107
+ " resolution=512,\n",
108
+ ")[\"color\"]\n",
109
+ "\n",
110
+ "# save video as gif\n",
111
+ "imageio.mimsave(\n",
112
+ " os.path.join(f\"{PATH}/gaussians/multi/{IMAGE_NAME}.gif\"),\n",
113
+ " video,\n",
114
+ " format=\"GIF\",\n",
115
+ " duration=1000 / 30, # default assuming 30fps from the input MP4\n",
116
+ " loop=0, # 0 means loop indefinitely\n",
117
+ ")\n",
118
+ "\n",
119
+ "# notebook display\n",
120
+ "ImageDisplay(url=f\"gaussians/multi/{IMAGE_NAME}.gif?cache_invalidator={uuid.uuid4()}\",)"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "markdown",
125
+ "metadata": {},
126
+ "source": [
127
+ "### b. Interactive Visualizer"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "# might take a while to load (black screen)\n",
137
+ "interactive_visualizer(f\"{PATH}/gaussians/multi/{IMAGE_NAME}.ply\")"
138
+ ]
139
+ }
140
+ ],
141
+ "metadata": {
142
+ "kernelspec": {
143
+ "display_name": "sam3d-objects",
144
+ "language": "python",
145
+ "name": "python3"
146
+ },
147
+ "language_info": {
148
+ "codemirror_mode": {
149
+ "name": "ipython",
150
+ "version": 3
151
+ },
152
+ "file_extension": ".py",
153
+ "mimetype": "text/x-python",
154
+ "name": "python",
155
+ "nbconvert_exporter": "python",
156
+ "pygments_lexer": "ipython3",
157
+ "version": "3.11.0"
158
+ }
159
+ },
160
+ "nbformat": 4,
161
+ "nbformat_minor": 2
162
+ }
thirdparty/sam3d/sam3d/notebook/demo_single_object.ipynb ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# Copyright (c) Meta Platforms, Inc. and affiliates."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "metadata": {},
15
+ "source": [
16
+ "## 1. Imports and Model Loading"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "import os\n",
26
+ "import imageio\n",
27
+ "import uuid\n",
28
+ "from IPython.display import Image as ImageDisplay\n",
29
+ "from inference import Inference, ready_gaussian_for_video_rendering, render_video, load_image, load_single_mask, display_image, make_scene, interactive_visualizer"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "PATH = os.getcwd()\n",
39
+ "TAG = \"hf\"\n",
40
+ "config_path = f\"{PATH}/../checkpoints/{TAG}/pipeline.yaml\"\n",
41
+ "inference = Inference(config_path, compile=False)"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "metadata": {},
47
+ "source": [
48
+ "## 2. Load input image to lift to 3D (single object)"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "IMAGE_PATH = f\"{PATH}/images/shutterstock_stylish_kidsroom_1640806567/image.png\"\n",
58
+ "IMAGE_NAME = os.path.basename(os.path.dirname(IMAGE_PATH))\n",
59
+ "\n",
60
+ "image = load_image(IMAGE_PATH)\n",
61
+ "mask = load_single_mask(os.path.dirname(IMAGE_PATH), index=14)\n",
62
+ "display_image(image, masks=[mask])"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {},
68
+ "source": [
69
+ "## 3. Generate Gaussian Splat"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# run model\n",
79
+ "output = inference(image, mask, seed=42)\n",
80
+ "\n",
81
+ "# export gaussian splat (as point cloud)\n",
82
+ "output[\"gs\"].save_ply(f\"{PATH}/gaussians/single/{IMAGE_NAME}.ply\")"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {},
88
+ "source": [
89
+ "## 4. Visualize Gaussian Splat\n",
90
+ "### a. Animated Gif"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "# render gaussian splat\n",
100
+ "scene_gs = make_scene(output)\n",
101
+ "scene_gs = ready_gaussian_for_video_rendering(scene_gs)\n",
102
+ "\n",
103
+ "video = render_video(\n",
104
+ " scene_gs,\n",
105
+ " r=1,\n",
106
+ " fov=60,\n",
107
+ " pitch_deg=15,\n",
108
+ " yaw_start_deg=-45,\n",
109
+ " resolution=512,\n",
110
+ ")[\"color\"]\n",
111
+ "\n",
112
+ "# save video as gif\n",
113
+ "imageio.mimsave(\n",
114
+ " os.path.join(f\"{PATH}/gaussians/single/{IMAGE_NAME}.gif\"),\n",
115
+ " video,\n",
116
+ " format=\"GIF\",\n",
117
+ " duration=1000 / 30, # default assuming 30fps from the input MP4\n",
118
+ " loop=0, # 0 means loop indefinitely\n",
119
+ ")\n",
120
+ "\n",
121
+ "# notebook display\n",
122
+ "ImageDisplay(url=f\"gaussians/single/{IMAGE_NAME}.gif?cache_invalidator={uuid.uuid4()}\")"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "metadata": {},
128
+ "source": [
129
+ "### b. Interactive Visualizer"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "# might take a while to load (black screen)\n",
139
+ "interactive_visualizer(f\"{PATH}/gaussians/single/{IMAGE_NAME}.ply\")"
140
+ ]
141
+ }
142
+ ],
143
+ "metadata": {
144
+ "kernelspec": {
145
+ "display_name": "sam3d_objects-3dfy",
146
+ "language": "python",
147
+ "name": "python3"
148
+ },
149
+ "language_info": {
150
+ "codemirror_mode": {
151
+ "name": "ipython",
152
+ "version": 3
153
+ },
154
+ "file_extension": ".py",
155
+ "mimetype": "text/x-python",
156
+ "name": "python",
157
+ "nbconvert_exporter": "python",
158
+ "pygments_lexer": "ipython3",
159
+ "version": "3.11.0"
160
+ }
161
+ },
162
+ "nbformat": 4,
163
+ "nbformat_minor": 2
164
+ }
thirdparty/sam3d/sam3d/notebook/inference.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import os
3
+
4
+ # not ideal to put that here
5
+ os.environ["CUDA_HOME"] = os.environ["CONDA_PREFIX"]
6
+ os.environ["LIDRA_SKIP_INIT"] = "true"
7
+
8
+ import sys
9
+ from typing import Union, Optional, List, Callable
10
+ import numpy as np
11
+ from PIL import Image
12
+ from omegaconf import OmegaConf, DictConfig, ListConfig
13
+ from hydra.utils import instantiate, get_method
14
+ import torch
15
+ import math
16
+ import utils3d
17
+ import shutil
18
+ import subprocess
19
+ import seaborn as sns
20
+ from PIL import Image
21
+ import numpy as np
22
+ import gradio as gr
23
+ import matplotlib.pyplot as plt
24
+ from copy import deepcopy
25
+ from kaolin.visualize import IpyTurntableVisualizer
26
+ from kaolin.render.camera import Camera, CameraExtrinsics, PinholeIntrinsics
27
+ import builtins
28
+ from pytorch3d.transforms import quaternion_multiply, quaternion_invert
29
+
30
+ import sam3d_objects # REMARK(Pierre) : do not remove this import
31
+ from sam3d_objects.pipeline.inference_pipeline_pointmap import InferencePipelinePointMap
32
+ from sam3d_objects.model.backbone.tdfy_dit.utils import render_utils
33
+
34
+ from sam3d_objects.utils.visualization import SceneVisualizer
35
+
36
+ __all__ = ["Inference"]
37
+
38
+ WHITELIST_FILTERS = [
39
+ lambda target: target.split(".", 1)[0] in {"sam3d_objects", "torch", "torchvision", "moge"},
40
+ ]
41
+
42
+ BLACKLIST_FILTERS = [
43
+ lambda target: get_method(target)
44
+ in {
45
+ builtins.exec,
46
+ builtins.eval,
47
+ builtins.__import__,
48
+ os.kill,
49
+ os.system,
50
+ os.putenv,
51
+ os.remove,
52
+ os.removedirs,
53
+ os.rmdir,
54
+ os.fchdir,
55
+ os.setuid,
56
+ os.fork,
57
+ os.forkpty,
58
+ os.killpg,
59
+ os.rename,
60
+ os.renames,
61
+ os.truncate,
62
+ os.replace,
63
+ os.unlink,
64
+ os.fchmod,
65
+ os.fchown,
66
+ os.chmod,
67
+ os.chown,
68
+ os.chroot,
69
+ os.fchdir,
70
+ os.lchown,
71
+ os.getcwd,
72
+ os.chdir,
73
+ shutil.rmtree,
74
+ shutil.move,
75
+ shutil.chown,
76
+ subprocess.Popen,
77
+ builtins.help,
78
+ },
79
+ ]
80
+
81
+
82
+ class Inference:
83
+ # public facing inference API
84
+ # only put publicly exposed arguments here
85
+ def __init__(self, config_file: str, compile: bool = False):
86
+ # load inference pipeline
87
+ config = OmegaConf.load(config_file)
88
+ config.rendering_engine = "pytorch3d" # overwrite to disable nvdiffrast
89
+ config.compile_model = compile
90
+ config.workspace_dir = os.path.dirname(config_file)
91
+ check_hydra_safety(config, WHITELIST_FILTERS, BLACKLIST_FILTERS)
92
+ self._pipeline: InferencePipelinePointMap = instantiate(config)
93
+
94
+ def merge_mask_to_rgba(self, image, mask):
95
+ mask = mask.astype(np.uint8) * 255
96
+ mask = mask[..., None]
97
+ # embed mask in alpha channel
98
+ rgba_image = np.concatenate([image[..., :3], mask], axis=-1)
99
+ return rgba_image
100
+
101
+ def __call__(
102
+ self,
103
+ image: Union[Image.Image, np.ndarray],
104
+ mask: Optional[Union[None, Image.Image, np.ndarray]],
105
+ seed: Optional[int] = None,
106
+ pointmap=None,
107
+ ) -> dict:
108
+ image = self.merge_mask_to_rgba(image, mask)
109
+ return self._pipeline.run(
110
+ image,
111
+ None,
112
+ seed,
113
+ stage1_only=False,
114
+ with_mesh_postprocess=False,
115
+ with_texture_baking=False,
116
+ with_layout_postprocess=True,
117
+ use_vertex_color=True,
118
+ stage1_inference_steps=None,
119
+ pointmap=pointmap,
120
+ )
121
+
122
+
123
+ def _yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
124
+ is_list = isinstance(yaws, list)
125
+ if not is_list:
126
+ yaws = [yaws]
127
+ pitchs = [pitchs]
128
+ if not isinstance(rs, list):
129
+ rs = [rs] * len(yaws)
130
+ if not isinstance(fovs, list):
131
+ fovs = [fovs] * len(yaws)
132
+ extrinsics = []
133
+ intrinsics = []
134
+ for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
135
+ fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
136
+ yaw = torch.tensor(float(yaw)).cuda()
137
+ pitch = torch.tensor(float(pitch)).cuda()
138
+ orig = (
139
+ torch.tensor(
140
+ [
141
+ torch.sin(yaw) * torch.cos(pitch),
142
+ torch.sin(pitch),
143
+ torch.cos(yaw) * torch.cos(pitch),
144
+ ]
145
+ ).cuda()
146
+ * r
147
+ )
148
+ extr = utils3d.torch.extrinsics_look_at(
149
+ orig,
150
+ torch.tensor([0, 0, 0]).float().cuda(),
151
+ torch.tensor([0, 1, 0]).float().cuda(),
152
+ )
153
+ intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
154
+ extrinsics.append(extr)
155
+ intrinsics.append(intr)
156
+ if not is_list:
157
+ extrinsics = extrinsics[0]
158
+ intrinsics = intrinsics[0]
159
+ return extrinsics, intrinsics
160
+
161
+
162
+ def render_video(
163
+ sample,
164
+ resolution=512,
165
+ bg_color=(0, 0, 0),
166
+ num_frames=300,
167
+ r=2.0,
168
+ fov=40,
169
+ pitch_deg=0,
170
+ yaw_start_deg=-90,
171
+ **kwargs,
172
+ ):
173
+
174
+ yaws = (
175
+ torch.linspace(0, 2 * torch.pi, num_frames) + math.radians(yaw_start_deg)
176
+ ).tolist()
177
+ pitch = [math.radians(pitch_deg)] * num_frames
178
+
179
+ extr, intr = _yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
180
+
181
+ return render_utils.render_frames(
182
+ sample,
183
+ extr,
184
+ intr,
185
+ {"resolution": resolution, "bg_color": bg_color, "backend": "gsplat"},
186
+ **kwargs,
187
+ )
188
+
189
+
190
+ def ready_gaussian_for_video_rendering(scene_gs, in_place=False, fix_alignment=False):
191
+ if fix_alignment:
192
+ scene_gs = _fix_gaussian_alignment(scene_gs, in_place=in_place)
193
+ scene_gs = normalized_gaussian(scene_gs, in_place=fix_alignment)
194
+ return scene_gs
195
+
196
+
197
+ def _fix_gaussian_alignment(scene_gs, in_place=False):
198
+ if not in_place:
199
+ scene_gs = deepcopy(scene_gs)
200
+
201
+ device = scene_gs._xyz.device
202
+ dtype = scene_gs._xyz.dtype
203
+ scene_gs._xyz = (
204
+ scene_gs._xyz
205
+ @ torch.tensor(
206
+ [
207
+ [-1, 0, 0],
208
+ [0, 0, 1],
209
+ [0, 1, 0],
210
+ ],
211
+ device=device,
212
+ dtype=dtype,
213
+ ).T
214
+ )
215
+ return scene_gs
216
+
217
+
218
+ def normalized_gaussian(scene_gs, in_place=False, outlier_percentile=None):
219
+ if not in_place:
220
+ scene_gs = deepcopy(scene_gs)
221
+
222
+ orig_xyz = scene_gs.get_xyz
223
+ orig_scale = scene_gs.get_scaling
224
+
225
+ active_mask = (scene_gs.get_opacity > 0.9).squeeze()
226
+ inv_scale = (
227
+ orig_xyz[active_mask].max(dim=0)[0] - orig_xyz[active_mask].min(dim=0)[0]
228
+ ).max()
229
+ norm_scale = orig_scale / inv_scale
230
+ norm_xyz = orig_xyz / inv_scale
231
+
232
+ if outlier_percentile is None:
233
+ lower_bound_xyz = torch.min(norm_xyz[active_mask], dim=0)[0]
234
+ upper_bound_xyz = torch.max(norm_xyz[active_mask], dim=0)[0]
235
+ else:
236
+ lower_bound_xyz = torch.quantile(
237
+ norm_xyz[active_mask],
238
+ outlier_percentile,
239
+ dim=0,
240
+ )
241
+ upper_bound_xyz = torch.quantile(
242
+ norm_xyz[active_mask],
243
+ 1.0 - outlier_percentile,
244
+ dim=0,
245
+ )
246
+
247
+ center = (lower_bound_xyz + upper_bound_xyz) / 2
248
+ norm_xyz = norm_xyz - center
249
+ scene_gs.from_xyz(norm_xyz)
250
+ scene_gs.mininum_kernel_size /= inv_scale.item()
251
+ scene_gs.from_scaling(norm_scale)
252
+ return scene_gs
253
+
254
+
255
+ def make_scene(*outputs, in_place=False):
256
+ if not in_place:
257
+ outputs = [deepcopy(output) for output in outputs]
258
+
259
+ all_outs = []
260
+ minimum_kernel_size = float("inf")
261
+ for output in outputs:
262
+ # move gaussians to scene frame of reference
263
+ PC = SceneVisualizer.object_pointcloud(
264
+ points_local=output["gaussian"][0].get_xyz.unsqueeze(0),
265
+ quat_l2c=output["rotation"],
266
+ trans_l2c=output["translation"],
267
+ scale_l2c=output["scale"],
268
+ )
269
+ output["gaussian"][0].from_xyz(PC.points_list()[0])
270
+ # must ... ROTATE
271
+ output["gaussian"][0].from_rotation(
272
+ quaternion_multiply(
273
+ quaternion_invert(output["rotation"]),
274
+ output["gaussian"][0].get_rotation,
275
+ )
276
+ )
277
+ scale = output["gaussian"][0].get_scaling
278
+ adjusted_scale = scale * output["scale"]
279
+ assert (
280
+ output["scale"][0, 0].item()
281
+ == output["scale"][0, 1].item()
282
+ == output["scale"][0, 2].item()
283
+ )
284
+ output["gaussian"][0].mininum_kernel_size *= output["scale"][0, 0].item()
285
+ adjusted_scale = torch.maximum(
286
+ adjusted_scale,
287
+ torch.tensor(
288
+ output["gaussian"][0].mininum_kernel_size * 1.1,
289
+ device=adjusted_scale.device,
290
+ ),
291
+ )
292
+ output["gaussian"][0].from_scaling(adjusted_scale)
293
+ minimum_kernel_size = min(
294
+ minimum_kernel_size,
295
+ output["gaussian"][0].mininum_kernel_size,
296
+ )
297
+ all_outs.append(output)
298
+
299
+ # merge gaussians
300
+ scene_gs = all_outs[0]["gaussian"][0]
301
+ scene_gs.mininum_kernel_size = minimum_kernel_size
302
+ for out in all_outs[1:]:
303
+ out_gs = out["gaussian"][0]
304
+ scene_gs._xyz = torch.cat([scene_gs._xyz, out_gs._xyz], dim=0)
305
+ scene_gs._features_dc = torch.cat(
306
+ [scene_gs._features_dc, out_gs._features_dc], dim=0
307
+ )
308
+ scene_gs._scaling = torch.cat([scene_gs._scaling, out_gs._scaling], dim=0)
309
+ scene_gs._rotation = torch.cat([scene_gs._rotation, out_gs._rotation], dim=0)
310
+ scene_gs._opacity = torch.cat([scene_gs._opacity, out_gs._opacity], dim=0)
311
+
312
+ return scene_gs
313
+
314
+
315
+ def check_target(
316
+ target: str,
317
+ whitelist_filters: List[Callable],
318
+ blacklist_filters: List[Callable],
319
+ ):
320
+ if any(filt(target) for filt in whitelist_filters):
321
+ if not any(filt(target) for filt in blacklist_filters):
322
+ return
323
+ raise RuntimeError(
324
+ f"target '{target}' is not allowed to be hydra instantiated, if this is a mistake, please do modify the whitelist_filters / blacklist_filters"
325
+ )
326
+
327
+
328
+ def check_hydra_safety(
329
+ config: DictConfig,
330
+ whitelist_filters: List[Callable],
331
+ blacklist_filters: List[Callable],
332
+ ):
333
+ to_check = [config]
334
+ while len(to_check) > 0:
335
+ node = to_check.pop()
336
+ if isinstance(node, DictConfig):
337
+ to_check.extend(list(node.values()))
338
+ if "_target_" in node:
339
+ check_target(node["_target_"], whitelist_filters, blacklist_filters)
340
+ elif isinstance(node, ListConfig):
341
+ to_check.extend(list(node))
342
+
343
+
344
+ def load_image(path):
345
+ image = Image.open(path)
346
+ image = np.array(image)
347
+ image = image.astype(np.uint8)
348
+ return image
349
+
350
+
351
+ def load_mask(path):
352
+ mask = load_image(path)
353
+ mask = mask > 0
354
+ if mask.ndim == 3:
355
+ mask = mask[..., -1]
356
+ return mask
357
+
358
+
359
+ def load_single_mask(folder_path, index=0, extension=".png"):
360
+ masks = load_masks(folder_path, [index], extension)
361
+ return masks[0]
362
+
363
+
364
+ def load_masks(folder_path, indices_list=None, extension=".png"):
365
+ masks = []
366
+ indices_list = [] if indices_list is None else list(indices_list)
367
+ if not len(indices_list) > 0: # get all all masks if not provided
368
+ idx = 0
369
+ while os.path.exists(os.path.join(folder_path, f"{idx}{extension}")):
370
+ indices_list.append(idx)
371
+ idx += 1
372
+
373
+ for idx in indices_list:
374
+ mask_path = os.path.join(folder_path, f"{idx}{extension}")
375
+ assert os.path.exists(mask_path), f"Mask path {mask_path} does not exist"
376
+ mask = load_mask(mask_path)
377
+ masks.append(mask)
378
+ return masks
379
+
380
+
381
+ def display_image(image, masks=None):
382
+ def imshow(image, ax):
383
+ ax.axis("off")
384
+ ax.imshow(image)
385
+
386
+ grid = (1, 1) if masks is None else (2, 2)
387
+ fig, axes = plt.subplots(*grid)
388
+ if masks is not None:
389
+ mask_colors = sns.color_palette("husl", len(masks))
390
+ black_image = np.zeros_like(image[..., :3], dtype=float) # background
391
+ mask_display = np.copy(black_image)
392
+ mask_union = np.zeros_like(image[..., :3])
393
+ for i, mask in enumerate(masks):
394
+ mask_display[mask] = mask_colors[i]
395
+ mask_union |= mask[..., None] if mask.ndim == 2 else mask
396
+ imshow(black_image, axes[0, 1])
397
+ imshow(mask_display, axes[1, 0])
398
+ imshow(image * mask_union, axes[1, 1])
399
+
400
+ image_axe = axes if masks is None else axes[0, 0]
401
+ imshow(image, image_axe)
402
+
403
+ fig.tight_layout(pad=0)
404
+ fig.show()
405
+
406
+
407
+ def interactive_visualizer(ply_path):
408
+ with gr.Blocks() as demo:
409
+ gr.Markdown("# 3D Gaussian Splatting (black-screen loading might take a while)")
410
+ gr.Model3D(
411
+ value=ply_path, # splat file
412
+ label="3D Scene",
413
+ )
414
+ demo.launch(share=True)
thirdparty/sam3d/sam3d/notebook/mesh_alignment.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ """
3
+ SAM 3D Body (3DB) Mesh Alignment Utilities
4
+ Handles alignment of 3DB meshes to SAM 3D Object, same as MoGe point cloud scale.
5
+ """
6
+
7
+ import os
8
+ import math
9
+ import json
10
+ import numpy as np
11
+ import torch
12
+ import trimesh
13
+ from PIL import Image
14
+ import torch.nn.functional as F
15
+ from pytorch3d.structures import Meshes
16
+ from pytorch3d.renderer import PerspectiveCameras, RasterizationSettings, MeshRasterizer, TexturesVertex
17
+ from moge.model.v1 import MoGeModel
18
+
19
+
20
+ def load_3db_mesh(mesh_path, device='cuda'):
21
+ """Load 3DB mesh and convert from OpenGL to PyTorch3D coordinates."""
22
+ mesh = trimesh.load(mesh_path)
23
+ vertices = np.array(mesh.vertices)
24
+ faces = np.array(mesh.faces)
25
+
26
+ # Convert from OpenGL to PyTorch3D coordinates
27
+ vertices[:, 0] *= -1 # Flip X
28
+ vertices[:, 2] *= -1 # Flip Z
29
+
30
+ vertices = torch.from_numpy(vertices).float().to(device)
31
+ faces = torch.from_numpy(faces).long().to(device)
32
+ return vertices, faces
33
+
34
+
35
+ def get_moge_pointcloud(image_tensor, device='cuda'):
36
+ """Generate MoGe point cloud from image tensor."""
37
+ moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device)
38
+ moge_model.eval()
39
+ with torch.no_grad():
40
+ moge_output = moge_model.infer(image_tensor)
41
+ return moge_output
42
+
43
+
44
+ def denormalize_intrinsics(norm_K, height, width):
45
+ """Convert normalized intrinsics to absolute pixel coordinates."""
46
+ cx_norm, cy_norm = norm_K[0, 2], norm_K[1, 2]
47
+ fx_norm, fy_norm = norm_K[0, 0], norm_K[1, 1]
48
+
49
+ fx_abs = fx_norm * width
50
+ fy_abs = fy_norm * height
51
+ cx_abs = cx_norm * width
52
+ cy_abs = cy_norm * height
53
+ fx_abs = fy_abs
54
+
55
+ return np.array([
56
+ [fx_abs, 0.0, cx_abs],
57
+ [0.0, fy_abs, cy_abs],
58
+ [0.0, 0.0, 1.0]
59
+ ])
60
+
61
+
62
+ def crop_mesh_with_mask(vertices, faces, focal_length, mask, device='cuda'):
63
+ """Crop mesh vertices to only those visible in the mask."""
64
+ textures = TexturesVertex(verts_features=torch.ones_like(vertices)[None])
65
+ mesh = Meshes(verts=[vertices], faces=[faces], textures=textures)
66
+
67
+ H, W = mask.shape[-2:]
68
+ fx = fy = focal_length
69
+ cx, cy = W / 2.0, H / 2.0
70
+
71
+ camera = PerspectiveCameras(
72
+ focal_length=((fx, fy),),
73
+ principal_point=((cx, cy),),
74
+ image_size=((H, W),),
75
+ in_ndc=False, device=device
76
+ )
77
+
78
+ raster_settings = RasterizationSettings(
79
+ image_size=(H, W), blur_radius=0.0, faces_per_pixel=1,
80
+ cull_backfaces=False, bin_size=0,
81
+ )
82
+
83
+ rasterizer = MeshRasterizer(cameras=camera, raster_settings=raster_settings)
84
+ fragments = rasterizer(mesh)
85
+
86
+ face_indices = fragments.pix_to_face[0, ..., 0] # (H, W)
87
+ visible_mask = (mask > 0) & (face_indices >= 0)
88
+ visible_face_ids = face_indices[visible_mask]
89
+
90
+ visible_faces = faces[visible_face_ids]
91
+ visible_vert_ids = torch.unique(visible_faces)
92
+ verts_cropped = vertices[visible_vert_ids]
93
+
94
+ return verts_cropped, visible_mask
95
+
96
+
97
+ def extract_target_points(pointmap, visible_mask):
98
+ """Extract target points from MoGe pointmap using visible mask."""
99
+ target_points = pointmap[visible_mask.bool()]
100
+
101
+ # Convert from MoGe coordinates to PyTorch3D coordinates
102
+ target_points[:, 0] *= -1
103
+ target_points[:, 1] *= -1
104
+
105
+ # Remove flying points using adaptive quantile filtering
106
+ z_range = torch.max(target_points[:, 2]) - torch.min(target_points[:, 2])
107
+ if z_range > 6.0:
108
+ thresh = 0.90
109
+ elif z_range > 2.0:
110
+ thresh = 0.93
111
+ else:
112
+ thresh = 0.95
113
+
114
+ depth_quantile = torch.quantile(target_points[:, 2], thresh)
115
+ target_points = target_points[target_points[:, 2] <= depth_quantile]
116
+
117
+ # Remove infinite values
118
+ finite_mask = torch.isfinite(target_points).all(dim=1)
119
+ target_points = target_points[finite_mask]
120
+
121
+ return target_points
122
+
123
+
124
+ def align_mesh_to_pointcloud(vertices, target_points):
125
+ """Align mesh vertices to target point cloud using scale and translation."""
126
+ if target_points.shape[0] == 0:
127
+ print("[WARNING] No target points for alignment!")
128
+ return vertices, torch.tensor(1.0), torch.zeros(3)
129
+
130
+ # Scale alignment based on height
131
+ height_src = torch.max(vertices[:, 1]) - torch.min(vertices[:, 1])
132
+ height_tgt = torch.max(target_points[:, 1]) - torch.min(target_points[:, 1])
133
+ scale_factor = height_tgt / height_src
134
+
135
+ vertices_scaled = vertices * scale_factor
136
+
137
+ # Translation alignment based on centers
138
+ center_src = torch.mean(vertices_scaled, dim=0)
139
+ center_tgt = torch.mean(target_points, dim=0)
140
+ translation = center_tgt - center_src
141
+
142
+ vertices_aligned = vertices_scaled + translation
143
+ return vertices_aligned, scale_factor, translation
144
+
145
+
146
+ def load_mask_for_alignment(mask_path):
147
+ """Load mask image as numpy array."""
148
+ mask = Image.open(mask_path).convert('L')
149
+ mask_array = np.array(mask) / 255.0
150
+ return mask_array
151
+
152
+
153
+ def load_focal_length_from_json(json_path):
154
+ """Load focal length from JSON file."""
155
+ try:
156
+ with open(json_path, 'r') as f:
157
+ data = json.load(f)
158
+ focal_length = data.get('focal_length')
159
+ if focal_length is None:
160
+ raise ValueError("'focal_length' key not found in JSON file")
161
+ print(f"[INFO] Loaded focal length from {json_path}: {focal_length}")
162
+ return focal_length
163
+ except Exception as e:
164
+ print(f"[ERROR] Failed to load focal length from {json_path}: {e}")
165
+ raise
166
+
167
+
168
+ def process_3db_alignment(mesh_path, mask_path, image_path, device='cuda', focal_length_json_path=None):
169
+ """Complete pipeline for aligning 3DB mesh to MoGe scale."""
170
+ print(f"[INFO] Processing alignment...")
171
+
172
+ # Load input data
173
+ vertices, faces = load_3db_mesh(mesh_path, device)
174
+
175
+ # Load and preprocess image
176
+ image = Image.open(image_path).convert('RGB')
177
+ image_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0
178
+ image_tensor = image_tensor.to(device)
179
+
180
+ # Load mask and resize to match image
181
+ H, W = image_tensor.shape[1:]
182
+ mask = load_mask_for_alignment(mask_path)
183
+ if mask.shape != (H, W):
184
+ mask = Image.fromarray((mask * 255).astype(np.uint8))
185
+ mask = mask.resize((W, H), Image.NEAREST)
186
+ mask = np.array(mask) / 255.0
187
+ mask = torch.from_numpy(mask).float().to(device)
188
+
189
+ # Generate MoGe point cloud
190
+ print("[INFO] Generating MoGe point cloud...")
191
+ moge_output = get_moge_pointcloud(image_tensor, device)
192
+
193
+ # Load focal length from JSON if provided, otherwise compute from MoGe intrinsics
194
+ if focal_length_json_path is not None:
195
+ focal_length = load_focal_length_from_json(focal_length_json_path)
196
+ else:
197
+ # Compute camera parameters from MoGe intrinsics (fallback)
198
+ intrinsics = denormalize_intrinsics(moge_output['intrinsics'].cpu().numpy(), H, W)
199
+ focal_length = intrinsics[1, 1] # Use fy
200
+ print(f"[INFO] Using computed focal length from MoGe: {focal_length}")
201
+
202
+ # Crop mesh using mask
203
+ print("[INFO] Cropping mesh with mask...")
204
+ verts_cropped, visible_mask = crop_mesh_with_mask(vertices, faces, focal_length, mask, device)
205
+
206
+ # Extract target points from MoGe
207
+ print("[INFO] Extracting target points...")
208
+ target_points = extract_target_points(moge_output['points'], visible_mask)
209
+
210
+ if target_points.shape[0] == 0:
211
+ print("[ERROR] No valid target points found!")
212
+ return None
213
+
214
+ # Perform alignment
215
+ print("[INFO] Aligning mesh to point cloud...")
216
+ aligned_vertices, scale_factor, translation = align_mesh_to_pointcloud(verts_cropped, target_points)
217
+
218
+ # Apply alignment to full mesh
219
+ full_aligned_vertices = (vertices * scale_factor) + translation
220
+
221
+ # Convert back to OpenGL coordinates for final output
222
+ final_vertices_opengl = full_aligned_vertices.cpu().numpy()
223
+ final_vertices_opengl[:, 0] *= -1
224
+ final_vertices_opengl[:, 2] *= -1
225
+
226
+ results = {
227
+ 'aligned_vertices_opengl': final_vertices_opengl,
228
+ 'faces': faces.cpu().numpy(),
229
+ 'scale_factor': scale_factor.item(),
230
+ 'translation': translation.cpu().numpy(),
231
+ 'focal_length': focal_length,
232
+ 'target_points_count': target_points.shape[0],
233
+ 'cropped_vertices_count': verts_cropped.shape[0]
234
+ }
235
+
236
+ print(f"[INFO] Alignment completed - Scale: {scale_factor.item():.4f}, Target points: {target_points.shape[0]}")
237
+ return results
238
+
239
+
240
+ def process_and_save_alignment(mesh_path, mask_path, image_path, output_dir, device='cuda', focal_length_json_path=None):
241
+ """
242
+ Complete pipeline for processing 3DB alignment and saving the result.
243
+
244
+ Args:
245
+ mesh_path: Path to input 3DB mesh (.ply)
246
+ mask_path: Path to mask image (.png)
247
+ image_path: Path to input image (.jpg)
248
+ output_dir: Directory to save aligned mesh
249
+ device: Device to use ('cuda' or 'cpu')
250
+ focal_length_json_path: Optional path to focal length JSON file
251
+
252
+ Returns:
253
+ tuple: (success: bool, output_mesh_path: str or None, result_info: dict or None)
254
+ """
255
+ try:
256
+ print("[INFO] Starting 3DB mesh alignment pipeline...")
257
+
258
+ # Ensure output directory exists
259
+ os.makedirs(output_dir, exist_ok=True)
260
+
261
+ # Process alignment
262
+ result = process_3db_alignment(
263
+ mesh_path=mesh_path,
264
+ mask_path=mask_path,
265
+ image_path=image_path,
266
+ device=device,
267
+ focal_length_json_path=focal_length_json_path
268
+ )
269
+
270
+ if result is not None:
271
+ # Save aligned mesh
272
+ output_mesh_path = os.path.join(output_dir, 'human_aligned.ply')
273
+ aligned_mesh = trimesh.Trimesh(
274
+ vertices=result['aligned_vertices_opengl'],
275
+ faces=result['faces']
276
+ )
277
+ aligned_mesh.export(output_mesh_path)
278
+
279
+ print(f" SUCCESS! Saved aligned mesh to: {output_mesh_path}")
280
+ return True, output_mesh_path, result
281
+ else:
282
+ print(" ERROR: Failed to process mesh alignment")
283
+ return False, None, None
284
+
285
+ except Exception as e:
286
+ print(f" ERROR: Exception during processing: {e}")
287
+ import traceback
288
+ traceback.print_exc()
289
+ return False, None, None
290
+
291
+ finally:
292
+ print(" Processing complete!")
293
+
294
+
295
+ def visualize_meshes_interactive(aligned_mesh_path, dfy_mesh_path, output_dir=None, share=True, height=600):
296
+ """
297
+ Interactive Gradio-based 3D visualization of aligned human and object meshes.
298
+
299
+ Args:
300
+ aligned_mesh_path: Path to aligned mesh PLY file
301
+ dfy_mesh_path: Path to 3Dfy GLB file
302
+ output_dir: Directory to save combined GLB file (defaults to same dir as aligned_mesh_path)
303
+ share: Whether to create a public shareable link (default: True)
304
+ height: Height of the 3D viewer in pixels (default: 600)
305
+
306
+ Returns:
307
+ tuple: (demo, combined_glb_path) - Gradio demo object and path to combined GLB file
308
+ """
309
+ import gradio as gr
310
+
311
+ print("Loading meshes for interactive visualization...")
312
+
313
+ try:
314
+ # Load aligned mesh (PLY)
315
+ aligned_mesh = trimesh.load(aligned_mesh_path)
316
+ print(f"Loaded aligned mesh: {len(aligned_mesh.vertices)} vertices")
317
+
318
+ # Load 3Dfy mesh (GLB - handle scene structure)
319
+ dfy_scene = trimesh.load(dfy_mesh_path)
320
+
321
+ if hasattr(dfy_scene, 'dump'): # It's a scene
322
+ dfy_meshes = [geom for geom in dfy_scene.geometry.values() if hasattr(geom, 'vertices')]
323
+ if len(dfy_meshes) == 1:
324
+ dfy_mesh = dfy_meshes[0]
325
+ elif len(dfy_meshes) > 1:
326
+ dfy_mesh = trimesh.util.concatenate(dfy_meshes)
327
+ else:
328
+ raise ValueError("No valid meshes in GLB file")
329
+ else:
330
+ dfy_mesh = dfy_scene
331
+
332
+ print(f"Loaded 3Dfy mesh: {len(dfy_mesh.vertices)} vertices")
333
+
334
+ # Create combined scene
335
+ scene = trimesh.Scene()
336
+
337
+ # Add both meshes with different colors
338
+ aligned_copy = aligned_mesh.copy()
339
+ aligned_copy.visual.vertex_colors = [255, 0, 0, 200] # Red for aligned human
340
+ scene.add_geometry(aligned_copy, node_name="sam3d_aligned_human")
341
+
342
+ dfy_copy = dfy_mesh.copy()
343
+ dfy_copy.visual.vertex_colors = [0, 0, 255, 200] # Blue for 3Dfy object
344
+ scene.add_geometry(dfy_copy, node_name="dfy_object")
345
+
346
+ # Determine output path
347
+ if output_dir is None:
348
+ output_dir = os.path.dirname(aligned_mesh_path)
349
+ os.makedirs(output_dir, exist_ok=True)
350
+
351
+ combined_glb_path = os.path.join(output_dir, 'combined_scene.glb')
352
+ scene.export(combined_glb_path)
353
+ print(f"Exported combined scene to: {combined_glb_path}")
354
+
355
+ # Create interactive Gradio viewer
356
+ with gr.Blocks() as demo:
357
+ gr.Markdown("# 3D Mesh Alignment Visualization")
358
+ gr.Markdown("**Red**: SAM 3D Body Aligned Human | **Blue**: 3Dfy Object")
359
+ gr.Model3D(
360
+ value=combined_glb_path,
361
+ label="Combined 3D Scene (Interactive)",
362
+ height=height
363
+ )
364
+
365
+ # Launch the viewer
366
+ print("Launching interactive 3D viewer...")
367
+ demo.launch(share=share)
368
+
369
+ return demo, combined_glb_path
370
+
371
+ except Exception as e:
372
+ print(f"ERROR in visualization: {e}")
373
+ import traceback
374
+ traceback.print_exc()
375
+ return None, None
376
+
377
+
378
+ def visualize_meshes_comparison(aligned_mesh_path, dfy_mesh_path, use_interactive=False):
379
+ """
380
+ Simple visualization of both meshes in a single 3D plot.
381
+
382
+ DEPRECATED: Use visualize_meshes_interactive() for better interactive visualization.
383
+
384
+ Args:
385
+ aligned_mesh_path: Path to aligned mesh PLY file
386
+ dfy_mesh_path: Path to 3Dfy GLB file
387
+ use_interactive: Whether to attempt trimesh scene viewer (default: False)
388
+
389
+ Returns:
390
+ tuple: (aligned_mesh, dfy_mesh) trimesh objects or (None, None) if failed
391
+ """
392
+ import matplotlib.pyplot as plt
393
+
394
+ print("Loading meshes for visualization...")
395
+
396
+ try:
397
+ # Load aligned mesh (PLY)
398
+ aligned_mesh = trimesh.load(aligned_mesh_path)
399
+ print(f"Loaded aligned mesh: {len(aligned_mesh.vertices)} vertices")
400
+
401
+ # Load 3Dfy mesh (GLB - handle scene structure)
402
+ dfy_scene = trimesh.load(dfy_mesh_path)
403
+
404
+ if hasattr(dfy_scene, 'dump'): # It's a scene
405
+ dfy_meshes = [geom for geom in dfy_scene.geometry.values() if hasattr(geom, 'vertices')]
406
+ if len(dfy_meshes) == 1:
407
+ dfy_mesh = dfy_meshes[0]
408
+ elif len(dfy_meshes) > 1:
409
+ dfy_mesh = trimesh.util.concatenate(dfy_meshes)
410
+ else:
411
+ raise ValueError("No valid meshes in GLB file")
412
+ else:
413
+ dfy_mesh = dfy_scene
414
+
415
+ print(f"Loaded 3Dfy mesh: {len(dfy_mesh.vertices)} vertices")
416
+
417
+ # Create single 3D plot with both meshes
418
+ fig = plt.figure(figsize=(12, 10))
419
+ ax = fig.add_subplot(111, projection='3d')
420
+
421
+ # Plot both meshes in the same space
422
+ ax.scatter(dfy_mesh.vertices[:, 0],
423
+ dfy_mesh.vertices[:, 1],
424
+ dfy_mesh.vertices[:, 2],
425
+ c='blue', s=0.1, alpha=0.6, label='3Dfy Original')
426
+
427
+ ax.scatter(aligned_mesh.vertices[:, 0],
428
+ aligned_mesh.vertices[:, 1],
429
+ aligned_mesh.vertices[:, 2],
430
+ c='red', s=0.1, alpha=0.6, label='SAM 3D Body Aligned')
431
+
432
+ ax.set_title('Mesh Comparison: 3Dfy vs SAM 3D Body Aligned', fontsize=16, fontweight='bold')
433
+ ax.set_xlabel('X')
434
+ ax.set_ylabel('Y')
435
+ ax.set_zlabel('Z')
436
+ ax.legend()
437
+
438
+ plt.tight_layout()
439
+ plt.show()
440
+
441
+ # Optional trimesh scene viewer
442
+ if use_interactive:
443
+ try:
444
+ print("Creating trimesh scene...")
445
+ scene = trimesh.Scene()
446
+
447
+ # Add both meshes with different colors
448
+ aligned_copy = aligned_mesh.copy()
449
+ aligned_copy.visual.vertex_colors = [255, 0, 0, 200] # Red
450
+ scene.add_geometry(aligned_copy, node_name="sam3d_aligned")
451
+
452
+ dfy_copy = dfy_mesh.copy()
453
+ dfy_copy.visual.vertex_colors = [0, 0, 255, 200] # Blue
454
+ scene.add_geometry(dfy_copy, node_name="dfy_original")
455
+
456
+ print("Opening interactive trimesh viewer...")
457
+ scene.show()
458
+
459
+ except Exception as e:
460
+ print(f"Trimesh viewer not available: {e}")
461
+
462
+ print("Visualization complete")
463
+ return aligned_mesh, dfy_mesh
464
+
465
+ except Exception as e:
466
+ print(f"ERROR in visualization: {e}")
467
+ import traceback
468
+ traceback.print_exc()
469
+ return None, None
thirdparty/sam3d/sam3d/patching/hydra ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import hydra
5
+ import urllib.request
6
+
7
+ if hydra.__version__ != "1.3.2":
8
+ raise RuntimeError("different hydra version has been found, cannot patch")
9
+
10
+ hydra_root = os.path.dirname(hydra.__file__)
11
+ utils_path = os.path.join(hydra_root, "core", "utils.py")
12
+
13
+ urllib.request.urlretrieve(
14
+ "https://raw.githubusercontent.com/gleize/hydra/78f00766b5f37672aa7232ebbf01bdd74246bd60/hydra/core/utils.py",
15
+ utils_path,
16
+ )
thirdparty/sam3d/sam3d/pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling", "hatch-requirements-txt"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [tool.hatch.envs.default.env-vars]
6
+ PIP_EXTRA_INDEX_URL = "https://pypi.ngc.nvidia.com https://download.pytorch.org/whl/cu121"
7
+
8
+ [tool.hatch.metadata]
9
+ # for git-referenced dependencies
10
+ allow-direct-references = true
11
+
12
+ [project]
13
+ name = "sam3d_objects"
14
+ version = "0.0.1"
15
+ # required for "hatch-requirements-txt" to work
16
+ dynamic = ["dependencies", "optional-dependencies"]
17
+
18
+ [tool.hatch.build]
19
+ ignore-vcs = true
20
+ include = ["**/*.py"]
21
+ exclude = ["conftest.py", "*_test.py"]
22
+ packages = ["sam3d_objects"]
23
+
24
+ [tool.hatch.metadata.hooks.requirements_txt]
25
+ files = ["requirements.txt"]
26
+
27
+ [tool.hatch.metadata.hooks.requirements_txt.optional-dependencies]
28
+ p3d = ["requirements.p3d.txt"]
29
+ inference = ["requirements.inference.txt"]
30
+ dev = ["requirements.dev.txt"]
thirdparty/sam3d/sam3d/requirements.dev.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pytest
2
+ findpydeps
3
+ pipdeptree
4
+ lovely_tensors
thirdparty/sam3d/sam3d/requirements.inference.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ kaolin==0.17.0
2
+ gsplat @ git+https://github.com/nerfstudio-project/gsplat.git@2323de5905d5e90e035f792fe65bad0fedd413e7
3
+ seaborn==0.13.2
4
+ gradio==5.49.0
thirdparty/sam3d/sam3d/requirements.p3d.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@75ebeeaea0908c5527e7b1e305fbc7681382db47
2
+ flash_attn==2.8.3
thirdparty/sam3d/sam3d/requirements.txt ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ astor==0.8.1
2
+ async-timeout==4.0.3
3
+ auto_gptq==0.7.1
4
+ autoflake==2.3.1
5
+ av==12.0.0
6
+ bitsandbytes==0.43.0
7
+ black==24.3.0
8
+ bpy==4.3.0
9
+ colorama==0.4.6
10
+ conda-pack==0.7.1
11
+ crcmod==1.7
12
+ cuda-python==12.1.0
13
+ dataclasses==0.6
14
+ decord==0.6.0
15
+ deprecation==2.1.0
16
+ easydict==1.13
17
+ einops-exts==0.0.4
18
+ exceptiongroup==1.2.0
19
+ fastavro==1.9.4
20
+ fasteners==0.19
21
+ flake8==7.0.0
22
+ Flask==3.0.3
23
+ fqdn==1.5.1
24
+ ftfy==6.2.0
25
+ fvcore==0.1.5.post20221221
26
+ gdown==5.2.0
27
+ h5py==3.12.1
28
+ hdfs==2.7.3
29
+ httplib2==0.22.0
30
+ hydra-core==1.3.2
31
+ hydra-submitit-launcher==1.2.0
32
+ igraph==0.11.8
33
+ imath==0.0.2
34
+ isoduration==20.11.0
35
+ jsonlines==4.0.0
36
+ jsonpickle==3.0.4
37
+ jsonpointer==2.4
38
+ jupyter==1.1.1
39
+ librosa==0.10.1
40
+ lightning==2.3.3
41
+ loguru==0.7.2
42
+ mosaicml-streaming==0.7.5
43
+ nvidia-cuda-nvcc-cu12==12.1.105
44
+ nvidia-pyindex==1.0.9
45
+ objsize==0.7.0
46
+ open3d==0.18.0
47
+ opencv-python==4.9.0.80
48
+ OpenEXR==3.3.3
49
+ optimum==1.18.1
50
+ optree==0.14.1
51
+ orjson==3.10.0
52
+ panda3d-gltf==1.2.1
53
+ pdoc3==0.10.0
54
+ peft==0.10.0
55
+ pip-system-certs==4.0
56
+ point-cloud-utils==0.29.5
57
+ polyscope==2.3.0
58
+ pycocotools==2.0.7
59
+ pydot==1.4.2
60
+ pymeshfix==0.17.0
61
+ pymongo==4.6.3
62
+ pyrender==0.1.45
63
+ PySocks==1.7.1
64
+ pytest==8.1.1
65
+ python-pycg==0.9.2
66
+ randomname==0.2.1
67
+ roma==1.5.1
68
+ rootutils==1.0.7
69
+ Rtree==1.3.0
70
+ sagemaker==2.242.0
71
+ scikit-image==0.23.1
72
+ sentence-transformers==2.6.1
73
+ simplejson==3.19.2
74
+ smplx==0.1.28
75
+ spconv-cu121==2.3.8
76
+ tensorboard==2.16.2
77
+ timm==0.9.16
78
+ tomli==2.0.1
79
+ torchaudio==2.5.1+cu121
80
+ uri-template==1.3.0
81
+ usort==1.0.8.post1
82
+ wandb==0.20.0
83
+ webcolors==1.13
84
+ webdataset==0.2.86
85
+ Werkzeug==3.0.6
86
+ xatlas==0.0.9
87
+ xformers==0.0.28.post3
88
+ MoGe @ git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b
thirdparty/sam3d/sam3d/sam3d_objects/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import os
3
+
4
+ # Allow skipping initialization for lightweight tools
5
+ if not os.environ.get('LIDRA_SKIP_INIT'):
6
+ import sam3d_objects.init
thirdparty/sam3d/sam3d/sam3d_objects/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
thirdparty/sam3d/sam3d/sam3d_objects/config/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import functools
3
+ from typing import Any, Callable, Union
4
+
5
+ from omegaconf import DictConfig, ListConfig, OmegaConf
6
+ from hydra.utils import instantiate
7
+
8
+ TargetType = Union[str, type, Callable[..., Any]]
9
+ ClassOrCallableType = Union[type, Callable[..., Any]]
10
+
11
+
12
+ def dump_config(config: DictConfig, path: str = "./config.yaml"):
13
+ txt = OmegaConf.to_yaml(config, sort_keys=True)
14
+ with open(path, "w") as f:
15
+ f.write(txt)
16
+
17
+
18
+ def locate(path: str) -> Any:
19
+ if path == "":
20
+ raise ImportError("Empty path")
21
+
22
+ import builtins
23
+ from importlib import import_module
24
+
25
+ parts = [part for part in path.split(".") if part]
26
+
27
+ # load module part
28
+ module = None
29
+ for n in reversed(range(len(parts))):
30
+ try:
31
+ mod = ".".join(parts[:n])
32
+ module = import_module(mod)
33
+ except Exception as e:
34
+ if n == 0:
35
+ raise ImportError(f"Error loading module '{path}'") from e
36
+ continue
37
+ if module:
38
+ break
39
+
40
+ if module:
41
+ obj = module
42
+ else:
43
+ obj = builtins
44
+
45
+ # load object path in module
46
+ for part in parts[n:]:
47
+ mod = mod + "." + part
48
+ if not hasattr(obj, part):
49
+ try:
50
+ import_module(mod)
51
+ except Exception as e:
52
+ raise ImportError(
53
+ f"Encountered error: `{e}` when loading module '{path}'"
54
+ ) from e
55
+ obj = getattr(obj, part)
56
+
57
+ return obj
58
+
59
+
60
+ def full_instance_name(instance: Any) -> str:
61
+ return full_class_name(instance.__class__)
62
+
63
+
64
+ def full_class_name(klass: Any) -> str:
65
+ module = klass.__module__
66
+ if module == "builtins":
67
+ return klass.__qualname__ # avoid outputs like 'builtins.str'
68
+ return module + "." + klass.__qualname__
69
+
70
+
71
+ def ensure_is_subclass(child_class: type, parent_class: type) -> None:
72
+ if not issubclass(child_class, parent_class):
73
+ raise RuntimeError(
74
+ f"class {full_class_name(child_class)} should be a subclass of {full_class_name(parent_class)}"
75
+ )
76
+
77
+
78
+ def find_class_or_callable_from_target(
79
+ target: TargetType,
80
+ ) -> ClassOrCallableType:
81
+ if isinstance(target, str):
82
+ obj = locate(target)
83
+ else:
84
+ obj = target
85
+
86
+ if (not isinstance(obj, type)) and (not callable(obj)):
87
+ raise ValueError(f"Invalid type ({type(obj)}) found for {target}")
88
+
89
+ return obj
90
+
91
+
92
+ def find_and_ensure_is_subclass(target: TargetType, type_: type) -> ClassOrCallableType:
93
+ klass = find_class_or_callable_from_target(target)
94
+ ensure_is_subclass(klass, type_)
95
+ return klass
96
+
97
+
98
+ class StrictPartial:
99
+ # remark : the `/` will handle the `path` argument name conflict (e.g. calling StrictPartial("a.b.c", ..., path="/a/b/c"))
100
+ def __init__(self, path, /, *args, **kwargs):
101
+ class_or_callable = find_class_or_callable_from_target(path)
102
+ self._partial = functools.partial(class_or_callable, *args, **kwargs)
103
+
104
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
105
+ return self._partial(*args, **kwargs)
106
+
107
+
108
+ class RecursivePartial:
109
+ @staticmethod
110
+ def replace_keys(config, key_mapping):
111
+ def recurse(data):
112
+ if isinstance(data, DictConfig):
113
+ new_data = {
114
+ key_mapping[k] if k in key_mapping else k: recurse(v)
115
+ for k, v in data.items()
116
+ }
117
+ new_data = DictConfig(new_data)
118
+ elif isinstance(data, ListConfig):
119
+ new_data = ListConfig([recurse(item) for item in data])
120
+ elif type(data) in {bool, str, int, float, type(None)}:
121
+ new_data = data
122
+ else:
123
+ raise RuntimeError(f"unknow type found : {type(data)}")
124
+
125
+ return new_data
126
+
127
+ return recurse(config)
128
+
129
+ def __init__(self, config):
130
+ self.config = RecursivePartial.replace_keys(
131
+ config, {"_rpartial_target_": "_target_"}
132
+ )
133
+
134
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
135
+ return instantiate(self.config)
136
+
137
+
138
+ class Partial(StrictPartial):
139
+ # remark : allow `path` argument to be exposed for easier use
140
+ def __init__(self, path, *args, **kwargs):
141
+ super().__init__(path, *args, **kwargs)
142
+
143
+
144
+ def subkey(mapping, key):
145
+ return mapping[key]
146
+
147
+
148
+ def make_set(*args):
149
+ return set(args)
150
+
151
+
152
+ def make_tuple(*args):
153
+ return tuple(args)
154
+
155
+
156
+ def make_list_from_kwargs(**kwargs):
157
+ # Filter out None/null values to avoid issues with callbacks
158
+ return [v for v in kwargs.values() if v is not None]
159
+
160
+
161
+ def make_string(value):
162
+ return str(value)
163
+
164
+
165
+ def make_dict(**kwargs):
166
+ return dict(kwargs)
167
+
168
+
169
+ def get_item(data, key: str):
170
+ return data[key]
171
+
172
+
173
+ def get_attr(data, key: str):
174
+ return getattr(data, key)
thirdparty/sam3d/sam3d/sam3d_objects/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py ADDED
@@ -0,0 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from collections import namedtuple
3
+ import random
4
+ from typing import Optional, Dict
5
+
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import torchvision.transforms.functional
9
+ from sam3d_objects.data.dataset.tdfy.img_processing import pad_to_square_centered
10
+ from sam3d_objects.model.backbone.dit.embedder.point_remapper import PointRemapper
11
+ from typing import Optional, Dict
12
+ from loguru import logger
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torchvision
16
+ import torchvision.transforms as tv_transforms
17
+ import torchvision.transforms.functional
18
+ import torchvision.transforms.functional as TF
19
+
20
+ from sam3d_objects.data.dataset.tdfy.img_processing import pad_to_square_centered
21
+
22
+
23
+ def UNNORMALIZE(mean, std):
24
+ mean = torch.tensor(mean).reshape((3, 1, 1))
25
+ std = torch.tensor(std).reshape((3, 1, 1))
26
+
27
+ def unnormalize_img(img):
28
+ assert img.ndim == 3 and img.shape[0] == 3
29
+
30
+ return img * std.to(img.device) + mean.to(img.device)
31
+
32
+ return unnormalize_img
33
+
34
+
35
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
36
+ IMAGENET_STD = (0.229, 0.224, 0.225)
37
+
38
+
39
+ IMAGENET_NORMALIZATION = tv_transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
40
+ IMAGENET_UNNORMALIZATION = UNNORMALIZE(IMAGENET_MEAN, IMAGENET_STD)
41
+
42
+
43
+ class BoundingBoxError(Exception):
44
+ pass
45
+
46
+
47
+ def check_bounding_box(bbox_w, bbox_h):
48
+ if bbox_w < 2 or bbox_h < 2:
49
+ raise BoundingBoxError("Bounding box dimensions must be at least 2x2.")
50
+
51
+
52
+ class RGBAImageProcessor:
53
+ def __init__(
54
+ self,
55
+ resize_and_make_square_kwargs: Optional[Dict] = None,
56
+ object_crop_kwargs: Optional[Dict] = None,
57
+ remove_background: bool = False,
58
+ imagenet_normalization: bool = False,
59
+ ):
60
+ self.remove_background = remove_background
61
+ self.resize_and_pad_kwargs = resize_and_make_square_kwargs
62
+ self.object_crop_kwargs = object_crop_kwargs
63
+ self.imagenet_normalization = imagenet_normalization
64
+ if resize_and_make_square_kwargs is not None:
65
+ self.transforms = resize_and_make_square(**resize_and_make_square_kwargs)
66
+
67
+ def __call__(
68
+ self, image: torch.Tensor, mask: Optional[torch.Tensor] = None
69
+ ) -> tuple[torch.Tensor, torch.Tensor]:
70
+ if mask is None:
71
+ assert (
72
+ image.shape[0] == 4
73
+ ), f"Requires 4 channels (RGB + alpha), got {image.shape[0]=}"
74
+ image, mask = split_rgba(image)
75
+ else:
76
+ assert (
77
+ image.shape[0] == 3
78
+ ), f"Requires 3 channels (RGB), got {image.shape[0]=}"
79
+ assert mask.dim() == 2, f"Requires 2D mask, got {mask.dim()=}"
80
+
81
+ if not self.object_crop_kwargs in [None, False]:
82
+ image, mask = crop_around_mask_with_padding(
83
+ image, mask, **self.object_crop_kwargs
84
+ )
85
+
86
+ if self.remove_background:
87
+ image, mask = rembg(image, mask)
88
+
89
+ image = self.transforms["img_transform"](image)
90
+ mask = self.transforms["mask_transform"](mask.unsqueeze(0))
91
+
92
+ if self.imagenet_normalization:
93
+ image = IMAGENET_NORMALIZATION(image)
94
+ return image, mask
95
+
96
+
97
+ def load_rgb(fpath: str) -> torch.Tensor:
98
+ """
99
+ Load a RGB(A) image from a file path.
100
+ """
101
+ image = plt.imread(fpath) # Why use matplotlib?
102
+ if image.dtype == "uint8":
103
+ image = image / 255
104
+ image = image.astype(np.float32)
105
+ image = torch.from_numpy(image)
106
+ image = image.permute(2, 0, 1).contiguous()
107
+ return image
108
+
109
+
110
+ def concat_rgba(
111
+ rgb_image: torch.Tensor,
112
+ mask: torch.Tensor,
113
+ ) -> torch.Tensor:
114
+ """
115
+ Create a 4-channel RGBA image from a 3-channel RGB image and a mask.
116
+ """
117
+ assert rgb_image.dim() == 3, f"{rgb_image.shape=}"
118
+ assert mask.dim() == 2, f"{mask.shape=}"
119
+ assert rgb_image.shape[0] == 3, f"{rgb_image.shape[0]=}"
120
+ assert rgb_image.shape[1:] == mask.shape, f"{rgb_image.shape[1:]=} != {mask.shape=}"
121
+ return torch.cat((rgb_image, mask[None, ...]), dim=0)
122
+
123
+
124
+ def split_rgba(rgba_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
125
+ """
126
+ Split a 4-channel RGBA image into a 3-channel RGB image and a 1-channel mask.
127
+
128
+ Args:
129
+ rgba_image: A 4-channel RGBA image.
130
+
131
+ Returns:
132
+ A tuple of (rgb_image, mask).
133
+ """
134
+ assert rgba_image.dim() == 3, f"{rgba_image.shape=}"
135
+ assert rgba_image.shape[0] == 4, f"{rgba_image.shape[0]=}"
136
+ return rgba_image[:3], rgba_image[3]
137
+
138
+
139
+ def get_mask(
140
+ rgb_image: torch.Tensor,
141
+ depth_image: torch.Tensor,
142
+ mask_source: str,
143
+ ) -> torch.Tensor:
144
+ """
145
+ Extract a mask from either the alpha channel of an RGB image or a depth image.
146
+
147
+ Args:
148
+ rgb_image: Tensor of shape (B, C, H, W) or (C, H, W) where C >= 4 if using alpha channel
149
+ depth_image: Tensor of shape (B, 1, H, W) or (1, H, W) containing depth information
150
+ mask_source: Source of the mask, either "ALPHA_CHANNEL" or "DEPTH"
151
+
152
+ Returns:
153
+ mask: Tensor of shape (B, 1, H, W) or (1, H, W) containing the extracted mask
154
+ """
155
+ # Handle unbatched inputs (add batch dimension if needed)
156
+ is_batched = len(rgb_image.shape) == 4
157
+
158
+ if not is_batched:
159
+ rgb_image = rgb_image.unsqueeze(0)
160
+ if depth_image is not None:
161
+ depth_image = depth_image.unsqueeze(0)
162
+
163
+ if mask_source == "ALPHA_CHANNEL":
164
+ if rgb_image.shape[1] != 4:
165
+ logger.warning(f"No ALPHA CHANNEL for the image, cannot read mask.")
166
+ mask = None
167
+ else:
168
+ mask = rgb_image[:, 3:4, :, :]
169
+ elif mask_source == "DEPTH":
170
+ mask = depth_image
171
+ else:
172
+ raise ValueError(f"Invalid mask source: {mask_source}")
173
+
174
+ # Remove batch dimension if input was unbatched
175
+ if not is_batched:
176
+ mask = mask.squeeze(0)
177
+
178
+ return mask
179
+
180
+
181
+ def rembg(image, mask, pointmap=None):
182
+ """
183
+ Remove the background from an image using a mask.
184
+ For pointmaps, sets background regions to NaN.
185
+
186
+ This function follows the standard transform pattern:
187
+ - If called with (image, mask), returns (image, mask)
188
+ - If called with (image, mask, pointmap), returns (image, mask, pointmap)
189
+ """
190
+ masked_image = image * mask
191
+
192
+ if pointmap is not None:
193
+ masked_pointmap = torch.where(mask > 0, pointmap, torch.nan)
194
+ return masked_image, mask, masked_pointmap
195
+
196
+ return masked_image, mask
197
+
198
+
199
+ def resize_and_make_square(
200
+ img_size: int,
201
+ make_square: bool | str = False,
202
+ ):
203
+ """
204
+ Create image and mask transforms based on configuration.
205
+
206
+ Returns:
207
+ dict: {"img_transform": img_transform, "mask_transform": mask_transform}
208
+ """
209
+ if isinstance(make_square, str):
210
+ make_square = make_square.lower()
211
+ assert make_square in ["pad", "crop", False]
212
+ pre_resize_transform = tv_transforms.Lambda(lambda x: x)
213
+ post_resize_transform = tv_transforms.Lambda(lambda x: x)
214
+ if make_square == "pad":
215
+ pre_resize_transform = pad_to_square_centered
216
+ elif make_square == "crop":
217
+ post_resize_transform = tv_transforms.CenterCrop(img_size)
218
+
219
+ img_resize = tv_transforms.Resize(img_size)
220
+ mask_resize = tv_transforms.Resize(
221
+ img_size,
222
+ interpolation=tv_transforms.InterpolationMode.BILINEAR,
223
+ )
224
+
225
+ img_transform = tv_transforms.Compose(
226
+ [
227
+ pre_resize_transform,
228
+ img_resize,
229
+ post_resize_transform,
230
+ ]
231
+ )
232
+
233
+ mask_transform = tv_transforms.Compose(
234
+ [
235
+ pre_resize_transform,
236
+ mask_resize,
237
+ post_resize_transform,
238
+ ]
239
+ )
240
+
241
+ return {
242
+ "img_transform": img_transform,
243
+ "mask_transform": mask_transform,
244
+ }
245
+
246
+
247
+ def crop_around_mask_with_random_box_size_factor(
248
+ loaded_image: torch.Tensor,
249
+ mask: torch.Tensor,
250
+ random_box_size_factor: float = 1.0,
251
+ pointmap: Optional[torch.Tensor] = None,
252
+ ) -> np.ndarray:
253
+ return crop_around_mask_with_padding(
254
+ loaded_image,
255
+ mask,
256
+ box_size_factor=1.0 + random.uniform(0, 1) * random_box_size_factor,
257
+ padding_factor=0.0,
258
+ pointmap=pointmap,
259
+ )
260
+
261
+
262
+ def crop_around_mask_with_padding(
263
+ loaded_image: torch.Tensor,
264
+ mask: torch.Tensor,
265
+ box_size_factor: float = 1.6,
266
+ padding_factor: float = 0.1,
267
+ pointmap: Optional[torch.Tensor] = None,
268
+ ) -> np.ndarray:
269
+ # cast to ensure the function can be called normally
270
+ cast_mask = False
271
+ if mask.dim() == 3:
272
+ assert mask.shape[0] == 1, "cannot take mask with channel dimension not 1"
273
+ mask = mask[0]
274
+ cast_mask = True
275
+ loaded_image = concat_rgba(loaded_image, mask)
276
+
277
+ bbox = compute_mask_bbox(mask, box_size_factor)
278
+ loaded_image = torchvision.transforms.functional.crop(
279
+ loaded_image, bbox[1], bbox[0], bbox[3] - bbox[1], bbox[2] - bbox[0]
280
+ )
281
+
282
+ # Crop pointmap if provided
283
+ if pointmap is not None:
284
+ pointmap = torchvision.transforms.functional.crop(
285
+ pointmap, bbox[1], bbox[0], bbox[3] - bbox[1], bbox[2] - bbox[0]
286
+ )
287
+
288
+ C, H, W = loaded_image.shape
289
+ max_dim = max(H, W) # Get the larger dimension
290
+
291
+ # Step 1: Pad to square shape
292
+ pad_h = (max_dim - H) // 2
293
+ pad_w = (max_dim - W) // 2
294
+ pad_h_extra = (max_dim - H) - pad_h # To ensure even padding
295
+ pad_w_extra = (max_dim - W) - pad_w
296
+
297
+ loaded_image = torch.nn.functional.pad(
298
+ loaded_image, (pad_w, pad_w_extra, pad_h, pad_h_extra), mode="constant", value=0
299
+ )
300
+ if pointmap is not None:
301
+ pointmap = torch.nn.functional.pad(
302
+ pointmap,
303
+ (pad_w, pad_w_extra, pad_h, pad_h_extra),
304
+ mode="constant",
305
+ value=float("nan"),
306
+ )
307
+
308
+ # Step 2: Extend by 10% on each side; idk but this seems to have better results overall
309
+ if padding_factor > 0:
310
+ extend_size = int(max_dim * padding_factor) # 10% extension on each side
311
+ loaded_image = torch.nn.functional.pad(
312
+ loaded_image,
313
+ (extend_size, extend_size, extend_size, extend_size),
314
+ mode="constant",
315
+ value=0,
316
+ )
317
+
318
+ if pointmap is not None:
319
+ pointmap = torch.nn.functional.pad(
320
+ pointmap,
321
+ (extend_size, extend_size, extend_size, extend_size),
322
+ mode="constant",
323
+ value=float("nan"),
324
+ )
325
+
326
+ rgb_image, mask = split_rgba(loaded_image)
327
+ if cast_mask:
328
+ mask = mask[None]
329
+
330
+ if pointmap is not None:
331
+ return rgb_image, mask, pointmap
332
+ return rgb_image, mask
333
+
334
+
335
+ def compute_mask_bbox(
336
+ mask: torch.Tensor, box_size_factor: float = 1.0
337
+ ) -> tuple[float, float, float, float]:
338
+ """
339
+ Compute a bounding box around a binary mask with optional size adjustment.
340
+
341
+ Args:
342
+ mask: A 2D binary tensor where non-zero values represent the object of interest.
343
+ box_size_factor: Factor to scale the bounding box size. Values > 1.0 create a larger box.
344
+ Default is 1.0 (tight bounding box).
345
+
346
+ Returns:
347
+ A tuple of (x1, y1, x2, y2) coordinates representing the bounding box,
348
+ where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner.
349
+
350
+ Raises:
351
+ ValueError: If mask is not a torch.Tensor or not a 2D tensor.
352
+ """
353
+ if not isinstance(mask, torch.Tensor):
354
+ raise ValueError("Mask must be a torch.Tensor")
355
+ if not mask.dim() == 2:
356
+ raise ValueError("Mask must be a 2D tensor")
357
+ bbox_indices = torch.nonzero(mask)
358
+ if bbox_indices.numel() == 0:
359
+ # Handle empty mask case
360
+ return (0, 0, 0, 0)
361
+
362
+ y_indices = bbox_indices[:, 0]
363
+ x_indices = bbox_indices[:, 1]
364
+
365
+ min_x = torch.min(x_indices).item()
366
+ min_y = torch.min(y_indices).item()
367
+ max_x = torch.max(x_indices).item()
368
+ max_y = torch.max(y_indices).item()
369
+
370
+ bbox = (min_x, min_y, max_x, max_y)
371
+
372
+ center_x = (bbox[0] + bbox[2]) / 2
373
+ center_y = (bbox[1] + bbox[3]) / 2
374
+
375
+ bbox_w, bbox_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
376
+
377
+ check_bounding_box(bbox_w, bbox_h)
378
+
379
+ size = max(bbox_w, bbox_h, 2)
380
+ size = int(size * box_size_factor)
381
+
382
+ bbox = (
383
+ int(center_x - size // 2),
384
+ int(center_y - size // 2),
385
+ int(center_x + size // 2),
386
+ int(center_y + size // 2),
387
+ )
388
+ # bbox = tuple(map(int, bbox))
389
+ return bbox
390
+
391
+
392
+ def crop_and_pad(image, bbox):
393
+ """
394
+ Crop an image using a bounding box and pad with zeros if out of bounds.
395
+
396
+ Args:
397
+ image (torch.Tensor): CxHxW image.
398
+ bbox (tuple): (x1, y1, x2, y2) bounding box.
399
+
400
+ Returns:
401
+ torch.Tensor: Cropped and zero-padded image.
402
+ """
403
+ C, H, W = image.shape
404
+ x1, y1, x2, y2 = bbox
405
+
406
+ # Ensure coordinates are integers
407
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
408
+
409
+ # Compute cropping coordinates
410
+ x1_pad, y1_pad = max(0, -x1), max(0, -y1)
411
+ x2_pad, y2_pad = max(0, x2 - W), max(0, y2 - H)
412
+
413
+ # Compute valid region in the original image
414
+ x1_crop, y1_crop = max(0, x1), max(0, y1)
415
+ x2_crop, y2_crop = min(W, x2), min(H, y2)
416
+
417
+ # Extract the valid part
418
+ cropped = image[:, y1_crop:y2_crop, x1_crop:x2_crop]
419
+
420
+ # Create a zero-padded output
421
+ padded = torch.zeros((C, y2 - y1, x2 - x1), dtype=image.dtype)
422
+
423
+ # Place the cropped image into the zero-padded array
424
+ padded[
425
+ :, y1_pad : y1_pad + cropped.shape[1], x1_pad : x1_pad + cropped.shape[2]
426
+ ] = cropped
427
+
428
+ return padded
429
+
430
+
431
+ def resize_all_to_same_size(
432
+ rgb_image: torch.Tensor,
433
+ mask: torch.Tensor,
434
+ pointmap: Optional[torch.Tensor] = None,
435
+ target_size: Optional[tuple[int, int]] = None,
436
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
437
+ """
438
+ Resize RGB image, mask, and pointmap to the same size.
439
+
440
+ This is crucial when pointmaps have different resolution than RGB images,
441
+ which must be done BEFORE any cropping operations.
442
+
443
+ Args:
444
+ rgb_image: RGB image tensor of shape (C, H, W)
445
+ mask: Mask tensor of shape (H, W) or (1, H, W)
446
+ pointmap: Optional pointmap tensor of shape (C_p, H_p, W_p)
447
+ target_size: Target size as (H, W). If None, uses RGB image size.
448
+
449
+ Returns:
450
+ Tuple of (resized_rgb, resized_mask, resized_pointmap)
451
+ """
452
+ squeeze_mask = (mask.dim() == 2)
453
+ if squeeze_mask:
454
+ mask = mask.unsqueeze(0)
455
+
456
+ if target_size is None:
457
+ target_size = (rgb_image.shape[1], rgb_image.shape[2]) # (H, W)
458
+
459
+ rgb_needs_resize = (rgb_image.shape[1], rgb_image.shape[2]) != target_size
460
+ if rgb_needs_resize:
461
+ rgb_image = torchvision.transforms.functional.resize(
462
+ rgb_image, target_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR
463
+ )
464
+ mask = torchvision.transforms.functional.resize(
465
+ mask, target_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST
466
+ )
467
+
468
+ if pointmap is not None:
469
+ pointmap_size = (pointmap.shape[1], pointmap.shape[2])
470
+ if pointmap_size != target_size:
471
+ # Handle NaN values in pointmap during resizing
472
+ # Direct resize would propagate NaN values, so we need special handling
473
+ nan_mask = torch.isnan(pointmap).any(dim=0)
474
+ pointmap_clean = torch.where(torch.isnan(pointmap), torch.zeros_like(pointmap), pointmap)
475
+ pointmap_resized = torchvision.transforms.functional.resize(
476
+ pointmap_clean, target_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR
477
+ )
478
+
479
+ # Resize the nan mask to identify which regions should remain invalid
480
+ nan_mask_resized = torchvision.transforms.functional.resize(
481
+ nan_mask.unsqueeze(0).float(), target_size,
482
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST
483
+ ).squeeze(0) > 0.5
484
+
485
+ # Restore NaN values in regions that were originally invalid
486
+ pointmap = torch.where(
487
+ nan_mask_resized.unsqueeze(0).expand_as(pointmap_resized),
488
+ torch.full_like(pointmap_resized, float('nan')),
489
+ pointmap_resized
490
+ )
491
+
492
+ if squeeze_mask:
493
+ mask = mask.squeeze(0)
494
+
495
+ if pointmap is not None:
496
+ return rgb_image, mask, pointmap
497
+ return rgb_image, mask
498
+
499
+
500
+ SSINormalizedPointmap = namedtuple("SSINormalizedPointmap", ["pointmap", "scale", "shift"])
501
+ class SSIPointmapNormalizer:
502
+
503
+ def normalize(self, pointmap: torch.Tensor, mask: torch.Tensor,
504
+ scale: Optional[torch.Tensor] = None, shift: Optional[torch.Tensor] = None,
505
+ ) -> SSINormalizedPointmap:
506
+ if scale is None or shift is None:
507
+ normalized_pointmap, scale, shift = normalize_pointmap_ssi(pointmap)
508
+ else:
509
+ assert scale.shape == (3,) and shift.shape == (3,), "scale and shift must be in (3,) format"
510
+ normalized_pointmap = _apply_metric_to_ssi(pointmap, scale, shift)
511
+ return SSINormalizedPointmap(normalized_pointmap, scale, shift)
512
+
513
+ def denormalize(self, pointmap: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
514
+ pointmap = _apply_metric_to_ssi(pointmap, scale, shift, apply_inverse=True)
515
+ return pointmap
516
+
517
+
518
+
519
+ class ObjectCentricSSI(SSIPointmapNormalizer):
520
+ def __init__(self,
521
+ use_scene_scale: bool = True,
522
+ quantile_drop_threshold: float = 0.1,
523
+ clip_beyond_scale: Optional[float] = None,
524
+ # scale_factor: float = 3.8076, # e^(1.337); empirical mean of R3+Artist train
525
+ scale_factor: float = 1.0, # e^(1.337); empirical mean of R3+Artist train
526
+ allow_scale_and_shift_override: bool = False,
527
+ raise_on_no_valid_points: bool = False,
528
+ ):
529
+ self.use_scene_scale = use_scene_scale
530
+ self.quantile_drop_threshold = quantile_drop_threshold
531
+ self.clip_beyond_scale = clip_beyond_scale
532
+ self.scale_factor = scale_factor
533
+ self.allow_scale_and_shift_override = allow_scale_and_shift_override
534
+ self.raise_on_no_valid_points = raise_on_no_valid_points
535
+
536
+ def _compute_scale_and_shift(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
537
+ pointmap_size = (pointmap.shape[1], pointmap.shape[2])
538
+
539
+
540
+ mask_resized = torchvision.transforms.functional.resize(
541
+ mask, pointmap_size,
542
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST
543
+ ).squeeze(0)
544
+
545
+ pointmap_flat = pointmap.reshape(3, -1)
546
+ # Get valid points from the mask
547
+ mask_bool = mask_resized.reshape(-1) > 0.5
548
+ mask_points = pointmap_flat[:, mask_bool]
549
+
550
+ if mask_points.isfinite().max() == 0:
551
+ if self.raise_on_no_valid_points:
552
+ raise ValueError(f"No valid points found in mask")
553
+ logger.warning(f"No valid points found in mask; setting scale to {self.scale_factor} and shift to 0")
554
+ return torch.ones_like(pointmap_flat[:,0]) * self.scale_factor, torch.zeros_like(pointmap_flat[:,0])
555
+
556
+ # Compute median for shift
557
+ shift = mask_points.nanmedian(dim=-1).values
558
+ # logger.info(f"{pointmap.shape=} {mask_resized.shape=} {shift.shape=}")
559
+
560
+
561
+ if self.use_scene_scale == True:
562
+ # Normalize by the scene scale
563
+ points_centered = pointmap_flat - shift.unsqueeze(-1)
564
+ max_dims = points_centered.abs().max(dim=0).values
565
+ scale = max_dims.nanmedian(dim=-1).values
566
+ elif self.use_scene_scale == False:
567
+ # Normalize by the object scale
568
+ shifted_mask_points = mask_points - shift.unsqueeze(-1)
569
+ norm = shifted_mask_points.norm(dim=0)
570
+ quantiles = torch.nanquantile(norm,
571
+ torch.tensor([self.quantile_drop_threshold, 1. - self.quantile_drop_threshold],
572
+ device=shifted_mask_points.device),
573
+ dim=-1)
574
+ scale = (quantiles[1] - quantiles[0]).max(dim=-1).values * 2.0
575
+ elif self.use_scene_scale.upper() == "OBJECT_NORM_MEDIAN":
576
+ # Normalize by the object scale
577
+ shifted_mask_points = mask_points - shift.unsqueeze(-1)
578
+ norm = shifted_mask_points.norm(dim=0)
579
+ scale = norm.nanmedian(dim=-1).values
580
+ else:
581
+ raise ValueError(f"Invalid use_scene_scale: {self.use_scene_scale}")
582
+ scale = scale.expand_as(shift) # per-dim scaling
583
+ scale = scale * self.scale_factor
584
+ return scale, shift
585
+
586
+ def normalize(self, pointmap: torch.Tensor, mask: torch.Tensor,
587
+ scale: Optional[torch.Tensor] = None, shift: Optional[torch.Tensor] = None,
588
+ ) -> torch.Tensor:
589
+ # 1. resize mask to size of pointmap using nearest interpolation
590
+ # 2. get mask points: pointmap[mask > 0.5]
591
+ # 3. shift = mask_points.median() # xyz
592
+ # 4. scale = # filter. If no points, then
593
+ # logger.info(f"{pointmap.shape=} {mask.shape=}")
594
+ assert pointmap.shape[0] == 3, "pointmap must be in (3, H, W) format"
595
+ pointmap_size = (pointmap.shape[1], pointmap.shape[2])
596
+
597
+ _scale, _shift = self._compute_scale_and_shift(pointmap, mask)
598
+ if scale is not None and self.allow_scale_and_shift_override:
599
+ _scale = scale
600
+ if shift is not None and self.allow_scale_and_shift_override:
601
+ _shift = shift
602
+ return_scale, return_shift = _scale, _shift
603
+
604
+ # Apply normalization
605
+ pointmap_normalized = _apply_metric_to_ssi(pointmap, return_scale, return_shift)
606
+
607
+ if self.clip_beyond_scale is not None and self.clip_beyond_scale > 0:
608
+ new_norm = pointmap_normalized.norm(dim=0)
609
+ pointmap_normalized = torch.where(
610
+ new_norm > self.clip_beyond_scale,
611
+ torch.full_like(pointmap_normalized, float('nan')),
612
+ pointmap_normalized
613
+ )
614
+
615
+ return SSINormalizedPointmap(pointmap_normalized, return_scale, return_shift)
616
+
617
+
618
+ class ObjectApparentSizeSSI(SSIPointmapNormalizer):
619
+ def __init__(self,
620
+ clip_beyond_scale: Optional[float] = None,
621
+ use_scene_scale: bool = True,
622
+ scale_factor: float = 1.0, # e^(1.337); empirical mean of R3+Artist train
623
+ ):
624
+ self.clip_beyond_scale = clip_beyond_scale
625
+ self.use_scene_scale = use_scene_scale
626
+ self.scale_factor = scale_factor
627
+
628
+ def _get_scale_and_shift(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
629
+ pointmap_size = (pointmap.shape[1], pointmap.shape[2])
630
+ pointmap_flat = pointmap.reshape(3, -1)
631
+
632
+ if not self.use_scene_scale:
633
+ # Get valid points from the mask
634
+ mask_resized = torchvision.transforms.functional.resize(
635
+ mask, pointmap_size,
636
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST
637
+ ).squeeze(0)
638
+ mask_bool = mask_resized.reshape(-1) > 0.5
639
+ pointmap_flat = pointmap_flat[:, mask_bool]
640
+
641
+ # Median z-distance
642
+ median_z = pointmap_flat[-1, ...].nanmedian().unsqueeze(0)
643
+ scale = median_z.expand(3) * self.scale_factor
644
+ shift = torch.zeros_like(scale)
645
+ # logger.info(f'median z = {median_z}')
646
+ return scale, shift
647
+
648
+ def normalize(self,
649
+ pointmap: torch.Tensor,
650
+ mask: torch.Tensor,
651
+ scale: Optional[torch.Tensor] = None,
652
+ shift: Optional[torch.Tensor] = None,
653
+ ) -> torch.Tensor:
654
+ assert pointmap.shape[0] == 3, "pointmap must be in (3, H, W) format"
655
+ pointmap_size = (pointmap.shape[1], pointmap.shape[2])
656
+
657
+ if scale is None or shift is None:
658
+ scale, shift = self._get_scale_and_shift(pointmap, mask)
659
+ else:
660
+ assert scale.shape == (3,) and shift.shape == (3,), "scale and shift must be in (3,) format"
661
+
662
+ # Apply normalization and clip
663
+ pointmap_normalized = _apply_metric_to_ssi(pointmap, scale, shift)
664
+ # logger.info(f"{pointmap_normalized.shape=}")
665
+
666
+ if self.clip_beyond_scale is not None and self.clip_beyond_scale > 0:
667
+ pointmap_normalized = torch.where(
668
+ pointmap_normalized[-1, ...] > self.clip_beyond_scale,
669
+ torch.full_like(pointmap_normalized, float('nan')),
670
+ pointmap_normalized
671
+ )
672
+
673
+ # return pointmap_normalized, scale, shift
674
+ return SSINormalizedPointmap(pointmap_normalized, scale, shift)
675
+
676
+
677
+ class NormalizedDisparitySpaceSSI(SSIPointmapNormalizer):
678
+ def __init__(self,
679
+ clip_beyond_scale: Optional[float] = None,
680
+ use_scene_scale: bool = True,
681
+ log_disparity_shift: float = 0.0,
682
+ ):
683
+ self.clip_beyond_scale = clip_beyond_scale
684
+ self.use_scene_scale = use_scene_scale
685
+ self.point_remapper = PointRemapper(remap_type="exp_disparity")
686
+ self.log_disparity_shift = log_disparity_shift
687
+
688
+ def normalize(self, pointmap: torch.Tensor, mask: torch.Tensor,
689
+ scale: Optional[torch.Tensor] = None, shift: Optional[torch.Tensor] = None,
690
+ ) -> torch.Tensor:
691
+ assert pointmap.shape[0] == 3, "pointmap must be in (3, H, W) format"
692
+
693
+
694
+ disparity_space_pointmap = self.point_remapper.forward(pointmap.permute(1, 2, 0)).permute(2, 0, 1)
695
+ if scale is None or shift is None:
696
+ scale, shift = self._get_scale_and_shift(disparity_space_pointmap, mask)
697
+ else:
698
+ assert scale.shape == (3,) and shift.shape == (3,), "scale and shift must be in (3,) format"
699
+
700
+ # pointmap_normalized = pointmap.clone().detach()
701
+ pointmap_normalized = _apply_metric_to_ssi(disparity_space_pointmap, scale, shift)
702
+ # logger.info(f"{pointmap_normalized.shape=}")
703
+
704
+ if self.clip_beyond_scale is not None and self.clip_beyond_scale > 0:
705
+ pointmap_normalized = torch.where(
706
+ pointmap_normalized[2, ...].abs() > self.clip_beyond_scale,
707
+ torch.full_like(pointmap_normalized, float('nan')),
708
+ pointmap_normalized
709
+ )
710
+
711
+ # return pointmap_normalized, scale, shift
712
+ return SSINormalizedPointmap(pointmap_normalized, scale, shift)
713
+
714
+ def denormalize(self, pointmap: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
715
+ pointmap = _apply_metric_to_ssi(pointmap, scale, shift, apply_inverse=True)
716
+ pointmap = self.point_remapper.inverse(pointmap.permute(1, 2, 0)).permute(2, 0, 1)
717
+ return pointmap
718
+
719
+ def _get_scale_and_shift(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
720
+ pointmap_size = (pointmap.shape[1], pointmap.shape[2])
721
+ mask_resized = torchvision.transforms.functional.resize(
722
+ mask, pointmap_size,
723
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST
724
+ ).squeeze(0)
725
+
726
+ pointmap_flat = pointmap.reshape(3, -1)
727
+ if self.use_scene_scale:
728
+ median_z = pointmap_flat[-1, ...].nanmedian().unsqueeze(0)
729
+ shift = torch.zeros_like(median_z.expand(3))
730
+ shift[-1, ...] = median_z[0] + self.log_disparity_shift
731
+ else:
732
+ # Get valid points from the mask (shift, x/z, y/z, log(z))
733
+ mask_bool = mask_resized.reshape(-1) > 0.5
734
+ pointmap_flat = pointmap_flat[:, mask_bool]
735
+ shift = pointmap_flat.nanmedian(dim=-1).values
736
+
737
+ scale = torch.ones_like(shift)
738
+ # logger.info(f'median z = {median_z}')
739
+ return scale, shift
740
+
741
+ def normalize_pointmap_ssi(pointmap: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
742
+ """
743
+ Normalize pointmap using Scale-Shift Invariant (SSI) normalization.
744
+
745
+ Args:
746
+ pointmap: Pointmap tensor of shape (H, W, 3) or (3, H, W)
747
+
748
+ Returns:
749
+ Tuple of (normalized_pointmap, scale, shift)
750
+ """
751
+ from sam3d_objects.data.dataset.tdfy.pose_target import ScaleShiftInvariant
752
+
753
+ # Convert to (H, W, 3) if needed for get_scale_and_shift
754
+ if pointmap.shape[0] == 3:
755
+ pointmap_hw3 = pointmap.permute(1, 2, 0)
756
+ original_format = 'chw'
757
+ else:
758
+ pointmap_hw3 = pointmap
759
+ original_format = 'hwc'
760
+
761
+ # Get scale and shift using existing method
762
+ scale, shift = ScaleShiftInvariant.get_scale_and_shift(pointmap_hw3)
763
+
764
+ pointmap_normalized = _apply_metric_to_ssi(pointmap, scale, shift)
765
+ return pointmap_normalized, scale, shift
766
+
767
+ def _apply_metric_to_ssi(pointmap: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor, apply_inverse: bool = False) -> torch.Tensor:
768
+ """
769
+ Normalize pointmap using Scale-Shift Invariant (SSI) normalization.
770
+
771
+ Args:
772
+ pointmap: Pointmap tensor of shape (H, W, 3) or (3, H, W)
773
+
774
+ Returns:
775
+ Tuple of (normalized_pointmap, scale, shift)
776
+ """
777
+ from sam3d_objects.data.dataset.tdfy.pose_target import ScaleShiftInvariant
778
+
779
+ # Convert to (H, W, 3) if needed for get_scale_and_shift
780
+ if pointmap.shape[0] == 3:
781
+ pointmap_hw3 = pointmap.permute(1, 2, 0)
782
+ original_format = 'chw'
783
+ else:
784
+ pointmap_hw3 = pointmap
785
+ original_format = 'hwc'
786
+
787
+ # Apply normalization
788
+ ssi_to_metric = ScaleShiftInvariant.ssi_to_metric(scale, shift)
789
+ metric_to_ssi = ssi_to_metric.inverse()
790
+ transform_to_apply = metric_to_ssi
791
+
792
+ if apply_inverse:
793
+ transform_to_apply = ssi_to_metric
794
+
795
+ pointmap_flat = pointmap_hw3.reshape(-1, 3)
796
+ pointmap_normalized = transform_to_apply.transform_points(pointmap_flat)
797
+
798
+ # Reshape back to original format
799
+ if original_format == 'chw':
800
+ pointmap_normalized = pointmap_normalized.reshape(pointmap.shape[1], pointmap.shape[2], 3).permute(2, 0, 1)
801
+ else:
802
+ pointmap_normalized = pointmap_normalized.reshape(pointmap_hw3.shape)
803
+
804
+ return pointmap_normalized
805
+
806
+
807
+ def perturb_mask_translation(
808
+ image: torch.Tensor,
809
+ mask: torch.Tensor,
810
+ max_px_delta: int = 5,
811
+ ):
812
+ """
813
+ Applies data augmentation to the mask by randomly translating the mask.
814
+
815
+ Args:
816
+ image: (C, H, W) float32 [0, 1] tensor.
817
+ mask: (1, H, W) float32 [0, 1] tensor.
818
+ max_px_delta: The maximum number of pixels we will randomly shift by in each 2D direction.
819
+ """
820
+ dx = random.randint(-max_px_delta, max_px_delta)
821
+ dy = random.randint(-max_px_delta, max_px_delta)
822
+
823
+ mask = mask.squeeze(0)
824
+ mask = torch.roll(mask, shifts=(dy, dx), dims=(0, 1))
825
+
826
+ # Zero out wrapped regions
827
+ if dy > 0:
828
+ mask[:dy, :] = 0
829
+ elif dy < 0:
830
+ mask[dy:, :] = 0
831
+ if dx > 0:
832
+ mask[:, :dx] = 0
833
+ elif dx < 0:
834
+ mask[:, dx:] = 0
835
+
836
+ mask = mask.unsqueeze(0)
837
+ return image, mask
838
+
839
+
840
+ def perturb_mask_boundary(
841
+ image: torch.Tensor,
842
+ mask: torch.Tensor,
843
+ kernel_range: tuple[int, int] = (2, 5),
844
+ p_erode: float = 0.1,
845
+ p_dilate: float = 0.8,
846
+ **kwargs,
847
+ ):
848
+ """
849
+ Applies data augmentation to the mask by randomly eroding or dilating the mask.
850
+
851
+ Args:
852
+ image: (C, H, W) float32 [0, 1] tensor.
853
+ mask: (1, H, W) float32 [0, 1] tensor.
854
+ kernel_range: Range of kernel sizes to sample from.
855
+ p_erode: Probability of erosion.
856
+ p_dilate: Probability of dilation.
857
+ kwargs: Kwargs for the cv2 erode/dilate function.
858
+ """
859
+ import cv2
860
+
861
+ C, H, W = image.shape
862
+ assert mask.shape == (1, H, W)
863
+ assert mask.dtype == torch.float32
864
+ assert torch.all((mask == 0) | (mask == 1)), "Mask must be binary (0 or 1)"
865
+
866
+ p_none = 1.0 - p_erode - p_dilate
867
+ assert 0 <= p_none <= 1, "Probabilities must sum to 1 and be valid."
868
+
869
+ # Sample operation.
870
+ op = random.choices(["erode", "dilate", "none"], weights=[p_erode, p_dilate, p_none], k=1)[0]
871
+
872
+ if op == "none":
873
+ pass
874
+ else:
875
+ # Sample kernel size
876
+ ksize = random.randint(*kernel_range)
877
+ kernel = np.ones((ksize, ksize), np.uint8)
878
+
879
+ mask = mask.squeeze().cpu().numpy().astype(np.uint8) # (H, W)
880
+
881
+ if op == "erode":
882
+ mask = cv2.erode(mask, kernel, **kwargs)
883
+ elif op == "dilate":
884
+ mask = cv2.dilate(mask, kernel, **kwargs)
885
+ else:
886
+ raise NotImplementedError
887
+
888
+ mask = torch.from_numpy(mask).float()[None] # (1, H, W)
889
+
890
+ return image, mask
891
+
892
+
893
+ def resolution_blur(
894
+ image: torch.Tensor,
895
+ mask: torch.Tensor,
896
+ scale_range=(0.05, 0.95),
897
+ interpolation_down=tv_transforms.InterpolationMode.BICUBIC,
898
+ interpolation_up=tv_transforms.InterpolationMode.BICUBIC,
899
+ ):
900
+ """
901
+ Blur the input image by applying upsample(downsample(x)).
902
+
903
+ Args:
904
+ image (torch.Tensor): Image tensor of shape (C, H, W), float32, with values in [0, 1].
905
+ mask (torch.Tensor): Mask tensor of shape (1, H, W), float32, with values in [0, 1]. The mask is returned unchanged.
906
+ scale_range: Tuple of (min_scale, max_scale) for downsampling.
907
+ interpolation_down: Interpolation mode for downsampling.
908
+ interpolation_up: Interpolation mode for upsampling.
909
+ """
910
+ C, H, W = image.shape
911
+ scale = random.uniform(*scale_range)
912
+ new_H, new_W = max(1, int(H * scale)), max(1, int(W * scale))
913
+
914
+ # Downsample
915
+ image = TF.resize(image, size=[new_H, new_W], interpolation=interpolation_down)
916
+
917
+ # Upsample back to original size
918
+ image = TF.resize(image, size=[H, W], interpolation=interpolation_up)
919
+
920
+ return image, mask
921
+
922
+
923
+ def gaussian_blur(
924
+ image: torch.Tensor,
925
+ mask: torch.Tensor,
926
+ kernel_range: tuple[int, int] = (3, 15),
927
+ sigma_range: tuple[int, int] = (0.1, 4.0),
928
+ ):
929
+ """
930
+ Apply gaussian blur to the input image.
931
+
932
+ Args:
933
+ image (torch.Tensor): Image tensor of shape (C, H, W), float32, with values in [0, 1].
934
+ mask (torch.Tensor): Mask tensor of shape (1, H, W), float32, with values in [0, 1]. The mask is returned unchanged.
935
+ kernel_range (tuple): Range of odd kernel sizes to sample from for the Gaussian blur (min, max).
936
+ sigma_range (tuple): Range of sigma values (standard deviation) to sample from for the Gaussian kernel (min, max).
937
+ """
938
+ kernel_size = random.choice([k for k in range(kernel_range[0], kernel_range[1]+1) if k % 2 == 1])
939
+ sigma = random.uniform(*sigma_range)
940
+ pad = kernel_size // 2
941
+
942
+ # Step 1: Pad the image
943
+ image = F.pad(image.unsqueeze(0), (pad, pad, pad, pad), mode='replicate')
944
+
945
+ # Step 2: Apply gaussian blur
946
+ image = TF.gaussian_blur(image, kernel_size=[kernel_size, kernel_size], sigma=sigma)
947
+
948
+ # Step 3: Unpad to get back to original size
949
+ image = image[:, :, pad:-pad, pad:-pad]
950
+
951
+ return image.squeeze(0), mask
952
+
953
+
954
+ def apply_blur_augmentation(
955
+ image: torch.Tensor,
956
+ mask: torch.Tensor,
957
+ p_resolution: float = 0.33,
958
+ p_gaussian: float = 0.33,
959
+ gaussian_kwargs: dict = None,
960
+ resolution_kwargs: dict = None,
961
+ ):
962
+ """Apply blur augmentation with configurable parameters"""
963
+
964
+ # Handle None defaults BEFORE unpacking
965
+ if gaussian_kwargs is None:
966
+ gaussian_kwargs = {}
967
+ if resolution_kwargs is None:
968
+ resolution_kwargs = {}
969
+
970
+ p_none = 1.0 - p_gaussian - p_resolution
971
+ assert 0 <= p_none <= 1, "Probabilities must sum to 1 and be valid."
972
+
973
+ operation = random.choices(
974
+ ["gaussian", "resolution", "none"],
975
+ weights=[p_gaussian, p_resolution, p_none],
976
+ k=1
977
+ )[0]
978
+
979
+ if operation == "gaussian":
980
+ return gaussian_blur(image, mask, **gaussian_kwargs)
981
+ elif operation == "resolution":
982
+ return resolution_blur(image, mask, **resolution_kwargs)
983
+ elif operation == "none":
984
+ return image, mask
985
+ else:
986
+ raise NotImplementedError
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_processing.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import math
3
+
4
+ import random
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from torchvision import transforms
10
+ from torchvision.transforms import functional as tv_F
11
+
12
+
13
+ class RandomResizedCrop(transforms.RandomResizedCrop):
14
+ """
15
+ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
16
+ This may lead to results different with torchvision's version.
17
+ Following BYOL's TF code:
18
+ https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
19
+ """
20
+
21
+ @staticmethod
22
+ def get_params(img, scale, ratio):
23
+ width, height = tv_F._get_image_size(img)
24
+ area = height * width
25
+
26
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
27
+ log_ratio = torch.log(torch.tensor(ratio))
28
+ aspect_ratio = torch.exp(
29
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
30
+ ).item()
31
+
32
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
33
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
34
+
35
+ w = min(w, width)
36
+ h = min(h, height)
37
+
38
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
39
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
40
+
41
+ return i, j, h, w
42
+
43
+
44
+ # following PT3D CO3D data to pad image
45
+ def pad_to_square(image, value=0):
46
+ _, _, h, w = image.shape # Assuming image is in (B, C, H, W) format
47
+ if h == w:
48
+ return image # The image is already square
49
+
50
+ # Calculate the padding
51
+ diff = abs(h - w)
52
+ pad2 = diff
53
+
54
+ # Pad the image to make it square
55
+ if h > w:
56
+ padding = (0, pad2, 0, 0) # Pad width (left, right, top, bottom)
57
+ else:
58
+ padding = (0, 0, 0, pad2) # Pad height
59
+ # Apply padding
60
+ padded_image = torch.nn.functional.pad(image, padding, mode="constant", value=value)
61
+ return padded_image
62
+
63
+
64
+ def preprocess_img(
65
+ x,
66
+ mask=None,
67
+ img_target_shape=224,
68
+ mask_target_shape=256,
69
+ normalize=False,
70
+ ):
71
+ if x.shape[1] != x.shape[2]:
72
+ x = pad_to_square(x)
73
+ if mask is not None and mask.shape[1] != mask.shape[2]:
74
+ mask = pad_to_square(mask)
75
+ if x.shape[2] != img_target_shape:
76
+ x = F.interpolate(
77
+ x,
78
+ size=(img_target_shape, img_target_shape),
79
+ # scale_factor=float(img_target_shape)/x.shape[2],
80
+ mode="bilinear",
81
+ )
82
+ if mask is not None and mask.shape[2] != mask_target_shape:
83
+ if mask is not None:
84
+ mask = F.interpolate(
85
+ mask,
86
+ size=(mask_target_shape, mask_target_shape),
87
+ # scale_factor=float(mask_target_shape)/mask.shape[2],
88
+ mode="nearest",
89
+ )
90
+ if normalize:
91
+ imgs_normed = resnet_img_normalization(x)
92
+ else:
93
+ imgs_normed = x
94
+ return imgs_normed, mask
95
+
96
+
97
+ def resnet_img_normalization(x):
98
+ resnet_mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).reshape(
99
+ (3, 1, 1)
100
+ )
101
+ resnet_std = torch.tensor([0.229, 0.224, 0.225], device=x.device).reshape((3, 1, 1))
102
+ if x.ndim == 4:
103
+ resnet_mean = resnet_mean[None]
104
+ resnet_std = resnet_std[None]
105
+ x = (x - resnet_mean) / resnet_std
106
+ return x
107
+
108
+
109
+ # pad image to be centered for unprojecting depth
110
+ def pad_to_square_centered(image, value=0, pointmap=None):
111
+ h, w = image.shape[-2], image.shape[-1] # Assuming image is in (B, C, H, W) format
112
+ if h == w:
113
+ if pointmap is not None:
114
+ return image, pointmap
115
+ return image # The image is already square
116
+
117
+ # Calculate the padding
118
+ diff = abs(h - w)
119
+ pad1 = diff // 2
120
+ pad2 = diff - pad1
121
+
122
+ # Pad the image to make it square
123
+ if h > w:
124
+ padding = (pad1, pad2, 0, 0) # Pad width (left, right, top, bottom)
125
+ else:
126
+ padding = (0, 0, pad1, pad2) # Pad height
127
+ # Apply padding to image
128
+ padded_image = F.pad(image, padding, mode="constant", value=value)
129
+
130
+ # Apply padding to pointmap if provided
131
+ if pointmap is not None:
132
+ # Pad pointmap using torch functional with NaN fill value
133
+ padded_pointmap = F.pad(pointmap, padding, mode="constant", value=float("nan"))
134
+
135
+ return padded_image, padded_pointmap
136
+ return padded_image
137
+
138
+
139
+ def crop_img_to_obj(mask, context_size):
140
+ nonzeros = torch.nonzero(mask)
141
+ if len(nonzeros) > 0:
142
+ r_max, c_max = nonzeros.max(dim=0)[0]
143
+ r_min, c_min = nonzeros.min(dim=0)[0]
144
+ box_h = max(1, r_max - r_min)
145
+ box_w = max(1, c_max - c_min)
146
+ left = max(0, c_min - int(box_w * context_size))
147
+ right = min(mask.shape[-1], c_max + int(box_w * context_size))
148
+ top = max(0, r_min - int(box_h * context_size))
149
+ bot = min(mask.shape[-2], r_max + int(box_h * context_size))
150
+ return left, right, top, bot
151
+ return None, None, None, None
152
+
153
+
154
+ def random_pad(img, mask=None, max_ratio=0.0, pointmap=None):
155
+ max_size = int(max(img.shape) * max_ratio)
156
+ padding = tuple([random.randint(0, max_size) for _ in range(4)])
157
+ img = F.pad(img, padding)
158
+ if mask is not None:
159
+ mask = F.pad(mask, padding)
160
+
161
+ if pointmap is not None:
162
+ pointmap = F.pad(pointmap, padding, mode="constant", value=float("nan"))
163
+ return img, mask, pointmap
164
+ return img, mask
165
+
166
+
167
+ def get_img_color_augmentation(
168
+ color_jit_prob=0.5,
169
+ gaussian_blur_prob=0.1,
170
+ ):
171
+ transform = transforms.Compose(
172
+ [
173
+ # (a) Random Color Jitter
174
+ transforms.RandomApply(
175
+ [
176
+ transforms.ColorJitter(
177
+ brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
178
+ )
179
+ ],
180
+ p=color_jit_prob,
181
+ ),
182
+ # (b) Randomly apply GaussianBlur
183
+ transforms.RandomApply(
184
+ [transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))],
185
+ p=gaussian_blur_prob,
186
+ ),
187
+ ]
188
+ )
189
+ return transform
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/pose_target.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import torch
3
+ from typing import Dict, Optional, Tuple, Any
4
+ from dataclasses import dataclass, asdict, field
5
+ from loguru import logger
6
+
7
+ from sam3d_objects.data.utils import expand_as_right, tree_tensor_map
8
+ from sam3d_objects.data.dataset.tdfy.transforms_3d import compose_transform, decompose_transform
9
+ from pytorch3d.transforms import Transform3d, quaternion_to_matrix, matrix_to_quaternion
10
+
11
+
12
+ @dataclass
13
+ class InstancePose:
14
+ """
15
+ Stores the pose of an object.
16
+ Also, stores some information about the scene that was used to normalize the pose.
17
+ """
18
+
19
+ instance_scale_l2c: torch.Tensor
20
+ instance_position_l2c: torch.Tensor
21
+ instance_quaternion_l2c: torch.Tensor
22
+ scene_scale: torch.Tensor
23
+ scene_shift: torch.Tensor
24
+
25
+ @classmethod
26
+ def _broadcast_postcompose(
27
+ cls,
28
+ scale: torch.Tensor,
29
+ rotation: torch.Tensor,
30
+ translation: torch.Tensor,
31
+ transform_to_postcompose: Transform3d,
32
+ ) -> Transform3d:
33
+ """
34
+ Assumes scale, rotation, translation are of shape:
35
+ B, K, C
36
+ ---
37
+ B: batch size
38
+ K: number of objects
39
+ C: number of channels
40
+
41
+ Takes a transform where
42
+ get_matrix() has shape (B, 3, 3)
43
+
44
+ Returns pose.compose(transform_to_postcompose)
45
+ """
46
+ scale_c = scale.shape[-1]
47
+ ndim_orig = scale.ndim
48
+ if ndim_orig == 3:
49
+ b, k, _ = scale.shape
50
+ elif ndim_orig == 2:
51
+ b = scale.shape[0]
52
+ k = 1
53
+ elif ndim_orig == 1:
54
+ b = 1
55
+ k = 1
56
+ else:
57
+ raise ValueError(f"Invalid scale shape: {scale.shape}")
58
+
59
+ # Create transform of shape (B * K)
60
+ wide = {"scale": scale, "rotation": rotation, "translation": translation}
61
+ shapes_orig = {k: v.shape for k, v in wide.items()}
62
+ long = tree_tensor_map(lambda x: x.reshape(b * k, x.shape[-1]), wide)
63
+ long["rotation"] = quaternion_to_matrix(long["rotation"])
64
+ if scale_c == 1:
65
+ long["scale"] = long["scale"].expand(b * k, 3)
66
+
67
+ composed = compose_transform(**long)
68
+
69
+ # Apply transform to shape (B * K)
70
+ pc_transform = transform_to_postcompose.get_matrix()
71
+ pc_transform = pc_transform.repeat(k, 1, 1)
72
+ stacked_pc_transform = Transform3d(matrix=pc_transform)
73
+ assert stacked_pc_transform.get_matrix().shape == composed.get_matrix().shape
74
+ postcomposed = composed.compose(stacked_pc_transform)
75
+
76
+ # Decompose transform to shape (B, K, C)
77
+ scale, rotation, translation = decompose_transform(postcomposed)
78
+ rotation = matrix_to_quaternion(rotation)
79
+ pc_long = {"scale": scale, "rotation": rotation, "translation": translation}
80
+ pc_wide = tree_tensor_map(lambda x: x.reshape(b, k, x.shape[-1]), pc_long)
81
+ if scale_c == 1:
82
+ pc_wide["scale"] = pc_wide["scale"][..., 0].unsqueeze(-1)
83
+ for k, shape in shapes_orig.items():
84
+ pc_wide[k] = pc_wide[k].reshape(*shape)
85
+ return pc_wide["scale"], pc_wide["rotation"], pc_wide["translation"]
86
+
87
+
88
+ @dataclass
89
+ class PoseTarget:
90
+ x_instance_scale: torch.Tensor
91
+ x_instance_rotation: torch.Tensor
92
+ x_instance_translation: torch.Tensor
93
+ x_scene_scale: torch.Tensor
94
+ x_scene_center: torch.Tensor
95
+ x_translation_scale: torch.Tensor
96
+ pose_target_convention: str = field(default="unknown")
97
+
98
+
99
+ @dataclass
100
+ class InvariantPoseTarget:
101
+ """
102
+ This is the canonical representation of pose targets, used for computing metrics.
103
+ instance_pose <-> invariant_pose_targets <-> all other pose_target_conventions
104
+
105
+ Background:
106
+ ---
107
+ We want to estimate a transformation T: R³ → R³ despite scene scale ambiguity.
108
+
109
+ The transformation taking object points to scene points is defined as
110
+ T(x) = s · R(q) · x + t
111
+ where:
112
+ - x is a point in the object coordinate frame,
113
+ - q is a unit quaternion representing rotation,
114
+ - s is the object-to-scene scale, and
115
+ - t is the translation.
116
+
117
+ However, there is an inherent scale ambiguity in the scene, denoted as s_scene;
118
+ This ambiguity introduces irreducible error that complicates both evaluation and training.
119
+
120
+ To decouple the scene scale from the invariant quantities, we define:
121
+ T(x) = s_scene · |t_rel| [ s_tilde · R(q) · x + t_unit ]
122
+ where we define
123
+ t_rel = t / s_scene
124
+ s_rel = s / s_scene
125
+ s_tilde = s_rel / |t_rel|
126
+ t_unit = t_rel / |t_rel|
127
+
128
+ During training, you would predict (q, s_tilde, t_unit), leaving s_scene separate.
129
+
130
+
131
+ Hand-wavy error analysis:
132
+ ---
133
+ 1. Naive (coupled) estimate:
134
+ T(x) = s_scene [ s_rel · R(q) · x + t_rel ]
135
+
136
+ We can define:
137
+ U = ln(s_rel)
138
+ V = ln(|t_rel|)
139
+ so that the error is governed by Var(U + V).
140
+
141
+ 2. In the decoupled case, we have:
142
+ T(x) = s_scene · |t_rel| [ s_tilde · R(q) · x + t_unit ]
143
+ = s_scene · |t_rel| [ (s_rel / |t_rel|) R(q) · x + t_unit ]
144
+ Then ln(s_tilde) = ln(s_rel) - ln(|t_rel|) = U - V, and the error is
145
+ Var(U - V) = Var(U) + Var(V) - 2Cov(U, V).
146
+
147
+ """
148
+
149
+ # These are invariant
150
+ q: torch.Tensor
151
+ t_unit: torch.Tensor
152
+ s_scene: torch.Tensor
153
+ t_scene_center: Optional[torch.Tensor] = None
154
+ t_rel_norm: Optional[torch.Tensor] = None
155
+ s_tilde: Optional[torch.Tensor] = None
156
+ s_rel: Optional[torch.Tensor] = None
157
+
158
+ def __post_init__(self):
159
+ # Check that fields that are required always have values.
160
+ if self.q is None:
161
+ raise ValueError("Field 'q' (quaternion) must be provided.")
162
+ if self.s_scene is None:
163
+ raise ValueError("Field 's_scene' must be provided.")
164
+ if self.s_rel is None:
165
+ if self.s_tilde is not None:
166
+ self.s_rel = self.s_tilde * self.t_rel_norm
167
+ else:
168
+ raise ValueError("Field 's_rel' or 's_tilde' must be provided.")
169
+ if self.t_unit is None:
170
+ raise ValueError("Field 't_unit' must be provided.")
171
+
172
+ if self.t_scene_center is None:
173
+ self.t_scene_center = torch.zeros_like(self.t_unit)
174
+
175
+ # There is a simple relationship between s_tilde and t_rel_norm:
176
+ # s_tilde = s_rel / t_rel_norm
177
+ #
178
+ # If one of these is missing and the other is provided, we can compute the missing field.
179
+ if self.s_tilde is None and self.t_rel_norm is not None:
180
+ self.s_tilde = self.s_rel / self.t_rel_norm
181
+ elif self.t_rel_norm is None and self.s_tilde is not None:
182
+ self.t_rel_norm = self.s_rel / self.s_tilde
183
+
184
+ # If both are provided, we check for consistency.
185
+ if self.s_tilde is not None and self.t_rel_norm is not None:
186
+ computed_s_tilde = self.s_rel / self.t_rel_norm
187
+ # If the provided s_tilde deviates from what is computed, update it.
188
+ if not torch.allclose(self.s_tilde, computed_s_tilde, atol=1e-6):
189
+ logger.warning(
190
+ f"s_tilde and t_rel_norm are provided, but they are not consistent. "
191
+ f"Updating s_tilde to {computed_s_tilde}."
192
+ )
193
+ self.s_tilde = computed_s_tilde
194
+
195
+ self._validate_fields()
196
+
197
+ def _validate_fields(self):
198
+ for field in self.__dict__:
199
+ if self.__dict__[field] is None:
200
+ raise ValueError(f"Field '{field}' must be provided.")
201
+
202
+
203
+ @staticmethod
204
+ def from_instance_pose(instance_pose: InstancePose) -> "InvariantPoseTarget":
205
+ q = instance_pose.instance_quaternion_l2c
206
+ s_obj_to_scene = instance_pose.instance_scale_l2c # (..., 1) or (..., 3)
207
+ t_obj_to_scene = instance_pose.instance_position_l2c # (..., 3)
208
+ s_scene = instance_pose.scene_scale # (..., 1) or scalar-broadcastable
209
+ t_scene_center = instance_pose.scene_shift # (..., 3)
210
+
211
+ # Normalize to scene scale (per the derivation)
212
+ if not ( s_obj_to_scene.ndim == (s_scene.ndim + 1)):
213
+ raise ValueError(f"s_scene should be ND [...,3] and s_obj_to_scene should be (N+1)D [...,K,3], but got {s_scene.shape=} {s_obj_to_scene.shape=}")
214
+ if not (t_obj_to_scene.ndim == (s_scene.ndim + 1)):
215
+ raise ValueError(f"t_scene_center should be ND [B,3] and t_obj_to_scene should be (N+1)D [B,K,3], but got {t_scene_center.shape=} {t_obj_to_scene.shape=}")
216
+ s_scene_exp = s_scene.unsqueeze(-2)
217
+
218
+ s_rel = s_obj_to_scene / s_scene_exp
219
+ t_rel = t_obj_to_scene / s_scene_exp
220
+
221
+ # Robust norms
222
+ eps = 1e-8
223
+ t_rel_norm = t_rel.norm(dim=-1, keepdim=True).clamp_min(eps)
224
+
225
+ s_tilde = s_rel / t_rel_norm
226
+ t_unit = t_rel / t_rel_norm
227
+
228
+ return InvariantPoseTarget(
229
+ q=q,
230
+ s_scene=s_scene,
231
+ t_scene_center=t_scene_center,
232
+ s_rel=s_rel,
233
+ s_tilde=s_tilde,
234
+ t_unit=t_unit,
235
+ t_rel_norm=t_rel_norm,
236
+ )
237
+
238
+
239
+ @staticmethod
240
+ def to_instance_pose(invariant_targets: "InvariantPoseTarget") -> InstancePose:
241
+ # scale factor per the derivation: s_scene * |t_rel|
242
+ # Normalize to scene scale (per the derivation)
243
+ t_rel_norm_ndim = invariant_targets.t_rel_norm.ndim
244
+ if not (invariant_targets.s_scene.ndim == (t_rel_norm_ndim - 1)) :
245
+ raise ValueError(f"s_scene should be ND [...,3] and t_rel_norm should be (N+1)D [...,K,3], but got {invariant_targets.s_scene.shape=} {invariant_targets.t_rel_norm.shape=}")
246
+
247
+ scale = invariant_targets.s_scene.unsqueeze(-2) * invariant_targets.t_rel_norm
248
+ return InstancePose(
249
+ instance_scale_l2c=invariant_targets.s_tilde * scale,
250
+ instance_position_l2c=invariant_targets.t_unit * scale,
251
+ instance_quaternion_l2c=invariant_targets.q,
252
+ scene_scale=invariant_targets.s_scene,
253
+ scene_shift=invariant_targets.t_scene_center,
254
+ )
255
+
256
+
257
+ class PoseTargetConvention:
258
+ """
259
+ Converts pose_targets <-> instance_pose <-> invariant_pose_targets
260
+ """
261
+
262
+ pose_target_convention: str
263
+
264
+ @classmethod
265
+ def from_invariant(cls, invariant_targets: InvariantPoseTarget) -> PoseTarget:
266
+ raise NotImplementedError("Implement this in a subclass")
267
+
268
+ @classmethod
269
+ def to_invariant(cls, instance_pose: InstancePose) -> InvariantPoseTarget:
270
+ raise NotImplementedError("Implement this in a subclass")
271
+
272
+ @classmethod
273
+ def from_instance_pose(cls, instance_pose: InstancePose) -> PoseTarget:
274
+ invariant_targets = InvariantPoseTarget.from_instance_pose(instance_pose)
275
+ return cls.from_invariant(invariant_targets)
276
+
277
+ @classmethod
278
+ def to_instance_pose(cls, pose_target: PoseTarget) -> InstancePose:
279
+ invariant_targets = cls.to_invariant(pose_target)
280
+ return InvariantPoseTarget.to_instance_pose(invariant_targets)
281
+
282
+
283
+ class ScaleShiftInvariant(PoseTargetConvention):
284
+ """
285
+
286
+ Midas eq. (6): https://arxiv.org/pdf/1907.01341v3
287
+ But for pointmaps (see MoGe): https://arxiv.org/pdf/2410.19115
288
+ """
289
+
290
+ pose_target_convention: str = "ScaleShiftInvariant"
291
+ scale_mean = torch.tensor([1.0232692956924438, 1.0232691764831543, 1.0232692956924438]).to(torch.float32)
292
+ scale_std = torch.tensor([1.3773751258850098, 1.3773752450942993, 1.3773750066757202]).to(torch.float32)
293
+ translation_mean = torch.tensor([0.003191213821992278, 0.017236359417438507, 0.9401122331619263]).to(torch.float32)
294
+ translation_std = torch.tensor([1.341888666152954, 0.7665449380874634, 3.175130605697632]).to(torch.float32)
295
+
296
+ @classmethod
297
+ def from_instance_pose(cls, instance_pose: InstancePose, normalize: bool = False) -> PoseTarget:
298
+ metric_to_ssi = cls.ssi_to_metric(
299
+ instance_pose.scene_scale, instance_pose.scene_shift
300
+ ).inverse()
301
+
302
+ ssi_scale, ssi_rotation, ssi_translation = InstancePose._broadcast_postcompose(
303
+ scale=instance_pose.instance_scale_l2c,
304
+ rotation=instance_pose.instance_quaternion_l2c,
305
+ translation=instance_pose.instance_position_l2c,
306
+ transform_to_postcompose=metric_to_ssi,
307
+ )
308
+ # logger.info(f"{normalize=} {ssi_scale.shape=} {ssi_rotation.shape=} {ssi_translation.shape=}")
309
+ if normalize:
310
+ device = ssi_scale.device
311
+ ssi_scale = (ssi_scale - cls.scale_mean.to(device)) / cls.scale_std.to(device)
312
+ ssi_translation = (ssi_translation - cls.translation_mean.to(device)) / cls.translation_std.to(device)
313
+
314
+ return PoseTarget(
315
+ x_instance_scale=ssi_scale,
316
+ x_instance_rotation=ssi_rotation,
317
+ x_instance_translation=ssi_translation,
318
+ x_scene_scale=instance_pose.scene_scale,
319
+ x_scene_center=instance_pose.scene_shift,
320
+ x_translation_scale=torch.ones_like(ssi_scale)[..., 0].unsqueeze(-1),
321
+ pose_target_convention=cls.pose_target_convention,
322
+ )
323
+
324
+ @classmethod
325
+ def to_instance_pose(cls, pose_target: PoseTarget, normalize: bool = False) -> InstancePose:
326
+ scene_scale = pose_target.x_scene_scale
327
+ scene_shift = pose_target.x_scene_center
328
+ ssi_to_metric = cls.ssi_to_metric(scene_scale, scene_shift)
329
+
330
+ if normalize:
331
+ device = pose_target.x_instance_scale.device
332
+ pose_target.x_instance_scale = pose_target.x_instance_scale * cls.scale_std.to(device) + cls.scale_mean.to(device)
333
+ pose_target.x_instance_translation = pose_target.x_instance_translation * cls.translation_std.to(device) + cls.translation_mean.to(device)
334
+
335
+ ins_scale, ins_rotation, ins_translation = InstancePose._broadcast_postcompose(
336
+ scale=pose_target.x_instance_scale,
337
+ rotation=pose_target.x_instance_rotation,
338
+ translation=pose_target.x_instance_translation,
339
+ transform_to_postcompose=ssi_to_metric,
340
+ )
341
+
342
+ return InstancePose(
343
+ instance_scale_l2c=ins_scale,
344
+ instance_position_l2c=ins_translation,
345
+ instance_quaternion_l2c=ins_rotation,
346
+ scene_scale=scene_scale,
347
+ scene_shift=scene_shift,
348
+ )
349
+
350
+ @classmethod
351
+ def to_invariant(cls, pose_target: PoseTarget, normalize: bool = False) -> InvariantPoseTarget:
352
+ instance_pose = cls.to_instance_pose(pose_target, normalize=normalize)
353
+ return InvariantPoseTarget.from_instance_pose(instance_pose)
354
+
355
+ @classmethod
356
+ def from_invariant(cls, invariant_targets: InvariantPoseTarget, normalize: bool = False) -> PoseTarget:
357
+ instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets)
358
+ return cls.from_instance_pose(instance_pose, normalize=normalize)
359
+
360
+ @classmethod
361
+ def get_scale_and_shift(cls, pointmap):
362
+ shift_z = pointmap[..., -1].nanmedian().unsqueeze(0)
363
+ shift = torch.zeros_like(shift_z.expand(1, 3))
364
+ shift[..., -1] = shift_z
365
+
366
+ shifted_pointmap = pointmap - shift
367
+ scale = shifted_pointmap.abs().nanmean().to(shift.device)
368
+
369
+ shift = shift.reshape(3)
370
+ scale = scale.expand(3)
371
+
372
+ return scale, shift
373
+
374
+ @staticmethod
375
+ def ssi_to_metric(scale: torch.Tensor, shift: torch.Tensor):
376
+ if scale.ndim == 1:
377
+ scale = scale.unsqueeze(0)
378
+ if shift.ndim == 1:
379
+ shift = shift.unsqueeze(0)
380
+ return Transform3d().scale(scale).translate(shift).to(shift.device)
381
+
382
+
383
+ class ScaleShiftInvariantWTranslationScale(PoseTargetConvention):
384
+ """
385
+
386
+ Midas eq. (6): https://arxiv.org/pdf/1907.01341v3
387
+ But for pointmaps (see MoGe): https://arxiv.org/pdf/2410.19115
388
+ """
389
+
390
+ pose_target_convention: str = "ScaleShiftInvariantWTranslationScale"
391
+ scale_mean = torch.tensor([1.0232692956924438, 1.0232691764831543, 1.0232692956924438]).to(torch.float32)
392
+ scale_std = torch.tensor([1.3773751258850098, 1.3773752450942993, 1.3773750066757202]).to(torch.float32)
393
+ translation_mean = torch.tensor([0.003191213821992278, 0.017236359417438507, 0.9401122331619263]).to(torch.float32)
394
+ translation_std = torch.tensor([1.341888666152954, 0.7665449380874634, 3.175130605697632]).to(torch.float32)
395
+
396
+ @classmethod
397
+ def from_instance_pose(cls, instance_pose: InstancePose, normalize: bool = False) -> PoseTarget:
398
+ metric_to_ssi = cls.ssi_to_metric(
399
+ instance_pose.scene_scale, instance_pose.scene_shift
400
+ ).inverse()
401
+
402
+ ssi_scale, ssi_rotation, ssi_translation = InstancePose._broadcast_postcompose(
403
+ scale=instance_pose.instance_scale_l2c,
404
+ rotation=instance_pose.instance_quaternion_l2c,
405
+ translation=instance_pose.instance_position_l2c,
406
+ transform_to_postcompose=metric_to_ssi,
407
+ )
408
+
409
+ ssi_translation_scale = ssi_translation.norm(dim=-1, keepdim=True)
410
+ ssi_translation_unit = ssi_translation / ssi_translation_scale.clamp_min(1e-7)
411
+
412
+ return PoseTarget(
413
+ x_instance_scale=ssi_scale,
414
+ x_instance_rotation=ssi_rotation,
415
+ x_instance_translation=ssi_translation_unit,
416
+ x_scene_scale=instance_pose.scene_scale,
417
+ x_scene_center=instance_pose.scene_shift,
418
+ x_translation_scale=ssi_translation_scale,
419
+ pose_target_convention=cls.pose_target_convention,
420
+ )
421
+
422
+ @classmethod
423
+ def to_instance_pose(cls, pose_target: PoseTarget, normalize: bool = False) -> InstancePose:
424
+ scene_scale = pose_target.x_scene_scale
425
+ scene_shift = pose_target.x_scene_center
426
+ ssi_to_metric = cls.ssi_to_metric(scene_scale, scene_shift)
427
+
428
+ ins_translation_unit = pose_target.x_instance_translation / pose_target.x_instance_translation.norm(dim=-1, keepdim=True)
429
+ ins_translation = ins_translation_unit * pose_target.x_translation_scale
430
+
431
+
432
+ ins_scale, ins_rotation, ins_translation = InstancePose._broadcast_postcompose(
433
+ scale=pose_target.x_instance_scale,
434
+ rotation=pose_target.x_instance_rotation,
435
+ translation=ins_translation,
436
+ transform_to_postcompose=ssi_to_metric,
437
+ )
438
+
439
+
440
+ return InstancePose(
441
+ instance_scale_l2c=ins_scale,
442
+ instance_position_l2c=ins_translation,
443
+ instance_quaternion_l2c=ins_rotation,
444
+ scene_scale=scene_scale,
445
+ scene_shift=scene_shift,
446
+ )
447
+
448
+ @classmethod
449
+ def to_invariant(cls, pose_target: PoseTarget) -> InvariantPoseTarget:
450
+ instance_pose = cls.to_instance_pose(pose_target)
451
+ return InvariantPoseTarget.from_instance_pose(instance_pose)
452
+
453
+ @classmethod
454
+ def from_invariant(cls, invariant_targets: InvariantPoseTarget) -> PoseTarget:
455
+ instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets)
456
+ return cls.from_instance_pose(instance_pose)
457
+
458
+ @classmethod
459
+ def get_scale_and_shift(cls, pointmap):
460
+ shift_z = pointmap[..., -1].nanmedian().unsqueeze(0)
461
+ shift = torch.zeros_like(shift_z.expand(1, 3))
462
+ shift[..., -1] = shift_z
463
+
464
+ shifted_pointmap = pointmap - shift
465
+ scale = shifted_pointmap.abs().nanmean().to(shift.device)
466
+
467
+ shift = shift.reshape(3)
468
+ scale = scale.expand(3)
469
+
470
+ return scale, shift
471
+
472
+ @staticmethod
473
+ def ssi_to_metric(scale: torch.Tensor, shift: torch.Tensor):
474
+ if scale.ndim == 1:
475
+ scale = scale.unsqueeze(0)
476
+ if shift.ndim == 1:
477
+ shift = shift.unsqueeze(0)
478
+ return Transform3d().scale(scale).translate(shift).to(shift.device)
479
+
480
+
481
+ class DisparitySpace(PoseTargetConvention):
482
+ pose_target_convention: str = "DisparitySpace"
483
+
484
+ @classmethod
485
+ def from_instance_pose(cls, instance_pose: InstancePose, normalize: bool = False) -> PoseTarget:
486
+
487
+ # x_instance_scale = orig_scale / scene_scale
488
+ # x_instance_translation = [x/z, y/z, 0] / scene_scale
489
+ # x_translation_scale = z / scene_scale
490
+ assert torch.allclose(instance_pose.scene_scale, torch.ones_like(instance_pose.scene_scale))
491
+
492
+ if not instance_pose.scene_shift.ndim == instance_pose.instance_position_l2c.ndim - 1:
493
+ raise ValueError(f"scene_shift must be (N+1)D and instance_position_l2c must be (N+1)D, but got {instance_pose.scene_shift.ndim} and {instance_pose.instance_position_l2c.ndim}")
494
+ shift_xy, shift_z_log = instance_pose.scene_shift.unsqueeze(-2).split([2, 1], dim=-1)
495
+
496
+
497
+ pose_xy, pose_z = instance_pose.instance_position_l2c.split([2, 1], dim=-1)
498
+ # Handle batch dimensions properly
499
+ if shift_xy.ndim < pose_xy.ndim:
500
+ shift_xy = shift_xy.unsqueeze(-2)
501
+ pose_xy_scaled = pose_xy / pose_z - shift_xy
502
+
503
+ pose_z_scaled_log = torch.log(pose_z) - shift_z_log
504
+ x_instance_scale_log = torch.log(instance_pose.instance_scale_l2c) - torch.log(pose_z)
505
+
506
+ x_instance_translation = torch.cat([pose_xy_scaled, torch.zeros_like(pose_z)], dim=-1)
507
+ x_translation_scale = torch.exp(pose_z_scaled_log)
508
+ x_instance_scale = torch.exp(x_instance_scale_log)
509
+
510
+
511
+
512
+ return PoseTarget(
513
+ x_instance_scale=x_instance_scale,
514
+ x_instance_translation=x_instance_translation,
515
+ x_instance_rotation=instance_pose.instance_quaternion_l2c,
516
+ x_scene_scale=instance_pose.scene_scale,
517
+ x_scene_center=instance_pose.scene_shift,
518
+ x_translation_scale=x_translation_scale,
519
+ pose_target_convention=cls.pose_target_convention,
520
+ )
521
+
522
+ @classmethod
523
+ def to_instance_pose(cls, pose_target: PoseTarget, normalize: bool = False) -> InstancePose:
524
+ scene_scale = pose_target.x_scene_scale
525
+ scene_shift = pose_target.x_scene_center
526
+
527
+ if not pose_target.x_scene_center.ndim == pose_target.x_instance_translation.ndim - 1:
528
+ raise ValueError(f"x_scene_center must be (N+1)D and x_instance_translation must be (N+1)D, but got {pose_target.x_scene_center.ndim} and {pose_target.x_instance_translation.ndim}")
529
+ shift_xy, shift_z_log = pose_target.x_scene_center.unsqueeze(-2).split([2, 1], dim=-1)
530
+ scene_z_scale = torch.exp(shift_z_log)
531
+
532
+ z = pose_target.x_translation_scale
533
+ ins_translation = pose_target.x_instance_translation.clone()
534
+ ins_translation[...,2] = 1.0
535
+ ins_translation[...,:2] = ins_translation[...,:2] + shift_xy
536
+ ins_translation = ins_translation * z * scene_z_scale
537
+
538
+ ins_scale = pose_target.x_instance_scale * z * scene_z_scale
539
+
540
+ return InstancePose(
541
+ instance_scale_l2c=ins_scale * scene_scale,
542
+ instance_position_l2c=ins_translation * scene_scale,
543
+ instance_quaternion_l2c=pose_target.x_instance_rotation,
544
+ scene_scale=scene_scale,
545
+ scene_shift=scene_shift,
546
+ )
547
+
548
+ @classmethod
549
+ def to_invariant(cls, pose_target: PoseTarget, normalize: bool = False) -> InvariantPoseTarget:
550
+ instance_pose = cls.to_instance_pose(pose_target, normalize=normalize)
551
+ return InvariantPoseTarget.from_instance_pose(instance_pose)
552
+
553
+ @classmethod
554
+ def from_invariant(cls, invariant_targets: InvariantPoseTarget, normalize: bool = False) -> PoseTarget:
555
+ instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets)
556
+ return cls.from_instance_pose(instance_pose, normalize=normalize)
557
+
558
+
559
+
560
+ class NormalizedSceneScale(PoseTargetConvention):
561
+ """
562
+ x_instance_scale and x_translation_scale are normalized to x_scene_scale
563
+ """
564
+
565
+ pose_target_convention: str = "NormalizedSceneScale"
566
+
567
+ @classmethod
568
+ def from_invariant(cls, invariant_targets: InvariantPoseTarget):
569
+ translation = invariant_targets.t_unit * invariant_targets.t_rel_norm
570
+ return PoseTarget(
571
+ x_instance_scale=invariant_targets.s_rel,
572
+ x_instance_rotation=invariant_targets.q,
573
+ x_instance_translation=translation,
574
+ x_scene_scale=invariant_targets.s_scene,
575
+ x_scene_center=invariant_targets.t_scene_center,
576
+ x_translation_scale=torch.ones_like(invariant_targets.t_rel_norm),
577
+ pose_target_convention=cls.pose_target_convention,
578
+ )
579
+
580
+ @classmethod
581
+ def to_invariant(cls, pose_target: PoseTarget):
582
+ t_rel_norm = torch.norm(
583
+ pose_target.x_instance_translation, dim=-1, keepdim=True
584
+ )
585
+ return InvariantPoseTarget(
586
+ s_scene=pose_target.x_scene_scale,
587
+ s_rel=pose_target.x_instance_scale,
588
+ q=pose_target.x_instance_rotation,
589
+ t_unit=pose_target.x_instance_translation / t_rel_norm,
590
+ t_rel_norm=t_rel_norm,
591
+ t_scene_center=pose_target.x_scene_center,
592
+ )
593
+
594
+
595
+ class Naive(PoseTargetConvention):
596
+ pose_target_convention: str = "Naive"
597
+
598
+ @classmethod
599
+ def from_invariant(cls, invariant_targets: InvariantPoseTarget):
600
+ s_scene = invariant_targets.s_rel * invariant_targets.s_scene
601
+ t_scene = invariant_targets.t_unit * invariant_targets.t_rel_norm
602
+ return PoseTarget(
603
+ x_instance_scale=s_scene,
604
+ x_instance_rotation=invariant_targets.q,
605
+ x_instance_translation=t_scene,
606
+ x_scene_scale=invariant_targets.s_scene,
607
+ x_scene_center=invariant_targets.t_scene_center,
608
+ x_translation_scale=torch.ones_like(invariant_targets.t_rel_norm),
609
+ pose_target_convention=cls.pose_target_convention,
610
+ )
611
+
612
+ @classmethod
613
+ def to_invariant(cls, pose_target: PoseTarget):
614
+ s_scene = pose_target.x_scene_scale
615
+ t_rel_norm = torch.norm(
616
+ pose_target.x_instance_translation, dim=-1, keepdim=True
617
+ )
618
+ return InvariantPoseTarget(
619
+ s_scene=s_scene,
620
+ t_scene_center=pose_target.x_scene_center,
621
+ s_rel=pose_target.x_instance_scale / s_scene,
622
+ q=pose_target.x_instance_rotation,
623
+ t_unit=pose_target.x_instance_translation / t_rel_norm,
624
+ t_rel_norm=t_rel_norm,
625
+ )
626
+
627
+
628
+ class NormalizedSceneScaleAndTranslation(PoseTargetConvention):
629
+ """
630
+ x_instance_scale and x_translation_scale are normalized to x_scene_scale
631
+ x_instance_translation is unit
632
+ """
633
+
634
+ pose_target_convention: str = "NormalizedSceneScaleAndTranslation"
635
+
636
+ @classmethod
637
+ def from_invariant(cls, invariant_targets: InvariantPoseTarget):
638
+ return PoseTarget(
639
+ x_instance_scale=invariant_targets.s_rel,
640
+ x_instance_rotation=invariant_targets.q,
641
+ x_instance_translation=invariant_targets.t_unit,
642
+ x_scene_scale=invariant_targets.s_scene,
643
+ x_scene_center=invariant_targets.t_scene_center,
644
+ x_translation_scale=invariant_targets.t_rel_norm,
645
+ pose_target_convention=cls.pose_target_convention,
646
+ )
647
+
648
+ @classmethod
649
+ def to_invariant(cls, pose_target: PoseTarget):
650
+ return InvariantPoseTarget(
651
+ s_scene=pose_target.x_scene_scale,
652
+ t_scene_center=pose_target.x_scene_center,
653
+ s_rel=pose_target.x_instance_scale,
654
+ q=pose_target.x_instance_rotation,
655
+ t_unit=pose_target.x_instance_translation,
656
+ t_rel_norm=pose_target.x_translation_scale,
657
+ )
658
+
659
+
660
+ class ApparentSize(PoseTargetConvention):
661
+ pose_target_convention: str = "ApparentSize"
662
+
663
+ @classmethod
664
+ def from_invariant(cls, invariant_targets: InvariantPoseTarget):
665
+ return PoseTarget(
666
+ x_instance_scale=invariant_targets.s_tilde,
667
+ x_instance_rotation=invariant_targets.q,
668
+ x_instance_translation=invariant_targets.t_unit,
669
+ x_scene_scale=invariant_targets.s_scene,
670
+ x_scene_center=invariant_targets.t_scene_center,
671
+ x_translation_scale=invariant_targets.t_rel_norm,
672
+ pose_target_convention=cls.pose_target_convention,
673
+ )
674
+
675
+ @classmethod
676
+ def to_invariant(cls, pose_target: PoseTarget):
677
+ return InvariantPoseTarget(
678
+ s_scene=pose_target.x_scene_scale,
679
+ t_scene_center=pose_target.x_scene_center,
680
+ s_tilde=pose_target.x_instance_scale,
681
+ q=pose_target.x_instance_rotation,
682
+ t_unit=pose_target.x_instance_translation,
683
+ t_rel_norm=pose_target.x_translation_scale,
684
+ )
685
+
686
+
687
+ class Identity(PoseTargetConvention):
688
+ """
689
+ Identity convention - no transformation applied.
690
+ Direct passthrough mapping between instance pose and pose target values.
691
+ This preserves all values including scene_scale and scene_shift.
692
+ """
693
+
694
+ pose_target_convention: str = "Identity"
695
+
696
+ @classmethod
697
+ def from_instance_pose(cls, instance_pose: InstancePose) -> PoseTarget:
698
+ return PoseTarget(
699
+ x_instance_scale=instance_pose.instance_scale_l2c,
700
+ x_instance_rotation=instance_pose.instance_quaternion_l2c,
701
+ x_instance_translation=instance_pose.instance_position_l2c,
702
+ x_scene_scale=instance_pose.scene_scale,
703
+ x_scene_center=instance_pose.scene_shift,
704
+ x_translation_scale=torch.ones_like(instance_pose.instance_scale_l2c)[..., 0].unsqueeze(-1),
705
+ pose_target_convention=cls.pose_target_convention,
706
+ )
707
+
708
+ @classmethod
709
+ def to_instance_pose(cls, pose_target: PoseTarget) -> InstancePose:
710
+ return InstancePose(
711
+ instance_scale_l2c=pose_target.x_instance_scale,
712
+ instance_position_l2c=pose_target.x_instance_translation,
713
+ instance_quaternion_l2c=pose_target.x_instance_rotation,
714
+ scene_scale=pose_target.x_scene_scale,
715
+ scene_shift=pose_target.x_scene_center,
716
+ )
717
+
718
+ @classmethod
719
+ def to_invariant(cls, pose_target: PoseTarget) -> InvariantPoseTarget:
720
+ instance_pose = cls.to_instance_pose(pose_target)
721
+ return InvariantPoseTarget.from_instance_pose(instance_pose)
722
+
723
+ @classmethod
724
+ def from_invariant(cls, invariant_targets: InvariantPoseTarget) -> PoseTarget:
725
+ instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets)
726
+ return cls.from_instance_pose(instance_pose)
727
+
728
+
729
+ class PoseTargetConverter:
730
+ @staticmethod
731
+ def pose_target_to_instance_pose(pose_target: PoseTarget, normalize: bool = False) -> InstancePose:
732
+ _convention_class = globals()[pose_target.pose_target_convention]
733
+ if _convention_class == ScaleShiftInvariant:
734
+ return _convention_class.to_instance_pose(pose_target, normalize=normalize)
735
+ else:
736
+ return _convention_class.to_instance_pose(pose_target)
737
+
738
+ @staticmethod
739
+ def instance_pose_to_pose_target(
740
+ instance_pose: InstancePose, pose_target_convention: str, normalize: bool = False
741
+ ) -> PoseTarget:
742
+ _convention_class = globals()[pose_target_convention]
743
+ if _convention_class == ScaleShiftInvariant:
744
+ return _convention_class.from_instance_pose(instance_pose, normalize=normalize)
745
+ else:
746
+ return _convention_class.from_instance_pose(instance_pose)
747
+
748
+ @staticmethod
749
+ def dicts_instance_pose_to_pose_target(
750
+ pose_target_convention: str,
751
+ **kwargs,
752
+ ):
753
+ instance_pose = InstancePose(**kwargs)
754
+ pose_target = PoseTargetConverter.instance_pose_to_pose_target(
755
+ instance_pose, pose_target_convention
756
+ )
757
+ return asdict(pose_target)
758
+
759
+ @staticmethod
760
+ def dicts_pose_target_to_instance_pose(
761
+ **kwargs,
762
+ ):
763
+ pose_target_convention = kwargs.get("pose_target_convention")
764
+ _convention_class = globals()[pose_target_convention]
765
+ assert (
766
+ _convention_class.pose_target_convention == pose_target_convention
767
+ ), f"Normalization name mismatch: {_convention_class.pose_target_convention} != {pose_target_convention}"
768
+
769
+ normalize = kwargs.pop("normalize", False)
770
+ pose_target = PoseTarget(**kwargs)
771
+ instance_pose = PoseTargetConverter.pose_target_to_instance_pose(pose_target, normalize)
772
+ return asdict(instance_pose)
773
+
774
+
775
+ class LogScaleShiftNormalizer:
776
+ def __init__(self, shift_log: torch.Tensor = 0.0, scale_log: torch.Tensor = 1.0):
777
+ self.shift_log = shift_log
778
+ self.scale_log = scale_log
779
+
780
+ def normalize(self, value: torch.Tensor):
781
+ return torch.log(value) - self.shift_log / self.scale_log
782
+
783
+ def denormalize(self, value: torch.Tensor):
784
+ return torch.exp(value * self.scale_log + self.shift_log)
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/preprocessor.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import warnings
3
+ import torch
4
+ from loguru import logger
5
+ from dataclasses import dataclass
6
+ from typing import Callable, Optional
7
+ import warnings
8
+
9
+ from .img_and_mask_transforms import (
10
+ SSIPointmapNormalizer,
11
+ )
12
+
13
+
14
+ # Load and process data
15
+ @dataclass
16
+ class PreProcessor:
17
+ """
18
+ Preprocessor configuration for image, mask, and pointmap transforms.
19
+
20
+ Transform application order:
21
+ 1. Pointmap normalization (if normalize_pointmap=True)
22
+ 2. Joint transforms (img_mask_pointmap_joint_transform or img_mask_joint_transform)
23
+ 3. Individual transforms (img_transform, mask_transform, pointmap_transform)
24
+
25
+ For backward compatibility, img_mask_joint_transform is preserved. When both
26
+ img_mask_pointmap_joint_transform and img_mask_joint_transform are present,
27
+ img_mask_pointmap_joint_transform takes priority.
28
+ """
29
+
30
+ img_transform: Callable = (None,)
31
+ mask_transform: Callable = (None,)
32
+ img_mask_joint_transform: list[Callable] = (None,)
33
+ rgb_img_mask_joint_transform: list[Callable] = (None,)
34
+
35
+ # New fields for pointmap support
36
+ pointmap_transform: Callable = (None,)
37
+ img_mask_pointmap_joint_transform: list[Callable] = (None,)
38
+
39
+ # Pointmap normalization option
40
+ normalize_pointmap: bool = False
41
+ pointmap_normalizer: Optional[Callable] = None
42
+ rgb_pointmap_normalizer: Optional[Callable] = None
43
+
44
+ def __post_init__(self):
45
+ if self.pointmap_normalizer is None:
46
+ self.pointmap_normalizer = SSIPointmapNormalizer()
47
+ if self.normalize_pointmap == False:
48
+ warnings.warn("normalize_pointmap is also set to False, which means we will return the moments but not normalize the pointmap. This supports old unnormalized pointmap models, but this is dangerous behavior.", DeprecationWarning, stacklevel=2)
49
+
50
+ if self.rgb_pointmap_normalizer is None:
51
+ logger.warning("No rgb pointmap normalizer provided, using scale + shift ")
52
+ self.rgb_pointmap_normalizer = self.pointmap_normalizer
53
+
54
+
55
+ def _normalize_pointmap(
56
+ self, pointmap: torch.Tensor,
57
+ mask: torch.Tensor,
58
+ pointmap_normalizer: Callable,
59
+ scale: Optional[torch.Tensor] = None,
60
+ shift: Optional[torch.Tensor] = None,
61
+ ):
62
+ if pointmap is None:
63
+ return pointmap, None, None
64
+
65
+ if self.normalize_pointmap == False:
66
+ # old behavior: Pose is normalized to the pointmap center, but pointmap is not
67
+ _, pointmap_scale, pointmap_shift = pointmap_normalizer.normalize(pointmap, mask)
68
+ return pointmap, pointmap_scale, pointmap_shift
69
+
70
+ if scale is not None or shift is not None:
71
+ return pointmap_normalizer.normalize(pointmap, mask, scale, shift)
72
+
73
+ return pointmap_normalizer.normalize(pointmap, mask)
74
+
75
+ def _process_image_mask_pointmap_mess(
76
+ self, rgb_image, rgb_image_mask, pointmap=None
77
+ ):
78
+ """Extended version that handles pointmaps"""
79
+
80
+ # Apply pointmap normalization if enabled
81
+ pointmap_for_crop, pointmap_scale, pointmap_shift = self._normalize_pointmap(
82
+ pointmap, rgb_image_mask, self.pointmap_normalizer
83
+ )
84
+
85
+ # Apply transforms to the original full rgb image and mask.
86
+ rgb_image, rgb_image_mask = self._preprocess_rgb_image_mask(rgb_image, rgb_image_mask)
87
+
88
+ # These two are typically used for getting cropped images of the object
89
+ # : first apply joint transforms
90
+ processed_rgb_image, processed_mask, processed_pointmap = (
91
+ self._preprocess_image_mask_pointmap(rgb_image, rgb_image_mask, pointmap_for_crop)
92
+ )
93
+ # : then apply individual transforms on top of the joint transforms
94
+ processed_rgb_image = self._apply_transform(
95
+ processed_rgb_image, self.img_transform
96
+ )
97
+ processed_mask = self._apply_transform(processed_mask, self.mask_transform)
98
+ if processed_pointmap is not None:
99
+ processed_pointmap = self._apply_transform(
100
+ processed_pointmap, self.pointmap_transform
101
+ )
102
+
103
+ # This version is typically the full version of the image
104
+ # : apply individual transforms only
105
+ rgb_image = self._apply_transform(rgb_image, self.img_transform)
106
+ rgb_image_mask = self._apply_transform(rgb_image_mask, self.mask_transform)
107
+
108
+ rgb_pointmap, rgb_pointmap_scale, rgb_pointmap_shift = self._normalize_pointmap(
109
+ pointmap, rgb_image_mask, self.rgb_pointmap_normalizer, pointmap_scale, pointmap_shift
110
+ )
111
+
112
+ if rgb_pointmap is not None:
113
+ rgb_pointmap = self._apply_transform(rgb_pointmap, self.pointmap_transform)
114
+
115
+ result = {
116
+ "mask": processed_mask,
117
+ "image": processed_rgb_image,
118
+ "rgb_image": rgb_image,
119
+ "rgb_image_mask": rgb_image_mask,
120
+ }
121
+
122
+ # Add pointmap results if available
123
+ if processed_pointmap is not None:
124
+ result.update(
125
+ {
126
+ "pointmap": processed_pointmap,
127
+ "rgb_pointmap": rgb_pointmap,
128
+ }
129
+ )
130
+
131
+ # Add normalization parameters if normalization was applied
132
+ if pointmap_scale is not None and pointmap_shift is not None:
133
+ result.update(
134
+ {
135
+ "pointmap_scale": pointmap_scale,
136
+ "pointmap_shift": pointmap_shift,
137
+ "rgb_pointmap_scale": rgb_pointmap_scale,
138
+ "rgb_pointmap_shift": rgb_pointmap_shift,
139
+ }
140
+ )
141
+
142
+ return result
143
+
144
+ def _process_image_and_mask_mess(self, rgb_image, rgb_image_mask):
145
+ """Original method - calls extended version without pointmap"""
146
+ return self._process_image_mask_pointmap_mess(rgb_image, rgb_image_mask, None)
147
+
148
+ def _preprocess_rgb_image_mask(self, rgb_image: torch.Tensor, rgb_image_mask: torch.Tensor):
149
+ """Apply joint transforms to rgb_image and rgb_image_mask."""
150
+ if (
151
+ self.rgb_img_mask_joint_transform != (None,)
152
+ and self.rgb_img_mask_joint_transform is not None
153
+ ):
154
+ for trans in self.rgb_img_mask_joint_transform:
155
+ rgb_image, rgb_image_mask = trans(rgb_image, rgb_image_mask)
156
+ return rgb_image, rgb_image_mask
157
+
158
+ def _preprocess_image_mask_pointmap(self, rgb_image, mask_image, pointmap=None):
159
+ """Apply joint transforms with priority: triple transforms > dual transforms."""
160
+ # Priority: img_mask_pointmap_joint_transform when pointmap is provided
161
+ if (
162
+ self.img_mask_pointmap_joint_transform != (None,)
163
+ and self.img_mask_pointmap_joint_transform is not None
164
+ and pointmap is not None
165
+ ):
166
+ for trans in self.img_mask_pointmap_joint_transform:
167
+ rgb_image, mask_image, pointmap = trans(
168
+ rgb_image, mask_image, pointmap=pointmap
169
+ )
170
+ return rgb_image, mask_image, pointmap
171
+
172
+ # Fallback: img_mask_joint_transform (existing behavior)
173
+ elif (
174
+ self.img_mask_joint_transform != (None,)
175
+ and self.img_mask_joint_transform is not None
176
+ ):
177
+ for trans in self.img_mask_joint_transform:
178
+ rgb_image, mask_image = trans(rgb_image, mask_image)
179
+ return rgb_image, mask_image, pointmap
180
+
181
+ return rgb_image, mask_image, pointmap
182
+
183
+ def _preprocess_image_and_mask(self, rgb_image, mask_image):
184
+ """Backward compatibility wrapper - only applies dual transforms"""
185
+ rgb_image, mask_image, _ = self._preprocess_image_mask_pointmap(
186
+ rgb_image, mask_image, None
187
+ )
188
+ return rgb_image, mask_image
189
+
190
+ # keep here for backward compatibility
191
+ def _preprocess_image_and_mask_inference(self, rgb_image, mask_image):
192
+ warnings.warn(
193
+ "The _preprocess_image_and_mask_inference is deprecated! Please use _preprocess_image_and_mask",
194
+ category=DeprecationWarning,
195
+ stacklevel=2,
196
+ )
197
+ return self._preprocess_image_and_mask(rgb_image, mask_image)
198
+
199
+ def _apply_transform(self, input: torch.Tensor, transform):
200
+ if input is not None and transform is not None and transform != (None,):
201
+ input = transform(input)
202
+
203
+ return input
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/transforms_3d.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from collections import namedtuple
3
+ import math
4
+ import torch
5
+
6
+ from pytorch3d.transforms import (
7
+ Rotate,
8
+ Translate,
9
+ Scale,
10
+ Transform3d,
11
+ quaternion_to_matrix,
12
+ axis_angle_to_quaternion,
13
+ )
14
+
15
+ DecomposedTransform = namedtuple(
16
+ "DecomposedTransform", ["scale", "rotation", "translation"]
17
+ )
18
+
19
+
20
+ def compose_transform(
21
+ scale: torch.Tensor, rotation: torch.Tensor, translation: torch.Tensor
22
+ ) -> Transform3d:
23
+ """
24
+ Args:
25
+ scale: (..., 3) tensor of scale factors
26
+ rotation: (..., 3, 3) tensor of rotation matrices
27
+ translation: (..., 3) tensor of translation vectors
28
+ """
29
+ tfm = Transform3d(dtype=scale.dtype, device=scale.device)
30
+ return tfm.scale(scale).rotate(rotation).translate(translation)
31
+
32
+
33
+ def decompose_transform(transform: Transform3d) -> DecomposedTransform:
34
+ """
35
+ Returns:
36
+ scale: (..., 3) tensor of scale factors
37
+ rotation: (..., 3, 3) tensor of rotation matrices
38
+ translation: (..., 3) tensor of translation vectors
39
+ """
40
+ matrices = transform.get_matrix()
41
+ scale = torch.norm(matrices[:, :3, :3], dim=-1)
42
+ rotation = matrices[:, :3, :3] / scale.unsqueeze(-1) # Normalize rotation matrix
43
+ translation = matrices[:, 3, :3] # Extract translation vector
44
+ return DecomposedTransform(scale, rotation, translation)
45
+
46
+
47
+ def get_rotation_about_x_axis(angle: float = math.pi / 2) -> torch.Tensor:
48
+ axis = torch.tensor([1.0, 0.0, 0.0])
49
+ axis_angle = axis * angle
50
+ return axis_angle_to_quaternion(axis_angle)
thirdparty/sam3d/sam3d/sam3d_objects/data/utils.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from typing import Any, Iterable, Tuple, Union, Dict, Sequence, Mapping, Container
3
+ import optree
4
+ import torch
5
+ from collections.abc import Iterable
6
+ import inspect
7
+ import ast
8
+ import astor
9
+ from torch.utils import _pytree
10
+
11
+ # None = root, Iterable[Any] = path, Any = path of one
12
+ ChildPathType = Union[None, Iterable[Any], Any]
13
+ ArgsType = Iterable[ChildPathType]
14
+ KwargsType = Mapping[str, ChildPathType]
15
+ ArgsKwargsType = Tuple[ArgsType, KwargsType]
16
+ MappingType = Union[None, ArgsKwargsType, ArgsType, KwargsType]
17
+
18
+
19
+ def tree_transpose_level_one(
20
+ structure,
21
+ check_children=False,
22
+ map_fn=None,
23
+ is_leaf=None,
24
+ ):
25
+ _, outer_spec = optree.tree_flatten(
26
+ structure,
27
+ is_leaf=lambda x: x is not structure,
28
+ none_is_leaf=True,
29
+ )
30
+
31
+ spec = optree.tree_structure(structure, none_is_leaf=True, is_leaf=is_leaf)
32
+ children_spec = spec.children()
33
+ if len(children_spec) > 0:
34
+ inner_spec = children_spec[0]
35
+ if check_children:
36
+ for child_spec in children_spec[1:]:
37
+ assert (
38
+ inner_spec == child_spec
39
+ ), f"one child was found having a different tree structure ({inner_spec} != {child_spec})"
40
+
41
+ structure = optree.tree_transpose(outer_spec, inner_spec, structure)
42
+
43
+ if map_fn is not None:
44
+ structure = optree.tree_map(
45
+ map_fn,
46
+ structure,
47
+ is_leaf=lambda x: optree.tree_structure(
48
+ x, is_leaf=is_leaf, none_is_leaf=True
49
+ )
50
+ == outer_spec,
51
+ none_is_leaf=True,
52
+ )
53
+
54
+ return structure
55
+
56
+
57
+ @staticmethod
58
+ def tree_tensor_map(fn, tree, *rest):
59
+ return optree.tree_map(
60
+ fn,
61
+ tree,
62
+ *rest,
63
+ is_leaf=lambda x: isinstance(x, torch.Tensor),
64
+ none_is_leaf=False,
65
+ )
66
+
67
+
68
+ def to_device(obj, device):
69
+ """Recursively moves all tensors in obj to the specified device.
70
+
71
+ Args:
72
+ obj: Object to move to device - can be a tensor, list, tuple, dict or any nested combination
73
+ device: Target device (e.g. 'cuda', 'cpu', torch.device('cuda:0') etc.)
74
+
75
+ Returns:
76
+ Same object structure with all contained tensors moved to specified device
77
+ """
78
+ to_fn = lambda x: x.to(device)
79
+ return optree.tree_map(to_fn, obj, is_leaf=torch.is_tensor, none_is_leaf=False)
80
+
81
+
82
+ def expand_right(tensor, target_shape):
83
+ """
84
+ e.g. Takes tensor of (a, b, c) and returns a tensor of (a, b, c, 1, 1, ...)
85
+ """
86
+ current_shape = tensor.shape
87
+ dims_to_add = len(target_shape) - len(current_shape)
88
+ result = tensor
89
+ for _ in range(dims_to_add):
90
+ result = result.unsqueeze(-1)
91
+ expand_shape = list(current_shape) + [-1] * dims_to_add
92
+ for i in range(len(target_shape)):
93
+ if i < len(expand_shape) and expand_shape[i] == -1:
94
+ expand_shape[i] = target_shape[i]
95
+ return result.expand(*expand_shape)
96
+
97
+
98
+ def expand_as_right(tensor, target):
99
+ return expand_right(tensor, target.shape)
100
+
101
+
102
+ def as_keys(path: ChildPathType):
103
+ if isinstance(path, Iterable) and (not isinstance(path, str)):
104
+ return tuple(path)
105
+ elif path is None:
106
+ return ()
107
+ return (path,)
108
+
109
+
110
+ def get_child(obj: Any, *keys: Iterable[Any]):
111
+ for key in keys:
112
+ obj = obj[key]
113
+ return obj
114
+
115
+
116
+ def set_child(obj: Any, value: Any, *keys: Iterable[Any]):
117
+ parent = None
118
+ for key in keys:
119
+ parent = obj
120
+ obj = obj[key]
121
+ if parent is None:
122
+ obj = value
123
+ else:
124
+ parent[key] = value
125
+ return obj
126
+
127
+
128
+ def build_args_batch_extractor(args_mapping: ArgsType):
129
+ def extract_fn(batch):
130
+ return tuple(get_child(batch, *as_keys(path)) for path in args_mapping)
131
+
132
+ return extract_fn
133
+
134
+
135
+ def build_kwargs_batch_extractor(kwargs_mapping: KwargsType):
136
+ def extract_fn(batch):
137
+ return {
138
+ name: get_child(batch, *as_keys(path))
139
+ for name, path in kwargs_mapping.items()
140
+ }
141
+
142
+ return extract_fn
143
+
144
+
145
+ empty_mapping = object()
146
+ kwargs_identity_mapping = object()
147
+
148
+
149
+ def build_batch_extractor(mapping: MappingType):
150
+ extract_args_fn = lambda x: ()
151
+ extract_kwargs_fn = lambda x: {}
152
+
153
+ if mapping is None:
154
+
155
+ def extract_args_fn(batch):
156
+ return (batch,)
157
+
158
+ elif mapping is empty_mapping:
159
+ pass
160
+ elif mapping is kwargs_identity_mapping:
161
+ extract_kwargs_fn = lambda x: x
162
+ elif isinstance(mapping, Sequence) and (not isinstance(mapping, str)):
163
+ if (
164
+ len(mapping) == 2
165
+ and isinstance(mapping[0], Sequence)
166
+ and isinstance(mapping[1], Dict)
167
+ ):
168
+ extract_args_fn = build_args_batch_extractor(mapping[0])
169
+ extract_kwargs_fn = build_kwargs_batch_extractor(mapping[1])
170
+ else:
171
+ extract_args_fn = build_args_batch_extractor(mapping)
172
+ elif isinstance(mapping, Mapping):
173
+ extract_kwargs_fn = build_kwargs_batch_extractor(mapping)
174
+ else:
175
+
176
+ def extract_args_fn(batch):
177
+ return (get_child(batch, *as_keys(mapping)),)
178
+
179
+ def extract_fn(batch):
180
+ return extract_args_fn(batch), extract_kwargs_fn(batch)
181
+
182
+ return extract_fn
183
+
184
+
185
+ # >
186
+
187
+
188
+ def right_broadcasting(arr, target):
189
+ return arr.reshape(arr.shape + (1,) * (target.ndim - arr.ndim))
190
+
191
+
192
+ def get_stats(tensor: torch.Tensor):
193
+ float_tensor = tensor.float()
194
+ return {
195
+ "shape": tuple(tensor.shape),
196
+ "min": tensor.min().item(),
197
+ "max": tensor.max().item(),
198
+ "mean": float_tensor.mean().item(),
199
+ "median": tensor.median().item(),
200
+ "std": float_tensor.std().item(),
201
+ }
202
+
203
+
204
+ def _get_caller_arg_name(argnum=0, parent_frame=1):
205
+ try:
206
+ frame = inspect.currentframe() # current frame
207
+ frame = inspect.getouterframes(frame)[1 + parent_frame] # parent frame
208
+ code = inspect.getframeinfo(frame[0]).code_context[0].strip() # get code line
209
+
210
+ tree = ast.parse(code)
211
+
212
+ for node in ast.walk(tree):
213
+ if isinstance(node, ast.Call):
214
+ args = node.args
215
+ break # only get the first parent call
216
+
217
+ # get first argument string (do not handle '=')
218
+ label = astor.to_source(args[argnum]).strip()
219
+ except:
220
+ # TODO(Pierre) log exception
221
+ label = "{label}"
222
+ return label
223
+
224
+
225
+ def print_stats(tensor, label=None):
226
+ if label is None:
227
+ label = _get_caller_arg_name(argnum=0)
228
+ stats = get_stats(tensor)
229
+ string = f"{label}:\n" + "\n".join(f"- {k}: {v}" for k, v in stats.items())
230
+ print(string)
231
+
232
+
233
+ def tree_reduce_unique(fn, tree, ensure_unique=True, **kwargs):
234
+ values = _pytree.tree_flatten(tree, **kwargs)[0]
235
+ values = tuple(map(fn, values))
236
+ first = values[0]
237
+ if ensure_unique:
238
+ for value in values[1:]:
239
+ if value != first:
240
+ raise RuntimeError(
241
+ f"different values found, {value} and {first} should be the same"
242
+ )
243
+ return first
thirdparty/sam3d/sam3d/sam3d_objects/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/dino.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import torch
3
+ from typing import Optional, Dict, Any
4
+ import warnings
5
+ from torchvision.transforms import Normalize
6
+ import torch.nn.functional as F
7
+ from loguru import logger
8
+
9
+
10
+ class Dino(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ input_size: int = 224,
14
+ repo_or_dir: str = "facebookresearch/dinov2",
15
+ dino_model: str = "dinov2_vitb14",
16
+ source: str = "github",
17
+ backbone_kwargs: Optional[Dict[str, Any]] = None,
18
+ normalize_images: bool = True,
19
+ # for backward compatible
20
+ prenorm_features: bool = False,
21
+ freeze_backbone: bool = True,
22
+ prune_network: bool = False, # False for backward compatible
23
+ ):
24
+ super().__init__()
25
+ if backbone_kwargs is None:
26
+ backbone_kwargs = {}
27
+
28
+ with warnings.catch_warnings():
29
+ warnings.simplefilter("ignore")
30
+
31
+ logger.info(f"Loading DINO model: {dino_model} from {repo_or_dir} (source: {source})")
32
+ if backbone_kwargs:
33
+ logger.info(f"DINO backbone kwargs: {backbone_kwargs}")
34
+
35
+ self.backbone = torch.hub.load(
36
+ repo_or_dir=repo_or_dir,
37
+ model=dino_model,
38
+ source=source,
39
+ verbose=False,
40
+ **backbone_kwargs,
41
+ )
42
+
43
+ # Log model properties after loading
44
+ logger.info(f"Loaded DINO model - type: {type(self.backbone)}, "
45
+ f"embed_dim: {self.backbone.embed_dim}, "
46
+ f"patch_size: {getattr(self.backbone.patch_embed, 'patch_size', 'N/A')}")
47
+
48
+
49
+ self.resize_input_size = (input_size, input_size)
50
+ self.embed_dim = self.backbone.embed_dim
51
+ self.input_size = input_size
52
+ self.input_channels = 3
53
+ self.normalize_images = normalize_images
54
+ self.prenorm_features = prenorm_features
55
+ self.register_buffer('mean', torch.as_tensor([[0.485, 0.456, 0.406]]).view(-1, 1, 1), persistent=False)
56
+ self.register_buffer('std', torch.as_tensor([[0.229, 0.224, 0.225]]).view(-1, 1, 1), persistent=False)
57
+
58
+ # freeze
59
+ if freeze_backbone:
60
+ self.requires_grad_(False)
61
+ self.eval()
62
+ elif not prune_network:
63
+ logger.warning(
64
+ "Unfreeze encoder w/o prune parameter may lead to error in ddp/fp16 training"
65
+ )
66
+
67
+ if prune_network:
68
+ self._prune_network()
69
+
70
+ def _preprocess_input(self, x):
71
+ _resized_images = torch.nn.functional.interpolate(
72
+ x,
73
+ size=self.resize_input_size,
74
+ mode="bilinear",
75
+ align_corners=False,
76
+ )
77
+
78
+ if x.shape[1] == 1:
79
+ _resized_images = _resized_images.repeat(1, 3, 1, 1)
80
+
81
+ if self.normalize_images:
82
+ _resized_images = _resized_images.sub_(self.mean).div_(self.std)
83
+
84
+ return _resized_images
85
+
86
+ def _forward_intermediate_layers(
87
+ self, input_img, intermediate_layers, cls_token=True
88
+ ):
89
+ return self.backbone.get_intermediate_layers(
90
+ input_img,
91
+ intermediate_layers,
92
+ return_class_token=cls_token,
93
+ )
94
+
95
+ def _forward_last_layer(self, input_img):
96
+ output = self.backbone.forward_features(input_img)
97
+ if self.prenorm_features:
98
+ features = output["x_prenorm"]
99
+ tokens = F.layer_norm(features, features.shape[-1:])
100
+ else:
101
+ tokens = torch.cat(
102
+ [
103
+ output["x_norm_clstoken"].unsqueeze(1),
104
+ output["x_norm_patchtokens"],
105
+ ],
106
+ dim=1,
107
+ )
108
+ return tokens
109
+
110
+ def forward(self, x, **kwargs):
111
+ _resized_images = self._preprocess_input(x)
112
+ tokens = self._forward_last_layer(_resized_images)
113
+ return tokens.to(x.dtype)
114
+
115
+ def _prune_network(self):
116
+ """
117
+ Ran this script:
118
+ out = model(input)
119
+ loss = out.sum()
120
+ loss.backward()
121
+
122
+ for name, p in dino_model.named_parameters():
123
+ if p.grad is None:
124
+ print(name)
125
+ model.zero_grad()
126
+ """
127
+ self.backbone.mask_token = None
128
+ if self.prenorm_features:
129
+ self.backbone.norm = torch.nn.Identity()
130
+
131
+
132
+ class DinoForMasks(torch.nn.Module):
133
+ def __init__(
134
+ self,
135
+ backbone: Dino,
136
+ ):
137
+ super().__init__()
138
+ self.backbone = backbone
139
+ self.embed_dim = self.backbone.embed_dim
140
+
141
+ def forward(self, image, mask):
142
+ return self.backbone.forward(mask)
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/embedder_fuser.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import math
3
+ import torch
4
+ from loguru import logger
5
+ from torch import nn
6
+ from typing import Optional, Tuple, List, Literal, Dict
7
+ from sam3d_objects.model.layers.llama3.ff import FeedForward
8
+ from omegaconf import OmegaConf
9
+
10
+ class EmbedderFuser(torch.nn.Module):
11
+ """
12
+ Fusing individual condition embedder. Require kwargs for the forward!
13
+ Args:
14
+ embedder_list: List of Tuples. Each Tuple consists of a condition_embedder
15
+ and a list of tuple. In the list, each tuple consists of a string, indicating
16
+ a kward, and astring, indicating the group of positional encoding to be used.
17
+ use_pos_embedding: whether to add positional embedding. If add, follow the index in
18
+ embedder_list. Choices of None (no pos emb), random, and learned.
19
+ projection_pre_norm: pre-normalize features before feeding into projector layers.
20
+ projection_net_hidden_dim_multiplier: hidden dimension for projection layer. If 0, don't use.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ embedder_list: List[Tuple[nn.Module, List[Tuple[str, Optional[str]]]]],
26
+ use_pos_embedding: Optional[Literal["random", "learned"]] = "learned",
27
+ projection_pre_norm: bool = True,
28
+ projection_net_hidden_dim_multiplier: float = 4.0,
29
+ compression_projection_multiplier: float = 0,
30
+ freeze: bool = False,
31
+ drop_modalities_weight: Dict[List[str], float] = None,
32
+ dropout_prob: float = 0.0,
33
+ force_drop_modalities: List[str] = None,
34
+ ):
35
+ super().__init__()
36
+ # torch.compile does not support OmegaConf.ListConfig, so we convert to a list
37
+ if not isinstance(embedder_list, List):
38
+ self.embedder_list = OmegaConf.to_container(embedder_list)
39
+ else:
40
+ self.embedder_list = embedder_list
41
+
42
+ self.embed_dims = 0
43
+ self.compression_projection_multiplier = compression_projection_multiplier
44
+ self.concate_embed_dims = 0
45
+ # keep moduleList to be compatible with nn module
46
+ self.module_list = []
47
+ max_positional_embed_idx = 0
48
+ self.positional_embed_map = {}
49
+ for condition_embedder, kwargs_info in self.embedder_list:
50
+ self.embed_dims = max(self.embed_dims, condition_embedder.embed_dim)
51
+ self.module_list.append(condition_embedder)
52
+ for _, pos_group in kwargs_info:
53
+ self.concate_embed_dims += condition_embedder.embed_dim
54
+ if pos_group is not None:
55
+ if pos_group not in self.positional_embed_map:
56
+ self.positional_embed_map[pos_group] = max_positional_embed_idx
57
+ max_positional_embed_idx += 1
58
+ self.module_list = nn.ModuleList(self.module_list)
59
+ self.use_pos_embedding = use_pos_embedding
60
+ if self.use_pos_embedding == "random":
61
+ idx_emb = torch.randn(max_positional_embed_idx + 1, 1, self.embed_dims)
62
+ self.register_buffer("idx_emb", idx_emb)
63
+ elif self.use_pos_embedding == "learned":
64
+ self.idx_emb = nn.Parameter(
65
+ torch.empty(max_positional_embed_idx + 1, self.embed_dims)
66
+ )
67
+ nn.init.normal_(
68
+ self.idx_emb, mean=0.0, std=1.0 / math.sqrt(self.embed_dims)
69
+ )
70
+ else:
71
+ raise NotImplementedError(f"Unknown pos embedding {self.use_pos_embedding}")
72
+
73
+ self.projection_pre_norm = projection_pre_norm
74
+ self.projection_net_hidden_dim_multiplier = projection_net_hidden_dim_multiplier
75
+ if projection_net_hidden_dim_multiplier > 0:
76
+ self.projection_nets = []
77
+ for condition_embedder, _ in self.embedder_list:
78
+ self.projection_nets.append(
79
+ self._make_projection_net(
80
+ condition_embedder.embed_dim,
81
+ self.embed_dims,
82
+ self.projection_net_hidden_dim_multiplier,
83
+ )
84
+ )
85
+ self.projection_nets = nn.ModuleList(self.projection_nets)
86
+
87
+ if compression_projection_multiplier > 0:
88
+ self.compression_projector = self._make_projection_net(
89
+ self.concate_embed_dims,
90
+ self.embed_dims,
91
+ self.compression_projection_multiplier,
92
+ )
93
+
94
+ self.drop_modalities_weight = drop_modalities_weight if drop_modalities_weight is not None else []
95
+ self.dropout_prob = dropout_prob
96
+ self.force_drop_modalities = force_drop_modalities
97
+
98
+ if freeze:
99
+ self.requires_grad_(False)
100
+ self.eval()
101
+
102
+ def _make_projection_net(
103
+ self,
104
+ input_embed_dim,
105
+ output_embed_dim: int,
106
+ multiplier: int,
107
+ ):
108
+ if self.projection_pre_norm:
109
+ pre_norm = nn.LayerNorm(input_embed_dim)
110
+ else:
111
+ pre_norm = nn.Identity()
112
+
113
+ # Per-token projection + gated activation
114
+ ff_net = FeedForward(
115
+ dim=input_embed_dim,
116
+ hidden_dim=int(multiplier * output_embed_dim),
117
+ output_dim=output_embed_dim,
118
+ )
119
+ return nn.Sequential(pre_norm, ff_net)
120
+
121
+ def _build_dropout_distribution(self, device):
122
+ """
123
+ Build the probability distribution for dropout configurations.
124
+
125
+ Returns:
126
+ dropout_configs: List of sets containing modalities to drop
127
+ cumsum_weights: Cumulative sum of weights for sampling
128
+ """
129
+ dropout_configs = []
130
+ weights = []
131
+
132
+ # Add no-dropout configuration with remaining probability
133
+ dropout_configs.append(set())
134
+ weights.append(1.0 - self.dropout_prob)
135
+
136
+ # Add configured dropout patterns
137
+ total_dropout_weight = sum(w for _, w in self.drop_modalities_weight)
138
+ assert total_dropout_weight > 0, "Total dropout weight must be positive when drop_modalities_weight is provided"
139
+ for modality_list, weight in self.drop_modalities_weight:
140
+ dropout_configs.append(set(modality_list))
141
+ # Scale weight by dropout_prob to ensure total probability sums to 1
142
+ weights.append(self.dropout_prob * weight / total_dropout_weight)
143
+
144
+ # Convert weights to cumulative distribution
145
+ weights_tensor = torch.tensor(weights, device=device)
146
+
147
+ was_deterministic = torch.are_deterministic_algorithms_enabled()
148
+ torch.use_deterministic_algorithms(False)
149
+ cumsum_weights = torch.cumsum(weights_tensor, dim=0)
150
+ torch.use_deterministic_algorithms(was_deterministic)
151
+
152
+ return dropout_configs, cumsum_weights
153
+
154
+ def _apply_force_drop(self, kwarg_names: List[str], tokens: List[torch.Tensor]):
155
+ if not self.force_drop_modalities:
156
+ return tokens
157
+
158
+ force_drop_set = set(self.force_drop_modalities)
159
+ result_tokens = []
160
+
161
+ for kwarg_name, token_tensor in zip(kwarg_names, tokens):
162
+ # Create mask: 0 for forced drop, 1 otherwise
163
+ mask = 0.0 if kwarg_name in force_drop_set else 1.0
164
+ result_tokens.append(token_tensor * mask)
165
+
166
+ return result_tokens
167
+
168
+ def _dropout_modalities(self, kwarg_names: List[str], tokens: List[torch.Tensor]):
169
+ # First apply forced drops (deterministic, always applied)
170
+ tokens = self._apply_force_drop(kwarg_names, tokens)
171
+
172
+ # Then apply probabilistic dropout (only in training)
173
+ if not self.training or self.dropout_prob <= 0 or not self.drop_modalities_weight:
174
+ return tokens
175
+
176
+ batch_size = tokens[0].shape[0]
177
+ device = tokens[0].device
178
+
179
+ # Build dropout configurations and sample which to use per batch element
180
+ dropout_configs, cumsum_weights = self._build_dropout_distribution(device)
181
+ rand_vals = torch.rand(batch_size, device=device)
182
+ # Clamp indices to valid range (handle edge case where rand_val == 1.0)
183
+ config_indices = torch.searchsorted(cumsum_weights, rand_vals).clamp(max=len(dropout_configs) - 1)
184
+
185
+ # Apply dropout masks with vectorized operations
186
+ result_tokens = []
187
+ for kwarg_name, token_tensor in zip(kwarg_names, tokens):
188
+ # Start with all ones (no dropout)
189
+ mask = torch.ones(batch_size, dtype=token_tensor.dtype, device=device)
190
+
191
+ # Vectorized mask creation: check all configurations at once
192
+ for config_idx, modalities_to_drop in enumerate(dropout_configs):
193
+ if kwarg_name in modalities_to_drop:
194
+ # Set mask to 0 for all batch elements using this configuration
195
+ mask[config_indices == config_idx] = 0.0
196
+
197
+ # Reshape mask to match token dimensions
198
+ mask = mask.view([batch_size] + [1] * (token_tensor.ndim - 1))
199
+ result_tokens.append(token_tensor * mask)
200
+
201
+ return result_tokens
202
+
203
+ def forward(self, *args, **kwargs):
204
+ tokens = []
205
+ kwarg_names = []
206
+
207
+ for i, (condition_embedder, kwargs_info) in enumerate(self.embedder_list):
208
+ # Ideally, we would batch the inputs; but that assumes same-sized inputs
209
+ for kwarg_name, pos_group in kwargs_info:
210
+ if kwarg_name not in kwargs:
211
+ logger.warning(f"{kwarg_name} not in kwargs to condition embedder!")
212
+ input_cond = kwargs[kwarg_name]
213
+ cond_token = condition_embedder(input_cond)
214
+ if self.projection_net_hidden_dim_multiplier > 0:
215
+ cond_token = self.projection_nets[i](cond_token)
216
+ if pos_group is not None:
217
+ pos_idx = self.positional_embed_map[pos_group]
218
+ if self.use_pos_embedding == "random":
219
+ cond_token += self.idx_emb[pos_idx : pos_idx + 1]
220
+ elif self.use_pos_embedding == "learned":
221
+ cond_token += self.idx_emb[pos_idx : pos_idx + 1, None]
222
+ else:
223
+ raise NotImplementedError(
224
+ f"Unknown pos embedding {self.use_pos_embedding}"
225
+ )
226
+ tokens.append(cond_token)
227
+ kwarg_names.append(kwarg_name)
228
+
229
+ # Apply dropout modalities with preserved order
230
+ tokens = self._dropout_modalities(kwarg_names, tokens)
231
+
232
+ if self.compression_projection_multiplier > 0:
233
+ tokens = torch.cat(tokens, dim=-1)
234
+ tokens = self.compression_projector(tokens)
235
+ else:
236
+ tokens = torch.cat(tokens, dim=1)
237
+
238
+ return tokens
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/point_remapper.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class PointRemapper(nn.Module):
7
+ """Handles remapping of 3D point coordinates and their inverse transformations."""
8
+
9
+ VALID_TYPES = ["linear", "sinh", "exp", "sinh_exp", "exp_disparity"]
10
+
11
+ def __init__(self, remap_type: str = "exp"):
12
+ super().__init__()
13
+ self.remap_type = remap_type
14
+
15
+ if remap_type not in self.VALID_TYPES:
16
+ raise ValueError(
17
+ f"Invalid remap type: {remap_type}. Must be one of {self.VALID_TYPES}"
18
+ )
19
+
20
+ def forward(self, points: torch.Tensor) -> torch.Tensor:
21
+ """Apply remapping to point coordinates."""
22
+ if self.remap_type == "linear":
23
+ return points
24
+
25
+ elif self.remap_type == "sinh":
26
+ return torch.asinh(points)
27
+
28
+ elif self.remap_type == "exp":
29
+ xy_scaled, z_exp = points.split([2, 1], dim=-1)
30
+ # Use log1p for better numerical stability near zero
31
+ z = torch.log1p(z_exp)
32
+ xy = xy_scaled / (1 + z_exp)
33
+ return torch.cat([xy, z], dim=-1)
34
+
35
+ elif self.remap_type == "exp_disparity":
36
+ xy_scaled, z_exp = points.split([2, 1], dim=-1)
37
+ xy = xy_scaled / z_exp
38
+ z = torch.log(z_exp)
39
+ return torch.cat([xy, z], dim=-1)
40
+
41
+ elif self.remap_type == "sinh_exp":
42
+ xy_sinh, z_exp = points.split([2, 1], dim=-1)
43
+ xy = torch.asinh(xy_sinh)
44
+ z = torch.log(z_exp.clamp(min=1e-8))
45
+ return torch.cat([xy, z], dim=-1)
46
+
47
+ else:
48
+ raise ValueError(f"Unknown remap type: {self.remap_type}")
49
+
50
+ def inverse(self, points: torch.Tensor) -> torch.Tensor:
51
+ """Apply inverse remapping to recover original point coordinates."""
52
+ if self.remap_type == "linear":
53
+ return points
54
+
55
+ elif self.remap_type == "sinh":
56
+ return torch.sinh(points)
57
+
58
+ elif self.remap_type == "exp":
59
+ xy, z = points.split([2, 1], dim=-1)
60
+ # Inverse of log1p is expm1(z) = exp(z) - 1
61
+ z_exp = torch.expm1(z)
62
+ # Inverse of xy/(1+z_exp) is xy*(1+z_exp)
63
+ return torch.cat([xy * (1 + z_exp), z_exp], dim=-1)
64
+
65
+ elif self.remap_type == "exp_disparity":
66
+ xy, z = points.split([2, 1], dim=-1)
67
+ z_exp = torch.exp(z)
68
+ return torch.cat([xy * z_exp, z_exp], dim=-1)
69
+
70
+ elif self.remap_type == "sinh_exp":
71
+ xy, z = points.split([2, 1], dim=-1)
72
+ return torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
73
+
74
+ else:
75
+ raise ValueError(f"Unknown remap type: {self.remap_type}")
76
+
77
+ def extra_repr(self) -> str:
78
+ return f"remap_type='{self.remap_type}'"
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/pointmap.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from timm.models.vision_transformer import Block
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from functools import partial
7
+ from loguru import logger
8
+
9
+ from .point_remapper import PointRemapper
10
+
11
+
12
+ class PointPatchEmbed(nn.Module):
13
+ """
14
+ Projects (x,y,z) → D
15
+ Splits into patches (patch_size x patch_size)
16
+ Runs a tiny self-attention block inside each window
17
+ Returns one token per window.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ input_size: int = 256,
23
+ patch_size: int = 8,
24
+ embed_dim: int = 768,
25
+ remap_output: str = "exp", # Add remap_output parameter
26
+ dropout_prob: float = 0.0, # Dropout probability for pointmap
27
+ force_dropout_always: bool = False, # Force dropout during validation/inference
28
+ ):
29
+ super().__init__()
30
+ self.input_size = input_size
31
+ self.patch_size = patch_size
32
+ self.embed_dim = embed_dim
33
+ self.dropout_prob = dropout_prob
34
+ self.force_dropout_always = force_dropout_always
35
+
36
+ # Point remapper
37
+ self.point_remapper = PointRemapper(remap_output)
38
+
39
+ # (1) point embedding
40
+ self.point_proj = nn.Linear(3, embed_dim)
41
+ self.invalid_xyz_token = nn.Parameter(torch.zeros(embed_dim))
42
+
43
+ # Special embedding for dropped patches (used during dropout)
44
+ # Alternative dropout strategies to consider:
45
+ # 1. Drop all tokens entirely or use a single token only
46
+ # 2. Different dropout patterns per window
47
+ # 3. Use dropped_xyz_token/invalid_xyz_token per pixel
48
+ if dropout_prob > 0:
49
+ self.dropped_xyz_token = nn.Parameter(torch.zeros(embed_dim))
50
+
51
+ # (2) positional embedding
52
+ num_patches = input_size // patch_size
53
+ # For patches
54
+ self.pos_embed = nn.Parameter(
55
+ torch.zeros(1, embed_dim, num_patches, num_patches)
56
+ )
57
+ # For points in a patch
58
+ self.pos_embed_window = nn.Parameter(
59
+ torch.zeros(1, 1 + patch_size * patch_size, embed_dim)
60
+ )
61
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
62
+
63
+ # (3) within-patch transformer block(s)
64
+ # From MCC: https://github.com/facebookresearch/MCC/blob/b04c97518360e4fdedfb6f090db7e90d0c2f8ae6/mcc_model.py#L97
65
+ self.blocks = nn.ModuleList(
66
+ [
67
+ Block(
68
+ embed_dim,
69
+ num_heads=16,
70
+ mlp_ratio=2.0,
71
+ qkv_bias=True,
72
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
73
+ )
74
+ ]
75
+ )
76
+ self.initialize_weights()
77
+
78
+ def initialize_weights(self):
79
+ # Initialize positional embeddings with small std
80
+ nn.init.normal_(self.pos_embed, std=0.02)
81
+ nn.init.normal_(self.pos_embed_window, std=0.02)
82
+ nn.init.normal_(self.cls_token, std=0.02)
83
+ nn.init.normal_(self.invalid_xyz_token, std=0.02)
84
+
85
+ # Initialize dropped pointmap token if dropout is enabled
86
+ if self.dropout_prob > 0:
87
+ nn.init.normal_(self.dropped_xyz_token, std=0.02)
88
+
89
+ # Initialize point projection with xavier uniform for better gradient flow
90
+ # This is crucial since pointmaps can have large value ranges
91
+ nn.init.xavier_uniform_(self.point_proj.weight, gain=0.02)
92
+ if self.point_proj.bias is not None:
93
+ nn.init.constant_(self.point_proj.bias, 0)
94
+
95
+ def _get_pos_embed(self, hw):
96
+ h, w = hw
97
+ pos_embed = F.interpolate(
98
+ self.pos_embed, size=(h, w), mode="bilinear", align_corners=False
99
+ )
100
+ pos_embed = pos_embed.permute(0, 2, 3, 1) # (B, H, W, C)
101
+ return pos_embed
102
+
103
+ def resize_input(self, xyz: torch.Tensor) -> torch.Tensor:
104
+ resized_xyz = F.interpolate(xyz, size=self.input_size, mode="nearest")
105
+ resized_xyz = resized_xyz.permute(0, 2, 3, 1) # (B, H, W, C)
106
+ return resized_xyz
107
+
108
+ def apply_pointmap_dropout(self, embeddings: torch.Tensor) -> torch.Tensor:
109
+ """
110
+ Apply dropout to pointmap embeddings.
111
+ Drops entire pointmap for selected samples during training or when forced.
112
+
113
+ When force_dropout_always is True, always drops pointmap regardless of training mode.
114
+ """
115
+ # Check if we should apply dropout
116
+ should_apply_dropout = (self.training or self.force_dropout_always) and self.dropout_prob > 0
117
+
118
+ if not should_apply_dropout:
119
+ return embeddings
120
+
121
+ # Check if dropout infrastructure exists
122
+ if not hasattr(self, 'dropped_xyz_token'):
123
+ if self.force_dropout_always:
124
+ raise RuntimeError(
125
+ "Cannot force dropout: model was initialized with dropout_prob=0. "
126
+ "Re-initialize with dropout_prob > 0 to enable forced dropout."
127
+ )
128
+ return embeddings
129
+
130
+ batch_size, n_windows, embed_dim = embeddings.shape
131
+
132
+ # Decide dropout behavior
133
+ if self.force_dropout_always and not self.training:
134
+ # When forced during inference, always drop (100% dropout)
135
+ drop_mask = torch.ones(batch_size, device=embeddings.device, dtype=torch.bool)
136
+ else:
137
+ # Normal training dropout - use configured probability
138
+ drop_mask = torch.rand(batch_size, device=embeddings.device) < self.dropout_prob
139
+
140
+ # Create dropped embedding for all windows - use same token for all patches
141
+ # Shape: (batch_size, n_windows, embed_dim)
142
+ dropped_embedding = self.dropped_xyz_token.view(1, 1, embed_dim).expand(batch_size, n_windows, embed_dim)
143
+
144
+ # Add positional embeddings to dropped tokens (same as regular embeddings get)
145
+ n_windows_h = n_windows_w = int(n_windows ** 0.5)
146
+ pos_embed_patch = self._get_pos_embed((n_windows_h, n_windows_w)).reshape(
147
+ 1, n_windows, embed_dim
148
+ )
149
+ dropped_embedding = dropped_embedding + pos_embed_patch
150
+ drop_mask_expanded = drop_mask.view(batch_size, 1, 1).expand_as(embeddings)
151
+ embeddings = torch.where(drop_mask_expanded, dropped_embedding, embeddings)
152
+
153
+ return embeddings
154
+
155
+ @torch._dynamo.disable()
156
+ def embed_pointmap_windows(
157
+ self, xyz: torch.Tensor, valid_mask: torch.Tensor = None
158
+ ) -> torch.Tensor:
159
+ """Process pointmap into window embeddings without positional encoding"""
160
+ with torch.no_grad():
161
+ xyz = self.resize_input(xyz)
162
+ if valid_mask is None:
163
+ valid_mask = xyz.isfinite().all(dim=-1)
164
+
165
+ B, H, W, _ = xyz.shape
166
+ assert (
167
+ H % self.patch_size == 0 and W % self.patch_size == 0
168
+ ), "image must be divisible by patch_size"
169
+
170
+ # (1) Handle NaN values before remapping to prevent propagation
171
+ xyz_safe = xyz.clone()
172
+ xyz_safe[~valid_mask] = 0.0 # Set invalid points to 0 before remapping
173
+
174
+ # (1b) remap points to normalize their range
175
+ xyz_remapped = self.point_remapper(xyz_safe)
176
+
177
+ # (2) project + invalid token
178
+ x = self.point_proj(xyz_remapped) # (B,H,W,D)
179
+
180
+ x[~valid_mask] = 0.0 # Stop gradient for invalid points
181
+ x[~valid_mask] += self.invalid_xyz_token
182
+
183
+ return x, B, H, W
184
+
185
+ def inner_forward(
186
+ self, x: torch.Tensor, B: int, H: int, W: int
187
+ ) -> torch.Tensor:
188
+ x = x.view(
189
+ B,
190
+ H // self.patch_size,
191
+ self.patch_size,
192
+ W // self.patch_size,
193
+ self.patch_size,
194
+ self.embed_dim,
195
+ ) # (B, hW, wW, ws, ws, D)
196
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous() # (B, hW, wW, ws, ws, D)
197
+ x = x.view(-1, self.patch_size * self.patch_size, self.embed_dim)
198
+
199
+ # (4) CLS token that contains the patch information
200
+ cls_tok = self.cls_token.expand(x.shape[0], -1, -1)
201
+ toks = torch.cat([cls_tok, x], dim=1)
202
+
203
+ # (5) add positional embedding for window
204
+ toks = toks + self.pos_embed_window
205
+
206
+ # (6) intra-window attention
207
+ for blk in self.blocks:
208
+ toks = blk(toks)
209
+
210
+ # (7) Extract CLS tokens and reshape to (B, n_windows, embed_dim)
211
+ n_windows_h = H // self.patch_size
212
+ n_windows_w = W // self.patch_size
213
+ window_embeddings = toks[:, 0].view(B, n_windows_h * n_windows_w, self.embed_dim)
214
+
215
+ # Add positional embeddings
216
+ pos_embed_patch = self._get_pos_embed((n_windows_h, n_windows_w)).reshape(
217
+ 1, n_windows_h * n_windows_w, self.embed_dim
218
+ )
219
+ out = window_embeddings + pos_embed_patch
220
+
221
+ # Apply dropout if enabled (during training OR when forced)
222
+ if (self.training or self.force_dropout_always) and self.dropout_prob > 0:
223
+ out = self.apply_pointmap_dropout(out)
224
+
225
+ return out
226
+
227
+ def forward(
228
+ self, xyz: torch.Tensor, valid_mask: torch.Tensor = None
229
+ ) -> torch.Tensor:
230
+ """
231
+ xyz : (B, 3, H, W) map of (x,y,z) coordinates
232
+ valid_mask : (B, H, W) boolean - True for valid points (optional)
233
+
234
+ returns: (B, num_windows, D)
235
+ """
236
+ # Get window embeddings
237
+ x, B, H, W = self.embed_pointmap_windows(xyz, valid_mask)
238
+ return self.inner_forward(x, B, H, W)
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/base.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import torch
3
+ from typing import Optional, Union
4
+
5
+
6
+ class Base(torch.nn.Module):
7
+ def __init__(self, seed_or_generator: Optional[Union[int, torch.Generator]] = None):
8
+ super().__init__()
9
+
10
+ if isinstance(seed_or_generator, torch.Generator):
11
+ self.random_generator = seed_or_generator
12
+ elif isinstance(seed_or_generator, int):
13
+ self.seed = seed_or_generator
14
+ elif seed_or_generator is None:
15
+ self.random_generator = torch.default_generator
16
+ else:
17
+ raise RuntimeError(
18
+ f"cannot use argument of type {type(seed_or_generator)} to set random generator"
19
+ )
20
+
21
+ @property
22
+ def seed(self):
23
+ raise AttributeError(f"Cannot read attribute 'seed'.")
24
+
25
+ @seed.setter
26
+ def seed(self, value: int):
27
+ self._random_generator = torch.Generator().manual_seed(value)
28
+
29
+ @property
30
+ def random_generator(self):
31
+ return self._random_generator
32
+
33
+ @random_generator.setter
34
+ def random_generator(self, generator: torch.Generator):
35
+ self._random_generator = generator
36
+
37
+ def forward(self, x_shape, x_device, *args_conditionals, **kwargs_conditionals):
38
+ return self.generate(
39
+ x_shape,
40
+ x_device,
41
+ *args_conditionals,
42
+ **kwargs_conditionals,
43
+ )
44
+
45
+ def generate(self, x_shape, x_device, *args_conditionals, **kwargs_conditionals):
46
+ for _, xt, _ in self.generate_iter(
47
+ x_shape,
48
+ x_device,
49
+ *args_conditionals,
50
+ **kwargs_conditionals,
51
+ ):
52
+ pass
53
+ return xt
54
+
55
+ def generate_iter(
56
+ self,
57
+ x_shape,
58
+ x_device,
59
+ *args_conditionals,
60
+ **kwargs_conditionals,
61
+ ):
62
+ raise NotImplementedError
63
+
64
+ def loss(self, x, *args_conditionals, **kwargs_conditionals):
65
+ raise NotImplementedError
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/classifier_free_guidance.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from functools import partial
3
+ from numbers import Number
4
+ import torch
5
+ import random
6
+ from torch.utils import _pytree
7
+ from torch.utils._pytree import tree_map_only
8
+ from loguru import logger
9
+
10
+ def _zeros_like(struct):
11
+ def make_zeros(x):
12
+ if isinstance(x, torch.Tensor):
13
+ return torch.zeros_like(x)
14
+ return x
15
+
16
+ return _pytree.tree_map(make_zeros, struct)
17
+
18
+
19
+ def zero_out(args, kwargs):
20
+ args = _zeros_like(args)
21
+ kwargs = _zeros_like(kwargs)
22
+ return args, kwargs
23
+
24
+
25
+ def discard(args, kwargs):
26
+ return (), {}
27
+
28
+
29
+ def _drop_tensors(struct):
30
+ """
31
+ Drop any conditioning that are tensors
32
+ Not using _pytree since we actually want to throw them instead of keeping them.
33
+ """
34
+ if isinstance(struct, dict):
35
+ return {
36
+ k: _drop_tensors(v)
37
+ for k, v in struct.items()
38
+ if not isinstance(v, torch.Tensor)
39
+ }
40
+ elif isinstance(struct, (list, tuple)):
41
+ filtered = [_drop_tensors(x) for x in struct if not isinstance(x, torch.Tensor)]
42
+ return tuple(filtered) if isinstance(struct, tuple) else filtered
43
+ else:
44
+ return struct
45
+
46
+
47
+ def drop_tensors(args, kwargs):
48
+ args = _drop_tensors(args)
49
+ kwargs = _drop_tensors(kwargs)
50
+ return args, kwargs
51
+
52
+
53
+ def add_flag(args, kwargs):
54
+ kwargs["cfg"] = True
55
+ return args, kwargs
56
+
57
+
58
+ class ClassifierFreeGuidance(torch.nn.Module):
59
+ UNCONDITIONAL_HANDLING_TYPES = {
60
+ "zeros": zero_out,
61
+ "discard": discard,
62
+ "drop_tensors": drop_tensors,
63
+ "add_flag": add_flag,
64
+ }
65
+
66
+ def __init__(
67
+ self,
68
+ backbone, # backbone should be a backbone/generator (e.g. DDPM/DDIM/FlowMatching)
69
+ p_unconditional=0.1,
70
+ strength=3.0,
71
+ # "zeros" = set cond tensors to 0,
72
+ # "discard" = remove cond arguments and let underlying model handle it
73
+ # "drop_tensors" = drop all tensors but leave non-tensors
74
+ # "add_flag" = add an argument in kwargs as "cfg" and defer the handling to generator backbone
75
+ unconditional_handling="zeros",
76
+ interval=None, # only perform cfg if t within interval
77
+ ):
78
+ super().__init__()
79
+
80
+ if not (
81
+ unconditional_handling
82
+ in ClassifierFreeGuidance.UNCONDITIONAL_HANDLING_TYPES
83
+ ):
84
+ raise RuntimeError(
85
+ f"'{unconditional_handling}' is not valid for `unconditional_handling`, should be in {ClassifierFreeGuidance.UNCONDITIONAL_HANDLING_TYPES}"
86
+ )
87
+
88
+ self.backbone = backbone
89
+ self.p_unconditional = p_unconditional
90
+ self.strength = strength
91
+ self.unconditional_handling = unconditional_handling
92
+ self.interval = interval
93
+ self._make_unconditional_args = (
94
+ ClassifierFreeGuidance.UNCONDITIONAL_HANDLING_TYPES[
95
+ self.unconditional_handling
96
+ ]
97
+ )
98
+
99
+ def _cfg_step_tensor(self, y_cond, y_uncond, strength):
100
+ return (1 + strength) * y_cond - strength * y_uncond
101
+
102
+ def _cfg_step(self, y_cond, y_uncond, strength):
103
+ if isinstance(strength, dict):
104
+ return _pytree.tree_map(self._cfg_step_tensor, y_cond, y_uncond, strength)
105
+ else:
106
+ return _pytree.tree_map(partial(self._cfg_step_tensor, strength=strength), y_cond, y_uncond)
107
+
108
+ def inner_forward(self, x, t, is_cond, strength, *args_cond, **kwargs_cond):
109
+ y_cond = self.backbone(x, t, *args_cond, **kwargs_cond)
110
+ if is_cond:
111
+ return y_cond
112
+ else:
113
+ args_cond, kwargs_cond = self._make_unconditional_args(
114
+ args_cond,
115
+ kwargs_cond,
116
+ )
117
+ y_uncond = self.backbone(x, t, *args_cond, **kwargs_cond)
118
+ return self._cfg_step(y_cond, y_uncond, strength)
119
+
120
+ def forward(self, x, t, *args_cond, **kwargs_cond):
121
+ # handle case when no conditional arguments are provided
122
+ if len(args_cond) + len(kwargs_cond) == 0: # unconditional
123
+ if self.unconditional_handling != "discard":
124
+ raise RuntimeError(
125
+ f"cannot call `ClassifierFreeGuidance` module without condition"
126
+ )
127
+ return self.backbone(x, t)
128
+ else: # conditional arguments are provided
129
+ # training mode
130
+ if self.training:
131
+ coin_flip = random.random() < self.p_unconditional
132
+ if coin_flip: # unconditional
133
+ args_cond, kwargs_cond = self._make_unconditional_args(
134
+ args_cond,
135
+ kwargs_cond,
136
+ )
137
+ return self.backbone(x, t, *args_cond, **kwargs_cond)
138
+ else: # inference mode
139
+ strength = get_strength(self.strength, self.interval, t)
140
+ is_cond = not any(x > 0.0 for x in _pytree.tree_flatten(strength)[0])
141
+ return self.inner_forward(
142
+ x, t, is_cond, strength, *args_cond, **kwargs_cond
143
+ )
144
+
145
+ def get_strength(strength, interval, t):
146
+ if interval is None:
147
+ return _pytree.tree_map(lambda x: 0.0, strength)
148
+
149
+ # If interval is not a dict (single tuple), broadcast it
150
+ if not isinstance(interval, dict):
151
+ return _pytree.tree_map(
152
+ lambda x: x if interval[0] <= t <= interval[1] else 0.0,
153
+ strength
154
+ )
155
+
156
+ return _pytree.tree_map(
157
+ lambda x, iv: x if iv[0] <= t <= iv[1] else 0.0,
158
+ strength,
159
+ interval
160
+ )
161
+
162
+ class PointmapCFG(ClassifierFreeGuidance):
163
+
164
+ def __init__(self, *args, strength_pm=0.0, **kwargs):
165
+ super().__init__(*args, **kwargs)
166
+ self.strength_pm = strength_pm
167
+
168
+ def _cfg_step_tensor(self, y_cond, y_uncond, y_unpm, strength, strength_pm):
169
+ # https://arxiv.org/abs/2411.18613
170
+ return y_cond \
171
+ + strength_pm * (y_cond - y_unpm) \
172
+ + strength * (y_unpm - y_uncond)
173
+
174
+ def _cfg_step(self, y_cond, y_uncond, y_pm, strength, strength_pm):
175
+ if isinstance(strength, dict):
176
+ return _pytree.tree_map(self._cfg_step_tensor, y_cond, y_uncond, y_pm, strength, strength_pm)
177
+ else:
178
+ return _pytree.tree_map(partial(self._cfg_step_tensor, strength=strength, strength_pm=strength_pm), y_cond, y_uncond, y_pm)
179
+
180
+ def inner_forward(self, x, t, is_cond, strength, strength_pm, *args_cond, **kwargs_cond):
181
+ y_cond = self.backbone(x, t, *args_cond, **kwargs_cond)
182
+
183
+ if is_cond:
184
+ return y_cond
185
+ else:
186
+ force_drop_modalities = self.backbone.condition_embedder.force_drop_modalities
187
+ self.backbone.condition_embedder.force_drop_modalities = ['pointmap', 'rgb_pointmap']
188
+ y_pm = self.backbone(x, t, *args_cond, **kwargs_cond)
189
+ self.backbone.condition_embedder.force_drop_modalities = force_drop_modalities
190
+
191
+ args_cond, kwargs_cond = self._make_unconditional_args(
192
+ args_cond,
193
+ kwargs_cond,
194
+ )
195
+ y_uncond = self.backbone(x, t, *args_cond, **kwargs_cond)
196
+ return self._cfg_step(y_cond, y_uncond, y_pm, strength, strength_pm)
197
+
198
+ def forward(self, x, t, *args_cond, **kwargs_cond):
199
+ # handle case when no conditional arguments are provided
200
+ if len(args_cond) + len(kwargs_cond) == 0: # unconditional
201
+ if self.unconditional_handling != "discard":
202
+ raise RuntimeError(
203
+ f"cannot call `ClassifierFreeGuidance` module without condition"
204
+ )
205
+ return self.backbone(x, t)
206
+ else: # conditional arguments are provided
207
+ # training mode
208
+ if self.training:
209
+ coin_flip = random.random() < self.p_unconditional
210
+ if coin_flip: # unconditional
211
+ args_cond, kwargs_cond = self._make_unconditional_args(
212
+ args_cond,
213
+ kwargs_cond,
214
+ )
215
+ return self.backbone(x, t, *args_cond, **kwargs_cond)
216
+ else: # inference mode
217
+ strength = get_strength(self.strength, self.interval, t)
218
+ is_cond = not any(x > 0.0 for x in _pytree.tree_flatten(strength)[0])
219
+ strength_pm = get_strength(self.strength_pm, self.interval, t)
220
+ return self.inner_forward(
221
+ x, t, is_cond, strength, strength_pm, *args_cond, **kwargs_cond
222
+ )
223
+
224
+ class ClassifierFreeGuidanceWithExternalUnconditionalProbability(ClassifierFreeGuidance):
225
+
226
+ def __init__(self, *args, use_unconditional_from_flow_matching=False, **kwargs):
227
+ super().__init__(*args, **kwargs)
228
+ self.use_unconditional_from_flow_matching = use_unconditional_from_flow_matching
229
+
230
+ def forward(self, x, t, *args_cond, p_unconditional=None, **kwargs_cond):
231
+ # p_unconditional should be a value in [0, 1], indicating the probability of unconditional
232
+
233
+ if p_unconditional is None:
234
+ coin_flip = random.random() < self.p_unconditional
235
+ else:
236
+ coin_flip = random.random() < p_unconditional
237
+
238
+ # handle case when no conditional arguments are provided
239
+ if len(args_cond) + len(kwargs_cond) == 0: # unconditional
240
+ if self.unconditional_handling != "discard":
241
+ raise RuntimeError(
242
+ f"cannot call `ClassifierFreeGuidance` module without condition"
243
+ )
244
+ return self.backbone(x, t)
245
+ else: # conditional arguments are provided
246
+ # training mode
247
+ if self.training:
248
+ if coin_flip: # unconditional
249
+ args_cond, kwargs_cond = self._make_unconditional_args(
250
+ args_cond,
251
+ kwargs_cond,
252
+ )
253
+ return self.backbone(x, t, *args_cond, **kwargs_cond)
254
+ else: # inference mode
255
+ strength = get_strength(self.strength, self.interval, t)
256
+ is_cond = not any(x > 0.0 for x in _pytree.tree_flatten(strength)[0])
257
+ return self.inner_forward(
258
+ x, t, is_cond, strength, *args_cond, **kwargs_cond
259
+ )
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/model.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from typing import Callable, Sequence, Union
3
+ import torch
4
+ import numpy as np
5
+ from functools import partial
6
+ import optree
7
+ import math
8
+
9
+ from sam3d_objects.model.backbone.generator.base import Base
10
+ from sam3d_objects.data.utils import right_broadcasting
11
+ from sam3d_objects.data.utils import tree_tensor_map, tree_reduce_unique
12
+ from sam3d_objects.model.backbone.generator.flow_matching.solver import (
13
+ ODESolver,
14
+ Euler,
15
+ Midpoint,
16
+ RungeKutta4,
17
+ gradient,
18
+ SDE,
19
+ )
20
+
21
+ # default sampler in flow matching
22
+ uniform_sampler = torch.rand
23
+
24
+
25
+ # https://arxiv.org/pdf/2403.03206
26
+ def lognorm_sampler(mean=0.0, std=1.0, **kwargs):
27
+ logit = torch.randn(**kwargs) * std + mean
28
+ return torch.nn.functional.sigmoid(logit)
29
+
30
+
31
+ # for backwards compatibility; please do not use this
32
+ def rev_lognorm_sampler(mean=0.0, std=1.0, **kwargs):
33
+ logit = torch.randn(**kwargs) * std + mean
34
+ return 1 - torch.nn.functional.sigmoid(logit)
35
+
36
+
37
+ # https://arxiv.org/pdf/2210.02747
38
+ class FlowMatching(Base):
39
+ SOLVER_METHODS = {
40
+ "euler": Euler,
41
+ "midpoint": Midpoint,
42
+ "rk4": RungeKutta4,
43
+ "sde": SDE,
44
+ }
45
+
46
+ def __init__(
47
+ self,
48
+ reverse_fn: Callable,
49
+ sigma_min: float = 0.0, # 0. = rectifier flow
50
+ inference_steps: int = 100,
51
+ time_scale: float = 1000.0, # scale [0,1]-time before passing to `reverse_fn`
52
+ training_time_sampler_fn: Callable = partial(
53
+ lognorm_sampler,
54
+ mean=0,
55
+ std=1,
56
+ ),
57
+ reversed_timestamp=False,
58
+ rescale_t=1.0,
59
+ loss_fn=partial(torch.nn.functional.mse_loss, reduction="mean"),
60
+ loss_weights=1.0,
61
+ solver_method: Union[str, ODESolver] = "euler",
62
+ solver_kwargs: dict = {},
63
+ **kwargs,
64
+ ):
65
+ super().__init__(**kwargs)
66
+
67
+ self.reverse_fn = reverse_fn
68
+ self.sigma_min = sigma_min
69
+ self.inference_steps = inference_steps
70
+ self.time_scale = time_scale
71
+ self.training_time_sampler_fn = training_time_sampler_fn
72
+ self.reversed_timestamp = reversed_timestamp
73
+ self.rescale_t = rescale_t
74
+ self.loss_fn = loss_fn
75
+ self.loss_weights = loss_weights
76
+ self._solver_method, self._solver = self._get_solver(
77
+ solver_method, solver_kwargs
78
+ )
79
+
80
+ def _get_solver(self, solver_method, solver_kwargs):
81
+ if solver_method in FlowMatching.SOLVER_METHODS:
82
+ solver = FlowMatching.SOLVER_METHODS[solver_method](**solver_kwargs)
83
+ elif isinstance(solver_method, ODESolver):
84
+ solver_method = f"custom[{solver_method.__class__.__name__}]"
85
+ solver = solver_method
86
+ else:
87
+ raise ValueError(
88
+ f"Invalid solver `{solver_method}`, should be in {set(self.SOLVER_METHODS.keys())} or an ODESolver instance"
89
+ )
90
+ return solver_method, solver
91
+
92
+ def _generate_noise_tensor(self, x_shape, x_device):
93
+ return torch.randn(
94
+ x_shape,
95
+ # generator=self.random_generator,
96
+ device=x_device,
97
+ )
98
+
99
+ def _generate_noise(self, x_shape, x_device):
100
+ def is_shape(maybe_shape):
101
+ return isinstance(maybe_shape, Sequence) and all(
102
+ (isinstance(s, int) and s >= 0) for s in maybe_shape
103
+ )
104
+
105
+ return optree.tree_map(
106
+ partial(self._generate_noise_tensor, x_device=x_device),
107
+ x_shape,
108
+ is_leaf=is_shape,
109
+ none_is_leaf=False,
110
+ )
111
+
112
+ def _generate_x0_tensor(self, x1: torch.Tensor):
113
+ x0 = self._generate_noise_tensor(x1.shape, x1.device)
114
+ return x0
115
+
116
+ def _generate_xt_tensor(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor):
117
+ # equation (22)
118
+ tb = right_broadcasting(t.to(x1.device), x1)
119
+ x_t = (1 - (1 - self.sigma_min) * tb) * x0 + tb * x1
120
+
121
+ return x_t
122
+
123
+ def _generate_target_tensor(self, x0: torch.Tensor, x1: torch.Tensor):
124
+ # equation (23)
125
+ target = x1 - (1 - self.sigma_min) * x0
126
+
127
+ return target
128
+
129
+ def _generate_x0(self, x1):
130
+ return tree_tensor_map(self._generate_x0_tensor, x1)
131
+
132
+ def _generate_xt(self, x0, x1, t):
133
+ return tree_tensor_map(
134
+ partial(self._generate_xt_tensor, t=t),
135
+ x0,
136
+ x1,
137
+ )
138
+
139
+ def _generate_target(self, x0, x1):
140
+ return tree_tensor_map(
141
+ self._generate_target_tensor,
142
+ x0,
143
+ x1,
144
+ )
145
+
146
+ def _generate_t(self, x1):
147
+ first_tensor = optree.tree_flatten(x1)[0][0]
148
+ batch_size = first_tensor.shape[0]
149
+ device = first_tensor.device
150
+
151
+ t = self.training_time_sampler_fn(
152
+ size=(batch_size,),
153
+ generator=self.random_generator,
154
+ ).to(device)
155
+
156
+ return t
157
+
158
+ def loss(self, x1: torch.Tensor, *args_conditionals, **kwargs_conditionals):
159
+ t = self._generate_t(x1)
160
+ x0 = self._generate_x0(x1)
161
+ x_t = self._generate_xt(x0, x1, t)
162
+ target = self._generate_target(x0, x1)
163
+
164
+ prediction = self.reverse_fn(
165
+ x_t,
166
+ t * self.time_scale,
167
+ *args_conditionals,
168
+ **kwargs_conditionals,
169
+ )
170
+
171
+ # broadcast & and compute loss
172
+ loss = optree.tree_broadcast_map(
173
+ lambda fn, weight, pred, targ: weight * fn(pred, targ),
174
+ self.loss_fn,
175
+ self.loss_weights,
176
+ prediction,
177
+ target,
178
+ )
179
+
180
+ total_loss = sum(optree.tree_flatten(loss)[0])
181
+
182
+ # Create detailed loss breakdown
183
+ detail_losses = {
184
+ "flow_matching_loss": total_loss,
185
+ }
186
+ if isinstance(loss, dict):
187
+ detail_losses.update(loss)
188
+ return total_loss, detail_losses
189
+
190
+ def _prepare_t(self, steps=None):
191
+ steps = self.inference_steps if steps is None else steps
192
+ t_seq = torch.linspace(0, 1, steps + 1)
193
+
194
+ if self.rescale_t:
195
+ t_seq = t_seq / (1 + (self.rescale_t - 1) * (1 - t_seq))
196
+
197
+ if self.reversed_timestamp:
198
+ t_seq = 1 - t_seq
199
+
200
+ return t_seq
201
+
202
+ def generate_iter(
203
+ self,
204
+ x_shape,
205
+ x_device,
206
+ *args_conditionals,
207
+ **kwargs_conditionals,
208
+ ):
209
+ x_0 = self._generate_noise(x_shape, x_device)
210
+ t_seq = self._prepare_t().to(x_device)
211
+
212
+ for x_t, t in self._solver.solve_iter(
213
+ self._generate_dynamics,
214
+ x_0,
215
+ t_seq,
216
+ *args_conditionals,
217
+ **kwargs_conditionals,
218
+ ):
219
+ yield t, x_t, ()
220
+
221
+ def _generate_dynamics(
222
+ self,
223
+ x_t,
224
+ t,
225
+ *args_conditionals,
226
+ **kwargs_conditionals,
227
+ ):
228
+ return self.reverse_fn(x_t, t * self.time_scale, *args_conditionals, **kwargs_conditionals)
229
+
230
+ def _log_p0(self, x0):
231
+ x0 = self._tree_flatten(x0)
232
+ inside_exp = -(x0**2).sum(dim=1) / 2
233
+ return inside_exp - math.log(2 * math.pi) / 2 * x0.shape[1]
234
+
235
+ def log_likelihood(
236
+ self,
237
+ x1,
238
+ solver=None,
239
+ steps=None,
240
+ z_samples=1,
241
+ *args_conditionals,
242
+ **kwargs_conditionals,
243
+ ):
244
+ device = tree_reduce_unique(lambda tensor: tensor.device, x1)
245
+ # device = "cuda"
246
+ t_seq = self._prepare_t(steps).to(device)
247
+ t_seq = 1 - t_seq # from x1 to x0
248
+ solver = self._solver if solver is None else self._get_solver(solver)[1]
249
+
250
+ x_0 = solver.solve(
251
+ partial(self._log_likelihood_dynamics, device=device, z_samples=z_samples),
252
+ {"x": x1, "log_p": 0.0},
253
+ t_seq,
254
+ *args_conditionals,
255
+ **kwargs_conditionals,
256
+ )
257
+
258
+ log_p1 = x_0["log_p"] + self._log_p0(x_0["x"])
259
+
260
+ return log_p1
261
+
262
+ def _log_likelihood_dynamics(
263
+ self,
264
+ state,
265
+ t,
266
+ device,
267
+ z_samples,
268
+ *args_conditionals,
269
+ **kwargs_conditionals,
270
+ ):
271
+ t = torch.tensor([t * self.time_scale], device=device, dtype=torch.float32)
272
+ x_t = state["x"]
273
+
274
+ with torch.set_grad_enabled(True):
275
+ tree_tensor_map(lambda x,: x.requires_grad_(True), x_t)
276
+ velocity = self.reverse_fn(
277
+ x_t,
278
+ t,
279
+ *args_conditionals,
280
+ **kwargs_conditionals,
281
+ )
282
+
283
+ # compute the divergence estimate
284
+ div = self._compute_hutchinson_divergence(velocity, x_t, z_samples)
285
+
286
+ tree_tensor_map(lambda x,: x.requires_grad_(False), x_t)
287
+ velocity = tree_tensor_map(lambda x: x.detach(), velocity)
288
+
289
+ return {"x": velocity, "log_p": div.detach()}
290
+
291
+ def _tree_flatten(self, tree):
292
+ flat_x = tree_tensor_map(lambda x: x.flatten(start_dim=1), tree)
293
+ flat_x, _ = optree.tree_flatten(
294
+ flat_x,
295
+ is_leaf=lambda x: isinstance(x, torch.Tensor),
296
+ )
297
+ flat_x = torch.cat(flat_x, dim=1)
298
+ return flat_x
299
+
300
+ def _compute_hutchinson_divergence(self, velocity, x_t, z_samples):
301
+ flat_velocity = self._tree_flatten(velocity)
302
+ flat_velocity = flat_velocity.unsqueeze(-1)
303
+
304
+ z = torch.randn(
305
+ flat_velocity.shape[:-1] + (z_samples,),
306
+ dtype=flat_velocity.dtype,
307
+ device=flat_velocity.device,
308
+ )
309
+ z = z < 0
310
+ z = z * 2.0 - 1.0
311
+ z = z / math.sqrt(z_samples)
312
+
313
+ # compute Hutchinson divergence estimator E[z^T D_x(vt) z] = E[D_x(z^T vt) z)]
314
+ vt_dot_z = torch.einsum("ijk,ijk->ik", flat_velocity, z)
315
+ grad_vt_dot_z = [
316
+ gradient(vt_dot_z[..., i], x_t, create_graph=(z_samples > 1))
317
+ for i in range(z_samples)
318
+ ]
319
+ grad_vt_dot_z = [self._tree_flatten(g) for g in grad_vt_dot_z]
320
+ grad_vt_dot_z = torch.stack(grad_vt_dot_z, dim=-1)
321
+ div = torch.einsum("ijk,ijk->i", grad_vt_dot_z, z)
322
+ return div
323
+
324
+
325
+ def _get_device(x):
326
+ device = tree_reduce_unique(lambda tensor: tensor.device, x)
327
+ return device
328
+
329
+
330
+ class ConditionalFlowMatching(FlowMatching):
331
+ def generate_iter(
332
+ self,
333
+ x_shape,
334
+ x_device,
335
+ *args_conditionals,
336
+ **kwargs_conditionals,
337
+ ):
338
+ x_0 = self._generate_noise(x_shape, x_device)
339
+ t_seq = self._prepare_t().to(x_device)
340
+
341
+ noise_override = None
342
+ if "noise_override" in kwargs_conditionals:
343
+ noise_override = kwargs_conditionals["noise_override"]
344
+ del kwargs_conditionals["noise_override"]
345
+ if noise_override is not None:
346
+ if type(x_0) == dict:
347
+ x_0.update(noise_override)
348
+ else:
349
+ x_0 = noise_override
350
+
351
+ for x_t, t in self._solver.solve_iter(
352
+ self._generate_dynamics,
353
+ x_0,
354
+ t_seq,
355
+ *args_conditionals,
356
+ **kwargs_conditionals,
357
+ ):
358
+ if noise_override is not None:
359
+ if type(noise_override) == dict:
360
+ x_t.update(noise_override)
361
+ else:
362
+ x_t = noise_override
363
+ yield t, x_t, ()
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/solver.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import optree
3
+ import torch
4
+ from functools import partial
5
+
6
+ from sam3d_objects.data.utils import tree_tensor_map
7
+
8
+
9
+ def linear_approximation_step(x_t, dt, velocity):
10
+ # x_tp1 = x_t + velocity * dt
11
+ x_tp1 = tree_tensor_map(lambda x, v: x + v * dt, x_t, velocity)
12
+ return x_tp1
13
+
14
+
15
+ def gradient(output, x, create_graph: bool = False):
16
+ tensors, pyspec = optree.tree_flatten(
17
+ x, is_leaf=lambda x: isinstance(x, torch.Tensor)
18
+ )
19
+ grad_outputs = [torch.ones_like(output).detach() for _ in tensors]
20
+ grads = torch.autograd.grad(
21
+ output,
22
+ tensors,
23
+ grad_outputs=grad_outputs,
24
+ create_graph=create_graph,
25
+ )
26
+ return optree.tree_unflatten(pyspec, grads)
27
+
28
+
29
+ class ODESolver:
30
+ def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
31
+ raise NotImplementedError
32
+
33
+ def solve_iter(self, dynamics_fn, x_init, times, *args, **kwargs):
34
+ x_t = x_init
35
+ for t0, t1 in zip(times[:-1], times[1:]):
36
+ dt = t1 - t0
37
+ x_t = self.step(dynamics_fn, x_t, t0, dt, *args, **kwargs)
38
+ yield x_t, t0
39
+
40
+ def solve(self, dynamics_fn, x_init, times, *args, **kwargs):
41
+ for x_t, _ in self.solve_iter(dynamics_fn, x_init, times, *args, **kwargs):
42
+ pass
43
+ return x_t
44
+
45
+
46
+ # https://en.wikipedia.org/wiki/Euler_method
47
+ class Euler(ODESolver):
48
+ def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
49
+ velocity = dynamics_fn(x_t, t, *args, **kwargs)
50
+ x_tp1 = linear_approximation_step(x_t, dt, velocity)
51
+ return x_tp1
52
+
53
+
54
+ # https://arxiv.org/abs/2505.05470
55
+ class SDE(ODESolver):
56
+ def __init__(self, **kwargs):
57
+ super().__init__()
58
+ self.sde_strength = kwargs.get("sde_strength", 0.1)
59
+
60
+ def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
61
+ velocity = dynamics_fn(x_t, t, *args, **kwargs)
62
+ sigma = 1 - t
63
+ var_t = sigma / (1 - torch.tensor(sigma).clamp(min=dt))
64
+ std_dev_t = (
65
+ torch.sqrt(variance) * self.sde_strength
66
+ ) # self.sde_strength = alpha
67
+
68
+ def compute_mean(x, v):
69
+ drift_term = x * (std_dev_t**2 / (2 * sigma) * dt)
70
+ velocity_term = v * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
71
+ return x + drift_term + velocity_term
72
+
73
+ prev_sample_mean = tree_tensor_map(compute_mean, x_t, velocity)
74
+
75
+ # Generate noise and compute final sample using tree_tensor_map
76
+ def add_noise(mean_val):
77
+ variance_noise = torch.randn_like(mean_val)
78
+ return mean_val + std_dev_t * torch.sqrt(torch.tensor(dt)) * variance_noise
79
+
80
+ prev_sample = tree_tensor_map(add_noise, prev_sample_mean)
81
+
82
+ return prev_sample
83
+
84
+
85
+ # https://en.wikipedia.org/wiki/Midpoint_method
86
+ class Midpoint(ODESolver):
87
+ def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
88
+ half_dt = 0.5 * dt
89
+
90
+ x_mid = Euler.step(self, dynamics_fn, x_t, t, half_dt, *args, **kwargs)
91
+
92
+ velocity_mid = dynamics_fn(x_mid, t + half_dt, *args, **kwargs)
93
+ x_tp1 = linear_approximation_step(x_t, dt, velocity_mid)
94
+ return x_tp1
95
+
96
+
97
+ # https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods
98
+ class RungeKutta4(ODESolver):
99
+
100
+ def k1(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
101
+ return dynamics_fn(x_t, t, *args, **kwargs)
102
+
103
+ def k2(self, dynamics_fn, x_t, t, dt, k1, *args, **kwargs):
104
+ x_k1 = linear_approximation_step(x_t, dt * 0.5, k1)
105
+ return dynamics_fn(x_k1, t + dt * 0.5, *args, **kwargs)
106
+
107
+ def k3(self, dynamics_fn, x_t, t, dt, k2, *args, **kwargs):
108
+ x_k2 = linear_approximation_step(x_t, dt * 0.5, k2)
109
+ return dynamics_fn(x_k2, t + dt * 0.5, *args, **kwargs)
110
+
111
+ def k4(self, dynamics_fn, x_t, t, dt, k3, *args, **kwargs):
112
+ x_k3 = linear_approximation_step(x_t, dt, k3)
113
+ return dynamics_fn(x_k3, t + dt, *args, **kwargs)
114
+
115
+ def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
116
+ k1 = self.k1(dynamics_fn, x_t, t, dt, *args, **kwargs)
117
+ k2 = self.k2(dynamics_fn, x_t, t, dt, k1, *args, **kwargs)
118
+ k3 = self.k3(dynamics_fn, x_t, t, dt, k2, *args, **kwargs)
119
+ k4 = self.k4(dynamics_fn, x_t, t, dt, k3, *args, **kwargs)
120
+
121
+ def compute_velocity(k1, k2, k3, k4):
122
+ return (k1 + 2 * k2 + 2 * k3 + k4) / 6
123
+
124
+ velocity_k = tree_tensor_map(compute_velocity, k1, k2, k3, k4)
125
+ x_tp1 = linear_approximation_step(x_t, dt, velocity_k)
126
+ return x_tp1
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/shortcut/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.