Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ logs/
testing
libtraining_data_loader.*
easy_train_data
*.yaml
*.yaml
CMakeCache.txt
CMakeFiles/
Makefile
cmake_install.cmake
10 changes: 7 additions & 3 deletions data_loader/_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch

from .config import CDataloaderSkipConfig
from .config import CDataloaderSkipConfig, CDataloaderDDPConfig


class SparseBatch(ctypes.Structure):
Expand Down Expand Up @@ -144,7 +144,8 @@ def _define_prototypes(self):
# const char* const* filenames,
# int batch_size,
# bool cyclic,
# DataloaderSkipConfig config
# DataloaderSkipConfig config,
# DataloaderDDPConfig ddp_config
# )
self.dll.create_fen_batch_stream.restype = ctypes.c_void_p
self.dll.create_fen_batch_stream.argtypes = [
Expand All @@ -154,6 +155,7 @@ def _define_prototypes(self):
ctypes.c_int,
ctypes.c_bool,
CDataloaderSkipConfig,
CDataloaderDDPConfig,
]

# EXPORT void CDECL destroy_fen_batch_stream(FenBatchStream* stream)
Expand All @@ -170,7 +172,8 @@ def _define_prototypes(self):
# const char* const* filenames,
# int batch_size,
# bool cyclic,
# DataloaderSkipConfig config
# DataloaderSkipConfig config,
# DataloaderDDPConfig ddp_config
# )
self.dll.create_sparse_batch_stream.restype = ctypes.c_void_p
self.dll.create_sparse_batch_stream.argtypes = [
Expand All @@ -181,6 +184,7 @@ def _define_prototypes(self):
ctypes.c_int,
ctypes.c_bool,
CDataloaderSkipConfig,
CDataloaderDDPConfig,
]

# EXPORT void CDECL destroy_sparse_batch_stream(Stream<SparseBatch>* stream)
Expand Down
19 changes: 19 additions & 0 deletions data_loader/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ class DataloaderSkipConfig:
pc_y3: float = 1.0


@dataclass
class DataloaderDDPConfig:
rank: int = 0
world_size: int = 1


class CDataloaderSkipConfig(ctypes.Structure):
_fields_ = [
("filtered", ctypes.c_bool),
Expand All @@ -40,3 +46,16 @@ def __init__(self, config: DataloaderSkipConfig):
pc_y2=config.pc_y2,
pc_y3=config.pc_y3,
)


class CDataloaderDDPConfig(ctypes.Structure):
_fields_ = [
("rank", ctypes.c_int),
("world_size", ctypes.c_int),
]

def __init__(self, config: DataloaderDDPConfig):
super().__init__(
rank=config.rank,
world_size=config.world_size,
)
14 changes: 11 additions & 3 deletions data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import Dataset

from . import stream
from .config import DataloaderSkipConfig
from .config import DataloaderSkipConfig, DataloaderDDPConfig


class FenBatchProvider:
Expand All @@ -16,6 +16,7 @@ def __init__(
num_workers,
batch_size=None,
config: DataloaderSkipConfig = DataloaderSkipConfig(),
ddp_config: DataloaderDDPConfig = None,
):
self.filename = filename
self.cyclic = cyclic
Expand All @@ -25,7 +26,7 @@ def __init__(

if batch_size:
self.stream = stream.create_fen_batch_stream(
self.num_workers, [self.filename], batch_size, cyclic, config
self.num_workers, [self.filename], batch_size, cyclic, config, ddp_config
)
else:
# doesnt work yet
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
num_workers,
batch_size=None,
config: DataloaderSkipConfig = DataloaderSkipConfig(),
ddp_config: DataloaderDDPConfig = None,
):
self.feature_set = feature_set.encode("utf-8")
self.create_stream = create_stream
Expand All @@ -87,10 +89,11 @@ def __init__(
batch_size,
cyclic,
config,
ddp_config,
)
else:
self.stream = self.create_stream(
self.feature_set, self.num_workers, self.filenames, cyclic, config
self.feature_set, self.num_workers, self.filenames, cyclic, config, ddp_config
)

