fix: crash because pos enc is on CPU device
#1
by
kabouzeid
- opened
- modeling_aimv2.py +1 -1
modeling_aimv2.py
CHANGED
|
@@ -101,7 +101,7 @@ class AIMv2ViTPreprocessor(nn.Module):
|
|
| 101 |
tokens = self.patchifier(x)
|
| 102 |
pos_embed = get_sincos_pos_embed(
|
| 103 |
H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
|
| 104 |
-
)
|
| 105 |
tokens = tokens + pos_embed
|
| 106 |
return tokens
|
| 107 |
|
|
|
|
| 101 |
tokens = self.patchifier(x)
|
| 102 |
pos_embed = get_sincos_pos_embed(
|
| 103 |
H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
|
| 104 |
+
).to(tokens.device)
|
| 105 |
tokens = tokens + pos_embed
|
| 106 |
return tokens
|
| 107 |
|