add streamer and verbose params and image_file support io.ByteIO,PIL.Image.Image for infer method
#23
by
weege007
- opened
- modeling_deepseekocr.py +22 -14
modeling_deepseekocr.py
CHANGED
|
@@ -27,7 +27,9 @@ import time
|
|
| 27 |
def load_image(image_path):
|
| 28 |
|
| 29 |
try:
|
| 30 |
-
image =
|
|
|
|
|
|
|
| 31 |
|
| 32 |
corrected_image = ImageOps.exif_transpose(image)
|
| 33 |
|
|
@@ -353,6 +355,7 @@ class DeepseekOCRConfig(DeepseekV2Config):
|
|
| 353 |
|
| 354 |
class DeepseekOCRModel(DeepseekV2Model):
|
| 355 |
config_class = DeepseekOCRConfig
|
|
|
|
| 356 |
|
| 357 |
def __init__(self, config: DeepseekV2Config):
|
| 358 |
super(DeepseekOCRModel, self).__init__(config)
|
|
@@ -432,10 +435,11 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 432 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 433 |
global_features = self.projector(global_features)
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
|
|
|
| 439 |
|
| 440 |
_, hw, n_dim = global_features.shape
|
| 441 |
h = w = int(hw ** 0.5)
|
|
@@ -475,10 +479,12 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 475 |
global_features_2 = vision_model(image_ori, global_features_1)
|
| 476 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 477 |
global_features = self.projector(global_features)
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
|
|
|
|
|
|
| 482 |
_, hw, n_dim = global_features.shape
|
| 483 |
h = w = int(hw ** 0.5)
|
| 484 |
|
|
@@ -700,11 +706,13 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 700 |
|
| 701 |
|
| 702 |
|
| 703 |
-
def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
|
| 704 |
self.disable_torch_init()
|
|
|
|
| 705 |
|
| 706 |
-
|
| 707 |
-
|
|
|
|
| 708 |
|
| 709 |
if prompt and image_file:
|
| 710 |
conversation = [
|
|
@@ -716,7 +724,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 716 |
# "content": "<image>\nFree OCR. ",
|
| 717 |
# "content": "<image>\nParse the figure. ",
|
| 718 |
# "content": "<image>\nExtract the text in the image. ",
|
| 719 |
-
"images": [f'{image_file}'],
|
| 720 |
},
|
| 721 |
{"role": "<|Assistant|>", "content": ""},
|
| 722 |
]
|
|
@@ -910,7 +918,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 910 |
|
| 911 |
|
| 912 |
if not eval_mode:
|
| 913 |
-
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 914 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 915 |
with torch.no_grad():
|
| 916 |
output_ids = self.generate(
|
|
|
|
| 27 |
def load_image(image_path):
|
| 28 |
|
| 29 |
try:
|
| 30 |
+
image = image_path
|
| 31 |
+
if not isinstance(image_path, Image.Image):
|
| 32 |
+
image = Image.open(image_path)
|
| 33 |
|
| 34 |
corrected_image = ImageOps.exif_transpose(image)
|
| 35 |
|
|
|
|
| 355 |
|
| 356 |
class DeepseekOCRModel(DeepseekV2Model):
|
| 357 |
config_class = DeepseekOCRConfig
|
| 358 |
+
verbose = True
|
| 359 |
|
| 360 |
def __init__(self, config: DeepseekV2Config):
|
| 361 |
super(DeepseekOCRModel, self).__init__(config)
|
|
|
|
| 435 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 436 |
global_features = self.projector(global_features)
|
| 437 |
|
| 438 |
+
if self.verbose:
|
| 439 |
+
print('=====================')
|
| 440 |
+
print('BASE: ', global_features.shape)
|
| 441 |
+
print('PATCHES: ', local_features.shape)
|
| 442 |
+
print('=====================')
|
| 443 |
|
| 444 |
_, hw, n_dim = global_features.shape
|
| 445 |
h = w = int(hw ** 0.5)
|
|
|
|
| 479 |
global_features_2 = vision_model(image_ori, global_features_1)
|
| 480 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 481 |
global_features = self.projector(global_features)
|
| 482 |
+
|
| 483 |
+
if self.verbose:
|
| 484 |
+
print('=====================')
|
| 485 |
+
print('BASE: ', global_features.shape)
|
| 486 |
+
print('NO PATCHES')
|
| 487 |
+
print('=====================')
|
| 488 |
_, hw, n_dim = global_features.shape
|
| 489 |
h = w = int(hw ** 0.5)
|
| 490 |
|
|
|
|
| 706 |
|
| 707 |
|
| 708 |
|
| 709 |
+
def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None, verbose=True):
|
| 710 |
self.disable_torch_init()
|
| 711 |
+
self.model.verbose = verbose
|
| 712 |
|
| 713 |
+
if len(output_path) > 0 :
|
| 714 |
+
os.makedirs(output_path, exist_ok=True)
|
| 715 |
+
os.makedirs(f'{output_path}/images', exist_ok=True)
|
| 716 |
|
| 717 |
if prompt and image_file:
|
| 718 |
conversation = [
|
|
|
|
| 724 |
# "content": "<image>\nFree OCR. ",
|
| 725 |
# "content": "<image>\nParse the figure. ",
|
| 726 |
# "content": "<image>\nExtract the text in the image. ",
|
| 727 |
+
"images": [image_file] if isinstance(image_file, (BytesIO, Image.Image)) else [f'{image_file}'],
|
| 728 |
},
|
| 729 |
{"role": "<|Assistant|>", "content": ""},
|
| 730 |
]
|
|
|
|
| 918 |
|
| 919 |
|
| 920 |
if not eval_mode:
|
| 921 |
+
streamer = streamer or NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 922 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 923 |
with torch.no_grad():
|
| 924 |
output_ids = self.generate(
|