|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
|
import json |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from .data_utils import ( |
|
|
get_flattened_position_ids_interpolate, |
|
|
get_flattened_position_ids_extrapolate, |
|
|
len2weight, |
|
|
patchify, |
|
|
prepare_attention_mask_per_sample, |
|
|
) |
|
|
from .dataset_info import DATASET_INFO, DATASET_REGISTRY |
|
|
from .transforms import ImageTransform |
|
|
from .video_utils import FrameSampler |
|
|
|
|
|
|
|
|
class DataConfig: |
|
|
def __init__( |
|
|
self, |
|
|
grouped_datasets, |
|
|
text_cond_dropout_prob=0.1, |
|
|
vit_cond_dropout_prob=0.4, |
|
|
vae_cond_dropout_prob=0.1, |
|
|
vae_image_downsample=16, |
|
|
max_latent_size=32, |
|
|
vit_patch_size=14, |
|
|
max_num_patch_per_side=70, |
|
|
): |
|
|
self.grouped_datasets = grouped_datasets |
|
|
self.text_cond_dropout_prob = text_cond_dropout_prob |
|
|
self.vit_cond_dropout_prob = vit_cond_dropout_prob |
|
|
self.vit_patch_size = vit_patch_size |
|
|
self.max_num_patch_per_side = max_num_patch_per_side |
|
|
self.vae_cond_dropout_prob = vae_cond_dropout_prob |
|
|
self.vae_image_downsample = vae_image_downsample |
|
|
self.max_latent_size = max_latent_size |
|
|
|
|
|
|
|
|
class PackedDataset(torch.utils.data.IterableDataset): |
|
|
def __init__( |
|
|
self, |
|
|
data_config, |
|
|
tokenizer, |
|
|
special_tokens, |
|
|
local_rank, |
|
|
world_size, |
|
|
num_workers, |
|
|
expected_num_tokens=32768, |
|
|
max_num_tokens_per_sample=16384, |
|
|
max_num_tokens=36864, |
|
|
prefer_buffer_before=16384, |
|
|
max_buffer_size=50, |
|
|
interpolate_pos=False, |
|
|
use_flex=False, |
|
|
data_status=None, |
|
|
): |
|
|
super().__init__() |
|
|
self.expected_num_tokens = expected_num_tokens |
|
|
self.max_num_tokens_per_sample = max_num_tokens_per_sample |
|
|
self.prefer_buffer_before = prefer_buffer_before |
|
|
self.max_num_tokens = max_num_tokens |
|
|
self.max_buffer_size = max_buffer_size |
|
|
self.tokenizer = tokenizer |
|
|
self.local_rank = local_rank |
|
|
self.world_size = world_size |
|
|
self.num_workers = num_workers |
|
|
self.use_flex = use_flex |
|
|
for k, v in special_tokens.items(): |
|
|
setattr(self, k, v) |
|
|
|
|
|
grouped_datasets, is_mandatory, grouped_weights = self.build_datasets( |
|
|
data_config.grouped_datasets, data_status |
|
|
) |
|
|
self.grouped_datasets = grouped_datasets |
|
|
self.dataset_iters = [iter(dataset) for dataset in grouped_datasets] |
|
|
self.is_mandatory = is_mandatory |
|
|
self.grouped_weights = grouped_weights |
|
|
self.data_config = data_config |
|
|
self.interpolate_pos = interpolate_pos |
|
|
if self.interpolate_pos: |
|
|
self.get_flattened_position_ids = get_flattened_position_ids_interpolate |
|
|
else: |
|
|
self.get_flattened_position_ids = get_flattened_position_ids_extrapolate |
|
|
|
|
|
def build_datasets(self, datasets_metainfo, data_status): |
|
|
datasets = [] |
|
|
is_mandatory = [] |
|
|
grouped_weights = [] |
|
|
for grouped_dataset_name, dataset_args in datasets_metainfo.items(): |
|
|
is_mandatory.append(dataset_args.pop('is_mandatory', False)) |
|
|
grouped_weights.append(dataset_args.pop('weight', 0.0)) |
|
|
|
|
|
if 'frame_sampler_args' in dataset_args.keys(): |
|
|
frame_sampler = FrameSampler(**dataset_args.pop('frame_sampler_args')) |
|
|
dataset_args['frame_sampler'] = frame_sampler |
|
|
if 'image_transform_args' in dataset_args.keys(): |
|
|
transform = ImageTransform(**dataset_args.pop('image_transform_args')) |
|
|
dataset_args['transform'] = transform |
|
|
if 'vit_image_transform_args' in dataset_args.keys(): |
|
|
vit_transform = ImageTransform(**dataset_args.pop('vit_image_transform_args')) |
|
|
dataset_args['vit_transform'] = vit_transform |
|
|
|
|
|
assert 'dataset_names' in dataset_args.keys() |
|
|
dataset_names = dataset_args.pop('dataset_names') |
|
|
dataset_args['data_dir_list'] = [] |
|
|
for item in dataset_names: |
|
|
if self.local_rank == 0: |
|
|
print(f'Preparing Dataset {grouped_dataset_name}/{item}') |
|
|
meta_info = DATASET_INFO[grouped_dataset_name][item] |
|
|
dataset_args['data_dir_list'].append(meta_info['data_dir']) |
|
|
|
|
|
if "parquet_info_path" in meta_info.keys(): |
|
|
if 'parquet_info' not in dataset_args.keys(): |
|
|
dataset_args['parquet_info'] = {} |
|
|
with open(meta_info['parquet_info_path'], 'r') as f: |
|
|
parquet_info = json.load(f) |
|
|
dataset_args['parquet_info'].update(parquet_info) |
|
|
|
|
|
if 'json_dir' in meta_info.keys(): |
|
|
|
|
|
if 'json_dir_list' not in dataset_args.keys(): |
|
|
dataset_args['json_dir_list'] = [meta_info['json_dir']] |
|
|
else: |
|
|
dataset_args['json_dir_list'].append(meta_info['json_dir']) |
|
|
|
|
|
if 'jsonl_path' in meta_info.keys(): |
|
|
|
|
|
if 'jsonl_path_list' not in dataset_args.keys(): |
|
|
dataset_args['jsonl_path_list'] = [meta_info['jsonl_path']] |
|
|
else: |
|
|
dataset_args['jsonl_path_list'].append(meta_info['jsonl_path']) |
|
|
|
|
|
if 'image_prefix_dir' in meta_info.keys(): |
|
|
dataset_args['image_prefix_dir'] = meta_info['image_prefix_dir'] |
|
|
|
|
|
resume_data_status = dataset_args.pop('resume_data_status', True) |
|
|
if data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status: |
|
|
data_status_per_group = data_status[grouped_dataset_name] |
|
|
else: |
|
|
data_status_per_group = None |
|
|
dataset = DATASET_REGISTRY[grouped_dataset_name]( |
|
|
dataset_name=grouped_dataset_name, |
|
|
tokenizer=self.tokenizer, |
|
|
local_rank=self.local_rank, |
|
|
world_size=self.world_size, |
|
|
num_workers=self.num_workers, |
|
|
data_status=data_status_per_group, |
|
|
**dataset_args |
|
|
) |
|
|
datasets.append(dataset) |
|
|
|
|
|
return datasets, is_mandatory, grouped_weights |
|
|
|
|
|
def set_epoch(self, seed): |
|
|
for dataset in self.grouped_datasets: |
|
|
dataset.set_epoch(seed) |
|
|
|
|
|
def set_sequence_status(self): |
|
|
sequence_status = dict( |
|
|
curr = 0, |
|
|
sample_lens = list(), |
|
|
packed_position_ids = list(), |
|
|
nested_attention_masks = list(), |
|
|
split_lens = list(), |
|
|
attn_modes = list(), |
|
|
packed_text_ids = list(), |
|
|
packed_text_indexes = list(), |
|
|
packed_label_ids = list(), |
|
|
ce_loss_indexes = list(), |
|
|
ce_loss_weights = list(), |
|
|
vae_image_tensors = list(), |
|
|
packed_latent_position_ids = list(), |
|
|
vae_latent_shapes = list(), |
|
|
packed_vae_token_indexes = list(), |
|
|
packed_timesteps = list(), |
|
|
mse_loss_indexes = list(), |
|
|
packed_vit_tokens = list(), |
|
|
vit_token_seqlens = list(), |
|
|
packed_vit_position_ids = list(), |
|
|
packed_vit_token_indexes = list(), |
|
|
) |
|
|
return sequence_status |
|
|
|
|
|
def to_tensor(self, sequence_status): |
|
|
data = dict( |
|
|
sequence_length=sum(sequence_status['sample_lens']), |
|
|
sample_lens=sequence_status['sample_lens'], |
|
|
packed_text_ids=torch.tensor(sequence_status['packed_text_ids']), |
|
|
packed_text_indexes=torch.tensor(sequence_status['packed_text_indexes']), |
|
|
packed_position_ids=torch.tensor(sequence_status['packed_position_ids']), |
|
|
) |
|
|
if not self.use_flex: |
|
|
data['nested_attention_masks'] = sequence_status['nested_attention_masks'] |
|
|
else: |
|
|
sequence_len = data['sequence_length'] |
|
|
pad_len = self.max_num_tokens - sequence_len |
|
|
data['split_lens'] = sequence_status['split_lens'] + [pad_len] |
|
|
data['attn_modes'] = sequence_status['attn_modes'] + ['causal'] |
|
|
data['sample_lens'] += [pad_len] |
|
|
|
|
|
|
|
|
if len(sequence_status['vae_image_tensors']) > 0: |
|
|
image_tensors = sequence_status.pop('vae_image_tensors') |
|
|
image_sizes = [item.shape for item in image_tensors] |
|
|
max_image_size = [max(item) for item in list(zip(*image_sizes))] |
|
|
padded_images = torch.zeros(size=(len(image_tensors), *max_image_size)) |
|
|
for i, image_tensor in enumerate(image_tensors): |
|
|
padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor |
|
|
|
|
|
data['padded_images'] = padded_images |
|
|
data['patchified_vae_latent_shapes'] = sequence_status['vae_latent_shapes'] |
|
|
data['packed_latent_position_ids'] = torch.cat(sequence_status['packed_latent_position_ids'], dim=0) |
|
|
data['packed_vae_token_indexes'] = torch.tensor(sequence_status['packed_vae_token_indexes']) |
|
|
|
|
|
|
|
|
if len(sequence_status['packed_vit_tokens']) > 0: |
|
|
data['packed_vit_tokens'] = torch.cat(sequence_status['packed_vit_tokens'], dim=0) |
|
|
data['packed_vit_position_ids'] = torch.cat(sequence_status['packed_vit_position_ids'], dim=0) |
|
|
data['packed_vit_token_indexes'] = torch.tensor(sequence_status['packed_vit_token_indexes']) |
|
|
data['vit_token_seqlens'] = torch.tensor(sequence_status['vit_token_seqlens']) |
|
|
|
|
|
|
|
|
if len(sequence_status['packed_timesteps']) > 0: |
|
|
data['packed_timesteps'] = torch.tensor(sequence_status['packed_timesteps']) |
|
|
data['mse_loss_indexes'] = torch.tensor(sequence_status['mse_loss_indexes']) |
|
|
|
|
|
|
|
|
if len(sequence_status['packed_label_ids']) > 0: |
|
|
data['packed_label_ids'] = torch.tensor(sequence_status['packed_label_ids']) |
|
|
data['ce_loss_indexes'] = torch.tensor(sequence_status['ce_loss_indexes']) |
|
|
data['ce_loss_weights'] = torch.tensor(sequence_status['ce_loss_weights']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return data |
|
|
|
|
|
def print_debug_info(self, data, sequence_status): |
|
|
"""Print detailed debug information in an intuitive table format""" |
|
|
print("\n" + "="*120) |
|
|
print("DEBUG: Complete Sequence Analysis") |
|
|
print("="*120) |
|
|
|
|
|
|
|
|
print(f"Sequence Length: {data['sequence_length']}") |
|
|
print(f"Sample Lengths: {data['sample_lens']}") |
|
|
|
|
|
|
|
|
packed_text_ids = data['packed_text_ids'].tolist() |
|
|
packed_text_indexes = data['packed_text_indexes'].tolist() |
|
|
|
|
|
|
|
|
ce_loss_indexes = set(data.get('ce_loss_indexes', []).tolist()) |
|
|
mse_loss_indexes = set(data.get('mse_loss_indexes', []).tolist()) |
|
|
vit_token_indexes = set(data.get('packed_vit_token_indexes', []).tolist()) |
|
|
vae_token_indexes = set(data.get('packed_vae_token_indexes', []).tolist()) |
|
|
|
|
|
|
|
|
label_mapping = {} |
|
|
if 'ce_loss_indexes' in data: |
|
|
ce_indexes = data['ce_loss_indexes'].tolist() |
|
|
ce_labels = data['packed_label_ids'].tolist() |
|
|
for i, pos in enumerate(ce_indexes): |
|
|
label_mapping[pos] = ce_labels[i] |
|
|
|
|
|
|
|
|
print(f"\n1. Raw Token IDs: {packed_text_ids}") |
|
|
|
|
|
|
|
|
try: |
|
|
decoded_text_tokens = [] |
|
|
for token_id in packed_text_ids: |
|
|
decoded = self.tokenizer.decode([token_id]) |
|
|
decoded_text_tokens.append(decoded) |
|
|
print(f"2. Decoded Tokens: {decoded_text_tokens}") |
|
|
except Exception as e: |
|
|
print(f"2. Error decoding tokens: {e}") |
|
|
decoded_text_tokens = ["<ERROR>"] * len(packed_text_ids) |
|
|
|
|
|
|
|
|
print(f"\n3. Complete Sequence Table:") |
|
|
print("-" * 120) |
|
|
print(f"{'Order':<6} | {'Token Type':<12} | {'Token/Content':<30} | {'Loss Type':<10} | {'Label':<30} | {'Notes':<20}") |
|
|
print("-" * 120) |
|
|
|
|
|
|
|
|
text_token_idx = 0 |
|
|
|
|
|
for pos in range(data['sequence_length']): |
|
|
|
|
|
if pos in packed_text_indexes: |
|
|
|
|
|
token_id = packed_text_ids[text_token_idx] |
|
|
try: |
|
|
decoded_token = self.tokenizer.decode([token_id]) |
|
|
token_content = f"ID:{token_id} '{decoded_token}'" |
|
|
except: |
|
|
token_content = f"ID:{token_id} '<ERROR>'" |
|
|
token_type = "TEXT" |
|
|
text_token_idx += 1 |
|
|
|
|
|
elif pos in vit_token_indexes: |
|
|
token_type = "VIT_IMAGE" |
|
|
token_content = "[VIT Image Patch]" |
|
|
|
|
|
elif pos in vae_token_indexes: |
|
|
token_type = "VAE_IMAGE" |
|
|
token_content = "[VAE Image Latent]" |
|
|
|
|
|
else: |
|
|
token_type = "UNKNOWN" |
|
|
token_content = "[Unknown Position]" |
|
|
|
|
|
|
|
|
if pos in ce_loss_indexes: |
|
|
loss_type = "CE" |
|
|
elif pos in mse_loss_indexes: |
|
|
loss_type = "MSE" |
|
|
else: |
|
|
loss_type = "None" |
|
|
|
|
|
|
|
|
if pos in label_mapping: |
|
|
label_id = label_mapping[pos] |
|
|
try: |
|
|
decoded_label = self.tokenizer.decode([label_id]) |
|
|
label_content = f"ID:{label_id} '{decoded_label}'" |
|
|
except: |
|
|
label_content = f"ID:{label_id} '<ERROR>'" |
|
|
elif pos in mse_loss_indexes: |
|
|
label_content = "[Image Generation Target]" |
|
|
else: |
|
|
label_content = "N/A" |
|
|
|
|
|
|
|
|
notes = "" |
|
|
if pos in mse_loss_indexes and 'packed_timesteps' in data: |
|
|
timestep_idx = list(mse_loss_indexes).index(pos) if pos in mse_loss_indexes else -1 |
|
|
if timestep_idx >= 0 and timestep_idx < len(data['packed_timesteps']): |
|
|
timestep = data['packed_timesteps'][timestep_idx].item() |
|
|
if timestep == float('-inf'): |
|
|
notes = "No noise" |
|
|
else: |
|
|
notes = f"t={timestep:.3f}" |
|
|
|
|
|
print(f"{pos:<6} | {token_type:<12} | {token_content:<30} | {loss_type:<10} | {label_content:<30} | {notes:<20}") |
|
|
|
|
|
print("-" * 120) |
|
|
|
|
|
|
|
|
total_positions = data['sequence_length'] |
|
|
ce_positions = len(ce_loss_indexes) |
|
|
mse_positions = len(mse_loss_indexes) |
|
|
vit_positions = len(vit_token_indexes) |
|
|
vae_positions = len(vae_token_indexes) |
|
|
text_positions = len(packed_text_indexes) |
|
|
no_loss_positions = total_positions - ce_positions - mse_positions |
|
|
|
|
|
print(f"\nSummary Statistics:") |
|
|
print(f" Total positions: {total_positions}") |
|
|
print(f" Text tokens: {text_positions} ({text_positions/total_positions*100:.1f}%)") |
|
|
print(f" VIT image tokens: {vit_positions} ({vit_positions/total_positions*100:.1f}%)") |
|
|
print(f" VAE image tokens: {vae_positions} ({vae_positions/total_positions*100:.1f}%)") |
|
|
print(f" Positions with CE loss: {ce_positions} ({ce_positions/total_positions*100:.1f}%)") |
|
|
print(f" Positions with MSE loss: {mse_positions} ({mse_positions/total_positions*100:.1f}%)") |
|
|
print(f" Positions with no loss: {no_loss_positions} ({no_loss_positions/total_positions*100:.1f}%)") |
|
|
|
|
|
print("="*120 + "\n") |
|
|
|
|
|
def __iter__(self): |
|
|
total_weights = sum(self.grouped_weights) |
|
|
assert total_weights > 0.0 |
|
|
group_cumprobs = [sum(self.grouped_weights[:i + 1]) / total_weights |
|
|
for i in range(len(self.grouped_weights))] |
|
|
sequence_status = self.set_sequence_status() |
|
|
batch_data_indexes = [] |
|
|
|
|
|
buffer = [] |
|
|
while True: |
|
|
|
|
|
if sequence_status['curr'] == 0: |
|
|
for group_index, group_iter in enumerate(self.dataset_iters): |
|
|
if self.is_mandatory[group_index]: |
|
|
while True: |
|
|
sample = next(group_iter) |
|
|
|
|
|
num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan']) |
|
|
if num_tokens < self.max_num_tokens_per_sample: |
|
|
sequence_status = self.pack_sequence(sample, sequence_status) |
|
|
batch_data_indexes.append(sample['data_indexes']) |
|
|
break |
|
|
else: |
|
|
print(f"skip a sample with length {num_tokens}") |
|
|
continue |
|
|
|
|
|
if sequence_status['curr'] < self.prefer_buffer_before and len(buffer) > 0: |
|
|
sample = buffer.pop(0) |
|
|
sample_from_buffer = True |
|
|
else: |
|
|
|
|
|
n = random.random() |
|
|
group_index = 0 |
|
|
for i, cumprob in enumerate(group_cumprobs): |
|
|
if n < cumprob: |
|
|
group_index = i |
|
|
break |
|
|
sample = next(self.dataset_iters[group_index]) |
|
|
sample_from_buffer = False |
|
|
|
|
|
|
|
|
num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan']) |
|
|
if num_tokens > self.max_num_tokens_per_sample: |
|
|
print(f"skip a sample with length {num_tokens}") |
|
|
continue |
|
|
|
|
|
if sequence_status['curr'] + num_tokens > self.max_num_tokens: |
|
|
if len(buffer) < self.max_buffer_size and not sample_from_buffer: |
|
|
buffer.append(sample) |
|
|
else: |
|
|
print(f"Yielding data with length {sum(sequence_status['sample_lens'])}") |
|
|
data = self.to_tensor(sequence_status) |
|
|
data['batch_data_indexes'] = batch_data_indexes |
|
|
yield data |
|
|
sequence_status = self.set_sequence_status() |
|
|
batch_data_indexes = [] |
|
|
continue |
|
|
|
|
|
sequence_status = self.pack_sequence(sample, sequence_status) |
|
|
batch_data_indexes.append(sample['data_indexes']) |
|
|
|
|
|
if sequence_status['curr'] >= self.expected_num_tokens: |
|
|
data = self.to_tensor(sequence_status) |
|
|
data['batch_data_indexes'] = batch_data_indexes |
|
|
yield data |
|
|
sequence_status = self.set_sequence_status() |
|
|
batch_data_indexes = [] |
|
|
|
|
|
def pack_sequence(self, sample, sequence_status): |
|
|
image_tensor_list = sample['image_tensor_list'] |
|
|
text_ids_list = sample['text_ids_list'] |
|
|
sequence_plan = sample['sequence_plan'] |
|
|
|
|
|
split_lens, attn_modes = list(), list() |
|
|
curr = sequence_status['curr'] |
|
|
curr_rope_id = 0 |
|
|
sample_lens = 0 |
|
|
|
|
|
for item in sequence_plan: |
|
|
split_start = item.get('split_start', True) |
|
|
if split_start: |
|
|
curr_split_len = 0 |
|
|
|
|
|
if item['type'] == 'text': |
|
|
text_ids = text_ids_list.pop(0) |
|
|
if item['enable_cfg'] == 1 and random.random() < self.data_config.text_cond_dropout_prob: |
|
|
continue |
|
|
|
|
|
shifted_text_ids = [self.bos_token_id] + text_ids |
|
|
sequence_status['packed_text_ids'].extend(shifted_text_ids) |
|
|
sequence_status['packed_text_indexes'].extend(range(curr, curr + len(shifted_text_ids))) |
|
|
if item['loss'] == 1: |
|
|
sequence_status['ce_loss_indexes'].extend(range(curr, curr + len(shifted_text_ids))) |
|
|
sequence_status['ce_loss_weights'].extend( |
|
|
[len2weight(len(shifted_text_ids))] * len(shifted_text_ids) |
|
|
) |
|
|
sequence_status['packed_label_ids'].extend(text_ids + [self.eos_token_id]) |
|
|
curr += len(shifted_text_ids) |
|
|
curr_split_len += len(shifted_text_ids) |
|
|
|
|
|
|
|
|
sequence_status['packed_text_ids'].append(self.eos_token_id) |
|
|
sequence_status['packed_text_indexes'].append(curr) |
|
|
if item['special_token_loss'] == 1: |
|
|
sequence_status['ce_loss_indexes'].append(curr) |
|
|
sequence_status['ce_loss_weights'].append(1.0) |
|
|
sequence_status['packed_label_ids'].append(item['special_token_label']) |
|
|
curr += 1 |
|
|
curr_split_len += 1 |
|
|
|
|
|
|
|
|
attn_modes.append("causal") |
|
|
sequence_status['packed_position_ids'].extend(range(curr_rope_id, curr_rope_id + curr_split_len)) |
|
|
curr_rope_id += curr_split_len |
|
|
|
|
|
elif item['type'] == 'vit_image': |
|
|
image_tensor = image_tensor_list.pop(0) |
|
|
if item['enable_cfg'] == 1 and random.random() < self.data_config.vit_cond_dropout_prob: |
|
|
curr_rope_id += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
sequence_status['packed_text_ids'].append(self.start_of_image) |
|
|
sequence_status['packed_text_indexes'].append(curr) |
|
|
curr += 1 |
|
|
curr_split_len += 1 |
|
|
|
|
|
|
|
|
vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size) |
|
|
num_img_tokens = vit_tokens.shape[0] |
|
|
sequence_status['packed_vit_token_indexes'].extend(range(curr, curr + num_img_tokens)) |
|
|
curr += num_img_tokens |
|
|
curr_split_len += num_img_tokens |
|
|
|
|
|
sequence_status['packed_vit_tokens'].append(vit_tokens) |
|
|
sequence_status['vit_token_seqlens'].append(num_img_tokens) |
|
|
sequence_status['packed_vit_position_ids'].append( |
|
|
self.get_flattened_position_ids( |
|
|
image_tensor.size(1), image_tensor.size(2), |
|
|
self.data_config.vit_patch_size, |
|
|
max_num_patches_per_side=self.data_config.max_num_patch_per_side |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
sequence_status['packed_text_ids'].append(self.end_of_image) |
|
|
sequence_status['packed_text_indexes'].append(curr) |
|
|
if item['special_token_loss'] == 1: |
|
|
sequence_status['ce_loss_indexes'].append(curr) |
|
|
sequence_status['ce_loss_weights'].append(1.0) |
|
|
sequence_status['packed_label_ids'].append(item['special_token_label']) |
|
|
curr += 1 |
|
|
curr_split_len += 1 |
|
|
|
|
|
|
|
|
attn_modes.append("full") |
|
|
sequence_status['packed_position_ids'].extend([curr_rope_id] * curr_split_len) |
|
|
curr_rope_id += 1 |
|
|
|
|
|
elif item['type'] == 'vae_image': |
|
|
image_tensor = image_tensor_list.pop(0) |
|
|
if item['enable_cfg'] == 1 and random.random() < self.data_config.vae_cond_dropout_prob: |
|
|
|
|
|
curr_rope_id += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sequence_status['packed_text_ids'].append(self.start_of_image) |
|
|
sequence_status['packed_text_indexes'].append(curr) |
|
|
|
|
|
if item['special_token_loss'] == 1: |
|
|
sequence_status['ce_loss_indexes'].append(curr) |
|
|
sequence_status['ce_loss_weights'].append(1.0) |
|
|
sequence_status['packed_label_ids'].append(item['special_token_label']) |
|
|
|
|
|
curr += 1 |
|
|
curr_split_len += 1 |
|
|
|
|
|
|
|
|
sequence_status['vae_image_tensors'].append(image_tensor) |
|
|
sequence_status['packed_latent_position_ids'].append( |
|
|
self.get_flattened_position_ids( |
|
|
image_tensor.size(1), image_tensor.size(2), |
|
|
self.data_config.vae_image_downsample, |
|
|
max_num_patches_per_side=self.data_config.max_latent_size |
|
|
) |
|
|
) |
|
|
H, W = image_tensor.shape[1:] |
|
|
h = H // self.data_config.vae_image_downsample |
|
|
w = W // self.data_config.vae_image_downsample |
|
|
sequence_status['vae_latent_shapes'].append((h, w)) |
|
|
|
|
|
num_img_tokens = w * h |
|
|
sequence_status['packed_vae_token_indexes'].extend(range(curr, curr + num_img_tokens)) |
|
|
if item['loss'] == 1: |
|
|
sequence_status['mse_loss_indexes'].extend(range(curr, curr + num_img_tokens)) |
|
|
if split_start: |
|
|
timestep = np.random.randn() |
|
|
else: |
|
|
timestep = float('-inf') |
|
|
|
|
|
sequence_status['packed_timesteps'].extend([timestep] * num_img_tokens) |
|
|
curr += num_img_tokens |
|
|
curr_split_len += num_img_tokens |
|
|
|
|
|
|
|
|
sequence_status['packed_text_ids'].append(self.end_of_image) |
|
|
sequence_status['packed_text_indexes'].append(curr) |
|
|
|
|
|
if item['special_token_loss'] == 1: |
|
|
sequence_status['ce_loss_indexes'].append(curr) |
|
|
sequence_status['ce_loss_weights'].append(1.0) |
|
|
sequence_status['packed_label_ids'].append(item['special_token_label']) |
|
|
curr += 1 |
|
|
curr_split_len += 1 |
|
|
|
|
|
|
|
|
if split_start: |
|
|
if item['loss'] == 1 and 'frame_delta' not in item.keys(): |
|
|
attn_modes.append("noise") |
|
|
else: |
|
|
attn_modes.append("full") |
|
|
sequence_status['packed_position_ids'].extend([curr_rope_id] * (num_img_tokens + 2)) |
|
|
if 'frame_delta' in item.keys(): |
|
|
curr_rope_id += item['frame_delta'] |
|
|
elif item['loss'] == 0: |
|
|
curr_rope_id += 1 |
|
|
|
|
|
if item.get('split_end', True): |
|
|
split_lens.append(curr_split_len) |
|
|
sample_lens += curr_split_len |
|
|
|
|
|
sequence_status['curr'] = curr |
|
|
sequence_status['sample_lens'].append(sample_lens) |
|
|
|
|
|
if not self.use_flex: |
|
|
sequence_status['nested_attention_masks'].append( |
|
|
prepare_attention_mask_per_sample(split_lens, attn_modes) |
|
|
) |
|
|
else: |
|
|
sequence_status['split_lens'].extend(split_lens) |
|
|
sequence_status['attn_modes'].extend(attn_modes) |
|
|
|
|
|
return sequence_status |
|
|
|
|
|
|
|
|
class SimpleCustomBatch: |
|
|
def __init__(self, batch): |
|
|
data = batch[0] |
|
|
self.batch_data_indexes = data['batch_data_indexes'] |
|
|
self.sequence_length = data["sequence_length"] |
|
|
self.sample_lens = data["sample_lens"] |
|
|
self.packed_text_ids = data["packed_text_ids"] |
|
|
self.packed_text_indexes = data["packed_text_indexes"] |
|
|
self.packed_position_ids = data["packed_position_ids"] |
|
|
|
|
|
self.use_flex = "nested_attention_masks" not in data.keys() |
|
|
|
|
|
if self.use_flex: |
|
|
self.split_lens = data["split_lens"] |
|
|
self.attn_modes = data["attn_modes"] |
|
|
else: |
|
|
self.nested_attention_masks = data["nested_attention_masks"] |
|
|
|
|
|
if "padded_images" in data.keys(): |
|
|
self.padded_images = data["padded_images"] |
|
|
self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"] |
|
|
self.packed_latent_position_ids = data["packed_latent_position_ids"] |
|
|
self.packed_vae_token_indexes = data["packed_vae_token_indexes"] |
|
|
|
|
|
if "packed_vit_tokens" in data.keys(): |
|
|
self.packed_vit_tokens = data["packed_vit_tokens"] |
|
|
self.packed_vit_position_ids = data["packed_vit_position_ids"] |
|
|
self.packed_vit_token_indexes = data["packed_vit_token_indexes"] |
|
|
self.vit_token_seqlens = data["vit_token_seqlens"] |
|
|
|
|
|
if "packed_timesteps" in data.keys(): |
|
|
self.packed_timesteps = data["packed_timesteps"] |
|
|
self.mse_loss_indexes = data["mse_loss_indexes"] |
|
|
|
|
|
if "packed_label_ids" in data.keys(): |
|
|
self.packed_label_ids = data["packed_label_ids"] |
|
|
self.ce_loss_indexes = data["ce_loss_indexes"] |
|
|
self.ce_loss_weights = data["ce_loss_weights"] |
|
|
|
|
|
def pin_memory(self): |
|
|
self.packed_text_ids = self.packed_text_ids.pin_memory() |
|
|
self.packed_text_indexes = self.packed_text_indexes.pin_memory() |
|
|
self.packed_position_ids = self.packed_position_ids.pin_memory() |
|
|
|
|
|
if not self.use_flex: |
|
|
self.nested_attention_masks = [item.pin_memory() for item in self.nested_attention_masks] |
|
|
|
|
|
if hasattr(self, 'padded_images'): |
|
|
self.padded_images = self.padded_images.pin_memory() |
|
|
self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory() |
|
|
self.packed_latent_position_ids = self.packed_latent_position_ids.pin_memory() |
|
|
|
|
|
if hasattr(self, 'packed_timesteps'): |
|
|
self.packed_timesteps = self.packed_timesteps.pin_memory() |
|
|
self.mse_loss_indexes = self.mse_loss_indexes.pin_memory() |
|
|
|
|
|
if hasattr(self, 'packed_vit_tokens'): |
|
|
self.packed_vit_tokens = self.packed_vit_tokens.pin_memory() |
|
|
self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory() |
|
|
self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory() |
|
|
self.vit_token_seqlens = self.vit_token_seqlens.pin_memory() |
|
|
|
|
|
if hasattr(self, 'packed_label_ids'): |
|
|
self.packed_label_ids = self.packed_label_ids.pin_memory() |
|
|
self.ce_loss_indexes = self.ce_loss_indexes.pin_memory() |
|
|
self.ce_loss_weights = self.ce_loss_weights.pin_memory() |
|
|
|
|
|
return self |
|
|
|
|
|
def cuda(self, device): |
|
|
self.packed_text_ids = self.packed_text_ids.to(device) |
|
|
self.packed_text_indexes = self.packed_text_indexes.to(device) |
|
|
self.packed_position_ids = self.packed_position_ids.to(device) |
|
|
|
|
|
if not self.use_flex: |
|
|
self.nested_attention_masks = [item.to(device) for item in self.nested_attention_masks] |
|
|
|
|
|
if hasattr(self, 'padded_images'): |
|
|
self.padded_images = self.padded_images.to(device) |
|
|
self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device) |
|
|
self.packed_latent_position_ids = self.packed_latent_position_ids.to(device) |
|
|
|
|
|
if hasattr(self, 'packed_timesteps'): |
|
|
self.packed_timesteps = self.packed_timesteps.to(device) |
|
|
self.mse_loss_indexes = self.mse_loss_indexes.to(device) |
|
|
|
|
|
if hasattr(self, 'packed_vit_tokens'): |
|
|
self.packed_vit_tokens = self.packed_vit_tokens.to(device) |
|
|
self.packed_vit_position_ids = self.packed_vit_position_ids.to(device) |
|
|
self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device) |
|
|
self.vit_token_seqlens = self.vit_token_seqlens.to(device) |
|
|
|
|
|
if hasattr(self, 'packed_label_ids'): |
|
|
self.packed_label_ids = self.packed_label_ids.to(device) |
|
|
self.ce_loss_indexes = self.ce_loss_indexes.to(device) |
|
|
self.ce_loss_weights = self.ce_loss_weights.to(device) |
|
|
|
|
|
return self |
|
|
|
|
|
def to_dict(self): |
|
|
data = dict( |
|
|
sequence_length = self.sequence_length, |
|
|
sample_lens = self.sample_lens, |
|
|
packed_text_ids = self.packed_text_ids, |
|
|
packed_text_indexes = self.packed_text_indexes, |
|
|
packed_position_ids = self.packed_position_ids, |
|
|
batch_data_indexes = self.batch_data_indexes, |
|
|
) |
|
|
|
|
|
if not self.use_flex: |
|
|
data['nested_attention_masks'] = self.nested_attention_masks |
|
|
else: |
|
|
data['split_lens'] = self.split_lens |
|
|
data['attn_modes'] = self.attn_modes |
|
|
|
|
|
if hasattr(self, 'padded_images'): |
|
|
data['padded_images'] = self.padded_images |
|
|
data['patchified_vae_latent_shapes'] = self.patchified_vae_latent_shapes |
|
|
data['packed_latent_position_ids'] = self.packed_latent_position_ids |
|
|
data['packed_vae_token_indexes'] = self.packed_vae_token_indexes |
|
|
|
|
|
if hasattr(self, 'packed_vit_tokens'): |
|
|
data['packed_vit_tokens'] = self.packed_vit_tokens |
|
|
data['packed_vit_position_ids'] = self.packed_vit_position_ids |
|
|
data['packed_vit_token_indexes'] = self.packed_vit_token_indexes |
|
|
data['vit_token_seqlens'] = self.vit_token_seqlens |
|
|
|
|
|
if hasattr(self, 'packed_timesteps'): |
|
|
data['packed_timesteps'] = self.packed_timesteps |
|
|
data['mse_loss_indexes'] = self.mse_loss_indexes |
|
|
|
|
|
if hasattr(self, 'packed_label_ids'): |
|
|
data['packed_label_ids'] = self.packed_label_ids |
|
|
data['ce_loss_indexes'] = self.ce_loss_indexes |
|
|
data['ce_loss_weights'] = self.ce_loss_weights |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
def collate_wrapper(): |
|
|
def collate_fn(batch): |
|
|
return SimpleCustomBatch(batch) |
|
|
return collate_fn |