Spaces:
Runtime error
Runtime error
Update modeling_llava_qwen2.py
Browse files- modeling_llava_qwen2.py +5 -5
modeling_llava_qwen2.py
CHANGED
|
@@ -535,13 +535,13 @@ class SigLipVisionTower(nn.Module):
|
|
| 535 |
if type(images) is list:
|
| 536 |
image_features = []
|
| 537 |
for image in images:
|
| 538 |
-
image_forward_out = self.vision_tower(image.to(device=
|
| 539 |
output_hidden_states=True)
|
| 540 |
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
| 541 |
assert image_features.shape[-2] == 729
|
| 542 |
image_features.append(image_feature)
|
| 543 |
else:
|
| 544 |
-
image_forward_outs = self.vision_tower(images.to(device=
|
| 545 |
output_hidden_states=True)
|
| 546 |
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
| 547 |
assert image_features.shape[-2] == 729
|
|
@@ -550,7 +550,7 @@ class SigLipVisionTower(nn.Module):
|
|
| 550 |
|
| 551 |
@property
|
| 552 |
def dummy_feature(self):
|
| 553 |
-
return torch.zeros(1, self.hidden_size, device=
|
| 554 |
|
| 555 |
@property
|
| 556 |
def dtype(self):
|
|
@@ -682,9 +682,9 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 682 |
image_features = self.encode_images(concat_images)
|
| 683 |
split_sizes = [image.shape[0] for image in images]
|
| 684 |
image_features = torch.split(image_features, split_sizes, dim=0)
|
| 685 |
-
image_features = [x.flatten(0, 1).to(
|
| 686 |
else:
|
| 687 |
-
image_features = self.encode_images(images).to(
|
| 688 |
|
| 689 |
# Let's just add dummy tensors if they do not exist,
|
| 690 |
# it is a headache to deal with None all the time.
|
|
|
|
| 535 |
if type(images) is list:
|
| 536 |
image_features = []
|
| 537 |
for image in images:
|
| 538 |
+
image_forward_out = self.vision_tower(image.to(device="cuda:0", dtype=self.dtype).unsqueeze(0),
|
| 539 |
output_hidden_states=True)
|
| 540 |
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
| 541 |
assert image_features.shape[-2] == 729
|
| 542 |
image_features.append(image_feature)
|
| 543 |
else:
|
| 544 |
+
image_forward_outs = self.vision_tower(images.to(device="cuda:0", dtype=self.dtype),
|
| 545 |
output_hidden_states=True)
|
| 546 |
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
| 547 |
assert image_features.shape[-2] == 729
|
|
|
|
| 550 |
|
| 551 |
@property
|
| 552 |
def dummy_feature(self):
|
| 553 |
+
return torch.zeros(1, self.hidden_size, device="cuda:0", dtype=self.dtype)
|
| 554 |
|
| 555 |
@property
|
| 556 |
def dtype(self):
|
|
|
|
| 682 |
image_features = self.encode_images(concat_images)
|
| 683 |
split_sizes = [image.shape[0] for image in images]
|
| 684 |
image_features = torch.split(image_features, split_sizes, dim=0)
|
| 685 |
+
image_features = [x.flatten(0, 1).to("cuda:0") for x in image_features]
|
| 686 |
else:
|
| 687 |
+
image_features = self.encode_images(images).to("cuda:0")
|
| 688 |
|
| 689 |
# Let's just add dummy tensors if they do not exist,
|
| 690 |
# it is a headache to deal with None all the time.
|