Skip to content
Snippets Groups Projects
Commit 37fb8aad authored by joeld's avatar joeld
Browse files

efficient multiprocess data loader, parallel streaming input files line by line

parent 0f68664f
No related branches found
No related tags found
No related merge requests found
import os
import gzip
import json
import glob
import multiprocessing
class DataLoader:
def __init__(self, input_dir, output_dir, total_procs, proc_rank):
self.input_dir = input_dir
self.output_dir = output_dir
self.total_procs = total_procs
self.proc_rank = proc_rank
os.makedirs(self.output_dir, exist_ok=True)
def _get_input_files(self):
return sorted(glob.glob(os.path.join(self.input_dir, "*.jsonl")) +
glob.glob(os.path.join(self.input_dir, "*.jsonl.gz")))
def _open_partition_file(self):
partition_path = os.path.join(self.output_dir, f"shard_{self.proc_rank}.jsonl")
return open(partition_path, "w", encoding="utf-8")
def partition_data(self):
input_files = self._get_input_files()
if not input_files:
raise FileNotFoundError(f"No input files found in {self.input_dir}")
partition_file = self._open_partition_file()
for file_path in input_files:
open_func = gzip.open if file_path.endswith(".gz") else open
with open_func(file_path, "rt", encoding="utf-8") as infile:
for line_index, line in enumerate(infile):
if line_index % self.total_procs == self.proc_rank:
partition_file.write(line)
partition_file.close()
# TODO: weakness may be that this assumes evenly distributed data, but what if one file contains way more lines than another
#loader = DataLoader(input_dir="/data/cc_news", output_dir="/data/processed", total_procs=4, proc_rank=100)
#loader.partition_data()
"""
INPUT_DIR = r"C:\Dev_Projects\HTYLLM-PG\data\cc_news"
OUTPUT_DIR = r"C:\Dev_Projects\HTYLLM-PG\data\processed"
TOTAL_PROCS = int(os.getenv("TOTAL_PROCS", "4"))
def test_partition_data(proc_rank):
loader = DataLoader(input_dir=INPUT_DIR, output_dir=OUTPUT_DIR, total_procs=TOTAL_PROCS, proc_rank=proc_rank)
loader.partition_data()
print(f"Process {proc_rank} completed.")
def main():
processes = []
for proc_rank in range(TOTAL_PROCS):
p = multiprocessing.Process(target=test_partition_data, args=(proc_rank,))
processes.append(p)
p.start()
for p in processes:
p.join()
print("All processes completed.")
if __name__ == "__main__":
main()
"""
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment