#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright @2023 AI, ZHIHU Inc. (zhihu.com) # # @author: ouzebin # @date: 2023/09/27 import copy import ctypes import functools import importlib import json import logging import os import random from collections import defaultdict from collections import OrderedDict from multiprocessing import Lock from multiprocessing import Process from multiprocessing.shared_memory import SharedMemory from typing import Any from typing import Callable from typing import Dict from typing import Iterable from typing import Iterator from typing import List from typing import Optional from typing import Set from typing import Tuple from typing import Union import bmtrain as bmt import numpy as np import torch from numpy.typing import NDArray from fm9g.dataset import PrefetchDecodeDataset from fm9g.utils.bitset import BitSet from fm9g.utils.vdc_sampling import van_der_corput from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen logger = logging.getLogger(__name__) IGNORE_TGT = -100 def load_dataset_cfgs(cfg_path, cfg_json_str=None): if cfg_json_str is not None: cfgs = json.loads(cfg_json_str) else: with open(cfg_path, "r", encoding="utf-8") as fin: cfgs = json.load(fin) transform_basedir = os.path.dirname(os.path.abspath(cfg_path)) path_dict = None platform_config_path = os.getenv("PLATFORM_CONFIG_PATH") try: with open(platform_config_path, "r") as f: platform_cfg = json.load(f) path_dict = platform_cfg["dataset_map"] if bmt.rank() == 0: logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...") except Exception as e: if bmt.rank() == 0: logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}") task_name2dataset_name = dict() for idx, cfg in enumerate(cfgs): assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str) assert "task_name" in cfg and isinstance(cfg["task_name"], str) # to be delibrately annoying :) if cfg["task_name"] in task_name2dataset_name: raise ValueError( f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'" f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'." ) task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"] assert "path" in cfg and isinstance(cfg["path"], str) # if path_dict is not None: # cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"]) # dealing with optional configs if "weight" in cfg: assert isinstance(cfg["weight"], (float, int)) else: cfg["weight"] = 1.0 if "oversize_rule" in cfg: assert cfg["oversize_rule"] in ("drop", "head", "segment") else: cfg["oversize_rule"] = "segment" if "transforms" in cfg: assert isinstance(cfg["transforms"], str) # dealing with relative path if not cfg["transforms"].startswith("/"): cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"]) if not cfg["transforms"]: cfg["transforms"] = None else: cfg["transforms"] = None if "incontext_weight" in cfg: assert isinstance(cfg["incontext_weight"], (list, tuple)) else: cfg["incontext_weight"] = [1.0] cfg["id"] = idx # dataset and iterator will be built return cfgs def data2ids(data, tokenizer, max_length): text = "\n".join( [ data.get("title", "").strip(), data.get("question", "").strip(), data.get("answer", "").strip(), data.get("abstract", "").strip(), data.get("text", "").strip(), data.get("code", "").strip(), ] ).strip() if not text: logger.warning(f"Warning: skip invalid sample without valid fields: {data}") yield from () return # suppress the annoying warning from tokenizer ids = ( [tokenizer.bos_token_id] + tokenizer.encode(text, max_length=int(1e12), truncation=True) + [tokenizer.eos_token_id] ) src_ids = ids[0:-1] tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation. if len(src_ids) > max_length: for st in range(0, len(src_ids), max_length): yield src_ids[st : st + max_length], tgt_ids[st : st + max_length] else: yield src_ids, tgt_ids def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False): assert oversize_rule in ("drop", "head", "segment") if data is None: yield from () return if "output" not in data or not data["output"]: yield from () return if "input" not in data or data["input"] is None: data["input"] = "" src_ids = [tokenizer.bos_token_id] tgt_ids = [] has_input = False is_segment_reenter = False # Use incremental tokenization to avoid waiting for a long document MAX_CHUNK_LENGTH = max_length * 10 for part in ("input", "output"): l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part])) while l < len(data[part]): try: current_slice = data[part][l:r] if not current_slice: break token_ids = tokenizer.encode(current_slice, add_special_tokens=False) except: print("Error in data[part][l:r] {}".format(data)) yield from () return if part == "input": if len(token_ids) > 0: has_input = True if len(token_ids) >= max_length - 2: # input len must < max_length yield from () return src_ids.extend(token_ids) tgt_ids.extend([IGNORE_TGT] * len(token_ids)) l = r r = min(len(data[part]), l + MAX_CHUNK_LENGTH) else: if len(token_ids) + len(tgt_ids) >= max_length: if oversize_rule == "drop": yield from () return elif oversize_rule == "head": selected_token_ids = token_ids[: max_length - len(src_ids) + 1] src_ids.extend(selected_token_ids[:-1]) tgt_ids.extend(selected_token_ids) assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})" yield src_ids[:max_length], tgt_ids[:max_length] return elif oversize_rule == "segment": instruction_rest_space = max_length - 1 - len(token_ids) if has_input: # is instruction data if ( do_compact and len(src_ids) >= 128 # avoid too short instruction info lost and instruction_rest_space / len(src_ids) > 0.8 ): # can be squeezed into max length inputs_len = len(src_ids) keep_len = instruction_rest_space // 2 src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :] tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1) src_ids.extend(token_ids) tgt_ids.extend(token_ids) tgt_ids.append(tokenizer.eos_token_id) assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}" assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})" yield src_ids, tgt_ids else: # else use head rule selected_token_ids = token_ids[: max_length - len(src_ids) + 1] src_ids.extend(selected_token_ids[:-1]) tgt_ids.extend(selected_token_ids) assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})" yield src_ids[:max_length], tgt_ids[:max_length] return else: # normal segment selected_token_ids = token_ids[: max_length - len(src_ids) + 1] src_ids.extend(selected_token_ids) tgt_ids.extend(selected_token_ids) assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}" assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}" yield src_ids[:max_length], tgt_ids[:max_length] src_ids = src_ids[max_length:] tgt_ids = tgt_ids[max_length:] # sliding input str window consumed_str = tokenizer.decode(selected_token_ids) l += len(consumed_str) r = min(len(data[part]), l + MAX_CHUNK_LENGTH) is_segment_reenter = True else: if (is_segment_reenter and len(token_ids) > 8) or ( not is_segment_reenter and len(token_ids) > 0 ): # is segmented LM data src_ids.extend(token_ids) tgt_ids.extend(token_ids) tgt_ids.append(tokenizer.eos_token_id) assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})" yield src_ids, tgt_ids else: yield from () return class SegmentedDataset(torch.utils.data.IterableDataset): def __init__( self, cfg, tokenizer, max_length=1024, transform_func=None, nthreads=1, prefetch_slice=3, slice_size=500, do_compact=False, ): super(SegmentedDataset, self).__init__() self.segment = functools.partial( cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact ) self.cfg = cfg self.max_length = max_length self.nthreads = nthreads self.transform_func = transform_func self.prefetch_slice = prefetch_slice self.slice_size = slice_size self.abs_weight = cfg.get("abs_weight", None) self.task_name = cfg["task_name"] self.dataset_name = cfg["dataset_name"] self.oversize_rule = cfg["oversize_rule"] self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True)) self.exhausted = False self.iterator = None self.counter = 0 self.allow_repeat = cfg.get("allow_repeat", True) self.used = BitSet() self.init_ave_tokens() def init_ave_tokens( self, ): try: shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}') except FileNotFoundError: bmt.print_rank( "Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}') ) shm = SharedMemory( create=True, size=ctypes.sizeof(ctypes.c_float), name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}', ) # 使用共享内存 shared_value = ctypes.c_float.from_buffer(shm.buf) _ave_tokens = self.cfg.get( "avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1)) ) if _ave_tokens > self.max_length: _ave_tokens = self.max_length bmt.print_rank( "Warning: avg_tokens {} is larger than max_length {}, set to max_length".format( _ave_tokens, self.max_length ) ) shared_value.value = _ave_tokens # 不再需要 shared_value 时,删除引用 del shared_value # 现在可以安全地关闭共享内存 shm.close() bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens)) @property def ave_tokens( self, ): existing_shm = SharedMemory( name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}' ) # -1 # default length shared_value = ctypes.c_float.from_buffer(existing_shm.buf) tmp = shared_value.value del shared_value existing_shm.close() return tmp def ave_tokens_update(self, length): existing_shm = SharedMemory( name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}' ) # -1 # default length shared_value = ctypes.c_float.from_buffer(existing_shm.buf) if shared_value.value < 0: shared_value.value = float(length) else: shared_value.value = 0.98 * shared_value.value + 0.02 * length del shared_value existing_shm.close() def size(self): return self.dataset.size() def __iter__(self): self.iterate() return self def reset(self): self.exhausted = False if self.iterator is not None: self.iterator.close() self.iterator = None self.used = BitSet() print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name)) def transform(self, data: dict) -> dict: weight = np.array(self.cfg["incontext_weight"], dtype=np.float32) weight = weight / weight.sum() num_incontext = np.random.choice(weight.shape[0], p=weight) return self.transform_func(data, num_incontext, random.Random()) def segment_iterate(self, sample_iter): for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used): for src_ids, tgt_ids in self.segment(self.transform(data)): self.ave_tokens_update(len(src_ids)) # 0 for input ids yield src_ids, tgt_ids, index def iterate(self): # make the dataset itself an iterator sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used) self.iterator = self.segment_iterate(sample_iter) def __next__(self): # advance the task iterator if self.iterator is None: self.iterate() try: return next(self.iterator) except StopIteration: self.exhausted = True return None def load_state_dict(self, state_dict): if state_dict.get("exhausted", False): self.exhausted = True self.used = BitSet() else: used = state_dict.get("used", BitSet()) if len(used) == len(self.dataset): self.exhausted = True self.used = BitSet() else: self.exhausted = False self.used = used self.ave_tokens_update(state_dict.get("ave_tokens", -1)) def state_dict(self): if len(self.used) == len(self.dataset): return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens) else: return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens) def update_state(self, indice): self.used.update(indice) class MixedIndexedDataset(torch.utils.data.IterableDataset): def __init__( self, cfg_path: str, cfg_json_str, tokenizer, max_length, weight_by_size: bool = True, nthreads=5, prefetch_slice=100, parallel_loading=False, vdc_sampling=False, update_weights_frequency=1, seed=42, ): super(MixedIndexedDataset, self).__init__() self.set_seed(seed + bmt.rank()) self.weight_by_size = weight_by_size self.tokenizer = tokenizer self.eos_token_id = self.tokenizer.eos_token_id self.bos_token_id = self.tokenizer.bos_token_id self.path2transform = dict() self.task_dict = OrderedDict() self.nthreads = nthreads self.prefetch_slice = prefetch_slice # useful for indexing self.tasks = [] self.names = [] # ending of iteration self.remain = 0 self.max_length = max_length self.vdc_sampling = vdc_sampling if self.vdc_sampling: self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6} self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values) self.update_weights_frequency = update_weights_frequency self.path2transform = dict() cfgs = load_dataset_cfgs(cfg_path, cfg_json_str) _sum_weight = sum([cfg["abs_weight"] for cfg in cfgs]) _weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs} bmt.print_rank("Absolute Weight of DataSet {}".format(_weights)) if parallel_loading: self.parallel_load(cfgs, max_workers=None) else: self.sequential_load(cfgs) self.weights = None self.update_weights() def set_seed(self, seed): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) def load_task(self, cfg): logger.info(f"Loading {cfg['path']}") transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"]) task = SegmentedDataset( cfg, self.tokenizer, self.max_length, transform_func=transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice, do_compact=cfg.get("do_compact", False), # dataset level do_compact ) return task def sequential_load(self, cfgs): self.cfgs = cfgs for cfg in cfgs: # python3.7 and later preserves insertion order to dictionary logger.info(f"Loading {cfg['path']}") transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"]) task = SegmentedDataset( cfg, self.tokenizer, self.max_length, transform_func=transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice, do_compact=cfg.get("do_compact", False), # dataset level do_compact ) self.task_dict[task.task_name] = task self.tasks.append(task) self.names.append(task.task_name) self.remain += 1 self.weights = None self.update_weights() def load_state_dict(self, state_dict): missing_keys = [] for name, task in self.task_dict.items(): if name in state_dict: task.load_state_dict(state_dict[name]) else: missing_keys.append(name) self.update_weights() return missing_keys def save_state_dict(self, path): state_dict = {} for name, task in self.task_dict.items(): _state_dict = task.state_dict() if isinstance(_state_dict["used"], BitSet): bitset = _state_dict["used"] _file_name = bitset.save(path) _state_dict["used"] = _file_name state_dict[name] = _state_dict else: state_dict[name] = task.state_dict() torch.save(state_dict, path) logger.info("Dataset state saved") def update_states(self, task_ids, indice): is_dict = isinstance(indice, dict) uniq = torch.unique(task_ids) for idx in uniq: idx = idx.item() indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist() self.tasks[idx].update_state(indexes) def get_transform_func(self, module_name: str, transform_script_path): if transform_script_path is None: # allow null transform return lambda data, num_incontext, rand: data module_name = "fm9g_live.transforms.{}".format(module_name) if transform_script_path not in self.path2transform: loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path) spec = importlib.util.spec_from_loader(loader.name, loader) if spec is None: raise RuntimeError("Spec is none! {}".format(module_name)) mod = importlib.util.module_from_spec(spec) self.path2transform[transform_script_path] = { "loader": loader, "module": mod, "last_mtime": 0, } transform_script_info = self.path2transform[transform_script_path] curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"]) if curr_mtime > transform_script_info["last_mtime"]: transform_script_info["last_mtime"] = curr_mtime transform_script_info["loader"].exec_module(transform_script_info["module"]) transform_func = getattr(transform_script_info["module"], "transform", None) if transform_func is None: raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path)) return transform_func def update_weights(self): task0 = self.tasks[0] if task0.abs_weight is not None: # 这一份config是指定绝对比例的 weights = [] for task in self.tasks: if task.exhausted: weights.append(0) else: if task.ave_tokens == -1: weights.append(task.abs_weight / self.max_length) else: weights.append(task.abs_weight / task.ave_tokens) weights = np.array(weights) else: weights = np.array([0 if task.exhausted else task.weight for task in self.tasks]) if self.weight_by_size: sizes = np.array([task.size() for task in self.tasks], dtype=np.float32) weights *= sizes self.weights = weights / weights.sum() def __iter__(self): for task in self.tasks: task.iterate() return self def __next__(self): step = 1 while True: if self.remain == 0: print("Rank {}, All task exhaust !!!!".format(bmt.rank())) raise StopIteration if self.vdc_sampling: idx = next(self.vdc_gen)(self.weights) else: idx = np.random.choice(len(self.weights), p=self.weights) data = next(self.tasks[idx]) if data is None: if self.tasks[idx].allow_repeat: # _runtime_ave = self.tasks[idx].ave_tokens print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name)) # self.tasks[idx] = SegmentedDataset( # self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice # ) # self.tasks[idx].ave_tokens_update(_runtime_ave) self.tasks[idx].reset() else: print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name)) self.tasks[idx].exhaust = True self.remain -= 1 continue if step % self.update_weights_frequency == 0: self.update_weights() step += 1 return dict( task_id=idx, input=data[0], target=data[1], index=data[2], is_long=self.tasks[idx].cfg.get("is_long", False), ) class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset): def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False): self.max_total_length = batch_size * max_length self.batch_size = 1 # setting compact=True concats segments orignated from the same input # into a long sequence. the relative order of segments should be preserved # in mixed_dataset, e.g., # - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3 # - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2 self.compact = compact self.total_length = 0 self.task2seqs = defaultdict(list) self.mixed_dataset = mixed_dataset self._max_length = max_length self._pose_prob = pose_prob self._pose_scaling_factor = pose_scaling_factor if self._pose_prob > 0.0: self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor) else: self._scaled_max_length = max_length def put(self, sample): self.total_length += len(sample["target"]) task_id = sample["task_id"] if self.compact and self.task2seqs[task_id]: last = self.task2seqs[task_id][-1] if last["target"][-1] != self.mixed_dataset.eos_token_id: # concatenate sequantial segments for longer context modeling: why not? last["input"].extend(sample["input"]) last["target"].extend(sample["target"]) return self.task2seqs[task_id].append(sample) def _pose_preprocess( self, input_ids: NDArray[np.int32], ) -> NDArray[np.int32]: """[PoSE](https://arxiv.org/abs/2309.10400v2) GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156 """ len_chunk = min(len(input_ids), self._max_length) len_input = len(input_ids) # Chunk input randomly to fit max_length if needed lt1 = 0 rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1) # Generate PoSE position ids position_ids = np.arange(len(chunked_input_ids), dtype=np.int32) len_position_ids = len(position_ids) lt = 0 rt = random.randint(lt, self._scaled_max_length - len_position_ids) position_ids[: rt1 - lt1] += lt position_ids[rt1 - lt1 :] += rt return position_ids def pop(self): indexes = defaultdict(list) lengths = [] inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32) targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT) task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1) position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32) span_begin = 0 for samples in self.task2seqs.values(): while samples: sample = samples.pop() span_end = span_begin + len(sample["input"]) inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32) targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32) task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32) if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob: _span_position_ids = self._pose_preprocess(sample["input"]) else: _span_position_ids = np.arange(len(sample["input"]), dtype=np.int32) position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids) # position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32) lengths.append(len(sample["target"])) indexes[int(sample["task_id"])].append(sample["index"]) self.total_length -= len(sample["target"]) span_begin = span_end cu_seqlens = torch.cat( [torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)], dim=0, ).int() batch = { "inputs": inputs, "targets": targets, "task_ids": task_ids, "indexes": indexes, # adhere to flash attention interface "cu_seqlens": cu_seqlens, "max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])), "lengths": torch.tensor(sum(lengths)).int(), "task_names": self.mixed_dataset.names, "position_ids": position_ids, } return batch def will_be_full(self, sample): return self.total_length + len(sample["target"]) > self.max_total_length def __iter__(self): for sample in self.mixed_dataset: if self.will_be_full(sample): yield self.pop() self.put(sample) class CudaPrefetcher(Iterable): """ Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency. """ def __init__(self, loader, tp_size=1, tp_rank=0): self.loader = iter(loader) self.tp_size = tp_size self.tp_rank = tp_rank self.stream = torch.cuda.Stream() self.preload() def preload(self): try: if self.tp_size > 1: if self.tp_rank == 0: data = next(self.loader) print("Rank {}, Preload data done.".format(bmt.rank())) d = {} with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb: for key in data.keys(): if isinstance(data[key], torch.Tensor): np_cur_data = data[key].cpu().numpy() bs = np_cur_data.tobytes() fb.write(bs) d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape) elif isinstance(data[key], np.ndarray): bs = data[key].tobytes() fb.write(bs) d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape) else: d[key] = data[key] try: _ = json.dumps(d) except TypeError: print(d) with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f: json.dump(d, f) bmt.synchronize() if self.tp_rank != 0: with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f: data = json.load(f) with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb: bs = fb.read() offset = 0 for key in data.keys(): if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY": nw_offset = offset + data[key][2] data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape( data[key][3:] ) offset = nw_offset elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH": nw_offset = offset + data[key][2] data[key] = torch.from_numpy( np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]) .reshape(data[key][3:]) .copy() ) offset = nw_offset self.data = data else: self.data = next(self.loader) except StopIteration: self.data = None return with torch.cuda.stream(self.stream): for key in self.data.keys(): if isinstance(self.data[key], torch.Tensor): self.data[key] = self.data[key].cuda(non_blocking=True) def __next__(self): torch.cuda.current_stream().wait_stream(self.stream) for key in self.data.keys(): if isinstance(self.data[key], torch.Tensor): self.data[key].record_stream(torch.cuda.current_stream()) data = copy.deepcopy(self.data) self.preload() return data def __iter__(self): return self