Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions fregex/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import sys
import ctypes
import random
import time
import statistics
import os
import gc
from pathlib import Path

from nanochat.tokenizer import SPLIT_PATTERN

os.environ.update({
'OMP_NUM_THREADS': '1',
'OPENBLAS_NUM_THREADS': '1',
'MKL_NUM_THREADS': '1',
'VECLIB_MAXIMUM_THREADS': '1',
'NUMEXPR_NUM_THREADS': '1',
'RAYON_NUM_THREADS': '1',
})

os.setpriority(os.PRIO_PROCESS, 0, -10)

from rustbpe import split_text as rust_split_text
from fregex.fuzz import gen_valid_unicode_string, compare_pair_text
from fregex.cload import *

PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString
PyBytes_AsString.restype = ctypes.c_void_p
PyBytes_AsString.argtypes = [ctypes.py_object]

def _run_once_c(data: bytes) -> float:
token_list = TokenList()
c_lib.tokenlist_init(ctypes.byref(token_list))
base_ptr = PyBytes_AsString(data)
t0 = time.perf_counter_ns()
c_lib.tokenize_fast(base_ptr, len(data), ctypes.byref(token_list))
dt_ms = (time.perf_counter_ns() - t0) / 1e6
c_lib.tokenlist_free(ctypes.byref(token_list))
return dt_ms

def _run_once_rust(text: str) -> float:
t0 = time.perf_counter_ns()
rust_split_text(SPLIT_PATTERN, text)
return (time.perf_counter_ns() - t0) / 1e6

def stats_summary(times: list) -> dict:
"""Compute statistics from timing list."""
if not times or len(times) == 0:
return {}

return {
'min': min(times),
'max': max(times),
'mean': statistics.mean(times),
'median': statistics.median(times),
'stdev': statistics.stdev(times) if len(times) > 1 else 0,
}

def format_stats(name: str, data_size: int, times: list) -> str:
"""Format timing statistics for output."""
if not times or len(times) == 0:
return f"{name:20} {data_size:>10} B --\n"

stats = stats_summary(times)

return (f"{name:20} {data_size:>10} B "
f"min={stats['min']:.3f}ms max={stats['max']:.3f}ms "
f"mean={stats['mean']:.3f}ms median={stats['median']:.3f}ms "
f"stdev={stats['stdev']:.3f}ms\n")

def benchmark_dataset(name: str, data_bytes: bytes, iterations: int) -> None:
test_text = data_bytes.decode('utf-8', errors='replace')

print(f"\n--- Dataset: {name} ({len(data_bytes)} bytes, {iterations} iterations) ---")
print()

# Pre-touch data to avoid first-touch/page-fault skew
if data_bytes:
_ = data_bytes[0]
for i in range(0, len(data_bytes), 4096):
_ = data_bytes[i]

# Warm-up
for _ in range(20):
_run_once_c(data_bytes)
_run_once_rust(test_text)

# Disable GC during timed section
gc_was_enabled = gc.isenabled()
if gc_was_enabled:
gc.disable()

c_times = []
rust_times = []
for _ in range(iterations):
c_times.append(_run_once_c(data_bytes))
rust_times.append(_run_once_rust(test_text))

if gc_was_enabled:
gc.enable()

print(format_stats("C tokenizer", len(data_bytes), c_times), end='')
print(format_stats("Rust split", len(data_bytes), rust_times), end='')

if c_times and rust_times:
c_mean = statistics.mean(c_times)
rust_mean = statistics.mean(rust_times)
ratio = rust_mean / c_mean
speedup = "C is faster" if ratio > 1 else "Rust is faster"
print(f"Speedup: {ratio:.2f}x ({speedup})")

print()

# Verify token splits match between C and Python regex tokenizer
cmp_text = data_bytes.decode('utf-8', errors='surrogatepass')
ok, err, out_c, out_py = compare_pair_text(cmp_text)
if ok:
print("Compare: OK (C vs Py splits match)")
else:
print("Compare: MISMATCH (C vs Py)")
if err:
print(err)
if out_c is not None and out_py is not None:
c_lines = out_c.splitlines()
p_lines = out_py.splitlines()
print(f"C tokens: {len(c_lines)} | Py tokens: {len(p_lines)}")
print("--- C (head) ---")
print("\n".join(c_lines[:10]))
print("--- Py (head) ---")
print("\n".join(p_lines[:10]))
# Stop the benchmark if mismatch detected
raise SystemExit(1)

