|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import os |
|
|
import re |
|
|
from collections import defaultdict |
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import datasets |
|
|
import numpy as np |
|
|
import torch |
|
|
from omegaconf import DictConfig, ListConfig |
|
|
from torch.utils.data import Dataset |
|
|
from transformers import PreTrainedTokenizer, ProcessorMixin |
|
|
|
|
|
import verl.utils.torch_functional as verl_F |
|
|
from verl.utils.model import compute_position_id_with_mask |
|
|
|
|
|
|
|
|
def collate_fn(data_list: list[dict]) -> dict: |
|
|
tensors = defaultdict(list) |
|
|
non_tensors = defaultdict(list) |
|
|
|
|
|
for data in data_list: |
|
|
for key, val in data.items(): |
|
|
if isinstance(val, torch.Tensor): |
|
|
tensors[key].append(val) |
|
|
else: |
|
|
non_tensors[key].append(val) |
|
|
|
|
|
for key, val in tensors.items(): |
|
|
tensors[key] = torch.stack(val, dim=0) |
|
|
|
|
|
for key, val in non_tensors.items(): |
|
|
non_tensors[key] = np.array(val, dtype=object) |
|
|
|
|
|
return {**tensors, **non_tensors} |
|
|
|
|
|
|
|
|
class RLHFDataset(Dataset): |
|
|
""" |
|
|
We assume the dataset contains a column that contains prompts and other information |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_files: Union[str, List[str]], |
|
|
tokenizer: PreTrainedTokenizer, |
|
|
config: DictConfig, |
|
|
processor: Optional[ProcessorMixin] = None, |
|
|
): |
|
|
if not isinstance(data_files, (List, ListConfig)): |
|
|
data_files = [data_files] |
|
|
|
|
|
self.data_files = copy.deepcopy(data_files) |
|
|
self.original_data_files = copy.deepcopy(data_files) |
|
|
self.tokenizer = tokenizer |
|
|
self.processor = processor |
|
|
self.config = config |
|
|
|
|
|
self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) |
|
|
self.prompt_key = config.get("prompt_key", "prompt") |
|
|
self.image_key = config.get("image_key", "images") |
|
|
self.video_key = config.get("video_key", "videos") |
|
|
self.max_prompt_length = config.get("max_prompt_length", 1024) |
|
|
|
|
|
self.return_raw_chat = config.get("return_raw_chat", False) |
|
|
self.truncation = config.get("truncation", "error") |
|
|
self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) |
|
|
|
|
|
self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) |
|
|
self.num_workers = min(self.num_workers, os.cpu_count()) |
|
|
|
|
|
|
|
|
|
|
|
self.serialize_dataset = False |
|
|
self._download() |
|
|
self._read_files_and_tokenize() |
|
|
|
|
|
def _download(self, use_origin_parquet=False): |
|
|
from verl.utils.fs import copy_to_local |
|
|
|
|
|
data_files = self.data_files if not use_origin_parquet else self.original_data_files |
|
|
for i, parquet_file in enumerate(data_files): |
|
|
self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir) |
|
|
|
|
|
def _read_files_and_tokenize(self): |
|
|
dataframes = [] |
|
|
for parquet_file in self.data_files: |
|
|
|
|
|
dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] |
|
|
dataframes.append(dataframe) |
|
|
self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) |
|
|
|
|
|
print(f"dataset len: {len(self.dataframe)}") |
|
|
|
|
|
|
|
|
if self.filter_overlong_prompts: |
|
|
tokenizer = self.tokenizer |
|
|
prompt_key = self.prompt_key |
|
|
self.dataframe = self.dataframe.filter( |
|
|
lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) |
|
|
<= self.max_prompt_length, |
|
|
num_proc=self.num_workers, |
|
|
desc=f"Filtering prompts longer than {self.max_prompt_length} tokens", |
|
|
) |
|
|
|
|
|
print(f"filter dataset len: {len(self.dataframe)}") |
|
|
|
|
|
def resume_dataset_state(self): |
|
|
self.serialize_dataset = not hasattr(self, "original_data_files") |
|
|
|
|
|
if not self.serialize_dataset: |
|
|
self._download(use_origin_parquet=True) |
|
|
self._read_files_and_tokenize() |
|
|
else: |
|
|
print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataframe) |
|
|
|
|
|
def _build_messages(self, example: dict): |
|
|
messages: list = example.pop(self.prompt_key) |
|
|
|
|
|
if self.image_key in example or self.video_key in example: |
|
|
for message in messages: |
|
|
content = message["content"] |
|
|
content_list = [] |
|
|
for segment in re.split("(<image>|<video>)", content): |
|
|
if segment == "<image>": |
|
|
content_list.append({"type": "image"}) |
|
|
elif segment == "<video>": |
|
|
content_list.append({"type": "video"}) |
|
|
else: |
|
|
content_list.append({"type": "text", "text": segment}) |
|
|
|
|
|
message["content"] = content_list |
|
|
|
|
|
return messages |
|
|
|
|
|
def __getitem__(self, item): |
|
|
""" |
|
|
Note that we also return the raw_input_ids so that it can be combined with other chat template |
|
|
""" |
|
|
row_dict: dict = self.dataframe[item] |
|
|
messages = self._build_messages(row_dict) |
|
|
model_inputs = {} |
|
|
|
|
|
if self.processor is not None: |
|
|
from verl.utils.dataset.vision_utils import process_image, process_video |
|
|
|
|
|
raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
|
multi_modal_data = {} |
|
|
|
|
|
images = None |
|
|
if self.image_key in row_dict: |
|
|
images = [process_image(image) for image in row_dict.pop(self.image_key)] |
|
|
multi_modal_data["image"] = images |
|
|
|
|
|
videos = None |
|
|
if self.video_key in row_dict: |
|
|
videos = [process_video(video) for video in row_dict.pop(self.video_key)] |
|
|
multi_modal_data["video"] = [video.numpy() for video in videos] |
|
|
|
|
|
model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt") |
|
|
|
|
|
input_ids = model_inputs.pop("input_ids") |
|
|
attention_mask = model_inputs.pop("attention_mask") |
|
|
|
|
|
if "second_per_grid_ts" in model_inputs: |
|
|
model_inputs.pop("second_per_grid_ts") |
|
|
|
|
|
|
|
|
row_dict["multi_modal_data"] = multi_modal_data |
|
|
row_dict["multi_modal_inputs"] = dict(model_inputs) |
|
|
|
|
|
|
|
|
row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None) |
|
|
|
|
|
else: |
|
|
raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
|
model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False) |
|
|
input_ids = model_inputs.pop("input_ids") |
|
|
attention_mask = model_inputs.pop("attention_mask") |
|
|
|
|
|
input_ids, attention_mask = verl_F.postprocess_data( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
max_length=self.max_prompt_length, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
left_pad=True, |
|
|
truncation=self.truncation, |
|
|
) |
|
|
|
|
|
if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor": |
|
|
from verl.models.transformers.qwen2_vl import get_rope_index |
|
|
|
|
|
position_ids = [ |
|
|
get_rope_index( |
|
|
self.processor, |
|
|
input_ids=input_ids[0], |
|
|
image_grid_thw=model_inputs.get("image_grid_thw"), |
|
|
video_grid_thw=model_inputs.get("video_grid_thw"), |
|
|
second_per_grid_ts=model_inputs.get("second_per_grid_ts"), |
|
|
attention_mask=attention_mask[0], |
|
|
) |
|
|
] |
|
|
|
|
|
else: |
|
|
position_ids = compute_position_id_with_mask(attention_mask) |
|
|
|
|
|
row_dict["input_ids"] = input_ids[0] |
|
|
row_dict["attention_mask"] = attention_mask[0] |
|
|
row_dict["position_ids"] = position_ids[0] |
|
|
|
|
|
raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False) |
|
|
if len(raw_prompt_ids) > self.max_prompt_length: |
|
|
if self.truncation == "left": |
|
|
raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :] |
|
|
elif self.truncation == "right": |
|
|
raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] |
|
|
elif self.truncation == "error": |
|
|
raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") |
|
|
|
|
|
row_dict["raw_prompt_ids"] = raw_prompt_ids |
|
|
|
|
|
if self.return_raw_chat: |
|
|
row_dict["raw_prompt"] = messages |
|
|
|
|
|
|
|
|
index = row_dict.get("extra_info", {}).get("index", 0) |
|
|
row_dict["index"] = index |
|
|
|
|
|
return row_dict |
|
|
|
|
|
def __getstate__(self): |
|
|
if not self.serialize_dataset: |
|
|
state = self.__dict__.copy() |
|
|
|
|
|
if "dataframe" in state: |
|
|
del state["dataframe"] |
|
|
return state |
|
|
|
|
|
return self.__dict__.copy() |
|
|
|