def __iter__(self):
Expand Down Expand Up @@ -119,6 +122,7 @@ def __init__(
cyclic=True,
num_workers=1,
config: DataloaderSkipConfig = DataloaderSkipConfig(),
ddp_config: DataloaderDDPConfig = None,
):
super().__init__(
feature_set,
Expand All @@ -131,6 +135,7 @@ def __init__(
num_workers,
batch_size,
config,
ddp_config,
)


Expand All @@ -143,6 +148,7 @@ def __init__(
cyclic=True,
num_workers=1,
config: DataloaderSkipConfig = DataloaderSkipConfig(),
ddp_config: DataloaderDDPConfig = None,
):
super().__init__()
self.feature_set = feature_set
Expand All @@ -151,6 +157,7 @@ def __init__(
self.cyclic = cyclic
self.num_workers = num_workers
self.config = config
self.ddp_config = ddp_config

def __iter__(self):
return SparseBatchProvider(
Expand All @@ -160,6 +167,7 @@ def __iter__(self):
cyclic=self.cyclic,
num_workers=self.num_workers,
config=self.config,
ddp_config=self.ddp_config,
)


Expand Down
28 changes: 27 additions & 1 deletion data_loader/stream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
import ctypes
import os

from ._native import c_lib, SparseBatchPtr, FenBatchPtr
from .config import CDataloaderSkipConfig, DataloaderSkipConfig
from .config import (
CDataloaderSkipConfig,
DataloaderSkipConfig,
CDataloaderDDPConfig,
DataloaderDDPConfig,
)


def _get_ddp_rank_and_world_size():
"""Get DDP rank and world size from environment variables."""
rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
print(f"DDP rank: {rank}, world size: {world_size}", flush=True)
return rank, world_size


def _to_c_str_array(str_list):
Expand All @@ -16,14 +30,20 @@ def create_fen_batch_stream(
batch_size,
cyclic,
config: DataloaderSkipConfig,
ddp_config: DataloaderDDPConfig = None,
) -> ctypes.c_void_p:
if ddp_config is None:
rank, world_size = _get_ddp_rank_and_world_size()
ddp_config = DataloaderDDPConfig(rank=rank, world_size=world_size)

return c_lib.dll.create_fen_batch_stream(
concurrency,
len(filenames),
_to_c_str_array(filenames),
batch_size,
cyclic,
CDataloaderSkipConfig(config),
CDataloaderDDPConfig(ddp_config),
)


Expand All @@ -46,7 +66,12 @@ def create_sparse_batch_stream(
batch_size,
cyclic,
config: DataloaderSkipConfig,
ddp_config: DataloaderDDPConfig = None,
) -> ctypes.c_void_p:
if ddp_config is None:
rank, world_size = _get_ddp_rank_and_world_size()
ddp_config = DataloaderDDPConfig(rank=rank, world_size=world_size)

return c_lib.dll.create_sparse_batch_stream(
feature_set,
concurrency,
Expand All @@ -55,6 +80,7 @@ def create_sparse_batch_stream(
batch_size,
cyclic,
CDataloaderSkipConfig(config),
CDataloaderDDPConfig(ddp_config),
)


Expand Down
76 changes: 74 additions & 2 deletions lib/nnue_training_data_formats.h
Original file line number Diff line number Diff line change
Expand Up @@ -6808,6 +6808,32 @@ namespace binpack
{
m_file.seekg(0);
}

bool skipChunks(std::size_t n)
{
if (n == 0) return true;

bool wrapped = false;
std::size_t skipped = 0;

while (skipped < n)
{
if (!hasNextChunk())
{
return false;
}

auto curPos = m_file.tellg();
Header header = readChunkHeader();
m_file.seekg(header.chunkSize, std::ios_base::cur);

assert(m_file.tellg() > curPos);

++skipped;
}

return true;
}

[[nodiscard]] std::vector<unsigned char> readNextChunk()
{
Expand Down Expand Up @@ -7620,12 +7646,16 @@ namespace binpack
std::vector<std::string> paths,
std::ios_base::openmode om = std::ios_base::app,
bool cyclic = false,
std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr
std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr,
int rank = 0,
int world_size = 1
) :
m_concurrency(concurrency),
m_bufferOffset(0),
m_cyclic(cyclic),
m_skipPredicate(std::move(skipPredicate))
m_skipPredicate(std::move(skipPredicate)),
m_rank(rank),
m_world_size(world_size)
{
m_numRunningWorkers.store(0);
std::vector<double> sizes; // discrete distribution wants double weights
Expand All @@ -7643,6 +7673,10 @@ namespace binpack

m_inputFileDistribution = std::discrete_distribution<>(sizes.begin(), sizes.end());

// Initialize DDP seeking tracking
m_files_seeked_for_ddp.resize(m_inputFiles.size(), false);
m_ddp_chunks_to_skip_after_read.resize(m_inputFiles.size(), 0);

m_stopFlag.store(false);

auto worker = [this]()
Expand Down Expand Up @@ -7830,6 +7864,12 @@ namespace binpack
std::function<bool(const TrainingDataEntry&)> m_skipPredicate;

std::vector<std::thread> m_workers;

// DDP support
int m_rank;
int m_world_size;
std::vector<std::uint8_t> m_files_seeked_for_ddp; // Track which files have been seeked for DDP
std::vector<std::size_t> m_ddp_chunks_to_skip_after_read;

bool fetchNextChunkIfNeeded(std::size_t& m_offset, std::vector<unsigned char>& m_chunk)
{
Expand All @@ -7841,18 +7881,50 @@ namespace binpack

std::unique_lock lock(m_fileMutex);

// DDP: chunk-based skipping
if (m_world_size > 1)
{
if (!m_files_seeked_for_ddp[fileId])
{
inputFile.skipChunks(static_cast<std::size_t>(m_rank));
m_files_seeked_for_ddp[fileId] = true;
}
else if (m_ddp_chunks_to_skip_after_read[fileId] > 0)
{
const bool success = inputFile.skipChunks(m_ddp_chunks_to_skip_after_read[fileId]);
m_ddp_chunks_to_skip_after_read[fileId] = 0;
if (!success && m_cyclic)
{
inputFile.seek_to_start();
inputFile.skipChunks(static_cast<std::size_t>(m_rank));
}
}
}

if (!inputFile.hasNextChunk())
{
if (m_cyclic)
{
inputFile.seek_to_start();

if (m_world_size > 1 )
{
inputFile.skipChunks(static_cast<std::size_t>(m_rank));
}
}
else
{
return true;
}
}

m_chunk = inputFile.readNextChunk();
m_offset = 0;

if (m_world_size > 1)
{
m_ddp_chunks_to_skip_after_read[fileId] = static_cast<std::size_t>(m_world_size - 1);
}
}

return false;
Expand Down
8 changes: 4 additions & 4 deletions lib/nnue_training_data_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ namespace training_data {
static constexpr auto openmode = std::ios::in | std::ios::binary;
static inline const std::string extension = "binpack";

BinpackSfenInputParallelStream(int concurrency, const std::vector<std::string>& filenames, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
m_stream(std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(concurrency, filenames, openmode, cyclic, skipPredicate)),
BinpackSfenInputParallelStream(int concurrency, const std::vector<std::string>& filenames, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate, int rank = 0, int world_size = 1) :
m_stream(std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(concurrency, filenames, openmode, cyclic, skipPredicate, rank, world_size)),
m_filenames(filenames),
m_concurrency(concurrency),
m_eof(false),
Expand Down Expand Up @@ -243,13 +243,13 @@ namespace training_data {
return nullptr;
}

inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file_parallel(int concurrency, const std::vector<std::string>& filenames, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr)
inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file_parallel(int concurrency, const std::vector<std::string>& filenames, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr, int rank = 0, int world_size = 1)
{
// TODO (low priority): optimize and parallelize .bin reading.
if (has_extension(filenames[0], BinSfenInputStream::extension))
return std::make_unique<BinSfenInputStream>(filenames[0], cyclic, std::move(skipPredicate));
else if (has_extension(filenames[0], BinpackSfenInputParallelStream::extension))
return std::make_unique<BinpackSfenInputParallelStream>(concurrency, filenames, cyclic, std::move(skipPredicate));
return std::make_unique<BinpackSfenInputParallelStream>(concurrency, filenames, cyclic, std::move(skipPredicate), rank, world_size);

return nullptr;
}
Expand Down
Loading