@@ -1003,12 +1003,13 @@ def eval_val_ngram(
10031003 eval_seq_len : int ,
10041004 stride : int ,
10051005 batch_seqs : int = 32 ,
1006- ngram_order : int = 7 ,
1007- ngram_min_order : int = 2 ,
1006+ ngram_order : int = 5 ,
1007+ ngram_min_order : int = 5 ,
10081008 ngram_buckets : int = 4194304 ,
10091009 ngram_min_count : int = 2 ,
1010- ent_base : float = 0.05 ,
1011- ent_range : float = 0.55 ,
1010+ fixed_alpha : float = 0.2 ,
1011+ ent_base : float = 0.0 ,
1012+ ent_range : float = 0.0 ,
10121013 ent_scale : float = 2.0 ,
10131014 ent_thresh : float = 4.0 ,
10141015 log_fn = None ,
@@ -1111,13 +1112,15 @@ def eval_val_ngram(
11111112 log_fn (f"ngram: done, { has_ngram .sum ()} positions with n-gram predictions" )
11121113
11131114 # step 3: vectorized mixing
1114- # debug: report stats on n-gram probs
11151115 if log_fn :
11161116 dbg_mask = scored_mask & has_ngram
11171117 ng_pt = ngram_prob_target [dbg_mask ]
11181118 log_fn (f"ngram_stats: mean_prob={ ng_pt .mean ():.6f} median={ np .median (ng_pt ):.6f} "
11191119 f"nonzero={ np .count_nonzero (ng_pt )} /{ len (ng_pt )} " )
1120- alpha_all = ent_base + ent_range / (1.0 + np .exp (- ent_scale * (token_neural_entropy - ent_thresh )))
1120+ if ent_range > 0 :
1121+ alpha_all = ent_base + ent_range / (1.0 + np .exp (- ent_scale * (token_neural_entropy - ent_thresh )))
1122+ else :
1123+ alpha_all = np .full (total_tokens , fixed_alpha , dtype = np .float64 )
11211124 mixed_nll = np .copy (token_neural_nll )
11221125 # only mix where n-gram assigns nonzero prob to the target token
11231126 mix_mask = scored_mask & has_ngram & (ngram_prob_target > 0 )
@@ -1761,24 +1764,26 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
17611764 ngram_enabled = bool (int (os .environ .get ("NGRAM_ENABLED" , "1" )))
17621765 sw_seq_len = effective_eval_seq_len
17631766 if ngram_enabled :
1764- ngram_order = int (os .environ .get ("NGRAM_ORDER" , "7 " ))
1765- ngram_min_order = int (os .environ .get ("NGRAM_MIN_ORDER" , "2 " ))
1767+ ngram_order = int (os .environ .get ("NGRAM_ORDER" , "5 " ))
1768+ ngram_min_order = int (os .environ .get ("NGRAM_MIN_ORDER" , "5 " ))
17661769 ngram_buckets = int (os .environ .get ("NGRAM_BUCKETS" , "4194304" ))
17671770 ngram_min_count = int (os .environ .get ("NGRAM_MIN_COUNT" , "2" ))
1768- ngram_ent_base = float (os .environ .get ("NGRAM_ENT_BASE" , "0.05" ))
1769- ngram_ent_range = float (os .environ .get ("NGRAM_ENT_RANGE" , "0.55" ))
1771+ ngram_alpha = float (os .environ .get ("NGRAM_ALPHA" , "0.2" )) # fixed alpha (PR #769)
1772+ ngram_ent_base = float (os .environ .get ("NGRAM_ENT_BASE" , "0.0" )) # 0 = fixed alpha
1773+ ngram_ent_range = float (os .environ .get ("NGRAM_ENT_RANGE" , "0.0" ))
17701774 ngram_ent_scale = float (os .environ .get ("NGRAM_ENT_SCALE" , "2.0" ))
17711775 ngram_ent_thresh = float (os .environ .get ("NGRAM_ENT_THRESH" , "4.0" ))
17721776 torch .cuda .synchronize ()
17731777 t_ngram = time .perf_counter ()
1774- log0 (f"ngram_eval: order={ ngram_order } min_order={ ngram_min_order } buckets={ ngram_buckets } " )
1778+ log0 (f"ngram_eval: order={ ngram_order } min_order={ ngram_min_order } buckets={ ngram_buckets } alpha= { ngram_alpha } " )
17751779 ng_val_loss , ng_val_bpb = eval_val_ngram (
17761780 args , eval_model , rank , world_size , device ,
17771781 val_tokens , base_bytes_lut , has_leading_space_lut , is_boundary_token_lut ,
17781782 eval_seq_len = sw_seq_len if args .eval_stride > 0 else effective_eval_seq_len ,
17791783 stride = args .eval_stride if args .eval_stride > 0 else effective_eval_seq_len ,
17801784 ngram_order = ngram_order , ngram_min_order = ngram_min_order ,
17811785 ngram_buckets = ngram_buckets , ngram_min_count = ngram_min_count ,
1786+ fixed_alpha = ngram_alpha ,
17821787 ent_base = ngram_ent_base , ent_range = ngram_ent_range ,
17831788 ent_scale = ngram_ent_scale , ent_thresh = ngram_ent_thresh ,
17841789 log_fn = log0 ,
0 commit comments