def main():
# Check if files were provided as arguments
file_args = sys.argv[1:] if len(sys.argv) > 1 else []

# If files provided, benchmark them
if file_args:
for file_path in file_args:
path = Path(file_path)
if not path.exists():
print(f"❌ File not found: {file_path}")
continue

try:
data = path.read_bytes()
benchmark_dataset(path.name, data, 1_000)
except Exception as e:
print(f"❌ Error reading {file_path}: {e}")
else:
# Run random generated data
configs = [
("tiny", 100, 1000),
("small", 1024, 500),
("medium", 10 * 1024, 100),
("large", 100 * 1024, 100),
("xlarge", 1024 * 1024, 100),
]

for name, size_bytes, iterations in configs:
# Generate test data
test_text = gen_valid_unicode_string(
random.Random(hash(name)),
size_bytes
)
test_bytes = test_text.encode('utf-8')

benchmark_dataset(name, test_bytes, iterations)

print("=" * 140)

if __name__ == "__main__":
main()
51 changes: 51 additions & 0 deletions fregex/cload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import ctypes

c_lib = ctypes.CDLL("fregex/libfregex.dylib")


class TokenPos(ctypes.Structure):
_fields_ = [
("start", ctypes.c_size_t),
("end", ctypes.c_size_t),
]


class TokenList(ctypes.Structure):
_fields_ = [
("splits", ctypes.POINTER(TokenPos)),
("count", ctypes.c_size_t),
("capacity", ctypes.c_size_t),
]


c_lib.tokenlist_init.argtypes = [ctypes.POINTER(TokenList)]
c_lib.tokenlist_init.restype = None
c_lib.tokenlist_free.argtypes = [ctypes.POINTER(TokenList)]
c_lib.tokenlist_free.restype = None
# Accept a raw pointer to the input buffer rather than a Python bytes object
c_lib.tokenize_fast.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(TokenList)]
c_lib.tokenize_fast.restype = None

def tokenize_c_bytes(data: bytes) -> list[bytes]:
# Use a C char* view of the original bytes; offsets computed from this base
c_data = ctypes.c_char_p(data)
tl = TokenList()
c_lib.tokenlist_init(ctypes.byref(tl))
try:
base_addr = ctypes.cast(c_data, ctypes.c_void_p).value
# Pass the same pointer to C
c_lib.tokenize_fast(ctypes.cast(c_data, ctypes.c_void_p), len(data), ctypes.byref(tl))
out: list[bytes] = []
count = int(tl.count)
for i in range(count):
start_addr = int(tl.splits[i].start)
end_addr = int(tl.splits[i].end)
# Compute offsets into our local buffer
off_start = start_addr - base_addr
off_end = end_addr - base_addr
if off_start < 0 or off_end < off_start or off_end > len(data):
raise RuntimeError(f"Invalid span [{start_addr}:{end_addr}] for buffer base {base_addr}")
out.append(data[off_start:off_end])
return out
finally:
c_lib.tokenlist_free(ctypes.byref(tl))
162 changes: 162 additions & 0 deletions fregex/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import sys
import ctypes
from pathlib import Path

from nanochat.tokenizer import SPLIT_PATTERN
from rustbpe import split_text as rust_split_text
from fregex.cload import *
from fregex.py_tokenizer import tokenize_py as py_tokenize_str

def escape_bytes(b: bytes) -> str:
buf = []
for code in b:
if code == 0x5C:
buf.append('\\')
elif code == 0x0A:
buf.append('\\n')
elif code == 0x0D:
buf.append('\\r')
elif code == 0x09:
buf.append('\\t')
elif code == 0x0C:
buf.append('\\f')
elif code == 0x0B:
buf.append('\\v')
elif code == 0x22:
buf.append('\\"')
elif code < 32 or code >= 127:
buf.append(f"\\x{code:02X}")
else:
buf.append(chr(code))
return ''.join(buf)

def dump_tokens(tokens: list[bytes]) -> str:
return "\n".join(f"{len(b)}\t{escape_bytes(b)}" for b in tokens)

def tokenize_py_bytes(data: bytes) -> list[bytes]:
text = data.decode('utf-8', errors='surrogatepass')
toks = py_tokenize_str(text)
return [t.encode('utf-8', errors='surrogatepass') for t in toks]

