Skip to content

Commit dcc4f69

Browse files
committed
exp54: 5-gram fixed alpha=0.2 cache (PR openai#769 recipe)
1 parent 987b26b commit dcc4f69

1 file changed

Lines changed: 16 additions & 11 deletions

File tree

train_gpt.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,12 +1003,13 @@ def eval_val_ngram(
10031003
eval_seq_len: int,
10041004
stride: int,
10051005
batch_seqs: int = 32,
1006-
ngram_order: int = 7,
1007-
ngram_min_order: int = 2,
1006+
ngram_order: int = 5,
1007+
ngram_min_order: int = 5,
10081008
ngram_buckets: int = 4194304,
10091009
ngram_min_count: int = 2,
1010-
ent_base: float = 0.05,
1011-
ent_range: float = 0.55,
1010+
fixed_alpha: float = 0.2,
1011+
ent_base: float = 0.0,
1012+
ent_range: float = 0.0,
10121013
ent_scale: float = 2.0,
10131014
ent_thresh: float = 4.0,
10141015
log_fn=None,
@@ -1111,13 +1112,15 @@ def eval_val_ngram(
11111112
log_fn(f"ngram: done, {has_ngram.sum()} positions with n-gram predictions")
11121113

11131114
# step 3: vectorized mixing
1114-
# debug: report stats on n-gram probs
11151115
if log_fn:
11161116
dbg_mask = scored_mask & has_ngram
11171117
ng_pt = ngram_prob_target[dbg_mask]
11181118
log_fn(f"ngram_stats: mean_prob={ng_pt.mean():.6f} median={np.median(ng_pt):.6f} "
11191119
f"nonzero={np.count_nonzero(ng_pt)}/{len(ng_pt)}")
1120-
alpha_all = ent_base + ent_range / (1.0 + np.exp(-ent_scale * (token_neural_entropy - ent_thresh)))
1120+
if ent_range > 0:
1121+
alpha_all = ent_base + ent_range / (1.0 + np.exp(-ent_scale * (token_neural_entropy - ent_thresh)))
1122+
else:
1123+
alpha_all = np.full(total_tokens, fixed_alpha, dtype=np.float64)
11211124
mixed_nll = np.copy(token_neural_nll)
11221125
# only mix where n-gram assigns nonzero prob to the target token
11231126
mix_mask = scored_mask & has_ngram & (ngram_prob_target > 0)
@@ -1761,24 +1764,26 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
17611764
ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "1")))
17621765
sw_seq_len = effective_eval_seq_len
17631766
if ngram_enabled:
1764-
ngram_order = int(os.environ.get("NGRAM_ORDER", "7"))
1765-
ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "2"))
1767+
ngram_order = int(os.environ.get("NGRAM_ORDER", "5"))
1768+
ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", "5"))
17661769
ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", "4194304"))
17671770
ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", "2"))
1768-
ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.05"))
1769-
ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.55"))
1771+
ngram_alpha = float(os.environ.get("NGRAM_ALPHA", "0.2")) # fixed alpha (PR #769)
1772+
ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", "0.0")) # 0 = fixed alpha
1773+
ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", "0.0"))
17701774
ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", "2.0"))
17711775
ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", "4.0"))
17721776
torch.cuda.synchronize()
17731777
t_ngram = time.perf_counter()
1774-
log0(f"ngram_eval: order={ngram_order} min_order={ngram_min_order} buckets={ngram_buckets}")
1778+
log0(f"ngram_eval: order={ngram_order} min_order={ngram_min_order} buckets={ngram_buckets} alpha={ngram_alpha}")
17751779
ng_val_loss, ng_val_bpb = eval_val_ngram(
17761780
args, eval_model, rank, world_size, device,
17771781
val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut,
17781782
eval_seq_len=sw_seq_len if args.eval_stride > 0 else effective_eval_seq_len,
17791783
stride=args.eval_stride if args.eval_stride > 0 else effective_eval_seq_len,
17801784
ngram_order=ngram_order, ngram_min_order=ngram_min_order,
17811785
ngram_buckets=ngram_buckets, ngram_min_count=ngram_min_count,
1786+
fixed_alpha=ngram_alpha,
17821787
ent_base=ngram_ent_base, ent_range=ngram_ent_range,
17831788
ent_scale=ngram_ent_scale, ent_thresh=ngram_ent_thresh,
17841789
log_fn=log0,

0 commit comments

Comments
 (0)