|
|
import json |
|
|
import os |
|
|
import re |
|
|
import traceback |
|
|
from PIL import Image, ImageFile, PngImagePlugin |
|
|
|
|
|
from .interleave_t2i_dataset import InterleavedBaseIterableDataset |
|
|
from ..data_utils import pil_img2rgb |
|
|
from ..distributed_iterable_dataset import DistributedIterableDataset |
|
|
|
|
|
|
|
|
Image.MAX_IMAGE_PIXELS = 200000000 |
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
MaximumDecompressedSize = 1024 |
|
|
MegaByte = 2 ** 20 |
|
|
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte |
|
|
|
|
|
|
|
|
class ThinkTraceJSONLIterableDataset(InterleavedBaseIterableDataset, DistributedIterableDataset): |
|
|
def __init__( |
|
|
self, |
|
|
dataset_name, |
|
|
transform, |
|
|
tokenizer, |
|
|
vit_transform, |
|
|
jsonl_path_list, |
|
|
data_dir_list, |
|
|
num_used_data, |
|
|
local_rank=0, |
|
|
world_size=1, |
|
|
num_workers=8, |
|
|
data_status=None, |
|
|
shuffle_lines=True, |
|
|
shuffle_seed=0, |
|
|
image_prefix_dir=None, |
|
|
): |
|
|
""" |
|
|
Dataset for think-trace style JSONL files with interleaved text and images. |
|
|
|
|
|
Args: |
|
|
dataset_name: Name of the dataset |
|
|
transform: Transform for VAE images |
|
|
tokenizer: Text tokenizer |
|
|
vit_transform: Transform for VIT images |
|
|
jsonl_path_list: List of JSONL file paths |
|
|
data_dir_list: List of base directories (should match jsonl_path_list) |
|
|
num_used_data: List of number of samples to use from each JSONL. If a value is None or non-positive, all data from that JSONL will be used. |
|
|
image_prefix_dir: Absolute path to prepend to relative image paths |
|
|
Other args: Standard distributed dataset args |
|
|
""" |
|
|
DistributedIterableDataset.__init__(self, dataset_name, local_rank, world_size, num_workers) |
|
|
self.transform = transform |
|
|
self.vit_transform = vit_transform |
|
|
self.tokenizer = tokenizer |
|
|
self.data_status = data_status |
|
|
self.image_prefix_dir = image_prefix_dir or "" |
|
|
|
|
|
self.start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>') |
|
|
self.end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>') |
|
|
self.im_start = tokenizer.convert_tokens_to_ids('<|im_start|>') |
|
|
|
|
|
self.data_paths = self.get_data_paths( |
|
|
jsonl_path_list, |
|
|
num_used_data, |
|
|
shuffle_lines, |
|
|
shuffle_seed, |
|
|
) |
|
|
self.set_epoch() |
|
|
|
|
|
def get_data_paths(self, jsonl_path_list, num_used_data, shuffle_lines, shuffle_seed): |
|
|
data_paths = [] |
|
|
if not isinstance(num_used_data, list): |
|
|
num_used_data = [num_used_data] * len(jsonl_path_list) |
|
|
|
|
|
for jsonl_path, num_data_point in zip(jsonl_path_list, num_used_data): |
|
|
with open(jsonl_path, 'r') as f: |
|
|
raw_data = f.readlines() |
|
|
if shuffle_lines: |
|
|
self.rng.seed(shuffle_seed) |
|
|
self.rng.shuffle(raw_data) |
|
|
|
|
|
|
|
|
if num_data_point == 'None': |
|
|
num_data_point = None |
|
|
|
|
|
if num_data_point is not None and int(num_data_point) > 0: |
|
|
raw_data = raw_data[:int(num_data_point)] |
|
|
|
|
|
data_paths.extend(raw_data) |
|
|
return data_paths |
|
|
|
|
|
def extract_image_references(self, text): |
|
|
"""Extract image references from text like <image_start>[problem_image_1]<image_end>""" |
|
|
pattern = r'<image_start>\[([^\]]+)\]<image_end>' |
|
|
matches = re.findall(pattern, text) |
|
|
return matches |
|
|
|
|
|
def replace_image_references(self, text): |
|
|
"""Replace image references with placeholder tokens for processing""" |
|
|
pattern = r'<image_start>\[([^\]]+)\]<image_end>' |
|
|
|
|
|
return re.sub(pattern, '<IMAGE_PLACEHOLDER>', text) |
|
|
|
|
|
def remove_thought_patterns(self, text): |
|
|
"""Remove THOUGHT x: patterns from text""" |
|
|
|
|
|
pattern = r'THOUGHT\s*\d+:\s*' |
|
|
return re.sub(pattern, '', text) |
|
|
|
|
|
def load_image_safely(self, data_item, image_key): |
|
|
"""Load image with null checking and path resolution""" |
|
|
if image_key not in data_item or data_item[image_key] is None: |
|
|
return None |
|
|
|
|
|
image_path = data_item[image_key] |
|
|
full_path = os.path.join(self.image_prefix_dir, image_path) |
|
|
|
|
|
try: |
|
|
return pil_img2rgb(Image.open(full_path)) |
|
|
except Exception as e: |
|
|
print(f"Failed to load image {full_path}: {e}") |
|
|
return None |
|
|
|
|
|
def parse_row(self, json_line): |
|
|
"""Parse a single JSON line into the required format""" |
|
|
try: |
|
|
data_item = json.loads(json_line.strip()) |
|
|
except: |
|
|
traceback.print_exc() |
|
|
return {} |
|
|
|
|
|
|
|
|
prompt = "You are an AI reasoning assistant capable of step-by-step interleaved text and visual chain of thought. Think step by step and generate visual aids to enhance your problem-solving. You should first think about the reasoning and planning process in the mind before generating visual aids. Wrap your text reasoning with <think></think> tokens, and wrap your final conclusion with <answer></answer> tokens. Provide your final conclusion clearly in the format of '<answer>Final Answer: <answer here></answer>'" |
|
|
question = data_item.get('Question', '') |
|
|
question = f'Question: {question}' |
|
|
reasoning_trace = data_item.get('Text Reasoning Trace', '') |
|
|
reasoning_trace = f'{reasoning_trace}' |
|
|
final_answer = data_item.get('Final Answer', '') |
|
|
final_answer = f'<answer>Final Answer: {final_answer}</answer>' |
|
|
|
|
|
if not question or not reasoning_trace or not final_answer: |
|
|
return {} |
|
|
|
|
|
|
|
|
data = self._init_data() |
|
|
|
|
|
|
|
|
data = self._add_text(data, prompt, need_loss=False, enable_cfg=True) |
|
|
|
|
|
|
|
|
question_image_refs = self.extract_image_references(question) |
|
|
if question_image_refs: |
|
|
clean_question = self.replace_image_references(question) |
|
|
question_text_parts = clean_question.split('<IMAGE_PLACEHOLDER>') |
|
|
|
|
|
if len(question_text_parts) != len(question_image_refs) + 1: |
|
|
print(f"Mismatch in question: text parts {len(question_text_parts)}, images {len(question_image_refs)}") |
|
|
return {} |
|
|
|
|
|
question_images = [] |
|
|
for image_ref in question_image_refs: |
|
|
image = self.load_image_safely(data_item, image_ref) |
|
|
if image is None: |
|
|
print(f"Skipping sample due to missing image in question: {image_ref}") |
|
|
return {} |
|
|
question_images.append(image) |
|
|
|
|
|
|
|
|
for i, text_part in enumerate(question_text_parts): |
|
|
if text_part.strip(): |
|
|
|
|
|
data = self._add_text(data, text_part.strip(), need_loss=False, enable_cfg=True) |
|
|
if i < len(question_images): |
|
|
data = self._add_image( |
|
|
data, question_images[i], |
|
|
need_loss=False, |
|
|
need_vae=False, |
|
|
need_vit=True, |
|
|
enable_cfg=True, |
|
|
) |
|
|
else: |
|
|
|
|
|
data = self._add_text(data, question, need_loss=False, enable_cfg=True) |
|
|
|
|
|
|
|
|
image_refs = self.extract_image_references(reasoning_trace) |
|
|
|
|
|
loaded_images = [] |
|
|
for image_ref in image_refs: |
|
|
image = self.load_image_safely(data_item, image_ref) |
|
|
if image is not None: |
|
|
loaded_images.append(image) |
|
|
else: |
|
|
|
|
|
print(f"Skipping sample due to missing image: {image_ref}") |
|
|
return {} |
|
|
|
|
|
|
|
|
clean_reasoning_trace = self.replace_image_references(reasoning_trace) |
|
|
|
|
|
|
|
|
clean_reasoning_trace = self.remove_thought_patterns(clean_reasoning_trace) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_parts = clean_reasoning_trace.split('<IMAGE_PLACEHOLDER>') |
|
|
|
|
|
if len(text_parts) != len(loaded_images) + 1: |
|
|
print(f"Mismatch between text parts ({len(text_parts)}) and images ({len(loaded_images)})") |
|
|
return {} |
|
|
|
|
|
|
|
|
for i, text_part in enumerate(text_parts): |
|
|
|
|
|
if text_part.strip(): |
|
|
|
|
|
wrapped_text = f"<think>{text_part.strip()}</think>" |
|
|
|
|
|
|
|
|
if i < len(loaded_images): |
|
|
|
|
|
next_token_label = self.start_of_image |
|
|
elif i == len(text_parts) - 1: |
|
|
|
|
|
next_token_label = self.im_start |
|
|
else: |
|
|
next_token_label = None |
|
|
|
|
|
data = self._add_text(data, wrapped_text, need_loss=True, enable_cfg=True, next_token_label=next_token_label) |
|
|
|
|
|
|
|
|
if i < len(loaded_images): |
|
|
|
|
|
data = self._add_image( |
|
|
data, |
|
|
loaded_images[i], |
|
|
need_loss=True, |
|
|
need_vae=True, |
|
|
need_vit=True, |
|
|
enable_cfg=True, |
|
|
) |
|
|
|
|
|
|
|
|
data = self._add_text(data, final_answer, need_loss=True, enable_cfg=True) |
|
|
|
|
|
return data |
|
|
|
|
|
def __iter__(self): |
|
|
data_paths_per_worker, worker_id = self.get_data_paths_per_worker() |
|
|
if self.data_status is not None: |
|
|
row_start_id = self.data_status[worker_id] + 1 |
|
|
else: |
|
|
row_start_id = 0 |
|
|
|
|
|
print( |
|
|
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: " |
|
|
f"resuming data at row#{row_start_id}" |
|
|
) |
|
|
|
|
|
while True: |
|
|
data_paths_per_worker_ = data_paths_per_worker[row_start_id:] |
|
|
for row_idx, json_line in enumerate(data_paths_per_worker_, start=row_start_id): |
|
|
try: |
|
|
data = self.parse_row(json_line) |
|
|
if len(data) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
has_loss = any(item['loss'] for item in data['sequence_plan']) |
|
|
if not has_loss: |
|
|
print('No loss defined, skipped.') |
|
|
continue |
|
|
|
|
|
data['data_indexes'] = { |
|
|
"data_indexes": row_idx, |
|
|
"worker_id": worker_id, |
|
|
"dataset_name": self.dataset_name, |
|
|
} |
|
|
yield data |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing row {row_idx}: {e}") |
|
|
traceback.print_exc() |
|
|
continue |
|
|
|
|
|
row_start_id = 0 |
|
|
print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}") |
|
|
|