def tokenize_rs_bytes(data: bytes) -> list[bytes]:
text = data.decode('utf-8', errors='surrogatepass')
parts = rust_split_text(SPLIT_PATTERN, text)
return [t.encode('utf-8', errors='surrogatepass') for t in parts]

def compare_one(path: Path) -> int:
data_bytes = Path(path).read_bytes()
try:
c_toks = tokenize_c_bytes(data_bytes)
except Exception as e:
print(f"C tokenizer failed on {path}:\n{e}", file=sys.stderr)
return 1
try:
py_toks = tokenize_py_bytes(data_bytes)
except Exception as e:
print(f"Python tokenizer failed on {path}:\n{e}", file=sys.stderr)
return 1
try:
rs_toks = tokenize_rs_bytes(data_bytes)
except Exception as e:
print(f"Rust split failed on {path}:\n{e}", file=sys.stderr)
return 1

out_c = dump_tokens(c_toks)
out_py = dump_tokens(py_toks)
out_rs = dump_tokens(rs_toks)

if out_c == out_py == out_rs:
print(f"OK {path.name}")
return 0
else:
print(f"DIFF {path.name}")
# Show a small 3-way diff at first differing line, with byte offsets
c_lines = out_c.splitlines()
p_lines = out_py.splitlines()
r_lines = out_rs.splitlines()

def parse_lines(lines):
parsed = []
for ln in lines:
# Format is: "<len>\t<escaped>"
try:
left, right = ln.split('\t', 1)
blen = int(left)
except Exception:
blen = 0
right = ln
parsed.append((blen, right))
return parsed

c_parsed = parse_lines(c_lines)
p_parsed = parse_lines(p_lines)
r_parsed = parse_lines(r_lines)

def byte_offsets(parsed):
offs = []
pos = 0
for blen, _ in parsed:
offs.append((pos, pos + blen))
pos += blen
return offs

c_offs = byte_offsets(c_parsed)
p_offs = byte_offsets(p_parsed)
r_offs = byte_offsets(r_parsed)

data_bytes = Path(path).read_bytes()

def print_unicode_debug(label, offs_list, idx):
if idx >= len(offs_list):
print(f" {label} piece: [n/a]")
return
start, end = offs_list[idx]
piece_bytes = data_bytes[start:end]
piece_text = piece_bytes.decode('utf-8', errors='replace')
if not piece_bytes:
print(f" {label} piece: [EMPTY]")
return
cp_parts = []
for ch in piece_text:
cp_parts.append(f"U+{ord(ch):04X}")
bytes_hex = ' '.join(f"{b:02X}" for b in piece_bytes)
print(f" {label} chars: {' | '.join(cp_parts)}")
print(f" {label} bytes: {bytes_hex} ({len(piece_bytes)}B, {len(piece_text)} chars)")

max_len = max(len(c_lines), len(p_lines), len(r_lines))
for i in range(max_len):
cl = c_lines[i] if i < len(c_lines) else "<eof>"
pl = p_lines[i] if i < len(p_lines) else "<eof>"
rl = r_lines[i] if i < len(r_lines) else "<eof>"
if not (cl == pl == rl):
# Collect byte positions if available
c_pos = f"[{c_offs[i][0]}:{c_offs[i][1]}]" if i < len(c_offs) else "[n/a]"
p_pos = f"[{p_offs[i][0]}:{p_offs[i][1]}]" if i < len(p_offs) else "[n/a]"
r_pos = f"[{r_offs[i][0]}:{r_offs[i][1]}]" if i < len(r_offs) else "[n/a]"
print(
f" line {i+1}:\n"
f" C: {cl} @ bytes {c_pos}\n"
f" Py: {pl} @ bytes {p_pos}\n"
f" Rs: {rl} @ bytes {r_pos}"
)
print(" === Unicode split detail ===")
print_unicode_debug("C", c_offs, i)
print_unicode_debug("Py", p_offs, i)
print_unicode_debug("Rs", r_offs, i)
break
return 2

def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <tests-dir>")
sys.exit(2)
paths = sorted(Path(sys.argv[1]).glob('*.txt'))
bad = 0
for p in paths:
bad += compare_one(p)
print(f"Completed. Failures: {bad}")

if __name__ == '__main__':
main()


Loading