diff --git a/records/track_1_short/2026-01-31-BigramHashH2D/112c686e-b0d6-4dc8-814a-1ad1f5d5b274.txt b/records/track_1_short/2026-01-31-BigramHashH2D/112c686e-b0d6-4dc8-814a-1ad1f5d5b274.txt new file mode 100644 index 000000000..e0c66b33c --- /dev/null +++ b/records/track_1_short/2026-01-31-BigramHashH2D/112c686e-b0d6-4dc8-814a-1ad1f5d5b274.txt @@ -0,0 +1,3976 @@ +import os +import sys + +# Read the current file and the kernels file code ASAP, for logging +with open(sys.argv[0], 'r') as f: + code = f.read() +with open(os.path.join(os.path.dirname(sys.argv[0]), 'triton_kernels.py'), 'r') as f: + code += f"\n\n{'-'*40}\n# triton_kernels.py\n{'-'*40}\n\n" + code += f.read() + +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate, pairwise +from pathlib import Path +import gc + +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +import torch +import triton + +torch.empty( + 1, device=f"cuda:{os.environ['LOCAL_RANK']}", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +from kernels import get_kernel +from torch import Tensor, nn + +from triton_kernels import XXT, ba_plus_cAA, FusedLinearReLUSquareFunction, FusedSoftcappedCrossEntropy + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Distributed training setup +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +grad_scale = 2 / grad_accum_steps # consistent grad magnitudes between different num_devices +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng +# Transposed layout by @ChrisJMcCormick allows for faster gradient accumulation. + +@torch.library.custom_op("nanogpt::mm_t", mutates_args=()) +def mm_t_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + """Computes y = x @ w with F8 weights stored as (in_features, out_features).""" + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + assert x.shape[1] == w.shape[0] # x: (batch, in), w: (in, out) + + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + + # _scaled_mm requires column-major B. w_f8 is row-major (in, out). + # .T.contiguous().T creates a column-major view without changing logical shape. + w_f8_col_major = w_f8.T.contiguous().T + + out = torch._scaled_mm( + x_f8, + w_f8_col_major, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_t_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[0] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_t_backward", mutates_args=()) +def mm_t_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + + x_scale = grad.new_tensor(x_s, dtype=torch.float32) + w_scale = grad.new_tensor(w_s, dtype=torch.float32) + grad_scale = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + + # grad_x = grad @ w.T + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=grad_scale, + scale_b=w_scale, + use_fast_accum=False, + ) + + # grad_w = x.T @ grad + # Result is (in, out), naturally matching weight storage. No final .T needed. + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_scale, + scale_b=grad_scale, + use_fast_accum=False, + ) + + return grad_x, grad_w + + grad_x, grad_w = impl(g, x_f8, w_f8) + + return grad_x, grad_w + +@mm_t_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) + +def backward_t(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_t_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context_t(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_t_op.register_autograd(backward_t, setup_context=setup_context_t) + +# ----------------------------------------------------------------------------- +# Polar Express + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor, split_baddbmm: bool = False): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + # Select batched vs unbatched + if split_baddbmm: + BX_matmul = torch.bmm if X.ndim > 2 else torch.mm + else: + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in polar_express_coeffs: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + + # Referencing X twice causes pytorch to make a defensive copy, + # resulting in a cudaMemcpyAsync in baddbmm. + # For large matrices (i.e., the mlp weights), it's faster to split + # the operation into two kernels to avoid this. + if split_baddbmm: + BX_matmul(B, X, out=C) # C = B @ X + C.add_(X, alpha=a) # C = C + a*X (in-place, X only read) + else: + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Combined NorMuon + Adam Optimizer + +@dataclass +class ParamConfig: + """Per-parameter configuration for NorMuonAndAdam optimizer.""" + label: str + optim: str # "adam" or "normuon" + comms: str # "none", "replicated", or "sharded" + adam_betas: tuple[float, float] | None + lr_mul: float + wd_mul: float + lr: float + initial_lr: float + weight_decay: float + # Adam-specific + eps: float | None = None + # NorMuon-specific + reshape: tuple | None = None + chunk_size: int | None = None + momentum: float | None = None + beta2: float | None = None + per_matrix_lr_mul: list[float] | None = None + + +class NorMuonAndAdam: + """ + Combined optimizer that handles both NorMuon (for projection matrices) and + Adam (for embeddings/scalars/gate weights). + + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, Muon uses a Newton-Schulz iteration (replaced + here with Polar Express), which has the advantage that it can be stably run in bfloat16 on the GPU. + + Muon is applied only to the projection matrices in the attention and MLP layers, and is not recommended + for embeddings, scalars, or individual weight vectors (e.g., bias terms or gate weights). + + Differences from standard Muon: + - Newton-Shulz is replaced with Polar Express for the orthogonalization step + - NorMuon adds a low-rank variance estimator similar to Adafactor. https://arxiv.org/pdf/2510.05491 + - Cautious weight decay, a gated version of decoupled weight decay + - Mantissa tracking for precision + + Adam (for embeddings/scalars/gates): + - Standard Adam with bias correction + - Cautious weight decay + + Configuration: + Unlike torch.optim.Optimizer, this class uses per-parameter configs from a `param_table` dict + and does not include parameter "groups". All parameters require a .label attribute, and a + corresponding entry in the param_table to specify their hyperparameters (lr_mul, wd_mul, adam_betas, etc.). + + Communication and ordering: + Gradient communication is explicitly scheduled rather than hook-driven. + Reductions are launched in `scatter_order`, while update math and final + gathers are executed in `work_order`. These orders are independent and + must each contain every parameter label exactly once. + + Two communication modes are supported per parameter: + - 'replicated': Gradients are all-reduced and each rank computes the full update. + - 'sharded': Gradients are reduce-scattered, each rank updates its shard, + and results are all-gathered. + + Adam parameters may be freely sharded. NorMuon operates on full matrices; sharding is + supported by grouping matrices into parameter banks. NorMuon parameters must have a + `.reshape` attribute that reshapes the bank so that the leading dimension is divisible + by world_size. + + # Contributors include @YouJiacheng, @KonstantinWilleke, @alexrgilbert, @adricarda, + # @tuttyfrutyee, @vdlad, @ryanyang0, @vagrawal, @varunneal, @chrisjmccormick + """ + def __init__(self, named_params, param_table: dict, scatter_order: list, work_order: list, + adam_defaults: dict, normuon_defaults: dict): + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # Store defaults for each optimizer type + self.adam_defaults = adam_defaults + self.normuon_defaults = normuon_defaults + self.param_table = param_table + self.scatter_order = scatter_order + self.work_order = work_order + + # Collect params by label and build config + self.param_cfgs: dict[nn.Parameter, ParamConfig] = {} + self.param_states: dict[nn.Parameter, dict] = {} + self._param_by_label: dict[str, nn.Parameter] = {} + for name, param in named_params: + label = getattr(param, "label", None) + assert label is not None and label in param_table # all params must have valid label + assert label not in self._param_by_label # exactly one param per label + self._param_by_label[label] = param + self._build_param_cfg(param, label) + + # Assert scatter_order and work_order match present labels exactly + present = set(self._param_by_label.keys()) + assert set(scatter_order) == present and set(work_order) == present + + # Handle world_size=1: overwrite comms to "none" + if self.world_size == 1: + for p_cfg in self.param_cfgs.values(): + p_cfg.comms = "none" + + # Initialize state for all params + self._init_state() + + # 0-D CPU tensors to avoid recompilation + self._step_size_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + + # Track async operations + self._reduce_futures: dict[nn.Parameter, tuple] = {} + + # Embed/lm_head tying state + self.split_embed = False + self._lm_head_param = self._param_by_label.get("lm_head") + self._embed_param = self._param_by_label.get("embed") + + def _build_param_cfg(self, param: nn.Parameter, label: str): + """Build config for a single parameter from param_table.""" + table_entry = self.param_table[label] + optim = table_entry["optim"] + comms = table_entry["comms"] + adam_betas = table_entry.get("adam_betas") + lr_mul = table_entry.get("lr_mul", 1.0) + wd_mul = table_entry.get("wd_mul", 1.0) + + if optim == "adam": + chunk_size = param.shape[0] // self.world_size if comms == "sharded" else None + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.adam_defaults["lr"], + initial_lr=self.adam_defaults["lr"], + weight_decay=self.adam_defaults["weight_decay"], + eps=self.adam_defaults["eps"], + chunk_size=chunk_size, + ) + elif optim == "normuon": + reshape = getattr(param, "reshape", None) + if reshape is None: + raise ValueError(f"NorMuon param {label} must have .reshape attribute") + if reshape[0] % self.world_size != 0: + raise ValueError(f"reshape[0]={reshape[0]} must be divisible by world_size") + + chunk_size = reshape[0] // self.world_size + chunk_shape = (chunk_size, *reshape[1:]) + # Shape-based LR multiplier for NorMuon + shape_mult = max(1.0, chunk_shape[-2] / chunk_shape[-1]) ** 0.5 if len(chunk_shape) >= 2 else 1.0 + lr_mul = shape_mult * lr_mul + + # Per-matrix LR multipliers for MLP c_proj (2x LR on odd indices) + per_matrix_lr_mul = None + if label == "mlp": + rank = dist.get_rank() if dist.is_initialized() else 0 + start_idx = rank * chunk_size + per_matrix_lr_mul = [] + for i in range(chunk_size): + global_idx = start_idx + i + is_c_proj = (global_idx % 2 == 1) + per_matrix_lr_mul.append(2.0 if is_c_proj else 1.0) + + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.normuon_defaults["lr"], + initial_lr=self.normuon_defaults["lr"], + weight_decay=self.normuon_defaults["weight_decay"], + reshape=reshape, + chunk_size=chunk_size, + momentum=self.normuon_defaults["momentum"], + beta2=self.normuon_defaults["beta2"], + per_matrix_lr_mul=per_matrix_lr_mul, + ) + else: + raise ValueError(f"Unknown optim type: {optim}") + + self.param_cfgs[param] = p_cfg + + def _init_state(self): + """Initialize optimizer state for all parameters.""" + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam": + # Sharded params use chunk state, replicated use full state + if p_cfg.comms == "sharded": + chunk = param[:p_cfg.chunk_size] + else: + chunk = param + exp_avg = torch.zeros_like(chunk, dtype=torch.float32, device=param.device) + self.param_states[param] = dict(step=0, exp_avg=exp_avg, exp_avg_sq=torch.zeros_like(exp_avg)) + + elif p_cfg.optim == "normuon": + chunk_shape = (p_cfg.chunk_size, *p_cfg.reshape[1:]) + + # Momentum buffer (FP32 for precision) + momentum_buffer = torch.zeros( + chunk_shape, dtype=torch.float32, device=param.device + ) + + # Second momentum buffer - reduced along one dimension + if chunk_shape[-2] >= chunk_shape[-1]: + second_mom_shape = (*chunk_shape[:-1], 1) + else: + second_mom_shape = (*chunk_shape[:-2], 1, chunk_shape[-1]) + second_momentum_buffer = torch.zeros( + second_mom_shape, dtype=torch.float32, device=param.device + ) + + # Mantissa buffer for precision tracking + mantissa = torch.zeros( + chunk_shape, dtype=torch.uint16, device=param.device + ) + + self.param_states[param] = dict( + momentum_buffer=momentum_buffer, + second_momentum_buffer=second_momentum_buffer, + mantissa=mantissa, + ) + + # ----------------------------------- + # Reduce/Gather operations + + def _launch_reduce(self, param: nn.Parameter, grad: Tensor): + """Launch async reduce for a parameter based on its comms policy.""" + p_cfg = self.param_cfgs[param] + + if p_cfg.comms == "none": + if p_cfg.optim == "normuon": + # NorMuon needs reshaped gradient even without communication + grad = grad.view(p_cfg.reshape) + self._reduce_futures[param] = (None, grad) + elif p_cfg.comms == "replicated": + future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + self._reduce_futures[param] = (future, grad) + elif p_cfg.comms == "sharded": + if p_cfg.optim == "normuon": + # NorMuon: reshape before reduce_scatter + grad_reshaped = grad.view(p_cfg.reshape) + grad_chunk = torch.empty( + (p_cfg.chunk_size, *grad_reshaped.shape[1:]), + dtype=grad.dtype, + device=grad.device + ) + future = dist.reduce_scatter_tensor( + grad_chunk, grad_reshaped.contiguous(), op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + else: + # Adam: simple reduce_scatter + grad_chunk = torch.empty_like(grad[:p_cfg.chunk_size]) + future = dist.reduce_scatter_tensor( + grad_chunk, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + + def _launch_gather(self, param: nn.Parameter, p_slice: Tensor) -> "torch.futures.Future": + """Launch async all_gather for a sharded parameter.""" + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "normuon": + full_param = param.data.view(p_cfg.reshape) + assert full_param.is_contiguous() + return dist.all_gather_into_tensor( + full_param, p_slice.contiguous(), async_op=True + ).get_future() + else: + return dist.all_gather_into_tensor( + param, p_slice.contiguous(), async_op=True + ).get_future() + + # ----------------------------------- + # State management + + def reset(self): + """Reset NorMuon momentum buffers and split_embed state (called on training reset).""" + self.split_embed = False + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "normuon": + p_state = self.param_states[param] + p_state["momentum_buffer"].zero_() + p_state["mantissa"].zero_() + p_state["second_momentum_buffer"].zero_() + + def copy_lm_state_to_embed(self): + """ + Copy the optimizer state from the lm_head to the embed at the untie point. + This requires an all-gather + reshard because of different sharding: + - lm_head (768, 50304) is sharded to (96, 50304) per rank (along model_dim) + - embed (50304, 768) is sharded to (6288, 768) per rank (along vocab_size) + + We all-gather the lm_head momentum, transpose it, then each rank takes their + embed shard to get the correct momentum state. + """ + lm_head = self._lm_head_param + embed = self._embed_param + lm_state = self.param_states[lm_head] + embed_state = self.param_states[embed] + lm_cfg = self.param_cfgs[lm_head] + embed_cfg = self.param_cfgs[embed] + + embed_state['step'] = lm_state['step'] # Preserve step count for bias correction + + # Copy optimizer state with all-gather + transpose + reshard + if self.world_size > 1: + rank = dist.get_rank() + lm_chunk_size = lm_cfg.chunk_size # 96 + embed_chunk_size = embed_cfg.chunk_size # 6288 + + # All-gather lm_head momentum to get full (768, 50304) tensor + for key in ["exp_avg", "exp_avg_sq"]: + lm_chunk = lm_state[key] # (96, 50304) + full_lm = torch.empty(lm_head.shape[0], lm_head.shape[1], dtype=lm_chunk.dtype, device=lm_chunk.device) + dist.all_gather_into_tensor(full_lm, lm_chunk.contiguous()) + embed_state[key].copy_(full_lm.T[rank * embed_chunk_size:(rank + 1) * embed_chunk_size]) + else: + # Single GPU: simple transpose + for key in ["exp_avg", "exp_avg_sq"]: + embed_state[key].copy_(lm_state[key].T) + + # Mark as split + self.split_embed = True + + def state_dict(self): + """Return the optimizer state as a dict.""" + return { + "param_states": {id(p): s for p, s in self.param_states.items()}, + "param_cfgs": {id(p): s for p, s in self.param_cfgs.items()}, + } + + def load_state_dict(self, state_dict): + """Load optimizer state from a dict.""" + # Build id->param mapping + id_to_param = {id(p): p for p in self.param_cfgs.keys()} + + # Load state, preserving dtypes + for param_id, saved_p_state in state_dict["param_states"].items(): + if param_id in id_to_param: + param = id_to_param[param_id] + p_state = self.param_states[param] + for k, v in saved_p_state.items(): + if isinstance(v, torch.Tensor) and k in p_state: + target_dtype = p_state[k].dtype + p_state[k] = v.to(dtype=target_dtype, device=p_state[k].device) + else: + p_state[k] = v + + # ----------------------------------- + # Unified optimizer step with explicit ordering + + @torch.no_grad() + def step(self, do_adam: bool = True): + """ + Combined optimizer step with explicit ordering. + + Args: + do_adam: If True, update Adam params. NorMuon params always updated. + + Flow: + 1. Scatter phase: Launch reduces in scatter_order + 2. Work phase: Process updates in work_order + - Wait for reduce, compute update, launch gather + 3. Finalize phase: Wait for gathers + + While the embeddings are tied: + - Comms and update math are only done on lm_head. + - We add embed.grad.T into lm_head.grad before comms. + - After lm_head gather, we copy lm_head.data.T --> embed.data + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + lm_param, embed_param = self._lm_head_param, self._embed_param + + # ===== Phase 1: Launch reduces in scatter_order ===== + for label in self.scatter_order: + param = self._param_by_label[label] + p_cfg = self.param_cfgs[param] + + if p_cfg.optim == "adam" and not do_adam: + continue + if param.grad is None: + continue + + # lm_head when tied: aggregate embed.grad.T (transposed shapes) + if label == "lm_head" and do_adam and not self.split_embed: + if embed_param is not None and embed_param.grad is not None: + param.grad.add_(embed_param.grad.T) + + # Skip embed when tied (copied from lm_head after gather) + if label == "embed" and not self.split_embed: + continue + + self._launch_reduce(param, param.grad) + + # ===== Phase 2: Process updates in work_order ===== + gather_futures = [] + lm_head_gather_future = None + + for label in self.work_order: + param = self._param_by_label[label] + if param not in self._reduce_futures: + continue + + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "adam" and not do_adam: + continue + # Wait for reduce + future, grad_chunk = self._reduce_futures[param] + if future is not None: + future.wait() + # Apply update based on optim type + if p_cfg.optim == "adam": + p_slice = self._adam_update(param, grad_chunk, p_cfg, rank) + else: + p_slice = self._normuon_update(param, grad_chunk, p_cfg, rank) + # Launch gather for sharded params + if p_cfg.comms == "sharded" and self.world_size > 1: + gather_fut = self._launch_gather(param, p_slice) + if label == "lm_head": + lm_head_gather_future = gather_fut + else: + gather_futures.append(gather_fut) + + # ===== Phase 3: Wait for gathers, sync embed if tied ===== + # Wait for lm_head gather first so we can copy to embed while other gathers complete + if lm_head_gather_future is not None: + lm_head_gather_future.wait() + + # When tied: copy lm_head.T to embed + if do_adam and not self.split_embed and embed_param is not None and lm_param is not None: + embed_param.data.copy_(lm_param.data.T) + + # Wait for remaining gathers + for fut in gather_futures: + fut.wait() + + self._reduce_futures.clear() + + # Clear grads for updated params + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam" and not do_adam: + continue # Don't clear Adam grads on even steps + param.grad = None + + # ----------------------------------- + # Adam update + + def _adam_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply Adam update to a parameter. Returns the updated p_slice.""" + beta1, beta2 = p_cfg.adam_betas + lr = p_cfg.lr * p_cfg.lr_mul + + # Get parameter slice + if p_cfg.comms == "sharded": + p_slice = param[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + else: + p_slice = param + + p_state = self.param_states[param] + p_state["step"] += 1 + t = p_state["step"] + + bias1, bias2 = 1 - beta1 ** t, 1 - beta2 ** t + self._step_size_t.fill_(lr * (bias2 ** 0.5 / bias1)) + self._eff_wd_t.fill_(lr * lr * p_cfg.weight_decay * p_cfg.wd_mul) + + NorMuonAndAdam._adam_update_step( + p_slice, grad_chunk, p_state["exp_avg"], p_state["exp_avg_sq"], + beta1, beta2, p_cfg.eps, self._step_size_t, self._eff_wd_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _adam_update_step(p_slice, g_slice, exp_avg, exp_avg_sq, beta1, beta2, eps, step_size_t, eff_wd_t): + """Compiled Adam update step.""" + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + update = exp_avg.div(exp_avg_sq.sqrt().add_(eps)).mul_(step_size_t) + # Cautious weight decay + mask = (update * p_slice) > 0 + update.addcmul_(p_slice, mask, value=eff_wd_t) + p_slice.add_(other=update, alpha=-1.0) + + # ----------------------------------- + # NorMuon update + + def _normuon_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply NorMuon update to a parameter. Returns the updated p_slice.""" + chunk_shape = grad_chunk.shape + + p_state = self.param_states[param] + grad_chunk = grad_chunk.float() # FP32 for momentum + + # Momentum update + momentum_buffer = p_state["momentum_buffer"] + momentum_buffer.lerp_(grad_chunk, 1 - p_cfg.momentum) + updated_grads = grad_chunk.lerp_(momentum_buffer, p_cfg.momentum) + + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + + # Polar Express orthogonalization + is_large_matrix = chunk_shape[-2] > 1024 + v_chunk = polar_express(updated_grads, split_baddbmm=is_large_matrix) + + # Variance reduction + red_dim = -1 if chunk_shape[-2] >= chunk_shape[-1] else -2 + v_chunk = NorMuonAndAdam._apply_normuon_variance_reduction( + v_chunk, p_state["second_momentum_buffer"], p_cfg.beta2, red_dim + ) + + # Update parameter, in place, with cautious weight decay + param_view = param.data.view(p_cfg.reshape) + p_slice = param_view[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + + # MLP has per-matrix LR multipliers (c_proj gets 2x LR) + if p_cfg.per_matrix_lr_mul is not None: + for mat_idx in range(p_cfg.chunk_size): + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.per_matrix_lr_mul[mat_idx] * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice[mat_idx].view(torch.uint16), p_state["mantissa"][mat_idx], v_chunk[mat_idx], + self._eff_wd_t, self._eff_lr_t + ) + else: + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice.view(torch.uint16), p_state["mantissa"], v_chunk, + self._eff_wd_t, self._eff_lr_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _cautious_wd_and_update_inplace(p, mantissa, grad, wd_tensor, lr_tensor): + """ + Cautious weight decay + parameter update. wd_tensor and lr_tensor are 0-D CPU tensors. + Mantissa is tracked to enable higher precision updates on bfloat16 parameters. + bfloat16 format: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits total + float32 format: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits total + """ + assert p.dtype == mantissa.dtype == torch.uint16 + grad = grad.float() + wd_factor = wd_tensor.to(torch.float32) + lr_factor = lr_tensor.to(torch.float32) + p_precise_raw = (p.to(torch.uint32) << 16) | mantissa.to(torch.uint32) + p_precise = p_precise_raw.view(torch.float32) + mask = (grad * p_precise) >= 0 + p_precise.copy_(p_precise - (p_precise * mask * wd_factor * lr_factor) - (grad * lr_factor)) + p.copy_((p_precise_raw >> 16).to(torch.uint16)) + mantissa.copy_(p_precise_raw.to(torch.uint16)) + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _apply_normuon_variance_reduction(v_chunk, second_momentum_buffer, beta2, red_dim): + """NorMuon variance reduction. Algebraically fuses the normalization steps to minimize memory ops.""" + v_mean = v_chunk.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = v_chunk.size(red_dim) + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True).mul_(red_dim_size) + v_norm = v_norm_sq.sqrt_() + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt_() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt_() + final_scale = step_size * (v_norm / v_norm_new.clamp_min_(1e-10)) + return v_chunk.mul_(final_scale.type_as(v_chunk)) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinearT(nn.Module): + """ + Linear layer with transposed weight storage (in_features, out_features) which + addresses the slow kernel that was used for gradient accumulation. @chrisjmccormick + """ + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + self.weight = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.bfloat16)) + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + nn.init.zeros_(self.weight) # @Grad62304977 and others + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out = torch.ops.nanogpt.mm_t(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return x @ self.weight.type_as(x) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len, paired=False): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.paired = paired + self.reset() + + def rotary(self, x_BTHD): + assert self.factor1.size(0) >= x_BTHD.size(-3) + factor1, factor2 = ( + self.factor1[None, : x_BTHD.size(-3), None, :], + self.factor2[None, : x_BTHD.size(-3), None, :], + ) + x_flip = x_BTHD.view(*x_BTHD.shape[:-1], x_BTHD.shape[-1] // 2, 2).flip(-1).view(x_BTHD.shape) + return factor1 * x_BTHD + factor2 * x_flip + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + angular_freq = angular_freq.repeat_interleave(2) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//2)]) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=device) + if not self.paired: + theta = torch.outer(t, angular_freq) + self.factor1 = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.factor2 = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, angular_freq) + theta2 = torch.outer(t_odd, angular_freq) + self.factor1 = nn.Buffer( + torch.cat((theta1.cos(), theta2.cos()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2 = nn.Buffer( + torch.cat((theta1.sin(), theta2.sin()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2[..., 1::2] *= -1 + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + if not self.paired: + theta = torch.outer(t, self.angular_freq) + self.factor1.copy_(theta.cos()) + self.factor2.copy_(theta.sin()) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, self.angular_freq) + theta2 = torch.outer(t_odd, self.angular_freq) + self.factor1.copy_(torch.cat((theta1.cos(), theta2.cos()), dim=-1)) + self.factor2.copy_(torch.cat((theta1.sin(), theta2.sin()), dim=-1)) + self.factor2[..., 1::2] *= -1 + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + yarn: Yarn + key_offset: bool + attn_gate_w: torch.Tensor + ve_gate_w: torch.Tensor + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, paired: bool = False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + self.paired = paired + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + yarn = attn_args.yarn + ve, sa_lambdas, key_offset = attn_args.ve, attn_args.sa_lambdas, attn_args.key_offset + seqlens, bm_size = attn_args.seqlens, attn_args.bm_size + # sparse gated attention to enable context based no-op by @classiclarryd + # only include gates on layers with value embeds used on forward pass + attn_gate_w, ve_gate_w = attn_args.attn_gate_w, attn_args.ve_gate_w + + q, k, v = F.linear(x, sa_lambdas[0] * qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + q, k = norm(q), norm(k) # QK norm @Grad62304977 + + if not self.paired: + q, k = yarn.rotary(q), yarn.rotary(k) + + if key_offset: + # shift keys forward for the stationary head dims. Enables 1-layer induction. + k[:, 1:, :, self.head_dim // 2:] = k[:, :-1, :, self.head_dim // 2:] + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T, self.num_heads, 1) + v = v + ve_gate_out * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + + else: + # Paired heads: adjacent heads' queries attend to each other's keys. + # Two copies of the input stream are interleaved to achieve this, which: + # - doubles the length of each sequence + # - halves the effective window size + q = q.view(B, T, self.num_heads // 2, self.head_dim * 2) + k = k.view(B, T, self.num_heads // 2, self.head_dim * 2) + v = v.reshape(B, T * 2, self.num_heads // 2, self.head_dim) + + q, k = yarn.rotary(q), yarn.rotary(k) + + q = q.view(B, T * 2, self.num_heads // 2, self.head_dim) + k = k.view(B, T * 2, self.num_heads // 2, self.head_dim) + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T * 2, self.num_heads // 2, 1) + v = v + ve_gate_out * ve.view_as(v) + + seqlens = 2 * seqlens + max_len = 2 * max_len + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, + max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=yarn.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(F.linear(x[..., :12], attn_gate_w)).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, sa_lambdas[1] * qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg + return y + +class MLP(nn.Module): + def __init__(self): + super().__init__() + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, c_fc: Tensor, c_proj: Tensor): + # relu(x)^2: + # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + # Fused triton kernel for relu(x @ W1.T)^2 @ W2.T + return FusedLinearReLUSquareFunction.apply(x, c_fc, c_proj) + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, has_attn: bool, has_mlp: bool, use_paired_head: bool): + super().__init__() + # skip attention of blocks.6 (the 7th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads, paired=use_paired_head) if has_attn else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP() if has_mlp else None + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor = None, c_fc: Tensor = None, c_proj: Tensor = None): + if self.attn is not None: + x = x + self.attn(norm(x), attn_args, qkvo_w) + if self.mlp is not None: + x = x + self.mlp(norm(x), c_fc, c_proj) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +@dataclass +class ForwardScheduleConfig: + mtp_weights: torch.Tensor + ws_short: int + ws_long: int + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + self.num_layers = num_layers + self.vocab_size = next_multiple_of_n(vocab_size, n=128) + + self.smear_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.smear_gate.weight) + self.smear_gate.weight.label = 'smear_gate' + + self.skip_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.skip_gate.weight) + self.skip_gate.weight.label = 'skip_gate' + + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.Parameter(torch.zeros(5 * self.vocab_size, model_dim, dtype=torch.bfloat16)) + self.value_embeds.label = 'value_embed' + + # parameter banks for attention and value embedding gate weights + self.attn_gate_bank = nn.Parameter(torch.zeros(10, num_heads, 12)) # 10 layers + self.attn_gate_bank.label = 'attn_gate_bank' + self.ve_gate_bank = nn.Parameter(torch.zeros(5, num_heads, 12)) # 5 unique gates + self.ve_gate_bank.label = 've_gate_bank' + + # ----------------------------------- + # Parameter banks for sharded optimization, by @chrisjmccormick + + # Identify which layers have attention/MLP + # Attention is skipped in layer 6 by @YouJiacheng + self.attn_layer_indices = [i for i in range(num_layers) if i != 6] + # All layers have MLP (At 11 layers--dropped first layer @EmelyanenkoK) + self.mlp_layer_indices = list(range(num_layers)) + + hdim = num_heads * head_dim + mlp_hdim = 4 * model_dim + + # Create index mappings: layer_idx -> bank_idx + self.layer_to_attn_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.attn_layer_indices)} + self.layer_to_mlp_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.mlp_layer_indices)} + + # Attention bank: stores QKVO weights for all attention layers + # merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # Simplified layout by @chrisjmccormick + # Shape: (num_attn_layers, 4*model_dim, hdim) = (10, 3072, 768) + # Reshape for sharding: (40, 768, 768) for even distribution across 8 GPUs + self.attn_bank = nn.Parameter(torch.empty(len(self.attn_layer_indices), 4 * model_dim, hdim)) + self.attn_bank.label = 'attn' + self.attn_bank.reshape = (len(self.attn_layer_indices) * 4, hdim, hdim) # (40, 768, 768) + + # MLP bank: stores c_fc and c_proj for all MLP layers + # Shape: (num_mlp_layers + padding, 2, mlp_hdim, model_dim) = (12, 2, 3072, 768) + # We add 1 padding layer (index 11) to get 12*2=24 matrices for even distribution across 8 GPUs + # Reshape for sharding: (24, 3072, 768) + num_mlp_with_padding = len(self.mlp_layer_indices) + 1 # 11 + 1 = 12 + self.mlp_bank = nn.Parameter(torch.empty(num_mlp_with_padding, 2, mlp_hdim, model_dim)) + self.mlp_bank.label = 'mlp' + self.mlp_bank.reshape = (num_mlp_with_padding * 2, mlp_hdim, model_dim) # (24, 3072, 768) + + # improved init scale by @YouJiacheng and @srashedll + std = 0.5 * model_dim ** -0.5 + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.attn_bank.uniform_(-bound, bound) + self.mlp_bank[:, 0, :, :].uniform_(-bound, bound) # c_fc + self.mlp_bank[:, 1, :, :].zero_() # c_proj - zero init suggested by @Grad62304977 + + # Create blocks with has_attn/has_mlp flags + self.paired_head_layers = [0, 2, 5, 9] + self.blocks = nn.ModuleList([ + Block(model_dim, head_dim, num_heads, + has_attn=(i in self.layer_to_attn_idx), + has_mlp=(i in self.layer_to_mlp_idx), + use_paired_head=(i in self.paired_head_layers)) + for i in range(num_layers) + ]) + self.yarn = Yarn(head_dim, max_seq_len) + self.yarn_paired_head = Yarn(head_dim, max_seq_len, paired=True) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + # Transposed weight storage for faster gradient accumulation + self.lm_head = CastedLinearT(model_dim, self.vocab_size, use_fp8=use_fp8, x_s=100/448, w_s=1.6/448, grad_s=grad_scale * 0.75/448) + + nn.init.normal_(self.lm_head.weight, mean=0, std=0.005) + self.lm_head.weight.label = 'lm_head' + + self.embed = nn.Embedding(self.vocab_size, model_dim) + self.embed.weight.label = 'embed' + with torch.no_grad(): + self.embed.weight.copy_(self.lm_head.weight.T) + + self.bigram_embed = nn.Embedding(args.bigram_vocab_size, model_dim) + self.bigram_embed.weight.label = 'bigram_embed' + nn.init.zeros_(self.bigram_embed.weight) + + # x0_lambdas separated out for different optimizer treatment (no beta smoothing) + self.x0_lambdas = nn.Parameter(torch.zeros(num_layers)) + self.x0_lambdas.label = 'x0_lambdas' + + pad = (-num_layers * 3 - 3) % dist.get_world_size() # updated: 3*num_layers instead of 4* + self.scalars = nn.Parameter( + torch.cat( + [ + 1.1 * torch.ones(num_layers), # resid lambdas. 1.1 init such that layer i weight is i^(num_layers-i). + *[torch.tensor([0.5, 1.0]) for _ in range(num_layers)], # SA lambdas + 0.1 * torch.ones(num_layers), # bigram lambdas + torch.zeros(1), # smear_lambda + 0.5*torch.ones(1), # backout_lambda + -1.5 * torch.ones(1), # skip_lambda -> σ(-1.5) ≈ 0.18 + torch.ones(pad), + ] + ) + ) + self.scalars.label = 'scalars' + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _compute_bigram_hash(x: Tensor, mod: int) -> Tensor: + """ + Computes bigram hash on GPU for each position using [prev_token, curr_token]. + Mathematically identical to the CPU version but computed on device. + """ + rand_int_1 = 36313 + rand_int_2 = 27191 + result = torch.empty_like(x) + result[0] = mod + result[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod + return result + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, schedule_cfg: ForwardScheduleConfig): + assert input_seq.ndim == 1 + + # unpack schedule_cfg + mtp_weights, ws_short, ws_long = schedule_cfg.mtp_weights, schedule_cfg.ws_short, schedule_cfg.ws_long + + # set configs + skip_connections = [] + skip_in = [3] # long attention window on layer 3 + skip_out = [6] # no attn op on layer 6 + x_backout = None + backout_layer = 7 + + # set lambdas + resid_lambdas = self.scalars[: 1 * self.num_layers] + x0_lambdas = self.x0_lambdas + sa_lambdas = self.scalars[1 * self.num_layers: 3 * self.num_layers].view(-1, 2) + bigram_lambdas = self.scalars[3 * self.num_layers: 4 * self.num_layers] + smear_lambda = self.scalars[4 * self.num_layers] + backout_lambda = self.scalars[4 * self.num_layers+1] + skip_lambda = self.scalars[4 * self.num_layers+2] + + # set block masks and key shift + bm_sizes = [ws_short, ws_short, ws_short, ws_long, ws_short, ws_short, None, ws_short, ws_short, ws_short, ws_long] + assert len(bm_sizes) == self.num_layers + key_offset = [b==ws_long for b in bm_sizes] # apply partial key offset to long windows + + # Embedding lookup - embed is synced from lm_head during tied phase by optimizer + x = self.embed(input_seq) + # Compute bigram hash on GPU (moved from CPU data loader) + bigram_seq = self._compute_bigram_hash(input_seq, args.bigram_vocab_size - 1) + x0_bigram = self.bigram_embed(bigram_seq)[None] + + # Value embeddings - always computed (not precomputed) + ve = self.value_embeds.view(5, self.vocab_size, -1)[:, input_seq] + # 01 ... 234 structure on token value embeddings by @photomz + ve = [ve[0], ve[1]] + [None] * (self.num_layers - 5) + [ve[2], ve[3], ve[4]] + assert len(ve) == self.num_layers + + # smear token embed forward 1 position @classiclarryd + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # unbind gate banks to avoid select_backwards kernel + ag = [w.bfloat16() for w in self.attn_gate_bank.unbind(0)] + veg = [w.bfloat16() for w in self.ve_gate_bank.unbind(0)] + attn_gates = ag[:6] + [None] + ag[6:] + ve_gates = [veg[0], veg[1]] + [None] * (self.num_layers - 5) + [veg[2], veg[3], veg[4]] + assert len(attn_gates) == self.num_layers + assert len(ve_gates) == self.num_layers + + # unbind weight banks to avoid select_backwards kernel + attn_weights = self.attn_bank.unbind(0) # tuple of [4*dim, hdim] tensors + mlp_fcs = self.mlp_bank[:, 0, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + mlp_projs = self.mlp_bank[:, 1, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + + for i in range(self.num_layers): + yarn = self.yarn_paired_head if i in self.paired_head_layers else self.yarn + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + yarn=yarn, + key_offset=key_offset[i], + attn_gate_w=attn_gates[i], + ve_gate_w=ve_gates[i] + ) + if i in skip_out: + skip_gate_out = torch.sigmoid(skip_lambda) * 2 * torch.sigmoid(self.skip_gate(x0[..., :self.skip_gate.weight.size(-1)])) + x = x + skip_gate_out * skip_connections.pop() + if i == 0: + x = (resid_lambdas[0] + x0_lambdas[0]) * x + bigram_lambdas[0] * x0_bigram + else: + x = resid_lambdas[i] * x + x0_lambdas[i] * x0 + bigram_lambdas[i] * x0_bigram + + # Get weights for this layer from banks + qkvo_w = attn_weights[self.layer_to_attn_idx[i]] if i in self.layer_to_attn_idx else None + c_fc = mlp_fcs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + c_proj = mlp_projs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + + x = self.blocks[i](x, attn_args, qkvo_w, c_fc, c_proj) + if i in skip_in: + skip_connections.append(x) + if i == backout_layer: + x_backout = x + + # back out contributions from first 7 layers that are only required for downstream context and not direct prediction + x -= backout_lambda * x_backout + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15 + # @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1). @classiclarryd updated to 23*sigmoid((logits+5)/7.5) + if self.training: + losses = FusedSoftcappedCrossEntropy.apply(logits.view(-1, logits.size(-1)), target_seq, mtp_weights, 23.0, 5.0, 7.5) + loss = losses.sum() + else: + logits = 23 * torch.sigmoid((logits + 5) / 7.5) + logits_for_loss = logits.float() + loss = F.cross_entropy(logits_for_loss.view(-1, logits_for_loss.size(-1)), target_seq, reduction="mean") + return loss +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class Shard: + def __init__(self, tokens: Tensor, world_size: int = 1): + self.tokens = tokens + self.size = tokens.numel() + self.world_size = world_size + self.i = 0 + + # Partial index now, full index async + self.bos_idx = (tokens[:6_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._full_idx = None + self._loader_thread = None + self._ready = threading.Event() + self._loader_thread = threading.Thread(target=self._scan) + self._loader_thread.start() + + def _scan(self): + self._full_idx = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._ready.set() + + def _maybe_switch(self): + # Switch to full index as soon as async scan completes + if self.bos_idx is not self._full_idx and self._ready.is_set(): + self._loader_thread.join() + self.bos_idx = self._full_idx + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + self._maybe_switch() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + return starts, ends + + @staticmethod + def load_async(file: Path, world_size: int = 1): + """Returns getter function for async shard loading""" + result = {} + ready = threading.Event() + def load(): + tokens = _load_data_shard(file) + result['shard'] = Shard(tokens, world_size) + ready.set() + thread = threading.Thread(target=load) + thread.start() + def get(): + ready.wait() + thread.join() + return result['shard'] + return get + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + shard = Shard(tokens, world_size) + next_shard_getter = Shard.load_async(next(file_iter), world_size) + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = shard.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + shard = next_shard_getter() + tokens = shard.tokens + try: + next_shard_getter = Shard.load_async(next(file_iter), world_size) + except StopIteration: + next_shard_getter = None # no more shards to preload + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + # Cast to int32 on CPU before transfer to avoid dtype conversion during .to() + _inputs = _inputs.to(dtype=torch.int32) + _targets = _targets.to(dtype=torch.int64) + _cum_lengths = _cum_lengths.to(dtype=torch.int32) + # Bigram hash computation moved to GPU in forward() + + new_params = yield ( + _inputs.to(device="cuda", non_blocking=True), + _targets.to(device="cuda", non_blocking=True), + _cum_lengths.to(device="cuda", non_blocking=True), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * new_grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens // new_grad_accum_steps + max_seq_len = new_max_seq_len + +# ----------------------------------------------------------------------------- +# Training Management + +@dataclass +class Hyperparameters: + # data + data_path = os.environ.get("DATA_PATH", ".") + train_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_train_*.bin") # input .bin to train on + val_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_val_*.bin") # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + # batch sizes + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # schedule + num_scheduled_iterations: int = 1515 # number of steps to complete lr and ws schedule + num_extension_iterations: int = 40 # number of steps to continue training at final lr and ws + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 250 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # bigram hash embedding + bigram_vocab_size: int = 50304 * 5 + +args = Hyperparameters() + +@dataclass +class TrainingStage: + lr_mul: float + batch_size: int + window_sizes: tuple[int, int] # (short, long) in block units + mtp_weights_start: list[float] + mtp_weights_end: list[float] + duration: float = None + +class TrainingSchedule: + """ + Training schedule initialized via TRAINING_STAGES + 1. Multi Token Prediction schedule of [1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1] @varunneal + 2. Sliding Attention window schedule of [1,3] -> [3,7] -> [5,11] -> [6,13] + 3. YaRN updates to RoPE on window changes + 4. Split embed and lm head at 2/3 of training + 5. Batch size schedule of 8 -> 16 -> 24 + 6. Post training extension of long windows from 13 to 20 + """ + + def __init__(self, stages: list[TrainingStage], scheduled_iterations: int, extension_iterations: int, + cooldown_frac: float = 0.5, split_embed_stage: int = 2, ws_post_yarn_ext: int = 20): + self.stages = stages + self.scheduled_iterations = scheduled_iterations + self.cooldown_frac = cooldown_frac + # increase final validation ws, used for YaRN extension and short window size @classiclarryd + self.ws_post_yarn_ext = ws_post_yarn_ext + + self.total_steps = self.scheduled_iterations + extension_iterations + + # Build stage boundaries (last is extension stage) + ends = [0] + [round(c * scheduled_iterations) for c in accumulate(s.duration for s in stages[:-1])] + [self.total_steps] + assert self.scheduled_iterations == ends[-2] + self.boundaries = list(pairwise(ends)) + + # Split embed at specified stage (ensure odd step for Adam) + self.split_step = self.boundaries[split_embed_stage][0] | 1 + + # Precompute MTP weights for all steps + self.mtp_weights = [] + for step in range(self.total_steps + 1): + stage, t = self.lookup(step) + w = [a + (b - a) * t for a, b in zip(stage.mtp_weights_start, stage.mtp_weights_end)] + self.mtp_weights.append(torch.tensor(w, device=device)) + + def lookup(self, step: int) -> tuple[TrainingStage, float]: + # Returns stage and % of the way through that stage + for i, (start, end) in enumerate(self.boundaries): + if step < end: + t = (step - start) / (end - start) + return self.stages[i], t + return self.stages[-1], 1.0 + + def get_lr(self, step: int) -> float: + # learning rate schedule: tied to batch size schedule, with cooldown at the end + stage, _ = self.lookup(step) + lr = stage.lr_mul + cd_start = int(self.scheduled_iterations * (1 - self.cooldown_frac)) + if step >= cd_start: + t = min(1.0, (step - cd_start) / (self.scheduled_iterations - cd_start)) + lr = lr * (1 - t) + 0.1 * t + return lr + +# window_sizes are in units of `block_size` tokens (defined in TrainingManager) +TRAINING_STAGES = [ + TrainingStage(duration=1/3, batch_size=8 * 2048 * 8, window_sizes=(1, 3), lr_mul=1.0, + mtp_weights_start=[1.0, 0.5, 0.25], mtp_weights_end=[1.0, 0.5, 0.0]), + TrainingStage(duration=1/3, batch_size=16 * 2048 * 8, window_sizes=(3, 7), lr_mul=1.52, # (16/8)**0.6 + mtp_weights_start=[1.0, 0.5], mtp_weights_end=[1.0, 0.0]), + TrainingStage(duration=1/3, batch_size=24 * 2048 * 8, window_sizes=(5, 11), lr_mul=1.73, # (24/8)**0.5 + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), + # extension stage + TrainingStage(batch_size=24 * 2048 * 8, window_sizes=(6, 13), lr_mul=1.0, # lr_mul is not used + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), +] + +training_schedule = TrainingSchedule(TRAINING_STAGES, args.num_scheduled_iterations, args.num_extension_iterations, cooldown_frac=0.55) + +def get_muon_momentum(step: int, muon_warmup_steps=300, muon_cooldown_steps=50, momentum_min=0.85, momentum_max=0.95): + # warmup phase: linearly increase momentum from min to max + # cooldown phase: linearly decrease momentum from max to min + momentum_cd_start = training_schedule.total_steps - muon_cooldown_steps + if step < muon_warmup_steps: + frac = step / muon_warmup_steps + momentum = momentum_min + frac * (momentum_max - momentum_min) + elif step > momentum_cd_start: + frac = (step - momentum_cd_start) / muon_cooldown_steps + momentum = momentum_max - frac * (momentum_max - momentum_min) + else: + momentum = momentum_max + return momentum + +class TrainingManager(): + """ + Manages the NorMuonAndAdam for all parameters with explicit ordering. + 1. Scalars are given higher momentum terms to smooth learning @ChrisJMcCormick + 2. Adam optimizers are only stepped on odd steps @classiclarryd + 3. Explicit scatter_order and work_order for communication scheduling (no backward hooks) + 4. Muon has a linear momentum warmup and cooldown schedule + 5. Learning rates follow a linear decay schedule + 6. Embed is tied to lm_head until split step (2/3 of training), then untied @classiclarryd + """ + def __init__(self, model): + self.model = model + self.block_size = 128 + + # - Ordering dictates when to launch reduce/reduce_scatter operations + # - "sharded" parameters use reduce_scatter/all_gather and "replicated" ones use all_reduce + # - lr_mul and wd_mul are per-parameter learning rate and weight decay multipliers + self.param_table = { + "attn": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "mlp": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "scalars": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 5.0, "wd_mul": 0.0}, + "value_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "bigram_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "smear_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.01, "wd_mul": 0.0}, + "skip_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.05, "wd_mul": 0.0}, + "attn_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "ve_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "x0_lambdas": {"optim": "adam", "comms": "replicated", "adam_betas": [0.65, 0.95], "lr_mul": 5.0, "wd_mul": 0.0}, + "lm_head": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + "embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + } + + # - Process smaller/faster params first while large reduces complete + # - lm_head must complete before embed sync (when tied) + self.work_order = [ + "scalars", "smear_gate", "skip_gate", "attn_gate_bank", "ve_gate_bank", "x0_lambdas", # Small, fast + "value_embed", "bigram_embed", # Medium + "lm_head", "embed", # lm_head must complete before embed sync (when tied) + "attn", "mlp", # Large, polar express - process last to maximize overlap + ] + + adam_defaults = dict( + lr=0.008, + eps=1e-10, + weight_decay=0.005, + ) + + normuon_defaults = dict( + lr=0.023, + momentum=0.95, + beta2=0.95, + weight_decay=1.2, + ) + + self.optimizer = NorMuonAndAdam( + model.named_parameters(), + param_table=self.param_table, + scatter_order=list(self.param_table.keys()), # Dict order defines scatter priority + work_order=self.work_order, + adam_defaults=adam_defaults, + normuon_defaults=normuon_defaults, + ) + + # Split embed from lm_head at 2/3 of training (on an odd step so Adam updates) + self.split_step = training_schedule.split_step + + self.reset() + + def apply_final_ws_ext(self): + self.ws_long = training_schedule.ws_post_yarn_ext + + def get_forward_args(self): + return ForwardScheduleConfig( + mtp_weights = self.mtp_weights, + ws_short = self.ws_short * self.block_size, + ws_long = self.ws_long * self.block_size + ) + + def _is_adam_step(self, step: int): + """Adam params are only updated on odd steps.""" + return step % 2 == 1 + + def get_transition_steps(self): + return [start for start, _ in training_schedule.boundaries[1:]] + + def advance_schedule(self, step: int): + stage, _ = training_schedule.lookup(step) + self.ws_short, new_ws_long = stage.window_sizes + if new_ws_long != self.ws_long: + self.model.yarn.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + self.model.yarn_paired_head.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + + new_batch_size = stage.batch_size + if new_batch_size != self.batch_size: + self.train_loader_send_args = (new_batch_size, args.train_max_seq_len, grad_accum_steps) + self.batch_size = new_batch_size + else: + self.train_loader_send_args = None + + self.ws_long = new_ws_long + self.mtp_weights = training_schedule.mtp_weights[step] + + def step_optimizers(self, step: int): + step_lr = training_schedule.get_lr(step) + muon_momentum = get_muon_momentum(step) + do_adam = self._is_adam_step(step) + + # Update learning rates and momentum for all params + for param, p_cfg in self.optimizer.param_cfgs.items(): + p_cfg.lr = p_cfg.initial_lr * step_lr + if p_cfg.optim == "normuon": + p_cfg.momentum = muon_momentum + + # Step optimizer with do_adam flag + self.optimizer.step(do_adam=do_adam) + + # At split step: copy lm_head optimizer state to embed and mark as split + if step == self.split_step: + self.optimizer.copy_lm_state_to_embed() + + def reset(self, state=None): + if state is not None: + self.optimizer.load_state_dict(state) + + # Reset NorMuon momentum buffers and split_embed state + self.optimizer.reset() + + stage, _ = training_schedule.lookup(0) + self.ws_short, self.ws_long = stage.window_sizes + self.batch_size = stage.batch_size + self.model.yarn.reset() + self.model.yarn_paired_head.reset() + + def get_state(self): + return copy.deepcopy(self.optimizer.state_dict()) + +# ----------------------------------------------------------------------------- +# int main + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=11, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=args.val_batch_size // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.weight.data = m.weight.data.bfloat16() +model.attn_gate_bank.data = model.attn_gate_bank.data.bfloat16() +model.ve_gate_bank.data = model.ve_gate_bank.data.bfloat16() +model.attn_bank.data = model.attn_bank.data.bfloat16() +model.mlp_bank.data = model.mlp_bank.data.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) +training_manager = TrainingManager(model) + +######################################## +# Warmup kernels # +######################################## +print0("Compiling model and warming up kernels (~7 minutes on first execution)", console=True) +# Warmup the training kernels, then re-initialize the state so we aren't cheating +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizer=training_manager.get_state()) # save the initial state +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + +transition_steps = training_manager.get_transition_steps() +# first few steps plus transitions +warmup_steps = sorted({0, 1, 2} | set(s + offset for s in transition_steps for offset in [-1, 0, 1] if s + offset >= 0)) +print0(f"Sampling steps {warmup_steps} for warmup", console=True) +for step in warmup_steps: + training_manager.advance_schedule(step) + model.eval() + with torch.no_grad(): + inputs, targets, cum_seqlens = next(val_loader) + model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + model.train() + for idx in range(grad_accum_steps): + send_args = training_manager.train_loader_send_args + inputs, targets, cum_seqlens = train_loader.send(send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) +print0("Resetting Model", console=True) +model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +training_manager.reset(initial_state["optimizer"]) +del val_loader, train_loader, initial_state +model.train() + +######################################## +# Training and validation # +######################################## +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) + +gc.collect() + +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = training_schedule.total_steps +for step in range(train_steps + 1): + last_step = (step == train_steps) + training_manager.advance_schedule(step) + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + training_manager.apply_final_ws_ext() + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + val_loss /= val_steps + del val_loader + dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizer=training_manager.get_state()) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for idx in range(grad_accum_steps): + inputs, targets, cum_seqlens = train_loader.send(training_manager.train_loader_send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) + + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + + +---------------------------------------- +# triton_kernels.py +---------------------------------------- + +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded configs based on H100 autotuning + if K == 768: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + else: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 128 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded config based on H100 autotuning (M=768) + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +# ----------------------------------------------------------------------------- +# Triton kernel for MLP: relu(x @ W1.T)^2, by @andrewbriand, @jrauvola + +@triton.jit +def linear_relu_square_kernel(a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, + ): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am_c = pid_m * BLOCK_SIZE_M + offs_bn_c = pid_n * BLOCK_SIZE_N + + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + + c0 = acc0.to(dtype) + if not FORWARD: + c0_pre = aux_desc.load([offs_am_c, offs_bn_c]) + c0 = 2 * c0 * tl.where(c0_pre > 0, c0_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c], c0) + + if FORWARD: + c0_post = tl.maximum(c0, 0) + c0_post = c0_post * c0_post + aux_desc.store([offs_am_c, offs_bn_c], c0_post) + + c1 = acc1.to(dtype) + if not FORWARD: + c1_pre = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c1 = 2 * c1 * tl.where(c1_pre > 0, c1_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + + if FORWARD: + c1_post = tl.maximum(c1, 0) + c1_post = c1_post * c1_post + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1_post) + + +def linear_relu_square(a, b, aux=None): + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + FORWARD = False + if aux is None: + FORWARD = True + aux = torch.empty((M, N), device=a.device, dtype=dtype) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_stages = 4 if FORWARD else 3 + num_warps = 8 + + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + + def grid(META): + return (min( + NUM_SMS, + triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), + ), ) + + linear_relu_square_kernel[grid]( + a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + NUM_SMS=NUM_SMS, + FORWARD=FORWARD, + num_stages=num_stages, + num_warps=num_warps + ) + + if FORWARD: + return c, aux + else: + return c + +class FusedLinearReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, W1, W2): + pre, post = linear_relu_square(x.view((-1, x.shape[-1])), W1) + x3 = post @ W2 + ctx.save_for_backward(x, W1, W2, pre, post) + return x3.view(x.shape) + + @staticmethod + def backward(ctx, grad_output): + x, W1, W2, pre, post = ctx.saved_tensors + dW2 = post.T @ grad_output + dpre = linear_relu_square(grad_output.view((-1, grad_output.shape[-1])), W2, aux=pre) + dW1 = dpre.T @ x + dx = dpre @ W1 + return dx.view(x.shape), dW1, dW2 + +# ----------------------------------------------------------------------------- +# Fused Softcapped Cross Entropy + + +@triton.jit +def fused_softcapped_entropy_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + + max_val = -float('inf') + sum_exp = 0.0 + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=-float('inf')).to(tl.float32) + z = A * tl.sigmoid((val + B) / C) + z = tl.where(mask, z, -float('inf')) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + + total_loss = 0.0 + for k in range(n_predict): + target_idx = row_idx + k + if target_idx < n_rows: + weight = tl.load(mtp_weights_ptr + k) + if weight > 0: + target = tl.load(targets_ptr + target_idx).to(tl.int32) + if target >= 0 and target < n_cols: + val_target = tl.load(logits_row_ptr + target).to(tl.float32) + z_target = A * tl.sigmoid((val_target + B) / C) + total_loss += weight * (lse - z_target) + + tl.store(losses_ptr + row_idx, total_loss) + +@triton.jit +def fused_softcapped_entropy_bwd_kernel( + grad_input_ptr, grad_output_ptr, lse_ptr, logits_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, stride_grad_n, stride_grad_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_input_ptr + row_idx * stride_grad_n + + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_output_ptr + row_idx) + + S_w = 0.0 + for k in range(n_predict): + if row_idx + k < n_rows: + S_w += tl.load(mtp_weights_ptr + k) + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=0.0).to(tl.float32) + u = (val + B) / C + sigmoid_u = tl.sigmoid(u) + z = A * sigmoid_u + p = tl.exp(z - lse) + + term1 = S_w * p + term2 = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for k in range(n_predict): + if row_idx + k < n_rows: + target = tl.load(targets_ptr + row_idx + k).to(tl.int32) + weight = tl.load(mtp_weights_ptr + k) + term2 += tl.where(cols == target, weight, 0.0) + + grad_z = grad_loss * (term1 - term2) + dz_dx = (1.0 / C) * z * (1.0 - sigmoid_u) + grad_x = grad_z * dz_dx + tl.store(grad_row_ptr + cols, grad_x.to(tl.bfloat16), mask=mask) + +class FusedSoftcappedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, targets, mtp_weights, A=23.0, B=5.0, C=7.5): + n_rows, n_cols = logits.shape + if mtp_weights is None: + mtp_weights = torch.tensor([1.0], device=logits.device, dtype=torch.float32) + n_predict = mtp_weights.shape[0] + + losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + lse = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + + logits = logits.contiguous() + targets = targets.contiguous() + mtp_weights = mtp_weights.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_fwd_kernel[grid]( + logits, losses, lse, targets, mtp_weights, + logits.stride(0), logits.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + + ctx.save_for_backward(logits, targets, mtp_weights, lse) + ctx.params = (A, B, C) + return losses + + @staticmethod + def backward(ctx, grad_output): + logits, targets, mtp_weights, lse = ctx.saved_tensors + A, B, C = ctx.params + n_rows, n_cols = logits.shape + n_predict = mtp_weights.shape[0] + + grad_input = torch.empty((n_rows, n_cols), dtype=torch.bfloat16, device=logits.device) + grad_output = grad_output.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_bwd_kernel[grid]( + grad_input, grad_output, lse, logits, targets, mtp_weights, + logits.stride(0), logits.stride(1), grad_input.stride(0), grad_input.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + return grad_input, None, None, None, None, None + +==================================================================================================== +Running Python 3.12.7 (main, Jan 31 2026, 04:21:49) [GCC 13.2.0] +Running PyTorch 2.10.0.dev20251210+cu126 compiled for CUDA 12.6 +Running Triton version 3.6.0 +Sun Feb 1 06:03:44 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 35C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:71:00.0 Off | 0 | +| N/A 37C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:79:00.0 Off | 0 | +| N/A 33C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:7F:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 37C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 35C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 12926 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 12927 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 12928 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 12929 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 12930 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 12931 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 12932 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 12933 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +Compiling model and warming up kernels (~7 minutes on first execution) +Sampling steps [0, 1, 2, 504, 505, 506, 1009, 1010, 1011, 1514, 1515, 1516] for warmup +Resetting Model +step:0/1555 val_loss:10.8301 train_time:0ms step_avg:0.03ms +step:1/1555 train_time:80ms step_avg:80.22ms +step:2/1555 train_time:104ms step_avg:51.92ms +step:3/1555 train_time:126ms step_avg:42.13ms +step:4/1555 train_time:149ms step_avg:37.27ms +step:5/1555 train_time:179ms step_avg:35.85ms +step:6/1555 train_time:217ms step_avg:36.19ms +step:7/1555 train_time:248ms step_avg:35.40ms +step:8/1555 train_time:285ms step_avg:35.67ms +step:9/1555 train_time:316ms step_avg:35.11ms +step:10/1555 train_time:354ms step_avg:35.40ms +step:11/1555 train_time:384ms step_avg:34.94ms +step:12/1555 train_time:422ms step_avg:35.20ms +step:13/1555 train_time:453ms step_avg:34.86ms +step:14/1555 train_time:491ms step_avg:35.06ms +step:15/1555 train_time:522ms step_avg:34.80ms +step:16/1555 train_time:560ms step_avg:35.00ms +step:17/1555 train_time:591ms step_avg:34.76ms +step:18/1555 train_time:628ms step_avg:34.91ms +step:19/1555 train_time:659ms step_avg:34.69ms +step:20/1555 train_time:697ms step_avg:34.85ms +step:21/1555 train_time:728ms step_avg:34.68ms +step:22/1555 train_time:766ms step_avg:34.81ms +step:23/1555 train_time:797ms step_avg:34.65ms +step:24/1555 train_time:835ms step_avg:34.78ms +step:25/1555 train_time:865ms step_avg:34.62ms +step:26/1555 train_time:903ms step_avg:34.75ms +step:27/1555 train_time:934ms step_avg:34.60ms +step:28/1555 train_time:972ms step_avg:34.71ms +step:29/1555 train_time:1003ms step_avg:34.57ms +step:30/1555 train_time:1040ms step_avg:34.67ms +step:31/1555 train_time:1072ms step_avg:34.57ms +step:32/1555 train_time:1109ms step_avg:34.66ms +step:33/1555 train_time:1141ms step_avg:34.56ms +step:34/1555 train_time:1179ms step_avg:34.68ms +step:35/1555 train_time:1210ms step_avg:34.57ms +step:36/1555 train_time:1248ms step_avg:34.65ms +step:37/1555 train_time:1279ms step_avg:34.56ms +step:38/1555 train_time:1317ms step_avg:34.65ms +step:39/1555 train_time:1347ms step_avg:34.55ms +step:40/1555 train_time:1385ms step_avg:34.63ms +step:41/1555 train_time:1416ms step_avg:34.55ms +step:42/1555 train_time:1455ms step_avg:34.63ms +step:43/1555 train_time:1485ms step_avg:34.54ms +step:44/1555 train_time:1523ms step_avg:34.62ms +step:45/1555 train_time:1554ms step_avg:34.53ms +step:46/1555 train_time:1592ms step_avg:34.60ms +step:47/1555 train_time:1623ms step_avg:34.53ms +step:48/1555 train_time:1661ms step_avg:34.60ms +step:49/1555 train_time:1691ms step_avg:34.52ms +step:50/1555 train_time:1729ms step_avg:34.58ms +step:51/1555 train_time:1760ms step_avg:34.51ms +step:52/1555 train_time:1798ms step_avg:34.57ms +step:53/1555 train_time:1829ms step_avg:34.51ms +step:54/1555 train_time:1866ms step_avg:34.56ms +step:55/1555 train_time:1898ms step_avg:34.50ms +step:56/1555 train_time:1935ms step_avg:34.56ms +step:57/1555 train_time:1966ms step_avg:34.49ms +step:58/1555 train_time:2004ms step_avg:34.55ms +step:59/1555 train_time:2035ms step_avg:34.48ms +step:60/1555 train_time:2073ms step_avg:34.55ms +step:61/1555 train_time:2104ms step_avg:34.49ms +step:62/1555 train_time:2141ms step_avg:34.54ms +step:63/1555 train_time:2172ms step_avg:34.48ms +step:64/1555 train_time:2210ms step_avg:34.53ms +step:65/1555 train_time:2241ms step_avg:34.48ms +step:66/1555 train_time:2279ms step_avg:34.52ms +step:67/1555 train_time:2309ms step_avg:34.47ms +step:68/1555 train_time:2348ms step_avg:34.52ms +step:69/1555 train_time:2379ms step_avg:34.48ms +step:70/1555 train_time:2417ms step_avg:34.53ms +step:71/1555 train_time:2448ms step_avg:34.48ms +step:72/1555 train_time:2486ms step_avg:34.52ms +step:73/1555 train_time:2517ms step_avg:34.48ms +step:74/1555 train_time:2555ms step_avg:34.53ms +step:75/1555 train_time:2585ms step_avg:34.47ms +step:76/1555 train_time:2624ms step_avg:34.52ms +step:77/1555 train_time:2655ms step_avg:34.48ms +step:78/1555 train_time:2693ms step_avg:34.52ms +step:79/1555 train_time:2724ms step_avg:34.47ms +step:80/1555 train_time:2761ms step_avg:34.52ms +step:81/1555 train_time:2792ms step_avg:34.47ms +step:82/1555 train_time:2829ms step_avg:34.51ms +step:83/1555 train_time:2861ms step_avg:34.46ms +step:84/1555 train_time:2898ms step_avg:34.50ms +step:85/1555 train_time:2929ms step_avg:34.45ms +step:86/1555 train_time:2966ms step_avg:34.49ms +step:87/1555 train_time:2997ms step_avg:34.45ms +step:88/1555 train_time:3035ms step_avg:34.49ms +step:89/1555 train_time:3066ms step_avg:34.45ms +step:90/1555 train_time:3104ms step_avg:34.49ms +step:91/1555 train_time:3135ms step_avg:34.45ms +step:92/1555 train_time:3173ms step_avg:34.49ms +step:93/1555 train_time:3203ms step_avg:34.45ms +step:94/1555 train_time:3241ms step_avg:34.48ms +step:95/1555 train_time:3272ms step_avg:34.45ms +step:96/1555 train_time:3310ms step_avg:34.48ms +step:97/1555 train_time:3341ms step_avg:34.44ms +step:98/1555 train_time:3379ms step_avg:34.48ms +step:99/1555 train_time:3410ms step_avg:34.44ms +step:100/1555 train_time:3447ms step_avg:34.47ms +step:101/1555 train_time:3478ms step_avg:34.44ms +step:102/1555 train_time:3516ms step_avg:34.47ms +step:103/1555 train_time:3548ms step_avg:34.45ms +step:104/1555 train_time:3586ms step_avg:34.48ms +step:105/1555 train_time:3617ms step_avg:34.44ms +step:106/1555 train_time:3654ms step_avg:34.47ms +step:107/1555 train_time:3685ms step_avg:34.44ms +step:108/1555 train_time:3723ms step_avg:34.47ms +step:109/1555 train_time:3754ms step_avg:34.44ms +step:110/1555 train_time:3791ms step_avg:34.47ms +step:111/1555 train_time:3823ms step_avg:34.44ms +step:112/1555 train_time:3861ms step_avg:34.47ms +step:113/1555 train_time:3891ms step_avg:34.44ms +step:114/1555 train_time:3929ms step_avg:34.46ms +step:115/1555 train_time:3960ms step_avg:34.43ms +step:116/1555 train_time:3997ms step_avg:34.46ms +step:117/1555 train_time:4028ms step_avg:34.43ms +step:118/1555 train_time:4066ms step_avg:34.46ms +step:119/1555 train_time:4097ms step_avg:34.43ms +step:120/1555 train_time:4135ms step_avg:34.46ms +step:121/1555 train_time:4165ms step_avg:34.42ms +step:122/1555 train_time:4203ms step_avg:34.45ms +step:123/1555 train_time:4234ms step_avg:34.42ms +step:124/1555 train_time:4271ms step_avg:34.45ms +step:125/1555 train_time:4302ms step_avg:34.42ms +step:126/1555 train_time:4340ms step_avg:34.45ms +step:127/1555 train_time:4371ms step_avg:34.42ms +step:128/1555 train_time:4409ms step_avg:34.45ms +step:129/1555 train_time:4440ms step_avg:34.42ms +step:130/1555 train_time:4479ms step_avg:34.45ms +step:131/1555 train_time:4509ms step_avg:34.42ms +step:132/1555 train_time:4546ms step_avg:34.44ms +step:133/1555 train_time:4577ms step_avg:34.42ms +step:134/1555 train_time:4615ms step_avg:34.44ms +step:135/1555 train_time:4646ms step_avg:34.41ms +step:136/1555 train_time:4684ms step_avg:34.44ms +step:137/1555 train_time:4715ms step_avg:34.41ms +step:138/1555 train_time:4752ms step_avg:34.44ms +step:139/1555 train_time:4783ms step_avg:34.41ms +step:140/1555 train_time:4821ms step_avg:34.44ms +step:141/1555 train_time:4852ms step_avg:34.41ms +step:142/1555 train_time:4889ms step_avg:34.43ms +step:143/1555 train_time:4920ms step_avg:34.41ms +step:144/1555 train_time:4958ms step_avg:34.43ms +step:145/1555 train_time:4988ms step_avg:34.40ms +step:146/1555 train_time:5026ms step_avg:34.42ms +step:147/1555 train_time:5057ms step_avg:34.40ms +step:148/1555 train_time:5095ms step_avg:34.43ms +step:149/1555 train_time:5126ms step_avg:34.40ms +step:150/1555 train_time:5163ms step_avg:34.42ms +step:151/1555 train_time:5194ms step_avg:34.40ms +step:152/1555 train_time:5231ms step_avg:34.42ms +step:153/1555 train_time:5262ms step_avg:34.40ms +step:154/1555 train_time:5300ms step_avg:34.42ms +step:155/1555 train_time:5331ms step_avg:34.39ms +step:156/1555 train_time:5368ms step_avg:34.41ms +step:157/1555 train_time:5400ms step_avg:34.39ms +step:158/1555 train_time:5437ms step_avg:34.41ms +step:159/1555 train_time:5468ms step_avg:34.39ms +step:160/1555 train_time:5506ms step_avg:34.41ms +step:161/1555 train_time:5537ms step_avg:34.39ms +step:162/1555 train_time:5575ms step_avg:34.41ms +step:163/1555 train_time:5606ms step_avg:34.39ms +step:164/1555 train_time:5643ms step_avg:34.41ms +step:165/1555 train_time:5675ms step_avg:34.39ms +step:166/1555 train_time:5712ms step_avg:34.41ms +step:167/1555 train_time:5744ms step_avg:34.39ms +step:168/1555 train_time:5782ms step_avg:34.42ms +step:169/1555 train_time:5814ms step_avg:34.40ms +step:170/1555 train_time:5851ms step_avg:34.42ms +step:171/1555 train_time:5883ms step_avg:34.40ms +step:172/1555 train_time:5921ms step_avg:34.42ms +step:173/1555 train_time:5951ms step_avg:34.40ms +step:174/1555 train_time:5989ms step_avg:34.42ms +step:175/1555 train_time:6021ms step_avg:34.40ms +step:176/1555 train_time:6058ms step_avg:34.42ms +step:177/1555 train_time:6089ms step_avg:34.40ms +step:178/1555 train_time:6127ms step_avg:34.42ms +step:179/1555 train_time:6157ms step_avg:34.40ms +step:180/1555 train_time:6195ms step_avg:34.42ms +step:181/1555 train_time:6226ms step_avg:34.40ms +step:182/1555 train_time:6264ms step_avg:34.42ms +step:183/1555 train_time:6294ms step_avg:34.39ms +step:184/1555 train_time:6331ms step_avg:34.41ms +step:185/1555 train_time:6362ms step_avg:34.39ms +step:186/1555 train_time:6399ms step_avg:34.41ms +step:187/1555 train_time:6430ms step_avg:34.39ms +step:188/1555 train_time:6467ms step_avg:34.40ms +step:189/1555 train_time:6498ms step_avg:34.38ms +step:190/1555 train_time:6535ms step_avg:34.40ms +step:191/1555 train_time:6567ms step_avg:34.38ms +step:192/1555 train_time:6604ms step_avg:34.40ms +step:193/1555 train_time:6635ms step_avg:34.38ms +step:194/1555 train_time:6673ms step_avg:34.40ms +step:195/1555 train_time:6703ms step_avg:34.38ms +step:196/1555 train_time:6741ms step_avg:34.39ms +step:197/1555 train_time:6772ms step_avg:34.38ms +step:198/1555 train_time:6809ms step_avg:34.39ms +step:199/1555 train_time:6841ms step_avg:34.37ms +step:200/1555 train_time:6878ms step_avg:34.39ms +step:201/1555 train_time:6909ms step_avg:34.37ms +step:202/1555 train_time:6946ms step_avg:34.39ms +step:203/1555 train_time:6977ms step_avg:34.37ms +step:204/1555 train_time:7014ms step_avg:34.38ms +step:205/1555 train_time:7045ms step_avg:34.37ms +step:206/1555 train_time:7083ms step_avg:34.38ms +step:207/1555 train_time:7114ms step_avg:34.37ms +step:208/1555 train_time:7152ms step_avg:34.38ms +step:209/1555 train_time:7182ms step_avg:34.37ms +step:210/1555 train_time:7220ms step_avg:34.38ms +step:211/1555 train_time:7251ms step_avg:34.36ms +step:212/1555 train_time:7288ms step_avg:34.38ms +step:213/1555 train_time:7319ms step_avg:34.36ms +step:214/1555 train_time:7357ms step_avg:34.38ms +step:215/1555 train_time:7388ms step_avg:34.36ms +step:216/1555 train_time:7425ms step_avg:34.38ms +step:217/1555 train_time:7457ms step_avg:34.36ms +step:218/1555 train_time:7494ms step_avg:34.38ms +step:219/1555 train_time:7525ms step_avg:34.36ms +step:220/1555 train_time:7562ms step_avg:34.37ms +step:221/1555 train_time:7593ms step_avg:34.36ms +step:222/1555 train_time:7631ms step_avg:34.37ms +step:223/1555 train_time:7661ms step_avg:34.36ms +step:224/1555 train_time:7700ms step_avg:34.37ms +step:225/1555 train_time:7731ms step_avg:34.36ms +step:226/1555 train_time:7768ms step_avg:34.37ms +step:227/1555 train_time:7799ms step_avg:34.36ms +step:228/1555 train_time:7837ms step_avg:34.37ms +step:229/1555 train_time:7867ms step_avg:34.36ms +step:230/1555 train_time:7905ms step_avg:34.37ms +step:231/1555 train_time:7935ms step_avg:34.35ms +step:232/1555 train_time:7973ms step_avg:34.37ms +step:233/1555 train_time:8004ms step_avg:34.35ms +step:234/1555 train_time:8042ms step_avg:34.37ms +step:235/1555 train_time:8073ms step_avg:34.35ms +step:236/1555 train_time:8110ms step_avg:34.36ms +step:237/1555 train_time:8142ms step_avg:34.35ms +step:238/1555 train_time:8179ms step_avg:34.37ms +step:239/1555 train_time:8210ms step_avg:34.35ms +step:240/1555 train_time:8248ms step_avg:34.37ms +step:241/1555 train_time:8278ms step_avg:34.35ms +step:242/1555 train_time:8316ms step_avg:34.36ms +step:243/1555 train_time:8348ms step_avg:34.35ms +step:244/1555 train_time:8385ms step_avg:34.37ms +step:245/1555 train_time:8417ms step_avg:34.35ms +step:246/1555 train_time:8454ms step_avg:34.37ms +step:247/1555 train_time:8485ms step_avg:34.35ms +step:248/1555 train_time:8523ms step_avg:34.37ms +step:249/1555 train_time:8553ms step_avg:34.35ms +step:250/1555 train_time:8591ms step_avg:34.36ms +step:250/1555 val_loss:4.5532 train_time:8641ms step_avg:34.56ms +step:251/1555 train_time:8662ms step_avg:34.51ms +step:252/1555 train_time:8686ms step_avg:34.47ms +step:253/1555 train_time:8707ms step_avg:34.41ms +step:254/1555 train_time:8731ms step_avg:34.37ms +step:255/1555 train_time:8763ms step_avg:34.37ms +step:256/1555 train_time:8802ms step_avg:34.38ms +step:257/1555 train_time:8835ms step_avg:34.38ms +step:258/1555 train_time:8874ms step_avg:34.40ms +step:259/1555 train_time:8907ms step_avg:34.39ms +step:260/1555 train_time:8944ms step_avg:34.40ms +step:261/1555 train_time:8974ms step_avg:34.38ms +step:262/1555 train_time:9012ms step_avg:34.40ms +step:263/1555 train_time:9042ms step_avg:34.38ms +step:264/1555 train_time:9080ms step_avg:34.39ms +step:265/1555 train_time:9110ms step_avg:34.38ms +step:266/1555 train_time:9148ms step_avg:34.39ms +step:267/1555 train_time:9179ms step_avg:34.38ms +step:268/1555 train_time:9217ms step_avg:34.39ms +step:269/1555 train_time:9247ms step_avg:34.38ms +step:270/1555 train_time:9285ms step_avg:34.39ms +step:271/1555 train_time:9315ms step_avg:34.37ms +step:272/1555 train_time:9353ms step_avg:34.38ms +step:273/1555 train_time:9383ms step_avg:34.37ms +step:274/1555 train_time:9420ms step_avg:34.38ms +step:275/1555 train_time:9451ms step_avg:34.37ms +step:276/1555 train_time:9489ms step_avg:34.38ms +step:277/1555 train_time:9520ms step_avg:34.37ms +step:278/1555 train_time:9558ms step_avg:34.38ms +step:279/1555 train_time:9593ms step_avg:34.39ms +step:280/1555 train_time:9625ms step_avg:34.38ms +step:281/1555 train_time:9656ms step_avg:34.36ms +step:282/1555 train_time:9694ms step_avg:34.38ms +step:283/1555 train_time:9725ms step_avg:34.36ms +step:284/1555 train_time:9762ms step_avg:34.37ms +step:285/1555 train_time:9794ms step_avg:34.36ms +step:286/1555 train_time:9832ms step_avg:34.38ms +step:287/1555 train_time:9863ms step_avg:34.36ms +step:288/1555 train_time:9900ms step_avg:34.38ms +step:289/1555 train_time:9931ms step_avg:34.36ms +step:290/1555 train_time:9969ms step_avg:34.37ms +step:291/1555 train_time:9999ms step_avg:34.36ms +step:292/1555 train_time:10037ms step_avg:34.37ms +step:293/1555 train_time:10068ms step_avg:34.36ms +step:294/1555 train_time:10105ms step_avg:34.37ms +step:295/1555 train_time:10136ms step_avg:34.36ms +step:296/1555 train_time:10173ms step_avg:34.37ms +step:297/1555 train_time:10204ms step_avg:34.36ms +step:298/1555 train_time:10242ms step_avg:34.37ms +step:299/1555 train_time:10273ms step_avg:34.36ms +step:300/1555 train_time:10311ms step_avg:34.37ms +step:301/1555 train_time:10341ms step_avg:34.36ms +step:302/1555 train_time:10378ms step_avg:34.37ms +step:303/1555 train_time:10409ms step_avg:34.35ms +step:304/1555 train_time:10447ms step_avg:34.36ms +step:305/1555 train_time:10477ms step_avg:34.35ms +step:306/1555 train_time:10515ms step_avg:34.36ms +step:307/1555 train_time:10546ms step_avg:34.35ms +step:308/1555 train_time:10583ms step_avg:34.36ms +step:309/1555 train_time:10613ms step_avg:34.35ms +step:310/1555 train_time:10651ms step_avg:34.36ms +step:311/1555 train_time:10682ms step_avg:34.35ms +step:312/1555 train_time:10719ms step_avg:34.35ms +step:313/1555 train_time:10750ms step_avg:34.34ms +step:314/1555 train_time:10787ms step_avg:34.35ms +step:315/1555 train_time:10818ms step_avg:34.34ms +step:316/1555 train_time:10856ms step_avg:34.35ms +step:317/1555 train_time:10886ms step_avg:34.34ms +step:318/1555 train_time:10924ms step_avg:34.35ms +step:319/1555 train_time:10955ms step_avg:34.34ms +step:320/1555 train_time:10993ms step_avg:34.35ms +step:321/1555 train_time:11024ms step_avg:34.34ms +step:322/1555 train_time:11062ms step_avg:34.35ms +step:323/1555 train_time:11092ms step_avg:34.34ms +step:324/1555 train_time:11130ms step_avg:34.35ms +step:325/1555 train_time:11161ms step_avg:34.34ms +step:326/1555 train_time:11199ms step_avg:34.35ms +step:327/1555 train_time:11230ms step_avg:34.34ms +step:328/1555 train_time:11267ms step_avg:34.35ms +step:329/1555 train_time:11298ms step_avg:34.34ms +step:330/1555 train_time:11336ms step_avg:34.35ms +step:331/1555 train_time:11366ms step_avg:34.34ms +step:332/1555 train_time:11404ms step_avg:34.35ms +step:333/1555 train_time:11434ms step_avg:34.34ms +step:334/1555 train_time:11472ms step_avg:34.35ms +step:335/1555 train_time:11502ms step_avg:34.34ms +step:336/1555 train_time:11540ms step_avg:34.35ms +step:337/1555 train_time:11571ms step_avg:34.34ms +step:338/1555 train_time:11608ms step_avg:34.34ms +step:339/1555 train_time:11639ms step_avg:34.33ms +step:340/1555 train_time:11677ms step_avg:34.34ms +step:341/1555 train_time:11708ms step_avg:34.33ms +step:342/1555 train_time:11745ms step_avg:34.34ms +step:343/1555 train_time:11776ms step_avg:34.33ms +step:344/1555 train_time:11814ms step_avg:34.34ms +step:345/1555 train_time:11845ms step_avg:34.33ms +step:346/1555 train_time:11883ms step_avg:34.34ms +step:347/1555 train_time:11913ms step_avg:34.33ms +step:348/1555 train_time:11952ms step_avg:34.34ms +step:349/1555 train_time:11983ms step_avg:34.33ms +step:350/1555 train_time:12020ms step_avg:34.34ms +step:351/1555 train_time:12051ms step_avg:34.33ms +step:352/1555 train_time:12089ms step_avg:34.34ms +step:353/1555 train_time:12119ms step_avg:34.33ms +step:354/1555 train_time:12157ms step_avg:34.34ms +step:355/1555 train_time:12188ms step_avg:34.33ms +step:356/1555 train_time:12225ms step_avg:34.34ms +step:357/1555 train_time:12256ms step_avg:34.33ms +step:358/1555 train_time:12294ms step_avg:34.34ms +step:359/1555 train_time:12325ms step_avg:34.33ms +step:360/1555 train_time:12362ms step_avg:34.34ms +step:361/1555 train_time:12393ms step_avg:34.33ms +step:362/1555 train_time:12431ms step_avg:34.34ms +step:363/1555 train_time:12462ms step_avg:34.33ms +step:364/1555 train_time:12500ms step_avg:34.34ms +step:365/1555 train_time:12530ms step_avg:34.33ms +step:366/1555 train_time:12568ms step_avg:34.34ms +step:367/1555 train_time:12598ms step_avg:34.33ms +step:368/1555 train_time:12636ms step_avg:34.34ms +step:369/1555 train_time:12667ms step_avg:34.33ms +step:370/1555 train_time:12704ms step_avg:34.34ms +step:371/1555 train_time:12735ms step_avg:34.33ms +step:372/1555 train_time:12773ms step_avg:34.34ms +step:373/1555 train_time:12804ms step_avg:34.33ms +step:374/1555 train_time:12841ms step_avg:34.33ms +step:375/1555 train_time:12872ms step_avg:34.33ms +step:376/1555 train_time:12909ms step_avg:34.33ms +step:377/1555 train_time:12940ms step_avg:34.32ms +step:378/1555 train_time:12977ms step_avg:34.33ms +step:379/1555 train_time:13008ms step_avg:34.32ms +step:380/1555 train_time:13046ms step_avg:34.33ms +step:381/1555 train_time:13077ms step_avg:34.32ms +step:382/1555 train_time:13115ms step_avg:34.33ms +step:383/1555 train_time:13146ms step_avg:34.32ms +step:384/1555 train_time:13183ms step_avg:34.33ms +step:385/1555 train_time:13214ms step_avg:34.32ms +step:386/1555 train_time:13252ms step_avg:34.33ms +step:387/1555 train_time:13283ms step_avg:34.32ms +step:388/1555 train_time:13321ms step_avg:34.33ms +step:389/1555 train_time:13351ms step_avg:34.32ms +step:390/1555 train_time:13388ms step_avg:34.33ms +step:391/1555 train_time:13419ms step_avg:34.32ms +step:392/1555 train_time:13457ms step_avg:34.33ms +step:393/1555 train_time:13488ms step_avg:34.32ms +step:394/1555 train_time:13525ms step_avg:34.33ms +step:395/1555 train_time:13556ms step_avg:34.32ms +step:396/1555 train_time:13593ms step_avg:34.33ms +step:397/1555 train_time:13624ms step_avg:34.32ms +step:398/1555 train_time:13662ms step_avg:34.33ms +step:399/1555 train_time:13693ms step_avg:34.32ms +step:400/1555 train_time:13731ms step_avg:34.33ms +step:401/1555 train_time:13761ms step_avg:34.32ms +step:402/1555 train_time:13799ms step_avg:34.33ms +step:403/1555 train_time:13829ms step_avg:34.32ms +step:404/1555 train_time:13867ms step_avg:34.32ms +step:405/1555 train_time:13898ms step_avg:34.32ms +step:406/1555 train_time:13935ms step_avg:34.32ms +step:407/1555 train_time:13966ms step_avg:34.31ms +step:408/1555 train_time:14003ms step_avg:34.32ms +step:409/1555 train_time:14034ms step_avg:34.31ms +step:410/1555 train_time:14071ms step_avg:34.32ms +step:411/1555 train_time:14102ms step_avg:34.31ms +step:412/1555 train_time:14139ms step_avg:34.32ms +step:413/1555 train_time:14170ms step_avg:34.31ms +step:414/1555 train_time:14207ms step_avg:34.32ms +step:415/1555 train_time:14238ms step_avg:34.31ms +step:416/1555 train_time:14275ms step_avg:34.32ms +step:417/1555 train_time:14306ms step_avg:34.31ms +step:418/1555 train_time:14343ms step_avg:34.31ms +step:419/1555 train_time:14374ms step_avg:34.30ms +step:420/1555 train_time:14411ms step_avg:34.31ms +step:421/1555 train_time:14442ms step_avg:34.30ms +step:422/1555 train_time:14480ms step_avg:34.31ms +step:423/1555 train_time:14511ms step_avg:34.30ms +step:424/1555 train_time:14548ms step_avg:34.31ms +step:425/1555 train_time:14579ms step_avg:34.30ms +step:426/1555 train_time:14617ms step_avg:34.31ms +step:427/1555 train_time:14648ms step_avg:34.30ms +step:428/1555 train_time:14685ms step_avg:34.31ms +step:429/1555 train_time:14716ms step_avg:34.30ms +step:430/1555 train_time:14753ms step_avg:34.31ms +step:431/1555 train_time:14784ms step_avg:34.30ms +step:432/1555 train_time:14822ms step_avg:34.31ms +step:433/1555 train_time:14852ms step_avg:34.30ms +step:434/1555 train_time:14890ms step_avg:34.31ms +step:435/1555 train_time:14921ms step_avg:34.30ms +step:436/1555 train_time:14958ms step_avg:34.31ms +step:437/1555 train_time:14989ms step_avg:34.30ms +step:438/1555 train_time:15027ms step_avg:34.31ms +step:439/1555 train_time:15058ms step_avg:34.30ms +step:440/1555 train_time:15095ms step_avg:34.31ms +step:441/1555 train_time:15125ms step_avg:34.30ms +step:442/1555 train_time:15163ms step_avg:34.30ms +step:443/1555 train_time:15193ms step_avg:34.30ms +step:444/1555 train_time:15231ms step_avg:34.30ms +step:445/1555 train_time:15262ms step_avg:34.30ms +step:446/1555 train_time:15299ms step_avg:34.30ms +step:447/1555 train_time:15330ms step_avg:34.29ms +step:448/1555 train_time:15367ms step_avg:34.30ms +step:449/1555 train_time:15398ms step_avg:34.29ms +step:450/1555 train_time:15436ms step_avg:34.30ms +step:451/1555 train_time:15467ms step_avg:34.29ms +step:452/1555 train_time:15505ms step_avg:34.30ms +step:453/1555 train_time:15535ms step_avg:34.29ms +step:454/1555 train_time:15574ms step_avg:34.30ms +step:455/1555 train_time:15604ms step_avg:34.30ms +step:456/1555 train_time:15642ms step_avg:34.30ms +step:457/1555 train_time:15672ms step_avg:34.29ms +step:458/1555 train_time:15710ms step_avg:34.30ms +step:459/1555 train_time:15741ms step_avg:34.29ms +step:460/1555 train_time:15779ms step_avg:34.30ms +step:461/1555 train_time:15809ms step_avg:34.29ms +step:462/1555 train_time:15847ms step_avg:34.30ms +step:463/1555 train_time:15878ms step_avg:34.29ms +step:464/1555 train_time:15916ms step_avg:34.30ms +step:465/1555 train_time:15947ms step_avg:34.29ms +step:466/1555 train_time:15984ms step_avg:34.30ms +step:467/1555 train_time:16016ms step_avg:34.30ms +step:468/1555 train_time:16054ms step_avg:34.30ms +step:469/1555 train_time:16085ms step_avg:34.30ms +step:470/1555 train_time:16122ms step_avg:34.30ms +step:471/1555 train_time:16154ms step_avg:34.30ms +step:472/1555 train_time:16192ms step_avg:34.30ms +step:473/1555 train_time:16222ms step_avg:34.30ms +step:474/1555 train_time:16260ms step_avg:34.30ms +step:475/1555 train_time:16291ms step_avg:34.30ms +step:476/1555 train_time:16328ms step_avg:34.30ms +step:477/1555 train_time:16359ms step_avg:34.29ms +step:478/1555 train_time:16396ms step_avg:34.30ms +step:479/1555 train_time:16427ms step_avg:34.29ms +step:480/1555 train_time:16464ms step_avg:34.30ms +step:481/1555 train_time:16495ms step_avg:34.29ms +step:482/1555 train_time:16534ms step_avg:34.30ms +step:483/1555 train_time:16564ms step_avg:34.29ms +step:484/1555 train_time:16601ms step_avg:34.30ms +step:485/1555 train_time:16632ms step_avg:34.29ms +step:486/1555 train_time:16670ms step_avg:34.30ms +step:487/1555 train_time:16701ms step_avg:34.29ms +step:488/1555 train_time:16739ms step_avg:34.30ms +step:489/1555 train_time:16769ms step_avg:34.29ms +step:490/1555 train_time:16807ms step_avg:34.30ms +step:491/1555 train_time:16837ms step_avg:34.29ms +step:492/1555 train_time:16875ms step_avg:34.30ms +step:493/1555 train_time:16905ms step_avg:34.29ms +step:494/1555 train_time:16943ms step_avg:34.30ms +step:495/1555 train_time:16974ms step_avg:34.29ms +step:496/1555 train_time:17012ms step_avg:34.30ms +step:497/1555 train_time:17042ms step_avg:34.29ms +step:498/1555 train_time:17080ms step_avg:34.30ms +step:499/1555 train_time:17111ms step_avg:34.29ms +step:500/1555 train_time:17148ms step_avg:34.30ms +step:500/1555 val_loss:4.2277 train_time:17197ms step_avg:34.39ms +step:501/1555 train_time:17218ms step_avg:34.37ms +step:502/1555 train_time:17238ms step_avg:34.34ms +step:503/1555 train_time:17257ms step_avg:34.31ms +step:504/1555 train_time:17287ms step_avg:34.30ms +step:505/1555 train_time:17319ms step_avg:34.30ms +step:506/1555 train_time:17364ms step_avg:34.32ms +step:507/1555 train_time:17416ms step_avg:34.35ms +step:508/1555 train_time:17480ms step_avg:34.41ms +step:509/1555 train_time:17538ms step_avg:34.46ms +step:510/1555 train_time:17603ms step_avg:34.52ms +step:511/1555 train_time:17661ms step_avg:34.56ms +step:512/1555 train_time:17724ms step_avg:34.62ms +step:513/1555 train_time:17781ms step_avg:34.66ms +step:514/1555 train_time:17845ms step_avg:34.72ms +step:515/1555 train_time:17902ms step_avg:34.76ms +step:516/1555 train_time:17966ms step_avg:34.82ms +step:517/1555 train_time:18023ms step_avg:34.86ms +step:518/1555 train_time:18086ms step_avg:34.92ms +step:519/1555 train_time:18145ms step_avg:34.96ms +step:520/1555 train_time:18211ms step_avg:35.02ms +step:521/1555 train_time:18270ms step_avg:35.07ms +step:522/1555 train_time:18337ms step_avg:35.13ms +step:523/1555 train_time:18394ms step_avg:35.17ms +step:524/1555 train_time:18458ms step_avg:35.23ms +step:525/1555 train_time:18516ms step_avg:35.27ms +step:526/1555 train_time:18580ms step_avg:35.32ms +step:527/1555 train_time:18638ms step_avg:35.37ms +step:528/1555 train_time:18702ms step_avg:35.42ms +step:529/1555 train_time:18760ms step_avg:35.46ms +step:530/1555 train_time:18823ms step_avg:35.52ms +step:531/1555 train_time:18881ms step_avg:35.56ms +step:532/1555 train_time:18945ms step_avg:35.61ms +step:533/1555 train_time:19001ms step_avg:35.65ms +step:534/1555 train_time:19065ms step_avg:35.70ms +step:535/1555 train_time:19123ms step_avg:35.74ms +step:536/1555 train_time:19188ms step_avg:35.80ms +step:537/1555 train_time:19247ms step_avg:35.84ms +step:538/1555 train_time:19312ms step_avg:35.90ms +step:539/1555 train_time:19372ms step_avg:35.94ms +step:540/1555 train_time:19435ms step_avg:35.99ms +step:541/1555 train_time:19493ms step_avg:36.03ms +step:542/1555 train_time:19556ms step_avg:36.08ms +step:543/1555 train_time:19615ms step_avg:36.12ms +step:544/1555 train_time:19679ms step_avg:36.17ms +step:545/1555 train_time:19737ms step_avg:36.21ms +step:546/1555 train_time:19801ms step_avg:36.27ms +step:547/1555 train_time:19858ms step_avg:36.30ms +step:548/1555 train_time:19922ms step_avg:36.35ms +step:549/1555 train_time:19979ms step_avg:36.39ms +step:550/1555 train_time:20043ms step_avg:36.44ms +step:551/1555 train_time:20100ms step_avg:36.48ms +step:552/1555 train_time:20165ms step_avg:36.53ms +step:553/1555 train_time:20223ms step_avg:36.57ms +step:554/1555 train_time:20288ms step_avg:36.62ms +step:555/1555 train_time:20346ms step_avg:36.66ms +step:556/1555 train_time:20412ms step_avg:36.71ms +step:557/1555 train_time:20471ms step_avg:36.75ms +step:558/1555 train_time:20534ms step_avg:36.80ms +step:559/1555 train_time:20591ms step_avg:36.84ms +step:560/1555 train_time:20655ms step_avg:36.88ms +step:561/1555 train_time:20712ms step_avg:36.92ms +step:562/1555 train_time:20776ms step_avg:36.97ms +step:563/1555 train_time:20834ms step_avg:37.01ms +step:564/1555 train_time:20898ms step_avg:37.05ms +step:565/1555 train_time:20956ms step_avg:37.09ms +step:566/1555 train_time:21020ms step_avg:37.14ms +step:567/1555 train_time:21078ms step_avg:37.17ms +step:568/1555 train_time:21142ms step_avg:37.22ms +step:569/1555 train_time:21200ms step_avg:37.26ms +step:570/1555 train_time:21265ms step_avg:37.31ms +step:571/1555 train_time:21322ms step_avg:37.34ms +step:572/1555 train_time:21387ms step_avg:37.39ms +step:573/1555 train_time:21445ms step_avg:37.43ms +step:574/1555 train_time:21510ms step_avg:37.47ms +step:575/1555 train_time:21568ms step_avg:37.51ms +step:576/1555 train_time:21632ms step_avg:37.56ms +step:577/1555 train_time:21689ms step_avg:37.59ms +step:578/1555 train_time:21753ms step_avg:37.63ms +step:579/1555 train_time:21810ms step_avg:37.67ms +step:580/1555 train_time:21876ms step_avg:37.72ms +step:581/1555 train_time:21932ms step_avg:37.75ms +step:582/1555 train_time:21997ms step_avg:37.79ms +step:583/1555 train_time:22055ms step_avg:37.83ms +step:584/1555 train_time:22119ms step_avg:37.87ms +step:585/1555 train_time:22176ms step_avg:37.91ms +step:586/1555 train_time:22241ms step_avg:37.95ms +step:587/1555 train_time:22299ms step_avg:37.99ms +step:588/1555 train_time:22365ms step_avg:38.03ms +step:589/1555 train_time:22422ms step_avg:38.07ms +step:590/1555 train_time:22486ms step_avg:38.11ms +step:591/1555 train_time:22544ms step_avg:38.15ms +step:592/1555 train_time:22609ms step_avg:38.19ms +step:593/1555 train_time:22666ms step_avg:38.22ms +step:594/1555 train_time:22730ms step_avg:38.27ms +step:595/1555 train_time:22788ms step_avg:38.30ms +step:596/1555 train_time:22852ms step_avg:38.34ms +step:597/1555 train_time:22910ms step_avg:38.37ms +step:598/1555 train_time:22974ms step_avg:38.42ms +step:599/1555 train_time:23031ms step_avg:38.45ms +step:600/1555 train_time:23094ms step_avg:38.49ms +step:601/1555 train_time:23153ms step_avg:38.52ms +step:602/1555 train_time:23216ms step_avg:38.57ms +step:603/1555 train_time:23275ms step_avg:38.60ms +step:604/1555 train_time:23340ms step_avg:38.64ms +step:605/1555 train_time:23397ms step_avg:38.67ms +step:606/1555 train_time:23462ms step_avg:38.72ms +step:607/1555 train_time:23519ms step_avg:38.75ms +step:608/1555 train_time:23584ms step_avg:38.79ms +step:609/1555 train_time:23642ms step_avg:38.82ms +step:610/1555 train_time:23707ms step_avg:38.86ms +step:611/1555 train_time:23764ms step_avg:38.89ms +step:612/1555 train_time:23829ms step_avg:38.94ms +step:613/1555 train_time:23887ms step_avg:38.97ms +step:614/1555 train_time:23951ms step_avg:39.01ms +step:615/1555 train_time:24008ms step_avg:39.04ms +step:616/1555 train_time:24074ms step_avg:39.08ms +step:617/1555 train_time:24131ms step_avg:39.11ms +step:618/1555 train_time:24195ms step_avg:39.15ms +step:619/1555 train_time:24253ms step_avg:39.18ms +step:620/1555 train_time:24317ms step_avg:39.22ms +step:621/1555 train_time:24374ms step_avg:39.25ms +step:622/1555 train_time:24438ms step_avg:39.29ms +step:623/1555 train_time:24496ms step_avg:39.32ms +step:624/1555 train_time:24561ms step_avg:39.36ms +step:625/1555 train_time:24619ms step_avg:39.39ms +step:626/1555 train_time:24684ms step_avg:39.43ms +step:627/1555 train_time:24742ms step_avg:39.46ms +step:628/1555 train_time:24806ms step_avg:39.50ms +step:629/1555 train_time:24864ms step_avg:39.53ms +step:630/1555 train_time:24929ms step_avg:39.57ms +step:631/1555 train_time:24987ms step_avg:39.60ms +step:632/1555 train_time:25051ms step_avg:39.64ms +step:633/1555 train_time:25109ms step_avg:39.67ms +step:634/1555 train_time:25173ms step_avg:39.70ms +step:635/1555 train_time:25230ms step_avg:39.73ms +step:636/1555 train_time:25293ms step_avg:39.77ms +step:637/1555 train_time:25351ms step_avg:39.80ms +step:638/1555 train_time:25415ms step_avg:39.84ms +step:639/1555 train_time:25472ms step_avg:39.86ms +step:640/1555 train_time:25537ms step_avg:39.90ms +step:641/1555 train_time:25596ms step_avg:39.93ms +step:642/1555 train_time:25661ms step_avg:39.97ms +step:643/1555 train_time:25719ms step_avg:40.00ms +step:644/1555 train_time:25783ms step_avg:40.04ms +step:645/1555 train_time:25842ms step_avg:40.06ms +step:646/1555 train_time:25907ms step_avg:40.10ms +step:647/1555 train_time:25965ms step_avg:40.13ms +step:648/1555 train_time:26029ms step_avg:40.17ms +step:649/1555 train_time:26087ms step_avg:40.20ms +step:650/1555 train_time:26151ms step_avg:40.23ms +step:651/1555 train_time:26210ms step_avg:40.26ms +step:652/1555 train_time:26274ms step_avg:40.30ms +step:653/1555 train_time:26331ms step_avg:40.32ms +step:654/1555 train_time:26394ms step_avg:40.36ms +step:655/1555 train_time:26452ms step_avg:40.39ms +step:656/1555 train_time:26516ms step_avg:40.42ms +step:657/1555 train_time:26572ms step_avg:40.45ms +step:658/1555 train_time:26637ms step_avg:40.48ms +step:659/1555 train_time:26694ms step_avg:40.51ms +step:660/1555 train_time:26760ms step_avg:40.54ms +step:661/1555 train_time:26817ms step_avg:40.57ms +step:662/1555 train_time:26882ms step_avg:40.61ms +step:663/1555 train_time:26940ms step_avg:40.63ms +step:664/1555 train_time:27005ms step_avg:40.67ms +step:665/1555 train_time:27063ms step_avg:40.70ms +step:666/1555 train_time:27128ms step_avg:40.73ms +step:667/1555 train_time:27186ms step_avg:40.76ms +step:668/1555 train_time:27250ms step_avg:40.79ms +step:669/1555 train_time:27308ms step_avg:40.82ms +step:670/1555 train_time:27372ms step_avg:40.85ms +step:671/1555 train_time:27431ms step_avg:40.88ms +step:672/1555 train_time:27494ms step_avg:40.91ms +step:673/1555 train_time:27551ms step_avg:40.94ms +step:674/1555 train_time:27615ms step_avg:40.97ms +step:675/1555 train_time:27674ms step_avg:41.00ms +step:676/1555 train_time:27738ms step_avg:41.03ms +step:677/1555 train_time:27796ms step_avg:41.06ms +step:678/1555 train_time:27861ms step_avg:41.09ms +step:679/1555 train_time:27918ms step_avg:41.12ms +step:680/1555 train_time:27983ms step_avg:41.15ms +step:681/1555 train_time:28042ms step_avg:41.18ms +step:682/1555 train_time:28106ms step_avg:41.21ms +step:683/1555 train_time:28163ms step_avg:41.23ms +step:684/1555 train_time:28228ms step_avg:41.27ms +step:685/1555 train_time:28286ms step_avg:41.29ms +step:686/1555 train_time:28351ms step_avg:41.33ms +step:687/1555 train_time:28408ms step_avg:41.35ms +step:688/1555 train_time:28473ms step_avg:41.38ms +step:689/1555 train_time:28530ms step_avg:41.41ms +step:690/1555 train_time:28594ms step_avg:41.44ms +step:691/1555 train_time:28652ms step_avg:41.46ms +step:692/1555 train_time:28716ms step_avg:41.50ms +step:693/1555 train_time:28774ms step_avg:41.52ms +step:694/1555 train_time:28838ms step_avg:41.55ms +step:695/1555 train_time:28895ms step_avg:41.58ms +step:696/1555 train_time:28960ms step_avg:41.61ms +step:697/1555 train_time:29018ms step_avg:41.63ms +step:698/1555 train_time:29083ms step_avg:41.67ms +step:699/1555 train_time:29141ms step_avg:41.69ms +step:700/1555 train_time:29205ms step_avg:41.72ms +step:701/1555 train_time:29263ms step_avg:41.75ms +step:702/1555 train_time:29327ms step_avg:41.78ms +step:703/1555 train_time:29385ms step_avg:41.80ms +step:704/1555 train_time:29450ms step_avg:41.83ms +step:705/1555 train_time:29507ms step_avg:41.85ms +step:706/1555 train_time:29571ms step_avg:41.89ms +step:707/1555 train_time:29629ms step_avg:41.91ms +step:708/1555 train_time:29694ms step_avg:41.94ms +step:709/1555 train_time:29752ms step_avg:41.96ms +step:710/1555 train_time:29816ms step_avg:41.99ms +step:711/1555 train_time:29873ms step_avg:42.02ms +step:712/1555 train_time:29937ms step_avg:42.05ms +step:713/1555 train_time:29994ms step_avg:42.07ms +step:714/1555 train_time:30058ms step_avg:42.10ms +step:715/1555 train_time:30116ms step_avg:42.12ms +step:716/1555 train_time:30182ms step_avg:42.15ms +step:717/1555 train_time:30241ms step_avg:42.18ms +step:718/1555 train_time:30305ms step_avg:42.21ms +step:719/1555 train_time:30364ms step_avg:42.23ms +step:720/1555 train_time:30428ms step_avg:42.26ms +step:721/1555 train_time:30485ms step_avg:42.28ms +step:722/1555 train_time:30549ms step_avg:42.31ms +step:723/1555 train_time:30607ms step_avg:42.33ms +step:724/1555 train_time:30671ms step_avg:42.36ms +step:725/1555 train_time:30729ms step_avg:42.38ms +step:726/1555 train_time:30792ms step_avg:42.41ms +step:727/1555 train_time:30850ms step_avg:42.43ms +step:728/1555 train_time:30913ms step_avg:42.46ms +step:729/1555 train_time:30971ms step_avg:42.48ms +step:730/1555 train_time:31035ms step_avg:42.51ms +step:731/1555 train_time:31093ms step_avg:42.54ms +step:732/1555 train_time:31158ms step_avg:42.57ms +step:733/1555 train_time:31216ms step_avg:42.59ms +step:734/1555 train_time:31281ms step_avg:42.62ms +step:735/1555 train_time:31339ms step_avg:42.64ms +step:736/1555 train_time:31404ms step_avg:42.67ms +step:737/1555 train_time:31462ms step_avg:42.69ms +step:738/1555 train_time:31526ms step_avg:42.72ms +step:739/1555 train_time:31584ms step_avg:42.74ms +step:740/1555 train_time:31648ms step_avg:42.77ms +step:741/1555 train_time:31705ms step_avg:42.79ms +step:742/1555 train_time:31770ms step_avg:42.82ms +step:743/1555 train_time:31829ms step_avg:42.84ms +step:744/1555 train_time:31892ms step_avg:42.87ms +step:745/1555 train_time:31951ms step_avg:42.89ms +step:746/1555 train_time:32014ms step_avg:42.91ms +step:747/1555 train_time:32072ms step_avg:42.93ms +step:748/1555 train_time:32137ms step_avg:42.96ms +step:749/1555 train_time:32195ms step_avg:42.98ms +step:750/1555 train_time:32258ms step_avg:43.01ms +step:750/1555 val_loss:3.8724 train_time:32342ms step_avg:43.12ms +step:751/1555 train_time:32363ms step_avg:43.09ms +step:752/1555 train_time:32392ms step_avg:43.07ms +step:753/1555 train_time:32443ms step_avg:43.09ms +step:754/1555 train_time:32512ms step_avg:43.12ms +step:755/1555 train_time:32569ms step_avg:43.14ms +step:756/1555 train_time:32633ms step_avg:43.17ms +step:757/1555 train_time:32691ms step_avg:43.19ms +step:758/1555 train_time:32754ms step_avg:43.21ms +step:759/1555 train_time:32811ms step_avg:43.23ms +step:760/1555 train_time:32874ms step_avg:43.25ms +step:761/1555 train_time:32930ms step_avg:43.27ms +step:762/1555 train_time:32994ms step_avg:43.30ms +step:763/1555 train_time:33050ms step_avg:43.32ms +step:764/1555 train_time:33113ms step_avg:43.34ms +step:765/1555 train_time:33170ms step_avg:43.36ms +step:766/1555 train_time:33233ms step_avg:43.39ms +step:767/1555 train_time:33291ms step_avg:43.40ms +step:768/1555 train_time:33356ms step_avg:43.43ms +step:769/1555 train_time:33416ms step_avg:43.45ms +step:770/1555 train_time:33482ms step_avg:43.48ms +step:771/1555 train_time:33540ms step_avg:43.50ms +step:772/1555 train_time:33605ms step_avg:43.53ms +step:773/1555 train_time:33664ms step_avg:43.55ms +step:774/1555 train_time:33728ms step_avg:43.58ms +step:775/1555 train_time:33785ms step_avg:43.59ms +step:776/1555 train_time:33849ms step_avg:43.62ms +step:777/1555 train_time:33906ms step_avg:43.64ms +step:778/1555 train_time:33970ms step_avg:43.66ms +step:779/1555 train_time:34027ms step_avg:43.68ms +step:780/1555 train_time:34090ms step_avg:43.71ms +step:781/1555 train_time:34147ms step_avg:43.72ms +step:782/1555 train_time:34211ms step_avg:43.75ms +step:783/1555 train_time:34269ms step_avg:43.77ms +step:784/1555 train_time:34333ms step_avg:43.79ms +step:785/1555 train_time:34393ms step_avg:43.81ms +step:786/1555 train_time:34457ms step_avg:43.84ms +step:787/1555 train_time:34515ms step_avg:43.86ms +step:788/1555 train_time:34580ms step_avg:43.88ms +step:789/1555 train_time:34637ms step_avg:43.90ms +step:790/1555 train_time:34702ms step_avg:43.93ms +step:791/1555 train_time:34759ms step_avg:43.94ms +step:792/1555 train_time:34824ms step_avg:43.97ms +step:793/1555 train_time:34882ms step_avg:43.99ms +step:794/1555 train_time:34945ms step_avg:44.01ms +step:795/1555 train_time:35003ms step_avg:44.03ms +step:796/1555 train_time:35066ms step_avg:44.05ms +step:797/1555 train_time:35124ms step_avg:44.07ms +step:798/1555 train_time:35189ms step_avg:44.10ms +step:799/1555 train_time:35247ms step_avg:44.11ms +step:800/1555 train_time:35311ms step_avg:44.14ms +step:801/1555 train_time:35371ms step_avg:44.16ms +step:802/1555 train_time:35434ms step_avg:44.18ms +step:803/1555 train_time:35494ms step_avg:44.20ms +step:804/1555 train_time:35557ms step_avg:44.22ms +step:805/1555 train_time:35614ms step_avg:44.24ms +step:806/1555 train_time:35678ms step_avg:44.27ms +step:807/1555 train_time:35735ms step_avg:44.28ms +step:808/1555 train_time:35800ms step_avg:44.31ms +step:809/1555 train_time:35857ms step_avg:44.32ms +step:810/1555 train_time:35921ms step_avg:44.35ms +step:811/1555 train_time:35978ms step_avg:44.36ms +step:812/1555 train_time:36043ms step_avg:44.39ms +step:813/1555 train_time:36100ms step_avg:44.40ms +step:814/1555 train_time:36166ms step_avg:44.43ms +step:815/1555 train_time:36224ms step_avg:44.45ms +step:816/1555 train_time:36289ms step_avg:44.47ms +step:817/1555 train_time:36347ms step_avg:44.49ms +step:818/1555 train_time:36411ms step_avg:44.51ms +step:819/1555 train_time:36469ms step_avg:44.53ms +step:820/1555 train_time:36534ms step_avg:44.55ms +step:821/1555 train_time:36593ms step_avg:44.57ms +step:822/1555 train_time:36656ms step_avg:44.59ms +step:823/1555 train_time:36713ms step_avg:44.61ms +step:824/1555 train_time:36776ms step_avg:44.63ms +step:825/1555 train_time:36834ms step_avg:44.65ms +step:826/1555 train_time:36898ms step_avg:44.67ms +step:827/1555 train_time:36955ms step_avg:44.69ms +step:828/1555 train_time:37020ms step_avg:44.71ms +step:829/1555 train_time:37077ms step_avg:44.73ms +step:830/1555 train_time:37142ms step_avg:44.75ms +step:831/1555 train_time:37200ms step_avg:44.77ms +step:832/1555 train_time:37264ms step_avg:44.79ms +step:833/1555 train_time:37323ms step_avg:44.81ms +step:834/1555 train_time:37388ms step_avg:44.83ms +step:835/1555 train_time:37445ms step_avg:44.84ms +step:836/1555 train_time:37511ms step_avg:44.87ms +step:837/1555 train_time:37569ms step_avg:44.89ms +step:838/1555 train_time:37634ms step_avg:44.91ms +step:839/1555 train_time:37691ms step_avg:44.92ms +step:840/1555 train_time:37754ms step_avg:44.95ms +step:841/1555 train_time:37812ms step_avg:44.96ms +step:842/1555 train_time:37875ms step_avg:44.98ms +step:843/1555 train_time:37933ms step_avg:45.00ms +step:844/1555 train_time:37997ms step_avg:45.02ms +step:845/1555 train_time:38054ms step_avg:45.03ms +step:846/1555 train_time:38119ms step_avg:45.06ms +step:847/1555 train_time:38175ms step_avg:45.07ms +step:848/1555 train_time:38240ms step_avg:45.09ms +step:849/1555 train_time:38298ms step_avg:45.11ms +step:850/1555 train_time:38363ms step_avg:45.13ms +step:851/1555 train_time:38421ms step_avg:45.15ms +step:852/1555 train_time:38486ms step_avg:45.17ms +step:853/1555 train_time:38544ms step_avg:45.19ms +step:854/1555 train_time:38609ms step_avg:45.21ms +step:855/1555 train_time:38667ms step_avg:45.22ms +step:856/1555 train_time:38731ms step_avg:45.25ms +step:857/1555 train_time:38788ms step_avg:45.26ms +step:858/1555 train_time:38852ms step_avg:45.28ms +step:859/1555 train_time:38911ms step_avg:45.30ms +step:860/1555 train_time:38974ms step_avg:45.32ms +step:861/1555 train_time:39032ms step_avg:45.33ms +step:862/1555 train_time:39096ms step_avg:45.35ms +step:863/1555 train_time:39153ms step_avg:45.37ms +step:864/1555 train_time:39217ms step_avg:45.39ms +step:865/1555 train_time:39274ms step_avg:45.40ms +step:866/1555 train_time:39338ms step_avg:45.42ms +step:867/1555 train_time:39396ms step_avg:45.44ms +step:868/1555 train_time:39460ms step_avg:45.46ms +step:869/1555 train_time:39518ms step_avg:45.48ms +step:870/1555 train_time:39583ms step_avg:45.50ms +step:871/1555 train_time:39641ms step_avg:45.51ms +step:872/1555 train_time:39706ms step_avg:45.53ms +step:873/1555 train_time:39764ms step_avg:45.55ms +step:874/1555 train_time:39828ms step_avg:45.57ms +step:875/1555 train_time:39886ms step_avg:45.58ms +step:876/1555 train_time:39951ms step_avg:45.61ms +step:877/1555 train_time:40008ms step_avg:45.62ms +step:878/1555 train_time:40073ms step_avg:45.64ms +step:879/1555 train_time:40131ms step_avg:45.65ms +step:880/1555 train_time:40195ms step_avg:45.68ms +step:881/1555 train_time:40252ms step_avg:45.69ms +step:882/1555 train_time:40316ms step_avg:45.71ms +step:883/1555 train_time:40374ms step_avg:45.72ms +step:884/1555 train_time:40438ms step_avg:45.74ms +step:885/1555 train_time:40496ms step_avg:45.76ms +step:886/1555 train_time:40560ms step_avg:45.78ms +step:887/1555 train_time:40618ms step_avg:45.79ms +step:888/1555 train_time:40683ms step_avg:45.81ms +step:889/1555 train_time:40740ms step_avg:45.83ms +step:890/1555 train_time:40805ms step_avg:45.85ms +step:891/1555 train_time:40863ms step_avg:45.86ms +step:892/1555 train_time:40927ms step_avg:45.88ms +step:893/1555 train_time:40985ms step_avg:45.90ms +step:894/1555 train_time:41050ms step_avg:45.92ms +step:895/1555 train_time:41107ms step_avg:45.93ms +step:896/1555 train_time:41172ms step_avg:45.95ms +step:897/1555 train_time:41231ms step_avg:45.97ms +step:898/1555 train_time:41295ms step_avg:45.99ms +step:899/1555 train_time:41353ms step_avg:46.00ms +step:900/1555 train_time:41416ms step_avg:46.02ms +step:901/1555 train_time:41475ms step_avg:46.03ms +step:902/1555 train_time:41537ms step_avg:46.05ms +step:903/1555 train_time:41595ms step_avg:46.06ms +step:904/1555 train_time:41659ms step_avg:46.08ms +step:905/1555 train_time:41716ms step_avg:46.10ms +step:906/1555 train_time:41781ms step_avg:46.12ms +step:907/1555 train_time:41839ms step_avg:46.13ms +step:908/1555 train_time:41903ms step_avg:46.15ms +step:909/1555 train_time:41961ms step_avg:46.16ms +step:910/1555 train_time:42026ms step_avg:46.18ms +step:911/1555 train_time:42084ms step_avg:46.20ms +step:912/1555 train_time:42148ms step_avg:46.22ms +step:913/1555 train_time:42207ms step_avg:46.23ms +step:914/1555 train_time:42271ms step_avg:46.25ms +step:915/1555 train_time:42329ms step_avg:46.26ms +step:916/1555 train_time:42394ms step_avg:46.28ms +step:917/1555 train_time:42451ms step_avg:46.29ms +step:918/1555 train_time:42514ms step_avg:46.31ms +step:919/1555 train_time:42572ms step_avg:46.32ms +step:920/1555 train_time:42636ms step_avg:46.34ms +step:921/1555 train_time:42694ms step_avg:46.36ms +step:922/1555 train_time:42758ms step_avg:46.37ms +step:923/1555 train_time:42816ms step_avg:46.39ms +step:924/1555 train_time:42880ms step_avg:46.41ms +step:925/1555 train_time:42937ms step_avg:46.42ms +step:926/1555 train_time:43002ms step_avg:46.44ms +step:927/1555 train_time:43060ms step_avg:46.45ms +step:928/1555 train_time:43125ms step_avg:46.47ms +step:929/1555 train_time:43183ms step_avg:46.48ms +step:930/1555 train_time:43248ms step_avg:46.50ms +step:931/1555 train_time:43306ms step_avg:46.52ms +step:932/1555 train_time:43370ms step_avg:46.53ms +step:933/1555 train_time:43429ms step_avg:46.55ms +step:934/1555 train_time:43494ms step_avg:46.57ms +step:935/1555 train_time:43551ms step_avg:46.58ms +step:936/1555 train_time:43614ms step_avg:46.60ms +step:937/1555 train_time:43672ms step_avg:46.61ms +step:938/1555 train_time:43736ms step_avg:46.63ms +step:939/1555 train_time:43794ms step_avg:46.64ms +step:940/1555 train_time:43857ms step_avg:46.66ms +step:941/1555 train_time:43916ms step_avg:46.67ms +step:942/1555 train_time:43980ms step_avg:46.69ms +step:943/1555 train_time:44038ms step_avg:46.70ms +step:944/1555 train_time:44103ms step_avg:46.72ms +step:945/1555 train_time:44161ms step_avg:46.73ms +step:946/1555 train_time:44225ms step_avg:46.75ms +step:947/1555 train_time:44284ms step_avg:46.76ms +step:948/1555 train_time:44349ms step_avg:46.78ms +step:949/1555 train_time:44407ms step_avg:46.79ms +step:950/1555 train_time:44470ms step_avg:46.81ms +step:951/1555 train_time:44528ms step_avg:46.82ms +step:952/1555 train_time:44592ms step_avg:46.84ms +step:953/1555 train_time:44650ms step_avg:46.85ms +step:954/1555 train_time:44714ms step_avg:46.87ms +step:955/1555 train_time:44771ms step_avg:46.88ms +step:956/1555 train_time:44835ms step_avg:46.90ms +step:957/1555 train_time:44894ms step_avg:46.91ms +step:958/1555 train_time:44957ms step_avg:46.93ms +step:959/1555 train_time:45015ms step_avg:46.94ms +step:960/1555 train_time:45080ms step_avg:46.96ms +step:961/1555 train_time:45138ms step_avg:46.97ms +step:962/1555 train_time:45203ms step_avg:46.99ms +step:963/1555 train_time:45261ms step_avg:47.00ms +step:964/1555 train_time:45326ms step_avg:47.02ms +step:965/1555 train_time:45384ms step_avg:47.03ms +step:966/1555 train_time:45448ms step_avg:47.05ms +step:967/1555 train_time:45506ms step_avg:47.06ms +step:968/1555 train_time:45571ms step_avg:47.08ms +step:969/1555 train_time:45629ms step_avg:47.09ms +step:970/1555 train_time:45694ms step_avg:47.11ms +step:971/1555 train_time:45751ms step_avg:47.12ms +step:972/1555 train_time:45814ms step_avg:47.13ms +step:973/1555 train_time:45873ms step_avg:47.15ms +step:974/1555 train_time:45937ms step_avg:47.16ms +step:975/1555 train_time:45994ms step_avg:47.17ms +step:976/1555 train_time:46058ms step_avg:47.19ms +step:977/1555 train_time:46116ms step_avg:47.20ms +step:978/1555 train_time:46180ms step_avg:47.22ms +step:979/1555 train_time:46237ms step_avg:47.23ms +step:980/1555 train_time:46303ms step_avg:47.25ms +step:981/1555 train_time:46360ms step_avg:47.26ms +step:982/1555 train_time:46424ms step_avg:47.28ms +step:983/1555 train_time:46483ms step_avg:47.29ms +step:984/1555 train_time:46548ms step_avg:47.30ms +step:985/1555 train_time:46606ms step_avg:47.32ms +step:986/1555 train_time:46670ms step_avg:47.33ms +step:987/1555 train_time:46728ms step_avg:47.34ms +step:988/1555 train_time:46792ms step_avg:47.36ms +step:989/1555 train_time:46850ms step_avg:47.37ms +step:990/1555 train_time:46914ms step_avg:47.39ms +step:991/1555 train_time:46973ms step_avg:47.40ms +step:992/1555 train_time:47036ms step_avg:47.42ms +step:993/1555 train_time:47094ms step_avg:47.43ms +step:994/1555 train_time:47157ms step_avg:47.44ms +step:995/1555 train_time:47215ms step_avg:47.45ms +step:996/1555 train_time:47279ms step_avg:47.47ms +step:997/1555 train_time:47336ms step_avg:47.48ms +step:998/1555 train_time:47402ms step_avg:47.50ms +step:999/1555 train_time:47460ms step_avg:47.51ms +step:1000/1555 train_time:47525ms step_avg:47.52ms +step:1000/1555 val_loss:3.5699 train_time:47608ms step_avg:47.61ms +step:1001/1555 train_time:47628ms step_avg:47.58ms +step:1002/1555 train_time:47650ms step_avg:47.55ms +step:1003/1555 train_time:47705ms step_avg:47.56ms +step:1004/1555 train_time:47775ms step_avg:47.58ms +step:1005/1555 train_time:47834ms step_avg:47.60ms +step:1006/1555 train_time:47899ms step_avg:47.61ms +step:1007/1555 train_time:47956ms step_avg:47.62ms +step:1008/1555 train_time:48020ms step_avg:47.64ms +step:1009/1555 train_time:48077ms step_avg:47.65ms +step:1010/1555 train_time:48141ms step_avg:47.66ms +step:1011/1555 train_time:48201ms step_avg:47.68ms +step:1012/1555 train_time:48285ms step_avg:47.71ms +step:1013/1555 train_time:48368ms step_avg:47.75ms +step:1014/1555 train_time:48458ms step_avg:47.79ms +step:1015/1555 train_time:48541ms step_avg:47.82ms +step:1016/1555 train_time:48632ms step_avg:47.87ms +step:1017/1555 train_time:48718ms step_avg:47.90ms +step:1018/1555 train_time:48811ms step_avg:47.95ms +step:1019/1555 train_time:48898ms step_avg:47.99ms +step:1020/1555 train_time:48988ms step_avg:48.03ms +step:1021/1555 train_time:49072ms step_avg:48.06ms +step:1022/1555 train_time:49162ms step_avg:48.10ms +step:1023/1555 train_time:49244ms step_avg:48.14ms +step:1024/1555 train_time:49334ms step_avg:48.18ms +step:1025/1555 train_time:49416ms step_avg:48.21ms +step:1026/1555 train_time:49505ms step_avg:48.25ms +step:1027/1555 train_time:49590ms step_avg:48.29ms +step:1028/1555 train_time:49681ms step_avg:48.33ms +step:1029/1555 train_time:49767ms step_avg:48.36ms +step:1030/1555 train_time:49860ms step_avg:48.41ms +step:1031/1555 train_time:49944ms step_avg:48.44ms +step:1032/1555 train_time:50034ms step_avg:48.48ms +step:1033/1555 train_time:50118ms step_avg:48.52ms +step:1034/1555 train_time:50207ms step_avg:48.56ms +step:1035/1555 train_time:50291ms step_avg:48.59ms +step:1036/1555 train_time:50380ms step_avg:48.63ms +step:1037/1555 train_time:50463ms step_avg:48.66ms +step:1038/1555 train_time:50553ms step_avg:48.70ms +step:1039/1555 train_time:50638ms step_avg:48.74ms +step:1040/1555 train_time:50727ms step_avg:48.78ms +step:1041/1555 train_time:50812ms step_avg:48.81ms +step:1042/1555 train_time:50904ms step_avg:48.85ms +step:1043/1555 train_time:50987ms step_avg:48.89ms +step:1044/1555 train_time:51079ms step_avg:48.93ms +step:1045/1555 train_time:51162ms step_avg:48.96ms +step:1046/1555 train_time:51252ms step_avg:49.00ms +step:1047/1555 train_time:51336ms step_avg:49.03ms +step:1048/1555 train_time:51424ms step_avg:49.07ms +step:1049/1555 train_time:51508ms step_avg:49.10ms +step:1050/1555 train_time:51599ms step_avg:49.14ms +step:1051/1555 train_time:51682ms step_avg:49.17ms +step:1052/1555 train_time:51773ms step_avg:49.21ms +step:1053/1555 train_time:51858ms step_avg:49.25ms +step:1054/1555 train_time:51947ms step_avg:49.29ms +step:1055/1555 train_time:52032ms step_avg:49.32ms +step:1056/1555 train_time:52122ms step_avg:49.36ms +step:1057/1555 train_time:52205ms step_avg:49.39ms +step:1058/1555 train_time:52296ms step_avg:49.43ms +step:1059/1555 train_time:52379ms step_avg:49.46ms +step:1060/1555 train_time:52469ms step_avg:49.50ms +step:1061/1555 train_time:52554ms step_avg:49.53ms +step:1062/1555 train_time:52643ms step_avg:49.57ms +step:1063/1555 train_time:52728ms step_avg:49.60ms +step:1064/1555 train_time:52818ms step_avg:49.64ms +step:1065/1555 train_time:52902ms step_avg:49.67ms +step:1066/1555 train_time:52993ms step_avg:49.71ms +step:1067/1555 train_time:53077ms step_avg:49.74ms +step:1068/1555 train_time:53168ms step_avg:49.78ms +step:1069/1555 train_time:53252ms step_avg:49.81ms +step:1070/1555 train_time:53342ms step_avg:49.85ms +step:1071/1555 train_time:53425ms step_avg:49.88ms +step:1072/1555 train_time:53516ms step_avg:49.92ms +step:1073/1555 train_time:53599ms step_avg:49.95ms +step:1074/1555 train_time:53689ms step_avg:49.99ms +step:1075/1555 train_time:53774ms step_avg:50.02ms +step:1076/1555 train_time:53864ms step_avg:50.06ms +step:1077/1555 train_time:53948ms step_avg:50.09ms +step:1078/1555 train_time:54040ms step_avg:50.13ms +step:1079/1555 train_time:54122ms step_avg:50.16ms +step:1080/1555 train_time:54212ms step_avg:50.20ms +step:1081/1555 train_time:54298ms step_avg:50.23ms +step:1082/1555 train_time:54386ms step_avg:50.26ms +step:1083/1555 train_time:54472ms step_avg:50.30ms +step:1084/1555 train_time:54562ms step_avg:50.33ms +step:1085/1555 train_time:54646ms step_avg:50.36ms +step:1086/1555 train_time:54737ms step_avg:50.40ms +step:1087/1555 train_time:54820ms step_avg:50.43ms +step:1088/1555 train_time:54910ms step_avg:50.47ms +step:1089/1555 train_time:54994ms step_avg:50.50ms +step:1090/1555 train_time:55083ms step_avg:50.54ms +step:1091/1555 train_time:55168ms step_avg:50.57ms +step:1092/1555 train_time:55258ms step_avg:50.60ms +step:1093/1555 train_time:55341ms step_avg:50.63ms +step:1094/1555 train_time:55431ms step_avg:50.67ms +step:1095/1555 train_time:55515ms step_avg:50.70ms +step:1096/1555 train_time:55604ms step_avg:50.73ms +step:1097/1555 train_time:55688ms step_avg:50.76ms +step:1098/1555 train_time:55778ms step_avg:50.80ms +step:1099/1555 train_time:55863ms step_avg:50.83ms +step:1100/1555 train_time:55953ms step_avg:50.87ms +step:1101/1555 train_time:56038ms step_avg:50.90ms +step:1102/1555 train_time:56127ms step_avg:50.93ms +step:1103/1555 train_time:56212ms step_avg:50.96ms +step:1104/1555 train_time:56301ms step_avg:51.00ms +step:1105/1555 train_time:56385ms step_avg:51.03ms +step:1106/1555 train_time:56475ms step_avg:51.06ms +step:1107/1555 train_time:56559ms step_avg:51.09ms +step:1108/1555 train_time:56649ms step_avg:51.13ms +step:1109/1555 train_time:56733ms step_avg:51.16ms +step:1110/1555 train_time:56823ms step_avg:51.19ms +step:1111/1555 train_time:56908ms step_avg:51.22ms +step:1112/1555 train_time:56999ms step_avg:51.26ms +step:1113/1555 train_time:57083ms step_avg:51.29ms +step:1114/1555 train_time:57173ms step_avg:51.32ms +step:1115/1555 train_time:57258ms step_avg:51.35ms +step:1116/1555 train_time:57346ms step_avg:51.38ms +step:1117/1555 train_time:57430ms step_avg:51.41ms +step:1118/1555 train_time:57521ms step_avg:51.45ms +step:1119/1555 train_time:57605ms step_avg:51.48ms +step:1120/1555 train_time:57695ms step_avg:51.51ms +step:1121/1555 train_time:57779ms step_avg:51.54ms +step:1122/1555 train_time:57868ms step_avg:51.58ms +step:1123/1555 train_time:57952ms step_avg:51.60ms +step:1124/1555 train_time:58042ms step_avg:51.64ms +step:1125/1555 train_time:58127ms step_avg:51.67ms +step:1126/1555 train_time:58217ms step_avg:51.70ms +step:1127/1555 train_time:58301ms step_avg:51.73ms +step:1128/1555 train_time:58390ms step_avg:51.76ms +step:1129/1555 train_time:58474ms step_avg:51.79ms +step:1130/1555 train_time:58564ms step_avg:51.83ms +step:1131/1555 train_time:58649ms step_avg:51.86ms +step:1132/1555 train_time:58739ms step_avg:51.89ms +step:1133/1555 train_time:58822ms step_avg:51.92ms +step:1134/1555 train_time:58913ms step_avg:51.95ms +step:1135/1555 train_time:58997ms step_avg:51.98ms +step:1136/1555 train_time:59087ms step_avg:52.01ms +step:1137/1555 train_time:59171ms step_avg:52.04ms +step:1138/1555 train_time:59261ms step_avg:52.07ms +step:1139/1555 train_time:59345ms step_avg:52.10ms +step:1140/1555 train_time:59436ms step_avg:52.14ms +step:1141/1555 train_time:59519ms step_avg:52.16ms +step:1142/1555 train_time:59608ms step_avg:52.20ms +step:1143/1555 train_time:59693ms step_avg:52.22ms +step:1144/1555 train_time:59783ms step_avg:52.26ms +step:1145/1555 train_time:59868ms step_avg:52.29ms +step:1146/1555 train_time:59960ms step_avg:52.32ms +step:1147/1555 train_time:60043ms step_avg:52.35ms +step:1148/1555 train_time:60132ms step_avg:52.38ms +step:1149/1555 train_time:60216ms step_avg:52.41ms +step:1150/1555 train_time:60306ms step_avg:52.44ms +step:1151/1555 train_time:60391ms step_avg:52.47ms +step:1152/1555 train_time:60481ms step_avg:52.50ms +step:1153/1555 train_time:60565ms step_avg:52.53ms +step:1154/1555 train_time:60655ms step_avg:52.56ms +step:1155/1555 train_time:60739ms step_avg:52.59ms +step:1156/1555 train_time:60829ms step_avg:52.62ms +step:1157/1555 train_time:60913ms step_avg:52.65ms +step:1158/1555 train_time:61003ms step_avg:52.68ms +step:1159/1555 train_time:61088ms step_avg:52.71ms +step:1160/1555 train_time:61179ms step_avg:52.74ms +step:1161/1555 train_time:61262ms step_avg:52.77ms +step:1162/1555 train_time:61352ms step_avg:52.80ms +step:1163/1555 train_time:61438ms step_avg:52.83ms +step:1164/1555 train_time:61526ms step_avg:52.86ms +step:1165/1555 train_time:61611ms step_avg:52.88ms +step:1166/1555 train_time:61701ms step_avg:52.92ms +step:1167/1555 train_time:61785ms step_avg:52.94ms +step:1168/1555 train_time:61876ms step_avg:52.98ms +step:1169/1555 train_time:61960ms step_avg:53.00ms +step:1170/1555 train_time:62049ms step_avg:53.03ms +step:1171/1555 train_time:62134ms step_avg:53.06ms +step:1172/1555 train_time:62224ms step_avg:53.09ms +step:1173/1555 train_time:62309ms step_avg:53.12ms +step:1174/1555 train_time:62399ms step_avg:53.15ms +step:1175/1555 train_time:62483ms step_avg:53.18ms +step:1176/1555 train_time:62574ms step_avg:53.21ms +step:1177/1555 train_time:62658ms step_avg:53.24ms +step:1178/1555 train_time:62747ms step_avg:53.27ms +step:1179/1555 train_time:62831ms step_avg:53.29ms +step:1180/1555 train_time:62920ms step_avg:53.32ms +step:1181/1555 train_time:63004ms step_avg:53.35ms +step:1182/1555 train_time:63095ms step_avg:53.38ms +step:1183/1555 train_time:63178ms step_avg:53.40ms +step:1184/1555 train_time:63267ms step_avg:53.44ms +step:1185/1555 train_time:63352ms step_avg:53.46ms +step:1186/1555 train_time:63442ms step_avg:53.49ms +step:1187/1555 train_time:63526ms step_avg:53.52ms +step:1188/1555 train_time:63617ms step_avg:53.55ms +step:1189/1555 train_time:63700ms step_avg:53.57ms +step:1190/1555 train_time:63790ms step_avg:53.61ms +step:1191/1555 train_time:63874ms step_avg:53.63ms +step:1192/1555 train_time:63963ms step_avg:53.66ms +step:1193/1555 train_time:64047ms step_avg:53.69ms +step:1194/1555 train_time:64138ms step_avg:53.72ms +step:1195/1555 train_time:64221ms step_avg:53.74ms +step:1196/1555 train_time:64311ms step_avg:53.77ms +step:1197/1555 train_time:64395ms step_avg:53.80ms +step:1198/1555 train_time:64485ms step_avg:53.83ms +step:1199/1555 train_time:64570ms step_avg:53.85ms +step:1200/1555 train_time:64660ms step_avg:53.88ms +step:1201/1555 train_time:64744ms step_avg:53.91ms +step:1202/1555 train_time:64834ms step_avg:53.94ms +step:1203/1555 train_time:64919ms step_avg:53.96ms +step:1204/1555 train_time:65008ms step_avg:53.99ms +step:1205/1555 train_time:65092ms step_avg:54.02ms +step:1206/1555 train_time:65182ms step_avg:54.05ms +step:1207/1555 train_time:65265ms step_avg:54.07ms +step:1208/1555 train_time:65356ms step_avg:54.10ms +step:1209/1555 train_time:65440ms step_avg:54.13ms +step:1210/1555 train_time:65529ms step_avg:54.16ms +step:1211/1555 train_time:65614ms step_avg:54.18ms +step:1212/1555 train_time:65704ms step_avg:54.21ms +step:1213/1555 train_time:65789ms step_avg:54.24ms +step:1214/1555 train_time:65880ms step_avg:54.27ms +step:1215/1555 train_time:65964ms step_avg:54.29ms +step:1216/1555 train_time:66054ms step_avg:54.32ms +step:1217/1555 train_time:66138ms step_avg:54.34ms +step:1218/1555 train_time:66226ms step_avg:54.37ms +step:1219/1555 train_time:66311ms step_avg:54.40ms +step:1220/1555 train_time:66401ms step_avg:54.43ms +step:1221/1555 train_time:66484ms step_avg:54.45ms +step:1222/1555 train_time:66575ms step_avg:54.48ms +step:1223/1555 train_time:66660ms step_avg:54.51ms +step:1224/1555 train_time:66749ms step_avg:54.53ms +step:1225/1555 train_time:66834ms step_avg:54.56ms +step:1226/1555 train_time:66923ms step_avg:54.59ms +step:1227/1555 train_time:67008ms step_avg:54.61ms +step:1228/1555 train_time:67098ms step_avg:54.64ms +step:1229/1555 train_time:67182ms step_avg:54.66ms +step:1230/1555 train_time:67272ms step_avg:54.69ms +step:1231/1555 train_time:67356ms step_avg:54.72ms +step:1232/1555 train_time:67446ms step_avg:54.75ms +step:1233/1555 train_time:67531ms step_avg:54.77ms +step:1234/1555 train_time:67621ms step_avg:54.80ms +step:1235/1555 train_time:67704ms step_avg:54.82ms +step:1236/1555 train_time:67794ms step_avg:54.85ms +step:1237/1555 train_time:67879ms step_avg:54.87ms +step:1238/1555 train_time:67968ms step_avg:54.90ms +step:1239/1555 train_time:68053ms step_avg:54.93ms +step:1240/1555 train_time:68143ms step_avg:54.95ms +step:1241/1555 train_time:68227ms step_avg:54.98ms +step:1242/1555 train_time:68317ms step_avg:55.01ms +step:1243/1555 train_time:68400ms step_avg:55.03ms +step:1244/1555 train_time:68489ms step_avg:55.06ms +step:1245/1555 train_time:68574ms step_avg:55.08ms +step:1246/1555 train_time:68664ms step_avg:55.11ms +step:1247/1555 train_time:68749ms step_avg:55.13ms +step:1248/1555 train_time:68839ms step_avg:55.16ms +step:1249/1555 train_time:68922ms step_avg:55.18ms +step:1250/1555 train_time:69013ms step_avg:55.21ms +step:1250/1555 val_loss:3.3977 train_time:69127ms step_avg:55.30ms +step:1251/1555 train_time:69149ms step_avg:55.28ms +step:1252/1555 train_time:69188ms step_avg:55.26ms +step:1253/1555 train_time:69275ms step_avg:55.29ms +step:1254/1555 train_time:69371ms step_avg:55.32ms +step:1255/1555 train_time:69455ms step_avg:55.34ms +step:1256/1555 train_time:69546ms step_avg:55.37ms +step:1257/1555 train_time:69629ms step_avg:55.39ms +step:1258/1555 train_time:69717ms step_avg:55.42ms +step:1259/1555 train_time:69801ms step_avg:55.44ms +step:1260/1555 train_time:69890ms step_avg:55.47ms +step:1261/1555 train_time:69972ms step_avg:55.49ms +step:1262/1555 train_time:70062ms step_avg:55.52ms +step:1263/1555 train_time:70149ms step_avg:55.54ms +step:1264/1555 train_time:70240ms step_avg:55.57ms +step:1265/1555 train_time:70327ms step_avg:55.59ms +step:1266/1555 train_time:70417ms step_avg:55.62ms +step:1267/1555 train_time:70502ms step_avg:55.65ms +step:1268/1555 train_time:70592ms step_avg:55.67ms +step:1269/1555 train_time:70674ms step_avg:55.69ms +step:1270/1555 train_time:70765ms step_avg:55.72ms +step:1271/1555 train_time:70848ms step_avg:55.74ms +step:1272/1555 train_time:70936ms step_avg:55.77ms +step:1273/1555 train_time:71020ms step_avg:55.79ms +step:1274/1555 train_time:71110ms step_avg:55.82ms +step:1275/1555 train_time:71194ms step_avg:55.84ms +step:1276/1555 train_time:71286ms step_avg:55.87ms +step:1277/1555 train_time:71371ms step_avg:55.89ms +step:1278/1555 train_time:71461ms step_avg:55.92ms +step:1279/1555 train_time:71546ms step_avg:55.94ms +step:1280/1555 train_time:71636ms step_avg:55.97ms +step:1281/1555 train_time:71721ms step_avg:55.99ms +step:1282/1555 train_time:71810ms step_avg:56.01ms +step:1283/1555 train_time:71893ms step_avg:56.03ms +step:1284/1555 train_time:71982ms step_avg:56.06ms +step:1285/1555 train_time:72065ms step_avg:56.08ms +step:1286/1555 train_time:72156ms step_avg:56.11ms +step:1287/1555 train_time:72241ms step_avg:56.13ms +step:1288/1555 train_time:72332ms step_avg:56.16ms +step:1289/1555 train_time:72416ms step_avg:56.18ms +step:1290/1555 train_time:72507ms step_avg:56.21ms +step:1291/1555 train_time:72591ms step_avg:56.23ms +step:1292/1555 train_time:72679ms step_avg:56.25ms +step:1293/1555 train_time:72763ms step_avg:56.27ms +step:1294/1555 train_time:72852ms step_avg:56.30ms +step:1295/1555 train_time:72935ms step_avg:56.32ms +step:1296/1555 train_time:73025ms step_avg:56.35ms +step:1297/1555 train_time:73109ms step_avg:56.37ms +step:1298/1555 train_time:73200ms step_avg:56.39ms +step:1299/1555 train_time:73285ms step_avg:56.42ms +step:1300/1555 train_time:73376ms step_avg:56.44ms +step:1301/1555 train_time:73460ms step_avg:56.46ms +step:1302/1555 train_time:73552ms step_avg:56.49ms +step:1303/1555 train_time:73634ms step_avg:56.51ms +step:1304/1555 train_time:73725ms step_avg:56.54ms +step:1305/1555 train_time:73810ms step_avg:56.56ms +step:1306/1555 train_time:73898ms step_avg:56.58ms +step:1307/1555 train_time:73982ms step_avg:56.60ms +step:1308/1555 train_time:74072ms step_avg:56.63ms +step:1309/1555 train_time:74156ms step_avg:56.65ms +step:1310/1555 train_time:74247ms step_avg:56.68ms +step:1311/1555 train_time:74331ms step_avg:56.70ms +step:1312/1555 train_time:74420ms step_avg:56.72ms +step:1313/1555 train_time:74505ms step_avg:56.74ms +step:1314/1555 train_time:74596ms step_avg:56.77ms +step:1315/1555 train_time:74680ms step_avg:56.79ms +step:1316/1555 train_time:74771ms step_avg:56.82ms +step:1317/1555 train_time:74854ms step_avg:56.84ms +step:1318/1555 train_time:74943ms step_avg:56.86ms +step:1319/1555 train_time:75026ms step_avg:56.88ms +step:1320/1555 train_time:75116ms step_avg:56.91ms +step:1321/1555 train_time:75201ms step_avg:56.93ms +step:1322/1555 train_time:75292ms step_avg:56.95ms +step:1323/1555 train_time:75376ms step_avg:56.97ms +step:1324/1555 train_time:75466ms step_avg:57.00ms +step:1325/1555 train_time:75551ms step_avg:57.02ms +step:1326/1555 train_time:75642ms step_avg:57.05ms +step:1327/1555 train_time:75725ms step_avg:57.07ms +step:1328/1555 train_time:75815ms step_avg:57.09ms +step:1329/1555 train_time:75898ms step_avg:57.11ms +step:1330/1555 train_time:75988ms step_avg:57.13ms +step:1331/1555 train_time:76072ms step_avg:57.15ms +step:1332/1555 train_time:76161ms step_avg:57.18ms +step:1333/1555 train_time:76246ms step_avg:57.20ms +step:1334/1555 train_time:76335ms step_avg:57.22ms +step:1335/1555 train_time:76420ms step_avg:57.24ms +step:1336/1555 train_time:76510ms step_avg:57.27ms +step:1337/1555 train_time:76594ms step_avg:57.29ms +step:1338/1555 train_time:76684ms step_avg:57.31ms +step:1339/1555 train_time:76768ms step_avg:57.33ms +step:1340/1555 train_time:76856ms step_avg:57.36ms +step:1341/1555 train_time:76940ms step_avg:57.38ms +step:1342/1555 train_time:77031ms step_avg:57.40ms +step:1343/1555 train_time:77114ms step_avg:57.42ms +step:1344/1555 train_time:77204ms step_avg:57.44ms +step:1345/1555 train_time:77289ms step_avg:57.46ms +step:1346/1555 train_time:77378ms step_avg:57.49ms +step:1347/1555 train_time:77462ms step_avg:57.51ms +step:1348/1555 train_time:77553ms step_avg:57.53ms +step:1349/1555 train_time:77636ms step_avg:57.55ms +step:1350/1555 train_time:77726ms step_avg:57.58ms +step:1351/1555 train_time:77810ms step_avg:57.59ms +step:1352/1555 train_time:77899ms step_avg:57.62ms +step:1353/1555 train_time:77984ms step_avg:57.64ms +step:1354/1555 train_time:78075ms step_avg:57.66ms +step:1355/1555 train_time:78158ms step_avg:57.68ms +step:1356/1555 train_time:78249ms step_avg:57.71ms +step:1357/1555 train_time:78333ms step_avg:57.73ms +step:1358/1555 train_time:78423ms step_avg:57.75ms +step:1359/1555 train_time:78507ms step_avg:57.77ms +step:1360/1555 train_time:78599ms step_avg:57.79ms +step:1361/1555 train_time:78684ms step_avg:57.81ms +step:1362/1555 train_time:78775ms step_avg:57.84ms +step:1363/1555 train_time:78858ms step_avg:57.86ms +step:1364/1555 train_time:78950ms step_avg:57.88ms +step:1365/1555 train_time:79032ms step_avg:57.90ms +step:1366/1555 train_time:79123ms step_avg:57.92ms +step:1367/1555 train_time:79207ms step_avg:57.94ms +step:1368/1555 train_time:79297ms step_avg:57.97ms +step:1369/1555 train_time:79381ms step_avg:57.98ms +step:1370/1555 train_time:79473ms step_avg:58.01ms +step:1371/1555 train_time:79557ms step_avg:58.03ms +step:1372/1555 train_time:79647ms step_avg:58.05ms +step:1373/1555 train_time:79731ms step_avg:58.07ms +step:1374/1555 train_time:79820ms step_avg:58.09ms +step:1375/1555 train_time:79905ms step_avg:58.11ms +step:1376/1555 train_time:79994ms step_avg:58.14ms +step:1377/1555 train_time:80078ms step_avg:58.15ms +step:1378/1555 train_time:80169ms step_avg:58.18ms +step:1379/1555 train_time:80252ms step_avg:58.20ms +step:1380/1555 train_time:80342ms step_avg:58.22ms +step:1381/1555 train_time:80426ms step_avg:58.24ms +step:1382/1555 train_time:80515ms step_avg:58.26ms +step:1383/1555 train_time:80600ms step_avg:58.28ms +step:1384/1555 train_time:80691ms step_avg:58.30ms +step:1385/1555 train_time:80774ms step_avg:58.32ms +step:1386/1555 train_time:80864ms step_avg:58.34ms +step:1387/1555 train_time:80948ms step_avg:58.36ms +step:1388/1555 train_time:81037ms step_avg:58.38ms +step:1389/1555 train_time:81120ms step_avg:58.40ms +step:1390/1555 train_time:81211ms step_avg:58.43ms +step:1391/1555 train_time:81295ms step_avg:58.44ms +step:1392/1555 train_time:81386ms step_avg:58.47ms +step:1393/1555 train_time:81470ms step_avg:58.49ms +step:1394/1555 train_time:81559ms step_avg:58.51ms +step:1395/1555 train_time:81644ms step_avg:58.53ms +step:1396/1555 train_time:81733ms step_avg:58.55ms +step:1397/1555 train_time:81818ms step_avg:58.57ms +step:1398/1555 train_time:81909ms step_avg:58.59ms +step:1399/1555 train_time:81992ms step_avg:58.61ms +step:1400/1555 train_time:82081ms step_avg:58.63ms +step:1401/1555 train_time:82165ms step_avg:58.65ms +step:1402/1555 train_time:82256ms step_avg:58.67ms +step:1403/1555 train_time:82340ms step_avg:58.69ms +step:1404/1555 train_time:82431ms step_avg:58.71ms +step:1405/1555 train_time:82514ms step_avg:58.73ms +step:1406/1555 train_time:82603ms step_avg:58.75ms +step:1407/1555 train_time:82687ms step_avg:58.77ms +step:1408/1555 train_time:82777ms step_avg:58.79ms +step:1409/1555 train_time:82861ms step_avg:58.81ms +step:1410/1555 train_time:82953ms step_avg:58.83ms +step:1411/1555 train_time:83036ms step_avg:58.85ms +step:1412/1555 train_time:83126ms step_avg:58.87ms +step:1413/1555 train_time:83210ms step_avg:58.89ms +step:1414/1555 train_time:83300ms step_avg:58.91ms +step:1415/1555 train_time:83385ms step_avg:58.93ms +step:1416/1555 train_time:83477ms step_avg:58.95ms +step:1417/1555 train_time:83559ms step_avg:58.97ms +step:1418/1555 train_time:83650ms step_avg:58.99ms +step:1419/1555 train_time:83733ms step_avg:59.01ms +step:1420/1555 train_time:83824ms step_avg:59.03ms +step:1421/1555 train_time:83907ms step_avg:59.05ms +step:1422/1555 train_time:83997ms step_avg:59.07ms +step:1423/1555 train_time:84080ms step_avg:59.09ms +step:1424/1555 train_time:84172ms step_avg:59.11ms +step:1425/1555 train_time:84255ms step_avg:59.13ms +step:1426/1555 train_time:84345ms step_avg:59.15ms +step:1427/1555 train_time:84430ms step_avg:59.17ms +step:1428/1555 train_time:84519ms step_avg:59.19ms +step:1429/1555 train_time:84603ms step_avg:59.20ms +step:1430/1555 train_time:84693ms step_avg:59.23ms +step:1431/1555 train_time:84777ms step_avg:59.24ms +step:1432/1555 train_time:84870ms step_avg:59.27ms +step:1433/1555 train_time:84953ms step_avg:59.28ms +step:1434/1555 train_time:85043ms step_avg:59.31ms +step:1435/1555 train_time:85127ms step_avg:59.32ms +step:1436/1555 train_time:85216ms step_avg:59.34ms +step:1437/1555 train_time:85300ms step_avg:59.36ms +step:1438/1555 train_time:85392ms step_avg:59.38ms +step:1439/1555 train_time:85475ms step_avg:59.40ms +step:1440/1555 train_time:85565ms step_avg:59.42ms +step:1441/1555 train_time:85649ms step_avg:59.44ms +step:1442/1555 train_time:85739ms step_avg:59.46ms +step:1443/1555 train_time:85824ms step_avg:59.48ms +step:1444/1555 train_time:85914ms step_avg:59.50ms +step:1445/1555 train_time:85998ms step_avg:59.51ms +step:1446/1555 train_time:86090ms step_avg:59.54ms +step:1447/1555 train_time:86174ms step_avg:59.55ms +step:1448/1555 train_time:86263ms step_avg:59.57ms +step:1449/1555 train_time:86349ms step_avg:59.59ms +step:1450/1555 train_time:86437ms step_avg:59.61ms +step:1451/1555 train_time:86522ms step_avg:59.63ms +step:1452/1555 train_time:86612ms step_avg:59.65ms +step:1453/1555 train_time:86696ms step_avg:59.67ms +step:1454/1555 train_time:86787ms step_avg:59.69ms +step:1455/1555 train_time:86870ms step_avg:59.70ms +step:1456/1555 train_time:86960ms step_avg:59.73ms +step:1457/1555 train_time:87046ms step_avg:59.74ms +step:1458/1555 train_time:87134ms step_avg:59.76ms +step:1459/1555 train_time:87219ms step_avg:59.78ms +step:1460/1555 train_time:87310ms step_avg:59.80ms +step:1461/1555 train_time:87393ms step_avg:59.82ms +step:1462/1555 train_time:87483ms step_avg:59.84ms +step:1463/1555 train_time:87567ms step_avg:59.85ms +step:1464/1555 train_time:87657ms step_avg:59.88ms +step:1465/1555 train_time:87740ms step_avg:59.89ms +step:1466/1555 train_time:87831ms step_avg:59.91ms +step:1467/1555 train_time:87914ms step_avg:59.93ms +step:1468/1555 train_time:88005ms step_avg:59.95ms +step:1469/1555 train_time:88090ms step_avg:59.97ms +step:1470/1555 train_time:88179ms step_avg:59.99ms +step:1471/1555 train_time:88264ms step_avg:60.00ms +step:1472/1555 train_time:88354ms step_avg:60.02ms +step:1473/1555 train_time:88439ms step_avg:60.04ms +step:1474/1555 train_time:88528ms step_avg:60.06ms +step:1475/1555 train_time:88611ms step_avg:60.08ms +step:1476/1555 train_time:88700ms step_avg:60.09ms +step:1477/1555 train_time:88784ms step_avg:60.11ms +step:1478/1555 train_time:88875ms step_avg:60.13ms +step:1479/1555 train_time:88958ms step_avg:60.15ms +step:1480/1555 train_time:89050ms step_avg:60.17ms +step:1481/1555 train_time:89133ms step_avg:60.18ms +step:1482/1555 train_time:89224ms step_avg:60.20ms +step:1483/1555 train_time:89309ms step_avg:60.22ms +step:1484/1555 train_time:89397ms step_avg:60.24ms +step:1485/1555 train_time:89481ms step_avg:60.26ms +step:1486/1555 train_time:89571ms step_avg:60.28ms +step:1487/1555 train_time:89655ms step_avg:60.29ms +step:1488/1555 train_time:89746ms step_avg:60.31ms +step:1489/1555 train_time:89829ms step_avg:60.33ms +step:1490/1555 train_time:89919ms step_avg:60.35ms +step:1491/1555 train_time:90004ms step_avg:60.36ms +step:1492/1555 train_time:90094ms step_avg:60.38ms +step:1493/1555 train_time:90178ms step_avg:60.40ms +step:1494/1555 train_time:90269ms step_avg:60.42ms +step:1495/1555 train_time:90353ms step_avg:60.44ms +step:1496/1555 train_time:90442ms step_avg:60.46ms +step:1497/1555 train_time:90526ms step_avg:60.47ms +step:1498/1555 train_time:90615ms step_avg:60.49ms +step:1499/1555 train_time:90699ms step_avg:60.51ms +step:1500/1555 train_time:90791ms step_avg:60.53ms +step:1500/1555 val_loss:3.2939 train_time:90905ms step_avg:60.60ms +step:1501/1555 train_time:90928ms step_avg:60.58ms +step:1502/1555 train_time:90967ms step_avg:60.56ms +step:1503/1555 train_time:91051ms step_avg:60.58ms +step:1504/1555 train_time:91146ms step_avg:60.60ms +step:1505/1555 train_time:91231ms step_avg:60.62ms +step:1506/1555 train_time:91322ms step_avg:60.64ms +step:1507/1555 train_time:91405ms step_avg:60.65ms +step:1508/1555 train_time:91493ms step_avg:60.67ms +step:1509/1555 train_time:91577ms step_avg:60.69ms +step:1510/1555 train_time:91666ms step_avg:60.71ms +step:1511/1555 train_time:91749ms step_avg:60.72ms +step:1512/1555 train_time:91840ms step_avg:60.74ms +step:1513/1555 train_time:91925ms step_avg:60.76ms +step:1514/1555 train_time:92016ms step_avg:60.78ms +step:1515/1555 train_time:92103ms step_avg:60.79ms +step:1516/1555 train_time:92198ms step_avg:60.82ms +step:1517/1555 train_time:92284ms step_avg:60.83ms +step:1518/1555 train_time:92373ms step_avg:60.85ms +step:1519/1555 train_time:92457ms step_avg:60.87ms +step:1520/1555 train_time:92546ms step_avg:60.89ms +step:1521/1555 train_time:92630ms step_avg:60.90ms +step:1522/1555 train_time:92720ms step_avg:60.92ms +step:1523/1555 train_time:92804ms step_avg:60.94ms +step:1524/1555 train_time:92895ms step_avg:60.95ms +step:1525/1555 train_time:92980ms step_avg:60.97ms +step:1526/1555 train_time:93071ms step_avg:60.99ms +step:1527/1555 train_time:93159ms step_avg:61.01ms +step:1528/1555 train_time:93248ms step_avg:61.03ms +step:1529/1555 train_time:93334ms step_avg:61.04ms +step:1530/1555 train_time:93423ms step_avg:61.06ms +step:1531/1555 train_time:93506ms step_avg:61.08ms +step:1532/1555 train_time:93596ms step_avg:61.09ms +step:1533/1555 train_time:93680ms step_avg:61.11ms +step:1534/1555 train_time:93769ms step_avg:61.13ms +step:1535/1555 train_time:93854ms step_avg:61.14ms +step:1536/1555 train_time:93945ms step_avg:61.16ms +step:1537/1555 train_time:94031ms step_avg:61.18ms +step:1538/1555 train_time:94123ms step_avg:61.20ms +step:1539/1555 train_time:94207ms step_avg:61.21ms +step:1540/1555 train_time:94299ms step_avg:61.23ms +step:1541/1555 train_time:94383ms step_avg:61.25ms +step:1542/1555 train_time:94473ms step_avg:61.27ms +step:1543/1555 train_time:94557ms step_avg:61.28ms +step:1544/1555 train_time:94646ms step_avg:61.30ms +step:1545/1555 train_time:94730ms step_avg:61.31ms +step:1546/1555 train_time:94821ms step_avg:61.33ms +step:1547/1555 train_time:94906ms step_avg:61.35ms +step:1548/1555 train_time:94997ms step_avg:61.37ms +step:1549/1555 train_time:95083ms step_avg:61.38ms +step:1550/1555 train_time:95172ms step_avg:61.40ms +step:1551/1555 train_time:95259ms step_avg:61.42ms +step:1552/1555 train_time:95348ms step_avg:61.44ms +step:1553/1555 train_time:95433ms step_avg:61.45ms +step:1554/1555 train_time:95524ms step_avg:61.47ms +step:1555/1555 train_time:95608ms step_avg:61.48ms +step:1555/1555 val_loss:3.2777 train_time:95723ms step_avg:61.56ms +peak memory allocated: 30746 MiB reserved: 46798 MiB diff --git a/records/track_1_short/2026-01-31-BigramHashH2D/439741b3-3557-40ec-8dbd-774921ae6a7d.txt b/records/track_1_short/2026-01-31-BigramHashH2D/439741b3-3557-40ec-8dbd-774921ae6a7d.txt new file mode 100644 index 000000000..bba92cbfd --- /dev/null +++ b/records/track_1_short/2026-01-31-BigramHashH2D/439741b3-3557-40ec-8dbd-774921ae6a7d.txt @@ -0,0 +1,3976 @@ +import os +import sys + +# Read the current file and the kernels file code ASAP, for logging +with open(sys.argv[0], 'r') as f: + code = f.read() +with open(os.path.join(os.path.dirname(sys.argv[0]), 'triton_kernels.py'), 'r') as f: + code += f"\n\n{'-'*40}\n# triton_kernels.py\n{'-'*40}\n\n" + code += f.read() + +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate, pairwise +from pathlib import Path +import gc + +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +import torch +import triton + +torch.empty( + 1, device=f"cuda:{os.environ['LOCAL_RANK']}", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +from kernels import get_kernel +from torch import Tensor, nn + +from triton_kernels import XXT, ba_plus_cAA, FusedLinearReLUSquareFunction, FusedSoftcappedCrossEntropy + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Distributed training setup +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +grad_scale = 2 / grad_accum_steps # consistent grad magnitudes between different num_devices +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng +# Transposed layout by @ChrisJMcCormick allows for faster gradient accumulation. + +@torch.library.custom_op("nanogpt::mm_t", mutates_args=()) +def mm_t_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + """Computes y = x @ w with F8 weights stored as (in_features, out_features).""" + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + assert x.shape[1] == w.shape[0] # x: (batch, in), w: (in, out) + + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + + # _scaled_mm requires column-major B. w_f8 is row-major (in, out). + # .T.contiguous().T creates a column-major view without changing logical shape. + w_f8_col_major = w_f8.T.contiguous().T + + out = torch._scaled_mm( + x_f8, + w_f8_col_major, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_t_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[0] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_t_backward", mutates_args=()) +def mm_t_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + + x_scale = grad.new_tensor(x_s, dtype=torch.float32) + w_scale = grad.new_tensor(w_s, dtype=torch.float32) + grad_scale = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + + # grad_x = grad @ w.T + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=grad_scale, + scale_b=w_scale, + use_fast_accum=False, + ) + + # grad_w = x.T @ grad + # Result is (in, out), naturally matching weight storage. No final .T needed. + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_scale, + scale_b=grad_scale, + use_fast_accum=False, + ) + + return grad_x, grad_w + + grad_x, grad_w = impl(g, x_f8, w_f8) + + return grad_x, grad_w + +@mm_t_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) + +def backward_t(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_t_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context_t(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_t_op.register_autograd(backward_t, setup_context=setup_context_t) + +# ----------------------------------------------------------------------------- +# Polar Express + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor, split_baddbmm: bool = False): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + # Select batched vs unbatched + if split_baddbmm: + BX_matmul = torch.bmm if X.ndim > 2 else torch.mm + else: + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in polar_express_coeffs: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + + # Referencing X twice causes pytorch to make a defensive copy, + # resulting in a cudaMemcpyAsync in baddbmm. + # For large matrices (i.e., the mlp weights), it's faster to split + # the operation into two kernels to avoid this. + if split_baddbmm: + BX_matmul(B, X, out=C) # C = B @ X + C.add_(X, alpha=a) # C = C + a*X (in-place, X only read) + else: + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Combined NorMuon + Adam Optimizer + +@dataclass +class ParamConfig: + """Per-parameter configuration for NorMuonAndAdam optimizer.""" + label: str + optim: str # "adam" or "normuon" + comms: str # "none", "replicated", or "sharded" + adam_betas: tuple[float, float] | None + lr_mul: float + wd_mul: float + lr: float + initial_lr: float + weight_decay: float + # Adam-specific + eps: float | None = None + # NorMuon-specific + reshape: tuple | None = None + chunk_size: int | None = None + momentum: float | None = None + beta2: float | None = None + per_matrix_lr_mul: list[float] | None = None + + +class NorMuonAndAdam: + """ + Combined optimizer that handles both NorMuon (for projection matrices) and + Adam (for embeddings/scalars/gate weights). + + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, Muon uses a Newton-Schulz iteration (replaced + here with Polar Express), which has the advantage that it can be stably run in bfloat16 on the GPU. + + Muon is applied only to the projection matrices in the attention and MLP layers, and is not recommended + for embeddings, scalars, or individual weight vectors (e.g., bias terms or gate weights). + + Differences from standard Muon: + - Newton-Shulz is replaced with Polar Express for the orthogonalization step + - NorMuon adds a low-rank variance estimator similar to Adafactor. https://arxiv.org/pdf/2510.05491 + - Cautious weight decay, a gated version of decoupled weight decay + - Mantissa tracking for precision + + Adam (for embeddings/scalars/gates): + - Standard Adam with bias correction + - Cautious weight decay + + Configuration: + Unlike torch.optim.Optimizer, this class uses per-parameter configs from a `param_table` dict + and does not include parameter "groups". All parameters require a .label attribute, and a + corresponding entry in the param_table to specify their hyperparameters (lr_mul, wd_mul, adam_betas, etc.). + + Communication and ordering: + Gradient communication is explicitly scheduled rather than hook-driven. + Reductions are launched in `scatter_order`, while update math and final + gathers are executed in `work_order`. These orders are independent and + must each contain every parameter label exactly once. + + Two communication modes are supported per parameter: + - 'replicated': Gradients are all-reduced and each rank computes the full update. + - 'sharded': Gradients are reduce-scattered, each rank updates its shard, + and results are all-gathered. + + Adam parameters may be freely sharded. NorMuon operates on full matrices; sharding is + supported by grouping matrices into parameter banks. NorMuon parameters must have a + `.reshape` attribute that reshapes the bank so that the leading dimension is divisible + by world_size. + + # Contributors include @YouJiacheng, @KonstantinWilleke, @alexrgilbert, @adricarda, + # @tuttyfrutyee, @vdlad, @ryanyang0, @vagrawal, @varunneal, @chrisjmccormick + """ + def __init__(self, named_params, param_table: dict, scatter_order: list, work_order: list, + adam_defaults: dict, normuon_defaults: dict): + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # Store defaults for each optimizer type + self.adam_defaults = adam_defaults + self.normuon_defaults = normuon_defaults + self.param_table = param_table + self.scatter_order = scatter_order + self.work_order = work_order + + # Collect params by label and build config + self.param_cfgs: dict[nn.Parameter, ParamConfig] = {} + self.param_states: dict[nn.Parameter, dict] = {} + self._param_by_label: dict[str, nn.Parameter] = {} + for name, param in named_params: + label = getattr(param, "label", None) + assert label is not None and label in param_table # all params must have valid label + assert label not in self._param_by_label # exactly one param per label + self._param_by_label[label] = param + self._build_param_cfg(param, label) + + # Assert scatter_order and work_order match present labels exactly + present = set(self._param_by_label.keys()) + assert set(scatter_order) == present and set(work_order) == present + + # Handle world_size=1: overwrite comms to "none" + if self.world_size == 1: + for p_cfg in self.param_cfgs.values(): + p_cfg.comms = "none" + + # Initialize state for all params + self._init_state() + + # 0-D CPU tensors to avoid recompilation + self._step_size_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + + # Track async operations + self._reduce_futures: dict[nn.Parameter, tuple] = {} + + # Embed/lm_head tying state + self.split_embed = False + self._lm_head_param = self._param_by_label.get("lm_head") + self._embed_param = self._param_by_label.get("embed") + + def _build_param_cfg(self, param: nn.Parameter, label: str): + """Build config for a single parameter from param_table.""" + table_entry = self.param_table[label] + optim = table_entry["optim"] + comms = table_entry["comms"] + adam_betas = table_entry.get("adam_betas") + lr_mul = table_entry.get("lr_mul", 1.0) + wd_mul = table_entry.get("wd_mul", 1.0) + + if optim == "adam": + chunk_size = param.shape[0] // self.world_size if comms == "sharded" else None + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.adam_defaults["lr"], + initial_lr=self.adam_defaults["lr"], + weight_decay=self.adam_defaults["weight_decay"], + eps=self.adam_defaults["eps"], + chunk_size=chunk_size, + ) + elif optim == "normuon": + reshape = getattr(param, "reshape", None) + if reshape is None: + raise ValueError(f"NorMuon param {label} must have .reshape attribute") + if reshape[0] % self.world_size != 0: + raise ValueError(f"reshape[0]={reshape[0]} must be divisible by world_size") + + chunk_size = reshape[0] // self.world_size + chunk_shape = (chunk_size, *reshape[1:]) + # Shape-based LR multiplier for NorMuon + shape_mult = max(1.0, chunk_shape[-2] / chunk_shape[-1]) ** 0.5 if len(chunk_shape) >= 2 else 1.0 + lr_mul = shape_mult * lr_mul + + # Per-matrix LR multipliers for MLP c_proj (2x LR on odd indices) + per_matrix_lr_mul = None + if label == "mlp": + rank = dist.get_rank() if dist.is_initialized() else 0 + start_idx = rank * chunk_size + per_matrix_lr_mul = [] + for i in range(chunk_size): + global_idx = start_idx + i + is_c_proj = (global_idx % 2 == 1) + per_matrix_lr_mul.append(2.0 if is_c_proj else 1.0) + + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.normuon_defaults["lr"], + initial_lr=self.normuon_defaults["lr"], + weight_decay=self.normuon_defaults["weight_decay"], + reshape=reshape, + chunk_size=chunk_size, + momentum=self.normuon_defaults["momentum"], + beta2=self.normuon_defaults["beta2"], + per_matrix_lr_mul=per_matrix_lr_mul, + ) + else: + raise ValueError(f"Unknown optim type: {optim}") + + self.param_cfgs[param] = p_cfg + + def _init_state(self): + """Initialize optimizer state for all parameters.""" + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam": + # Sharded params use chunk state, replicated use full state + if p_cfg.comms == "sharded": + chunk = param[:p_cfg.chunk_size] + else: + chunk = param + exp_avg = torch.zeros_like(chunk, dtype=torch.float32, device=param.device) + self.param_states[param] = dict(step=0, exp_avg=exp_avg, exp_avg_sq=torch.zeros_like(exp_avg)) + + elif p_cfg.optim == "normuon": + chunk_shape = (p_cfg.chunk_size, *p_cfg.reshape[1:]) + + # Momentum buffer (FP32 for precision) + momentum_buffer = torch.zeros( + chunk_shape, dtype=torch.float32, device=param.device + ) + + # Second momentum buffer - reduced along one dimension + if chunk_shape[-2] >= chunk_shape[-1]: + second_mom_shape = (*chunk_shape[:-1], 1) + else: + second_mom_shape = (*chunk_shape[:-2], 1, chunk_shape[-1]) + second_momentum_buffer = torch.zeros( + second_mom_shape, dtype=torch.float32, device=param.device + ) + + # Mantissa buffer for precision tracking + mantissa = torch.zeros( + chunk_shape, dtype=torch.uint16, device=param.device + ) + + self.param_states[param] = dict( + momentum_buffer=momentum_buffer, + second_momentum_buffer=second_momentum_buffer, + mantissa=mantissa, + ) + + # ----------------------------------- + # Reduce/Gather operations + + def _launch_reduce(self, param: nn.Parameter, grad: Tensor): + """Launch async reduce for a parameter based on its comms policy.""" + p_cfg = self.param_cfgs[param] + + if p_cfg.comms == "none": + if p_cfg.optim == "normuon": + # NorMuon needs reshaped gradient even without communication + grad = grad.view(p_cfg.reshape) + self._reduce_futures[param] = (None, grad) + elif p_cfg.comms == "replicated": + future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + self._reduce_futures[param] = (future, grad) + elif p_cfg.comms == "sharded": + if p_cfg.optim == "normuon": + # NorMuon: reshape before reduce_scatter + grad_reshaped = grad.view(p_cfg.reshape) + grad_chunk = torch.empty( + (p_cfg.chunk_size, *grad_reshaped.shape[1:]), + dtype=grad.dtype, + device=grad.device + ) + future = dist.reduce_scatter_tensor( + grad_chunk, grad_reshaped.contiguous(), op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + else: + # Adam: simple reduce_scatter + grad_chunk = torch.empty_like(grad[:p_cfg.chunk_size]) + future = dist.reduce_scatter_tensor( + grad_chunk, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + + def _launch_gather(self, param: nn.Parameter, p_slice: Tensor) -> "torch.futures.Future": + """Launch async all_gather for a sharded parameter.""" + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "normuon": + full_param = param.data.view(p_cfg.reshape) + assert full_param.is_contiguous() + return dist.all_gather_into_tensor( + full_param, p_slice.contiguous(), async_op=True + ).get_future() + else: + return dist.all_gather_into_tensor( + param, p_slice.contiguous(), async_op=True + ).get_future() + + # ----------------------------------- + # State management + + def reset(self): + """Reset NorMuon momentum buffers and split_embed state (called on training reset).""" + self.split_embed = False + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "normuon": + p_state = self.param_states[param] + p_state["momentum_buffer"].zero_() + p_state["mantissa"].zero_() + p_state["second_momentum_buffer"].zero_() + + def copy_lm_state_to_embed(self): + """ + Copy the optimizer state from the lm_head to the embed at the untie point. + This requires an all-gather + reshard because of different sharding: + - lm_head (768, 50304) is sharded to (96, 50304) per rank (along model_dim) + - embed (50304, 768) is sharded to (6288, 768) per rank (along vocab_size) + + We all-gather the lm_head momentum, transpose it, then each rank takes their + embed shard to get the correct momentum state. + """ + lm_head = self._lm_head_param + embed = self._embed_param + lm_state = self.param_states[lm_head] + embed_state = self.param_states[embed] + lm_cfg = self.param_cfgs[lm_head] + embed_cfg = self.param_cfgs[embed] + + embed_state['step'] = lm_state['step'] # Preserve step count for bias correction + + # Copy optimizer state with all-gather + transpose + reshard + if self.world_size > 1: + rank = dist.get_rank() + lm_chunk_size = lm_cfg.chunk_size # 96 + embed_chunk_size = embed_cfg.chunk_size # 6288 + + # All-gather lm_head momentum to get full (768, 50304) tensor + for key in ["exp_avg", "exp_avg_sq"]: + lm_chunk = lm_state[key] # (96, 50304) + full_lm = torch.empty(lm_head.shape[0], lm_head.shape[1], dtype=lm_chunk.dtype, device=lm_chunk.device) + dist.all_gather_into_tensor(full_lm, lm_chunk.contiguous()) + embed_state[key].copy_(full_lm.T[rank * embed_chunk_size:(rank + 1) * embed_chunk_size]) + else: + # Single GPU: simple transpose + for key in ["exp_avg", "exp_avg_sq"]: + embed_state[key].copy_(lm_state[key].T) + + # Mark as split + self.split_embed = True + + def state_dict(self): + """Return the optimizer state as a dict.""" + return { + "param_states": {id(p): s for p, s in self.param_states.items()}, + "param_cfgs": {id(p): s for p, s in self.param_cfgs.items()}, + } + + def load_state_dict(self, state_dict): + """Load optimizer state from a dict.""" + # Build id->param mapping + id_to_param = {id(p): p for p in self.param_cfgs.keys()} + + # Load state, preserving dtypes + for param_id, saved_p_state in state_dict["param_states"].items(): + if param_id in id_to_param: + param = id_to_param[param_id] + p_state = self.param_states[param] + for k, v in saved_p_state.items(): + if isinstance(v, torch.Tensor) and k in p_state: + target_dtype = p_state[k].dtype + p_state[k] = v.to(dtype=target_dtype, device=p_state[k].device) + else: + p_state[k] = v + + # ----------------------------------- + # Unified optimizer step with explicit ordering + + @torch.no_grad() + def step(self, do_adam: bool = True): + """ + Combined optimizer step with explicit ordering. + + Args: + do_adam: If True, update Adam params. NorMuon params always updated. + + Flow: + 1. Scatter phase: Launch reduces in scatter_order + 2. Work phase: Process updates in work_order + - Wait for reduce, compute update, launch gather + 3. Finalize phase: Wait for gathers + + While the embeddings are tied: + - Comms and update math are only done on lm_head. + - We add embed.grad.T into lm_head.grad before comms. + - After lm_head gather, we copy lm_head.data.T --> embed.data + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + lm_param, embed_param = self._lm_head_param, self._embed_param + + # ===== Phase 1: Launch reduces in scatter_order ===== + for label in self.scatter_order: + param = self._param_by_label[label] + p_cfg = self.param_cfgs[param] + + if p_cfg.optim == "adam" and not do_adam: + continue + if param.grad is None: + continue + + # lm_head when tied: aggregate embed.grad.T (transposed shapes) + if label == "lm_head" and do_adam and not self.split_embed: + if embed_param is not None and embed_param.grad is not None: + param.grad.add_(embed_param.grad.T) + + # Skip embed when tied (copied from lm_head after gather) + if label == "embed" and not self.split_embed: + continue + + self._launch_reduce(param, param.grad) + + # ===== Phase 2: Process updates in work_order ===== + gather_futures = [] + lm_head_gather_future = None + + for label in self.work_order: + param = self._param_by_label[label] + if param not in self._reduce_futures: + continue + + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "adam" and not do_adam: + continue + # Wait for reduce + future, grad_chunk = self._reduce_futures[param] + if future is not None: + future.wait() + # Apply update based on optim type + if p_cfg.optim == "adam": + p_slice = self._adam_update(param, grad_chunk, p_cfg, rank) + else: + p_slice = self._normuon_update(param, grad_chunk, p_cfg, rank) + # Launch gather for sharded params + if p_cfg.comms == "sharded" and self.world_size > 1: + gather_fut = self._launch_gather(param, p_slice) + if label == "lm_head": + lm_head_gather_future = gather_fut + else: + gather_futures.append(gather_fut) + + # ===== Phase 3: Wait for gathers, sync embed if tied ===== + # Wait for lm_head gather first so we can copy to embed while other gathers complete + if lm_head_gather_future is not None: + lm_head_gather_future.wait() + + # When tied: copy lm_head.T to embed + if do_adam and not self.split_embed and embed_param is not None and lm_param is not None: + embed_param.data.copy_(lm_param.data.T) + + # Wait for remaining gathers + for fut in gather_futures: + fut.wait() + + self._reduce_futures.clear() + + # Clear grads for updated params + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam" and not do_adam: + continue # Don't clear Adam grads on even steps + param.grad = None + + # ----------------------------------- + # Adam update + + def _adam_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply Adam update to a parameter. Returns the updated p_slice.""" + beta1, beta2 = p_cfg.adam_betas + lr = p_cfg.lr * p_cfg.lr_mul + + # Get parameter slice + if p_cfg.comms == "sharded": + p_slice = param[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + else: + p_slice = param + + p_state = self.param_states[param] + p_state["step"] += 1 + t = p_state["step"] + + bias1, bias2 = 1 - beta1 ** t, 1 - beta2 ** t + self._step_size_t.fill_(lr * (bias2 ** 0.5 / bias1)) + self._eff_wd_t.fill_(lr * lr * p_cfg.weight_decay * p_cfg.wd_mul) + + NorMuonAndAdam._adam_update_step( + p_slice, grad_chunk, p_state["exp_avg"], p_state["exp_avg_sq"], + beta1, beta2, p_cfg.eps, self._step_size_t, self._eff_wd_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _adam_update_step(p_slice, g_slice, exp_avg, exp_avg_sq, beta1, beta2, eps, step_size_t, eff_wd_t): + """Compiled Adam update step.""" + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + update = exp_avg.div(exp_avg_sq.sqrt().add_(eps)).mul_(step_size_t) + # Cautious weight decay + mask = (update * p_slice) > 0 + update.addcmul_(p_slice, mask, value=eff_wd_t) + p_slice.add_(other=update, alpha=-1.0) + + # ----------------------------------- + # NorMuon update + + def _normuon_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply NorMuon update to a parameter. Returns the updated p_slice.""" + chunk_shape = grad_chunk.shape + + p_state = self.param_states[param] + grad_chunk = grad_chunk.float() # FP32 for momentum + + # Momentum update + momentum_buffer = p_state["momentum_buffer"] + momentum_buffer.lerp_(grad_chunk, 1 - p_cfg.momentum) + updated_grads = grad_chunk.lerp_(momentum_buffer, p_cfg.momentum) + + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + + # Polar Express orthogonalization + is_large_matrix = chunk_shape[-2] > 1024 + v_chunk = polar_express(updated_grads, split_baddbmm=is_large_matrix) + + # Variance reduction + red_dim = -1 if chunk_shape[-2] >= chunk_shape[-1] else -2 + v_chunk = NorMuonAndAdam._apply_normuon_variance_reduction( + v_chunk, p_state["second_momentum_buffer"], p_cfg.beta2, red_dim + ) + + # Update parameter, in place, with cautious weight decay + param_view = param.data.view(p_cfg.reshape) + p_slice = param_view[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + + # MLP has per-matrix LR multipliers (c_proj gets 2x LR) + if p_cfg.per_matrix_lr_mul is not None: + for mat_idx in range(p_cfg.chunk_size): + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.per_matrix_lr_mul[mat_idx] * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice[mat_idx].view(torch.uint16), p_state["mantissa"][mat_idx], v_chunk[mat_idx], + self._eff_wd_t, self._eff_lr_t + ) + else: + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice.view(torch.uint16), p_state["mantissa"], v_chunk, + self._eff_wd_t, self._eff_lr_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _cautious_wd_and_update_inplace(p, mantissa, grad, wd_tensor, lr_tensor): + """ + Cautious weight decay + parameter update. wd_tensor and lr_tensor are 0-D CPU tensors. + Mantissa is tracked to enable higher precision updates on bfloat16 parameters. + bfloat16 format: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits total + float32 format: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits total + """ + assert p.dtype == mantissa.dtype == torch.uint16 + grad = grad.float() + wd_factor = wd_tensor.to(torch.float32) + lr_factor = lr_tensor.to(torch.float32) + p_precise_raw = (p.to(torch.uint32) << 16) | mantissa.to(torch.uint32) + p_precise = p_precise_raw.view(torch.float32) + mask = (grad * p_precise) >= 0 + p_precise.copy_(p_precise - (p_precise * mask * wd_factor * lr_factor) - (grad * lr_factor)) + p.copy_((p_precise_raw >> 16).to(torch.uint16)) + mantissa.copy_(p_precise_raw.to(torch.uint16)) + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _apply_normuon_variance_reduction(v_chunk, second_momentum_buffer, beta2, red_dim): + """NorMuon variance reduction. Algebraically fuses the normalization steps to minimize memory ops.""" + v_mean = v_chunk.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = v_chunk.size(red_dim) + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True).mul_(red_dim_size) + v_norm = v_norm_sq.sqrt_() + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt_() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt_() + final_scale = step_size * (v_norm / v_norm_new.clamp_min_(1e-10)) + return v_chunk.mul_(final_scale.type_as(v_chunk)) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinearT(nn.Module): + """ + Linear layer with transposed weight storage (in_features, out_features) which + addresses the slow kernel that was used for gradient accumulation. @chrisjmccormick + """ + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + self.weight = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.bfloat16)) + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + nn.init.zeros_(self.weight) # @Grad62304977 and others + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out = torch.ops.nanogpt.mm_t(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return x @ self.weight.type_as(x) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len, paired=False): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.paired = paired + self.reset() + + def rotary(self, x_BTHD): + assert self.factor1.size(0) >= x_BTHD.size(-3) + factor1, factor2 = ( + self.factor1[None, : x_BTHD.size(-3), None, :], + self.factor2[None, : x_BTHD.size(-3), None, :], + ) + x_flip = x_BTHD.view(*x_BTHD.shape[:-1], x_BTHD.shape[-1] // 2, 2).flip(-1).view(x_BTHD.shape) + return factor1 * x_BTHD + factor2 * x_flip + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + angular_freq = angular_freq.repeat_interleave(2) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//2)]) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=device) + if not self.paired: + theta = torch.outer(t, angular_freq) + self.factor1 = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.factor2 = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, angular_freq) + theta2 = torch.outer(t_odd, angular_freq) + self.factor1 = nn.Buffer( + torch.cat((theta1.cos(), theta2.cos()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2 = nn.Buffer( + torch.cat((theta1.sin(), theta2.sin()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2[..., 1::2] *= -1 + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + if not self.paired: + theta = torch.outer(t, self.angular_freq) + self.factor1.copy_(theta.cos()) + self.factor2.copy_(theta.sin()) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, self.angular_freq) + theta2 = torch.outer(t_odd, self.angular_freq) + self.factor1.copy_(torch.cat((theta1.cos(), theta2.cos()), dim=-1)) + self.factor2.copy_(torch.cat((theta1.sin(), theta2.sin()), dim=-1)) + self.factor2[..., 1::2] *= -1 + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + yarn: Yarn + key_offset: bool + attn_gate_w: torch.Tensor + ve_gate_w: torch.Tensor + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, paired: bool = False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + self.paired = paired + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + yarn = attn_args.yarn + ve, sa_lambdas, key_offset = attn_args.ve, attn_args.sa_lambdas, attn_args.key_offset + seqlens, bm_size = attn_args.seqlens, attn_args.bm_size + # sparse gated attention to enable context based no-op by @classiclarryd + # only include gates on layers with value embeds used on forward pass + attn_gate_w, ve_gate_w = attn_args.attn_gate_w, attn_args.ve_gate_w + + q, k, v = F.linear(x, sa_lambdas[0] * qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + q, k = norm(q), norm(k) # QK norm @Grad62304977 + + if not self.paired: + q, k = yarn.rotary(q), yarn.rotary(k) + + if key_offset: + # shift keys forward for the stationary head dims. Enables 1-layer induction. + k[:, 1:, :, self.head_dim // 2:] = k[:, :-1, :, self.head_dim // 2:] + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T, self.num_heads, 1) + v = v + ve_gate_out * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + + else: + # Paired heads: adjacent heads' queries attend to each other's keys. + # Two copies of the input stream are interleaved to achieve this, which: + # - doubles the length of each sequence + # - halves the effective window size + q = q.view(B, T, self.num_heads // 2, self.head_dim * 2) + k = k.view(B, T, self.num_heads // 2, self.head_dim * 2) + v = v.reshape(B, T * 2, self.num_heads // 2, self.head_dim) + + q, k = yarn.rotary(q), yarn.rotary(k) + + q = q.view(B, T * 2, self.num_heads // 2, self.head_dim) + k = k.view(B, T * 2, self.num_heads // 2, self.head_dim) + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T * 2, self.num_heads // 2, 1) + v = v + ve_gate_out * ve.view_as(v) + + seqlens = 2 * seqlens + max_len = 2 * max_len + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, + max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=yarn.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(F.linear(x[..., :12], attn_gate_w)).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, sa_lambdas[1] * qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg + return y + +class MLP(nn.Module): + def __init__(self): + super().__init__() + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, c_fc: Tensor, c_proj: Tensor): + # relu(x)^2: + # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + # Fused triton kernel for relu(x @ W1.T)^2 @ W2.T + return FusedLinearReLUSquareFunction.apply(x, c_fc, c_proj) + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, has_attn: bool, has_mlp: bool, use_paired_head: bool): + super().__init__() + # skip attention of blocks.6 (the 7th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads, paired=use_paired_head) if has_attn else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP() if has_mlp else None + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor = None, c_fc: Tensor = None, c_proj: Tensor = None): + if self.attn is not None: + x = x + self.attn(norm(x), attn_args, qkvo_w) + if self.mlp is not None: + x = x + self.mlp(norm(x), c_fc, c_proj) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +@dataclass +class ForwardScheduleConfig: + mtp_weights: torch.Tensor + ws_short: int + ws_long: int + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + self.num_layers = num_layers + self.vocab_size = next_multiple_of_n(vocab_size, n=128) + + self.smear_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.smear_gate.weight) + self.smear_gate.weight.label = 'smear_gate' + + self.skip_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.skip_gate.weight) + self.skip_gate.weight.label = 'skip_gate' + + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.Parameter(torch.zeros(5 * self.vocab_size, model_dim, dtype=torch.bfloat16)) + self.value_embeds.label = 'value_embed' + + # parameter banks for attention and value embedding gate weights + self.attn_gate_bank = nn.Parameter(torch.zeros(10, num_heads, 12)) # 10 layers + self.attn_gate_bank.label = 'attn_gate_bank' + self.ve_gate_bank = nn.Parameter(torch.zeros(5, num_heads, 12)) # 5 unique gates + self.ve_gate_bank.label = 've_gate_bank' + + # ----------------------------------- + # Parameter banks for sharded optimization, by @chrisjmccormick + + # Identify which layers have attention/MLP + # Attention is skipped in layer 6 by @YouJiacheng + self.attn_layer_indices = [i for i in range(num_layers) if i != 6] + # All layers have MLP (At 11 layers--dropped first layer @EmelyanenkoK) + self.mlp_layer_indices = list(range(num_layers)) + + hdim = num_heads * head_dim + mlp_hdim = 4 * model_dim + + # Create index mappings: layer_idx -> bank_idx + self.layer_to_attn_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.attn_layer_indices)} + self.layer_to_mlp_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.mlp_layer_indices)} + + # Attention bank: stores QKVO weights for all attention layers + # merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # Simplified layout by @chrisjmccormick + # Shape: (num_attn_layers, 4*model_dim, hdim) = (10, 3072, 768) + # Reshape for sharding: (40, 768, 768) for even distribution across 8 GPUs + self.attn_bank = nn.Parameter(torch.empty(len(self.attn_layer_indices), 4 * model_dim, hdim)) + self.attn_bank.label = 'attn' + self.attn_bank.reshape = (len(self.attn_layer_indices) * 4, hdim, hdim) # (40, 768, 768) + + # MLP bank: stores c_fc and c_proj for all MLP layers + # Shape: (num_mlp_layers + padding, 2, mlp_hdim, model_dim) = (12, 2, 3072, 768) + # We add 1 padding layer (index 11) to get 12*2=24 matrices for even distribution across 8 GPUs + # Reshape for sharding: (24, 3072, 768) + num_mlp_with_padding = len(self.mlp_layer_indices) + 1 # 11 + 1 = 12 + self.mlp_bank = nn.Parameter(torch.empty(num_mlp_with_padding, 2, mlp_hdim, model_dim)) + self.mlp_bank.label = 'mlp' + self.mlp_bank.reshape = (num_mlp_with_padding * 2, mlp_hdim, model_dim) # (24, 3072, 768) + + # improved init scale by @YouJiacheng and @srashedll + std = 0.5 * model_dim ** -0.5 + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.attn_bank.uniform_(-bound, bound) + self.mlp_bank[:, 0, :, :].uniform_(-bound, bound) # c_fc + self.mlp_bank[:, 1, :, :].zero_() # c_proj - zero init suggested by @Grad62304977 + + # Create blocks with has_attn/has_mlp flags + self.paired_head_layers = [0, 2, 5, 9] + self.blocks = nn.ModuleList([ + Block(model_dim, head_dim, num_heads, + has_attn=(i in self.layer_to_attn_idx), + has_mlp=(i in self.layer_to_mlp_idx), + use_paired_head=(i in self.paired_head_layers)) + for i in range(num_layers) + ]) + self.yarn = Yarn(head_dim, max_seq_len) + self.yarn_paired_head = Yarn(head_dim, max_seq_len, paired=True) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + # Transposed weight storage for faster gradient accumulation + self.lm_head = CastedLinearT(model_dim, self.vocab_size, use_fp8=use_fp8, x_s=100/448, w_s=1.6/448, grad_s=grad_scale * 0.75/448) + + nn.init.normal_(self.lm_head.weight, mean=0, std=0.005) + self.lm_head.weight.label = 'lm_head' + + self.embed = nn.Embedding(self.vocab_size, model_dim) + self.embed.weight.label = 'embed' + with torch.no_grad(): + self.embed.weight.copy_(self.lm_head.weight.T) + + self.bigram_embed = nn.Embedding(args.bigram_vocab_size, model_dim) + self.bigram_embed.weight.label = 'bigram_embed' + nn.init.zeros_(self.bigram_embed.weight) + + # x0_lambdas separated out for different optimizer treatment (no beta smoothing) + self.x0_lambdas = nn.Parameter(torch.zeros(num_layers)) + self.x0_lambdas.label = 'x0_lambdas' + + pad = (-num_layers * 3 - 3) % dist.get_world_size() # updated: 3*num_layers instead of 4* + self.scalars = nn.Parameter( + torch.cat( + [ + 1.1 * torch.ones(num_layers), # resid lambdas. 1.1 init such that layer i weight is i^(num_layers-i). + *[torch.tensor([0.5, 1.0]) for _ in range(num_layers)], # SA lambdas + 0.1 * torch.ones(num_layers), # bigram lambdas + torch.zeros(1), # smear_lambda + 0.5*torch.ones(1), # backout_lambda + -1.5 * torch.ones(1), # skip_lambda -> σ(-1.5) ≈ 0.18 + torch.ones(pad), + ] + ) + ) + self.scalars.label = 'scalars' + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _compute_bigram_hash(x: Tensor, mod: int) -> Tensor: + """ + Computes bigram hash on GPU for each position using [prev_token, curr_token]. + Mathematically identical to the CPU version but computed on device. + """ + rand_int_1 = 36313 + rand_int_2 = 27191 + result = torch.empty_like(x) + result[0] = mod + result[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod + return result + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, schedule_cfg: ForwardScheduleConfig): + assert input_seq.ndim == 1 + + # unpack schedule_cfg + mtp_weights, ws_short, ws_long = schedule_cfg.mtp_weights, schedule_cfg.ws_short, schedule_cfg.ws_long + + # set configs + skip_connections = [] + skip_in = [3] # long attention window on layer 3 + skip_out = [6] # no attn op on layer 6 + x_backout = None + backout_layer = 7 + + # set lambdas + resid_lambdas = self.scalars[: 1 * self.num_layers] + x0_lambdas = self.x0_lambdas + sa_lambdas = self.scalars[1 * self.num_layers: 3 * self.num_layers].view(-1, 2) + bigram_lambdas = self.scalars[3 * self.num_layers: 4 * self.num_layers] + smear_lambda = self.scalars[4 * self.num_layers] + backout_lambda = self.scalars[4 * self.num_layers+1] + skip_lambda = self.scalars[4 * self.num_layers+2] + + # set block masks and key shift + bm_sizes = [ws_short, ws_short, ws_short, ws_long, ws_short, ws_short, None, ws_short, ws_short, ws_short, ws_long] + assert len(bm_sizes) == self.num_layers + key_offset = [b==ws_long for b in bm_sizes] # apply partial key offset to long windows + + # Embedding lookup - embed is synced from lm_head during tied phase by optimizer + x = self.embed(input_seq) + # Compute bigram hash on GPU (moved from CPU data loader) + bigram_seq = self._compute_bigram_hash(input_seq, args.bigram_vocab_size - 1) + x0_bigram = self.bigram_embed(bigram_seq)[None] + + # Value embeddings - always computed (not precomputed) + ve = self.value_embeds.view(5, self.vocab_size, -1)[:, input_seq] + # 01 ... 234 structure on token value embeddings by @photomz + ve = [ve[0], ve[1]] + [None] * (self.num_layers - 5) + [ve[2], ve[3], ve[4]] + assert len(ve) == self.num_layers + + # smear token embed forward 1 position @classiclarryd + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # unbind gate banks to avoid select_backwards kernel + ag = [w.bfloat16() for w in self.attn_gate_bank.unbind(0)] + veg = [w.bfloat16() for w in self.ve_gate_bank.unbind(0)] + attn_gates = ag[:6] + [None] + ag[6:] + ve_gates = [veg[0], veg[1]] + [None] * (self.num_layers - 5) + [veg[2], veg[3], veg[4]] + assert len(attn_gates) == self.num_layers + assert len(ve_gates) == self.num_layers + + # unbind weight banks to avoid select_backwards kernel + attn_weights = self.attn_bank.unbind(0) # tuple of [4*dim, hdim] tensors + mlp_fcs = self.mlp_bank[:, 0, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + mlp_projs = self.mlp_bank[:, 1, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + + for i in range(self.num_layers): + yarn = self.yarn_paired_head if i in self.paired_head_layers else self.yarn + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + yarn=yarn, + key_offset=key_offset[i], + attn_gate_w=attn_gates[i], + ve_gate_w=ve_gates[i] + ) + if i in skip_out: + skip_gate_out = torch.sigmoid(skip_lambda) * 2 * torch.sigmoid(self.skip_gate(x0[..., :self.skip_gate.weight.size(-1)])) + x = x + skip_gate_out * skip_connections.pop() + if i == 0: + x = (resid_lambdas[0] + x0_lambdas[0]) * x + bigram_lambdas[0] * x0_bigram + else: + x = resid_lambdas[i] * x + x0_lambdas[i] * x0 + bigram_lambdas[i] * x0_bigram + + # Get weights for this layer from banks + qkvo_w = attn_weights[self.layer_to_attn_idx[i]] if i in self.layer_to_attn_idx else None + c_fc = mlp_fcs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + c_proj = mlp_projs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + + x = self.blocks[i](x, attn_args, qkvo_w, c_fc, c_proj) + if i in skip_in: + skip_connections.append(x) + if i == backout_layer: + x_backout = x + + # back out contributions from first 7 layers that are only required for downstream context and not direct prediction + x -= backout_lambda * x_backout + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15 + # @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1). @classiclarryd updated to 23*sigmoid((logits+5)/7.5) + if self.training: + losses = FusedSoftcappedCrossEntropy.apply(logits.view(-1, logits.size(-1)), target_seq, mtp_weights, 23.0, 5.0, 7.5) + loss = losses.sum() + else: + logits = 23 * torch.sigmoid((logits + 5) / 7.5) + logits_for_loss = logits.float() + loss = F.cross_entropy(logits_for_loss.view(-1, logits_for_loss.size(-1)), target_seq, reduction="mean") + return loss +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class Shard: + def __init__(self, tokens: Tensor, world_size: int = 1): + self.tokens = tokens + self.size = tokens.numel() + self.world_size = world_size + self.i = 0 + + # Partial index now, full index async + self.bos_idx = (tokens[:6_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._full_idx = None + self._loader_thread = None + self._ready = threading.Event() + self._loader_thread = threading.Thread(target=self._scan) + self._loader_thread.start() + + def _scan(self): + self._full_idx = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._ready.set() + + def _maybe_switch(self): + # Switch to full index as soon as async scan completes + if self.bos_idx is not self._full_idx and self._ready.is_set(): + self._loader_thread.join() + self.bos_idx = self._full_idx + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + self._maybe_switch() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + return starts, ends + + @staticmethod + def load_async(file: Path, world_size: int = 1): + """Returns getter function for async shard loading""" + result = {} + ready = threading.Event() + def load(): + tokens = _load_data_shard(file) + result['shard'] = Shard(tokens, world_size) + ready.set() + thread = threading.Thread(target=load) + thread.start() + def get(): + ready.wait() + thread.join() + return result['shard'] + return get + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + shard = Shard(tokens, world_size) + next_shard_getter = Shard.load_async(next(file_iter), world_size) + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = shard.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + shard = next_shard_getter() + tokens = shard.tokens + try: + next_shard_getter = Shard.load_async(next(file_iter), world_size) + except StopIteration: + next_shard_getter = None # no more shards to preload + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + # Cast to int32 on CPU before transfer to avoid dtype conversion during .to() + _inputs = _inputs.to(dtype=torch.int32) + _targets = _targets.to(dtype=torch.int64) + _cum_lengths = _cum_lengths.to(dtype=torch.int32) + # Bigram hash computation moved to GPU in forward() + + new_params = yield ( + _inputs.to(device="cuda", non_blocking=True), + _targets.to(device="cuda", non_blocking=True), + _cum_lengths.to(device="cuda", non_blocking=True), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * new_grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens // new_grad_accum_steps + max_seq_len = new_max_seq_len + +# ----------------------------------------------------------------------------- +# Training Management + +@dataclass +class Hyperparameters: + # data + data_path = os.environ.get("DATA_PATH", ".") + train_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_train_*.bin") # input .bin to train on + val_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_val_*.bin") # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + # batch sizes + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # schedule + num_scheduled_iterations: int = 1515 # number of steps to complete lr and ws schedule + num_extension_iterations: int = 40 # number of steps to continue training at final lr and ws + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 250 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # bigram hash embedding + bigram_vocab_size: int = 50304 * 5 + +args = Hyperparameters() + +@dataclass +class TrainingStage: + lr_mul: float + batch_size: int + window_sizes: tuple[int, int] # (short, long) in block units + mtp_weights_start: list[float] + mtp_weights_end: list[float] + duration: float = None + +class TrainingSchedule: + """ + Training schedule initialized via TRAINING_STAGES + 1. Multi Token Prediction schedule of [1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1] @varunneal + 2. Sliding Attention window schedule of [1,3] -> [3,7] -> [5,11] -> [6,13] + 3. YaRN updates to RoPE on window changes + 4. Split embed and lm head at 2/3 of training + 5. Batch size schedule of 8 -> 16 -> 24 + 6. Post training extension of long windows from 13 to 20 + """ + + def __init__(self, stages: list[TrainingStage], scheduled_iterations: int, extension_iterations: int, + cooldown_frac: float = 0.5, split_embed_stage: int = 2, ws_post_yarn_ext: int = 20): + self.stages = stages + self.scheduled_iterations = scheduled_iterations + self.cooldown_frac = cooldown_frac + # increase final validation ws, used for YaRN extension and short window size @classiclarryd + self.ws_post_yarn_ext = ws_post_yarn_ext + + self.total_steps = self.scheduled_iterations + extension_iterations + + # Build stage boundaries (last is extension stage) + ends = [0] + [round(c * scheduled_iterations) for c in accumulate(s.duration for s in stages[:-1])] + [self.total_steps] + assert self.scheduled_iterations == ends[-2] + self.boundaries = list(pairwise(ends)) + + # Split embed at specified stage (ensure odd step for Adam) + self.split_step = self.boundaries[split_embed_stage][0] | 1 + + # Precompute MTP weights for all steps + self.mtp_weights = [] + for step in range(self.total_steps + 1): + stage, t = self.lookup(step) + w = [a + (b - a) * t for a, b in zip(stage.mtp_weights_start, stage.mtp_weights_end)] + self.mtp_weights.append(torch.tensor(w, device=device)) + + def lookup(self, step: int) -> tuple[TrainingStage, float]: + # Returns stage and % of the way through that stage + for i, (start, end) in enumerate(self.boundaries): + if step < end: + t = (step - start) / (end - start) + return self.stages[i], t + return self.stages[-1], 1.0 + + def get_lr(self, step: int) -> float: + # learning rate schedule: tied to batch size schedule, with cooldown at the end + stage, _ = self.lookup(step) + lr = stage.lr_mul + cd_start = int(self.scheduled_iterations * (1 - self.cooldown_frac)) + if step >= cd_start: + t = min(1.0, (step - cd_start) / (self.scheduled_iterations - cd_start)) + lr = lr * (1 - t) + 0.1 * t + return lr + +# window_sizes are in units of `block_size` tokens (defined in TrainingManager) +TRAINING_STAGES = [ + TrainingStage(duration=1/3, batch_size=8 * 2048 * 8, window_sizes=(1, 3), lr_mul=1.0, + mtp_weights_start=[1.0, 0.5, 0.25], mtp_weights_end=[1.0, 0.5, 0.0]), + TrainingStage(duration=1/3, batch_size=16 * 2048 * 8, window_sizes=(3, 7), lr_mul=1.52, # (16/8)**0.6 + mtp_weights_start=[1.0, 0.5], mtp_weights_end=[1.0, 0.0]), + TrainingStage(duration=1/3, batch_size=24 * 2048 * 8, window_sizes=(5, 11), lr_mul=1.73, # (24/8)**0.5 + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), + # extension stage + TrainingStage(batch_size=24 * 2048 * 8, window_sizes=(6, 13), lr_mul=1.0, # lr_mul is not used + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), +] + +training_schedule = TrainingSchedule(TRAINING_STAGES, args.num_scheduled_iterations, args.num_extension_iterations, cooldown_frac=0.55) + +def get_muon_momentum(step: int, muon_warmup_steps=300, muon_cooldown_steps=50, momentum_min=0.85, momentum_max=0.95): + # warmup phase: linearly increase momentum from min to max + # cooldown phase: linearly decrease momentum from max to min + momentum_cd_start = training_schedule.total_steps - muon_cooldown_steps + if step < muon_warmup_steps: + frac = step / muon_warmup_steps + momentum = momentum_min + frac * (momentum_max - momentum_min) + elif step > momentum_cd_start: + frac = (step - momentum_cd_start) / muon_cooldown_steps + momentum = momentum_max - frac * (momentum_max - momentum_min) + else: + momentum = momentum_max + return momentum + +class TrainingManager(): + """ + Manages the NorMuonAndAdam for all parameters with explicit ordering. + 1. Scalars are given higher momentum terms to smooth learning @ChrisJMcCormick + 2. Adam optimizers are only stepped on odd steps @classiclarryd + 3. Explicit scatter_order and work_order for communication scheduling (no backward hooks) + 4. Muon has a linear momentum warmup and cooldown schedule + 5. Learning rates follow a linear decay schedule + 6. Embed is tied to lm_head until split step (2/3 of training), then untied @classiclarryd + """ + def __init__(self, model): + self.model = model + self.block_size = 128 + + # - Ordering dictates when to launch reduce/reduce_scatter operations + # - "sharded" parameters use reduce_scatter/all_gather and "replicated" ones use all_reduce + # - lr_mul and wd_mul are per-parameter learning rate and weight decay multipliers + self.param_table = { + "attn": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "mlp": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "scalars": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 5.0, "wd_mul": 0.0}, + "value_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "bigram_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "smear_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.01, "wd_mul": 0.0}, + "skip_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.05, "wd_mul": 0.0}, + "attn_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "ve_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "x0_lambdas": {"optim": "adam", "comms": "replicated", "adam_betas": [0.65, 0.95], "lr_mul": 5.0, "wd_mul": 0.0}, + "lm_head": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + "embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + } + + # - Process smaller/faster params first while large reduces complete + # - lm_head must complete before embed sync (when tied) + self.work_order = [ + "scalars", "smear_gate", "skip_gate", "attn_gate_bank", "ve_gate_bank", "x0_lambdas", # Small, fast + "value_embed", "bigram_embed", # Medium + "lm_head", "embed", # lm_head must complete before embed sync (when tied) + "attn", "mlp", # Large, polar express - process last to maximize overlap + ] + + adam_defaults = dict( + lr=0.008, + eps=1e-10, + weight_decay=0.005, + ) + + normuon_defaults = dict( + lr=0.023, + momentum=0.95, + beta2=0.95, + weight_decay=1.2, + ) + + self.optimizer = NorMuonAndAdam( + model.named_parameters(), + param_table=self.param_table, + scatter_order=list(self.param_table.keys()), # Dict order defines scatter priority + work_order=self.work_order, + adam_defaults=adam_defaults, + normuon_defaults=normuon_defaults, + ) + + # Split embed from lm_head at 2/3 of training (on an odd step so Adam updates) + self.split_step = training_schedule.split_step + + self.reset() + + def apply_final_ws_ext(self): + self.ws_long = training_schedule.ws_post_yarn_ext + + def get_forward_args(self): + return ForwardScheduleConfig( + mtp_weights = self.mtp_weights, + ws_short = self.ws_short * self.block_size, + ws_long = self.ws_long * self.block_size + ) + + def _is_adam_step(self, step: int): + """Adam params are only updated on odd steps.""" + return step % 2 == 1 + + def get_transition_steps(self): + return [start for start, _ in training_schedule.boundaries[1:]] + + def advance_schedule(self, step: int): + stage, _ = training_schedule.lookup(step) + self.ws_short, new_ws_long = stage.window_sizes + if new_ws_long != self.ws_long: + self.model.yarn.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + self.model.yarn_paired_head.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + + new_batch_size = stage.batch_size + if new_batch_size != self.batch_size: + self.train_loader_send_args = (new_batch_size, args.train_max_seq_len, grad_accum_steps) + self.batch_size = new_batch_size + else: + self.train_loader_send_args = None + + self.ws_long = new_ws_long + self.mtp_weights = training_schedule.mtp_weights[step] + + def step_optimizers(self, step: int): + step_lr = training_schedule.get_lr(step) + muon_momentum = get_muon_momentum(step) + do_adam = self._is_adam_step(step) + + # Update learning rates and momentum for all params + for param, p_cfg in self.optimizer.param_cfgs.items(): + p_cfg.lr = p_cfg.initial_lr * step_lr + if p_cfg.optim == "normuon": + p_cfg.momentum = muon_momentum + + # Step optimizer with do_adam flag + self.optimizer.step(do_adam=do_adam) + + # At split step: copy lm_head optimizer state to embed and mark as split + if step == self.split_step: + self.optimizer.copy_lm_state_to_embed() + + def reset(self, state=None): + if state is not None: + self.optimizer.load_state_dict(state) + + # Reset NorMuon momentum buffers and split_embed state + self.optimizer.reset() + + stage, _ = training_schedule.lookup(0) + self.ws_short, self.ws_long = stage.window_sizes + self.batch_size = stage.batch_size + self.model.yarn.reset() + self.model.yarn_paired_head.reset() + + def get_state(self): + return copy.deepcopy(self.optimizer.state_dict()) + +# ----------------------------------------------------------------------------- +# int main + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=11, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=args.val_batch_size // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.weight.data = m.weight.data.bfloat16() +model.attn_gate_bank.data = model.attn_gate_bank.data.bfloat16() +model.ve_gate_bank.data = model.ve_gate_bank.data.bfloat16() +model.attn_bank.data = model.attn_bank.data.bfloat16() +model.mlp_bank.data = model.mlp_bank.data.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) +training_manager = TrainingManager(model) + +######################################## +# Warmup kernels # +######################################## +print0("Compiling model and warming up kernels (~7 minutes on first execution)", console=True) +# Warmup the training kernels, then re-initialize the state so we aren't cheating +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizer=training_manager.get_state()) # save the initial state +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + +transition_steps = training_manager.get_transition_steps() +# first few steps plus transitions +warmup_steps = sorted({0, 1, 2} | set(s + offset for s in transition_steps for offset in [-1, 0, 1] if s + offset >= 0)) +print0(f"Sampling steps {warmup_steps} for warmup", console=True) +for step in warmup_steps: + training_manager.advance_schedule(step) + model.eval() + with torch.no_grad(): + inputs, targets, cum_seqlens = next(val_loader) + model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + model.train() + for idx in range(grad_accum_steps): + send_args = training_manager.train_loader_send_args + inputs, targets, cum_seqlens = train_loader.send(send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) +print0("Resetting Model", console=True) +model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +training_manager.reset(initial_state["optimizer"]) +del val_loader, train_loader, initial_state +model.train() + +######################################## +# Training and validation # +######################################## +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) + +gc.collect() + +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = training_schedule.total_steps +for step in range(train_steps + 1): + last_step = (step == train_steps) + training_manager.advance_schedule(step) + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + training_manager.apply_final_ws_ext() + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + val_loss /= val_steps + del val_loader + dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizer=training_manager.get_state()) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for idx in range(grad_accum_steps): + inputs, targets, cum_seqlens = train_loader.send(training_manager.train_loader_send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) + + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + + +---------------------------------------- +# triton_kernels.py +---------------------------------------- + +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded configs based on H100 autotuning + if K == 768: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + else: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 128 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded config based on H100 autotuning (M=768) + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +# ----------------------------------------------------------------------------- +# Triton kernel for MLP: relu(x @ W1.T)^2, by @andrewbriand, @jrauvola + +@triton.jit +def linear_relu_square_kernel(a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, + ): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am_c = pid_m * BLOCK_SIZE_M + offs_bn_c = pid_n * BLOCK_SIZE_N + + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + + c0 = acc0.to(dtype) + if not FORWARD: + c0_pre = aux_desc.load([offs_am_c, offs_bn_c]) + c0 = 2 * c0 * tl.where(c0_pre > 0, c0_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c], c0) + + if FORWARD: + c0_post = tl.maximum(c0, 0) + c0_post = c0_post * c0_post + aux_desc.store([offs_am_c, offs_bn_c], c0_post) + + c1 = acc1.to(dtype) + if not FORWARD: + c1_pre = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c1 = 2 * c1 * tl.where(c1_pre > 0, c1_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + + if FORWARD: + c1_post = tl.maximum(c1, 0) + c1_post = c1_post * c1_post + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1_post) + + +def linear_relu_square(a, b, aux=None): + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + FORWARD = False + if aux is None: + FORWARD = True + aux = torch.empty((M, N), device=a.device, dtype=dtype) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_stages = 4 if FORWARD else 3 + num_warps = 8 + + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + + def grid(META): + return (min( + NUM_SMS, + triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), + ), ) + + linear_relu_square_kernel[grid]( + a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + NUM_SMS=NUM_SMS, + FORWARD=FORWARD, + num_stages=num_stages, + num_warps=num_warps + ) + + if FORWARD: + return c, aux + else: + return c + +class FusedLinearReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, W1, W2): + pre, post = linear_relu_square(x.view((-1, x.shape[-1])), W1) + x3 = post @ W2 + ctx.save_for_backward(x, W1, W2, pre, post) + return x3.view(x.shape) + + @staticmethod + def backward(ctx, grad_output): + x, W1, W2, pre, post = ctx.saved_tensors + dW2 = post.T @ grad_output + dpre = linear_relu_square(grad_output.view((-1, grad_output.shape[-1])), W2, aux=pre) + dW1 = dpre.T @ x + dx = dpre @ W1 + return dx.view(x.shape), dW1, dW2 + +# ----------------------------------------------------------------------------- +# Fused Softcapped Cross Entropy + + +@triton.jit +def fused_softcapped_entropy_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + + max_val = -float('inf') + sum_exp = 0.0 + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=-float('inf')).to(tl.float32) + z = A * tl.sigmoid((val + B) / C) + z = tl.where(mask, z, -float('inf')) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + + total_loss = 0.0 + for k in range(n_predict): + target_idx = row_idx + k + if target_idx < n_rows: + weight = tl.load(mtp_weights_ptr + k) + if weight > 0: + target = tl.load(targets_ptr + target_idx).to(tl.int32) + if target >= 0 and target < n_cols: + val_target = tl.load(logits_row_ptr + target).to(tl.float32) + z_target = A * tl.sigmoid((val_target + B) / C) + total_loss += weight * (lse - z_target) + + tl.store(losses_ptr + row_idx, total_loss) + +@triton.jit +def fused_softcapped_entropy_bwd_kernel( + grad_input_ptr, grad_output_ptr, lse_ptr, logits_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, stride_grad_n, stride_grad_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_input_ptr + row_idx * stride_grad_n + + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_output_ptr + row_idx) + + S_w = 0.0 + for k in range(n_predict): + if row_idx + k < n_rows: + S_w += tl.load(mtp_weights_ptr + k) + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=0.0).to(tl.float32) + u = (val + B) / C + sigmoid_u = tl.sigmoid(u) + z = A * sigmoid_u + p = tl.exp(z - lse) + + term1 = S_w * p + term2 = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for k in range(n_predict): + if row_idx + k < n_rows: + target = tl.load(targets_ptr + row_idx + k).to(tl.int32) + weight = tl.load(mtp_weights_ptr + k) + term2 += tl.where(cols == target, weight, 0.0) + + grad_z = grad_loss * (term1 - term2) + dz_dx = (1.0 / C) * z * (1.0 - sigmoid_u) + grad_x = grad_z * dz_dx + tl.store(grad_row_ptr + cols, grad_x.to(tl.bfloat16), mask=mask) + +class FusedSoftcappedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, targets, mtp_weights, A=23.0, B=5.0, C=7.5): + n_rows, n_cols = logits.shape + if mtp_weights is None: + mtp_weights = torch.tensor([1.0], device=logits.device, dtype=torch.float32) + n_predict = mtp_weights.shape[0] + + losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + lse = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + + logits = logits.contiguous() + targets = targets.contiguous() + mtp_weights = mtp_weights.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_fwd_kernel[grid]( + logits, losses, lse, targets, mtp_weights, + logits.stride(0), logits.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + + ctx.save_for_backward(logits, targets, mtp_weights, lse) + ctx.params = (A, B, C) + return losses + + @staticmethod + def backward(ctx, grad_output): + logits, targets, mtp_weights, lse = ctx.saved_tensors + A, B, C = ctx.params + n_rows, n_cols = logits.shape + n_predict = mtp_weights.shape[0] + + grad_input = torch.empty((n_rows, n_cols), dtype=torch.bfloat16, device=logits.device) + grad_output = grad_output.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_bwd_kernel[grid]( + grad_input, grad_output, lse, logits, targets, mtp_weights, + logits.stride(0), logits.stride(1), grad_input.stride(0), grad_input.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + return grad_input, None, None, None, None, None + +==================================================================================================== +Running Python 3.12.7 (main, Jan 31 2026, 04:21:49) [GCC 13.2.0] +Running PyTorch 2.10.0.dev20251210+cu126 compiled for CUDA 12.6 +Running Triton version 3.6.0 +Sun Feb 1 05:57:40 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 28C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 30C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:71:00.0 Off | 0 | +| N/A 31C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:79:00.0 Off | 0 | +| N/A 30C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:7F:00.0 Off | 0 | +| N/A 28C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 30C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 95 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 96 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 97 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 98 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 99 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 100 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 101 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 102 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +Compiling model and warming up kernels (~7 minutes on first execution) +Sampling steps [0, 1, 2, 504, 505, 506, 1009, 1010, 1011, 1514, 1515, 1516] for warmup +Resetting Model +step:0/1555 val_loss:10.8316 train_time:0ms step_avg:0.03ms +step:1/1555 train_time:92ms step_avg:92.38ms +step:2/1555 train_time:117ms step_avg:58.35ms +step:3/1555 train_time:138ms step_avg:46.13ms +step:4/1555 train_time:160ms step_avg:40.03ms +step:5/1555 train_time:186ms step_avg:37.15ms +step:6/1555 train_time:224ms step_avg:37.31ms +step:7/1555 train_time:254ms step_avg:36.28ms +step:8/1555 train_time:292ms step_avg:36.47ms +step:9/1555 train_time:322ms step_avg:35.78ms +step:10/1555 train_time:360ms step_avg:35.96ms +step:11/1555 train_time:390ms step_avg:35.46ms +step:12/1555 train_time:428ms step_avg:35.66ms +step:13/1555 train_time:459ms step_avg:35.27ms +step:14/1555 train_time:496ms step_avg:35.45ms +step:15/1555 train_time:527ms step_avg:35.12ms +step:16/1555 train_time:565ms step_avg:35.30ms +step:17/1555 train_time:596ms step_avg:35.04ms +step:18/1555 train_time:633ms step_avg:35.18ms +step:19/1555 train_time:665ms step_avg:35.02ms +step:20/1555 train_time:703ms step_avg:35.14ms +step:21/1555 train_time:733ms step_avg:34.89ms +step:22/1555 train_time:770ms step_avg:35.01ms +step:23/1555 train_time:801ms step_avg:34.82ms +step:24/1555 train_time:838ms step_avg:34.94ms +step:25/1555 train_time:869ms step_avg:34.78ms +step:26/1555 train_time:907ms step_avg:34.89ms +step:27/1555 train_time:938ms step_avg:34.73ms +step:28/1555 train_time:975ms step_avg:34.83ms +step:29/1555 train_time:1006ms step_avg:34.69ms +step:30/1555 train_time:1044ms step_avg:34.80ms +step:31/1555 train_time:1075ms step_avg:34.68ms +step:32/1555 train_time:1114ms step_avg:34.80ms +step:33/1555 train_time:1145ms step_avg:34.69ms +step:34/1555 train_time:1183ms step_avg:34.81ms +step:35/1555 train_time:1214ms step_avg:34.68ms +step:36/1555 train_time:1252ms step_avg:34.77ms +step:37/1555 train_time:1284ms step_avg:34.69ms +step:38/1555 train_time:1322ms step_avg:34.80ms +step:39/1555 train_time:1352ms step_avg:34.68ms +step:40/1555 train_time:1390ms step_avg:34.75ms +step:41/1555 train_time:1421ms step_avg:34.65ms +step:42/1555 train_time:1458ms step_avg:34.72ms +step:43/1555 train_time:1489ms step_avg:34.63ms +step:44/1555 train_time:1527ms step_avg:34.70ms +step:45/1555 train_time:1557ms step_avg:34.61ms +step:46/1555 train_time:1595ms step_avg:34.67ms +step:47/1555 train_time:1625ms step_avg:34.58ms +step:48/1555 train_time:1663ms step_avg:34.65ms +step:49/1555 train_time:1694ms step_avg:34.57ms +step:50/1555 train_time:1732ms step_avg:34.63ms +step:51/1555 train_time:1762ms step_avg:34.55ms +step:52/1555 train_time:1799ms step_avg:34.60ms +step:53/1555 train_time:1830ms step_avg:34.53ms +step:54/1555 train_time:1868ms step_avg:34.59ms +step:55/1555 train_time:1899ms step_avg:34.52ms +step:56/1555 train_time:1937ms step_avg:34.58ms +step:57/1555 train_time:1967ms step_avg:34.51ms +step:58/1555 train_time:2005ms step_avg:34.57ms +step:59/1555 train_time:2035ms step_avg:34.50ms +step:60/1555 train_time:2073ms step_avg:34.55ms +step:61/1555 train_time:2104ms step_avg:34.49ms +step:62/1555 train_time:2142ms step_avg:34.55ms +step:63/1555 train_time:2173ms step_avg:34.49ms +step:64/1555 train_time:2211ms step_avg:34.54ms +step:65/1555 train_time:2242ms step_avg:34.49ms +step:66/1555 train_time:2280ms step_avg:34.55ms +step:67/1555 train_time:2311ms step_avg:34.49ms +step:68/1555 train_time:2349ms step_avg:34.54ms +step:69/1555 train_time:2380ms step_avg:34.49ms +step:70/1555 train_time:2418ms step_avg:34.54ms +step:71/1555 train_time:2449ms step_avg:34.49ms +step:72/1555 train_time:2487ms step_avg:34.54ms +step:73/1555 train_time:2518ms step_avg:34.49ms +step:74/1555 train_time:2555ms step_avg:34.53ms +step:75/1555 train_time:2586ms step_avg:34.48ms +step:76/1555 train_time:2624ms step_avg:34.53ms +step:77/1555 train_time:2654ms step_avg:34.47ms +step:78/1555 train_time:2692ms step_avg:34.51ms +step:79/1555 train_time:2723ms step_avg:34.47ms +step:80/1555 train_time:2760ms step_avg:34.50ms +step:81/1555 train_time:2791ms step_avg:34.46ms +step:82/1555 train_time:2830ms step_avg:34.51ms +step:83/1555 train_time:2860ms step_avg:34.45ms +step:84/1555 train_time:2897ms step_avg:34.49ms +step:85/1555 train_time:2928ms step_avg:34.45ms +step:86/1555 train_time:2966ms step_avg:34.49ms +step:87/1555 train_time:2997ms step_avg:34.45ms +step:88/1555 train_time:3035ms step_avg:34.48ms +step:89/1555 train_time:3065ms step_avg:34.44ms +step:90/1555 train_time:3103ms step_avg:34.48ms +step:91/1555 train_time:3134ms step_avg:34.44ms +step:92/1555 train_time:3171ms step_avg:34.47ms +step:93/1555 train_time:3202ms step_avg:34.43ms +step:94/1555 train_time:3240ms step_avg:34.47ms +step:95/1555 train_time:3271ms step_avg:34.43ms +step:96/1555 train_time:3309ms step_avg:34.47ms +step:97/1555 train_time:3340ms step_avg:34.43ms +step:98/1555 train_time:3377ms step_avg:34.46ms +step:99/1555 train_time:3408ms step_avg:34.43ms +step:100/1555 train_time:3446ms step_avg:34.46ms +step:101/1555 train_time:3476ms step_avg:34.42ms +step:102/1555 train_time:3514ms step_avg:34.45ms +step:103/1555 train_time:3545ms step_avg:34.42ms +step:104/1555 train_time:3583ms step_avg:34.45ms +step:105/1555 train_time:3613ms step_avg:34.41ms +step:106/1555 train_time:3651ms step_avg:34.44ms +step:107/1555 train_time:3682ms step_avg:34.41ms +step:108/1555 train_time:3720ms step_avg:34.44ms +step:109/1555 train_time:3750ms step_avg:34.40ms +step:110/1555 train_time:3788ms step_avg:34.43ms +step:111/1555 train_time:3819ms step_avg:34.40ms +step:112/1555 train_time:3856ms step_avg:34.43ms +step:113/1555 train_time:3887ms step_avg:34.40ms +step:114/1555 train_time:3925ms step_avg:34.43ms +step:115/1555 train_time:3956ms step_avg:34.40ms +step:116/1555 train_time:3993ms step_avg:34.43ms +step:117/1555 train_time:4024ms step_avg:34.39ms +step:118/1555 train_time:4062ms step_avg:34.42ms +step:119/1555 train_time:4092ms step_avg:34.39ms +step:120/1555 train_time:4130ms step_avg:34.42ms +step:121/1555 train_time:4160ms step_avg:34.38ms +step:122/1555 train_time:4197ms step_avg:34.41ms +step:123/1555 train_time:4228ms step_avg:34.38ms +step:124/1555 train_time:4266ms step_avg:34.40ms +step:125/1555 train_time:4297ms step_avg:34.37ms +step:126/1555 train_time:4334ms step_avg:34.40ms +step:127/1555 train_time:4365ms step_avg:34.37ms +step:128/1555 train_time:4402ms step_avg:34.39ms +step:129/1555 train_time:4433ms step_avg:34.37ms +step:130/1555 train_time:4471ms step_avg:34.39ms +step:131/1555 train_time:4502ms step_avg:34.37ms +step:132/1555 train_time:4539ms step_avg:34.39ms +step:133/1555 train_time:4570ms step_avg:34.36ms +step:134/1555 train_time:4608ms step_avg:34.38ms +step:135/1555 train_time:4639ms step_avg:34.36ms +step:136/1555 train_time:4676ms step_avg:34.38ms +step:137/1555 train_time:4707ms step_avg:34.36ms +step:138/1555 train_time:4745ms step_avg:34.39ms +step:139/1555 train_time:4776ms step_avg:34.36ms +step:140/1555 train_time:4813ms step_avg:34.38ms +step:141/1555 train_time:4844ms step_avg:34.36ms +step:142/1555 train_time:4883ms step_avg:34.38ms +step:143/1555 train_time:4914ms step_avg:34.36ms +step:144/1555 train_time:4951ms step_avg:34.38ms +step:145/1555 train_time:4982ms step_avg:34.36ms +step:146/1555 train_time:5019ms step_avg:34.38ms +step:147/1555 train_time:5050ms step_avg:34.35ms +step:148/1555 train_time:5088ms step_avg:34.38ms +step:149/1555 train_time:5118ms step_avg:34.35ms +step:150/1555 train_time:5156ms step_avg:34.37ms +step:151/1555 train_time:5186ms step_avg:34.35ms +step:152/1555 train_time:5224ms step_avg:34.37ms +step:153/1555 train_time:5255ms step_avg:34.34ms +step:154/1555 train_time:5292ms step_avg:34.37ms +step:155/1555 train_time:5323ms step_avg:34.34ms +step:156/1555 train_time:5361ms step_avg:34.36ms +step:157/1555 train_time:5392ms step_avg:34.34ms +step:158/1555 train_time:5430ms step_avg:34.37ms +step:159/1555 train_time:5460ms step_avg:34.34ms +step:160/1555 train_time:5498ms step_avg:34.36ms +step:161/1555 train_time:5529ms step_avg:34.34ms +step:162/1555 train_time:5567ms step_avg:34.36ms +step:163/1555 train_time:5597ms step_avg:34.34ms +step:164/1555 train_time:5635ms step_avg:34.36ms +step:165/1555 train_time:5666ms step_avg:34.34ms +step:166/1555 train_time:5703ms step_avg:34.36ms +step:167/1555 train_time:5734ms step_avg:34.34ms +step:168/1555 train_time:5772ms step_avg:34.36ms +step:169/1555 train_time:5803ms step_avg:34.33ms +step:170/1555 train_time:5840ms step_avg:34.35ms +step:171/1555 train_time:5871ms step_avg:34.33ms +step:172/1555 train_time:5908ms step_avg:34.35ms +step:173/1555 train_time:5939ms step_avg:34.33ms +step:174/1555 train_time:5976ms step_avg:34.35ms +step:175/1555 train_time:6008ms step_avg:34.33ms +step:176/1555 train_time:6046ms step_avg:34.35ms +step:177/1555 train_time:6076ms step_avg:34.33ms +step:178/1555 train_time:6114ms step_avg:34.35ms +step:179/1555 train_time:6145ms step_avg:34.33ms +step:180/1555 train_time:6183ms step_avg:34.35ms +step:181/1555 train_time:6213ms step_avg:34.33ms +step:182/1555 train_time:6251ms step_avg:34.35ms +step:183/1555 train_time:6281ms step_avg:34.32ms +step:184/1555 train_time:6319ms step_avg:34.34ms +step:185/1555 train_time:6349ms step_avg:34.32ms +step:186/1555 train_time:6387ms step_avg:34.34ms +step:187/1555 train_time:6418ms step_avg:34.32ms +step:188/1555 train_time:6456ms step_avg:34.34ms +step:189/1555 train_time:6487ms step_avg:34.32ms +step:190/1555 train_time:6525ms step_avg:34.34ms +step:191/1555 train_time:6555ms step_avg:34.32ms +step:192/1555 train_time:6593ms step_avg:34.34ms +step:193/1555 train_time:6624ms step_avg:34.32ms +step:194/1555 train_time:6661ms step_avg:34.34ms +step:195/1555 train_time:6692ms step_avg:34.32ms +step:196/1555 train_time:6729ms step_avg:34.33ms +step:197/1555 train_time:6760ms step_avg:34.31ms +step:198/1555 train_time:6797ms step_avg:34.33ms +step:199/1555 train_time:6828ms step_avg:34.31ms +step:200/1555 train_time:6865ms step_avg:34.33ms +step:201/1555 train_time:6896ms step_avg:34.31ms +step:202/1555 train_time:6933ms step_avg:34.32ms +step:203/1555 train_time:6964ms step_avg:34.30ms +step:204/1555 train_time:7001ms step_avg:34.32ms +step:205/1555 train_time:7032ms step_avg:34.30ms +step:206/1555 train_time:7070ms step_avg:34.32ms +step:207/1555 train_time:7101ms step_avg:34.30ms +step:208/1555 train_time:7138ms step_avg:34.32ms +step:209/1555 train_time:7169ms step_avg:34.30ms +step:210/1555 train_time:7207ms step_avg:34.32ms +step:211/1555 train_time:7238ms step_avg:34.30ms +step:212/1555 train_time:7275ms step_avg:34.32ms +step:213/1555 train_time:7306ms step_avg:34.30ms +step:214/1555 train_time:7344ms step_avg:34.32ms +step:215/1555 train_time:7375ms step_avg:34.30ms +step:216/1555 train_time:7412ms step_avg:34.32ms +step:217/1555 train_time:7443ms step_avg:34.30ms +step:218/1555 train_time:7480ms step_avg:34.31ms +step:219/1555 train_time:7512ms step_avg:34.30ms +step:220/1555 train_time:7549ms step_avg:34.32ms +step:221/1555 train_time:7580ms step_avg:34.30ms +step:222/1555 train_time:7618ms step_avg:34.31ms +step:223/1555 train_time:7648ms step_avg:34.30ms +step:224/1555 train_time:7686ms step_avg:34.31ms +step:225/1555 train_time:7716ms step_avg:34.30ms +step:226/1555 train_time:7754ms step_avg:34.31ms +step:227/1555 train_time:7785ms step_avg:34.30ms +step:228/1555 train_time:7824ms step_avg:34.32ms +step:229/1555 train_time:7853ms step_avg:34.29ms +step:230/1555 train_time:7891ms step_avg:34.31ms +step:231/1555 train_time:7921ms step_avg:34.29ms +step:232/1555 train_time:7959ms step_avg:34.31ms +step:233/1555 train_time:7990ms step_avg:34.29ms +step:234/1555 train_time:8027ms step_avg:34.30ms +step:235/1555 train_time:8057ms step_avg:34.29ms +step:236/1555 train_time:8095ms step_avg:34.30ms +step:237/1555 train_time:8126ms step_avg:34.29ms +step:238/1555 train_time:8165ms step_avg:34.31ms +step:239/1555 train_time:8195ms step_avg:34.29ms +step:240/1555 train_time:8233ms step_avg:34.30ms +step:241/1555 train_time:8263ms step_avg:34.29ms +step:242/1555 train_time:8300ms step_avg:34.30ms +step:243/1555 train_time:8331ms step_avg:34.28ms +step:244/1555 train_time:8368ms step_avg:34.30ms +step:245/1555 train_time:8399ms step_avg:34.28ms +step:246/1555 train_time:8436ms step_avg:34.29ms +step:247/1555 train_time:8467ms step_avg:34.28ms +step:248/1555 train_time:8505ms step_avg:34.29ms +step:249/1555 train_time:8535ms step_avg:34.28ms +step:250/1555 train_time:8573ms step_avg:34.29ms +step:250/1555 val_loss:4.5356 train_time:8622ms step_avg:34.49ms +step:251/1555 train_time:8642ms step_avg:34.43ms +step:252/1555 train_time:8664ms step_avg:34.38ms +step:253/1555 train_time:8683ms step_avg:34.32ms +step:254/1555 train_time:8711ms step_avg:34.30ms +step:255/1555 train_time:8745ms step_avg:34.29ms +step:256/1555 train_time:8785ms step_avg:34.32ms +step:257/1555 train_time:8816ms step_avg:34.30ms +step:258/1555 train_time:8854ms step_avg:34.32ms +step:259/1555 train_time:8884ms step_avg:34.30ms +step:260/1555 train_time:8922ms step_avg:34.32ms +step:261/1555 train_time:8953ms step_avg:34.30ms +step:262/1555 train_time:8991ms step_avg:34.32ms +step:263/1555 train_time:9022ms step_avg:34.30ms +step:264/1555 train_time:9059ms step_avg:34.32ms +step:265/1555 train_time:9090ms step_avg:34.30ms +step:266/1555 train_time:9127ms step_avg:34.31ms +step:267/1555 train_time:9158ms step_avg:34.30ms +step:268/1555 train_time:9195ms step_avg:34.31ms +step:269/1555 train_time:9226ms step_avg:34.30ms +step:270/1555 train_time:9263ms step_avg:34.31ms +step:271/1555 train_time:9294ms step_avg:34.29ms +step:272/1555 train_time:9331ms step_avg:34.30ms +step:273/1555 train_time:9361ms step_avg:34.29ms +step:274/1555 train_time:9399ms step_avg:34.30ms +step:275/1555 train_time:9430ms step_avg:34.29ms +step:276/1555 train_time:9467ms step_avg:34.30ms +step:277/1555 train_time:9497ms step_avg:34.29ms +step:278/1555 train_time:9535ms step_avg:34.30ms +step:279/1555 train_time:9566ms step_avg:34.29ms +step:280/1555 train_time:9603ms step_avg:34.30ms +step:281/1555 train_time:9634ms step_avg:34.28ms +step:282/1555 train_time:9672ms step_avg:34.30ms +step:283/1555 train_time:9703ms step_avg:34.29ms +step:284/1555 train_time:9740ms step_avg:34.30ms +step:285/1555 train_time:9771ms step_avg:34.28ms +step:286/1555 train_time:9809ms step_avg:34.30ms +step:287/1555 train_time:9840ms step_avg:34.28ms +step:288/1555 train_time:9878ms step_avg:34.30ms +step:289/1555 train_time:9909ms step_avg:34.29ms +step:290/1555 train_time:9946ms step_avg:34.30ms +step:291/1555 train_time:9977ms step_avg:34.28ms +step:292/1555 train_time:10015ms step_avg:34.30ms +step:293/1555 train_time:10045ms step_avg:34.28ms +step:294/1555 train_time:10083ms step_avg:34.30ms +step:295/1555 train_time:10113ms step_avg:34.28ms +step:296/1555 train_time:10151ms step_avg:34.29ms +step:297/1555 train_time:10182ms step_avg:34.28ms +step:298/1555 train_time:10219ms step_avg:34.29ms +step:299/1555 train_time:10250ms step_avg:34.28ms +step:300/1555 train_time:10287ms step_avg:34.29ms +step:301/1555 train_time:10318ms step_avg:34.28ms +step:302/1555 train_time:10356ms step_avg:34.29ms +step:303/1555 train_time:10387ms step_avg:34.28ms +step:304/1555 train_time:10424ms step_avg:34.29ms +step:305/1555 train_time:10455ms step_avg:34.28ms +step:306/1555 train_time:10493ms step_avg:34.29ms +step:307/1555 train_time:10524ms step_avg:34.28ms +step:308/1555 train_time:10561ms step_avg:34.29ms +step:309/1555 train_time:10593ms step_avg:34.28ms +step:310/1555 train_time:10630ms step_avg:34.29ms +step:311/1555 train_time:10661ms step_avg:34.28ms +step:312/1555 train_time:10699ms step_avg:34.29ms +step:313/1555 train_time:10729ms step_avg:34.28ms +step:314/1555 train_time:10767ms step_avg:34.29ms +step:315/1555 train_time:10798ms step_avg:34.28ms +step:316/1555 train_time:10836ms step_avg:34.29ms +step:317/1555 train_time:10867ms step_avg:34.28ms +step:318/1555 train_time:10905ms step_avg:34.29ms +step:319/1555 train_time:10936ms step_avg:34.28ms +step:320/1555 train_time:10974ms step_avg:34.29ms +step:321/1555 train_time:11005ms step_avg:34.28ms +step:322/1555 train_time:11042ms step_avg:34.29ms +step:323/1555 train_time:11073ms step_avg:34.28ms +step:324/1555 train_time:11111ms step_avg:34.29ms +step:325/1555 train_time:11142ms step_avg:34.28ms +step:326/1555 train_time:11181ms step_avg:34.30ms +step:327/1555 train_time:11211ms step_avg:34.29ms +step:328/1555 train_time:11249ms step_avg:34.30ms +step:329/1555 train_time:11280ms step_avg:34.29ms +step:330/1555 train_time:11318ms step_avg:34.30ms +step:331/1555 train_time:11349ms step_avg:34.29ms +step:332/1555 train_time:11386ms step_avg:34.30ms +step:333/1555 train_time:11417ms step_avg:34.28ms +step:334/1555 train_time:11455ms step_avg:34.30ms +step:335/1555 train_time:11485ms step_avg:34.28ms +step:336/1555 train_time:11522ms step_avg:34.29ms +step:337/1555 train_time:11553ms step_avg:34.28ms +step:338/1555 train_time:11591ms step_avg:34.29ms +step:339/1555 train_time:11621ms step_avg:34.28ms +step:340/1555 train_time:11659ms step_avg:34.29ms +step:341/1555 train_time:11690ms step_avg:34.28ms +step:342/1555 train_time:11727ms step_avg:34.29ms +step:343/1555 train_time:11758ms step_avg:34.28ms +step:344/1555 train_time:11795ms step_avg:34.29ms +step:345/1555 train_time:11826ms step_avg:34.28ms +step:346/1555 train_time:11864ms step_avg:34.29ms +step:347/1555 train_time:11895ms step_avg:34.28ms +step:348/1555 train_time:11933ms step_avg:34.29ms +step:349/1555 train_time:11964ms step_avg:34.28ms +step:350/1555 train_time:12001ms step_avg:34.29ms +step:351/1555 train_time:12032ms step_avg:34.28ms +step:352/1555 train_time:12070ms step_avg:34.29ms +step:353/1555 train_time:12100ms step_avg:34.28ms +step:354/1555 train_time:12138ms step_avg:34.29ms +step:355/1555 train_time:12169ms step_avg:34.28ms +step:356/1555 train_time:12206ms step_avg:34.29ms +step:357/1555 train_time:12237ms step_avg:34.28ms +step:358/1555 train_time:12274ms step_avg:34.29ms +step:359/1555 train_time:12305ms step_avg:34.28ms +step:360/1555 train_time:12343ms step_avg:34.29ms +step:361/1555 train_time:12373ms step_avg:34.27ms +step:362/1555 train_time:12411ms step_avg:34.28ms +step:363/1555 train_time:12442ms step_avg:34.28ms +step:364/1555 train_time:12479ms step_avg:34.28ms +step:365/1555 train_time:12510ms step_avg:34.27ms +step:366/1555 train_time:12547ms step_avg:34.28ms +step:367/1555 train_time:12578ms step_avg:34.27ms +step:368/1555 train_time:12616ms step_avg:34.28ms +step:369/1555 train_time:12646ms step_avg:34.27ms +step:370/1555 train_time:12684ms step_avg:34.28ms +step:371/1555 train_time:12714ms step_avg:34.27ms +step:372/1555 train_time:12752ms step_avg:34.28ms +step:373/1555 train_time:12783ms step_avg:34.27ms +step:374/1555 train_time:12821ms step_avg:34.28ms +step:375/1555 train_time:12851ms step_avg:34.27ms +step:376/1555 train_time:12889ms step_avg:34.28ms +step:377/1555 train_time:12920ms step_avg:34.27ms +step:378/1555 train_time:12958ms step_avg:34.28ms +step:379/1555 train_time:12989ms step_avg:34.27ms +step:380/1555 train_time:13026ms step_avg:34.28ms +step:381/1555 train_time:13057ms step_avg:34.27ms +step:382/1555 train_time:13096ms step_avg:34.28ms +step:383/1555 train_time:13126ms step_avg:34.27ms +step:384/1555 train_time:13164ms step_avg:34.28ms +step:385/1555 train_time:13194ms step_avg:34.27ms +step:386/1555 train_time:13232ms step_avg:34.28ms +step:387/1555 train_time:13263ms step_avg:34.27ms +step:388/1555 train_time:13300ms step_avg:34.28ms +step:389/1555 train_time:13332ms step_avg:34.27ms +step:390/1555 train_time:13369ms step_avg:34.28ms +step:391/1555 train_time:13400ms step_avg:34.27ms +step:392/1555 train_time:13438ms step_avg:34.28ms +step:393/1555 train_time:13469ms step_avg:34.27ms +step:394/1555 train_time:13506ms step_avg:34.28ms +step:395/1555 train_time:13537ms step_avg:34.27ms +step:396/1555 train_time:13575ms step_avg:34.28ms +step:397/1555 train_time:13605ms step_avg:34.27ms +step:398/1555 train_time:13642ms step_avg:34.28ms +step:399/1555 train_time:13674ms step_avg:34.27ms +step:400/1555 train_time:13711ms step_avg:34.28ms +step:401/1555 train_time:13742ms step_avg:34.27ms +step:402/1555 train_time:13779ms step_avg:34.28ms +step:403/1555 train_time:13810ms step_avg:34.27ms +step:404/1555 train_time:13848ms step_avg:34.28ms +step:405/1555 train_time:13879ms step_avg:34.27ms +step:406/1555 train_time:13917ms step_avg:34.28ms +step:407/1555 train_time:13948ms step_avg:34.27ms +step:408/1555 train_time:13985ms step_avg:34.28ms +step:409/1555 train_time:14016ms step_avg:34.27ms +step:410/1555 train_time:14054ms step_avg:34.28ms +step:411/1555 train_time:14084ms step_avg:34.27ms +step:412/1555 train_time:14122ms step_avg:34.28ms +step:413/1555 train_time:14153ms step_avg:34.27ms +step:414/1555 train_time:14190ms step_avg:34.28ms +step:415/1555 train_time:14220ms step_avg:34.27ms +step:416/1555 train_time:14258ms step_avg:34.27ms +step:417/1555 train_time:14288ms step_avg:34.26ms +step:418/1555 train_time:14326ms step_avg:34.27ms +step:419/1555 train_time:14357ms step_avg:34.26ms +step:420/1555 train_time:14395ms step_avg:34.27ms +step:421/1555 train_time:14425ms step_avg:34.26ms +step:422/1555 train_time:14463ms step_avg:34.27ms +step:423/1555 train_time:14494ms step_avg:34.26ms +step:424/1555 train_time:14531ms step_avg:34.27ms +step:425/1555 train_time:14562ms step_avg:34.26ms +step:426/1555 train_time:14600ms step_avg:34.27ms +step:427/1555 train_time:14631ms step_avg:34.26ms +step:428/1555 train_time:14668ms step_avg:34.27ms +step:429/1555 train_time:14699ms step_avg:34.26ms +step:430/1555 train_time:14736ms step_avg:34.27ms +step:431/1555 train_time:14767ms step_avg:34.26ms +step:432/1555 train_time:14805ms step_avg:34.27ms +step:433/1555 train_time:14835ms step_avg:34.26ms +step:434/1555 train_time:14873ms step_avg:34.27ms +step:435/1555 train_time:14904ms step_avg:34.26ms +step:436/1555 train_time:14942ms step_avg:34.27ms +step:437/1555 train_time:14973ms step_avg:34.26ms +step:438/1555 train_time:15011ms step_avg:34.27ms +step:439/1555 train_time:15042ms step_avg:34.26ms +step:440/1555 train_time:15079ms step_avg:34.27ms +step:441/1555 train_time:15110ms step_avg:34.26ms +step:442/1555 train_time:15147ms step_avg:34.27ms +step:443/1555 train_time:15177ms step_avg:34.26ms +step:444/1555 train_time:15215ms step_avg:34.27ms +step:445/1555 train_time:15246ms step_avg:34.26ms +step:446/1555 train_time:15283ms step_avg:34.27ms +step:447/1555 train_time:15314ms step_avg:34.26ms +step:448/1555 train_time:15353ms step_avg:34.27ms +step:449/1555 train_time:15383ms step_avg:34.26ms +step:450/1555 train_time:15420ms step_avg:34.27ms +step:451/1555 train_time:15451ms step_avg:34.26ms +step:452/1555 train_time:15490ms step_avg:34.27ms +step:453/1555 train_time:15520ms step_avg:34.26ms +step:454/1555 train_time:15558ms step_avg:34.27ms +step:455/1555 train_time:15589ms step_avg:34.26ms +step:456/1555 train_time:15626ms step_avg:34.27ms +step:457/1555 train_time:15657ms step_avg:34.26ms +step:458/1555 train_time:15695ms step_avg:34.27ms +step:459/1555 train_time:15725ms step_avg:34.26ms +step:460/1555 train_time:15763ms step_avg:34.27ms +step:461/1555 train_time:15794ms step_avg:34.26ms +step:462/1555 train_time:15832ms step_avg:34.27ms +step:463/1555 train_time:15862ms step_avg:34.26ms +step:464/1555 train_time:15901ms step_avg:34.27ms +step:465/1555 train_time:15931ms step_avg:34.26ms +step:466/1555 train_time:15969ms step_avg:34.27ms +step:467/1555 train_time:16000ms step_avg:34.26ms +step:468/1555 train_time:16038ms step_avg:34.27ms +step:469/1555 train_time:16068ms step_avg:34.26ms +step:470/1555 train_time:16106ms step_avg:34.27ms +step:471/1555 train_time:16137ms step_avg:34.26ms +step:472/1555 train_time:16175ms step_avg:34.27ms +step:473/1555 train_time:16205ms step_avg:34.26ms +step:474/1555 train_time:16243ms step_avg:34.27ms +step:475/1555 train_time:16274ms step_avg:34.26ms +step:476/1555 train_time:16312ms step_avg:34.27ms +step:477/1555 train_time:16343ms step_avg:34.26ms +step:478/1555 train_time:16381ms step_avg:34.27ms +step:479/1555 train_time:16411ms step_avg:34.26ms +step:480/1555 train_time:16449ms step_avg:34.27ms +step:481/1555 train_time:16480ms step_avg:34.26ms +step:482/1555 train_time:16517ms step_avg:34.27ms +step:483/1555 train_time:16548ms step_avg:34.26ms +step:484/1555 train_time:16585ms step_avg:34.27ms +step:485/1555 train_time:16615ms step_avg:34.26ms +step:486/1555 train_time:16653ms step_avg:34.27ms +step:487/1555 train_time:16684ms step_avg:34.26ms +step:488/1555 train_time:16721ms step_avg:34.26ms +step:489/1555 train_time:16752ms step_avg:34.26ms +step:490/1555 train_time:16790ms step_avg:34.27ms +step:491/1555 train_time:16821ms step_avg:34.26ms +step:492/1555 train_time:16858ms step_avg:34.27ms +step:493/1555 train_time:16889ms step_avg:34.26ms +step:494/1555 train_time:16927ms step_avg:34.26ms +step:495/1555 train_time:16958ms step_avg:34.26ms +step:496/1555 train_time:16995ms step_avg:34.26ms +step:497/1555 train_time:17026ms step_avg:34.26ms +step:498/1555 train_time:17063ms step_avg:34.26ms +step:499/1555 train_time:17094ms step_avg:34.26ms +step:500/1555 train_time:17132ms step_avg:34.26ms +step:500/1555 val_loss:4.2220 train_time:17181ms step_avg:34.36ms +step:501/1555 train_time:17204ms step_avg:34.34ms +step:502/1555 train_time:17226ms step_avg:34.31ms +step:503/1555 train_time:17245ms step_avg:34.28ms +step:504/1555 train_time:17272ms step_avg:34.27ms +step:505/1555 train_time:17305ms step_avg:34.27ms +step:506/1555 train_time:17351ms step_avg:34.29ms +step:507/1555 train_time:17401ms step_avg:34.32ms +step:508/1555 train_time:17466ms step_avg:34.38ms +step:509/1555 train_time:17523ms step_avg:34.43ms +step:510/1555 train_time:17587ms step_avg:34.48ms +step:511/1555 train_time:17644ms step_avg:34.53ms +step:512/1555 train_time:17707ms step_avg:34.58ms +step:513/1555 train_time:17763ms step_avg:34.63ms +step:514/1555 train_time:17827ms step_avg:34.68ms +step:515/1555 train_time:17883ms step_avg:34.72ms +step:516/1555 train_time:17946ms step_avg:34.78ms +step:517/1555 train_time:18003ms step_avg:34.82ms +step:518/1555 train_time:18068ms step_avg:34.88ms +step:519/1555 train_time:18127ms step_avg:34.93ms +step:520/1555 train_time:18192ms step_avg:34.98ms +step:521/1555 train_time:18252ms step_avg:35.03ms +step:522/1555 train_time:18317ms step_avg:35.09ms +step:523/1555 train_time:18375ms step_avg:35.13ms +step:524/1555 train_time:18440ms step_avg:35.19ms +step:525/1555 train_time:18497ms step_avg:35.23ms +step:526/1555 train_time:18562ms step_avg:35.29ms +step:527/1555 train_time:18620ms step_avg:35.33ms +step:528/1555 train_time:18684ms step_avg:35.39ms +step:529/1555 train_time:18742ms step_avg:35.43ms +step:530/1555 train_time:18805ms step_avg:35.48ms +step:531/1555 train_time:18863ms step_avg:35.52ms +step:532/1555 train_time:18927ms step_avg:35.58ms +step:533/1555 train_time:18984ms step_avg:35.62ms +step:534/1555 train_time:19048ms step_avg:35.67ms +step:535/1555 train_time:19106ms step_avg:35.71ms +step:536/1555 train_time:19172ms step_avg:35.77ms +step:537/1555 train_time:19230ms step_avg:35.81ms +step:538/1555 train_time:19294ms step_avg:35.86ms +step:539/1555 train_time:19351ms step_avg:35.90ms +step:540/1555 train_time:19415ms step_avg:35.95ms +step:541/1555 train_time:19474ms step_avg:36.00ms +step:542/1555 train_time:19539ms step_avg:36.05ms +step:543/1555 train_time:19596ms step_avg:36.09ms +step:544/1555 train_time:19661ms step_avg:36.14ms +step:545/1555 train_time:19720ms step_avg:36.18ms +step:546/1555 train_time:19783ms step_avg:36.23ms +step:547/1555 train_time:19840ms step_avg:36.27ms +step:548/1555 train_time:19904ms step_avg:36.32ms +step:549/1555 train_time:19961ms step_avg:36.36ms +step:550/1555 train_time:20026ms step_avg:36.41ms +step:551/1555 train_time:20084ms step_avg:36.45ms +step:552/1555 train_time:20148ms step_avg:36.50ms +step:553/1555 train_time:20207ms step_avg:36.54ms +step:554/1555 train_time:20271ms step_avg:36.59ms +step:555/1555 train_time:20329ms step_avg:36.63ms +step:556/1555 train_time:20393ms step_avg:36.68ms +step:557/1555 train_time:20450ms step_avg:36.71ms +step:558/1555 train_time:20514ms step_avg:36.76ms +step:559/1555 train_time:20571ms step_avg:36.80ms +step:560/1555 train_time:20636ms step_avg:36.85ms +step:561/1555 train_time:20694ms step_avg:36.89ms +step:562/1555 train_time:20759ms step_avg:36.94ms +step:563/1555 train_time:20817ms step_avg:36.98ms +step:564/1555 train_time:20881ms step_avg:37.02ms +step:565/1555 train_time:20939ms step_avg:37.06ms +step:566/1555 train_time:21004ms step_avg:37.11ms +step:567/1555 train_time:21061ms step_avg:37.15ms +step:568/1555 train_time:21126ms step_avg:37.19ms +step:569/1555 train_time:21185ms step_avg:37.23ms +step:570/1555 train_time:21248ms step_avg:37.28ms +step:571/1555 train_time:21306ms step_avg:37.31ms +step:572/1555 train_time:21371ms step_avg:37.36ms +step:573/1555 train_time:21429ms step_avg:37.40ms +step:574/1555 train_time:21493ms step_avg:37.44ms +step:575/1555 train_time:21550ms step_avg:37.48ms +step:576/1555 train_time:21614ms step_avg:37.52ms +step:577/1555 train_time:21672ms step_avg:37.56ms +step:578/1555 train_time:21737ms step_avg:37.61ms +step:579/1555 train_time:21794ms step_avg:37.64ms +step:580/1555 train_time:21859ms step_avg:37.69ms +step:581/1555 train_time:21918ms step_avg:37.72ms +step:582/1555 train_time:21982ms step_avg:37.77ms +step:583/1555 train_time:22039ms step_avg:37.80ms +step:584/1555 train_time:22104ms step_avg:37.85ms +step:585/1555 train_time:22162ms step_avg:37.88ms +step:586/1555 train_time:22228ms step_avg:37.93ms +step:587/1555 train_time:22285ms step_avg:37.96ms +step:588/1555 train_time:22349ms step_avg:38.01ms +step:589/1555 train_time:22408ms step_avg:38.04ms +step:590/1555 train_time:22472ms step_avg:38.09ms +step:591/1555 train_time:22530ms step_avg:38.12ms +step:592/1555 train_time:22593ms step_avg:38.16ms +step:593/1555 train_time:22650ms step_avg:38.20ms +step:594/1555 train_time:22714ms step_avg:38.24ms +step:595/1555 train_time:22772ms step_avg:38.27ms +step:596/1555 train_time:22837ms step_avg:38.32ms +step:597/1555 train_time:22894ms step_avg:38.35ms +step:598/1555 train_time:22960ms step_avg:38.39ms +step:599/1555 train_time:23018ms step_avg:38.43ms +step:600/1555 train_time:23083ms step_avg:38.47ms +step:601/1555 train_time:23141ms step_avg:38.50ms +step:602/1555 train_time:23205ms step_avg:38.55ms +step:603/1555 train_time:23263ms step_avg:38.58ms +step:604/1555 train_time:23327ms step_avg:38.62ms +step:605/1555 train_time:23385ms step_avg:38.65ms +step:606/1555 train_time:23449ms step_avg:38.69ms +step:607/1555 train_time:23506ms step_avg:38.73ms +step:608/1555 train_time:23570ms step_avg:38.77ms +step:609/1555 train_time:23629ms step_avg:38.80ms +step:610/1555 train_time:23692ms step_avg:38.84ms +step:611/1555 train_time:23750ms step_avg:38.87ms +step:612/1555 train_time:23815ms step_avg:38.91ms +step:613/1555 train_time:23873ms step_avg:38.94ms +step:614/1555 train_time:23938ms step_avg:38.99ms +step:615/1555 train_time:23996ms step_avg:39.02ms +step:616/1555 train_time:24061ms step_avg:39.06ms +step:617/1555 train_time:24119ms step_avg:39.09ms +step:618/1555 train_time:24184ms step_avg:39.13ms +step:619/1555 train_time:24241ms step_avg:39.16ms +step:620/1555 train_time:24306ms step_avg:39.20ms +step:621/1555 train_time:24364ms step_avg:39.23ms +step:622/1555 train_time:24428ms step_avg:39.27ms +step:623/1555 train_time:24486ms step_avg:39.30ms +step:624/1555 train_time:24550ms step_avg:39.34ms +step:625/1555 train_time:24607ms step_avg:39.37ms +step:626/1555 train_time:24670ms step_avg:39.41ms +step:627/1555 train_time:24728ms step_avg:39.44ms +step:628/1555 train_time:24792ms step_avg:39.48ms +step:629/1555 train_time:24850ms step_avg:39.51ms +step:630/1555 train_time:24915ms step_avg:39.55ms +step:631/1555 train_time:24973ms step_avg:39.58ms +step:632/1555 train_time:25038ms step_avg:39.62ms +step:633/1555 train_time:25096ms step_avg:39.65ms +step:634/1555 train_time:25160ms step_avg:39.69ms +step:635/1555 train_time:25220ms step_avg:39.72ms +step:636/1555 train_time:25284ms step_avg:39.76ms +step:637/1555 train_time:25342ms step_avg:39.78ms +step:638/1555 train_time:25406ms step_avg:39.82ms +step:639/1555 train_time:25464ms step_avg:39.85ms +step:640/1555 train_time:25528ms step_avg:39.89ms +step:641/1555 train_time:25585ms step_avg:39.91ms +step:642/1555 train_time:25649ms step_avg:39.95ms +step:643/1555 train_time:25707ms step_avg:39.98ms +step:644/1555 train_time:25771ms step_avg:40.02ms +step:645/1555 train_time:25830ms step_avg:40.05ms +step:646/1555 train_time:25893ms step_avg:40.08ms +step:647/1555 train_time:25950ms step_avg:40.11ms +step:648/1555 train_time:26014ms step_avg:40.15ms +step:649/1555 train_time:26071ms step_avg:40.17ms +step:650/1555 train_time:26136ms step_avg:40.21ms +step:651/1555 train_time:26194ms step_avg:40.24ms +step:652/1555 train_time:26260ms step_avg:40.28ms +step:653/1555 train_time:26318ms step_avg:40.30ms +step:654/1555 train_time:26383ms step_avg:40.34ms +step:655/1555 train_time:26441ms step_avg:40.37ms +step:656/1555 train_time:26505ms step_avg:40.40ms +step:657/1555 train_time:26563ms step_avg:40.43ms +step:658/1555 train_time:26627ms step_avg:40.47ms +step:659/1555 train_time:26684ms step_avg:40.49ms +step:660/1555 train_time:26748ms step_avg:40.53ms +step:661/1555 train_time:26806ms step_avg:40.55ms +step:662/1555 train_time:26870ms step_avg:40.59ms +step:663/1555 train_time:26928ms step_avg:40.62ms +step:664/1555 train_time:26992ms step_avg:40.65ms +step:665/1555 train_time:27049ms step_avg:40.68ms +step:666/1555 train_time:27114ms step_avg:40.71ms +step:667/1555 train_time:27172ms step_avg:40.74ms +step:668/1555 train_time:27237ms step_avg:40.77ms +step:669/1555 train_time:27294ms step_avg:40.80ms +step:670/1555 train_time:27360ms step_avg:40.84ms +step:671/1555 train_time:27418ms step_avg:40.86ms +step:672/1555 train_time:27483ms step_avg:40.90ms +step:673/1555 train_time:27540ms step_avg:40.92ms +step:674/1555 train_time:27605ms step_avg:40.96ms +step:675/1555 train_time:27663ms step_avg:40.98ms +step:676/1555 train_time:27728ms step_avg:41.02ms +step:677/1555 train_time:27785ms step_avg:41.04ms +step:678/1555 train_time:27849ms step_avg:41.08ms +step:679/1555 train_time:27906ms step_avg:41.10ms +step:680/1555 train_time:27970ms step_avg:41.13ms +step:681/1555 train_time:28029ms step_avg:41.16ms +step:682/1555 train_time:28092ms step_avg:41.19ms +step:683/1555 train_time:28150ms step_avg:41.21ms +step:684/1555 train_time:28214ms step_avg:41.25ms +step:685/1555 train_time:28272ms step_avg:41.27ms +step:686/1555 train_time:28338ms step_avg:41.31ms +step:687/1555 train_time:28395ms step_avg:41.33ms +step:688/1555 train_time:28460ms step_avg:41.37ms +step:689/1555 train_time:28518ms step_avg:41.39ms +step:690/1555 train_time:28582ms step_avg:41.42ms +step:691/1555 train_time:28640ms step_avg:41.45ms +step:692/1555 train_time:28705ms step_avg:41.48ms +step:693/1555 train_time:28763ms step_avg:41.51ms +step:694/1555 train_time:28827ms step_avg:41.54ms +step:695/1555 train_time:28885ms step_avg:41.56ms +step:696/1555 train_time:28949ms step_avg:41.59ms +step:697/1555 train_time:29007ms step_avg:41.62ms +step:698/1555 train_time:29071ms step_avg:41.65ms +step:699/1555 train_time:29129ms step_avg:41.67ms +step:700/1555 train_time:29193ms step_avg:41.70ms +step:701/1555 train_time:29250ms step_avg:41.73ms +step:702/1555 train_time:29315ms step_avg:41.76ms +step:703/1555 train_time:29373ms step_avg:41.78ms +step:704/1555 train_time:29438ms step_avg:41.81ms +step:705/1555 train_time:29496ms step_avg:41.84ms +step:706/1555 train_time:29559ms step_avg:41.87ms +step:707/1555 train_time:29618ms step_avg:41.89ms +step:708/1555 train_time:29683ms step_avg:41.93ms +step:709/1555 train_time:29741ms step_avg:41.95ms +step:710/1555 train_time:29805ms step_avg:41.98ms +step:711/1555 train_time:29863ms step_avg:42.00ms +step:712/1555 train_time:29927ms step_avg:42.03ms +step:713/1555 train_time:29985ms step_avg:42.05ms +step:714/1555 train_time:30049ms step_avg:42.09ms +step:715/1555 train_time:30107ms step_avg:42.11ms +step:716/1555 train_time:30170ms step_avg:42.14ms +step:717/1555 train_time:30229ms step_avg:42.16ms +step:718/1555 train_time:30292ms step_avg:42.19ms +step:719/1555 train_time:30350ms step_avg:42.21ms +step:720/1555 train_time:30414ms step_avg:42.24ms +step:721/1555 train_time:30472ms step_avg:42.26ms +step:722/1555 train_time:30537ms step_avg:42.30ms +step:723/1555 train_time:30595ms step_avg:42.32ms +step:724/1555 train_time:30661ms step_avg:42.35ms +step:725/1555 train_time:30720ms step_avg:42.37ms +step:726/1555 train_time:30784ms step_avg:42.40ms +step:727/1555 train_time:30841ms step_avg:42.42ms +step:728/1555 train_time:30905ms step_avg:42.45ms +step:729/1555 train_time:30962ms step_avg:42.47ms +step:730/1555 train_time:31028ms step_avg:42.50ms +step:731/1555 train_time:31085ms step_avg:42.52ms +step:732/1555 train_time:31149ms step_avg:42.55ms +step:733/1555 train_time:31207ms step_avg:42.57ms +step:734/1555 train_time:31271ms step_avg:42.60ms +step:735/1555 train_time:31329ms step_avg:42.62ms +step:736/1555 train_time:31393ms step_avg:42.65ms +step:737/1555 train_time:31451ms step_avg:42.67ms +step:738/1555 train_time:31516ms step_avg:42.70ms +step:739/1555 train_time:31573ms step_avg:42.72ms +step:740/1555 train_time:31639ms step_avg:42.76ms +step:741/1555 train_time:31697ms step_avg:42.78ms +step:742/1555 train_time:31761ms step_avg:42.80ms +step:743/1555 train_time:31819ms step_avg:42.83ms +step:744/1555 train_time:31884ms step_avg:42.85ms +step:745/1555 train_time:31942ms step_avg:42.87ms +step:746/1555 train_time:32005ms step_avg:42.90ms +step:747/1555 train_time:32063ms step_avg:42.92ms +step:748/1555 train_time:32128ms step_avg:42.95ms +step:749/1555 train_time:32186ms step_avg:42.97ms +step:750/1555 train_time:32249ms step_avg:43.00ms +step:750/1555 val_loss:3.8682 train_time:32331ms step_avg:43.11ms +step:751/1555 train_time:32354ms step_avg:43.08ms +step:752/1555 train_time:32376ms step_avg:43.05ms +step:753/1555 train_time:32431ms step_avg:43.07ms +step:754/1555 train_time:32502ms step_avg:43.11ms +step:755/1555 train_time:32562ms step_avg:43.13ms +step:756/1555 train_time:32626ms step_avg:43.16ms +step:757/1555 train_time:32683ms step_avg:43.17ms +step:758/1555 train_time:32746ms step_avg:43.20ms +step:759/1555 train_time:32804ms step_avg:43.22ms +step:760/1555 train_time:32867ms step_avg:43.25ms +step:761/1555 train_time:32924ms step_avg:43.26ms +step:762/1555 train_time:32988ms step_avg:43.29ms +step:763/1555 train_time:33045ms step_avg:43.31ms +step:764/1555 train_time:33109ms step_avg:43.34ms +step:765/1555 train_time:33166ms step_avg:43.35ms +step:766/1555 train_time:33230ms step_avg:43.38ms +step:767/1555 train_time:33287ms step_avg:43.40ms +step:768/1555 train_time:33352ms step_avg:43.43ms +step:769/1555 train_time:33412ms step_avg:43.45ms +step:770/1555 train_time:33480ms step_avg:43.48ms +step:771/1555 train_time:33539ms step_avg:43.50ms +step:772/1555 train_time:33603ms step_avg:43.53ms +step:773/1555 train_time:33660ms step_avg:43.55ms +step:774/1555 train_time:33724ms step_avg:43.57ms +step:775/1555 train_time:33782ms step_avg:43.59ms +step:776/1555 train_time:33845ms step_avg:43.61ms +step:777/1555 train_time:33902ms step_avg:43.63ms +step:778/1555 train_time:33965ms step_avg:43.66ms +step:779/1555 train_time:34022ms step_avg:43.67ms +step:780/1555 train_time:34086ms step_avg:43.70ms +step:781/1555 train_time:34143ms step_avg:43.72ms +step:782/1555 train_time:34207ms step_avg:43.74ms +step:783/1555 train_time:34264ms step_avg:43.76ms +step:784/1555 train_time:34330ms step_avg:43.79ms +step:785/1555 train_time:34388ms step_avg:43.81ms +step:786/1555 train_time:34454ms step_avg:43.84ms +step:787/1555 train_time:34513ms step_avg:43.85ms +step:788/1555 train_time:34578ms step_avg:43.88ms +step:789/1555 train_time:34636ms step_avg:43.90ms +step:790/1555 train_time:34700ms step_avg:43.92ms +step:791/1555 train_time:34757ms step_avg:43.94ms +step:792/1555 train_time:34821ms step_avg:43.97ms +step:793/1555 train_time:34879ms step_avg:43.98ms +step:794/1555 train_time:34943ms step_avg:44.01ms +step:795/1555 train_time:35001ms step_avg:44.03ms +step:796/1555 train_time:35064ms step_avg:44.05ms +step:797/1555 train_time:35122ms step_avg:44.07ms +step:798/1555 train_time:35186ms step_avg:44.09ms +step:799/1555 train_time:35244ms step_avg:44.11ms +step:800/1555 train_time:35308ms step_avg:44.13ms +step:801/1555 train_time:35365ms step_avg:44.15ms +step:802/1555 train_time:35431ms step_avg:44.18ms +step:803/1555 train_time:35489ms step_avg:44.20ms +step:804/1555 train_time:35555ms step_avg:44.22ms +step:805/1555 train_time:35613ms step_avg:44.24ms +step:806/1555 train_time:35677ms step_avg:44.26ms +step:807/1555 train_time:35735ms step_avg:44.28ms +step:808/1555 train_time:35799ms step_avg:44.31ms +step:809/1555 train_time:35858ms step_avg:44.32ms +step:810/1555 train_time:35921ms step_avg:44.35ms +step:811/1555 train_time:35979ms step_avg:44.36ms +step:812/1555 train_time:36043ms step_avg:44.39ms +step:813/1555 train_time:36100ms step_avg:44.40ms +step:814/1555 train_time:36164ms step_avg:44.43ms +step:815/1555 train_time:36223ms step_avg:44.45ms +step:816/1555 train_time:36286ms step_avg:44.47ms +step:817/1555 train_time:36343ms step_avg:44.48ms +step:818/1555 train_time:36407ms step_avg:44.51ms +step:819/1555 train_time:36466ms step_avg:44.53ms +step:820/1555 train_time:36531ms step_avg:44.55ms +step:821/1555 train_time:36589ms step_avg:44.57ms +step:822/1555 train_time:36655ms step_avg:44.59ms +step:823/1555 train_time:36713ms step_avg:44.61ms +step:824/1555 train_time:36778ms step_avg:44.63ms +step:825/1555 train_time:36836ms step_avg:44.65ms +step:826/1555 train_time:36900ms step_avg:44.67ms +step:827/1555 train_time:36957ms step_avg:44.69ms +step:828/1555 train_time:37021ms step_avg:44.71ms +step:829/1555 train_time:37077ms step_avg:44.73ms +step:830/1555 train_time:37142ms step_avg:44.75ms +step:831/1555 train_time:37200ms step_avg:44.77ms +step:832/1555 train_time:37263ms step_avg:44.79ms +step:833/1555 train_time:37321ms step_avg:44.80ms +step:834/1555 train_time:37386ms step_avg:44.83ms +step:835/1555 train_time:37443ms step_avg:44.84ms +step:836/1555 train_time:37508ms step_avg:44.87ms +step:837/1555 train_time:37566ms step_avg:44.88ms +step:838/1555 train_time:37632ms step_avg:44.91ms +step:839/1555 train_time:37689ms step_avg:44.92ms +step:840/1555 train_time:37755ms step_avg:44.95ms +step:841/1555 train_time:37813ms step_avg:44.96ms +step:842/1555 train_time:37878ms step_avg:44.99ms +step:843/1555 train_time:37937ms step_avg:45.00ms +step:844/1555 train_time:38000ms step_avg:45.02ms +step:845/1555 train_time:38057ms step_avg:45.04ms +step:846/1555 train_time:38122ms step_avg:45.06ms +step:847/1555 train_time:38179ms step_avg:45.08ms +step:848/1555 train_time:38242ms step_avg:45.10ms +step:849/1555 train_time:38300ms step_avg:45.11ms +step:850/1555 train_time:38364ms step_avg:45.13ms +step:851/1555 train_time:38422ms step_avg:45.15ms +step:852/1555 train_time:38487ms step_avg:45.17ms +step:853/1555 train_time:38545ms step_avg:45.19ms +step:854/1555 train_time:38609ms step_avg:45.21ms +step:855/1555 train_time:38667ms step_avg:45.22ms +step:856/1555 train_time:38731ms step_avg:45.25ms +step:857/1555 train_time:38789ms step_avg:45.26ms +step:858/1555 train_time:38855ms step_avg:45.29ms +step:859/1555 train_time:38913ms step_avg:45.30ms +step:860/1555 train_time:38977ms step_avg:45.32ms +step:861/1555 train_time:39034ms step_avg:45.34ms +step:862/1555 train_time:39099ms step_avg:45.36ms +step:863/1555 train_time:39156ms step_avg:45.37ms +step:864/1555 train_time:39220ms step_avg:45.39ms +step:865/1555 train_time:39277ms step_avg:45.41ms +step:866/1555 train_time:39343ms step_avg:45.43ms +step:867/1555 train_time:39401ms step_avg:45.44ms +step:868/1555 train_time:39464ms step_avg:45.47ms +step:869/1555 train_time:39522ms step_avg:45.48ms +step:870/1555 train_time:39588ms step_avg:45.50ms +step:871/1555 train_time:39645ms step_avg:45.52ms +step:872/1555 train_time:39709ms step_avg:45.54ms +step:873/1555 train_time:39766ms step_avg:45.55ms +step:874/1555 train_time:39831ms step_avg:45.57ms +step:875/1555 train_time:39889ms step_avg:45.59ms +step:876/1555 train_time:39954ms step_avg:45.61ms +step:877/1555 train_time:40012ms step_avg:45.62ms +step:878/1555 train_time:40077ms step_avg:45.65ms +step:879/1555 train_time:40135ms step_avg:45.66ms +step:880/1555 train_time:40199ms step_avg:45.68ms +step:881/1555 train_time:40256ms step_avg:45.69ms +step:882/1555 train_time:40320ms step_avg:45.71ms +step:883/1555 train_time:40378ms step_avg:45.73ms +step:884/1555 train_time:40444ms step_avg:45.75ms +step:885/1555 train_time:40502ms step_avg:45.77ms +step:886/1555 train_time:40565ms step_avg:45.78ms +step:887/1555 train_time:40623ms step_avg:45.80ms +step:888/1555 train_time:40688ms step_avg:45.82ms +step:889/1555 train_time:40745ms step_avg:45.83ms +step:890/1555 train_time:40809ms step_avg:45.85ms +step:891/1555 train_time:40866ms step_avg:45.87ms +step:892/1555 train_time:40931ms step_avg:45.89ms +step:893/1555 train_time:40988ms step_avg:45.90ms +step:894/1555 train_time:41054ms step_avg:45.92ms +step:895/1555 train_time:41113ms step_avg:45.94ms +step:896/1555 train_time:41177ms step_avg:45.96ms +step:897/1555 train_time:41236ms step_avg:45.97ms +step:898/1555 train_time:41300ms step_avg:45.99ms +step:899/1555 train_time:41357ms step_avg:46.00ms +step:900/1555 train_time:41422ms step_avg:46.02ms +step:901/1555 train_time:41480ms step_avg:46.04ms +step:902/1555 train_time:41544ms step_avg:46.06ms +step:903/1555 train_time:41602ms step_avg:46.07ms +step:904/1555 train_time:41665ms step_avg:46.09ms +step:905/1555 train_time:41723ms step_avg:46.10ms +step:906/1555 train_time:41787ms step_avg:46.12ms +step:907/1555 train_time:41845ms step_avg:46.14ms +step:908/1555 train_time:41909ms step_avg:46.16ms +step:909/1555 train_time:41967ms step_avg:46.17ms +step:910/1555 train_time:42032ms step_avg:46.19ms +step:911/1555 train_time:42089ms step_avg:46.20ms +step:912/1555 train_time:42155ms step_avg:46.22ms +step:913/1555 train_time:42213ms step_avg:46.24ms +step:914/1555 train_time:42277ms step_avg:46.26ms +step:915/1555 train_time:42336ms step_avg:46.27ms +step:916/1555 train_time:42399ms step_avg:46.29ms +step:917/1555 train_time:42457ms step_avg:46.30ms +step:918/1555 train_time:42521ms step_avg:46.32ms +step:919/1555 train_time:42579ms step_avg:46.33ms +step:920/1555 train_time:42643ms step_avg:46.35ms +step:921/1555 train_time:42701ms step_avg:46.36ms +step:922/1555 train_time:42764ms step_avg:46.38ms +step:923/1555 train_time:42822ms step_avg:46.39ms +step:924/1555 train_time:42886ms step_avg:46.41ms +step:925/1555 train_time:42944ms step_avg:46.43ms +step:926/1555 train_time:43009ms step_avg:46.45ms +step:927/1555 train_time:43067ms step_avg:46.46ms +step:928/1555 train_time:43132ms step_avg:46.48ms +step:929/1555 train_time:43190ms step_avg:46.49ms +step:930/1555 train_time:43255ms step_avg:46.51ms +step:931/1555 train_time:43313ms step_avg:46.52ms +step:932/1555 train_time:43377ms step_avg:46.54ms +step:933/1555 train_time:43435ms step_avg:46.55ms +step:934/1555 train_time:43500ms step_avg:46.57ms +step:935/1555 train_time:43557ms step_avg:46.59ms +step:936/1555 train_time:43622ms step_avg:46.60ms +step:937/1555 train_time:43679ms step_avg:46.62ms +step:938/1555 train_time:43743ms step_avg:46.63ms +step:939/1555 train_time:43801ms step_avg:46.65ms +step:940/1555 train_time:43864ms step_avg:46.66ms +step:941/1555 train_time:43922ms step_avg:46.68ms +step:942/1555 train_time:43986ms step_avg:46.69ms +step:943/1555 train_time:44044ms step_avg:46.71ms +step:944/1555 train_time:44108ms step_avg:46.72ms +step:945/1555 train_time:44166ms step_avg:46.74ms +step:946/1555 train_time:44231ms step_avg:46.76ms +step:947/1555 train_time:44288ms step_avg:46.77ms +step:948/1555 train_time:44354ms step_avg:46.79ms +step:949/1555 train_time:44412ms step_avg:46.80ms +step:950/1555 train_time:44476ms step_avg:46.82ms +step:951/1555 train_time:44535ms step_avg:46.83ms +step:952/1555 train_time:44599ms step_avg:46.85ms +step:953/1555 train_time:44657ms step_avg:46.86ms +step:954/1555 train_time:44721ms step_avg:46.88ms +step:955/1555 train_time:44778ms step_avg:46.89ms +step:956/1555 train_time:44843ms step_avg:46.91ms +step:957/1555 train_time:44900ms step_avg:46.92ms +step:958/1555 train_time:44964ms step_avg:46.94ms +step:959/1555 train_time:45022ms step_avg:46.95ms +step:960/1555 train_time:45086ms step_avg:46.96ms +step:961/1555 train_time:45144ms step_avg:46.98ms +step:962/1555 train_time:45208ms step_avg:46.99ms +step:963/1555 train_time:45266ms step_avg:47.01ms +step:964/1555 train_time:45330ms step_avg:47.02ms +step:965/1555 train_time:45388ms step_avg:47.03ms +step:966/1555 train_time:45453ms step_avg:47.05ms +step:967/1555 train_time:45512ms step_avg:47.06ms +step:968/1555 train_time:45576ms step_avg:47.08ms +step:969/1555 train_time:45635ms step_avg:47.09ms +step:970/1555 train_time:45699ms step_avg:47.11ms +step:971/1555 train_time:45756ms step_avg:47.12ms +step:972/1555 train_time:45820ms step_avg:47.14ms +step:973/1555 train_time:45878ms step_avg:47.15ms +step:974/1555 train_time:45942ms step_avg:47.17ms +step:975/1555 train_time:46000ms step_avg:47.18ms +step:976/1555 train_time:46064ms step_avg:47.20ms +step:977/1555 train_time:46122ms step_avg:47.21ms +step:978/1555 train_time:46186ms step_avg:47.22ms +step:979/1555 train_time:46244ms step_avg:47.24ms +step:980/1555 train_time:46307ms step_avg:47.25ms +step:981/1555 train_time:46366ms step_avg:47.26ms +step:982/1555 train_time:46430ms step_avg:47.28ms +step:983/1555 train_time:46488ms step_avg:47.29ms +step:984/1555 train_time:46553ms step_avg:47.31ms +step:985/1555 train_time:46611ms step_avg:47.32ms +step:986/1555 train_time:46676ms step_avg:47.34ms +step:987/1555 train_time:46734ms step_avg:47.35ms +step:988/1555 train_time:46797ms step_avg:47.37ms +step:989/1555 train_time:46855ms step_avg:47.38ms +step:990/1555 train_time:46919ms step_avg:47.39ms +step:991/1555 train_time:46977ms step_avg:47.40ms +step:992/1555 train_time:47042ms step_avg:47.42ms +step:993/1555 train_time:47101ms step_avg:47.43ms +step:994/1555 train_time:47164ms step_avg:47.45ms +step:995/1555 train_time:47222ms step_avg:47.46ms +step:996/1555 train_time:47285ms step_avg:47.47ms +step:997/1555 train_time:47343ms step_avg:47.49ms +step:998/1555 train_time:47408ms step_avg:47.50ms +step:999/1555 train_time:47466ms step_avg:47.51ms +step:1000/1555 train_time:47531ms step_avg:47.53ms +step:1000/1555 val_loss:3.5680 train_time:47614ms step_avg:47.61ms +step:1001/1555 train_time:47636ms step_avg:47.59ms +step:1002/1555 train_time:47657ms step_avg:47.56ms +step:1003/1555 train_time:47717ms step_avg:47.57ms +step:1004/1555 train_time:47782ms step_avg:47.59ms +step:1005/1555 train_time:47842ms step_avg:47.60ms +step:1006/1555 train_time:47906ms step_avg:47.62ms +step:1007/1555 train_time:47964ms step_avg:47.63ms +step:1008/1555 train_time:48028ms step_avg:47.65ms +step:1009/1555 train_time:48085ms step_avg:47.66ms +step:1010/1555 train_time:48149ms step_avg:47.67ms +step:1011/1555 train_time:48208ms step_avg:47.68ms +step:1012/1555 train_time:48294ms step_avg:47.72ms +step:1013/1555 train_time:48378ms step_avg:47.76ms +step:1014/1555 train_time:48468ms step_avg:47.80ms +step:1015/1555 train_time:48552ms step_avg:47.83ms +step:1016/1555 train_time:48644ms step_avg:47.88ms +step:1017/1555 train_time:48731ms step_avg:47.92ms +step:1018/1555 train_time:48823ms step_avg:47.96ms +step:1019/1555 train_time:48907ms step_avg:48.00ms +step:1020/1555 train_time:48999ms step_avg:48.04ms +step:1021/1555 train_time:49082ms step_avg:48.07ms +step:1022/1555 train_time:49170ms step_avg:48.11ms +step:1023/1555 train_time:49254ms step_avg:48.15ms +step:1024/1555 train_time:49343ms step_avg:48.19ms +step:1025/1555 train_time:49425ms step_avg:48.22ms +step:1026/1555 train_time:49514ms step_avg:48.26ms +step:1027/1555 train_time:49600ms step_avg:48.30ms +step:1028/1555 train_time:49690ms step_avg:48.34ms +step:1029/1555 train_time:49777ms step_avg:48.37ms +step:1030/1555 train_time:49867ms step_avg:48.41ms +step:1031/1555 train_time:49952ms step_avg:48.45ms +step:1032/1555 train_time:50043ms step_avg:48.49ms +step:1033/1555 train_time:50126ms step_avg:48.52ms +step:1034/1555 train_time:50215ms step_avg:48.56ms +step:1035/1555 train_time:50299ms step_avg:48.60ms +step:1036/1555 train_time:50388ms step_avg:48.64ms +step:1037/1555 train_time:50471ms step_avg:48.67ms +step:1038/1555 train_time:50563ms step_avg:48.71ms +step:1039/1555 train_time:50646ms step_avg:48.75ms +step:1040/1555 train_time:50737ms step_avg:48.79ms +step:1041/1555 train_time:50822ms step_avg:48.82ms +step:1042/1555 train_time:50913ms step_avg:48.86ms +step:1043/1555 train_time:50998ms step_avg:48.90ms +step:1044/1555 train_time:51088ms step_avg:48.93ms +step:1045/1555 train_time:51172ms step_avg:48.97ms +step:1046/1555 train_time:51263ms step_avg:49.01ms +step:1047/1555 train_time:51346ms step_avg:49.04ms +step:1048/1555 train_time:51436ms step_avg:49.08ms +step:1049/1555 train_time:51519ms step_avg:49.11ms +step:1050/1555 train_time:51608ms step_avg:49.15ms +step:1051/1555 train_time:51695ms step_avg:49.19ms +step:1052/1555 train_time:51785ms step_avg:49.23ms +step:1053/1555 train_time:51870ms step_avg:49.26ms +step:1054/1555 train_time:51961ms step_avg:49.30ms +step:1055/1555 train_time:52044ms step_avg:49.33ms +step:1056/1555 train_time:52134ms step_avg:49.37ms +step:1057/1555 train_time:52220ms step_avg:49.40ms +step:1058/1555 train_time:52310ms step_avg:49.44ms +step:1059/1555 train_time:52392ms step_avg:49.47ms +step:1060/1555 train_time:52483ms step_avg:49.51ms +step:1061/1555 train_time:52564ms step_avg:49.54ms +step:1062/1555 train_time:52656ms step_avg:49.58ms +step:1063/1555 train_time:52741ms step_avg:49.62ms +step:1064/1555 train_time:52830ms step_avg:49.65ms +step:1065/1555 train_time:52915ms step_avg:49.69ms +step:1066/1555 train_time:53007ms step_avg:49.73ms +step:1067/1555 train_time:53089ms step_avg:49.76ms +step:1068/1555 train_time:53181ms step_avg:49.79ms +step:1069/1555 train_time:53264ms step_avg:49.83ms +step:1070/1555 train_time:53355ms step_avg:49.86ms +step:1071/1555 train_time:53438ms step_avg:49.90ms +step:1072/1555 train_time:53527ms step_avg:49.93ms +step:1073/1555 train_time:53611ms step_avg:49.96ms +step:1074/1555 train_time:53702ms step_avg:50.00ms +step:1075/1555 train_time:53786ms step_avg:50.03ms +step:1076/1555 train_time:53876ms step_avg:50.07ms +step:1077/1555 train_time:53960ms step_avg:50.10ms +step:1078/1555 train_time:54049ms step_avg:50.14ms +step:1079/1555 train_time:54134ms step_avg:50.17ms +step:1080/1555 train_time:54224ms step_avg:50.21ms +step:1081/1555 train_time:54308ms step_avg:50.24ms +step:1082/1555 train_time:54399ms step_avg:50.28ms +step:1083/1555 train_time:54482ms step_avg:50.31ms +step:1084/1555 train_time:54572ms step_avg:50.34ms +step:1085/1555 train_time:54656ms step_avg:50.37ms +step:1086/1555 train_time:54746ms step_avg:50.41ms +step:1087/1555 train_time:54831ms step_avg:50.44ms +step:1088/1555 train_time:54922ms step_avg:50.48ms +step:1089/1555 train_time:55005ms step_avg:50.51ms +step:1090/1555 train_time:55095ms step_avg:50.55ms +step:1091/1555 train_time:55180ms step_avg:50.58ms +step:1092/1555 train_time:55269ms step_avg:50.61ms +step:1093/1555 train_time:55353ms step_avg:50.64ms +step:1094/1555 train_time:55443ms step_avg:50.68ms +step:1095/1555 train_time:55526ms step_avg:50.71ms +step:1096/1555 train_time:55617ms step_avg:50.75ms +step:1097/1555 train_time:55701ms step_avg:50.78ms +step:1098/1555 train_time:55790ms step_avg:50.81ms +step:1099/1555 train_time:55875ms step_avg:50.84ms +step:1100/1555 train_time:55964ms step_avg:50.88ms +step:1101/1555 train_time:56048ms step_avg:50.91ms +step:1102/1555 train_time:56140ms step_avg:50.94ms +step:1103/1555 train_time:56224ms step_avg:50.97ms +step:1104/1555 train_time:56313ms step_avg:51.01ms +step:1105/1555 train_time:56397ms step_avg:51.04ms +step:1106/1555 train_time:56486ms step_avg:51.07ms +step:1107/1555 train_time:56570ms step_avg:51.10ms +step:1108/1555 train_time:56660ms step_avg:51.14ms +step:1109/1555 train_time:56743ms step_avg:51.17ms +step:1110/1555 train_time:56834ms step_avg:51.20ms +step:1111/1555 train_time:56917ms step_avg:51.23ms +step:1112/1555 train_time:57007ms step_avg:51.26ms +step:1113/1555 train_time:57091ms step_avg:51.29ms +step:1114/1555 train_time:57183ms step_avg:51.33ms +step:1115/1555 train_time:57265ms step_avg:51.36ms +step:1116/1555 train_time:57355ms step_avg:51.39ms +step:1117/1555 train_time:57439ms step_avg:51.42ms +step:1118/1555 train_time:57528ms step_avg:51.46ms +step:1119/1555 train_time:57613ms step_avg:51.49ms +step:1120/1555 train_time:57704ms step_avg:51.52ms +step:1121/1555 train_time:57787ms step_avg:51.55ms +step:1122/1555 train_time:57878ms step_avg:51.58ms +step:1123/1555 train_time:57962ms step_avg:51.61ms +step:1124/1555 train_time:58051ms step_avg:51.65ms +step:1125/1555 train_time:58136ms step_avg:51.68ms +step:1126/1555 train_time:58225ms step_avg:51.71ms +step:1127/1555 train_time:58309ms step_avg:51.74ms +step:1128/1555 train_time:58399ms step_avg:51.77ms +step:1129/1555 train_time:58483ms step_avg:51.80ms +step:1130/1555 train_time:58573ms step_avg:51.83ms +step:1131/1555 train_time:58655ms step_avg:51.86ms +step:1132/1555 train_time:58745ms step_avg:51.90ms +step:1133/1555 train_time:58828ms step_avg:51.92ms +step:1134/1555 train_time:58920ms step_avg:51.96ms +step:1135/1555 train_time:59003ms step_avg:51.98ms +step:1136/1555 train_time:59093ms step_avg:52.02ms +step:1137/1555 train_time:59178ms step_avg:52.05ms +step:1138/1555 train_time:59268ms step_avg:52.08ms +step:1139/1555 train_time:59353ms step_avg:52.11ms +step:1140/1555 train_time:59443ms step_avg:52.14ms +step:1141/1555 train_time:59526ms step_avg:52.17ms +step:1142/1555 train_time:59618ms step_avg:52.20ms +step:1143/1555 train_time:59702ms step_avg:52.23ms +step:1144/1555 train_time:59791ms step_avg:52.27ms +step:1145/1555 train_time:59877ms step_avg:52.29ms +step:1146/1555 train_time:59965ms step_avg:52.33ms +step:1147/1555 train_time:60049ms step_avg:52.35ms +step:1148/1555 train_time:60140ms step_avg:52.39ms +step:1149/1555 train_time:60224ms step_avg:52.41ms +step:1150/1555 train_time:60314ms step_avg:52.45ms +step:1151/1555 train_time:60398ms step_avg:52.47ms +step:1152/1555 train_time:60487ms step_avg:52.51ms +step:1153/1555 train_time:60571ms step_avg:52.53ms +step:1154/1555 train_time:60663ms step_avg:52.57ms +step:1155/1555 train_time:60745ms step_avg:52.59ms +step:1156/1555 train_time:60836ms step_avg:52.63ms +step:1157/1555 train_time:60920ms step_avg:52.65ms +step:1158/1555 train_time:61009ms step_avg:52.68ms +step:1159/1555 train_time:61094ms step_avg:52.71ms +step:1160/1555 train_time:61185ms step_avg:52.75ms +step:1161/1555 train_time:61270ms step_avg:52.77ms +step:1162/1555 train_time:61359ms step_avg:52.80ms +step:1163/1555 train_time:61442ms step_avg:52.83ms +step:1164/1555 train_time:61533ms step_avg:52.86ms +step:1165/1555 train_time:61617ms step_avg:52.89ms +step:1166/1555 train_time:61707ms step_avg:52.92ms +step:1167/1555 train_time:61792ms step_avg:52.95ms +step:1168/1555 train_time:61883ms step_avg:52.98ms +step:1169/1555 train_time:61966ms step_avg:53.01ms +step:1170/1555 train_time:62056ms step_avg:53.04ms +step:1171/1555 train_time:62140ms step_avg:53.07ms +step:1172/1555 train_time:62230ms step_avg:53.10ms +step:1173/1555 train_time:62315ms step_avg:53.12ms +step:1174/1555 train_time:62404ms step_avg:53.16ms +step:1175/1555 train_time:62488ms step_avg:53.18ms +step:1176/1555 train_time:62581ms step_avg:53.22ms +step:1177/1555 train_time:62664ms step_avg:53.24ms +step:1178/1555 train_time:62754ms step_avg:53.27ms +step:1179/1555 train_time:62839ms step_avg:53.30ms +step:1180/1555 train_time:62929ms step_avg:53.33ms +step:1181/1555 train_time:63014ms step_avg:53.36ms +step:1182/1555 train_time:63104ms step_avg:53.39ms +step:1183/1555 train_time:63187ms step_avg:53.41ms +step:1184/1555 train_time:63278ms step_avg:53.44ms +step:1185/1555 train_time:63361ms step_avg:53.47ms +step:1186/1555 train_time:63450ms step_avg:53.50ms +step:1187/1555 train_time:63535ms step_avg:53.53ms +step:1188/1555 train_time:63625ms step_avg:53.56ms +step:1189/1555 train_time:63708ms step_avg:53.58ms +step:1190/1555 train_time:63800ms step_avg:53.61ms +step:1191/1555 train_time:63883ms step_avg:53.64ms +step:1192/1555 train_time:63973ms step_avg:53.67ms +step:1193/1555 train_time:64058ms step_avg:53.69ms +step:1194/1555 train_time:64148ms step_avg:53.72ms +step:1195/1555 train_time:64232ms step_avg:53.75ms +step:1196/1555 train_time:64324ms step_avg:53.78ms +step:1197/1555 train_time:64408ms step_avg:53.81ms +step:1198/1555 train_time:64498ms step_avg:53.84ms +step:1199/1555 train_time:64582ms step_avg:53.86ms +step:1200/1555 train_time:64671ms step_avg:53.89ms +step:1201/1555 train_time:64756ms step_avg:53.92ms +step:1202/1555 train_time:64845ms step_avg:53.95ms +step:1203/1555 train_time:64929ms step_avg:53.97ms +step:1204/1555 train_time:65021ms step_avg:54.00ms +step:1205/1555 train_time:65104ms step_avg:54.03ms +step:1206/1555 train_time:65193ms step_avg:54.06ms +step:1207/1555 train_time:65278ms step_avg:54.08ms +step:1208/1555 train_time:65367ms step_avg:54.11ms +step:1209/1555 train_time:65451ms step_avg:54.14ms +step:1210/1555 train_time:65541ms step_avg:54.17ms +step:1211/1555 train_time:65624ms step_avg:54.19ms +step:1212/1555 train_time:65715ms step_avg:54.22ms +step:1213/1555 train_time:65799ms step_avg:54.25ms +step:1214/1555 train_time:65888ms step_avg:54.27ms +step:1215/1555 train_time:65973ms step_avg:54.30ms +step:1216/1555 train_time:66063ms step_avg:54.33ms +step:1217/1555 train_time:66147ms step_avg:54.35ms +step:1218/1555 train_time:66238ms step_avg:54.38ms +step:1219/1555 train_time:66321ms step_avg:54.41ms +step:1220/1555 train_time:66410ms step_avg:54.43ms +step:1221/1555 train_time:66497ms step_avg:54.46ms +step:1222/1555 train_time:66586ms step_avg:54.49ms +step:1223/1555 train_time:66669ms step_avg:54.51ms +step:1224/1555 train_time:66760ms step_avg:54.54ms +step:1225/1555 train_time:66844ms step_avg:54.57ms +step:1226/1555 train_time:66934ms step_avg:54.60ms +step:1227/1555 train_time:67018ms step_avg:54.62ms +step:1228/1555 train_time:67108ms step_avg:54.65ms +step:1229/1555 train_time:67193ms step_avg:54.67ms +step:1230/1555 train_time:67284ms step_avg:54.70ms +step:1231/1555 train_time:67367ms step_avg:54.73ms +step:1232/1555 train_time:67458ms step_avg:54.75ms +step:1233/1555 train_time:67542ms step_avg:54.78ms +step:1234/1555 train_time:67631ms step_avg:54.81ms +step:1235/1555 train_time:67715ms step_avg:54.83ms +step:1236/1555 train_time:67805ms step_avg:54.86ms +step:1237/1555 train_time:67889ms step_avg:54.88ms +step:1238/1555 train_time:67981ms step_avg:54.91ms +step:1239/1555 train_time:68064ms step_avg:54.93ms +step:1240/1555 train_time:68153ms step_avg:54.96ms +step:1241/1555 train_time:68237ms step_avg:54.99ms +step:1242/1555 train_time:68326ms step_avg:55.01ms +step:1243/1555 train_time:68410ms step_avg:55.04ms +step:1244/1555 train_time:68501ms step_avg:55.06ms +step:1245/1555 train_time:68585ms step_avg:55.09ms +step:1246/1555 train_time:68675ms step_avg:55.12ms +step:1247/1555 train_time:68759ms step_avg:55.14ms +step:1248/1555 train_time:68848ms step_avg:55.17ms +step:1249/1555 train_time:68933ms step_avg:55.19ms +step:1250/1555 train_time:69024ms step_avg:55.22ms +step:1250/1555 val_loss:3.3954 train_time:69139ms step_avg:55.31ms +step:1251/1555 train_time:69161ms step_avg:55.28ms +step:1252/1555 train_time:69201ms step_avg:55.27ms +step:1253/1555 train_time:69291ms step_avg:55.30ms +step:1254/1555 train_time:69382ms step_avg:55.33ms +step:1255/1555 train_time:69466ms step_avg:55.35ms +step:1256/1555 train_time:69555ms step_avg:55.38ms +step:1257/1555 train_time:69639ms step_avg:55.40ms +step:1258/1555 train_time:69728ms step_avg:55.43ms +step:1259/1555 train_time:69811ms step_avg:55.45ms +step:1260/1555 train_time:69899ms step_avg:55.48ms +step:1261/1555 train_time:69983ms step_avg:55.50ms +step:1262/1555 train_time:70071ms step_avg:55.52ms +step:1263/1555 train_time:70156ms step_avg:55.55ms +step:1264/1555 train_time:70250ms step_avg:55.58ms +step:1265/1555 train_time:70336ms step_avg:55.60ms +step:1266/1555 train_time:70428ms step_avg:55.63ms +step:1267/1555 train_time:70512ms step_avg:55.65ms +step:1268/1555 train_time:70601ms step_avg:55.68ms +step:1269/1555 train_time:70685ms step_avg:55.70ms +step:1270/1555 train_time:70773ms step_avg:55.73ms +step:1271/1555 train_time:70856ms step_avg:55.75ms +step:1272/1555 train_time:70947ms step_avg:55.78ms +step:1273/1555 train_time:71030ms step_avg:55.80ms +step:1274/1555 train_time:71120ms step_avg:55.82ms +step:1275/1555 train_time:71208ms step_avg:55.85ms +step:1276/1555 train_time:71297ms step_avg:55.88ms +step:1277/1555 train_time:71383ms step_avg:55.90ms +step:1278/1555 train_time:71473ms step_avg:55.93ms +step:1279/1555 train_time:71558ms step_avg:55.95ms +step:1280/1555 train_time:71648ms step_avg:55.97ms +step:1281/1555 train_time:71731ms step_avg:56.00ms +step:1282/1555 train_time:71820ms step_avg:56.02ms +step:1283/1555 train_time:71904ms step_avg:56.04ms +step:1284/1555 train_time:71993ms step_avg:56.07ms +step:1285/1555 train_time:72077ms step_avg:56.09ms +step:1286/1555 train_time:72168ms step_avg:56.12ms +step:1287/1555 train_time:72252ms step_avg:56.14ms +step:1288/1555 train_time:72344ms step_avg:56.17ms +step:1289/1555 train_time:72429ms step_avg:56.19ms +step:1290/1555 train_time:72519ms step_avg:56.22ms +step:1291/1555 train_time:72603ms step_avg:56.24ms +step:1292/1555 train_time:72693ms step_avg:56.26ms +step:1293/1555 train_time:72776ms step_avg:56.28ms +step:1294/1555 train_time:72866ms step_avg:56.31ms +step:1295/1555 train_time:72949ms step_avg:56.33ms +step:1296/1555 train_time:73039ms step_avg:56.36ms +step:1297/1555 train_time:73124ms step_avg:56.38ms +step:1298/1555 train_time:73215ms step_avg:56.41ms +step:1299/1555 train_time:73301ms step_avg:56.43ms +step:1300/1555 train_time:73394ms step_avg:56.46ms +step:1301/1555 train_time:73476ms step_avg:56.48ms +step:1302/1555 train_time:73568ms step_avg:56.50ms +step:1303/1555 train_time:73651ms step_avg:56.52ms +step:1304/1555 train_time:73740ms step_avg:56.55ms +step:1305/1555 train_time:73824ms step_avg:56.57ms +step:1306/1555 train_time:73913ms step_avg:56.59ms +step:1307/1555 train_time:73996ms step_avg:56.62ms +step:1308/1555 train_time:74089ms step_avg:56.64ms +step:1309/1555 train_time:74170ms step_avg:56.66ms +step:1310/1555 train_time:74261ms step_avg:56.69ms +step:1311/1555 train_time:74346ms step_avg:56.71ms +step:1312/1555 train_time:74436ms step_avg:56.73ms +step:1313/1555 train_time:74522ms step_avg:56.76ms +step:1314/1555 train_time:74612ms step_avg:56.78ms +step:1315/1555 train_time:74696ms step_avg:56.80ms +step:1316/1555 train_time:74787ms step_avg:56.83ms +step:1317/1555 train_time:74870ms step_avg:56.85ms +step:1318/1555 train_time:74959ms step_avg:56.87ms +step:1319/1555 train_time:75044ms step_avg:56.89ms +step:1320/1555 train_time:75134ms step_avg:56.92ms +step:1321/1555 train_time:75218ms step_avg:56.94ms +step:1322/1555 train_time:75310ms step_avg:56.97ms +step:1323/1555 train_time:75393ms step_avg:56.99ms +step:1324/1555 train_time:75483ms step_avg:57.01ms +step:1325/1555 train_time:75567ms step_avg:57.03ms +step:1326/1555 train_time:75657ms step_avg:57.06ms +step:1327/1555 train_time:75742ms step_avg:57.08ms +step:1328/1555 train_time:75832ms step_avg:57.10ms +step:1329/1555 train_time:75916ms step_avg:57.12ms +step:1330/1555 train_time:76007ms step_avg:57.15ms +step:1331/1555 train_time:76090ms step_avg:57.17ms +step:1332/1555 train_time:76180ms step_avg:57.19ms +step:1333/1555 train_time:76265ms step_avg:57.21ms +step:1334/1555 train_time:76354ms step_avg:57.24ms +step:1335/1555 train_time:76438ms step_avg:57.26ms +step:1336/1555 train_time:76529ms step_avg:57.28ms +step:1337/1555 train_time:76612ms step_avg:57.30ms +step:1338/1555 train_time:76702ms step_avg:57.33ms +step:1339/1555 train_time:76786ms step_avg:57.35ms +step:1340/1555 train_time:76874ms step_avg:57.37ms +step:1341/1555 train_time:76958ms step_avg:57.39ms +step:1342/1555 train_time:77049ms step_avg:57.41ms +step:1343/1555 train_time:77132ms step_avg:57.43ms +step:1344/1555 train_time:77223ms step_avg:57.46ms +step:1345/1555 train_time:77308ms step_avg:57.48ms +step:1346/1555 train_time:77398ms step_avg:57.50ms +step:1347/1555 train_time:77483ms step_avg:57.52ms +step:1348/1555 train_time:77572ms step_avg:57.55ms +step:1349/1555 train_time:77656ms step_avg:57.57ms +step:1350/1555 train_time:77747ms step_avg:57.59ms +step:1351/1555 train_time:77830ms step_avg:57.61ms +step:1352/1555 train_time:77920ms step_avg:57.63ms +step:1353/1555 train_time:78005ms step_avg:57.65ms +step:1354/1555 train_time:78094ms step_avg:57.68ms +step:1355/1555 train_time:78178ms step_avg:57.70ms +step:1356/1555 train_time:78268ms step_avg:57.72ms +step:1357/1555 train_time:78352ms step_avg:57.74ms +step:1358/1555 train_time:78442ms step_avg:57.76ms +step:1359/1555 train_time:78527ms step_avg:57.78ms +step:1360/1555 train_time:78617ms step_avg:57.81ms +step:1361/1555 train_time:78702ms step_avg:57.83ms +step:1362/1555 train_time:78793ms step_avg:57.85ms +step:1363/1555 train_time:78876ms step_avg:57.87ms +step:1364/1555 train_time:78967ms step_avg:57.89ms +step:1365/1555 train_time:79050ms step_avg:57.91ms +step:1366/1555 train_time:79139ms step_avg:57.94ms +step:1367/1555 train_time:79223ms step_avg:57.95ms +step:1368/1555 train_time:79313ms step_avg:57.98ms +step:1369/1555 train_time:79398ms step_avg:58.00ms +step:1370/1555 train_time:79490ms step_avg:58.02ms +step:1371/1555 train_time:79574ms step_avg:58.04ms +step:1372/1555 train_time:79664ms step_avg:58.06ms +step:1373/1555 train_time:79748ms step_avg:58.08ms +step:1374/1555 train_time:79838ms step_avg:58.11ms +step:1375/1555 train_time:79922ms step_avg:58.12ms +step:1376/1555 train_time:80012ms step_avg:58.15ms +step:1377/1555 train_time:80095ms step_avg:58.17ms +step:1378/1555 train_time:80187ms step_avg:58.19ms +step:1379/1555 train_time:80270ms step_avg:58.21ms +step:1380/1555 train_time:80360ms step_avg:58.23ms +step:1381/1555 train_time:80446ms step_avg:58.25ms +step:1382/1555 train_time:80535ms step_avg:58.27ms +step:1383/1555 train_time:80619ms step_avg:58.29ms +step:1384/1555 train_time:80710ms step_avg:58.32ms +step:1385/1555 train_time:80794ms step_avg:58.33ms +step:1386/1555 train_time:80883ms step_avg:58.36ms +step:1387/1555 train_time:80967ms step_avg:58.38ms +step:1388/1555 train_time:81057ms step_avg:58.40ms +step:1389/1555 train_time:81141ms step_avg:58.42ms +step:1390/1555 train_time:81231ms step_avg:58.44ms +step:1391/1555 train_time:81315ms step_avg:58.46ms +step:1392/1555 train_time:81406ms step_avg:58.48ms +step:1393/1555 train_time:81490ms step_avg:58.50ms +step:1394/1555 train_time:81580ms step_avg:58.52ms +step:1395/1555 train_time:81665ms step_avg:58.54ms +step:1396/1555 train_time:81754ms step_avg:58.56ms +step:1397/1555 train_time:81839ms step_avg:58.58ms +step:1398/1555 train_time:81929ms step_avg:58.60ms +step:1399/1555 train_time:82013ms step_avg:58.62ms +step:1400/1555 train_time:82102ms step_avg:58.64ms +step:1401/1555 train_time:82186ms step_avg:58.66ms +step:1402/1555 train_time:82275ms step_avg:58.68ms +step:1403/1555 train_time:82359ms step_avg:58.70ms +step:1404/1555 train_time:82451ms step_avg:58.73ms +step:1405/1555 train_time:82534ms step_avg:58.74ms +step:1406/1555 train_time:82624ms step_avg:58.77ms +step:1407/1555 train_time:82709ms step_avg:58.78ms +step:1408/1555 train_time:82798ms step_avg:58.81ms +step:1409/1555 train_time:82883ms step_avg:58.82ms +step:1410/1555 train_time:82972ms step_avg:58.85ms +step:1411/1555 train_time:83056ms step_avg:58.86ms +step:1412/1555 train_time:83147ms step_avg:58.89ms +step:1413/1555 train_time:83230ms step_avg:58.90ms +step:1414/1555 train_time:83320ms step_avg:58.92ms +step:1415/1555 train_time:83405ms step_avg:58.94ms +step:1416/1555 train_time:83494ms step_avg:58.96ms +step:1417/1555 train_time:83577ms step_avg:58.98ms +step:1418/1555 train_time:83669ms step_avg:59.00ms +step:1419/1555 train_time:83754ms step_avg:59.02ms +step:1420/1555 train_time:83843ms step_avg:59.04ms +step:1421/1555 train_time:83927ms step_avg:59.06ms +step:1422/1555 train_time:84016ms step_avg:59.08ms +step:1423/1555 train_time:84100ms step_avg:59.10ms +step:1424/1555 train_time:84191ms step_avg:59.12ms +step:1425/1555 train_time:84274ms step_avg:59.14ms +step:1426/1555 train_time:84364ms step_avg:59.16ms +step:1427/1555 train_time:84449ms step_avg:59.18ms +step:1428/1555 train_time:84537ms step_avg:59.20ms +step:1429/1555 train_time:84623ms step_avg:59.22ms +step:1430/1555 train_time:84713ms step_avg:59.24ms +step:1431/1555 train_time:84796ms step_avg:59.26ms +step:1432/1555 train_time:84887ms step_avg:59.28ms +step:1433/1555 train_time:84970ms step_avg:59.30ms +step:1434/1555 train_time:85060ms step_avg:59.32ms +step:1435/1555 train_time:85144ms step_avg:59.33ms +step:1436/1555 train_time:85233ms step_avg:59.35ms +step:1437/1555 train_time:85316ms step_avg:59.37ms +step:1438/1555 train_time:85406ms step_avg:59.39ms +step:1439/1555 train_time:85491ms step_avg:59.41ms +step:1440/1555 train_time:85580ms step_avg:59.43ms +step:1441/1555 train_time:85663ms step_avg:59.45ms +step:1442/1555 train_time:85753ms step_avg:59.47ms +step:1443/1555 train_time:85837ms step_avg:59.49ms +step:1444/1555 train_time:85928ms step_avg:59.51ms +step:1445/1555 train_time:86012ms step_avg:59.52ms +step:1446/1555 train_time:86101ms step_avg:59.54ms +step:1447/1555 train_time:86186ms step_avg:59.56ms +step:1448/1555 train_time:86275ms step_avg:59.58ms +step:1449/1555 train_time:86360ms step_avg:59.60ms +step:1450/1555 train_time:86450ms step_avg:59.62ms +step:1451/1555 train_time:86532ms step_avg:59.64ms +step:1452/1555 train_time:86623ms step_avg:59.66ms +step:1453/1555 train_time:86708ms step_avg:59.67ms +step:1454/1555 train_time:86797ms step_avg:59.70ms +step:1455/1555 train_time:86882ms step_avg:59.71ms +step:1456/1555 train_time:86972ms step_avg:59.73ms +step:1457/1555 train_time:87056ms step_avg:59.75ms +step:1458/1555 train_time:87147ms step_avg:59.77ms +step:1459/1555 train_time:87230ms step_avg:59.79ms +step:1460/1555 train_time:87319ms step_avg:59.81ms +step:1461/1555 train_time:87405ms step_avg:59.83ms +step:1462/1555 train_time:87494ms step_avg:59.85ms +step:1463/1555 train_time:87577ms step_avg:59.86ms +step:1464/1555 train_time:87668ms step_avg:59.88ms +step:1465/1555 train_time:87751ms step_avg:59.90ms +step:1466/1555 train_time:87841ms step_avg:59.92ms +step:1467/1555 train_time:87927ms step_avg:59.94ms +step:1468/1555 train_time:88018ms step_avg:59.96ms +step:1469/1555 train_time:88102ms step_avg:59.97ms +step:1470/1555 train_time:88192ms step_avg:59.99ms +step:1471/1555 train_time:88275ms step_avg:60.01ms +step:1472/1555 train_time:88365ms step_avg:60.03ms +step:1473/1555 train_time:88450ms step_avg:60.05ms +step:1474/1555 train_time:88538ms step_avg:60.07ms +step:1475/1555 train_time:88623ms step_avg:60.08ms +step:1476/1555 train_time:88712ms step_avg:60.10ms +step:1477/1555 train_time:88796ms step_avg:60.12ms +step:1478/1555 train_time:88888ms step_avg:60.14ms +step:1479/1555 train_time:88970ms step_avg:60.16ms +step:1480/1555 train_time:89060ms step_avg:60.18ms +step:1481/1555 train_time:89146ms step_avg:60.19ms +step:1482/1555 train_time:89236ms step_avg:60.21ms +step:1483/1555 train_time:89320ms step_avg:60.23ms +step:1484/1555 train_time:89411ms step_avg:60.25ms +step:1485/1555 train_time:89495ms step_avg:60.27ms +step:1486/1555 train_time:89585ms step_avg:60.29ms +step:1487/1555 train_time:89668ms step_avg:60.30ms +step:1488/1555 train_time:89757ms step_avg:60.32ms +step:1489/1555 train_time:89842ms step_avg:60.34ms +step:1490/1555 train_time:89933ms step_avg:60.36ms +step:1491/1555 train_time:90017ms step_avg:60.37ms +step:1492/1555 train_time:90106ms step_avg:60.39ms +step:1493/1555 train_time:90191ms step_avg:60.41ms +step:1494/1555 train_time:90280ms step_avg:60.43ms +step:1495/1555 train_time:90363ms step_avg:60.44ms +step:1496/1555 train_time:90453ms step_avg:60.46ms +step:1497/1555 train_time:90537ms step_avg:60.48ms +step:1498/1555 train_time:90629ms step_avg:60.50ms +step:1499/1555 train_time:90712ms step_avg:60.51ms +step:1500/1555 train_time:90802ms step_avg:60.53ms +step:1500/1555 val_loss:3.2932 train_time:90916ms step_avg:60.61ms +step:1501/1555 train_time:90942ms step_avg:60.59ms +step:1502/1555 train_time:90978ms step_avg:60.57ms +step:1503/1555 train_time:91061ms step_avg:60.59ms +step:1504/1555 train_time:91153ms step_avg:60.61ms +step:1505/1555 train_time:91238ms step_avg:60.62ms +step:1506/1555 train_time:91329ms step_avg:60.64ms +step:1507/1555 train_time:91411ms step_avg:60.66ms +step:1508/1555 train_time:91500ms step_avg:60.68ms +step:1509/1555 train_time:91583ms step_avg:60.69ms +step:1510/1555 train_time:91673ms step_avg:60.71ms +step:1511/1555 train_time:91757ms step_avg:60.73ms +step:1512/1555 train_time:91844ms step_avg:60.74ms +step:1513/1555 train_time:91931ms step_avg:60.76ms +step:1514/1555 train_time:92022ms step_avg:60.78ms +step:1515/1555 train_time:92108ms step_avg:60.80ms +step:1516/1555 train_time:92203ms step_avg:60.82ms +step:1517/1555 train_time:92288ms step_avg:60.84ms +step:1518/1555 train_time:92377ms step_avg:60.85ms +step:1519/1555 train_time:92461ms step_avg:60.87ms +step:1520/1555 train_time:92550ms step_avg:60.89ms +step:1521/1555 train_time:92633ms step_avg:60.90ms +step:1522/1555 train_time:92722ms step_avg:60.92ms +step:1523/1555 train_time:92806ms step_avg:60.94ms +step:1524/1555 train_time:92898ms step_avg:60.96ms +step:1525/1555 train_time:92983ms step_avg:60.97ms +step:1526/1555 train_time:93077ms step_avg:60.99ms +step:1527/1555 train_time:93161ms step_avg:61.01ms +step:1528/1555 train_time:93250ms step_avg:61.03ms +step:1529/1555 train_time:93335ms step_avg:61.04ms +step:1530/1555 train_time:93424ms step_avg:61.06ms +step:1531/1555 train_time:93508ms step_avg:61.08ms +step:1532/1555 train_time:93600ms step_avg:61.10ms +step:1533/1555 train_time:93683ms step_avg:61.11ms +step:1534/1555 train_time:93773ms step_avg:61.13ms +step:1535/1555 train_time:93857ms step_avg:61.14ms +step:1536/1555 train_time:93947ms step_avg:61.16ms +step:1537/1555 train_time:94034ms step_avg:61.18ms +step:1538/1555 train_time:94123ms step_avg:61.20ms +step:1539/1555 train_time:94209ms step_avg:61.21ms +step:1540/1555 train_time:94300ms step_avg:61.23ms +step:1541/1555 train_time:94383ms step_avg:61.25ms +step:1542/1555 train_time:94475ms step_avg:61.27ms +step:1543/1555 train_time:94558ms step_avg:61.28ms +step:1544/1555 train_time:94648ms step_avg:61.30ms +step:1545/1555 train_time:94732ms step_avg:61.32ms +step:1546/1555 train_time:94821ms step_avg:61.33ms +step:1547/1555 train_time:94907ms step_avg:61.35ms +step:1548/1555 train_time:94998ms step_avg:61.37ms +step:1549/1555 train_time:95082ms step_avg:61.38ms +step:1550/1555 train_time:95174ms step_avg:61.40ms +step:1551/1555 train_time:95258ms step_avg:61.42ms +step:1552/1555 train_time:95347ms step_avg:61.44ms +step:1553/1555 train_time:95432ms step_avg:61.45ms +step:1554/1555 train_time:95523ms step_avg:61.47ms +step:1555/1555 train_time:95608ms step_avg:61.48ms +step:1555/1555 val_loss:3.2771 train_time:95723ms step_avg:61.56ms +peak memory allocated: 31630 MiB reserved: 46538 MiB diff --git a/records/track_1_short/2026-01-31-BigramHashH2D/4cf9ea26-1f26-4e64-b052-0b0ff0763178.txt b/records/track_1_short/2026-01-31-BigramHashH2D/4cf9ea26-1f26-4e64-b052-0b0ff0763178.txt new file mode 100644 index 000000000..e677a7bd3 --- /dev/null +++ b/records/track_1_short/2026-01-31-BigramHashH2D/4cf9ea26-1f26-4e64-b052-0b0ff0763178.txt @@ -0,0 +1,3976 @@ +import os +import sys + +# Read the current file and the kernels file code ASAP, for logging +with open(sys.argv[0], 'r') as f: + code = f.read() +with open(os.path.join(os.path.dirname(sys.argv[0]), 'triton_kernels.py'), 'r') as f: + code += f"\n\n{'-'*40}\n# triton_kernels.py\n{'-'*40}\n\n" + code += f.read() + +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate, pairwise +from pathlib import Path +import gc + +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +import torch +import triton + +torch.empty( + 1, device=f"cuda:{os.environ['LOCAL_RANK']}", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +from kernels import get_kernel +from torch import Tensor, nn + +from triton_kernels import XXT, ba_plus_cAA, FusedLinearReLUSquareFunction, FusedSoftcappedCrossEntropy + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Distributed training setup +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +grad_scale = 2 / grad_accum_steps # consistent grad magnitudes between different num_devices +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng +# Transposed layout by @ChrisJMcCormick allows for faster gradient accumulation. + +@torch.library.custom_op("nanogpt::mm_t", mutates_args=()) +def mm_t_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + """Computes y = x @ w with F8 weights stored as (in_features, out_features).""" + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + assert x.shape[1] == w.shape[0] # x: (batch, in), w: (in, out) + + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + + # _scaled_mm requires column-major B. w_f8 is row-major (in, out). + # .T.contiguous().T creates a column-major view without changing logical shape. + w_f8_col_major = w_f8.T.contiguous().T + + out = torch._scaled_mm( + x_f8, + w_f8_col_major, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_t_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[0] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_t_backward", mutates_args=()) +def mm_t_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + + x_scale = grad.new_tensor(x_s, dtype=torch.float32) + w_scale = grad.new_tensor(w_s, dtype=torch.float32) + grad_scale = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + + # grad_x = grad @ w.T + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=grad_scale, + scale_b=w_scale, + use_fast_accum=False, + ) + + # grad_w = x.T @ grad + # Result is (in, out), naturally matching weight storage. No final .T needed. + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_scale, + scale_b=grad_scale, + use_fast_accum=False, + ) + + return grad_x, grad_w + + grad_x, grad_w = impl(g, x_f8, w_f8) + + return grad_x, grad_w + +@mm_t_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) + +def backward_t(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_t_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context_t(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_t_op.register_autograd(backward_t, setup_context=setup_context_t) + +# ----------------------------------------------------------------------------- +# Polar Express + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor, split_baddbmm: bool = False): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + # Select batched vs unbatched + if split_baddbmm: + BX_matmul = torch.bmm if X.ndim > 2 else torch.mm + else: + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in polar_express_coeffs: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + + # Referencing X twice causes pytorch to make a defensive copy, + # resulting in a cudaMemcpyAsync in baddbmm. + # For large matrices (i.e., the mlp weights), it's faster to split + # the operation into two kernels to avoid this. + if split_baddbmm: + BX_matmul(B, X, out=C) # C = B @ X + C.add_(X, alpha=a) # C = C + a*X (in-place, X only read) + else: + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Combined NorMuon + Adam Optimizer + +@dataclass +class ParamConfig: + """Per-parameter configuration for NorMuonAndAdam optimizer.""" + label: str + optim: str # "adam" or "normuon" + comms: str # "none", "replicated", or "sharded" + adam_betas: tuple[float, float] | None + lr_mul: float + wd_mul: float + lr: float + initial_lr: float + weight_decay: float + # Adam-specific + eps: float | None = None + # NorMuon-specific + reshape: tuple | None = None + chunk_size: int | None = None + momentum: float | None = None + beta2: float | None = None + per_matrix_lr_mul: list[float] | None = None + + +class NorMuonAndAdam: + """ + Combined optimizer that handles both NorMuon (for projection matrices) and + Adam (for embeddings/scalars/gate weights). + + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, Muon uses a Newton-Schulz iteration (replaced + here with Polar Express), which has the advantage that it can be stably run in bfloat16 on the GPU. + + Muon is applied only to the projection matrices in the attention and MLP layers, and is not recommended + for embeddings, scalars, or individual weight vectors (e.g., bias terms or gate weights). + + Differences from standard Muon: + - Newton-Shulz is replaced with Polar Express for the orthogonalization step + - NorMuon adds a low-rank variance estimator similar to Adafactor. https://arxiv.org/pdf/2510.05491 + - Cautious weight decay, a gated version of decoupled weight decay + - Mantissa tracking for precision + + Adam (for embeddings/scalars/gates): + - Standard Adam with bias correction + - Cautious weight decay + + Configuration: + Unlike torch.optim.Optimizer, this class uses per-parameter configs from a `param_table` dict + and does not include parameter "groups". All parameters require a .label attribute, and a + corresponding entry in the param_table to specify their hyperparameters (lr_mul, wd_mul, adam_betas, etc.). + + Communication and ordering: + Gradient communication is explicitly scheduled rather than hook-driven. + Reductions are launched in `scatter_order`, while update math and final + gathers are executed in `work_order`. These orders are independent and + must each contain every parameter label exactly once. + + Two communication modes are supported per parameter: + - 'replicated': Gradients are all-reduced and each rank computes the full update. + - 'sharded': Gradients are reduce-scattered, each rank updates its shard, + and results are all-gathered. + + Adam parameters may be freely sharded. NorMuon operates on full matrices; sharding is + supported by grouping matrices into parameter banks. NorMuon parameters must have a + `.reshape` attribute that reshapes the bank so that the leading dimension is divisible + by world_size. + + # Contributors include @YouJiacheng, @KonstantinWilleke, @alexrgilbert, @adricarda, + # @tuttyfrutyee, @vdlad, @ryanyang0, @vagrawal, @varunneal, @chrisjmccormick + """ + def __init__(self, named_params, param_table: dict, scatter_order: list, work_order: list, + adam_defaults: dict, normuon_defaults: dict): + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # Store defaults for each optimizer type + self.adam_defaults = adam_defaults + self.normuon_defaults = normuon_defaults + self.param_table = param_table + self.scatter_order = scatter_order + self.work_order = work_order + + # Collect params by label and build config + self.param_cfgs: dict[nn.Parameter, ParamConfig] = {} + self.param_states: dict[nn.Parameter, dict] = {} + self._param_by_label: dict[str, nn.Parameter] = {} + for name, param in named_params: + label = getattr(param, "label", None) + assert label is not None and label in param_table # all params must have valid label + assert label not in self._param_by_label # exactly one param per label + self._param_by_label[label] = param + self._build_param_cfg(param, label) + + # Assert scatter_order and work_order match present labels exactly + present = set(self._param_by_label.keys()) + assert set(scatter_order) == present and set(work_order) == present + + # Handle world_size=1: overwrite comms to "none" + if self.world_size == 1: + for p_cfg in self.param_cfgs.values(): + p_cfg.comms = "none" + + # Initialize state for all params + self._init_state() + + # 0-D CPU tensors to avoid recompilation + self._step_size_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + + # Track async operations + self._reduce_futures: dict[nn.Parameter, tuple] = {} + + # Embed/lm_head tying state + self.split_embed = False + self._lm_head_param = self._param_by_label.get("lm_head") + self._embed_param = self._param_by_label.get("embed") + + def _build_param_cfg(self, param: nn.Parameter, label: str): + """Build config for a single parameter from param_table.""" + table_entry = self.param_table[label] + optim = table_entry["optim"] + comms = table_entry["comms"] + adam_betas = table_entry.get("adam_betas") + lr_mul = table_entry.get("lr_mul", 1.0) + wd_mul = table_entry.get("wd_mul", 1.0) + + if optim == "adam": + chunk_size = param.shape[0] // self.world_size if comms == "sharded" else None + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.adam_defaults["lr"], + initial_lr=self.adam_defaults["lr"], + weight_decay=self.adam_defaults["weight_decay"], + eps=self.adam_defaults["eps"], + chunk_size=chunk_size, + ) + elif optim == "normuon": + reshape = getattr(param, "reshape", None) + if reshape is None: + raise ValueError(f"NorMuon param {label} must have .reshape attribute") + if reshape[0] % self.world_size != 0: + raise ValueError(f"reshape[0]={reshape[0]} must be divisible by world_size") + + chunk_size = reshape[0] // self.world_size + chunk_shape = (chunk_size, *reshape[1:]) + # Shape-based LR multiplier for NorMuon + shape_mult = max(1.0, chunk_shape[-2] / chunk_shape[-1]) ** 0.5 if len(chunk_shape) >= 2 else 1.0 + lr_mul = shape_mult * lr_mul + + # Per-matrix LR multipliers for MLP c_proj (2x LR on odd indices) + per_matrix_lr_mul = None + if label == "mlp": + rank = dist.get_rank() if dist.is_initialized() else 0 + start_idx = rank * chunk_size + per_matrix_lr_mul = [] + for i in range(chunk_size): + global_idx = start_idx + i + is_c_proj = (global_idx % 2 == 1) + per_matrix_lr_mul.append(2.0 if is_c_proj else 1.0) + + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.normuon_defaults["lr"], + initial_lr=self.normuon_defaults["lr"], + weight_decay=self.normuon_defaults["weight_decay"], + reshape=reshape, + chunk_size=chunk_size, + momentum=self.normuon_defaults["momentum"], + beta2=self.normuon_defaults["beta2"], + per_matrix_lr_mul=per_matrix_lr_mul, + ) + else: + raise ValueError(f"Unknown optim type: {optim}") + + self.param_cfgs[param] = p_cfg + + def _init_state(self): + """Initialize optimizer state for all parameters.""" + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam": + # Sharded params use chunk state, replicated use full state + if p_cfg.comms == "sharded": + chunk = param[:p_cfg.chunk_size] + else: + chunk = param + exp_avg = torch.zeros_like(chunk, dtype=torch.float32, device=param.device) + self.param_states[param] = dict(step=0, exp_avg=exp_avg, exp_avg_sq=torch.zeros_like(exp_avg)) + + elif p_cfg.optim == "normuon": + chunk_shape = (p_cfg.chunk_size, *p_cfg.reshape[1:]) + + # Momentum buffer (FP32 for precision) + momentum_buffer = torch.zeros( + chunk_shape, dtype=torch.float32, device=param.device + ) + + # Second momentum buffer - reduced along one dimension + if chunk_shape[-2] >= chunk_shape[-1]: + second_mom_shape = (*chunk_shape[:-1], 1) + else: + second_mom_shape = (*chunk_shape[:-2], 1, chunk_shape[-1]) + second_momentum_buffer = torch.zeros( + second_mom_shape, dtype=torch.float32, device=param.device + ) + + # Mantissa buffer for precision tracking + mantissa = torch.zeros( + chunk_shape, dtype=torch.uint16, device=param.device + ) + + self.param_states[param] = dict( + momentum_buffer=momentum_buffer, + second_momentum_buffer=second_momentum_buffer, + mantissa=mantissa, + ) + + # ----------------------------------- + # Reduce/Gather operations + + def _launch_reduce(self, param: nn.Parameter, grad: Tensor): + """Launch async reduce for a parameter based on its comms policy.""" + p_cfg = self.param_cfgs[param] + + if p_cfg.comms == "none": + if p_cfg.optim == "normuon": + # NorMuon needs reshaped gradient even without communication + grad = grad.view(p_cfg.reshape) + self._reduce_futures[param] = (None, grad) + elif p_cfg.comms == "replicated": + future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + self._reduce_futures[param] = (future, grad) + elif p_cfg.comms == "sharded": + if p_cfg.optim == "normuon": + # NorMuon: reshape before reduce_scatter + grad_reshaped = grad.view(p_cfg.reshape) + grad_chunk = torch.empty( + (p_cfg.chunk_size, *grad_reshaped.shape[1:]), + dtype=grad.dtype, + device=grad.device + ) + future = dist.reduce_scatter_tensor( + grad_chunk, grad_reshaped.contiguous(), op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + else: + # Adam: simple reduce_scatter + grad_chunk = torch.empty_like(grad[:p_cfg.chunk_size]) + future = dist.reduce_scatter_tensor( + grad_chunk, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + + def _launch_gather(self, param: nn.Parameter, p_slice: Tensor) -> "torch.futures.Future": + """Launch async all_gather for a sharded parameter.""" + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "normuon": + full_param = param.data.view(p_cfg.reshape) + assert full_param.is_contiguous() + return dist.all_gather_into_tensor( + full_param, p_slice.contiguous(), async_op=True + ).get_future() + else: + return dist.all_gather_into_tensor( + param, p_slice.contiguous(), async_op=True + ).get_future() + + # ----------------------------------- + # State management + + def reset(self): + """Reset NorMuon momentum buffers and split_embed state (called on training reset).""" + self.split_embed = False + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "normuon": + p_state = self.param_states[param] + p_state["momentum_buffer"].zero_() + p_state["mantissa"].zero_() + p_state["second_momentum_buffer"].zero_() + + def copy_lm_state_to_embed(self): + """ + Copy the optimizer state from the lm_head to the embed at the untie point. + This requires an all-gather + reshard because of different sharding: + - lm_head (768, 50304) is sharded to (96, 50304) per rank (along model_dim) + - embed (50304, 768) is sharded to (6288, 768) per rank (along vocab_size) + + We all-gather the lm_head momentum, transpose it, then each rank takes their + embed shard to get the correct momentum state. + """ + lm_head = self._lm_head_param + embed = self._embed_param + lm_state = self.param_states[lm_head] + embed_state = self.param_states[embed] + lm_cfg = self.param_cfgs[lm_head] + embed_cfg = self.param_cfgs[embed] + + embed_state['step'] = lm_state['step'] # Preserve step count for bias correction + + # Copy optimizer state with all-gather + transpose + reshard + if self.world_size > 1: + rank = dist.get_rank() + lm_chunk_size = lm_cfg.chunk_size # 96 + embed_chunk_size = embed_cfg.chunk_size # 6288 + + # All-gather lm_head momentum to get full (768, 50304) tensor + for key in ["exp_avg", "exp_avg_sq"]: + lm_chunk = lm_state[key] # (96, 50304) + full_lm = torch.empty(lm_head.shape[0], lm_head.shape[1], dtype=lm_chunk.dtype, device=lm_chunk.device) + dist.all_gather_into_tensor(full_lm, lm_chunk.contiguous()) + embed_state[key].copy_(full_lm.T[rank * embed_chunk_size:(rank + 1) * embed_chunk_size]) + else: + # Single GPU: simple transpose + for key in ["exp_avg", "exp_avg_sq"]: + embed_state[key].copy_(lm_state[key].T) + + # Mark as split + self.split_embed = True + + def state_dict(self): + """Return the optimizer state as a dict.""" + return { + "param_states": {id(p): s for p, s in self.param_states.items()}, + "param_cfgs": {id(p): s for p, s in self.param_cfgs.items()}, + } + + def load_state_dict(self, state_dict): + """Load optimizer state from a dict.""" + # Build id->param mapping + id_to_param = {id(p): p for p in self.param_cfgs.keys()} + + # Load state, preserving dtypes + for param_id, saved_p_state in state_dict["param_states"].items(): + if param_id in id_to_param: + param = id_to_param[param_id] + p_state = self.param_states[param] + for k, v in saved_p_state.items(): + if isinstance(v, torch.Tensor) and k in p_state: + target_dtype = p_state[k].dtype + p_state[k] = v.to(dtype=target_dtype, device=p_state[k].device) + else: + p_state[k] = v + + # ----------------------------------- + # Unified optimizer step with explicit ordering + + @torch.no_grad() + def step(self, do_adam: bool = True): + """ + Combined optimizer step with explicit ordering. + + Args: + do_adam: If True, update Adam params. NorMuon params always updated. + + Flow: + 1. Scatter phase: Launch reduces in scatter_order + 2. Work phase: Process updates in work_order + - Wait for reduce, compute update, launch gather + 3. Finalize phase: Wait for gathers + + While the embeddings are tied: + - Comms and update math are only done on lm_head. + - We add embed.grad.T into lm_head.grad before comms. + - After lm_head gather, we copy lm_head.data.T --> embed.data + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + lm_param, embed_param = self._lm_head_param, self._embed_param + + # ===== Phase 1: Launch reduces in scatter_order ===== + for label in self.scatter_order: + param = self._param_by_label[label] + p_cfg = self.param_cfgs[param] + + if p_cfg.optim == "adam" and not do_adam: + continue + if param.grad is None: + continue + + # lm_head when tied: aggregate embed.grad.T (transposed shapes) + if label == "lm_head" and do_adam and not self.split_embed: + if embed_param is not None and embed_param.grad is not None: + param.grad.add_(embed_param.grad.T) + + # Skip embed when tied (copied from lm_head after gather) + if label == "embed" and not self.split_embed: + continue + + self._launch_reduce(param, param.grad) + + # ===== Phase 2: Process updates in work_order ===== + gather_futures = [] + lm_head_gather_future = None + + for label in self.work_order: + param = self._param_by_label[label] + if param not in self._reduce_futures: + continue + + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "adam" and not do_adam: + continue + # Wait for reduce + future, grad_chunk = self._reduce_futures[param] + if future is not None: + future.wait() + # Apply update based on optim type + if p_cfg.optim == "adam": + p_slice = self._adam_update(param, grad_chunk, p_cfg, rank) + else: + p_slice = self._normuon_update(param, grad_chunk, p_cfg, rank) + # Launch gather for sharded params + if p_cfg.comms == "sharded" and self.world_size > 1: + gather_fut = self._launch_gather(param, p_slice) + if label == "lm_head": + lm_head_gather_future = gather_fut + else: + gather_futures.append(gather_fut) + + # ===== Phase 3: Wait for gathers, sync embed if tied ===== + # Wait for lm_head gather first so we can copy to embed while other gathers complete + if lm_head_gather_future is not None: + lm_head_gather_future.wait() + + # When tied: copy lm_head.T to embed + if do_adam and not self.split_embed and embed_param is not None and lm_param is not None: + embed_param.data.copy_(lm_param.data.T) + + # Wait for remaining gathers + for fut in gather_futures: + fut.wait() + + self._reduce_futures.clear() + + # Clear grads for updated params + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam" and not do_adam: + continue # Don't clear Adam grads on even steps + param.grad = None + + # ----------------------------------- + # Adam update + + def _adam_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply Adam update to a parameter. Returns the updated p_slice.""" + beta1, beta2 = p_cfg.adam_betas + lr = p_cfg.lr * p_cfg.lr_mul + + # Get parameter slice + if p_cfg.comms == "sharded": + p_slice = param[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + else: + p_slice = param + + p_state = self.param_states[param] + p_state["step"] += 1 + t = p_state["step"] + + bias1, bias2 = 1 - beta1 ** t, 1 - beta2 ** t + self._step_size_t.fill_(lr * (bias2 ** 0.5 / bias1)) + self._eff_wd_t.fill_(lr * lr * p_cfg.weight_decay * p_cfg.wd_mul) + + NorMuonAndAdam._adam_update_step( + p_slice, grad_chunk, p_state["exp_avg"], p_state["exp_avg_sq"], + beta1, beta2, p_cfg.eps, self._step_size_t, self._eff_wd_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _adam_update_step(p_slice, g_slice, exp_avg, exp_avg_sq, beta1, beta2, eps, step_size_t, eff_wd_t): + """Compiled Adam update step.""" + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + update = exp_avg.div(exp_avg_sq.sqrt().add_(eps)).mul_(step_size_t) + # Cautious weight decay + mask = (update * p_slice) > 0 + update.addcmul_(p_slice, mask, value=eff_wd_t) + p_slice.add_(other=update, alpha=-1.0) + + # ----------------------------------- + # NorMuon update + + def _normuon_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply NorMuon update to a parameter. Returns the updated p_slice.""" + chunk_shape = grad_chunk.shape + + p_state = self.param_states[param] + grad_chunk = grad_chunk.float() # FP32 for momentum + + # Momentum update + momentum_buffer = p_state["momentum_buffer"] + momentum_buffer.lerp_(grad_chunk, 1 - p_cfg.momentum) + updated_grads = grad_chunk.lerp_(momentum_buffer, p_cfg.momentum) + + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + + # Polar Express orthogonalization + is_large_matrix = chunk_shape[-2] > 1024 + v_chunk = polar_express(updated_grads, split_baddbmm=is_large_matrix) + + # Variance reduction + red_dim = -1 if chunk_shape[-2] >= chunk_shape[-1] else -2 + v_chunk = NorMuonAndAdam._apply_normuon_variance_reduction( + v_chunk, p_state["second_momentum_buffer"], p_cfg.beta2, red_dim + ) + + # Update parameter, in place, with cautious weight decay + param_view = param.data.view(p_cfg.reshape) + p_slice = param_view[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + + # MLP has per-matrix LR multipliers (c_proj gets 2x LR) + if p_cfg.per_matrix_lr_mul is not None: + for mat_idx in range(p_cfg.chunk_size): + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.per_matrix_lr_mul[mat_idx] * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice[mat_idx].view(torch.uint16), p_state["mantissa"][mat_idx], v_chunk[mat_idx], + self._eff_wd_t, self._eff_lr_t + ) + else: + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice.view(torch.uint16), p_state["mantissa"], v_chunk, + self._eff_wd_t, self._eff_lr_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _cautious_wd_and_update_inplace(p, mantissa, grad, wd_tensor, lr_tensor): + """ + Cautious weight decay + parameter update. wd_tensor and lr_tensor are 0-D CPU tensors. + Mantissa is tracked to enable higher precision updates on bfloat16 parameters. + bfloat16 format: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits total + float32 format: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits total + """ + assert p.dtype == mantissa.dtype == torch.uint16 + grad = grad.float() + wd_factor = wd_tensor.to(torch.float32) + lr_factor = lr_tensor.to(torch.float32) + p_precise_raw = (p.to(torch.uint32) << 16) | mantissa.to(torch.uint32) + p_precise = p_precise_raw.view(torch.float32) + mask = (grad * p_precise) >= 0 + p_precise.copy_(p_precise - (p_precise * mask * wd_factor * lr_factor) - (grad * lr_factor)) + p.copy_((p_precise_raw >> 16).to(torch.uint16)) + mantissa.copy_(p_precise_raw.to(torch.uint16)) + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _apply_normuon_variance_reduction(v_chunk, second_momentum_buffer, beta2, red_dim): + """NorMuon variance reduction. Algebraically fuses the normalization steps to minimize memory ops.""" + v_mean = v_chunk.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = v_chunk.size(red_dim) + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True).mul_(red_dim_size) + v_norm = v_norm_sq.sqrt_() + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt_() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt_() + final_scale = step_size * (v_norm / v_norm_new.clamp_min_(1e-10)) + return v_chunk.mul_(final_scale.type_as(v_chunk)) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinearT(nn.Module): + """ + Linear layer with transposed weight storage (in_features, out_features) which + addresses the slow kernel that was used for gradient accumulation. @chrisjmccormick + """ + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + self.weight = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.bfloat16)) + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + nn.init.zeros_(self.weight) # @Grad62304977 and others + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out = torch.ops.nanogpt.mm_t(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return x @ self.weight.type_as(x) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len, paired=False): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.paired = paired + self.reset() + + def rotary(self, x_BTHD): + assert self.factor1.size(0) >= x_BTHD.size(-3) + factor1, factor2 = ( + self.factor1[None, : x_BTHD.size(-3), None, :], + self.factor2[None, : x_BTHD.size(-3), None, :], + ) + x_flip = x_BTHD.view(*x_BTHD.shape[:-1], x_BTHD.shape[-1] // 2, 2).flip(-1).view(x_BTHD.shape) + return factor1 * x_BTHD + factor2 * x_flip + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + angular_freq = angular_freq.repeat_interleave(2) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//2)]) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=device) + if not self.paired: + theta = torch.outer(t, angular_freq) + self.factor1 = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.factor2 = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, angular_freq) + theta2 = torch.outer(t_odd, angular_freq) + self.factor1 = nn.Buffer( + torch.cat((theta1.cos(), theta2.cos()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2 = nn.Buffer( + torch.cat((theta1.sin(), theta2.sin()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2[..., 1::2] *= -1 + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + if not self.paired: + theta = torch.outer(t, self.angular_freq) + self.factor1.copy_(theta.cos()) + self.factor2.copy_(theta.sin()) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, self.angular_freq) + theta2 = torch.outer(t_odd, self.angular_freq) + self.factor1.copy_(torch.cat((theta1.cos(), theta2.cos()), dim=-1)) + self.factor2.copy_(torch.cat((theta1.sin(), theta2.sin()), dim=-1)) + self.factor2[..., 1::2] *= -1 + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + yarn: Yarn + key_offset: bool + attn_gate_w: torch.Tensor + ve_gate_w: torch.Tensor + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, paired: bool = False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + self.paired = paired + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + yarn = attn_args.yarn + ve, sa_lambdas, key_offset = attn_args.ve, attn_args.sa_lambdas, attn_args.key_offset + seqlens, bm_size = attn_args.seqlens, attn_args.bm_size + # sparse gated attention to enable context based no-op by @classiclarryd + # only include gates on layers with value embeds used on forward pass + attn_gate_w, ve_gate_w = attn_args.attn_gate_w, attn_args.ve_gate_w + + q, k, v = F.linear(x, sa_lambdas[0] * qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + q, k = norm(q), norm(k) # QK norm @Grad62304977 + + if not self.paired: + q, k = yarn.rotary(q), yarn.rotary(k) + + if key_offset: + # shift keys forward for the stationary head dims. Enables 1-layer induction. + k[:, 1:, :, self.head_dim // 2:] = k[:, :-1, :, self.head_dim // 2:] + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T, self.num_heads, 1) + v = v + ve_gate_out * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + + else: + # Paired heads: adjacent heads' queries attend to each other's keys. + # Two copies of the input stream are interleaved to achieve this, which: + # - doubles the length of each sequence + # - halves the effective window size + q = q.view(B, T, self.num_heads // 2, self.head_dim * 2) + k = k.view(B, T, self.num_heads // 2, self.head_dim * 2) + v = v.reshape(B, T * 2, self.num_heads // 2, self.head_dim) + + q, k = yarn.rotary(q), yarn.rotary(k) + + q = q.view(B, T * 2, self.num_heads // 2, self.head_dim) + k = k.view(B, T * 2, self.num_heads // 2, self.head_dim) + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T * 2, self.num_heads // 2, 1) + v = v + ve_gate_out * ve.view_as(v) + + seqlens = 2 * seqlens + max_len = 2 * max_len + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, + max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=yarn.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(F.linear(x[..., :12], attn_gate_w)).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, sa_lambdas[1] * qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg + return y + +class MLP(nn.Module): + def __init__(self): + super().__init__() + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, c_fc: Tensor, c_proj: Tensor): + # relu(x)^2: + # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + # Fused triton kernel for relu(x @ W1.T)^2 @ W2.T + return FusedLinearReLUSquareFunction.apply(x, c_fc, c_proj) + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, has_attn: bool, has_mlp: bool, use_paired_head: bool): + super().__init__() + # skip attention of blocks.6 (the 7th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads, paired=use_paired_head) if has_attn else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP() if has_mlp else None + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor = None, c_fc: Tensor = None, c_proj: Tensor = None): + if self.attn is not None: + x = x + self.attn(norm(x), attn_args, qkvo_w) + if self.mlp is not None: + x = x + self.mlp(norm(x), c_fc, c_proj) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +@dataclass +class ForwardScheduleConfig: + mtp_weights: torch.Tensor + ws_short: int + ws_long: int + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + self.num_layers = num_layers + self.vocab_size = next_multiple_of_n(vocab_size, n=128) + + self.smear_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.smear_gate.weight) + self.smear_gate.weight.label = 'smear_gate' + + self.skip_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.skip_gate.weight) + self.skip_gate.weight.label = 'skip_gate' + + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.Parameter(torch.zeros(5 * self.vocab_size, model_dim, dtype=torch.bfloat16)) + self.value_embeds.label = 'value_embed' + + # parameter banks for attention and value embedding gate weights + self.attn_gate_bank = nn.Parameter(torch.zeros(10, num_heads, 12)) # 10 layers + self.attn_gate_bank.label = 'attn_gate_bank' + self.ve_gate_bank = nn.Parameter(torch.zeros(5, num_heads, 12)) # 5 unique gates + self.ve_gate_bank.label = 've_gate_bank' + + # ----------------------------------- + # Parameter banks for sharded optimization, by @chrisjmccormick + + # Identify which layers have attention/MLP + # Attention is skipped in layer 6 by @YouJiacheng + self.attn_layer_indices = [i for i in range(num_layers) if i != 6] + # All layers have MLP (At 11 layers--dropped first layer @EmelyanenkoK) + self.mlp_layer_indices = list(range(num_layers)) + + hdim = num_heads * head_dim + mlp_hdim = 4 * model_dim + + # Create index mappings: layer_idx -> bank_idx + self.layer_to_attn_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.attn_layer_indices)} + self.layer_to_mlp_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.mlp_layer_indices)} + + # Attention bank: stores QKVO weights for all attention layers + # merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # Simplified layout by @chrisjmccormick + # Shape: (num_attn_layers, 4*model_dim, hdim) = (10, 3072, 768) + # Reshape for sharding: (40, 768, 768) for even distribution across 8 GPUs + self.attn_bank = nn.Parameter(torch.empty(len(self.attn_layer_indices), 4 * model_dim, hdim)) + self.attn_bank.label = 'attn' + self.attn_bank.reshape = (len(self.attn_layer_indices) * 4, hdim, hdim) # (40, 768, 768) + + # MLP bank: stores c_fc and c_proj for all MLP layers + # Shape: (num_mlp_layers + padding, 2, mlp_hdim, model_dim) = (12, 2, 3072, 768) + # We add 1 padding layer (index 11) to get 12*2=24 matrices for even distribution across 8 GPUs + # Reshape for sharding: (24, 3072, 768) + num_mlp_with_padding = len(self.mlp_layer_indices) + 1 # 11 + 1 = 12 + self.mlp_bank = nn.Parameter(torch.empty(num_mlp_with_padding, 2, mlp_hdim, model_dim)) + self.mlp_bank.label = 'mlp' + self.mlp_bank.reshape = (num_mlp_with_padding * 2, mlp_hdim, model_dim) # (24, 3072, 768) + + # improved init scale by @YouJiacheng and @srashedll + std = 0.5 * model_dim ** -0.5 + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.attn_bank.uniform_(-bound, bound) + self.mlp_bank[:, 0, :, :].uniform_(-bound, bound) # c_fc + self.mlp_bank[:, 1, :, :].zero_() # c_proj - zero init suggested by @Grad62304977 + + # Create blocks with has_attn/has_mlp flags + self.paired_head_layers = [0, 2, 5, 9] + self.blocks = nn.ModuleList([ + Block(model_dim, head_dim, num_heads, + has_attn=(i in self.layer_to_attn_idx), + has_mlp=(i in self.layer_to_mlp_idx), + use_paired_head=(i in self.paired_head_layers)) + for i in range(num_layers) + ]) + self.yarn = Yarn(head_dim, max_seq_len) + self.yarn_paired_head = Yarn(head_dim, max_seq_len, paired=True) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + # Transposed weight storage for faster gradient accumulation + self.lm_head = CastedLinearT(model_dim, self.vocab_size, use_fp8=use_fp8, x_s=100/448, w_s=1.6/448, grad_s=grad_scale * 0.75/448) + + nn.init.normal_(self.lm_head.weight, mean=0, std=0.005) + self.lm_head.weight.label = 'lm_head' + + self.embed = nn.Embedding(self.vocab_size, model_dim) + self.embed.weight.label = 'embed' + with torch.no_grad(): + self.embed.weight.copy_(self.lm_head.weight.T) + + self.bigram_embed = nn.Embedding(args.bigram_vocab_size, model_dim) + self.bigram_embed.weight.label = 'bigram_embed' + nn.init.zeros_(self.bigram_embed.weight) + + # x0_lambdas separated out for different optimizer treatment (no beta smoothing) + self.x0_lambdas = nn.Parameter(torch.zeros(num_layers)) + self.x0_lambdas.label = 'x0_lambdas' + + pad = (-num_layers * 3 - 3) % dist.get_world_size() # updated: 3*num_layers instead of 4* + self.scalars = nn.Parameter( + torch.cat( + [ + 1.1 * torch.ones(num_layers), # resid lambdas. 1.1 init such that layer i weight is i^(num_layers-i). + *[torch.tensor([0.5, 1.0]) for _ in range(num_layers)], # SA lambdas + 0.1 * torch.ones(num_layers), # bigram lambdas + torch.zeros(1), # smear_lambda + 0.5*torch.ones(1), # backout_lambda + -1.5 * torch.ones(1), # skip_lambda -> σ(-1.5) ≈ 0.18 + torch.ones(pad), + ] + ) + ) + self.scalars.label = 'scalars' + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _compute_bigram_hash(x: Tensor, mod: int) -> Tensor: + """ + Computes bigram hash on GPU for each position using [prev_token, curr_token]. + Mathematically identical to the CPU version but computed on device. + """ + rand_int_1 = 36313 + rand_int_2 = 27191 + result = torch.empty_like(x) + result[0] = mod + result[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod + return result + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, schedule_cfg: ForwardScheduleConfig): + assert input_seq.ndim == 1 + + # unpack schedule_cfg + mtp_weights, ws_short, ws_long = schedule_cfg.mtp_weights, schedule_cfg.ws_short, schedule_cfg.ws_long + + # set configs + skip_connections = [] + skip_in = [3] # long attention window on layer 3 + skip_out = [6] # no attn op on layer 6 + x_backout = None + backout_layer = 7 + + # set lambdas + resid_lambdas = self.scalars[: 1 * self.num_layers] + x0_lambdas = self.x0_lambdas + sa_lambdas = self.scalars[1 * self.num_layers: 3 * self.num_layers].view(-1, 2) + bigram_lambdas = self.scalars[3 * self.num_layers: 4 * self.num_layers] + smear_lambda = self.scalars[4 * self.num_layers] + backout_lambda = self.scalars[4 * self.num_layers+1] + skip_lambda = self.scalars[4 * self.num_layers+2] + + # set block masks and key shift + bm_sizes = [ws_short, ws_short, ws_short, ws_long, ws_short, ws_short, None, ws_short, ws_short, ws_short, ws_long] + assert len(bm_sizes) == self.num_layers + key_offset = [b==ws_long for b in bm_sizes] # apply partial key offset to long windows + + # Embedding lookup - embed is synced from lm_head during tied phase by optimizer + x = self.embed(input_seq) + # Compute bigram hash on GPU (moved from CPU data loader) + bigram_seq = self._compute_bigram_hash(input_seq, args.bigram_vocab_size - 1) + x0_bigram = self.bigram_embed(bigram_seq)[None] + + # Value embeddings - always computed (not precomputed) + ve = self.value_embeds.view(5, self.vocab_size, -1)[:, input_seq] + # 01 ... 234 structure on token value embeddings by @photomz + ve = [ve[0], ve[1]] + [None] * (self.num_layers - 5) + [ve[2], ve[3], ve[4]] + assert len(ve) == self.num_layers + + # smear token embed forward 1 position @classiclarryd + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # unbind gate banks to avoid select_backwards kernel + ag = [w.bfloat16() for w in self.attn_gate_bank.unbind(0)] + veg = [w.bfloat16() for w in self.ve_gate_bank.unbind(0)] + attn_gates = ag[:6] + [None] + ag[6:] + ve_gates = [veg[0], veg[1]] + [None] * (self.num_layers - 5) + [veg[2], veg[3], veg[4]] + assert len(attn_gates) == self.num_layers + assert len(ve_gates) == self.num_layers + + # unbind weight banks to avoid select_backwards kernel + attn_weights = self.attn_bank.unbind(0) # tuple of [4*dim, hdim] tensors + mlp_fcs = self.mlp_bank[:, 0, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + mlp_projs = self.mlp_bank[:, 1, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + + for i in range(self.num_layers): + yarn = self.yarn_paired_head if i in self.paired_head_layers else self.yarn + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + yarn=yarn, + key_offset=key_offset[i], + attn_gate_w=attn_gates[i], + ve_gate_w=ve_gates[i] + ) + if i in skip_out: + skip_gate_out = torch.sigmoid(skip_lambda) * 2 * torch.sigmoid(self.skip_gate(x0[..., :self.skip_gate.weight.size(-1)])) + x = x + skip_gate_out * skip_connections.pop() + if i == 0: + x = (resid_lambdas[0] + x0_lambdas[0]) * x + bigram_lambdas[0] * x0_bigram + else: + x = resid_lambdas[i] * x + x0_lambdas[i] * x0 + bigram_lambdas[i] * x0_bigram + + # Get weights for this layer from banks + qkvo_w = attn_weights[self.layer_to_attn_idx[i]] if i in self.layer_to_attn_idx else None + c_fc = mlp_fcs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + c_proj = mlp_projs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + + x = self.blocks[i](x, attn_args, qkvo_w, c_fc, c_proj) + if i in skip_in: + skip_connections.append(x) + if i == backout_layer: + x_backout = x + + # back out contributions from first 7 layers that are only required for downstream context and not direct prediction + x -= backout_lambda * x_backout + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15 + # @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1). @classiclarryd updated to 23*sigmoid((logits+5)/7.5) + if self.training: + losses = FusedSoftcappedCrossEntropy.apply(logits.view(-1, logits.size(-1)), target_seq, mtp_weights, 23.0, 5.0, 7.5) + loss = losses.sum() + else: + logits = 23 * torch.sigmoid((logits + 5) / 7.5) + logits_for_loss = logits.float() + loss = F.cross_entropy(logits_for_loss.view(-1, logits_for_loss.size(-1)), target_seq, reduction="mean") + return loss +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class Shard: + def __init__(self, tokens: Tensor, world_size: int = 1): + self.tokens = tokens + self.size = tokens.numel() + self.world_size = world_size + self.i = 0 + + # Partial index now, full index async + self.bos_idx = (tokens[:6_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._full_idx = None + self._loader_thread = None + self._ready = threading.Event() + self._loader_thread = threading.Thread(target=self._scan) + self._loader_thread.start() + + def _scan(self): + self._full_idx = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._ready.set() + + def _maybe_switch(self): + # Switch to full index as soon as async scan completes + if self.bos_idx is not self._full_idx and self._ready.is_set(): + self._loader_thread.join() + self.bos_idx = self._full_idx + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + self._maybe_switch() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + return starts, ends + + @staticmethod + def load_async(file: Path, world_size: int = 1): + """Returns getter function for async shard loading""" + result = {} + ready = threading.Event() + def load(): + tokens = _load_data_shard(file) + result['shard'] = Shard(tokens, world_size) + ready.set() + thread = threading.Thread(target=load) + thread.start() + def get(): + ready.wait() + thread.join() + return result['shard'] + return get + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + shard = Shard(tokens, world_size) + next_shard_getter = Shard.load_async(next(file_iter), world_size) + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = shard.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + shard = next_shard_getter() + tokens = shard.tokens + try: + next_shard_getter = Shard.load_async(next(file_iter), world_size) + except StopIteration: + next_shard_getter = None # no more shards to preload + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + # Cast to int32 on CPU before transfer to avoid dtype conversion during .to() + _inputs = _inputs.to(dtype=torch.int32) + _targets = _targets.to(dtype=torch.int64) + _cum_lengths = _cum_lengths.to(dtype=torch.int32) + # Bigram hash computation moved to GPU in forward() + + new_params = yield ( + _inputs.to(device="cuda", non_blocking=True), + _targets.to(device="cuda", non_blocking=True), + _cum_lengths.to(device="cuda", non_blocking=True), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * new_grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens // new_grad_accum_steps + max_seq_len = new_max_seq_len + +# ----------------------------------------------------------------------------- +# Training Management + +@dataclass +class Hyperparameters: + # data + data_path = os.environ.get("DATA_PATH", ".") + train_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_train_*.bin") # input .bin to train on + val_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_val_*.bin") # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + # batch sizes + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # schedule + num_scheduled_iterations: int = 1515 # number of steps to complete lr and ws schedule + num_extension_iterations: int = 40 # number of steps to continue training at final lr and ws + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 250 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # bigram hash embedding + bigram_vocab_size: int = 50304 * 5 + +args = Hyperparameters() + +@dataclass +class TrainingStage: + lr_mul: float + batch_size: int + window_sizes: tuple[int, int] # (short, long) in block units + mtp_weights_start: list[float] + mtp_weights_end: list[float] + duration: float = None + +class TrainingSchedule: + """ + Training schedule initialized via TRAINING_STAGES + 1. Multi Token Prediction schedule of [1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1] @varunneal + 2. Sliding Attention window schedule of [1,3] -> [3,7] -> [5,11] -> [6,13] + 3. YaRN updates to RoPE on window changes + 4. Split embed and lm head at 2/3 of training + 5. Batch size schedule of 8 -> 16 -> 24 + 6. Post training extension of long windows from 13 to 20 + """ + + def __init__(self, stages: list[TrainingStage], scheduled_iterations: int, extension_iterations: int, + cooldown_frac: float = 0.5, split_embed_stage: int = 2, ws_post_yarn_ext: int = 20): + self.stages = stages + self.scheduled_iterations = scheduled_iterations + self.cooldown_frac = cooldown_frac + # increase final validation ws, used for YaRN extension and short window size @classiclarryd + self.ws_post_yarn_ext = ws_post_yarn_ext + + self.total_steps = self.scheduled_iterations + extension_iterations + + # Build stage boundaries (last is extension stage) + ends = [0] + [round(c * scheduled_iterations) for c in accumulate(s.duration for s in stages[:-1])] + [self.total_steps] + assert self.scheduled_iterations == ends[-2] + self.boundaries = list(pairwise(ends)) + + # Split embed at specified stage (ensure odd step for Adam) + self.split_step = self.boundaries[split_embed_stage][0] | 1 + + # Precompute MTP weights for all steps + self.mtp_weights = [] + for step in range(self.total_steps + 1): + stage, t = self.lookup(step) + w = [a + (b - a) * t for a, b in zip(stage.mtp_weights_start, stage.mtp_weights_end)] + self.mtp_weights.append(torch.tensor(w, device=device)) + + def lookup(self, step: int) -> tuple[TrainingStage, float]: + # Returns stage and % of the way through that stage + for i, (start, end) in enumerate(self.boundaries): + if step < end: + t = (step - start) / (end - start) + return self.stages[i], t + return self.stages[-1], 1.0 + + def get_lr(self, step: int) -> float: + # learning rate schedule: tied to batch size schedule, with cooldown at the end + stage, _ = self.lookup(step) + lr = stage.lr_mul + cd_start = int(self.scheduled_iterations * (1 - self.cooldown_frac)) + if step >= cd_start: + t = min(1.0, (step - cd_start) / (self.scheduled_iterations - cd_start)) + lr = lr * (1 - t) + 0.1 * t + return lr + +# window_sizes are in units of `block_size` tokens (defined in TrainingManager) +TRAINING_STAGES = [ + TrainingStage(duration=1/3, batch_size=8 * 2048 * 8, window_sizes=(1, 3), lr_mul=1.0, + mtp_weights_start=[1.0, 0.5, 0.25], mtp_weights_end=[1.0, 0.5, 0.0]), + TrainingStage(duration=1/3, batch_size=16 * 2048 * 8, window_sizes=(3, 7), lr_mul=1.52, # (16/8)**0.6 + mtp_weights_start=[1.0, 0.5], mtp_weights_end=[1.0, 0.0]), + TrainingStage(duration=1/3, batch_size=24 * 2048 * 8, window_sizes=(5, 11), lr_mul=1.73, # (24/8)**0.5 + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), + # extension stage + TrainingStage(batch_size=24 * 2048 * 8, window_sizes=(6, 13), lr_mul=1.0, # lr_mul is not used + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), +] + +training_schedule = TrainingSchedule(TRAINING_STAGES, args.num_scheduled_iterations, args.num_extension_iterations, cooldown_frac=0.55) + +def get_muon_momentum(step: int, muon_warmup_steps=300, muon_cooldown_steps=50, momentum_min=0.85, momentum_max=0.95): + # warmup phase: linearly increase momentum from min to max + # cooldown phase: linearly decrease momentum from max to min + momentum_cd_start = training_schedule.total_steps - muon_cooldown_steps + if step < muon_warmup_steps: + frac = step / muon_warmup_steps + momentum = momentum_min + frac * (momentum_max - momentum_min) + elif step > momentum_cd_start: + frac = (step - momentum_cd_start) / muon_cooldown_steps + momentum = momentum_max - frac * (momentum_max - momentum_min) + else: + momentum = momentum_max + return momentum + +class TrainingManager(): + """ + Manages the NorMuonAndAdam for all parameters with explicit ordering. + 1. Scalars are given higher momentum terms to smooth learning @ChrisJMcCormick + 2. Adam optimizers are only stepped on odd steps @classiclarryd + 3. Explicit scatter_order and work_order for communication scheduling (no backward hooks) + 4. Muon has a linear momentum warmup and cooldown schedule + 5. Learning rates follow a linear decay schedule + 6. Embed is tied to lm_head until split step (2/3 of training), then untied @classiclarryd + """ + def __init__(self, model): + self.model = model + self.block_size = 128 + + # - Ordering dictates when to launch reduce/reduce_scatter operations + # - "sharded" parameters use reduce_scatter/all_gather and "replicated" ones use all_reduce + # - lr_mul and wd_mul are per-parameter learning rate and weight decay multipliers + self.param_table = { + "attn": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "mlp": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "scalars": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 5.0, "wd_mul": 0.0}, + "value_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "bigram_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "smear_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.01, "wd_mul": 0.0}, + "skip_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.05, "wd_mul": 0.0}, + "attn_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "ve_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "x0_lambdas": {"optim": "adam", "comms": "replicated", "adam_betas": [0.65, 0.95], "lr_mul": 5.0, "wd_mul": 0.0}, + "lm_head": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + "embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + } + + # - Process smaller/faster params first while large reduces complete + # - lm_head must complete before embed sync (when tied) + self.work_order = [ + "scalars", "smear_gate", "skip_gate", "attn_gate_bank", "ve_gate_bank", "x0_lambdas", # Small, fast + "value_embed", "bigram_embed", # Medium + "lm_head", "embed", # lm_head must complete before embed sync (when tied) + "attn", "mlp", # Large, polar express - process last to maximize overlap + ] + + adam_defaults = dict( + lr=0.008, + eps=1e-10, + weight_decay=0.005, + ) + + normuon_defaults = dict( + lr=0.023, + momentum=0.95, + beta2=0.95, + weight_decay=1.2, + ) + + self.optimizer = NorMuonAndAdam( + model.named_parameters(), + param_table=self.param_table, + scatter_order=list(self.param_table.keys()), # Dict order defines scatter priority + work_order=self.work_order, + adam_defaults=adam_defaults, + normuon_defaults=normuon_defaults, + ) + + # Split embed from lm_head at 2/3 of training (on an odd step so Adam updates) + self.split_step = training_schedule.split_step + + self.reset() + + def apply_final_ws_ext(self): + self.ws_long = training_schedule.ws_post_yarn_ext + + def get_forward_args(self): + return ForwardScheduleConfig( + mtp_weights = self.mtp_weights, + ws_short = self.ws_short * self.block_size, + ws_long = self.ws_long * self.block_size + ) + + def _is_adam_step(self, step: int): + """Adam params are only updated on odd steps.""" + return step % 2 == 1 + + def get_transition_steps(self): + return [start for start, _ in training_schedule.boundaries[1:]] + + def advance_schedule(self, step: int): + stage, _ = training_schedule.lookup(step) + self.ws_short, new_ws_long = stage.window_sizes + if new_ws_long != self.ws_long: + self.model.yarn.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + self.model.yarn_paired_head.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + + new_batch_size = stage.batch_size + if new_batch_size != self.batch_size: + self.train_loader_send_args = (new_batch_size, args.train_max_seq_len, grad_accum_steps) + self.batch_size = new_batch_size + else: + self.train_loader_send_args = None + + self.ws_long = new_ws_long + self.mtp_weights = training_schedule.mtp_weights[step] + + def step_optimizers(self, step: int): + step_lr = training_schedule.get_lr(step) + muon_momentum = get_muon_momentum(step) + do_adam = self._is_adam_step(step) + + # Update learning rates and momentum for all params + for param, p_cfg in self.optimizer.param_cfgs.items(): + p_cfg.lr = p_cfg.initial_lr * step_lr + if p_cfg.optim == "normuon": + p_cfg.momentum = muon_momentum + + # Step optimizer with do_adam flag + self.optimizer.step(do_adam=do_adam) + + # At split step: copy lm_head optimizer state to embed and mark as split + if step == self.split_step: + self.optimizer.copy_lm_state_to_embed() + + def reset(self, state=None): + if state is not None: + self.optimizer.load_state_dict(state) + + # Reset NorMuon momentum buffers and split_embed state + self.optimizer.reset() + + stage, _ = training_schedule.lookup(0) + self.ws_short, self.ws_long = stage.window_sizes + self.batch_size = stage.batch_size + self.model.yarn.reset() + self.model.yarn_paired_head.reset() + + def get_state(self): + return copy.deepcopy(self.optimizer.state_dict()) + +# ----------------------------------------------------------------------------- +# int main + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=11, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=args.val_batch_size // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.weight.data = m.weight.data.bfloat16() +model.attn_gate_bank.data = model.attn_gate_bank.data.bfloat16() +model.ve_gate_bank.data = model.ve_gate_bank.data.bfloat16() +model.attn_bank.data = model.attn_bank.data.bfloat16() +model.mlp_bank.data = model.mlp_bank.data.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) +training_manager = TrainingManager(model) + +######################################## +# Warmup kernels # +######################################## +print0("Compiling model and warming up kernels (~7 minutes on first execution)", console=True) +# Warmup the training kernels, then re-initialize the state so we aren't cheating +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizer=training_manager.get_state()) # save the initial state +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + +transition_steps = training_manager.get_transition_steps() +# first few steps plus transitions +warmup_steps = sorted({0, 1, 2} | set(s + offset for s in transition_steps for offset in [-1, 0, 1] if s + offset >= 0)) +print0(f"Sampling steps {warmup_steps} for warmup", console=True) +for step in warmup_steps: + training_manager.advance_schedule(step) + model.eval() + with torch.no_grad(): + inputs, targets, cum_seqlens = next(val_loader) + model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + model.train() + for idx in range(grad_accum_steps): + send_args = training_manager.train_loader_send_args + inputs, targets, cum_seqlens = train_loader.send(send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) +print0("Resetting Model", console=True) +model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +training_manager.reset(initial_state["optimizer"]) +del val_loader, train_loader, initial_state +model.train() + +######################################## +# Training and validation # +######################################## +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) + +gc.collect() + +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = training_schedule.total_steps +for step in range(train_steps + 1): + last_step = (step == train_steps) + training_manager.advance_schedule(step) + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + training_manager.apply_final_ws_ext() + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + val_loss /= val_steps + del val_loader + dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizer=training_manager.get_state()) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for idx in range(grad_accum_steps): + inputs, targets, cum_seqlens = train_loader.send(training_manager.train_loader_send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) + + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + + +---------------------------------------- +# triton_kernels.py +---------------------------------------- + +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded configs based on H100 autotuning + if K == 768: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + else: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 128 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded config based on H100 autotuning (M=768) + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +# ----------------------------------------------------------------------------- +# Triton kernel for MLP: relu(x @ W1.T)^2, by @andrewbriand, @jrauvola + +@triton.jit +def linear_relu_square_kernel(a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, + ): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am_c = pid_m * BLOCK_SIZE_M + offs_bn_c = pid_n * BLOCK_SIZE_N + + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + + c0 = acc0.to(dtype) + if not FORWARD: + c0_pre = aux_desc.load([offs_am_c, offs_bn_c]) + c0 = 2 * c0 * tl.where(c0_pre > 0, c0_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c], c0) + + if FORWARD: + c0_post = tl.maximum(c0, 0) + c0_post = c0_post * c0_post + aux_desc.store([offs_am_c, offs_bn_c], c0_post) + + c1 = acc1.to(dtype) + if not FORWARD: + c1_pre = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c1 = 2 * c1 * tl.where(c1_pre > 0, c1_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + + if FORWARD: + c1_post = tl.maximum(c1, 0) + c1_post = c1_post * c1_post + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1_post) + + +def linear_relu_square(a, b, aux=None): + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + FORWARD = False + if aux is None: + FORWARD = True + aux = torch.empty((M, N), device=a.device, dtype=dtype) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_stages = 4 if FORWARD else 3 + num_warps = 8 + + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + + def grid(META): + return (min( + NUM_SMS, + triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), + ), ) + + linear_relu_square_kernel[grid]( + a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + NUM_SMS=NUM_SMS, + FORWARD=FORWARD, + num_stages=num_stages, + num_warps=num_warps + ) + + if FORWARD: + return c, aux + else: + return c + +class FusedLinearReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, W1, W2): + pre, post = linear_relu_square(x.view((-1, x.shape[-1])), W1) + x3 = post @ W2 + ctx.save_for_backward(x, W1, W2, pre, post) + return x3.view(x.shape) + + @staticmethod + def backward(ctx, grad_output): + x, W1, W2, pre, post = ctx.saved_tensors + dW2 = post.T @ grad_output + dpre = linear_relu_square(grad_output.view((-1, grad_output.shape[-1])), W2, aux=pre) + dW1 = dpre.T @ x + dx = dpre @ W1 + return dx.view(x.shape), dW1, dW2 + +# ----------------------------------------------------------------------------- +# Fused Softcapped Cross Entropy + + +@triton.jit +def fused_softcapped_entropy_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + + max_val = -float('inf') + sum_exp = 0.0 + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=-float('inf')).to(tl.float32) + z = A * tl.sigmoid((val + B) / C) + z = tl.where(mask, z, -float('inf')) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + + total_loss = 0.0 + for k in range(n_predict): + target_idx = row_idx + k + if target_idx < n_rows: + weight = tl.load(mtp_weights_ptr + k) + if weight > 0: + target = tl.load(targets_ptr + target_idx).to(tl.int32) + if target >= 0 and target < n_cols: + val_target = tl.load(logits_row_ptr + target).to(tl.float32) + z_target = A * tl.sigmoid((val_target + B) / C) + total_loss += weight * (lse - z_target) + + tl.store(losses_ptr + row_idx, total_loss) + +@triton.jit +def fused_softcapped_entropy_bwd_kernel( + grad_input_ptr, grad_output_ptr, lse_ptr, logits_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, stride_grad_n, stride_grad_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_input_ptr + row_idx * stride_grad_n + + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_output_ptr + row_idx) + + S_w = 0.0 + for k in range(n_predict): + if row_idx + k < n_rows: + S_w += tl.load(mtp_weights_ptr + k) + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=0.0).to(tl.float32) + u = (val + B) / C + sigmoid_u = tl.sigmoid(u) + z = A * sigmoid_u + p = tl.exp(z - lse) + + term1 = S_w * p + term2 = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for k in range(n_predict): + if row_idx + k < n_rows: + target = tl.load(targets_ptr + row_idx + k).to(tl.int32) + weight = tl.load(mtp_weights_ptr + k) + term2 += tl.where(cols == target, weight, 0.0) + + grad_z = grad_loss * (term1 - term2) + dz_dx = (1.0 / C) * z * (1.0 - sigmoid_u) + grad_x = grad_z * dz_dx + tl.store(grad_row_ptr + cols, grad_x.to(tl.bfloat16), mask=mask) + +class FusedSoftcappedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, targets, mtp_weights, A=23.0, B=5.0, C=7.5): + n_rows, n_cols = logits.shape + if mtp_weights is None: + mtp_weights = torch.tensor([1.0], device=logits.device, dtype=torch.float32) + n_predict = mtp_weights.shape[0] + + losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + lse = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + + logits = logits.contiguous() + targets = targets.contiguous() + mtp_weights = mtp_weights.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_fwd_kernel[grid]( + logits, losses, lse, targets, mtp_weights, + logits.stride(0), logits.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + + ctx.save_for_backward(logits, targets, mtp_weights, lse) + ctx.params = (A, B, C) + return losses + + @staticmethod + def backward(ctx, grad_output): + logits, targets, mtp_weights, lse = ctx.saved_tensors + A, B, C = ctx.params + n_rows, n_cols = logits.shape + n_predict = mtp_weights.shape[0] + + grad_input = torch.empty((n_rows, n_cols), dtype=torch.bfloat16, device=logits.device) + grad_output = grad_output.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_bwd_kernel[grid]( + grad_input, grad_output, lse, logits, targets, mtp_weights, + logits.stride(0), logits.stride(1), grad_input.stride(0), grad_input.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + return grad_input, None, None, None, None, None + +==================================================================================================== +Running Python 3.12.7 (main, Jan 31 2026, 04:21:49) [GCC 13.2.0] +Running PyTorch 2.10.0.dev20251210+cu126 compiled for CUDA 12.6 +Running Triton version 3.6.0 +Sun Feb 1 06:10:04 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:71:00.0 Off | 0 | +| N/A 39C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:79:00.0 Off | 0 | +| N/A 34C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:7F:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 39C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 37C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 18107 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 18108 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 18109 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 18110 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 18111 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 18112 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 18113 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 18114 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +Compiling model and warming up kernels (~7 minutes on first execution) +Sampling steps [0, 1, 2, 504, 505, 506, 1009, 1010, 1011, 1514, 1515, 1516] for warmup +Resetting Model +step:0/1555 val_loss:10.8306 train_time:0ms step_avg:0.03ms +step:1/1555 train_time:74ms step_avg:74.30ms +step:2/1555 train_time:95ms step_avg:47.73ms +step:3/1555 train_time:119ms step_avg:39.57ms +step:4/1555 train_time:156ms step_avg:39.10ms +step:5/1555 train_time:187ms step_avg:37.37ms +step:6/1555 train_time:224ms step_avg:37.38ms +step:7/1555 train_time:255ms step_avg:36.43ms +step:8/1555 train_time:293ms step_avg:36.61ms +step:9/1555 train_time:324ms step_avg:36.01ms +step:10/1555 train_time:361ms step_avg:36.13ms +step:11/1555 train_time:392ms step_avg:35.67ms +step:12/1555 train_time:430ms step_avg:35.85ms +step:13/1555 train_time:461ms step_avg:35.45ms +step:14/1555 train_time:498ms step_avg:35.59ms +step:15/1555 train_time:529ms step_avg:35.29ms +step:16/1555 train_time:568ms step_avg:35.47ms +step:17/1555 train_time:599ms step_avg:35.24ms +step:18/1555 train_time:637ms step_avg:35.36ms +step:19/1555 train_time:667ms step_avg:35.13ms +step:20/1555 train_time:705ms step_avg:35.23ms +step:21/1555 train_time:736ms step_avg:35.05ms +step:22/1555 train_time:774ms step_avg:35.17ms +step:23/1555 train_time:805ms step_avg:34.99ms +step:24/1555 train_time:842ms step_avg:35.08ms +step:25/1555 train_time:873ms step_avg:34.93ms +step:26/1555 train_time:911ms step_avg:35.04ms +step:27/1555 train_time:942ms step_avg:34.89ms +step:28/1555 train_time:980ms step_avg:34.99ms +step:29/1555 train_time:1011ms step_avg:34.85ms +step:30/1555 train_time:1048ms step_avg:34.93ms +step:31/1555 train_time:1079ms step_avg:34.82ms +step:32/1555 train_time:1117ms step_avg:34.90ms +step:33/1555 train_time:1148ms step_avg:34.80ms +step:34/1555 train_time:1186ms step_avg:34.88ms +step:35/1555 train_time:1217ms step_avg:34.78ms +step:36/1555 train_time:1255ms step_avg:34.86ms +step:37/1555 train_time:1286ms step_avg:34.75ms +step:38/1555 train_time:1323ms step_avg:34.82ms +step:39/1555 train_time:1355ms step_avg:34.73ms +step:40/1555 train_time:1392ms step_avg:34.80ms +step:41/1555 train_time:1423ms step_avg:34.72ms +step:42/1555 train_time:1461ms step_avg:34.78ms +step:43/1555 train_time:1492ms step_avg:34.70ms +step:44/1555 train_time:1530ms step_avg:34.77ms +step:45/1555 train_time:1562ms step_avg:34.70ms +step:46/1555 train_time:1599ms step_avg:34.76ms +step:47/1555 train_time:1630ms step_avg:34.68ms +step:48/1555 train_time:1669ms step_avg:34.76ms +step:49/1555 train_time:1699ms step_avg:34.68ms +step:50/1555 train_time:1737ms step_avg:34.75ms +step:51/1555 train_time:1769ms step_avg:34.68ms +step:52/1555 train_time:1807ms step_avg:34.74ms +step:53/1555 train_time:1838ms step_avg:34.67ms +step:54/1555 train_time:1875ms step_avg:34.73ms +step:55/1555 train_time:1906ms step_avg:34.66ms +step:56/1555 train_time:1944ms step_avg:34.71ms +step:57/1555 train_time:1975ms step_avg:34.65ms +step:58/1555 train_time:2013ms step_avg:34.71ms +step:59/1555 train_time:2045ms step_avg:34.65ms +step:60/1555 train_time:2082ms step_avg:34.70ms +step:61/1555 train_time:2113ms step_avg:34.64ms +step:62/1555 train_time:2151ms step_avg:34.70ms +step:63/1555 train_time:2182ms step_avg:34.64ms +step:64/1555 train_time:2220ms step_avg:34.68ms +step:65/1555 train_time:2251ms step_avg:34.63ms +step:66/1555 train_time:2289ms step_avg:34.68ms +step:67/1555 train_time:2320ms step_avg:34.63ms +step:68/1555 train_time:2357ms step_avg:34.67ms +step:69/1555 train_time:2389ms step_avg:34.62ms +step:70/1555 train_time:2426ms step_avg:34.66ms +step:71/1555 train_time:2457ms step_avg:34.61ms +step:72/1555 train_time:2495ms step_avg:34.65ms +step:73/1555 train_time:2526ms step_avg:34.60ms +step:74/1555 train_time:2563ms step_avg:34.63ms +step:75/1555 train_time:2594ms step_avg:34.58ms +step:76/1555 train_time:2631ms step_avg:34.62ms +step:77/1555 train_time:2662ms step_avg:34.58ms +step:78/1555 train_time:2700ms step_avg:34.61ms +step:79/1555 train_time:2731ms step_avg:34.57ms +step:80/1555 train_time:2769ms step_avg:34.61ms +step:81/1555 train_time:2800ms step_avg:34.57ms +step:82/1555 train_time:2838ms step_avg:34.60ms +step:83/1555 train_time:2869ms step_avg:34.56ms +step:84/1555 train_time:2906ms step_avg:34.60ms +step:85/1555 train_time:2938ms step_avg:34.56ms +step:86/1555 train_time:2976ms step_avg:34.60ms +step:87/1555 train_time:3007ms step_avg:34.56ms +step:88/1555 train_time:3044ms step_avg:34.59ms +step:89/1555 train_time:3076ms step_avg:34.56ms +step:90/1555 train_time:3113ms step_avg:34.59ms +step:91/1555 train_time:3144ms step_avg:34.55ms +step:92/1555 train_time:3181ms step_avg:34.58ms +step:93/1555 train_time:3213ms step_avg:34.55ms +step:94/1555 train_time:3250ms step_avg:34.58ms +step:95/1555 train_time:3281ms step_avg:34.54ms +step:96/1555 train_time:3319ms step_avg:34.57ms +step:97/1555 train_time:3350ms step_avg:34.53ms +step:98/1555 train_time:3387ms step_avg:34.57ms +step:99/1555 train_time:3419ms step_avg:34.54ms +step:100/1555 train_time:3457ms step_avg:34.57ms +step:101/1555 train_time:3488ms step_avg:34.53ms +step:102/1555 train_time:3525ms step_avg:34.56ms +step:103/1555 train_time:3557ms step_avg:34.53ms +step:104/1555 train_time:3594ms step_avg:34.56ms +step:105/1555 train_time:3626ms step_avg:34.53ms +step:106/1555 train_time:3663ms step_avg:34.56ms +step:107/1555 train_time:3694ms step_avg:34.53ms +step:108/1555 train_time:3732ms step_avg:34.56ms +step:109/1555 train_time:3763ms step_avg:34.52ms +step:110/1555 train_time:3800ms step_avg:34.55ms +step:111/1555 train_time:3831ms step_avg:34.52ms +step:112/1555 train_time:3869ms step_avg:34.54ms +step:113/1555 train_time:3900ms step_avg:34.51ms +step:114/1555 train_time:3937ms step_avg:34.54ms +step:115/1555 train_time:3969ms step_avg:34.51ms +step:116/1555 train_time:4006ms step_avg:34.53ms +step:117/1555 train_time:4037ms step_avg:34.50ms +step:118/1555 train_time:4075ms step_avg:34.53ms +step:119/1555 train_time:4106ms step_avg:34.50ms +step:120/1555 train_time:4143ms step_avg:34.53ms +step:121/1555 train_time:4175ms step_avg:34.50ms +step:122/1555 train_time:4213ms step_avg:34.53ms +step:123/1555 train_time:4244ms step_avg:34.50ms +step:124/1555 train_time:4281ms step_avg:34.52ms +step:125/1555 train_time:4313ms step_avg:34.50ms +step:126/1555 train_time:4350ms step_avg:34.53ms +step:127/1555 train_time:4381ms step_avg:34.50ms +step:128/1555 train_time:4418ms step_avg:34.52ms +step:129/1555 train_time:4450ms step_avg:34.49ms +step:130/1555 train_time:4488ms step_avg:34.53ms +step:131/1555 train_time:4520ms step_avg:34.50ms +step:132/1555 train_time:4557ms step_avg:34.52ms +step:133/1555 train_time:4588ms step_avg:34.50ms +step:134/1555 train_time:4626ms step_avg:34.52ms +step:135/1555 train_time:4657ms step_avg:34.50ms +step:136/1555 train_time:4695ms step_avg:34.52ms +step:137/1555 train_time:4726ms step_avg:34.49ms +step:138/1555 train_time:4763ms step_avg:34.51ms +step:139/1555 train_time:4794ms step_avg:34.49ms +step:140/1555 train_time:4832ms step_avg:34.51ms +step:141/1555 train_time:4862ms step_avg:34.49ms +step:142/1555 train_time:4900ms step_avg:34.50ms +step:143/1555 train_time:4931ms step_avg:34.48ms +step:144/1555 train_time:4968ms step_avg:34.50ms +step:145/1555 train_time:4999ms step_avg:34.48ms +step:146/1555 train_time:5037ms step_avg:34.50ms +step:147/1555 train_time:5068ms step_avg:34.48ms +step:148/1555 train_time:5105ms step_avg:34.49ms +step:149/1555 train_time:5136ms step_avg:34.47ms +step:150/1555 train_time:5174ms step_avg:34.49ms +step:151/1555 train_time:5205ms step_avg:34.47ms +step:152/1555 train_time:5242ms step_avg:34.49ms +step:153/1555 train_time:5274ms step_avg:34.47ms +step:154/1555 train_time:5311ms step_avg:34.49ms +step:155/1555 train_time:5342ms step_avg:34.47ms +step:156/1555 train_time:5380ms step_avg:34.49ms +step:157/1555 train_time:5411ms step_avg:34.47ms +step:158/1555 train_time:5450ms step_avg:34.49ms +step:159/1555 train_time:5481ms step_avg:34.47ms +step:160/1555 train_time:5518ms step_avg:34.49ms +step:161/1555 train_time:5550ms step_avg:34.47ms +step:162/1555 train_time:5587ms step_avg:34.49ms +step:163/1555 train_time:5619ms step_avg:34.47ms +step:164/1555 train_time:5656ms step_avg:34.49ms +step:165/1555 train_time:5688ms step_avg:34.47ms +step:166/1555 train_time:5725ms step_avg:34.49ms +step:167/1555 train_time:5756ms step_avg:34.47ms +step:168/1555 train_time:5794ms step_avg:34.49ms +step:169/1555 train_time:5826ms step_avg:34.47ms +step:170/1555 train_time:5863ms step_avg:34.49ms +step:171/1555 train_time:5894ms step_avg:34.47ms +step:172/1555 train_time:5931ms step_avg:34.48ms +step:173/1555 train_time:5962ms step_avg:34.46ms +step:174/1555 train_time:6000ms step_avg:34.48ms +step:175/1555 train_time:6031ms step_avg:34.46ms +step:176/1555 train_time:6069ms step_avg:34.48ms +step:177/1555 train_time:6100ms step_avg:34.46ms +step:178/1555 train_time:6137ms step_avg:34.48ms +step:179/1555 train_time:6169ms step_avg:34.46ms +step:180/1555 train_time:6206ms step_avg:34.48ms +step:181/1555 train_time:6237ms step_avg:34.46ms +step:182/1555 train_time:6275ms step_avg:34.48ms +step:183/1555 train_time:6306ms step_avg:34.46ms +step:184/1555 train_time:6343ms step_avg:34.47ms +step:185/1555 train_time:6374ms step_avg:34.45ms +step:186/1555 train_time:6412ms step_avg:34.47ms +step:187/1555 train_time:6443ms step_avg:34.45ms +step:188/1555 train_time:6480ms step_avg:34.47ms +step:189/1555 train_time:6511ms step_avg:34.45ms +step:190/1555 train_time:6548ms step_avg:34.47ms +step:191/1555 train_time:6579ms step_avg:34.45ms +step:192/1555 train_time:6617ms step_avg:34.46ms +step:193/1555 train_time:6648ms step_avg:34.45ms +step:194/1555 train_time:6686ms step_avg:34.46ms +step:195/1555 train_time:6717ms step_avg:34.44ms +step:196/1555 train_time:6755ms step_avg:34.46ms +step:197/1555 train_time:6786ms step_avg:34.44ms +step:198/1555 train_time:6823ms step_avg:34.46ms +step:199/1555 train_time:6854ms step_avg:34.44ms +step:200/1555 train_time:6892ms step_avg:34.46ms +step:201/1555 train_time:6923ms step_avg:34.44ms +step:202/1555 train_time:6960ms step_avg:34.45ms +step:203/1555 train_time:6991ms step_avg:34.44ms +step:204/1555 train_time:7028ms step_avg:34.45ms +step:205/1555 train_time:7059ms step_avg:34.44ms +step:206/1555 train_time:7097ms step_avg:34.45ms +step:207/1555 train_time:7128ms step_avg:34.43ms +step:208/1555 train_time:7165ms step_avg:34.45ms +step:209/1555 train_time:7196ms step_avg:34.43ms +step:210/1555 train_time:7234ms step_avg:34.45ms +step:211/1555 train_time:7265ms step_avg:34.43ms +step:212/1555 train_time:7303ms step_avg:34.45ms +step:213/1555 train_time:7334ms step_avg:34.43ms +step:214/1555 train_time:7372ms step_avg:34.45ms +step:215/1555 train_time:7404ms step_avg:34.44ms +step:216/1555 train_time:7441ms step_avg:34.45ms +step:217/1555 train_time:7472ms step_avg:34.43ms +step:218/1555 train_time:7509ms step_avg:34.45ms +step:219/1555 train_time:7540ms step_avg:34.43ms +step:220/1555 train_time:7577ms step_avg:34.44ms +step:221/1555 train_time:7609ms step_avg:34.43ms +step:222/1555 train_time:7647ms step_avg:34.44ms +step:223/1555 train_time:7678ms step_avg:34.43ms +step:224/1555 train_time:7715ms step_avg:34.44ms +step:225/1555 train_time:7746ms step_avg:34.43ms +step:226/1555 train_time:7783ms step_avg:34.44ms +step:227/1555 train_time:7814ms step_avg:34.42ms +step:228/1555 train_time:7852ms step_avg:34.44ms +step:229/1555 train_time:7883ms step_avg:34.43ms +step:230/1555 train_time:7921ms step_avg:34.44ms +step:231/1555 train_time:7952ms step_avg:34.42ms +step:232/1555 train_time:7990ms step_avg:34.44ms +step:233/1555 train_time:8020ms step_avg:34.42ms +step:234/1555 train_time:8058ms step_avg:34.43ms +step:235/1555 train_time:8088ms step_avg:34.42ms +step:236/1555 train_time:8126ms step_avg:34.43ms +step:237/1555 train_time:8156ms step_avg:34.42ms +step:238/1555 train_time:8194ms step_avg:34.43ms +step:239/1555 train_time:8225ms step_avg:34.41ms +step:240/1555 train_time:8262ms step_avg:34.42ms +step:241/1555 train_time:8293ms step_avg:34.41ms +step:242/1555 train_time:8330ms step_avg:34.42ms +step:243/1555 train_time:8361ms step_avg:34.41ms +step:244/1555 train_time:8399ms step_avg:34.42ms +step:245/1555 train_time:8430ms step_avg:34.41ms +step:246/1555 train_time:8467ms step_avg:34.42ms +step:247/1555 train_time:8498ms step_avg:34.41ms +step:248/1555 train_time:8536ms step_avg:34.42ms +step:249/1555 train_time:8568ms step_avg:34.41ms +step:250/1555 train_time:8605ms step_avg:34.42ms +step:250/1555 val_loss:4.5542 train_time:8656ms step_avg:34.62ms +step:251/1555 train_time:8673ms step_avg:34.55ms +step:252/1555 train_time:8692ms step_avg:34.49ms +step:253/1555 train_time:8709ms step_avg:34.42ms +step:254/1555 train_time:8745ms step_avg:34.43ms +step:255/1555 train_time:8778ms step_avg:34.42ms +step:256/1555 train_time:8816ms step_avg:34.44ms +step:257/1555 train_time:8848ms step_avg:34.43ms +step:258/1555 train_time:8886ms step_avg:34.44ms +step:259/1555 train_time:8917ms step_avg:34.43ms +step:260/1555 train_time:8954ms step_avg:34.44ms +step:261/1555 train_time:8986ms step_avg:34.43ms +step:262/1555 train_time:9024ms step_avg:34.44ms +step:263/1555 train_time:9055ms step_avg:34.43ms +step:264/1555 train_time:9092ms step_avg:34.44ms +step:265/1555 train_time:9124ms step_avg:34.43ms +step:266/1555 train_time:9161ms step_avg:34.44ms +step:267/1555 train_time:9192ms step_avg:34.43ms +step:268/1555 train_time:9229ms step_avg:34.44ms +step:269/1555 train_time:9260ms step_avg:34.42ms +step:270/1555 train_time:9297ms step_avg:34.43ms +step:271/1555 train_time:9328ms step_avg:34.42ms +step:272/1555 train_time:9365ms step_avg:34.43ms +step:273/1555 train_time:9396ms step_avg:34.42ms +step:274/1555 train_time:9433ms step_avg:34.43ms +step:275/1555 train_time:9464ms step_avg:34.41ms +step:276/1555 train_time:9502ms step_avg:34.43ms +step:277/1555 train_time:9533ms step_avg:34.41ms +step:278/1555 train_time:9570ms step_avg:34.42ms +step:279/1555 train_time:9601ms step_avg:34.41ms +step:280/1555 train_time:9638ms step_avg:34.42ms +step:281/1555 train_time:9669ms step_avg:34.41ms +step:282/1555 train_time:9707ms step_avg:34.42ms +step:283/1555 train_time:9738ms step_avg:34.41ms +step:284/1555 train_time:9775ms step_avg:34.42ms +step:285/1555 train_time:9806ms step_avg:34.41ms +step:286/1555 train_time:9844ms step_avg:34.42ms +step:287/1555 train_time:9876ms step_avg:34.41ms +step:288/1555 train_time:9913ms step_avg:34.42ms +step:289/1555 train_time:9944ms step_avg:34.41ms +step:290/1555 train_time:9982ms step_avg:34.42ms +step:291/1555 train_time:10013ms step_avg:34.41ms +step:292/1555 train_time:10050ms step_avg:34.42ms +step:293/1555 train_time:10081ms step_avg:34.41ms +step:294/1555 train_time:10119ms step_avg:34.42ms +step:295/1555 train_time:10150ms step_avg:34.41ms +step:296/1555 train_time:10187ms step_avg:34.42ms +step:297/1555 train_time:10218ms step_avg:34.40ms +step:298/1555 train_time:10255ms step_avg:34.41ms +step:299/1555 train_time:10286ms step_avg:34.40ms +step:300/1555 train_time:10324ms step_avg:34.41ms +step:301/1555 train_time:10355ms step_avg:34.40ms +step:302/1555 train_time:10392ms step_avg:34.41ms +step:303/1555 train_time:10423ms step_avg:34.40ms +step:304/1555 train_time:10460ms step_avg:34.41ms +step:305/1555 train_time:10491ms step_avg:34.40ms +step:306/1555 train_time:10528ms step_avg:34.41ms +step:307/1555 train_time:10559ms step_avg:34.39ms +step:308/1555 train_time:10596ms step_avg:34.40ms +step:309/1555 train_time:10627ms step_avg:34.39ms +step:310/1555 train_time:10664ms step_avg:34.40ms +step:311/1555 train_time:10695ms step_avg:34.39ms +step:312/1555 train_time:10732ms step_avg:34.40ms +step:313/1555 train_time:10763ms step_avg:34.39ms +step:314/1555 train_time:10800ms step_avg:34.40ms +step:315/1555 train_time:10831ms step_avg:34.38ms +step:316/1555 train_time:10869ms step_avg:34.39ms +step:317/1555 train_time:10900ms step_avg:34.38ms +step:318/1555 train_time:10937ms step_avg:34.39ms +step:319/1555 train_time:10969ms step_avg:34.39ms +step:320/1555 train_time:11006ms step_avg:34.39ms +step:321/1555 train_time:11037ms step_avg:34.38ms +step:322/1555 train_time:11074ms step_avg:34.39ms +step:323/1555 train_time:11105ms step_avg:34.38ms +step:324/1555 train_time:11143ms step_avg:34.39ms +step:325/1555 train_time:11174ms step_avg:34.38ms +step:326/1555 train_time:11211ms step_avg:34.39ms +step:327/1555 train_time:11243ms step_avg:34.38ms +step:328/1555 train_time:11280ms step_avg:34.39ms +step:329/1555 train_time:11311ms step_avg:34.38ms +step:330/1555 train_time:11349ms step_avg:34.39ms +step:331/1555 train_time:11380ms step_avg:34.38ms +step:332/1555 train_time:11417ms step_avg:34.39ms +step:333/1555 train_time:11448ms step_avg:34.38ms +step:334/1555 train_time:11486ms step_avg:34.39ms +step:335/1555 train_time:11517ms step_avg:34.38ms +step:336/1555 train_time:11554ms step_avg:34.39ms +step:337/1555 train_time:11586ms step_avg:34.38ms +step:338/1555 train_time:11624ms step_avg:34.39ms +step:339/1555 train_time:11655ms step_avg:34.38ms +step:340/1555 train_time:11692ms step_avg:34.39ms +step:341/1555 train_time:11723ms step_avg:34.38ms +step:342/1555 train_time:11760ms step_avg:34.39ms +step:343/1555 train_time:11791ms step_avg:34.38ms +step:344/1555 train_time:11828ms step_avg:34.38ms +step:345/1555 train_time:11859ms step_avg:34.37ms +step:346/1555 train_time:11896ms step_avg:34.38ms +step:347/1555 train_time:11927ms step_avg:34.37ms +step:348/1555 train_time:11965ms step_avg:34.38ms +step:349/1555 train_time:11995ms step_avg:34.37ms +step:350/1555 train_time:12033ms step_avg:34.38ms +step:351/1555 train_time:12063ms step_avg:34.37ms +step:352/1555 train_time:12101ms step_avg:34.38ms +step:353/1555 train_time:12132ms step_avg:34.37ms +step:354/1555 train_time:12169ms step_avg:34.38ms +step:355/1555 train_time:12200ms step_avg:34.37ms +step:356/1555 train_time:12237ms step_avg:34.37ms +step:357/1555 train_time:12269ms step_avg:34.37ms +step:358/1555 train_time:12306ms step_avg:34.38ms +step:359/1555 train_time:12338ms step_avg:34.37ms +step:360/1555 train_time:12374ms step_avg:34.37ms +step:361/1555 train_time:12405ms step_avg:34.36ms +step:362/1555 train_time:12443ms step_avg:34.37ms +step:363/1555 train_time:12474ms step_avg:34.36ms +step:364/1555 train_time:12511ms step_avg:34.37ms +step:365/1555 train_time:12542ms step_avg:34.36ms +step:366/1555 train_time:12580ms step_avg:34.37ms +step:367/1555 train_time:12611ms step_avg:34.36ms +step:368/1555 train_time:12649ms step_avg:34.37ms +step:369/1555 train_time:12680ms step_avg:34.36ms +step:370/1555 train_time:12717ms step_avg:34.37ms +step:371/1555 train_time:12748ms step_avg:34.36ms +step:372/1555 train_time:12786ms step_avg:34.37ms +step:373/1555 train_time:12816ms step_avg:34.36ms +step:374/1555 train_time:12853ms step_avg:34.37ms +step:375/1555 train_time:12885ms step_avg:34.36ms +step:376/1555 train_time:12922ms step_avg:34.37ms +step:377/1555 train_time:12953ms step_avg:34.36ms +step:378/1555 train_time:12990ms step_avg:34.37ms +step:379/1555 train_time:13021ms step_avg:34.36ms +step:380/1555 train_time:13059ms step_avg:34.37ms +step:381/1555 train_time:13090ms step_avg:34.36ms +step:382/1555 train_time:13127ms step_avg:34.36ms +step:383/1555 train_time:13158ms step_avg:34.35ms +step:384/1555 train_time:13195ms step_avg:34.36ms +step:385/1555 train_time:13226ms step_avg:34.35ms +step:386/1555 train_time:13263ms step_avg:34.36ms +step:387/1555 train_time:13294ms step_avg:34.35ms +step:388/1555 train_time:13331ms step_avg:34.36ms +step:389/1555 train_time:13363ms step_avg:34.35ms +step:390/1555 train_time:13400ms step_avg:34.36ms +step:391/1555 train_time:13431ms step_avg:34.35ms +step:392/1555 train_time:13468ms step_avg:34.36ms +step:393/1555 train_time:13499ms step_avg:34.35ms +step:394/1555 train_time:13536ms step_avg:34.36ms +step:395/1555 train_time:13568ms step_avg:34.35ms +step:396/1555 train_time:13606ms step_avg:34.36ms +step:397/1555 train_time:13636ms step_avg:34.35ms +step:398/1555 train_time:13674ms step_avg:34.36ms +step:399/1555 train_time:13705ms step_avg:34.35ms +step:400/1555 train_time:13743ms step_avg:34.36ms +step:401/1555 train_time:13774ms step_avg:34.35ms +step:402/1555 train_time:13812ms step_avg:34.36ms +step:403/1555 train_time:13843ms step_avg:34.35ms +step:404/1555 train_time:13881ms step_avg:34.36ms +step:405/1555 train_time:13912ms step_avg:34.35ms +step:406/1555 train_time:13949ms step_avg:34.36ms +step:407/1555 train_time:13980ms step_avg:34.35ms +step:408/1555 train_time:14017ms step_avg:34.36ms +step:409/1555 train_time:14049ms step_avg:34.35ms +step:410/1555 train_time:14086ms step_avg:34.36ms +step:411/1555 train_time:14117ms step_avg:34.35ms +step:412/1555 train_time:14154ms step_avg:34.36ms +step:413/1555 train_time:14185ms step_avg:34.35ms +step:414/1555 train_time:14222ms step_avg:34.35ms +step:415/1555 train_time:14253ms step_avg:34.35ms +step:416/1555 train_time:14290ms step_avg:34.35ms +step:417/1555 train_time:14321ms step_avg:34.34ms +step:418/1555 train_time:14358ms step_avg:34.35ms +step:419/1555 train_time:14389ms step_avg:34.34ms +step:420/1555 train_time:14427ms step_avg:34.35ms +step:421/1555 train_time:14458ms step_avg:34.34ms +step:422/1555 train_time:14495ms step_avg:34.35ms +step:423/1555 train_time:14526ms step_avg:34.34ms +step:424/1555 train_time:14563ms step_avg:34.35ms +step:425/1555 train_time:14594ms step_avg:34.34ms +step:426/1555 train_time:14632ms step_avg:34.35ms +step:427/1555 train_time:14662ms step_avg:34.34ms +step:428/1555 train_time:14700ms step_avg:34.34ms +step:429/1555 train_time:14731ms step_avg:34.34ms +step:430/1555 train_time:14768ms step_avg:34.34ms +step:431/1555 train_time:14799ms step_avg:34.34ms +step:432/1555 train_time:14836ms step_avg:34.34ms +step:433/1555 train_time:14867ms step_avg:34.34ms +step:434/1555 train_time:14905ms step_avg:34.34ms +step:435/1555 train_time:14936ms step_avg:34.34ms +step:436/1555 train_time:14973ms step_avg:34.34ms +step:437/1555 train_time:15004ms step_avg:34.33ms +step:438/1555 train_time:15041ms step_avg:34.34ms +step:439/1555 train_time:15073ms step_avg:34.33ms +step:440/1555 train_time:15110ms step_avg:34.34ms +step:441/1555 train_time:15141ms step_avg:34.33ms +step:442/1555 train_time:15179ms step_avg:34.34ms +step:443/1555 train_time:15209ms step_avg:34.33ms +step:444/1555 train_time:15247ms step_avg:34.34ms +step:445/1555 train_time:15278ms step_avg:34.33ms +step:446/1555 train_time:15315ms step_avg:34.34ms +step:447/1555 train_time:15346ms step_avg:34.33ms +step:448/1555 train_time:15383ms step_avg:34.34ms +step:449/1555 train_time:15414ms step_avg:34.33ms +step:450/1555 train_time:15452ms step_avg:34.34ms +step:451/1555 train_time:15483ms step_avg:34.33ms +step:452/1555 train_time:15521ms step_avg:34.34ms +step:453/1555 train_time:15552ms step_avg:34.33ms +step:454/1555 train_time:15589ms step_avg:34.34ms +step:455/1555 train_time:15620ms step_avg:34.33ms +step:456/1555 train_time:15657ms step_avg:34.34ms +step:457/1555 train_time:15688ms step_avg:34.33ms +step:458/1555 train_time:15725ms step_avg:34.33ms +step:459/1555 train_time:15756ms step_avg:34.33ms +step:460/1555 train_time:15793ms step_avg:34.33ms +step:461/1555 train_time:15824ms step_avg:34.33ms +step:462/1555 train_time:15862ms step_avg:34.33ms +step:463/1555 train_time:15893ms step_avg:34.33ms +step:464/1555 train_time:15930ms step_avg:34.33ms +step:465/1555 train_time:15961ms step_avg:34.32ms +step:466/1555 train_time:15998ms step_avg:34.33ms +step:467/1555 train_time:16029ms step_avg:34.32ms +step:468/1555 train_time:16066ms step_avg:34.33ms +step:469/1555 train_time:16097ms step_avg:34.32ms +step:470/1555 train_time:16134ms step_avg:34.33ms +step:471/1555 train_time:16165ms step_avg:34.32ms +step:472/1555 train_time:16202ms step_avg:34.33ms +step:473/1555 train_time:16233ms step_avg:34.32ms +step:474/1555 train_time:16270ms step_avg:34.33ms +step:475/1555 train_time:16302ms step_avg:34.32ms +step:476/1555 train_time:16339ms step_avg:34.32ms +step:477/1555 train_time:16370ms step_avg:34.32ms +step:478/1555 train_time:16407ms step_avg:34.32ms +step:479/1555 train_time:16438ms step_avg:34.32ms +step:480/1555 train_time:16476ms step_avg:34.32ms +step:481/1555 train_time:16506ms step_avg:34.32ms +step:482/1555 train_time:16544ms step_avg:34.32ms +step:483/1555 train_time:16575ms step_avg:34.32ms +step:484/1555 train_time:16612ms step_avg:34.32ms +step:485/1555 train_time:16643ms step_avg:34.32ms +step:486/1555 train_time:16681ms step_avg:34.32ms +step:487/1555 train_time:16712ms step_avg:34.32ms +step:488/1555 train_time:16749ms step_avg:34.32ms +step:489/1555 train_time:16780ms step_avg:34.32ms +step:490/1555 train_time:16817ms step_avg:34.32ms +step:491/1555 train_time:16848ms step_avg:34.31ms +step:492/1555 train_time:16886ms step_avg:34.32ms +step:493/1555 train_time:16917ms step_avg:34.31ms +step:494/1555 train_time:16954ms step_avg:34.32ms +step:495/1555 train_time:16985ms step_avg:34.31ms +step:496/1555 train_time:17023ms step_avg:34.32ms +step:497/1555 train_time:17055ms step_avg:34.31ms +step:498/1555 train_time:17092ms step_avg:34.32ms +step:499/1555 train_time:17123ms step_avg:34.31ms +step:500/1555 train_time:17160ms step_avg:34.32ms +step:500/1555 val_loss:4.2227 train_time:17210ms step_avg:34.42ms +step:501/1555 train_time:17227ms step_avg:34.39ms +step:502/1555 train_time:17246ms step_avg:34.36ms +step:503/1555 train_time:17263ms step_avg:34.32ms +step:504/1555 train_time:17299ms step_avg:34.32ms +step:505/1555 train_time:17331ms step_avg:34.32ms +step:506/1555 train_time:17372ms step_avg:34.33ms +step:507/1555 train_time:17427ms step_avg:34.37ms +step:508/1555 train_time:17491ms step_avg:34.43ms +step:509/1555 train_time:17549ms step_avg:34.48ms +step:510/1555 train_time:17613ms step_avg:34.53ms +step:511/1555 train_time:17671ms step_avg:34.58ms +step:512/1555 train_time:17734ms step_avg:34.64ms +step:513/1555 train_time:17790ms step_avg:34.68ms +step:514/1555 train_time:17854ms step_avg:34.74ms +step:515/1555 train_time:17910ms step_avg:34.78ms +step:516/1555 train_time:17973ms step_avg:34.83ms +step:517/1555 train_time:18030ms step_avg:34.87ms +step:518/1555 train_time:18093ms step_avg:34.93ms +step:519/1555 train_time:18152ms step_avg:34.98ms +step:520/1555 train_time:18217ms step_avg:35.03ms +step:521/1555 train_time:18276ms step_avg:35.08ms +step:522/1555 train_time:18341ms step_avg:35.14ms +step:523/1555 train_time:18399ms step_avg:35.18ms +step:524/1555 train_time:18464ms step_avg:35.24ms +step:525/1555 train_time:18522ms step_avg:35.28ms +step:526/1555 train_time:18587ms step_avg:35.34ms +step:527/1555 train_time:18644ms step_avg:35.38ms +step:528/1555 train_time:18708ms step_avg:35.43ms +step:529/1555 train_time:18765ms step_avg:35.47ms +step:530/1555 train_time:18829ms step_avg:35.53ms +step:531/1555 train_time:18886ms step_avg:35.57ms +step:532/1555 train_time:18950ms step_avg:35.62ms +step:533/1555 train_time:19007ms step_avg:35.66ms +step:534/1555 train_time:19071ms step_avg:35.71ms +step:535/1555 train_time:19130ms step_avg:35.76ms +step:536/1555 train_time:19195ms step_avg:35.81ms +step:537/1555 train_time:19254ms step_avg:35.86ms +step:538/1555 train_time:19318ms step_avg:35.91ms +step:539/1555 train_time:19376ms step_avg:35.95ms +step:540/1555 train_time:19439ms step_avg:36.00ms +step:541/1555 train_time:19497ms step_avg:36.04ms +step:542/1555 train_time:19561ms step_avg:36.09ms +step:543/1555 train_time:19619ms step_avg:36.13ms +step:544/1555 train_time:19683ms step_avg:36.18ms +step:545/1555 train_time:19741ms step_avg:36.22ms +step:546/1555 train_time:19805ms step_avg:36.27ms +step:547/1555 train_time:19862ms step_avg:36.31ms +step:548/1555 train_time:19926ms step_avg:36.36ms +step:549/1555 train_time:19984ms step_avg:36.40ms +step:550/1555 train_time:20048ms step_avg:36.45ms +step:551/1555 train_time:20106ms step_avg:36.49ms +step:552/1555 train_time:20171ms step_avg:36.54ms +step:553/1555 train_time:20230ms step_avg:36.58ms +step:554/1555 train_time:20294ms step_avg:36.63ms +step:555/1555 train_time:20352ms step_avg:36.67ms +step:556/1555 train_time:20416ms step_avg:36.72ms +step:557/1555 train_time:20474ms step_avg:36.76ms +step:558/1555 train_time:20538ms step_avg:36.81ms +step:559/1555 train_time:20595ms step_avg:36.84ms +step:560/1555 train_time:20660ms step_avg:36.89ms +step:561/1555 train_time:20716ms step_avg:36.93ms +step:562/1555 train_time:20781ms step_avg:36.98ms +step:563/1555 train_time:20838ms step_avg:37.01ms +step:564/1555 train_time:20902ms step_avg:37.06ms +step:565/1555 train_time:20960ms step_avg:37.10ms +step:566/1555 train_time:21025ms step_avg:37.15ms +step:567/1555 train_time:21082ms step_avg:37.18ms +step:568/1555 train_time:21147ms step_avg:37.23ms +step:569/1555 train_time:21205ms step_avg:37.27ms +step:570/1555 train_time:21270ms step_avg:37.32ms +step:571/1555 train_time:21328ms step_avg:37.35ms +step:572/1555 train_time:21392ms step_avg:37.40ms +step:573/1555 train_time:21451ms step_avg:37.44ms +step:574/1555 train_time:21514ms step_avg:37.48ms +step:575/1555 train_time:21572ms step_avg:37.52ms +step:576/1555 train_time:21637ms step_avg:37.56ms +step:577/1555 train_time:21694ms step_avg:37.60ms +step:578/1555 train_time:21758ms step_avg:37.64ms +step:579/1555 train_time:21815ms step_avg:37.68ms +step:580/1555 train_time:21879ms step_avg:37.72ms +step:581/1555 train_time:21936ms step_avg:37.76ms +step:582/1555 train_time:22000ms step_avg:37.80ms +step:583/1555 train_time:22058ms step_avg:37.84ms +step:584/1555 train_time:22123ms step_avg:37.88ms +step:585/1555 train_time:22181ms step_avg:37.92ms +step:586/1555 train_time:22247ms step_avg:37.96ms +step:587/1555 train_time:22304ms step_avg:38.00ms +step:588/1555 train_time:22369ms step_avg:38.04ms +step:589/1555 train_time:22426ms step_avg:38.08ms +step:590/1555 train_time:22491ms step_avg:38.12ms +step:591/1555 train_time:22549ms step_avg:38.15ms +step:592/1555 train_time:22613ms step_avg:38.20ms +step:593/1555 train_time:22670ms step_avg:38.23ms +step:594/1555 train_time:22735ms step_avg:38.27ms +step:595/1555 train_time:22793ms step_avg:38.31ms +step:596/1555 train_time:22857ms step_avg:38.35ms +step:597/1555 train_time:22914ms step_avg:38.38ms +step:598/1555 train_time:22977ms step_avg:38.42ms +step:599/1555 train_time:23035ms step_avg:38.46ms +step:600/1555 train_time:23100ms step_avg:38.50ms +step:601/1555 train_time:23157ms step_avg:38.53ms +step:602/1555 train_time:23222ms step_avg:38.58ms +step:603/1555 train_time:23280ms step_avg:38.61ms +step:604/1555 train_time:23345ms step_avg:38.65ms +step:605/1555 train_time:23402ms step_avg:38.68ms +step:606/1555 train_time:23468ms step_avg:38.73ms +step:607/1555 train_time:23525ms step_avg:38.76ms +step:608/1555 train_time:23590ms step_avg:38.80ms +step:609/1555 train_time:23649ms step_avg:38.83ms +step:610/1555 train_time:23714ms step_avg:38.87ms +step:611/1555 train_time:23771ms step_avg:38.91ms +step:612/1555 train_time:23835ms step_avg:38.95ms +step:613/1555 train_time:23893ms step_avg:38.98ms +step:614/1555 train_time:23957ms step_avg:39.02ms +step:615/1555 train_time:24015ms step_avg:39.05ms +step:616/1555 train_time:24078ms step_avg:39.09ms +step:617/1555 train_time:24135ms step_avg:39.12ms +step:618/1555 train_time:24199ms step_avg:39.16ms +step:619/1555 train_time:24256ms step_avg:39.19ms +step:620/1555 train_time:24322ms step_avg:39.23ms +step:621/1555 train_time:24379ms step_avg:39.26ms +step:622/1555 train_time:24444ms step_avg:39.30ms +step:623/1555 train_time:24502ms step_avg:39.33ms +step:624/1555 train_time:24568ms step_avg:39.37ms +step:625/1555 train_time:24626ms step_avg:39.40ms +step:626/1555 train_time:24690ms step_avg:39.44ms +step:627/1555 train_time:24748ms step_avg:39.47ms +step:628/1555 train_time:24813ms step_avg:39.51ms +step:629/1555 train_time:24871ms step_avg:39.54ms +step:630/1555 train_time:24935ms step_avg:39.58ms +step:631/1555 train_time:24993ms step_avg:39.61ms +step:632/1555 train_time:25057ms step_avg:39.65ms +step:633/1555 train_time:25114ms step_avg:39.67ms +step:634/1555 train_time:25177ms step_avg:39.71ms +step:635/1555 train_time:25235ms step_avg:39.74ms +step:636/1555 train_time:25298ms step_avg:39.78ms +step:637/1555 train_time:25356ms step_avg:39.81ms +step:638/1555 train_time:25420ms step_avg:39.84ms +step:639/1555 train_time:25478ms step_avg:39.87ms +step:640/1555 train_time:25542ms step_avg:39.91ms +step:641/1555 train_time:25600ms step_avg:39.94ms +step:642/1555 train_time:25666ms step_avg:39.98ms +step:643/1555 train_time:25724ms step_avg:40.01ms +step:644/1555 train_time:25788ms step_avg:40.04ms +step:645/1555 train_time:25846ms step_avg:40.07ms +step:646/1555 train_time:25910ms step_avg:40.11ms +step:647/1555 train_time:25968ms step_avg:40.14ms +step:648/1555 train_time:26033ms step_avg:40.17ms +step:649/1555 train_time:26091ms step_avg:40.20ms +step:650/1555 train_time:26155ms step_avg:40.24ms +step:651/1555 train_time:26213ms step_avg:40.27ms +step:652/1555 train_time:26276ms step_avg:40.30ms +step:653/1555 train_time:26334ms step_avg:40.33ms +step:654/1555 train_time:26398ms step_avg:40.36ms +step:655/1555 train_time:26456ms step_avg:40.39ms +step:656/1555 train_time:26519ms step_avg:40.43ms +step:657/1555 train_time:26577ms step_avg:40.45ms +step:658/1555 train_time:26641ms step_avg:40.49ms +step:659/1555 train_time:26699ms step_avg:40.51ms +step:660/1555 train_time:26765ms step_avg:40.55ms +step:661/1555 train_time:26823ms step_avg:40.58ms +step:662/1555 train_time:26888ms step_avg:40.62ms +step:663/1555 train_time:26945ms step_avg:40.64ms +step:664/1555 train_time:27010ms step_avg:40.68ms +step:665/1555 train_time:27067ms step_avg:40.70ms +step:666/1555 train_time:27133ms step_avg:40.74ms +step:667/1555 train_time:27191ms step_avg:40.77ms +step:668/1555 train_time:27255ms step_avg:40.80ms +step:669/1555 train_time:27312ms step_avg:40.83ms +step:670/1555 train_time:27376ms step_avg:40.86ms +step:671/1555 train_time:27434ms step_avg:40.89ms +step:672/1555 train_time:27499ms step_avg:40.92ms +step:673/1555 train_time:27556ms step_avg:40.95ms +step:674/1555 train_time:27620ms step_avg:40.98ms +step:675/1555 train_time:27677ms step_avg:41.00ms +step:676/1555 train_time:27741ms step_avg:41.04ms +step:677/1555 train_time:27799ms step_avg:41.06ms +step:678/1555 train_time:27864ms step_avg:41.10ms +step:679/1555 train_time:27922ms step_avg:41.12ms +step:680/1555 train_time:27987ms step_avg:41.16ms +step:681/1555 train_time:28045ms step_avg:41.18ms +step:682/1555 train_time:28109ms step_avg:41.22ms +step:683/1555 train_time:28167ms step_avg:41.24ms +step:684/1555 train_time:28232ms step_avg:41.27ms +step:685/1555 train_time:28289ms step_avg:41.30ms +step:686/1555 train_time:28355ms step_avg:41.33ms +step:687/1555 train_time:28412ms step_avg:41.36ms +step:688/1555 train_time:28476ms step_avg:41.39ms +step:689/1555 train_time:28535ms step_avg:41.41ms +step:690/1555 train_time:28599ms step_avg:41.45ms +step:691/1555 train_time:28657ms step_avg:41.47ms +step:692/1555 train_time:28720ms step_avg:41.50ms +step:693/1555 train_time:28778ms step_avg:41.53ms +step:694/1555 train_time:28842ms step_avg:41.56ms +step:695/1555 train_time:28899ms step_avg:41.58ms +step:696/1555 train_time:28965ms step_avg:41.62ms +step:697/1555 train_time:29023ms step_avg:41.64ms +step:698/1555 train_time:29087ms step_avg:41.67ms +step:699/1555 train_time:29145ms step_avg:41.70ms +step:700/1555 train_time:29209ms step_avg:41.73ms +step:701/1555 train_time:29267ms step_avg:41.75ms +step:702/1555 train_time:29330ms step_avg:41.78ms +step:703/1555 train_time:29390ms step_avg:41.81ms +step:704/1555 train_time:29455ms step_avg:41.84ms +step:705/1555 train_time:29512ms step_avg:41.86ms +step:706/1555 train_time:29576ms step_avg:41.89ms +step:707/1555 train_time:29635ms step_avg:41.92ms +step:708/1555 train_time:29698ms step_avg:41.95ms +step:709/1555 train_time:29755ms step_avg:41.97ms +step:710/1555 train_time:29819ms step_avg:42.00ms +step:711/1555 train_time:29877ms step_avg:42.02ms +step:712/1555 train_time:29941ms step_avg:42.05ms +step:713/1555 train_time:29999ms step_avg:42.07ms +step:714/1555 train_time:30064ms step_avg:42.11ms +step:715/1555 train_time:30122ms step_avg:42.13ms +step:716/1555 train_time:30187ms step_avg:42.16ms +step:717/1555 train_time:30245ms step_avg:42.18ms +step:718/1555 train_time:30309ms step_avg:42.21ms +step:719/1555 train_time:30367ms step_avg:42.23ms +step:720/1555 train_time:30431ms step_avg:42.27ms +step:721/1555 train_time:30490ms step_avg:42.29ms +step:722/1555 train_time:30554ms step_avg:42.32ms +step:723/1555 train_time:30613ms step_avg:42.34ms +step:724/1555 train_time:30676ms step_avg:42.37ms +step:725/1555 train_time:30733ms step_avg:42.39ms +step:726/1555 train_time:30797ms step_avg:42.42ms +step:727/1555 train_time:30856ms step_avg:42.44ms +step:728/1555 train_time:30919ms step_avg:42.47ms +step:729/1555 train_time:30977ms step_avg:42.49ms +step:730/1555 train_time:31043ms step_avg:42.52ms +step:731/1555 train_time:31099ms step_avg:42.54ms +step:732/1555 train_time:31165ms step_avg:42.57ms +step:733/1555 train_time:31222ms step_avg:42.60ms +step:734/1555 train_time:31287ms step_avg:42.62ms +step:735/1555 train_time:31345ms step_avg:42.65ms +step:736/1555 train_time:31409ms step_avg:42.68ms +step:737/1555 train_time:31467ms step_avg:42.70ms +step:738/1555 train_time:31531ms step_avg:42.72ms +step:739/1555 train_time:31590ms step_avg:42.75ms +step:740/1555 train_time:31653ms step_avg:42.77ms +step:741/1555 train_time:31711ms step_avg:42.80ms +step:742/1555 train_time:31775ms step_avg:42.82ms +step:743/1555 train_time:31833ms step_avg:42.84ms +step:744/1555 train_time:31897ms step_avg:42.87ms +step:745/1555 train_time:31955ms step_avg:42.89ms +step:746/1555 train_time:32019ms step_avg:42.92ms +step:747/1555 train_time:32076ms step_avg:42.94ms +step:748/1555 train_time:32141ms step_avg:42.97ms +step:749/1555 train_time:32199ms step_avg:42.99ms +step:750/1555 train_time:32263ms step_avg:43.02ms +step:750/1555 val_loss:3.8721 train_time:32346ms step_avg:43.13ms +step:751/1555 train_time:32366ms step_avg:43.10ms +step:752/1555 train_time:32387ms step_avg:43.07ms +step:753/1555 train_time:32447ms step_avg:43.09ms +step:754/1555 train_time:32515ms step_avg:43.12ms +step:755/1555 train_time:32574ms step_avg:43.14ms +step:756/1555 train_time:32638ms step_avg:43.17ms +step:757/1555 train_time:32695ms step_avg:43.19ms +step:758/1555 train_time:32759ms step_avg:43.22ms +step:759/1555 train_time:32816ms step_avg:43.24ms +step:760/1555 train_time:32878ms step_avg:43.26ms +step:761/1555 train_time:32935ms step_avg:43.28ms +step:762/1555 train_time:32998ms step_avg:43.30ms +step:763/1555 train_time:33054ms step_avg:43.32ms +step:764/1555 train_time:33117ms step_avg:43.35ms +step:765/1555 train_time:33175ms step_avg:43.37ms +step:766/1555 train_time:33237ms step_avg:43.39ms +step:767/1555 train_time:33295ms step_avg:43.41ms +step:768/1555 train_time:33360ms step_avg:43.44ms +step:769/1555 train_time:33419ms step_avg:43.46ms +step:770/1555 train_time:33488ms step_avg:43.49ms +step:771/1555 train_time:33549ms step_avg:43.51ms +step:772/1555 train_time:33614ms step_avg:43.54ms +step:773/1555 train_time:33672ms step_avg:43.56ms +step:774/1555 train_time:33735ms step_avg:43.58ms +step:775/1555 train_time:33792ms step_avg:43.60ms +step:776/1555 train_time:33856ms step_avg:43.63ms +step:777/1555 train_time:33912ms step_avg:43.64ms +step:778/1555 train_time:33976ms step_avg:43.67ms +step:779/1555 train_time:34032ms step_avg:43.69ms +step:780/1555 train_time:34095ms step_avg:43.71ms +step:781/1555 train_time:34153ms step_avg:43.73ms +step:782/1555 train_time:34216ms step_avg:43.75ms +step:783/1555 train_time:34274ms step_avg:43.77ms +step:784/1555 train_time:34337ms step_avg:43.80ms +step:785/1555 train_time:34397ms step_avg:43.82ms +step:786/1555 train_time:34463ms step_avg:43.85ms +step:787/1555 train_time:34521ms step_avg:43.86ms +step:788/1555 train_time:34586ms step_avg:43.89ms +step:789/1555 train_time:34644ms step_avg:43.91ms +step:790/1555 train_time:34708ms step_avg:43.93ms +step:791/1555 train_time:34767ms step_avg:43.95ms +step:792/1555 train_time:34830ms step_avg:43.98ms +step:793/1555 train_time:34888ms step_avg:43.99ms +step:794/1555 train_time:34952ms step_avg:44.02ms +step:795/1555 train_time:35009ms step_avg:44.04ms +step:796/1555 train_time:35073ms step_avg:44.06ms +step:797/1555 train_time:35130ms step_avg:44.08ms +step:798/1555 train_time:35194ms step_avg:44.10ms +step:799/1555 train_time:35252ms step_avg:44.12ms +step:800/1555 train_time:35316ms step_avg:44.14ms +step:801/1555 train_time:35374ms step_avg:44.16ms +step:802/1555 train_time:35438ms step_avg:44.19ms +step:803/1555 train_time:35496ms step_avg:44.20ms +step:804/1555 train_time:35561ms step_avg:44.23ms +step:805/1555 train_time:35617ms step_avg:44.25ms +step:806/1555 train_time:35682ms step_avg:44.27ms +step:807/1555 train_time:35740ms step_avg:44.29ms +step:808/1555 train_time:35804ms step_avg:44.31ms +step:809/1555 train_time:35863ms step_avg:44.33ms +step:810/1555 train_time:35927ms step_avg:44.35ms +step:811/1555 train_time:35984ms step_avg:44.37ms +step:812/1555 train_time:36048ms step_avg:44.39ms +step:813/1555 train_time:36105ms step_avg:44.41ms +step:814/1555 train_time:36170ms step_avg:44.43ms +step:815/1555 train_time:36228ms step_avg:44.45ms +step:816/1555 train_time:36292ms step_avg:44.48ms +step:817/1555 train_time:36350ms step_avg:44.49ms +step:818/1555 train_time:36415ms step_avg:44.52ms +step:819/1555 train_time:36474ms step_avg:44.53ms +step:820/1555 train_time:36537ms step_avg:44.56ms +step:821/1555 train_time:36595ms step_avg:44.57ms +step:822/1555 train_time:36659ms step_avg:44.60ms +step:823/1555 train_time:36716ms step_avg:44.61ms +step:824/1555 train_time:36780ms step_avg:44.64ms +step:825/1555 train_time:36837ms step_avg:44.65ms +step:826/1555 train_time:36902ms step_avg:44.68ms +step:827/1555 train_time:36959ms step_avg:44.69ms +step:828/1555 train_time:37024ms step_avg:44.72ms +step:829/1555 train_time:37081ms step_avg:44.73ms +step:830/1555 train_time:37145ms step_avg:44.75ms +step:831/1555 train_time:37202ms step_avg:44.77ms +step:832/1555 train_time:37266ms step_avg:44.79ms +step:833/1555 train_time:37325ms step_avg:44.81ms +step:834/1555 train_time:37389ms step_avg:44.83ms +step:835/1555 train_time:37448ms step_avg:44.85ms +step:836/1555 train_time:37512ms step_avg:44.87ms +step:837/1555 train_time:37572ms step_avg:44.89ms +step:838/1555 train_time:37635ms step_avg:44.91ms +step:839/1555 train_time:37693ms step_avg:44.93ms +step:840/1555 train_time:37758ms step_avg:44.95ms +step:841/1555 train_time:37815ms step_avg:44.96ms +step:842/1555 train_time:37880ms step_avg:44.99ms +step:843/1555 train_time:37936ms step_avg:45.00ms +step:844/1555 train_time:38001ms step_avg:45.02ms +step:845/1555 train_time:38058ms step_avg:45.04ms +step:846/1555 train_time:38122ms step_avg:45.06ms +step:847/1555 train_time:38179ms step_avg:45.08ms +step:848/1555 train_time:38243ms step_avg:45.10ms +step:849/1555 train_time:38301ms step_avg:45.11ms +step:850/1555 train_time:38366ms step_avg:45.14ms +step:851/1555 train_time:38423ms step_avg:45.15ms +step:852/1555 train_time:38489ms step_avg:45.17ms +step:853/1555 train_time:38547ms step_avg:45.19ms +step:854/1555 train_time:38612ms step_avg:45.21ms +step:855/1555 train_time:38670ms step_avg:45.23ms +step:856/1555 train_time:38734ms step_avg:45.25ms +step:857/1555 train_time:38792ms step_avg:45.27ms +step:858/1555 train_time:38857ms step_avg:45.29ms +step:859/1555 train_time:38915ms step_avg:45.30ms +step:860/1555 train_time:38978ms step_avg:45.32ms +step:861/1555 train_time:39035ms step_avg:45.34ms +step:862/1555 train_time:39099ms step_avg:45.36ms +step:863/1555 train_time:39156ms step_avg:45.37ms +step:864/1555 train_time:39220ms step_avg:45.39ms +step:865/1555 train_time:39277ms step_avg:45.41ms +step:866/1555 train_time:39341ms step_avg:45.43ms +step:867/1555 train_time:39399ms step_avg:45.44ms +step:868/1555 train_time:39464ms step_avg:45.47ms +step:869/1555 train_time:39522ms step_avg:45.48ms +step:870/1555 train_time:39587ms step_avg:45.50ms +step:871/1555 train_time:39645ms step_avg:45.52ms +step:872/1555 train_time:39709ms step_avg:45.54ms +step:873/1555 train_time:39768ms step_avg:45.55ms +step:874/1555 train_time:39833ms step_avg:45.58ms +step:875/1555 train_time:39890ms step_avg:45.59ms +step:876/1555 train_time:39955ms step_avg:45.61ms +step:877/1555 train_time:40013ms step_avg:45.62ms +step:878/1555 train_time:40077ms step_avg:45.65ms +step:879/1555 train_time:40134ms step_avg:45.66ms +step:880/1555 train_time:40198ms step_avg:45.68ms +step:881/1555 train_time:40256ms step_avg:45.69ms +step:882/1555 train_time:40319ms step_avg:45.71ms +step:883/1555 train_time:40377ms step_avg:45.73ms +step:884/1555 train_time:40440ms step_avg:45.75ms +step:885/1555 train_time:40498ms step_avg:45.76ms +step:886/1555 train_time:40563ms step_avg:45.78ms +step:887/1555 train_time:40620ms step_avg:45.80ms +step:888/1555 train_time:40686ms step_avg:45.82ms +step:889/1555 train_time:40744ms step_avg:45.83ms +step:890/1555 train_time:40809ms step_avg:45.85ms +step:891/1555 train_time:40867ms step_avg:45.87ms +step:892/1555 train_time:40931ms step_avg:45.89ms +step:893/1555 train_time:40989ms step_avg:45.90ms +step:894/1555 train_time:41053ms step_avg:45.92ms +step:895/1555 train_time:41111ms step_avg:45.93ms +step:896/1555 train_time:41176ms step_avg:45.96ms +step:897/1555 train_time:41233ms step_avg:45.97ms +step:898/1555 train_time:41297ms step_avg:45.99ms +step:899/1555 train_time:41355ms step_avg:46.00ms +step:900/1555 train_time:41419ms step_avg:46.02ms +step:901/1555 train_time:41477ms step_avg:46.03ms +step:902/1555 train_time:41541ms step_avg:46.05ms +step:903/1555 train_time:41598ms step_avg:46.07ms +step:904/1555 train_time:41663ms step_avg:46.09ms +step:905/1555 train_time:41720ms step_avg:46.10ms +step:906/1555 train_time:41785ms step_avg:46.12ms +step:907/1555 train_time:41844ms step_avg:46.13ms +step:908/1555 train_time:41909ms step_avg:46.16ms +step:909/1555 train_time:41967ms step_avg:46.17ms +step:910/1555 train_time:42030ms step_avg:46.19ms +step:911/1555 train_time:42089ms step_avg:46.20ms +step:912/1555 train_time:42153ms step_avg:46.22ms +step:913/1555 train_time:42211ms step_avg:46.23ms +step:914/1555 train_time:42277ms step_avg:46.26ms +step:915/1555 train_time:42335ms step_avg:46.27ms +step:916/1555 train_time:42398ms step_avg:46.29ms +step:917/1555 train_time:42456ms step_avg:46.30ms +step:918/1555 train_time:42519ms step_avg:46.32ms +step:919/1555 train_time:42577ms step_avg:46.33ms +step:920/1555 train_time:42640ms step_avg:46.35ms +step:921/1555 train_time:42698ms step_avg:46.36ms +step:922/1555 train_time:42762ms step_avg:46.38ms +step:923/1555 train_time:42821ms step_avg:46.39ms +step:924/1555 train_time:42885ms step_avg:46.41ms +step:925/1555 train_time:42943ms step_avg:46.42ms +step:926/1555 train_time:43008ms step_avg:46.44ms +step:927/1555 train_time:43065ms step_avg:46.46ms +step:928/1555 train_time:43129ms step_avg:46.47ms +step:929/1555 train_time:43187ms step_avg:46.49ms +step:930/1555 train_time:43251ms step_avg:46.51ms +step:931/1555 train_time:43309ms step_avg:46.52ms +step:932/1555 train_time:43375ms step_avg:46.54ms +step:933/1555 train_time:43433ms step_avg:46.55ms +step:934/1555 train_time:43497ms step_avg:46.57ms +step:935/1555 train_time:43555ms step_avg:46.58ms +step:936/1555 train_time:43619ms step_avg:46.60ms +step:937/1555 train_time:43676ms step_avg:46.61ms +step:938/1555 train_time:43740ms step_avg:46.63ms +step:939/1555 train_time:43798ms step_avg:46.64ms +step:940/1555 train_time:43861ms step_avg:46.66ms +step:941/1555 train_time:43919ms step_avg:46.67ms +step:942/1555 train_time:43984ms step_avg:46.69ms +step:943/1555 train_time:44041ms step_avg:46.70ms +step:944/1555 train_time:44107ms step_avg:46.72ms +step:945/1555 train_time:44165ms step_avg:46.74ms +step:946/1555 train_time:44229ms step_avg:46.75ms +step:947/1555 train_time:44287ms step_avg:46.77ms +step:948/1555 train_time:44351ms step_avg:46.78ms +step:949/1555 train_time:44409ms step_avg:46.80ms +step:950/1555 train_time:44474ms step_avg:46.82ms +step:951/1555 train_time:44532ms step_avg:46.83ms +step:952/1555 train_time:44599ms step_avg:46.85ms +step:953/1555 train_time:44656ms step_avg:46.86ms +step:954/1555 train_time:44719ms step_avg:46.88ms +step:955/1555 train_time:44777ms step_avg:46.89ms +step:956/1555 train_time:44842ms step_avg:46.91ms +step:957/1555 train_time:44900ms step_avg:46.92ms +step:958/1555 train_time:44965ms step_avg:46.94ms +step:959/1555 train_time:45022ms step_avg:46.95ms +step:960/1555 train_time:45087ms step_avg:46.97ms +step:961/1555 train_time:45144ms step_avg:46.98ms +step:962/1555 train_time:45208ms step_avg:46.99ms +step:963/1555 train_time:45266ms step_avg:47.01ms +step:964/1555 train_time:45331ms step_avg:47.02ms +step:965/1555 train_time:45389ms step_avg:47.03ms +step:966/1555 train_time:45453ms step_avg:47.05ms +step:967/1555 train_time:45511ms step_avg:47.06ms +step:968/1555 train_time:45577ms step_avg:47.08ms +step:969/1555 train_time:45634ms step_avg:47.09ms +step:970/1555 train_time:45698ms step_avg:47.11ms +step:971/1555 train_time:45756ms step_avg:47.12ms +step:972/1555 train_time:45819ms step_avg:47.14ms +step:973/1555 train_time:45876ms step_avg:47.15ms +step:974/1555 train_time:45941ms step_avg:47.17ms +step:975/1555 train_time:45998ms step_avg:47.18ms +step:976/1555 train_time:46062ms step_avg:47.19ms +step:977/1555 train_time:46119ms step_avg:47.20ms +step:978/1555 train_time:46184ms step_avg:47.22ms +step:979/1555 train_time:46242ms step_avg:47.23ms +step:980/1555 train_time:46307ms step_avg:47.25ms +step:981/1555 train_time:46365ms step_avg:47.26ms +step:982/1555 train_time:46429ms step_avg:47.28ms +step:983/1555 train_time:46488ms step_avg:47.29ms +step:984/1555 train_time:46553ms step_avg:47.31ms +step:985/1555 train_time:46611ms step_avg:47.32ms +step:986/1555 train_time:46676ms step_avg:47.34ms +step:987/1555 train_time:46734ms step_avg:47.35ms +step:988/1555 train_time:46798ms step_avg:47.37ms +step:989/1555 train_time:46856ms step_avg:47.38ms +step:990/1555 train_time:46919ms step_avg:47.39ms +step:991/1555 train_time:46977ms step_avg:47.40ms +step:992/1555 train_time:47041ms step_avg:47.42ms +step:993/1555 train_time:47098ms step_avg:47.43ms +step:994/1555 train_time:47163ms step_avg:47.45ms +step:995/1555 train_time:47219ms step_avg:47.46ms +step:996/1555 train_time:47285ms step_avg:47.47ms +step:997/1555 train_time:47343ms step_avg:47.49ms +step:998/1555 train_time:47409ms step_avg:47.50ms +step:999/1555 train_time:47467ms step_avg:47.51ms +step:1000/1555 train_time:47530ms step_avg:47.53ms +step:1000/1555 val_loss:3.5717 train_time:47613ms step_avg:47.61ms +step:1001/1555 train_time:47631ms step_avg:47.58ms +step:1002/1555 train_time:47654ms step_avg:47.56ms +step:1003/1555 train_time:47710ms step_avg:47.57ms +step:1004/1555 train_time:47780ms step_avg:47.59ms +step:1005/1555 train_time:47838ms step_avg:47.60ms +step:1006/1555 train_time:47902ms step_avg:47.62ms +step:1007/1555 train_time:47959ms step_avg:47.63ms +step:1008/1555 train_time:48023ms step_avg:47.64ms +step:1009/1555 train_time:48080ms step_avg:47.65ms +step:1010/1555 train_time:48144ms step_avg:47.67ms +step:1011/1555 train_time:48204ms step_avg:47.68ms +step:1012/1555 train_time:48289ms step_avg:47.72ms +step:1013/1555 train_time:48373ms step_avg:47.75ms +step:1014/1555 train_time:48464ms step_avg:47.79ms +step:1015/1555 train_time:48547ms step_avg:47.83ms +step:1016/1555 train_time:48638ms step_avg:47.87ms +step:1017/1555 train_time:48724ms step_avg:47.91ms +step:1018/1555 train_time:48816ms step_avg:47.95ms +step:1019/1555 train_time:48901ms step_avg:47.99ms +step:1020/1555 train_time:48992ms step_avg:48.03ms +step:1021/1555 train_time:49076ms step_avg:48.07ms +step:1022/1555 train_time:49166ms step_avg:48.11ms +step:1023/1555 train_time:49248ms step_avg:48.14ms +step:1024/1555 train_time:49338ms step_avg:48.18ms +step:1025/1555 train_time:49421ms step_avg:48.22ms +step:1026/1555 train_time:49510ms step_avg:48.26ms +step:1027/1555 train_time:49595ms step_avg:48.29ms +step:1028/1555 train_time:49685ms step_avg:48.33ms +step:1029/1555 train_time:49771ms step_avg:48.37ms +step:1030/1555 train_time:49862ms step_avg:48.41ms +step:1031/1555 train_time:49946ms step_avg:48.44ms +step:1032/1555 train_time:50039ms step_avg:48.49ms +step:1033/1555 train_time:50122ms step_avg:48.52ms +step:1034/1555 train_time:50210ms step_avg:48.56ms +step:1035/1555 train_time:50293ms step_avg:48.59ms +step:1036/1555 train_time:50383ms step_avg:48.63ms +step:1037/1555 train_time:50466ms step_avg:48.66ms +step:1038/1555 train_time:50556ms step_avg:48.71ms +step:1039/1555 train_time:50641ms step_avg:48.74ms +step:1040/1555 train_time:50731ms step_avg:48.78ms +step:1041/1555 train_time:50817ms step_avg:48.82ms +step:1042/1555 train_time:50907ms step_avg:48.85ms +step:1043/1555 train_time:50991ms step_avg:48.89ms +step:1044/1555 train_time:51081ms step_avg:48.93ms +step:1045/1555 train_time:51165ms step_avg:48.96ms +step:1046/1555 train_time:51254ms step_avg:49.00ms +step:1047/1555 train_time:51338ms step_avg:49.03ms +step:1048/1555 train_time:51426ms step_avg:49.07ms +step:1049/1555 train_time:51510ms step_avg:49.10ms +step:1050/1555 train_time:51601ms step_avg:49.14ms +step:1051/1555 train_time:51684ms step_avg:49.18ms +step:1052/1555 train_time:51776ms step_avg:49.22ms +step:1053/1555 train_time:51861ms step_avg:49.25ms +step:1054/1555 train_time:51951ms step_avg:49.29ms +step:1055/1555 train_time:52036ms step_avg:49.32ms +step:1056/1555 train_time:52127ms step_avg:49.36ms +step:1057/1555 train_time:52211ms step_avg:49.40ms +step:1058/1555 train_time:52301ms step_avg:49.43ms +step:1059/1555 train_time:52384ms step_avg:49.47ms +step:1060/1555 train_time:52473ms step_avg:49.50ms +step:1061/1555 train_time:52557ms step_avg:49.54ms +step:1062/1555 train_time:52648ms step_avg:49.57ms +step:1063/1555 train_time:52733ms step_avg:49.61ms +step:1064/1555 train_time:52823ms step_avg:49.65ms +step:1065/1555 train_time:52908ms step_avg:49.68ms +step:1066/1555 train_time:52999ms step_avg:49.72ms +step:1067/1555 train_time:53083ms step_avg:49.75ms +step:1068/1555 train_time:53172ms step_avg:49.79ms +step:1069/1555 train_time:53257ms step_avg:49.82ms +step:1070/1555 train_time:53346ms step_avg:49.86ms +step:1071/1555 train_time:53430ms step_avg:49.89ms +step:1072/1555 train_time:53521ms step_avg:49.93ms +step:1073/1555 train_time:53604ms step_avg:49.96ms +step:1074/1555 train_time:53695ms step_avg:50.00ms +step:1075/1555 train_time:53779ms step_avg:50.03ms +step:1076/1555 train_time:53868ms step_avg:50.06ms +step:1077/1555 train_time:53952ms step_avg:50.09ms +step:1078/1555 train_time:54043ms step_avg:50.13ms +step:1079/1555 train_time:54127ms step_avg:50.16ms +step:1080/1555 train_time:54219ms step_avg:50.20ms +step:1081/1555 train_time:54301ms step_avg:50.23ms +step:1082/1555 train_time:54391ms step_avg:50.27ms +step:1083/1555 train_time:54475ms step_avg:50.30ms +step:1084/1555 train_time:54565ms step_avg:50.34ms +step:1085/1555 train_time:54649ms step_avg:50.37ms +step:1086/1555 train_time:54739ms step_avg:50.40ms +step:1087/1555 train_time:54823ms step_avg:50.44ms +step:1088/1555 train_time:54914ms step_avg:50.47ms +step:1089/1555 train_time:54998ms step_avg:50.50ms +step:1090/1555 train_time:55089ms step_avg:50.54ms +step:1091/1555 train_time:55173ms step_avg:50.57ms +step:1092/1555 train_time:55263ms step_avg:50.61ms +step:1093/1555 train_time:55346ms step_avg:50.64ms +step:1094/1555 train_time:55436ms step_avg:50.67ms +step:1095/1555 train_time:55520ms step_avg:50.70ms +step:1096/1555 train_time:55610ms step_avg:50.74ms +step:1097/1555 train_time:55694ms step_avg:50.77ms +step:1098/1555 train_time:55784ms step_avg:50.81ms +step:1099/1555 train_time:55868ms step_avg:50.84ms +step:1100/1555 train_time:55959ms step_avg:50.87ms +step:1101/1555 train_time:56043ms step_avg:50.90ms +step:1102/1555 train_time:56134ms step_avg:50.94ms +step:1103/1555 train_time:56219ms step_avg:50.97ms +step:1104/1555 train_time:56308ms step_avg:51.00ms +step:1105/1555 train_time:56391ms step_avg:51.03ms +step:1106/1555 train_time:56482ms step_avg:51.07ms +step:1107/1555 train_time:56565ms step_avg:51.10ms +step:1108/1555 train_time:56655ms step_avg:51.13ms +step:1109/1555 train_time:56739ms step_avg:51.16ms +step:1110/1555 train_time:56829ms step_avg:51.20ms +step:1111/1555 train_time:56914ms step_avg:51.23ms +step:1112/1555 train_time:57003ms step_avg:51.26ms +step:1113/1555 train_time:57087ms step_avg:51.29ms +step:1114/1555 train_time:57177ms step_avg:51.33ms +step:1115/1555 train_time:57261ms step_avg:51.35ms +step:1116/1555 train_time:57350ms step_avg:51.39ms +step:1117/1555 train_time:57435ms step_avg:51.42ms +step:1118/1555 train_time:57525ms step_avg:51.45ms +step:1119/1555 train_time:57608ms step_avg:51.48ms +step:1120/1555 train_time:57699ms step_avg:51.52ms +step:1121/1555 train_time:57783ms step_avg:51.55ms +step:1122/1555 train_time:57873ms step_avg:51.58ms +step:1123/1555 train_time:57958ms step_avg:51.61ms +step:1124/1555 train_time:58048ms step_avg:51.64ms +step:1125/1555 train_time:58133ms step_avg:51.67ms +step:1126/1555 train_time:58223ms step_avg:51.71ms +step:1127/1555 train_time:58306ms step_avg:51.74ms +step:1128/1555 train_time:58396ms step_avg:51.77ms +step:1129/1555 train_time:58480ms step_avg:51.80ms +step:1130/1555 train_time:58569ms step_avg:51.83ms +step:1131/1555 train_time:58654ms step_avg:51.86ms +step:1132/1555 train_time:58744ms step_avg:51.89ms +step:1133/1555 train_time:58828ms step_avg:51.92ms +step:1134/1555 train_time:58920ms step_avg:51.96ms +step:1135/1555 train_time:59003ms step_avg:51.98ms +step:1136/1555 train_time:59093ms step_avg:52.02ms +step:1137/1555 train_time:59177ms step_avg:52.05ms +step:1138/1555 train_time:59266ms step_avg:52.08ms +step:1139/1555 train_time:59352ms step_avg:52.11ms +step:1140/1555 train_time:59441ms step_avg:52.14ms +step:1141/1555 train_time:59525ms step_avg:52.17ms +step:1142/1555 train_time:59616ms step_avg:52.20ms +step:1143/1555 train_time:59700ms step_avg:52.23ms +step:1144/1555 train_time:59790ms step_avg:52.26ms +step:1145/1555 train_time:59874ms step_avg:52.29ms +step:1146/1555 train_time:59964ms step_avg:52.32ms +step:1147/1555 train_time:60048ms step_avg:52.35ms +step:1148/1555 train_time:60139ms step_avg:52.39ms +step:1149/1555 train_time:60222ms step_avg:52.41ms +step:1150/1555 train_time:60313ms step_avg:52.45ms +step:1151/1555 train_time:60397ms step_avg:52.47ms +step:1152/1555 train_time:60486ms step_avg:52.51ms +step:1153/1555 train_time:60571ms step_avg:52.53ms +step:1154/1555 train_time:60661ms step_avg:52.57ms +step:1155/1555 train_time:60745ms step_avg:52.59ms +step:1156/1555 train_time:60835ms step_avg:52.63ms +step:1157/1555 train_time:60919ms step_avg:52.65ms +step:1158/1555 train_time:61008ms step_avg:52.68ms +step:1159/1555 train_time:61093ms step_avg:52.71ms +step:1160/1555 train_time:61182ms step_avg:52.74ms +step:1161/1555 train_time:61266ms step_avg:52.77ms +step:1162/1555 train_time:61357ms step_avg:52.80ms +step:1163/1555 train_time:61440ms step_avg:52.83ms +step:1164/1555 train_time:61530ms step_avg:52.86ms +step:1165/1555 train_time:61615ms step_avg:52.89ms +step:1166/1555 train_time:61705ms step_avg:52.92ms +step:1167/1555 train_time:61789ms step_avg:52.95ms +step:1168/1555 train_time:61879ms step_avg:52.98ms +step:1169/1555 train_time:61962ms step_avg:53.00ms +step:1170/1555 train_time:62053ms step_avg:53.04ms +step:1171/1555 train_time:62137ms step_avg:53.06ms +step:1172/1555 train_time:62226ms step_avg:53.09ms +step:1173/1555 train_time:62312ms step_avg:53.12ms +step:1174/1555 train_time:62402ms step_avg:53.15ms +step:1175/1555 train_time:62485ms step_avg:53.18ms +step:1176/1555 train_time:62576ms step_avg:53.21ms +step:1177/1555 train_time:62659ms step_avg:53.24ms +step:1178/1555 train_time:62749ms step_avg:53.27ms +step:1179/1555 train_time:62834ms step_avg:53.29ms +step:1180/1555 train_time:62924ms step_avg:53.33ms +step:1181/1555 train_time:63008ms step_avg:53.35ms +step:1182/1555 train_time:63099ms step_avg:53.38ms +step:1183/1555 train_time:63182ms step_avg:53.41ms +step:1184/1555 train_time:63271ms step_avg:53.44ms +step:1185/1555 train_time:63357ms step_avg:53.47ms +step:1186/1555 train_time:63446ms step_avg:53.50ms +step:1187/1555 train_time:63530ms step_avg:53.52ms +step:1188/1555 train_time:63621ms step_avg:53.55ms +step:1189/1555 train_time:63703ms step_avg:53.58ms +step:1190/1555 train_time:63793ms step_avg:53.61ms +step:1191/1555 train_time:63877ms step_avg:53.63ms +step:1192/1555 train_time:63966ms step_avg:53.66ms +step:1193/1555 train_time:64051ms step_avg:53.69ms +step:1194/1555 train_time:64142ms step_avg:53.72ms +step:1195/1555 train_time:64225ms step_avg:53.75ms +step:1196/1555 train_time:64317ms step_avg:53.78ms +step:1197/1555 train_time:64400ms step_avg:53.80ms +step:1198/1555 train_time:64490ms step_avg:53.83ms +step:1199/1555 train_time:64574ms step_avg:53.86ms +step:1200/1555 train_time:64665ms step_avg:53.89ms +step:1201/1555 train_time:64749ms step_avg:53.91ms +step:1202/1555 train_time:64840ms step_avg:53.94ms +step:1203/1555 train_time:64924ms step_avg:53.97ms +step:1204/1555 train_time:65014ms step_avg:54.00ms +step:1205/1555 train_time:65098ms step_avg:54.02ms +step:1206/1555 train_time:65186ms step_avg:54.05ms +step:1207/1555 train_time:65271ms step_avg:54.08ms +step:1208/1555 train_time:65361ms step_avg:54.11ms +step:1209/1555 train_time:65445ms step_avg:54.13ms +step:1210/1555 train_time:65536ms step_avg:54.16ms +step:1211/1555 train_time:65620ms step_avg:54.19ms +step:1212/1555 train_time:65710ms step_avg:54.22ms +step:1213/1555 train_time:65795ms step_avg:54.24ms +step:1214/1555 train_time:65884ms step_avg:54.27ms +step:1215/1555 train_time:65968ms step_avg:54.29ms +step:1216/1555 train_time:66060ms step_avg:54.33ms +step:1217/1555 train_time:66143ms step_avg:54.35ms +step:1218/1555 train_time:66233ms step_avg:54.38ms +step:1219/1555 train_time:66317ms step_avg:54.40ms +step:1220/1555 train_time:66405ms step_avg:54.43ms +step:1221/1555 train_time:66489ms step_avg:54.45ms +step:1222/1555 train_time:66581ms step_avg:54.49ms +step:1223/1555 train_time:66664ms step_avg:54.51ms +step:1224/1555 train_time:66754ms step_avg:54.54ms +step:1225/1555 train_time:66839ms step_avg:54.56ms +step:1226/1555 train_time:66928ms step_avg:54.59ms +step:1227/1555 train_time:67015ms step_avg:54.62ms +step:1228/1555 train_time:67103ms step_avg:54.64ms +step:1229/1555 train_time:67188ms step_avg:54.67ms +step:1230/1555 train_time:67278ms step_avg:54.70ms +step:1231/1555 train_time:67362ms step_avg:54.72ms +step:1232/1555 train_time:67451ms step_avg:54.75ms +step:1233/1555 train_time:67535ms step_avg:54.77ms +step:1234/1555 train_time:67626ms step_avg:54.80ms +step:1235/1555 train_time:67710ms step_avg:54.83ms +step:1236/1555 train_time:67800ms step_avg:54.85ms +step:1237/1555 train_time:67884ms step_avg:54.88ms +step:1238/1555 train_time:67974ms step_avg:54.91ms +step:1239/1555 train_time:68060ms step_avg:54.93ms +step:1240/1555 train_time:68148ms step_avg:54.96ms +step:1241/1555 train_time:68233ms step_avg:54.98ms +step:1242/1555 train_time:68323ms step_avg:55.01ms +step:1243/1555 train_time:68406ms step_avg:55.03ms +step:1244/1555 train_time:68496ms step_avg:55.06ms +step:1245/1555 train_time:68580ms step_avg:55.08ms +step:1246/1555 train_time:68670ms step_avg:55.11ms +step:1247/1555 train_time:68754ms step_avg:55.14ms +step:1248/1555 train_time:68844ms step_avg:55.16ms +step:1249/1555 train_time:68928ms step_avg:55.19ms +step:1250/1555 train_time:69020ms step_avg:55.22ms +step:1250/1555 val_loss:3.3999 train_time:69133ms step_avg:55.31ms +step:1251/1555 train_time:69151ms step_avg:55.28ms +step:1252/1555 train_time:69195ms step_avg:55.27ms +step:1253/1555 train_time:69283ms step_avg:55.29ms +step:1254/1555 train_time:69373ms step_avg:55.32ms +step:1255/1555 train_time:69456ms step_avg:55.34ms +step:1256/1555 train_time:69547ms step_avg:55.37ms +step:1257/1555 train_time:69629ms step_avg:55.39ms +step:1258/1555 train_time:69718ms step_avg:55.42ms +step:1259/1555 train_time:69801ms step_avg:55.44ms +step:1260/1555 train_time:69891ms step_avg:55.47ms +step:1261/1555 train_time:69973ms step_avg:55.49ms +step:1262/1555 train_time:70063ms step_avg:55.52ms +step:1263/1555 train_time:70150ms step_avg:55.54ms +step:1264/1555 train_time:70242ms step_avg:55.57ms +step:1265/1555 train_time:70330ms step_avg:55.60ms +step:1266/1555 train_time:70421ms step_avg:55.63ms +step:1267/1555 train_time:70506ms step_avg:55.65ms +step:1268/1555 train_time:70596ms step_avg:55.67ms +step:1269/1555 train_time:70678ms step_avg:55.70ms +step:1270/1555 train_time:70768ms step_avg:55.72ms +step:1271/1555 train_time:70851ms step_avg:55.74ms +step:1272/1555 train_time:70940ms step_avg:55.77ms +step:1273/1555 train_time:71025ms step_avg:55.79ms +step:1274/1555 train_time:71115ms step_avg:55.82ms +step:1275/1555 train_time:71200ms step_avg:55.84ms +step:1276/1555 train_time:71292ms step_avg:55.87ms +step:1277/1555 train_time:71376ms step_avg:55.89ms +step:1278/1555 train_time:71468ms step_avg:55.92ms +step:1279/1555 train_time:71552ms step_avg:55.94ms +step:1280/1555 train_time:71641ms step_avg:55.97ms +step:1281/1555 train_time:71725ms step_avg:55.99ms +step:1282/1555 train_time:71814ms step_avg:56.02ms +step:1283/1555 train_time:71897ms step_avg:56.04ms +step:1284/1555 train_time:71986ms step_avg:56.06ms +step:1285/1555 train_time:72071ms step_avg:56.09ms +step:1286/1555 train_time:72161ms step_avg:56.11ms +step:1287/1555 train_time:72247ms step_avg:56.14ms +step:1288/1555 train_time:72337ms step_avg:56.16ms +step:1289/1555 train_time:72423ms step_avg:56.19ms +step:1290/1555 train_time:72514ms step_avg:56.21ms +step:1291/1555 train_time:72598ms step_avg:56.23ms +step:1292/1555 train_time:72688ms step_avg:56.26ms +step:1293/1555 train_time:72771ms step_avg:56.28ms +step:1294/1555 train_time:72860ms step_avg:56.31ms +step:1295/1555 train_time:72944ms step_avg:56.33ms +step:1296/1555 train_time:73034ms step_avg:56.35ms +step:1297/1555 train_time:73118ms step_avg:56.37ms +step:1298/1555 train_time:73210ms step_avg:56.40ms +step:1299/1555 train_time:73294ms step_avg:56.42ms +step:1300/1555 train_time:73384ms step_avg:56.45ms +step:1301/1555 train_time:73468ms step_avg:56.47ms +step:1302/1555 train_time:73559ms step_avg:56.50ms +step:1303/1555 train_time:73644ms step_avg:56.52ms +step:1304/1555 train_time:73735ms step_avg:56.54ms +step:1305/1555 train_time:73817ms step_avg:56.56ms +step:1306/1555 train_time:73907ms step_avg:56.59ms +step:1307/1555 train_time:73991ms step_avg:56.61ms +step:1308/1555 train_time:74080ms step_avg:56.64ms +step:1309/1555 train_time:74165ms step_avg:56.66ms +step:1310/1555 train_time:74255ms step_avg:56.68ms +step:1311/1555 train_time:74339ms step_avg:56.70ms +step:1312/1555 train_time:74432ms step_avg:56.73ms +step:1313/1555 train_time:74515ms step_avg:56.75ms +step:1314/1555 train_time:74605ms step_avg:56.78ms +step:1315/1555 train_time:74688ms step_avg:56.80ms +step:1316/1555 train_time:74778ms step_avg:56.82ms +step:1317/1555 train_time:74861ms step_avg:56.84ms +step:1318/1555 train_time:74951ms step_avg:56.87ms +step:1319/1555 train_time:75035ms step_avg:56.89ms +step:1320/1555 train_time:75125ms step_avg:56.91ms +step:1321/1555 train_time:75208ms step_avg:56.93ms +step:1322/1555 train_time:75299ms step_avg:56.96ms +step:1323/1555 train_time:75383ms step_avg:56.98ms +step:1324/1555 train_time:75474ms step_avg:57.00ms +step:1325/1555 train_time:75558ms step_avg:57.02ms +step:1326/1555 train_time:75649ms step_avg:57.05ms +step:1327/1555 train_time:75734ms step_avg:57.07ms +step:1328/1555 train_time:75823ms step_avg:57.10ms +step:1329/1555 train_time:75906ms step_avg:57.11ms +step:1330/1555 train_time:75995ms step_avg:57.14ms +step:1331/1555 train_time:76078ms step_avg:57.16ms +step:1332/1555 train_time:76171ms step_avg:57.19ms +step:1333/1555 train_time:76254ms step_avg:57.21ms +step:1334/1555 train_time:76344ms step_avg:57.23ms +step:1335/1555 train_time:76429ms step_avg:57.25ms +step:1336/1555 train_time:76518ms step_avg:57.27ms +step:1337/1555 train_time:76604ms step_avg:57.30ms +step:1338/1555 train_time:76694ms step_avg:57.32ms +step:1339/1555 train_time:76777ms step_avg:57.34ms +step:1340/1555 train_time:76868ms step_avg:57.36ms +step:1341/1555 train_time:76953ms step_avg:57.38ms +step:1342/1555 train_time:77041ms step_avg:57.41ms +step:1343/1555 train_time:77127ms step_avg:57.43ms +step:1344/1555 train_time:77217ms step_avg:57.45ms +step:1345/1555 train_time:77302ms step_avg:57.47ms +step:1346/1555 train_time:77393ms step_avg:57.50ms +step:1347/1555 train_time:77477ms step_avg:57.52ms +step:1348/1555 train_time:77567ms step_avg:57.54ms +step:1349/1555 train_time:77651ms step_avg:57.56ms +step:1350/1555 train_time:77740ms step_avg:57.59ms +step:1351/1555 train_time:77825ms step_avg:57.61ms +step:1352/1555 train_time:77915ms step_avg:57.63ms +step:1353/1555 train_time:77999ms step_avg:57.65ms +step:1354/1555 train_time:78089ms step_avg:57.67ms +step:1355/1555 train_time:78173ms step_avg:57.69ms +step:1356/1555 train_time:78263ms step_avg:57.72ms +step:1357/1555 train_time:78348ms step_avg:57.74ms +step:1358/1555 train_time:78438ms step_avg:57.76ms +step:1359/1555 train_time:78523ms step_avg:57.78ms +step:1360/1555 train_time:78613ms step_avg:57.80ms +step:1361/1555 train_time:78696ms step_avg:57.82ms +step:1362/1555 train_time:78787ms step_avg:57.85ms +step:1363/1555 train_time:78871ms step_avg:57.87ms +step:1364/1555 train_time:78961ms step_avg:57.89ms +step:1365/1555 train_time:79045ms step_avg:57.91ms +step:1366/1555 train_time:79135ms step_avg:57.93ms +step:1367/1555 train_time:79219ms step_avg:57.95ms +step:1368/1555 train_time:79308ms step_avg:57.97ms +step:1369/1555 train_time:79393ms step_avg:57.99ms +step:1370/1555 train_time:79483ms step_avg:58.02ms +step:1371/1555 train_time:79568ms step_avg:58.04ms +step:1372/1555 train_time:79657ms step_avg:58.06ms +step:1373/1555 train_time:79742ms step_avg:58.08ms +step:1374/1555 train_time:79833ms step_avg:58.10ms +step:1375/1555 train_time:79916ms step_avg:58.12ms +step:1376/1555 train_time:80007ms step_avg:58.14ms +step:1377/1555 train_time:80091ms step_avg:58.16ms +step:1378/1555 train_time:80180ms step_avg:58.19ms +step:1379/1555 train_time:80264ms step_avg:58.20ms +step:1380/1555 train_time:80355ms step_avg:58.23ms +step:1381/1555 train_time:80438ms step_avg:58.25ms +step:1382/1555 train_time:80529ms step_avg:58.27ms +step:1383/1555 train_time:80612ms step_avg:58.29ms +step:1384/1555 train_time:80702ms step_avg:58.31ms +step:1385/1555 train_time:80786ms step_avg:58.33ms +step:1386/1555 train_time:80875ms step_avg:58.35ms +step:1387/1555 train_time:80960ms step_avg:58.37ms +step:1388/1555 train_time:81050ms step_avg:58.39ms +step:1389/1555 train_time:81134ms step_avg:58.41ms +step:1390/1555 train_time:81224ms step_avg:58.43ms +step:1391/1555 train_time:81308ms step_avg:58.45ms +step:1392/1555 train_time:81398ms step_avg:58.48ms +step:1393/1555 train_time:81482ms step_avg:58.49ms +step:1394/1555 train_time:81572ms step_avg:58.52ms +step:1395/1555 train_time:81655ms step_avg:58.53ms +step:1396/1555 train_time:81746ms step_avg:58.56ms +step:1397/1555 train_time:81831ms step_avg:58.58ms +step:1398/1555 train_time:81919ms step_avg:58.60ms +step:1399/1555 train_time:82003ms step_avg:58.62ms +step:1400/1555 train_time:82093ms step_avg:58.64ms +step:1401/1555 train_time:82177ms step_avg:58.66ms +step:1402/1555 train_time:82268ms step_avg:58.68ms +step:1403/1555 train_time:82353ms step_avg:58.70ms +step:1404/1555 train_time:82443ms step_avg:58.72ms +step:1405/1555 train_time:82527ms step_avg:58.74ms +step:1406/1555 train_time:82616ms step_avg:58.76ms +step:1407/1555 train_time:82700ms step_avg:58.78ms +step:1408/1555 train_time:82792ms step_avg:58.80ms +step:1409/1555 train_time:82875ms step_avg:58.82ms +step:1410/1555 train_time:82966ms step_avg:58.84ms +step:1411/1555 train_time:83050ms step_avg:58.86ms +step:1412/1555 train_time:83139ms step_avg:58.88ms +step:1413/1555 train_time:83224ms step_avg:58.90ms +step:1414/1555 train_time:83314ms step_avg:58.92ms +step:1415/1555 train_time:83398ms step_avg:58.94ms +step:1416/1555 train_time:83487ms step_avg:58.96ms +step:1417/1555 train_time:83572ms step_avg:58.98ms +step:1418/1555 train_time:83661ms step_avg:59.00ms +step:1419/1555 train_time:83746ms step_avg:59.02ms +step:1420/1555 train_time:83836ms step_avg:59.04ms +step:1421/1555 train_time:83919ms step_avg:59.06ms +step:1422/1555 train_time:84010ms step_avg:59.08ms +step:1423/1555 train_time:84094ms step_avg:59.10ms +step:1424/1555 train_time:84184ms step_avg:59.12ms +step:1425/1555 train_time:84268ms step_avg:59.14ms +step:1426/1555 train_time:84358ms step_avg:59.16ms +step:1427/1555 train_time:84442ms step_avg:59.17ms +step:1428/1555 train_time:84532ms step_avg:59.20ms +step:1429/1555 train_time:84616ms step_avg:59.21ms +step:1430/1555 train_time:84706ms step_avg:59.24ms +step:1431/1555 train_time:84792ms step_avg:59.25ms +step:1432/1555 train_time:84881ms step_avg:59.27ms +step:1433/1555 train_time:84966ms step_avg:59.29ms +step:1434/1555 train_time:85055ms step_avg:59.31ms +step:1435/1555 train_time:85139ms step_avg:59.33ms +step:1436/1555 train_time:85230ms step_avg:59.35ms +step:1437/1555 train_time:85314ms step_avg:59.37ms +step:1438/1555 train_time:85403ms step_avg:59.39ms +step:1439/1555 train_time:85488ms step_avg:59.41ms +step:1440/1555 train_time:85578ms step_avg:59.43ms +step:1441/1555 train_time:85662ms step_avg:59.45ms +step:1442/1555 train_time:85754ms step_avg:59.47ms +step:1443/1555 train_time:85837ms step_avg:59.49ms +step:1444/1555 train_time:85928ms step_avg:59.51ms +step:1445/1555 train_time:86011ms step_avg:59.52ms +step:1446/1555 train_time:86101ms step_avg:59.54ms +step:1447/1555 train_time:86186ms step_avg:59.56ms +step:1448/1555 train_time:86276ms step_avg:59.58ms +step:1449/1555 train_time:86360ms step_avg:59.60ms +step:1450/1555 train_time:86450ms step_avg:59.62ms +step:1451/1555 train_time:86535ms step_avg:59.64ms +step:1452/1555 train_time:86625ms step_avg:59.66ms +step:1453/1555 train_time:86708ms step_avg:59.68ms +step:1454/1555 train_time:86799ms step_avg:59.70ms +step:1455/1555 train_time:86883ms step_avg:59.71ms +step:1456/1555 train_time:86974ms step_avg:59.74ms +step:1457/1555 train_time:87057ms step_avg:59.75ms +step:1458/1555 train_time:87148ms step_avg:59.77ms +step:1459/1555 train_time:87232ms step_avg:59.79ms +step:1460/1555 train_time:87321ms step_avg:59.81ms +step:1461/1555 train_time:87405ms step_avg:59.83ms +step:1462/1555 train_time:87495ms step_avg:59.85ms +step:1463/1555 train_time:87580ms step_avg:59.86ms +step:1464/1555 train_time:87670ms step_avg:59.88ms +step:1465/1555 train_time:87754ms step_avg:59.90ms +step:1466/1555 train_time:87844ms step_avg:59.92ms +step:1467/1555 train_time:87930ms step_avg:59.94ms +step:1468/1555 train_time:88018ms step_avg:59.96ms +step:1469/1555 train_time:88103ms step_avg:59.97ms +step:1470/1555 train_time:88194ms step_avg:60.00ms +step:1471/1555 train_time:88276ms step_avg:60.01ms +step:1472/1555 train_time:88366ms step_avg:60.03ms +step:1473/1555 train_time:88450ms step_avg:60.05ms +step:1474/1555 train_time:88540ms step_avg:60.07ms +step:1475/1555 train_time:88625ms step_avg:60.08ms +step:1476/1555 train_time:88715ms step_avg:60.10ms +step:1477/1555 train_time:88799ms step_avg:60.12ms +step:1478/1555 train_time:88889ms step_avg:60.14ms +step:1479/1555 train_time:88973ms step_avg:60.16ms +step:1480/1555 train_time:89063ms step_avg:60.18ms +step:1481/1555 train_time:89148ms step_avg:60.19ms +step:1482/1555 train_time:89238ms step_avg:60.21ms +step:1483/1555 train_time:89323ms step_avg:60.23ms +step:1484/1555 train_time:89414ms step_avg:60.25ms +step:1485/1555 train_time:89498ms step_avg:60.27ms +step:1486/1555 train_time:89588ms step_avg:60.29ms +step:1487/1555 train_time:89671ms step_avg:60.30ms +step:1488/1555 train_time:89761ms step_avg:60.32ms +step:1489/1555 train_time:89847ms step_avg:60.34ms +step:1490/1555 train_time:89937ms step_avg:60.36ms +step:1491/1555 train_time:90021ms step_avg:60.38ms +step:1492/1555 train_time:90111ms step_avg:60.40ms +step:1493/1555 train_time:90195ms step_avg:60.41ms +step:1494/1555 train_time:90284ms step_avg:60.43ms +step:1495/1555 train_time:90370ms step_avg:60.45ms +step:1496/1555 train_time:90459ms step_avg:60.47ms +step:1497/1555 train_time:90544ms step_avg:60.48ms +step:1498/1555 train_time:90635ms step_avg:60.50ms +step:1499/1555 train_time:90718ms step_avg:60.52ms +step:1500/1555 train_time:90808ms step_avg:60.54ms +step:1500/1555 val_loss:3.2959 train_time:90924ms step_avg:60.62ms +step:1501/1555 train_time:90942ms step_avg:60.59ms +step:1502/1555 train_time:90986ms step_avg:60.58ms +step:1503/1555 train_time:91071ms step_avg:60.59ms +step:1504/1555 train_time:91164ms step_avg:60.61ms +step:1505/1555 train_time:91250ms step_avg:60.63ms +step:1506/1555 train_time:91341ms step_avg:60.65ms +step:1507/1555 train_time:91423ms step_avg:60.67ms +step:1508/1555 train_time:91512ms step_avg:60.68ms +step:1509/1555 train_time:91595ms step_avg:60.70ms +step:1510/1555 train_time:91684ms step_avg:60.72ms +step:1511/1555 train_time:91767ms step_avg:60.73ms +step:1512/1555 train_time:91856ms step_avg:60.75ms +step:1513/1555 train_time:91943ms step_avg:60.77ms +step:1514/1555 train_time:92033ms step_avg:60.79ms +step:1515/1555 train_time:92120ms step_avg:60.81ms +step:1516/1555 train_time:92216ms step_avg:60.83ms +step:1517/1555 train_time:92302ms step_avg:60.85ms +step:1518/1555 train_time:92391ms step_avg:60.86ms +step:1519/1555 train_time:92476ms step_avg:60.88ms +step:1520/1555 train_time:92565ms step_avg:60.90ms +step:1521/1555 train_time:92648ms step_avg:60.91ms +step:1522/1555 train_time:92737ms step_avg:60.93ms +step:1523/1555 train_time:92821ms step_avg:60.95ms +step:1524/1555 train_time:92911ms step_avg:60.97ms +step:1525/1555 train_time:92996ms step_avg:60.98ms +step:1526/1555 train_time:93089ms step_avg:61.00ms +step:1527/1555 train_time:93175ms step_avg:61.02ms +step:1528/1555 train_time:93266ms step_avg:61.04ms +step:1529/1555 train_time:93350ms step_avg:61.05ms +step:1530/1555 train_time:93441ms step_avg:61.07ms +step:1531/1555 train_time:93524ms step_avg:61.09ms +step:1532/1555 train_time:93613ms step_avg:61.10ms +step:1533/1555 train_time:93696ms step_avg:61.12ms +step:1534/1555 train_time:93787ms step_avg:61.14ms +step:1535/1555 train_time:93870ms step_avg:61.15ms +step:1536/1555 train_time:93961ms step_avg:61.17ms +step:1537/1555 train_time:94047ms step_avg:61.19ms +step:1538/1555 train_time:94139ms step_avg:61.21ms +step:1539/1555 train_time:94224ms step_avg:61.22ms +step:1540/1555 train_time:94314ms step_avg:61.24ms +step:1541/1555 train_time:94399ms step_avg:61.26ms +step:1542/1555 train_time:94490ms step_avg:61.28ms +step:1543/1555 train_time:94574ms step_avg:61.29ms +step:1544/1555 train_time:94663ms step_avg:61.31ms +step:1545/1555 train_time:94748ms step_avg:61.33ms +step:1546/1555 train_time:94837ms step_avg:61.34ms +step:1547/1555 train_time:94921ms step_avg:61.36ms +step:1548/1555 train_time:95011ms step_avg:61.38ms +step:1549/1555 train_time:95096ms step_avg:61.39ms +step:1550/1555 train_time:95189ms step_avg:61.41ms +step:1551/1555 train_time:95275ms step_avg:61.43ms +step:1552/1555 train_time:95366ms step_avg:61.45ms +step:1553/1555 train_time:95450ms step_avg:61.46ms +step:1554/1555 train_time:95540ms step_avg:61.48ms +step:1555/1555 train_time:95624ms step_avg:61.49ms +step:1555/1555 val_loss:3.2797 train_time:95739ms step_avg:61.57ms +peak memory allocated: 31630 MiB reserved: 46578 MiB diff --git a/records/track_1_short/2026-01-31-BigramHashH2D/8d0fb296-a033-445f-ac3a-7cbd6d0e4af6.txt b/records/track_1_short/2026-01-31-BigramHashH2D/8d0fb296-a033-445f-ac3a-7cbd6d0e4af6.txt new file mode 100644 index 000000000..9512d52b0 --- /dev/null +++ b/records/track_1_short/2026-01-31-BigramHashH2D/8d0fb296-a033-445f-ac3a-7cbd6d0e4af6.txt @@ -0,0 +1,3976 @@ +import os +import sys + +# Read the current file and the kernels file code ASAP, for logging +with open(sys.argv[0], 'r') as f: + code = f.read() +with open(os.path.join(os.path.dirname(sys.argv[0]), 'triton_kernels.py'), 'r') as f: + code += f"\n\n{'-'*40}\n# triton_kernels.py\n{'-'*40}\n\n" + code += f.read() + +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate, pairwise +from pathlib import Path +import gc + +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +import torch +import triton + +torch.empty( + 1, device=f"cuda:{os.environ['LOCAL_RANK']}", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +from kernels import get_kernel +from torch import Tensor, nn + +from triton_kernels import XXT, ba_plus_cAA, FusedLinearReLUSquareFunction, FusedSoftcappedCrossEntropy + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Distributed training setup +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +grad_scale = 2 / grad_accum_steps # consistent grad magnitudes between different num_devices +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng +# Transposed layout by @ChrisJMcCormick allows for faster gradient accumulation. + +@torch.library.custom_op("nanogpt::mm_t", mutates_args=()) +def mm_t_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + """Computes y = x @ w with F8 weights stored as (in_features, out_features).""" + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + assert x.shape[1] == w.shape[0] # x: (batch, in), w: (in, out) + + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + + # _scaled_mm requires column-major B. w_f8 is row-major (in, out). + # .T.contiguous().T creates a column-major view without changing logical shape. + w_f8_col_major = w_f8.T.contiguous().T + + out = torch._scaled_mm( + x_f8, + w_f8_col_major, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_t_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[0] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_t_backward", mutates_args=()) +def mm_t_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + + x_scale = grad.new_tensor(x_s, dtype=torch.float32) + w_scale = grad.new_tensor(w_s, dtype=torch.float32) + grad_scale = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + + # grad_x = grad @ w.T + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=grad_scale, + scale_b=w_scale, + use_fast_accum=False, + ) + + # grad_w = x.T @ grad + # Result is (in, out), naturally matching weight storage. No final .T needed. + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_scale, + scale_b=grad_scale, + use_fast_accum=False, + ) + + return grad_x, grad_w + + grad_x, grad_w = impl(g, x_f8, w_f8) + + return grad_x, grad_w + +@mm_t_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) + +def backward_t(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_t_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context_t(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_t_op.register_autograd(backward_t, setup_context=setup_context_t) + +# ----------------------------------------------------------------------------- +# Polar Express + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor, split_baddbmm: bool = False): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + # Select batched vs unbatched + if split_baddbmm: + BX_matmul = torch.bmm if X.ndim > 2 else torch.mm + else: + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in polar_express_coeffs: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + + # Referencing X twice causes pytorch to make a defensive copy, + # resulting in a cudaMemcpyAsync in baddbmm. + # For large matrices (i.e., the mlp weights), it's faster to split + # the operation into two kernels to avoid this. + if split_baddbmm: + BX_matmul(B, X, out=C) # C = B @ X + C.add_(X, alpha=a) # C = C + a*X (in-place, X only read) + else: + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Combined NorMuon + Adam Optimizer + +@dataclass +class ParamConfig: + """Per-parameter configuration for NorMuonAndAdam optimizer.""" + label: str + optim: str # "adam" or "normuon" + comms: str # "none", "replicated", or "sharded" + adam_betas: tuple[float, float] | None + lr_mul: float + wd_mul: float + lr: float + initial_lr: float + weight_decay: float + # Adam-specific + eps: float | None = None + # NorMuon-specific + reshape: tuple | None = None + chunk_size: int | None = None + momentum: float | None = None + beta2: float | None = None + per_matrix_lr_mul: list[float] | None = None + + +class NorMuonAndAdam: + """ + Combined optimizer that handles both NorMuon (for projection matrices) and + Adam (for embeddings/scalars/gate weights). + + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, Muon uses a Newton-Schulz iteration (replaced + here with Polar Express), which has the advantage that it can be stably run in bfloat16 on the GPU. + + Muon is applied only to the projection matrices in the attention and MLP layers, and is not recommended + for embeddings, scalars, or individual weight vectors (e.g., bias terms or gate weights). + + Differences from standard Muon: + - Newton-Shulz is replaced with Polar Express for the orthogonalization step + - NorMuon adds a low-rank variance estimator similar to Adafactor. https://arxiv.org/pdf/2510.05491 + - Cautious weight decay, a gated version of decoupled weight decay + - Mantissa tracking for precision + + Adam (for embeddings/scalars/gates): + - Standard Adam with bias correction + - Cautious weight decay + + Configuration: + Unlike torch.optim.Optimizer, this class uses per-parameter configs from a `param_table` dict + and does not include parameter "groups". All parameters require a .label attribute, and a + corresponding entry in the param_table to specify their hyperparameters (lr_mul, wd_mul, adam_betas, etc.). + + Communication and ordering: + Gradient communication is explicitly scheduled rather than hook-driven. + Reductions are launched in `scatter_order`, while update math and final + gathers are executed in `work_order`. These orders are independent and + must each contain every parameter label exactly once. + + Two communication modes are supported per parameter: + - 'replicated': Gradients are all-reduced and each rank computes the full update. + - 'sharded': Gradients are reduce-scattered, each rank updates its shard, + and results are all-gathered. + + Adam parameters may be freely sharded. NorMuon operates on full matrices; sharding is + supported by grouping matrices into parameter banks. NorMuon parameters must have a + `.reshape` attribute that reshapes the bank so that the leading dimension is divisible + by world_size. + + # Contributors include @YouJiacheng, @KonstantinWilleke, @alexrgilbert, @adricarda, + # @tuttyfrutyee, @vdlad, @ryanyang0, @vagrawal, @varunneal, @chrisjmccormick + """ + def __init__(self, named_params, param_table: dict, scatter_order: list, work_order: list, + adam_defaults: dict, normuon_defaults: dict): + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # Store defaults for each optimizer type + self.adam_defaults = adam_defaults + self.normuon_defaults = normuon_defaults + self.param_table = param_table + self.scatter_order = scatter_order + self.work_order = work_order + + # Collect params by label and build config + self.param_cfgs: dict[nn.Parameter, ParamConfig] = {} + self.param_states: dict[nn.Parameter, dict] = {} + self._param_by_label: dict[str, nn.Parameter] = {} + for name, param in named_params: + label = getattr(param, "label", None) + assert label is not None and label in param_table # all params must have valid label + assert label not in self._param_by_label # exactly one param per label + self._param_by_label[label] = param + self._build_param_cfg(param, label) + + # Assert scatter_order and work_order match present labels exactly + present = set(self._param_by_label.keys()) + assert set(scatter_order) == present and set(work_order) == present + + # Handle world_size=1: overwrite comms to "none" + if self.world_size == 1: + for p_cfg in self.param_cfgs.values(): + p_cfg.comms = "none" + + # Initialize state for all params + self._init_state() + + # 0-D CPU tensors to avoid recompilation + self._step_size_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + + # Track async operations + self._reduce_futures: dict[nn.Parameter, tuple] = {} + + # Embed/lm_head tying state + self.split_embed = False + self._lm_head_param = self._param_by_label.get("lm_head") + self._embed_param = self._param_by_label.get("embed") + + def _build_param_cfg(self, param: nn.Parameter, label: str): + """Build config for a single parameter from param_table.""" + table_entry = self.param_table[label] + optim = table_entry["optim"] + comms = table_entry["comms"] + adam_betas = table_entry.get("adam_betas") + lr_mul = table_entry.get("lr_mul", 1.0) + wd_mul = table_entry.get("wd_mul", 1.0) + + if optim == "adam": + chunk_size = param.shape[0] // self.world_size if comms == "sharded" else None + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.adam_defaults["lr"], + initial_lr=self.adam_defaults["lr"], + weight_decay=self.adam_defaults["weight_decay"], + eps=self.adam_defaults["eps"], + chunk_size=chunk_size, + ) + elif optim == "normuon": + reshape = getattr(param, "reshape", None) + if reshape is None: + raise ValueError(f"NorMuon param {label} must have .reshape attribute") + if reshape[0] % self.world_size != 0: + raise ValueError(f"reshape[0]={reshape[0]} must be divisible by world_size") + + chunk_size = reshape[0] // self.world_size + chunk_shape = (chunk_size, *reshape[1:]) + # Shape-based LR multiplier for NorMuon + shape_mult = max(1.0, chunk_shape[-2] / chunk_shape[-1]) ** 0.5 if len(chunk_shape) >= 2 else 1.0 + lr_mul = shape_mult * lr_mul + + # Per-matrix LR multipliers for MLP c_proj (2x LR on odd indices) + per_matrix_lr_mul = None + if label == "mlp": + rank = dist.get_rank() if dist.is_initialized() else 0 + start_idx = rank * chunk_size + per_matrix_lr_mul = [] + for i in range(chunk_size): + global_idx = start_idx + i + is_c_proj = (global_idx % 2 == 1) + per_matrix_lr_mul.append(2.0 if is_c_proj else 1.0) + + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.normuon_defaults["lr"], + initial_lr=self.normuon_defaults["lr"], + weight_decay=self.normuon_defaults["weight_decay"], + reshape=reshape, + chunk_size=chunk_size, + momentum=self.normuon_defaults["momentum"], + beta2=self.normuon_defaults["beta2"], + per_matrix_lr_mul=per_matrix_lr_mul, + ) + else: + raise ValueError(f"Unknown optim type: {optim}") + + self.param_cfgs[param] = p_cfg + + def _init_state(self): + """Initialize optimizer state for all parameters.""" + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam": + # Sharded params use chunk state, replicated use full state + if p_cfg.comms == "sharded": + chunk = param[:p_cfg.chunk_size] + else: + chunk = param + exp_avg = torch.zeros_like(chunk, dtype=torch.float32, device=param.device) + self.param_states[param] = dict(step=0, exp_avg=exp_avg, exp_avg_sq=torch.zeros_like(exp_avg)) + + elif p_cfg.optim == "normuon": + chunk_shape = (p_cfg.chunk_size, *p_cfg.reshape[1:]) + + # Momentum buffer (FP32 for precision) + momentum_buffer = torch.zeros( + chunk_shape, dtype=torch.float32, device=param.device + ) + + # Second momentum buffer - reduced along one dimension + if chunk_shape[-2] >= chunk_shape[-1]: + second_mom_shape = (*chunk_shape[:-1], 1) + else: + second_mom_shape = (*chunk_shape[:-2], 1, chunk_shape[-1]) + second_momentum_buffer = torch.zeros( + second_mom_shape, dtype=torch.float32, device=param.device + ) + + # Mantissa buffer for precision tracking + mantissa = torch.zeros( + chunk_shape, dtype=torch.uint16, device=param.device + ) + + self.param_states[param] = dict( + momentum_buffer=momentum_buffer, + second_momentum_buffer=second_momentum_buffer, + mantissa=mantissa, + ) + + # ----------------------------------- + # Reduce/Gather operations + + def _launch_reduce(self, param: nn.Parameter, grad: Tensor): + """Launch async reduce for a parameter based on its comms policy.""" + p_cfg = self.param_cfgs[param] + + if p_cfg.comms == "none": + if p_cfg.optim == "normuon": + # NorMuon needs reshaped gradient even without communication + grad = grad.view(p_cfg.reshape) + self._reduce_futures[param] = (None, grad) + elif p_cfg.comms == "replicated": + future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + self._reduce_futures[param] = (future, grad) + elif p_cfg.comms == "sharded": + if p_cfg.optim == "normuon": + # NorMuon: reshape before reduce_scatter + grad_reshaped = grad.view(p_cfg.reshape) + grad_chunk = torch.empty( + (p_cfg.chunk_size, *grad_reshaped.shape[1:]), + dtype=grad.dtype, + device=grad.device + ) + future = dist.reduce_scatter_tensor( + grad_chunk, grad_reshaped.contiguous(), op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + else: + # Adam: simple reduce_scatter + grad_chunk = torch.empty_like(grad[:p_cfg.chunk_size]) + future = dist.reduce_scatter_tensor( + grad_chunk, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + + def _launch_gather(self, param: nn.Parameter, p_slice: Tensor) -> "torch.futures.Future": + """Launch async all_gather for a sharded parameter.""" + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "normuon": + full_param = param.data.view(p_cfg.reshape) + assert full_param.is_contiguous() + return dist.all_gather_into_tensor( + full_param, p_slice.contiguous(), async_op=True + ).get_future() + else: + return dist.all_gather_into_tensor( + param, p_slice.contiguous(), async_op=True + ).get_future() + + # ----------------------------------- + # State management + + def reset(self): + """Reset NorMuon momentum buffers and split_embed state (called on training reset).""" + self.split_embed = False + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "normuon": + p_state = self.param_states[param] + p_state["momentum_buffer"].zero_() + p_state["mantissa"].zero_() + p_state["second_momentum_buffer"].zero_() + + def copy_lm_state_to_embed(self): + """ + Copy the optimizer state from the lm_head to the embed at the untie point. + This requires an all-gather + reshard because of different sharding: + - lm_head (768, 50304) is sharded to (96, 50304) per rank (along model_dim) + - embed (50304, 768) is sharded to (6288, 768) per rank (along vocab_size) + + We all-gather the lm_head momentum, transpose it, then each rank takes their + embed shard to get the correct momentum state. + """ + lm_head = self._lm_head_param + embed = self._embed_param + lm_state = self.param_states[lm_head] + embed_state = self.param_states[embed] + lm_cfg = self.param_cfgs[lm_head] + embed_cfg = self.param_cfgs[embed] + + embed_state['step'] = lm_state['step'] # Preserve step count for bias correction + + # Copy optimizer state with all-gather + transpose + reshard + if self.world_size > 1: + rank = dist.get_rank() + lm_chunk_size = lm_cfg.chunk_size # 96 + embed_chunk_size = embed_cfg.chunk_size # 6288 + + # All-gather lm_head momentum to get full (768, 50304) tensor + for key in ["exp_avg", "exp_avg_sq"]: + lm_chunk = lm_state[key] # (96, 50304) + full_lm = torch.empty(lm_head.shape[0], lm_head.shape[1], dtype=lm_chunk.dtype, device=lm_chunk.device) + dist.all_gather_into_tensor(full_lm, lm_chunk.contiguous()) + embed_state[key].copy_(full_lm.T[rank * embed_chunk_size:(rank + 1) * embed_chunk_size]) + else: + # Single GPU: simple transpose + for key in ["exp_avg", "exp_avg_sq"]: + embed_state[key].copy_(lm_state[key].T) + + # Mark as split + self.split_embed = True + + def state_dict(self): + """Return the optimizer state as a dict.""" + return { + "param_states": {id(p): s for p, s in self.param_states.items()}, + "param_cfgs": {id(p): s for p, s in self.param_cfgs.items()}, + } + + def load_state_dict(self, state_dict): + """Load optimizer state from a dict.""" + # Build id->param mapping + id_to_param = {id(p): p for p in self.param_cfgs.keys()} + + # Load state, preserving dtypes + for param_id, saved_p_state in state_dict["param_states"].items(): + if param_id in id_to_param: + param = id_to_param[param_id] + p_state = self.param_states[param] + for k, v in saved_p_state.items(): + if isinstance(v, torch.Tensor) and k in p_state: + target_dtype = p_state[k].dtype + p_state[k] = v.to(dtype=target_dtype, device=p_state[k].device) + else: + p_state[k] = v + + # ----------------------------------- + # Unified optimizer step with explicit ordering + + @torch.no_grad() + def step(self, do_adam: bool = True): + """ + Combined optimizer step with explicit ordering. + + Args: + do_adam: If True, update Adam params. NorMuon params always updated. + + Flow: + 1. Scatter phase: Launch reduces in scatter_order + 2. Work phase: Process updates in work_order + - Wait for reduce, compute update, launch gather + 3. Finalize phase: Wait for gathers + + While the embeddings are tied: + - Comms and update math are only done on lm_head. + - We add embed.grad.T into lm_head.grad before comms. + - After lm_head gather, we copy lm_head.data.T --> embed.data + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + lm_param, embed_param = self._lm_head_param, self._embed_param + + # ===== Phase 1: Launch reduces in scatter_order ===== + for label in self.scatter_order: + param = self._param_by_label[label] + p_cfg = self.param_cfgs[param] + + if p_cfg.optim == "adam" and not do_adam: + continue + if param.grad is None: + continue + + # lm_head when tied: aggregate embed.grad.T (transposed shapes) + if label == "lm_head" and do_adam and not self.split_embed: + if embed_param is not None and embed_param.grad is not None: + param.grad.add_(embed_param.grad.T) + + # Skip embed when tied (copied from lm_head after gather) + if label == "embed" and not self.split_embed: + continue + + self._launch_reduce(param, param.grad) + + # ===== Phase 2: Process updates in work_order ===== + gather_futures = [] + lm_head_gather_future = None + + for label in self.work_order: + param = self._param_by_label[label] + if param not in self._reduce_futures: + continue + + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "adam" and not do_adam: + continue + # Wait for reduce + future, grad_chunk = self._reduce_futures[param] + if future is not None: + future.wait() + # Apply update based on optim type + if p_cfg.optim == "adam": + p_slice = self._adam_update(param, grad_chunk, p_cfg, rank) + else: + p_slice = self._normuon_update(param, grad_chunk, p_cfg, rank) + # Launch gather for sharded params + if p_cfg.comms == "sharded" and self.world_size > 1: + gather_fut = self._launch_gather(param, p_slice) + if label == "lm_head": + lm_head_gather_future = gather_fut + else: + gather_futures.append(gather_fut) + + # ===== Phase 3: Wait for gathers, sync embed if tied ===== + # Wait for lm_head gather first so we can copy to embed while other gathers complete + if lm_head_gather_future is not None: + lm_head_gather_future.wait() + + # When tied: copy lm_head.T to embed + if do_adam and not self.split_embed and embed_param is not None and lm_param is not None: + embed_param.data.copy_(lm_param.data.T) + + # Wait for remaining gathers + for fut in gather_futures: + fut.wait() + + self._reduce_futures.clear() + + # Clear grads for updated params + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam" and not do_adam: + continue # Don't clear Adam grads on even steps + param.grad = None + + # ----------------------------------- + # Adam update + + def _adam_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply Adam update to a parameter. Returns the updated p_slice.""" + beta1, beta2 = p_cfg.adam_betas + lr = p_cfg.lr * p_cfg.lr_mul + + # Get parameter slice + if p_cfg.comms == "sharded": + p_slice = param[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + else: + p_slice = param + + p_state = self.param_states[param] + p_state["step"] += 1 + t = p_state["step"] + + bias1, bias2 = 1 - beta1 ** t, 1 - beta2 ** t + self._step_size_t.fill_(lr * (bias2 ** 0.5 / bias1)) + self._eff_wd_t.fill_(lr * lr * p_cfg.weight_decay * p_cfg.wd_mul) + + NorMuonAndAdam._adam_update_step( + p_slice, grad_chunk, p_state["exp_avg"], p_state["exp_avg_sq"], + beta1, beta2, p_cfg.eps, self._step_size_t, self._eff_wd_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _adam_update_step(p_slice, g_slice, exp_avg, exp_avg_sq, beta1, beta2, eps, step_size_t, eff_wd_t): + """Compiled Adam update step.""" + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + update = exp_avg.div(exp_avg_sq.sqrt().add_(eps)).mul_(step_size_t) + # Cautious weight decay + mask = (update * p_slice) > 0 + update.addcmul_(p_slice, mask, value=eff_wd_t) + p_slice.add_(other=update, alpha=-1.0) + + # ----------------------------------- + # NorMuon update + + def _normuon_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply NorMuon update to a parameter. Returns the updated p_slice.""" + chunk_shape = grad_chunk.shape + + p_state = self.param_states[param] + grad_chunk = grad_chunk.float() # FP32 for momentum + + # Momentum update + momentum_buffer = p_state["momentum_buffer"] + momentum_buffer.lerp_(grad_chunk, 1 - p_cfg.momentum) + updated_grads = grad_chunk.lerp_(momentum_buffer, p_cfg.momentum) + + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + + # Polar Express orthogonalization + is_large_matrix = chunk_shape[-2] > 1024 + v_chunk = polar_express(updated_grads, split_baddbmm=is_large_matrix) + + # Variance reduction + red_dim = -1 if chunk_shape[-2] >= chunk_shape[-1] else -2 + v_chunk = NorMuonAndAdam._apply_normuon_variance_reduction( + v_chunk, p_state["second_momentum_buffer"], p_cfg.beta2, red_dim + ) + + # Update parameter, in place, with cautious weight decay + param_view = param.data.view(p_cfg.reshape) + p_slice = param_view[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + + # MLP has per-matrix LR multipliers (c_proj gets 2x LR) + if p_cfg.per_matrix_lr_mul is not None: + for mat_idx in range(p_cfg.chunk_size): + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.per_matrix_lr_mul[mat_idx] * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice[mat_idx].view(torch.uint16), p_state["mantissa"][mat_idx], v_chunk[mat_idx], + self._eff_wd_t, self._eff_lr_t + ) + else: + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice.view(torch.uint16), p_state["mantissa"], v_chunk, + self._eff_wd_t, self._eff_lr_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _cautious_wd_and_update_inplace(p, mantissa, grad, wd_tensor, lr_tensor): + """ + Cautious weight decay + parameter update. wd_tensor and lr_tensor are 0-D CPU tensors. + Mantissa is tracked to enable higher precision updates on bfloat16 parameters. + bfloat16 format: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits total + float32 format: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits total + """ + assert p.dtype == mantissa.dtype == torch.uint16 + grad = grad.float() + wd_factor = wd_tensor.to(torch.float32) + lr_factor = lr_tensor.to(torch.float32) + p_precise_raw = (p.to(torch.uint32) << 16) | mantissa.to(torch.uint32) + p_precise = p_precise_raw.view(torch.float32) + mask = (grad * p_precise) >= 0 + p_precise.copy_(p_precise - (p_precise * mask * wd_factor * lr_factor) - (grad * lr_factor)) + p.copy_((p_precise_raw >> 16).to(torch.uint16)) + mantissa.copy_(p_precise_raw.to(torch.uint16)) + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _apply_normuon_variance_reduction(v_chunk, second_momentum_buffer, beta2, red_dim): + """NorMuon variance reduction. Algebraically fuses the normalization steps to minimize memory ops.""" + v_mean = v_chunk.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = v_chunk.size(red_dim) + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True).mul_(red_dim_size) + v_norm = v_norm_sq.sqrt_() + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt_() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt_() + final_scale = step_size * (v_norm / v_norm_new.clamp_min_(1e-10)) + return v_chunk.mul_(final_scale.type_as(v_chunk)) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinearT(nn.Module): + """ + Linear layer with transposed weight storage (in_features, out_features) which + addresses the slow kernel that was used for gradient accumulation. @chrisjmccormick + """ + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + self.weight = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.bfloat16)) + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + nn.init.zeros_(self.weight) # @Grad62304977 and others + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out = torch.ops.nanogpt.mm_t(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return x @ self.weight.type_as(x) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len, paired=False): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.paired = paired + self.reset() + + def rotary(self, x_BTHD): + assert self.factor1.size(0) >= x_BTHD.size(-3) + factor1, factor2 = ( + self.factor1[None, : x_BTHD.size(-3), None, :], + self.factor2[None, : x_BTHD.size(-3), None, :], + ) + x_flip = x_BTHD.view(*x_BTHD.shape[:-1], x_BTHD.shape[-1] // 2, 2).flip(-1).view(x_BTHD.shape) + return factor1 * x_BTHD + factor2 * x_flip + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + angular_freq = angular_freq.repeat_interleave(2) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//2)]) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=device) + if not self.paired: + theta = torch.outer(t, angular_freq) + self.factor1 = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.factor2 = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, angular_freq) + theta2 = torch.outer(t_odd, angular_freq) + self.factor1 = nn.Buffer( + torch.cat((theta1.cos(), theta2.cos()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2 = nn.Buffer( + torch.cat((theta1.sin(), theta2.sin()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2[..., 1::2] *= -1 + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + if not self.paired: + theta = torch.outer(t, self.angular_freq) + self.factor1.copy_(theta.cos()) + self.factor2.copy_(theta.sin()) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, self.angular_freq) + theta2 = torch.outer(t_odd, self.angular_freq) + self.factor1.copy_(torch.cat((theta1.cos(), theta2.cos()), dim=-1)) + self.factor2.copy_(torch.cat((theta1.sin(), theta2.sin()), dim=-1)) + self.factor2[..., 1::2] *= -1 + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + yarn: Yarn + key_offset: bool + attn_gate_w: torch.Tensor + ve_gate_w: torch.Tensor + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, paired: bool = False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + self.paired = paired + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + yarn = attn_args.yarn + ve, sa_lambdas, key_offset = attn_args.ve, attn_args.sa_lambdas, attn_args.key_offset + seqlens, bm_size = attn_args.seqlens, attn_args.bm_size + # sparse gated attention to enable context based no-op by @classiclarryd + # only include gates on layers with value embeds used on forward pass + attn_gate_w, ve_gate_w = attn_args.attn_gate_w, attn_args.ve_gate_w + + q, k, v = F.linear(x, sa_lambdas[0] * qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + q, k = norm(q), norm(k) # QK norm @Grad62304977 + + if not self.paired: + q, k = yarn.rotary(q), yarn.rotary(k) + + if key_offset: + # shift keys forward for the stationary head dims. Enables 1-layer induction. + k[:, 1:, :, self.head_dim // 2:] = k[:, :-1, :, self.head_dim // 2:] + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T, self.num_heads, 1) + v = v + ve_gate_out * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + + else: + # Paired heads: adjacent heads' queries attend to each other's keys. + # Two copies of the input stream are interleaved to achieve this, which: + # - doubles the length of each sequence + # - halves the effective window size + q = q.view(B, T, self.num_heads // 2, self.head_dim * 2) + k = k.view(B, T, self.num_heads // 2, self.head_dim * 2) + v = v.reshape(B, T * 2, self.num_heads // 2, self.head_dim) + + q, k = yarn.rotary(q), yarn.rotary(k) + + q = q.view(B, T * 2, self.num_heads // 2, self.head_dim) + k = k.view(B, T * 2, self.num_heads // 2, self.head_dim) + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T * 2, self.num_heads // 2, 1) + v = v + ve_gate_out * ve.view_as(v) + + seqlens = 2 * seqlens + max_len = 2 * max_len + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, + max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=yarn.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(F.linear(x[..., :12], attn_gate_w)).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, sa_lambdas[1] * qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg + return y + +class MLP(nn.Module): + def __init__(self): + super().__init__() + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, c_fc: Tensor, c_proj: Tensor): + # relu(x)^2: + # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + # Fused triton kernel for relu(x @ W1.T)^2 @ W2.T + return FusedLinearReLUSquareFunction.apply(x, c_fc, c_proj) + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, has_attn: bool, has_mlp: bool, use_paired_head: bool): + super().__init__() + # skip attention of blocks.6 (the 7th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads, paired=use_paired_head) if has_attn else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP() if has_mlp else None + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor = None, c_fc: Tensor = None, c_proj: Tensor = None): + if self.attn is not None: + x = x + self.attn(norm(x), attn_args, qkvo_w) + if self.mlp is not None: + x = x + self.mlp(norm(x), c_fc, c_proj) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +@dataclass +class ForwardScheduleConfig: + mtp_weights: torch.Tensor + ws_short: int + ws_long: int + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + self.num_layers = num_layers + self.vocab_size = next_multiple_of_n(vocab_size, n=128) + + self.smear_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.smear_gate.weight) + self.smear_gate.weight.label = 'smear_gate' + + self.skip_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.skip_gate.weight) + self.skip_gate.weight.label = 'skip_gate' + + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.Parameter(torch.zeros(5 * self.vocab_size, model_dim, dtype=torch.bfloat16)) + self.value_embeds.label = 'value_embed' + + # parameter banks for attention and value embedding gate weights + self.attn_gate_bank = nn.Parameter(torch.zeros(10, num_heads, 12)) # 10 layers + self.attn_gate_bank.label = 'attn_gate_bank' + self.ve_gate_bank = nn.Parameter(torch.zeros(5, num_heads, 12)) # 5 unique gates + self.ve_gate_bank.label = 've_gate_bank' + + # ----------------------------------- + # Parameter banks for sharded optimization, by @chrisjmccormick + + # Identify which layers have attention/MLP + # Attention is skipped in layer 6 by @YouJiacheng + self.attn_layer_indices = [i for i in range(num_layers) if i != 6] + # All layers have MLP (At 11 layers--dropped first layer @EmelyanenkoK) + self.mlp_layer_indices = list(range(num_layers)) + + hdim = num_heads * head_dim + mlp_hdim = 4 * model_dim + + # Create index mappings: layer_idx -> bank_idx + self.layer_to_attn_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.attn_layer_indices)} + self.layer_to_mlp_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.mlp_layer_indices)} + + # Attention bank: stores QKVO weights for all attention layers + # merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # Simplified layout by @chrisjmccormick + # Shape: (num_attn_layers, 4*model_dim, hdim) = (10, 3072, 768) + # Reshape for sharding: (40, 768, 768) for even distribution across 8 GPUs + self.attn_bank = nn.Parameter(torch.empty(len(self.attn_layer_indices), 4 * model_dim, hdim)) + self.attn_bank.label = 'attn' + self.attn_bank.reshape = (len(self.attn_layer_indices) * 4, hdim, hdim) # (40, 768, 768) + + # MLP bank: stores c_fc and c_proj for all MLP layers + # Shape: (num_mlp_layers + padding, 2, mlp_hdim, model_dim) = (12, 2, 3072, 768) + # We add 1 padding layer (index 11) to get 12*2=24 matrices for even distribution across 8 GPUs + # Reshape for sharding: (24, 3072, 768) + num_mlp_with_padding = len(self.mlp_layer_indices) + 1 # 11 + 1 = 12 + self.mlp_bank = nn.Parameter(torch.empty(num_mlp_with_padding, 2, mlp_hdim, model_dim)) + self.mlp_bank.label = 'mlp' + self.mlp_bank.reshape = (num_mlp_with_padding * 2, mlp_hdim, model_dim) # (24, 3072, 768) + + # improved init scale by @YouJiacheng and @srashedll + std = 0.5 * model_dim ** -0.5 + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.attn_bank.uniform_(-bound, bound) + self.mlp_bank[:, 0, :, :].uniform_(-bound, bound) # c_fc + self.mlp_bank[:, 1, :, :].zero_() # c_proj - zero init suggested by @Grad62304977 + + # Create blocks with has_attn/has_mlp flags + self.paired_head_layers = [0, 2, 5, 9] + self.blocks = nn.ModuleList([ + Block(model_dim, head_dim, num_heads, + has_attn=(i in self.layer_to_attn_idx), + has_mlp=(i in self.layer_to_mlp_idx), + use_paired_head=(i in self.paired_head_layers)) + for i in range(num_layers) + ]) + self.yarn = Yarn(head_dim, max_seq_len) + self.yarn_paired_head = Yarn(head_dim, max_seq_len, paired=True) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + # Transposed weight storage for faster gradient accumulation + self.lm_head = CastedLinearT(model_dim, self.vocab_size, use_fp8=use_fp8, x_s=100/448, w_s=1.6/448, grad_s=grad_scale * 0.75/448) + + nn.init.normal_(self.lm_head.weight, mean=0, std=0.005) + self.lm_head.weight.label = 'lm_head' + + self.embed = nn.Embedding(self.vocab_size, model_dim) + self.embed.weight.label = 'embed' + with torch.no_grad(): + self.embed.weight.copy_(self.lm_head.weight.T) + + self.bigram_embed = nn.Embedding(args.bigram_vocab_size, model_dim) + self.bigram_embed.weight.label = 'bigram_embed' + nn.init.zeros_(self.bigram_embed.weight) + + # x0_lambdas separated out for different optimizer treatment (no beta smoothing) + self.x0_lambdas = nn.Parameter(torch.zeros(num_layers)) + self.x0_lambdas.label = 'x0_lambdas' + + pad = (-num_layers * 3 - 3) % dist.get_world_size() # updated: 3*num_layers instead of 4* + self.scalars = nn.Parameter( + torch.cat( + [ + 1.1 * torch.ones(num_layers), # resid lambdas. 1.1 init such that layer i weight is i^(num_layers-i). + *[torch.tensor([0.5, 1.0]) for _ in range(num_layers)], # SA lambdas + 0.1 * torch.ones(num_layers), # bigram lambdas + torch.zeros(1), # smear_lambda + 0.5*torch.ones(1), # backout_lambda + -1.5 * torch.ones(1), # skip_lambda -> σ(-1.5) ≈ 0.18 + torch.ones(pad), + ] + ) + ) + self.scalars.label = 'scalars' + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _compute_bigram_hash(x: Tensor, mod: int) -> Tensor: + """ + Computes bigram hash on GPU for each position using [prev_token, curr_token]. + Mathematically identical to the CPU version but computed on device. + """ + rand_int_1 = 36313 + rand_int_2 = 27191 + result = torch.empty_like(x) + result[0] = mod + result[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod + return result + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, schedule_cfg: ForwardScheduleConfig): + assert input_seq.ndim == 1 + + # unpack schedule_cfg + mtp_weights, ws_short, ws_long = schedule_cfg.mtp_weights, schedule_cfg.ws_short, schedule_cfg.ws_long + + # set configs + skip_connections = [] + skip_in = [3] # long attention window on layer 3 + skip_out = [6] # no attn op on layer 6 + x_backout = None + backout_layer = 7 + + # set lambdas + resid_lambdas = self.scalars[: 1 * self.num_layers] + x0_lambdas = self.x0_lambdas + sa_lambdas = self.scalars[1 * self.num_layers: 3 * self.num_layers].view(-1, 2) + bigram_lambdas = self.scalars[3 * self.num_layers: 4 * self.num_layers] + smear_lambda = self.scalars[4 * self.num_layers] + backout_lambda = self.scalars[4 * self.num_layers+1] + skip_lambda = self.scalars[4 * self.num_layers+2] + + # set block masks and key shift + bm_sizes = [ws_short, ws_short, ws_short, ws_long, ws_short, ws_short, None, ws_short, ws_short, ws_short, ws_long] + assert len(bm_sizes) == self.num_layers + key_offset = [b==ws_long for b in bm_sizes] # apply partial key offset to long windows + + # Embedding lookup - embed is synced from lm_head during tied phase by optimizer + x = self.embed(input_seq) + # Compute bigram hash on GPU (moved from CPU data loader) + bigram_seq = self._compute_bigram_hash(input_seq, args.bigram_vocab_size - 1) + x0_bigram = self.bigram_embed(bigram_seq)[None] + + # Value embeddings - always computed (not precomputed) + ve = self.value_embeds.view(5, self.vocab_size, -1)[:, input_seq] + # 01 ... 234 structure on token value embeddings by @photomz + ve = [ve[0], ve[1]] + [None] * (self.num_layers - 5) + [ve[2], ve[3], ve[4]] + assert len(ve) == self.num_layers + + # smear token embed forward 1 position @classiclarryd + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # unbind gate banks to avoid select_backwards kernel + ag = [w.bfloat16() for w in self.attn_gate_bank.unbind(0)] + veg = [w.bfloat16() for w in self.ve_gate_bank.unbind(0)] + attn_gates = ag[:6] + [None] + ag[6:] + ve_gates = [veg[0], veg[1]] + [None] * (self.num_layers - 5) + [veg[2], veg[3], veg[4]] + assert len(attn_gates) == self.num_layers + assert len(ve_gates) == self.num_layers + + # unbind weight banks to avoid select_backwards kernel + attn_weights = self.attn_bank.unbind(0) # tuple of [4*dim, hdim] tensors + mlp_fcs = self.mlp_bank[:, 0, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + mlp_projs = self.mlp_bank[:, 1, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + + for i in range(self.num_layers): + yarn = self.yarn_paired_head if i in self.paired_head_layers else self.yarn + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + yarn=yarn, + key_offset=key_offset[i], + attn_gate_w=attn_gates[i], + ve_gate_w=ve_gates[i] + ) + if i in skip_out: + skip_gate_out = torch.sigmoid(skip_lambda) * 2 * torch.sigmoid(self.skip_gate(x0[..., :self.skip_gate.weight.size(-1)])) + x = x + skip_gate_out * skip_connections.pop() + if i == 0: + x = (resid_lambdas[0] + x0_lambdas[0]) * x + bigram_lambdas[0] * x0_bigram + else: + x = resid_lambdas[i] * x + x0_lambdas[i] * x0 + bigram_lambdas[i] * x0_bigram + + # Get weights for this layer from banks + qkvo_w = attn_weights[self.layer_to_attn_idx[i]] if i in self.layer_to_attn_idx else None + c_fc = mlp_fcs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + c_proj = mlp_projs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + + x = self.blocks[i](x, attn_args, qkvo_w, c_fc, c_proj) + if i in skip_in: + skip_connections.append(x) + if i == backout_layer: + x_backout = x + + # back out contributions from first 7 layers that are only required for downstream context and not direct prediction + x -= backout_lambda * x_backout + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15 + # @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1). @classiclarryd updated to 23*sigmoid((logits+5)/7.5) + if self.training: + losses = FusedSoftcappedCrossEntropy.apply(logits.view(-1, logits.size(-1)), target_seq, mtp_weights, 23.0, 5.0, 7.5) + loss = losses.sum() + else: + logits = 23 * torch.sigmoid((logits + 5) / 7.5) + logits_for_loss = logits.float() + loss = F.cross_entropy(logits_for_loss.view(-1, logits_for_loss.size(-1)), target_seq, reduction="mean") + return loss +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class Shard: + def __init__(self, tokens: Tensor, world_size: int = 1): + self.tokens = tokens + self.size = tokens.numel() + self.world_size = world_size + self.i = 0 + + # Partial index now, full index async + self.bos_idx = (tokens[:6_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._full_idx = None + self._loader_thread = None + self._ready = threading.Event() + self._loader_thread = threading.Thread(target=self._scan) + self._loader_thread.start() + + def _scan(self): + self._full_idx = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._ready.set() + + def _maybe_switch(self): + # Switch to full index as soon as async scan completes + if self.bos_idx is not self._full_idx and self._ready.is_set(): + self._loader_thread.join() + self.bos_idx = self._full_idx + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + self._maybe_switch() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + return starts, ends + + @staticmethod + def load_async(file: Path, world_size: int = 1): + """Returns getter function for async shard loading""" + result = {} + ready = threading.Event() + def load(): + tokens = _load_data_shard(file) + result['shard'] = Shard(tokens, world_size) + ready.set() + thread = threading.Thread(target=load) + thread.start() + def get(): + ready.wait() + thread.join() + return result['shard'] + return get + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + shard = Shard(tokens, world_size) + next_shard_getter = Shard.load_async(next(file_iter), world_size) + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = shard.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + shard = next_shard_getter() + tokens = shard.tokens + try: + next_shard_getter = Shard.load_async(next(file_iter), world_size) + except StopIteration: + next_shard_getter = None # no more shards to preload + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + # Cast to int32 on CPU before transfer to avoid dtype conversion during .to() + _inputs = _inputs.to(dtype=torch.int32) + _targets = _targets.to(dtype=torch.int64) + _cum_lengths = _cum_lengths.to(dtype=torch.int32) + # Bigram hash computation moved to GPU in forward() + + new_params = yield ( + _inputs.to(device="cuda", non_blocking=True), + _targets.to(device="cuda", non_blocking=True), + _cum_lengths.to(device="cuda", non_blocking=True), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * new_grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens // new_grad_accum_steps + max_seq_len = new_max_seq_len + +# ----------------------------------------------------------------------------- +# Training Management + +@dataclass +class Hyperparameters: + # data + data_path = os.environ.get("DATA_PATH", ".") + train_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_train_*.bin") # input .bin to train on + val_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_val_*.bin") # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + # batch sizes + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # schedule + num_scheduled_iterations: int = 1515 # number of steps to complete lr and ws schedule + num_extension_iterations: int = 40 # number of steps to continue training at final lr and ws + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 250 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # bigram hash embedding + bigram_vocab_size: int = 50304 * 5 + +args = Hyperparameters() + +@dataclass +class TrainingStage: + lr_mul: float + batch_size: int + window_sizes: tuple[int, int] # (short, long) in block units + mtp_weights_start: list[float] + mtp_weights_end: list[float] + duration: float = None + +class TrainingSchedule: + """ + Training schedule initialized via TRAINING_STAGES + 1. Multi Token Prediction schedule of [1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1] @varunneal + 2. Sliding Attention window schedule of [1,3] -> [3,7] -> [5,11] -> [6,13] + 3. YaRN updates to RoPE on window changes + 4. Split embed and lm head at 2/3 of training + 5. Batch size schedule of 8 -> 16 -> 24 + 6. Post training extension of long windows from 13 to 20 + """ + + def __init__(self, stages: list[TrainingStage], scheduled_iterations: int, extension_iterations: int, + cooldown_frac: float = 0.5, split_embed_stage: int = 2, ws_post_yarn_ext: int = 20): + self.stages = stages + self.scheduled_iterations = scheduled_iterations + self.cooldown_frac = cooldown_frac + # increase final validation ws, used for YaRN extension and short window size @classiclarryd + self.ws_post_yarn_ext = ws_post_yarn_ext + + self.total_steps = self.scheduled_iterations + extension_iterations + + # Build stage boundaries (last is extension stage) + ends = [0] + [round(c * scheduled_iterations) for c in accumulate(s.duration for s in stages[:-1])] + [self.total_steps] + assert self.scheduled_iterations == ends[-2] + self.boundaries = list(pairwise(ends)) + + # Split embed at specified stage (ensure odd step for Adam) + self.split_step = self.boundaries[split_embed_stage][0] | 1 + + # Precompute MTP weights for all steps + self.mtp_weights = [] + for step in range(self.total_steps + 1): + stage, t = self.lookup(step) + w = [a + (b - a) * t for a, b in zip(stage.mtp_weights_start, stage.mtp_weights_end)] + self.mtp_weights.append(torch.tensor(w, device=device)) + + def lookup(self, step: int) -> tuple[TrainingStage, float]: + # Returns stage and % of the way through that stage + for i, (start, end) in enumerate(self.boundaries): + if step < end: + t = (step - start) / (end - start) + return self.stages[i], t + return self.stages[-1], 1.0 + + def get_lr(self, step: int) -> float: + # learning rate schedule: tied to batch size schedule, with cooldown at the end + stage, _ = self.lookup(step) + lr = stage.lr_mul + cd_start = int(self.scheduled_iterations * (1 - self.cooldown_frac)) + if step >= cd_start: + t = min(1.0, (step - cd_start) / (self.scheduled_iterations - cd_start)) + lr = lr * (1 - t) + 0.1 * t + return lr + +# window_sizes are in units of `block_size` tokens (defined in TrainingManager) +TRAINING_STAGES = [ + TrainingStage(duration=1/3, batch_size=8 * 2048 * 8, window_sizes=(1, 3), lr_mul=1.0, + mtp_weights_start=[1.0, 0.5, 0.25], mtp_weights_end=[1.0, 0.5, 0.0]), + TrainingStage(duration=1/3, batch_size=16 * 2048 * 8, window_sizes=(3, 7), lr_mul=1.52, # (16/8)**0.6 + mtp_weights_start=[1.0, 0.5], mtp_weights_end=[1.0, 0.0]), + TrainingStage(duration=1/3, batch_size=24 * 2048 * 8, window_sizes=(5, 11), lr_mul=1.73, # (24/8)**0.5 + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), + # extension stage + TrainingStage(batch_size=24 * 2048 * 8, window_sizes=(6, 13), lr_mul=1.0, # lr_mul is not used + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), +] + +training_schedule = TrainingSchedule(TRAINING_STAGES, args.num_scheduled_iterations, args.num_extension_iterations, cooldown_frac=0.55) + +def get_muon_momentum(step: int, muon_warmup_steps=300, muon_cooldown_steps=50, momentum_min=0.85, momentum_max=0.95): + # warmup phase: linearly increase momentum from min to max + # cooldown phase: linearly decrease momentum from max to min + momentum_cd_start = training_schedule.total_steps - muon_cooldown_steps + if step < muon_warmup_steps: + frac = step / muon_warmup_steps + momentum = momentum_min + frac * (momentum_max - momentum_min) + elif step > momentum_cd_start: + frac = (step - momentum_cd_start) / muon_cooldown_steps + momentum = momentum_max - frac * (momentum_max - momentum_min) + else: + momentum = momentum_max + return momentum + +class TrainingManager(): + """ + Manages the NorMuonAndAdam for all parameters with explicit ordering. + 1. Scalars are given higher momentum terms to smooth learning @ChrisJMcCormick + 2. Adam optimizers are only stepped on odd steps @classiclarryd + 3. Explicit scatter_order and work_order for communication scheduling (no backward hooks) + 4. Muon has a linear momentum warmup and cooldown schedule + 5. Learning rates follow a linear decay schedule + 6. Embed is tied to lm_head until split step (2/3 of training), then untied @classiclarryd + """ + def __init__(self, model): + self.model = model + self.block_size = 128 + + # - Ordering dictates when to launch reduce/reduce_scatter operations + # - "sharded" parameters use reduce_scatter/all_gather and "replicated" ones use all_reduce + # - lr_mul and wd_mul are per-parameter learning rate and weight decay multipliers + self.param_table = { + "attn": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "mlp": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "scalars": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 5.0, "wd_mul": 0.0}, + "value_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "bigram_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "smear_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.01, "wd_mul": 0.0}, + "skip_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.05, "wd_mul": 0.0}, + "attn_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "ve_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "x0_lambdas": {"optim": "adam", "comms": "replicated", "adam_betas": [0.65, 0.95], "lr_mul": 5.0, "wd_mul": 0.0}, + "lm_head": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + "embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + } + + # - Process smaller/faster params first while large reduces complete + # - lm_head must complete before embed sync (when tied) + self.work_order = [ + "scalars", "smear_gate", "skip_gate", "attn_gate_bank", "ve_gate_bank", "x0_lambdas", # Small, fast + "value_embed", "bigram_embed", # Medium + "lm_head", "embed", # lm_head must complete before embed sync (when tied) + "attn", "mlp", # Large, polar express - process last to maximize overlap + ] + + adam_defaults = dict( + lr=0.008, + eps=1e-10, + weight_decay=0.005, + ) + + normuon_defaults = dict( + lr=0.023, + momentum=0.95, + beta2=0.95, + weight_decay=1.2, + ) + + self.optimizer = NorMuonAndAdam( + model.named_parameters(), + param_table=self.param_table, + scatter_order=list(self.param_table.keys()), # Dict order defines scatter priority + work_order=self.work_order, + adam_defaults=adam_defaults, + normuon_defaults=normuon_defaults, + ) + + # Split embed from lm_head at 2/3 of training (on an odd step so Adam updates) + self.split_step = training_schedule.split_step + + self.reset() + + def apply_final_ws_ext(self): + self.ws_long = training_schedule.ws_post_yarn_ext + + def get_forward_args(self): + return ForwardScheduleConfig( + mtp_weights = self.mtp_weights, + ws_short = self.ws_short * self.block_size, + ws_long = self.ws_long * self.block_size + ) + + def _is_adam_step(self, step: int): + """Adam params are only updated on odd steps.""" + return step % 2 == 1 + + def get_transition_steps(self): + return [start for start, _ in training_schedule.boundaries[1:]] + + def advance_schedule(self, step: int): + stage, _ = training_schedule.lookup(step) + self.ws_short, new_ws_long = stage.window_sizes + if new_ws_long != self.ws_long: + self.model.yarn.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + self.model.yarn_paired_head.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + + new_batch_size = stage.batch_size + if new_batch_size != self.batch_size: + self.train_loader_send_args = (new_batch_size, args.train_max_seq_len, grad_accum_steps) + self.batch_size = new_batch_size + else: + self.train_loader_send_args = None + + self.ws_long = new_ws_long + self.mtp_weights = training_schedule.mtp_weights[step] + + def step_optimizers(self, step: int): + step_lr = training_schedule.get_lr(step) + muon_momentum = get_muon_momentum(step) + do_adam = self._is_adam_step(step) + + # Update learning rates and momentum for all params + for param, p_cfg in self.optimizer.param_cfgs.items(): + p_cfg.lr = p_cfg.initial_lr * step_lr + if p_cfg.optim == "normuon": + p_cfg.momentum = muon_momentum + + # Step optimizer with do_adam flag + self.optimizer.step(do_adam=do_adam) + + # At split step: copy lm_head optimizer state to embed and mark as split + if step == self.split_step: + self.optimizer.copy_lm_state_to_embed() + + def reset(self, state=None): + if state is not None: + self.optimizer.load_state_dict(state) + + # Reset NorMuon momentum buffers and split_embed state + self.optimizer.reset() + + stage, _ = training_schedule.lookup(0) + self.ws_short, self.ws_long = stage.window_sizes + self.batch_size = stage.batch_size + self.model.yarn.reset() + self.model.yarn_paired_head.reset() + + def get_state(self): + return copy.deepcopy(self.optimizer.state_dict()) + +# ----------------------------------------------------------------------------- +# int main + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=11, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=args.val_batch_size // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.weight.data = m.weight.data.bfloat16() +model.attn_gate_bank.data = model.attn_gate_bank.data.bfloat16() +model.ve_gate_bank.data = model.ve_gate_bank.data.bfloat16() +model.attn_bank.data = model.attn_bank.data.bfloat16() +model.mlp_bank.data = model.mlp_bank.data.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) +training_manager = TrainingManager(model) + +######################################## +# Warmup kernels # +######################################## +print0("Compiling model and warming up kernels (~7 minutes on first execution)", console=True) +# Warmup the training kernels, then re-initialize the state so we aren't cheating +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizer=training_manager.get_state()) # save the initial state +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + +transition_steps = training_manager.get_transition_steps() +# first few steps plus transitions +warmup_steps = sorted({0, 1, 2} | set(s + offset for s in transition_steps for offset in [-1, 0, 1] if s + offset >= 0)) +print0(f"Sampling steps {warmup_steps} for warmup", console=True) +for step in warmup_steps: + training_manager.advance_schedule(step) + model.eval() + with torch.no_grad(): + inputs, targets, cum_seqlens = next(val_loader) + model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + model.train() + for idx in range(grad_accum_steps): + send_args = training_manager.train_loader_send_args + inputs, targets, cum_seqlens = train_loader.send(send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) +print0("Resetting Model", console=True) +model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +training_manager.reset(initial_state["optimizer"]) +del val_loader, train_loader, initial_state +model.train() + +######################################## +# Training and validation # +######################################## +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) + +gc.collect() + +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = training_schedule.total_steps +for step in range(train_steps + 1): + last_step = (step == train_steps) + training_manager.advance_schedule(step) + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + training_manager.apply_final_ws_ext() + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + val_loss /= val_steps + del val_loader + dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizer=training_manager.get_state()) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for idx in range(grad_accum_steps): + inputs, targets, cum_seqlens = train_loader.send(training_manager.train_loader_send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) + + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + + +---------------------------------------- +# triton_kernels.py +---------------------------------------- + +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded configs based on H100 autotuning + if K == 768: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + else: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 128 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded config based on H100 autotuning (M=768) + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +# ----------------------------------------------------------------------------- +# Triton kernel for MLP: relu(x @ W1.T)^2, by @andrewbriand, @jrauvola + +@triton.jit +def linear_relu_square_kernel(a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, + ): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am_c = pid_m * BLOCK_SIZE_M + offs_bn_c = pid_n * BLOCK_SIZE_N + + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + + c0 = acc0.to(dtype) + if not FORWARD: + c0_pre = aux_desc.load([offs_am_c, offs_bn_c]) + c0 = 2 * c0 * tl.where(c0_pre > 0, c0_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c], c0) + + if FORWARD: + c0_post = tl.maximum(c0, 0) + c0_post = c0_post * c0_post + aux_desc.store([offs_am_c, offs_bn_c], c0_post) + + c1 = acc1.to(dtype) + if not FORWARD: + c1_pre = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c1 = 2 * c1 * tl.where(c1_pre > 0, c1_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + + if FORWARD: + c1_post = tl.maximum(c1, 0) + c1_post = c1_post * c1_post + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1_post) + + +def linear_relu_square(a, b, aux=None): + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + FORWARD = False + if aux is None: + FORWARD = True + aux = torch.empty((M, N), device=a.device, dtype=dtype) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_stages = 4 if FORWARD else 3 + num_warps = 8 + + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + + def grid(META): + return (min( + NUM_SMS, + triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), + ), ) + + linear_relu_square_kernel[grid]( + a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + NUM_SMS=NUM_SMS, + FORWARD=FORWARD, + num_stages=num_stages, + num_warps=num_warps + ) + + if FORWARD: + return c, aux + else: + return c + +class FusedLinearReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, W1, W2): + pre, post = linear_relu_square(x.view((-1, x.shape[-1])), W1) + x3 = post @ W2 + ctx.save_for_backward(x, W1, W2, pre, post) + return x3.view(x.shape) + + @staticmethod + def backward(ctx, grad_output): + x, W1, W2, pre, post = ctx.saved_tensors + dW2 = post.T @ grad_output + dpre = linear_relu_square(grad_output.view((-1, grad_output.shape[-1])), W2, aux=pre) + dW1 = dpre.T @ x + dx = dpre @ W1 + return dx.view(x.shape), dW1, dW2 + +# ----------------------------------------------------------------------------- +# Fused Softcapped Cross Entropy + + +@triton.jit +def fused_softcapped_entropy_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + + max_val = -float('inf') + sum_exp = 0.0 + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=-float('inf')).to(tl.float32) + z = A * tl.sigmoid((val + B) / C) + z = tl.where(mask, z, -float('inf')) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + + total_loss = 0.0 + for k in range(n_predict): + target_idx = row_idx + k + if target_idx < n_rows: + weight = tl.load(mtp_weights_ptr + k) + if weight > 0: + target = tl.load(targets_ptr + target_idx).to(tl.int32) + if target >= 0 and target < n_cols: + val_target = tl.load(logits_row_ptr + target).to(tl.float32) + z_target = A * tl.sigmoid((val_target + B) / C) + total_loss += weight * (lse - z_target) + + tl.store(losses_ptr + row_idx, total_loss) + +@triton.jit +def fused_softcapped_entropy_bwd_kernel( + grad_input_ptr, grad_output_ptr, lse_ptr, logits_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, stride_grad_n, stride_grad_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_input_ptr + row_idx * stride_grad_n + + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_output_ptr + row_idx) + + S_w = 0.0 + for k in range(n_predict): + if row_idx + k < n_rows: + S_w += tl.load(mtp_weights_ptr + k) + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=0.0).to(tl.float32) + u = (val + B) / C + sigmoid_u = tl.sigmoid(u) + z = A * sigmoid_u + p = tl.exp(z - lse) + + term1 = S_w * p + term2 = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for k in range(n_predict): + if row_idx + k < n_rows: + target = tl.load(targets_ptr + row_idx + k).to(tl.int32) + weight = tl.load(mtp_weights_ptr + k) + term2 += tl.where(cols == target, weight, 0.0) + + grad_z = grad_loss * (term1 - term2) + dz_dx = (1.0 / C) * z * (1.0 - sigmoid_u) + grad_x = grad_z * dz_dx + tl.store(grad_row_ptr + cols, grad_x.to(tl.bfloat16), mask=mask) + +class FusedSoftcappedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, targets, mtp_weights, A=23.0, B=5.0, C=7.5): + n_rows, n_cols = logits.shape + if mtp_weights is None: + mtp_weights = torch.tensor([1.0], device=logits.device, dtype=torch.float32) + n_predict = mtp_weights.shape[0] + + losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + lse = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + + logits = logits.contiguous() + targets = targets.contiguous() + mtp_weights = mtp_weights.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_fwd_kernel[grid]( + logits, losses, lse, targets, mtp_weights, + logits.stride(0), logits.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + + ctx.save_for_backward(logits, targets, mtp_weights, lse) + ctx.params = (A, B, C) + return losses + + @staticmethod + def backward(ctx, grad_output): + logits, targets, mtp_weights, lse = ctx.saved_tensors + A, B, C = ctx.params + n_rows, n_cols = logits.shape + n_predict = mtp_weights.shape[0] + + grad_input = torch.empty((n_rows, n_cols), dtype=torch.bfloat16, device=logits.device) + grad_output = grad_output.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_bwd_kernel[grid]( + grad_input, grad_output, lse, logits, targets, mtp_weights, + logits.stride(0), logits.stride(1), grad_input.stride(0), grad_input.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + return grad_input, None, None, None, None, None + +==================================================================================================== +Running Python 3.12.7 (main, Jan 31 2026, 04:21:49) [GCC 13.2.0] +Running PyTorch 2.10.0.dev20251210+cu126 compiled for CUDA 12.6 +Running Triton version 3.6.0 +Sun Feb 1 06:06:54 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:71:00.0 Off | 0 | +| N/A 39C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:79:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:7F:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 39C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 37C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 15516 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 15517 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 15518 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 15519 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 15520 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 15521 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 15522 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 15523 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +Compiling model and warming up kernels (~7 minutes on first execution) +Sampling steps [0, 1, 2, 504, 505, 506, 1009, 1010, 1011, 1514, 1515, 1516] for warmup +Resetting Model +step:0/1555 val_loss:10.8306 train_time:0ms step_avg:0.03ms +step:1/1555 train_time:94ms step_avg:93.96ms +step:2/1555 train_time:121ms step_avg:60.66ms +step:3/1555 train_time:140ms step_avg:46.72ms +step:4/1555 train_time:159ms step_avg:39.65ms +step:5/1555 train_time:183ms step_avg:36.60ms +step:6/1555 train_time:220ms step_avg:36.68ms +step:7/1555 train_time:251ms step_avg:35.85ms +step:8/1555 train_time:288ms step_avg:36.05ms +step:9/1555 train_time:319ms step_avg:35.45ms +step:10/1555 train_time:356ms step_avg:35.65ms +step:11/1555 train_time:388ms step_avg:35.25ms +step:12/1555 train_time:425ms step_avg:35.43ms +step:13/1555 train_time:456ms step_avg:35.10ms +step:14/1555 train_time:494ms step_avg:35.26ms +step:15/1555 train_time:525ms step_avg:34.97ms +step:16/1555 train_time:562ms step_avg:35.15ms +step:17/1555 train_time:593ms step_avg:34.91ms +step:18/1555 train_time:631ms step_avg:35.04ms +step:19/1555 train_time:662ms step_avg:34.83ms +step:20/1555 train_time:700ms step_avg:34.99ms +step:21/1555 train_time:731ms step_avg:34.79ms +step:22/1555 train_time:768ms step_avg:34.91ms +step:23/1555 train_time:799ms step_avg:34.75ms +step:24/1555 train_time:837ms step_avg:34.86ms +step:25/1555 train_time:868ms step_avg:34.71ms +step:26/1555 train_time:905ms step_avg:34.82ms +step:27/1555 train_time:937ms step_avg:34.70ms +step:28/1555 train_time:974ms step_avg:34.79ms +step:29/1555 train_time:1006ms step_avg:34.68ms +step:30/1555 train_time:1045ms step_avg:34.82ms +step:31/1555 train_time:1076ms step_avg:34.71ms +step:32/1555 train_time:1113ms step_avg:34.79ms +step:33/1555 train_time:1145ms step_avg:34.70ms +step:34/1555 train_time:1183ms step_avg:34.80ms +step:35/1555 train_time:1214ms step_avg:34.69ms +step:36/1555 train_time:1252ms step_avg:34.77ms +step:37/1555 train_time:1284ms step_avg:34.69ms +step:38/1555 train_time:1321ms step_avg:34.77ms +step:39/1555 train_time:1353ms step_avg:34.68ms +step:40/1555 train_time:1390ms step_avg:34.75ms +step:41/1555 train_time:1421ms step_avg:34.66ms +step:42/1555 train_time:1458ms step_avg:34.72ms +step:43/1555 train_time:1489ms step_avg:34.64ms +step:44/1555 train_time:1528ms step_avg:34.72ms +step:45/1555 train_time:1558ms step_avg:34.62ms +step:46/1555 train_time:1595ms step_avg:34.68ms +step:47/1555 train_time:1626ms step_avg:34.60ms +step:48/1555 train_time:1664ms step_avg:34.66ms +step:49/1555 train_time:1695ms step_avg:34.58ms +step:50/1555 train_time:1732ms step_avg:34.64ms +step:51/1555 train_time:1763ms step_avg:34.56ms +step:52/1555 train_time:1800ms step_avg:34.62ms +step:53/1555 train_time:1832ms step_avg:34.56ms +step:54/1555 train_time:1869ms step_avg:34.61ms +step:55/1555 train_time:1900ms step_avg:34.54ms +step:56/1555 train_time:1938ms step_avg:34.60ms +step:57/1555 train_time:1969ms step_avg:34.54ms +step:58/1555 train_time:2007ms step_avg:34.60ms +step:59/1555 train_time:2038ms step_avg:34.54ms +step:60/1555 train_time:2075ms step_avg:34.59ms +step:61/1555 train_time:2107ms step_avg:34.54ms +step:62/1555 train_time:2145ms step_avg:34.59ms +step:63/1555 train_time:2176ms step_avg:34.54ms +step:64/1555 train_time:2214ms step_avg:34.59ms +step:65/1555 train_time:2245ms step_avg:34.54ms +step:66/1555 train_time:2283ms step_avg:34.60ms +step:67/1555 train_time:2315ms step_avg:34.55ms +step:68/1555 train_time:2352ms step_avg:34.59ms +step:69/1555 train_time:2384ms step_avg:34.55ms +step:70/1555 train_time:2422ms step_avg:34.60ms +step:71/1555 train_time:2453ms step_avg:34.55ms +step:72/1555 train_time:2490ms step_avg:34.59ms +step:73/1555 train_time:2522ms step_avg:34.55ms +step:74/1555 train_time:2560ms step_avg:34.59ms +step:75/1555 train_time:2591ms step_avg:34.55ms +step:76/1555 train_time:2629ms step_avg:34.59ms +step:77/1555 train_time:2660ms step_avg:34.54ms +step:78/1555 train_time:2697ms step_avg:34.58ms +step:79/1555 train_time:2728ms step_avg:34.54ms +step:80/1555 train_time:2766ms step_avg:34.57ms +step:81/1555 train_time:2797ms step_avg:34.53ms +step:82/1555 train_time:2834ms step_avg:34.56ms +step:83/1555 train_time:2865ms step_avg:34.52ms +step:84/1555 train_time:2903ms step_avg:34.56ms +step:85/1555 train_time:2935ms step_avg:34.52ms +step:86/1555 train_time:2972ms step_avg:34.56ms +step:87/1555 train_time:3004ms step_avg:34.53ms +step:88/1555 train_time:3042ms step_avg:34.57ms +step:89/1555 train_time:3073ms step_avg:34.53ms +step:90/1555 train_time:3110ms step_avg:34.56ms +step:91/1555 train_time:3141ms step_avg:34.52ms +step:92/1555 train_time:3179ms step_avg:34.56ms +step:93/1555 train_time:3210ms step_avg:34.52ms +step:94/1555 train_time:3247ms step_avg:34.55ms +step:95/1555 train_time:3278ms step_avg:34.51ms +step:96/1555 train_time:3315ms step_avg:34.53ms +step:97/1555 train_time:3347ms step_avg:34.50ms +step:98/1555 train_time:3384ms step_avg:34.53ms +step:99/1555 train_time:3415ms step_avg:34.50ms +step:100/1555 train_time:3452ms step_avg:34.52ms +step:101/1555 train_time:3483ms step_avg:34.49ms +step:102/1555 train_time:3521ms step_avg:34.52ms +step:103/1555 train_time:3552ms step_avg:34.48ms +step:104/1555 train_time:3589ms step_avg:34.51ms +step:105/1555 train_time:3620ms step_avg:34.48ms +step:106/1555 train_time:3658ms step_avg:34.51ms +step:107/1555 train_time:3689ms step_avg:34.47ms +step:108/1555 train_time:3726ms step_avg:34.50ms +step:109/1555 train_time:3757ms step_avg:34.47ms +step:110/1555 train_time:3794ms step_avg:34.49ms +step:111/1555 train_time:3826ms step_avg:34.46ms +step:112/1555 train_time:3863ms step_avg:34.49ms +step:113/1555 train_time:3894ms step_avg:34.46ms +step:114/1555 train_time:3931ms step_avg:34.48ms +step:115/1555 train_time:3962ms step_avg:34.46ms +step:116/1555 train_time:4000ms step_avg:34.48ms +step:117/1555 train_time:4031ms step_avg:34.45ms +step:118/1555 train_time:4068ms step_avg:34.48ms +step:119/1555 train_time:4099ms step_avg:34.45ms +step:120/1555 train_time:4137ms step_avg:34.47ms +step:121/1555 train_time:4168ms step_avg:34.45ms +step:122/1555 train_time:4206ms step_avg:34.48ms +step:123/1555 train_time:4236ms step_avg:34.44ms +step:124/1555 train_time:4274ms step_avg:34.47ms +step:125/1555 train_time:4305ms step_avg:34.44ms +step:126/1555 train_time:4343ms step_avg:34.47ms +step:127/1555 train_time:4374ms step_avg:34.44ms +step:128/1555 train_time:4411ms step_avg:34.46ms +step:129/1555 train_time:4443ms step_avg:34.44ms +step:130/1555 train_time:4480ms step_avg:34.46ms +step:131/1555 train_time:4511ms step_avg:34.44ms +step:132/1555 train_time:4549ms step_avg:34.46ms +step:133/1555 train_time:4581ms step_avg:34.44ms +step:134/1555 train_time:4619ms step_avg:34.47ms +step:135/1555 train_time:4650ms step_avg:34.44ms +step:136/1555 train_time:4688ms step_avg:34.47ms +step:137/1555 train_time:4719ms step_avg:34.44ms +step:138/1555 train_time:4756ms step_avg:34.46ms +step:139/1555 train_time:4787ms step_avg:34.44ms +step:140/1555 train_time:4824ms step_avg:34.46ms +step:141/1555 train_time:4855ms step_avg:34.43ms +step:142/1555 train_time:4893ms step_avg:34.45ms +step:143/1555 train_time:4924ms step_avg:34.43ms +step:144/1555 train_time:4961ms step_avg:34.45ms +step:145/1555 train_time:4993ms step_avg:34.43ms +step:146/1555 train_time:5030ms step_avg:34.45ms +step:147/1555 train_time:5061ms step_avg:34.43ms +step:148/1555 train_time:5099ms step_avg:34.45ms +step:149/1555 train_time:5130ms step_avg:34.43ms +step:150/1555 train_time:5167ms step_avg:34.45ms +step:151/1555 train_time:5198ms step_avg:34.43ms +step:152/1555 train_time:5236ms step_avg:34.45ms +step:153/1555 train_time:5267ms step_avg:34.42ms +step:154/1555 train_time:5304ms step_avg:34.44ms +step:155/1555 train_time:5336ms step_avg:34.43ms +step:156/1555 train_time:5373ms step_avg:34.45ms +step:157/1555 train_time:5405ms step_avg:34.43ms +step:158/1555 train_time:5442ms step_avg:34.45ms +step:159/1555 train_time:5474ms step_avg:34.43ms +step:160/1555 train_time:5511ms step_avg:34.44ms +step:161/1555 train_time:5543ms step_avg:34.43ms +step:162/1555 train_time:5580ms step_avg:34.45ms +step:163/1555 train_time:5612ms step_avg:34.43ms +step:164/1555 train_time:5649ms step_avg:34.44ms +step:165/1555 train_time:5680ms step_avg:34.42ms +step:166/1555 train_time:5718ms step_avg:34.44ms +step:167/1555 train_time:5749ms step_avg:34.42ms +step:168/1555 train_time:5786ms step_avg:34.44ms +step:169/1555 train_time:5817ms step_avg:34.42ms +step:170/1555 train_time:5854ms step_avg:34.44ms +step:171/1555 train_time:5885ms step_avg:34.42ms +step:172/1555 train_time:5923ms step_avg:34.44ms +step:173/1555 train_time:5954ms step_avg:34.41ms +step:174/1555 train_time:5991ms step_avg:34.43ms +step:175/1555 train_time:6022ms step_avg:34.41ms +step:176/1555 train_time:6060ms step_avg:34.43ms +step:177/1555 train_time:6091ms step_avg:34.41ms +step:178/1555 train_time:6129ms step_avg:34.43ms +step:179/1555 train_time:6160ms step_avg:34.41ms +step:180/1555 train_time:6197ms step_avg:34.43ms +step:181/1555 train_time:6230ms step_avg:34.42ms +step:182/1555 train_time:6266ms step_avg:34.43ms +step:183/1555 train_time:6297ms step_avg:34.41ms +step:184/1555 train_time:6334ms step_avg:34.42ms +step:185/1555 train_time:6366ms step_avg:34.41ms +step:186/1555 train_time:6403ms step_avg:34.43ms +step:187/1555 train_time:6434ms step_avg:34.41ms +step:188/1555 train_time:6472ms step_avg:34.42ms +step:189/1555 train_time:6503ms step_avg:34.41ms +step:190/1555 train_time:6540ms step_avg:34.42ms +step:191/1555 train_time:6571ms step_avg:34.40ms +step:192/1555 train_time:6609ms step_avg:34.42ms +step:193/1555 train_time:6640ms step_avg:34.41ms +step:194/1555 train_time:6678ms step_avg:34.42ms +step:195/1555 train_time:6709ms step_avg:34.41ms +step:196/1555 train_time:6747ms step_avg:34.42ms +step:197/1555 train_time:6778ms step_avg:34.41ms +step:198/1555 train_time:6815ms step_avg:34.42ms +step:199/1555 train_time:6846ms step_avg:34.40ms +step:200/1555 train_time:6883ms step_avg:34.42ms +step:201/1555 train_time:6914ms step_avg:34.40ms +step:202/1555 train_time:6951ms step_avg:34.41ms +step:203/1555 train_time:6982ms step_avg:34.40ms +step:204/1555 train_time:7020ms step_avg:34.41ms +step:205/1555 train_time:7051ms step_avg:34.39ms +step:206/1555 train_time:7088ms step_avg:34.41ms +step:207/1555 train_time:7119ms step_avg:34.39ms +step:208/1555 train_time:7157ms step_avg:34.41ms +step:209/1555 train_time:7188ms step_avg:34.39ms +step:210/1555 train_time:7225ms step_avg:34.41ms +step:211/1555 train_time:7256ms step_avg:34.39ms +step:212/1555 train_time:7293ms step_avg:34.40ms +step:213/1555 train_time:7325ms step_avg:34.39ms +step:214/1555 train_time:7362ms step_avg:34.40ms +step:215/1555 train_time:7393ms step_avg:34.39ms +step:216/1555 train_time:7431ms step_avg:34.40ms +step:217/1555 train_time:7462ms step_avg:34.39ms +step:218/1555 train_time:7499ms step_avg:34.40ms +step:219/1555 train_time:7530ms step_avg:34.39ms +step:220/1555 train_time:7568ms step_avg:34.40ms +step:221/1555 train_time:7599ms step_avg:34.38ms +step:222/1555 train_time:7636ms step_avg:34.40ms +step:223/1555 train_time:7668ms step_avg:34.38ms +step:224/1555 train_time:7705ms step_avg:34.40ms +step:225/1555 train_time:7736ms step_avg:34.38ms +step:226/1555 train_time:7774ms step_avg:34.40ms +step:227/1555 train_time:7805ms step_avg:34.38ms +step:228/1555 train_time:7843ms step_avg:34.40ms +step:229/1555 train_time:7874ms step_avg:34.38ms +step:230/1555 train_time:7911ms step_avg:34.40ms +step:231/1555 train_time:7942ms step_avg:34.38ms +step:232/1555 train_time:7980ms step_avg:34.39ms +step:233/1555 train_time:8011ms step_avg:34.38ms +step:234/1555 train_time:8048ms step_avg:34.39ms +step:235/1555 train_time:8079ms step_avg:34.38ms +step:236/1555 train_time:8117ms step_avg:34.39ms +step:237/1555 train_time:8148ms step_avg:34.38ms +step:238/1555 train_time:8185ms step_avg:34.39ms +step:239/1555 train_time:8216ms step_avg:34.38ms +step:240/1555 train_time:8254ms step_avg:34.39ms +step:241/1555 train_time:8286ms step_avg:34.38ms +step:242/1555 train_time:8324ms step_avg:34.40ms +step:243/1555 train_time:8355ms step_avg:34.38ms +step:244/1555 train_time:8392ms step_avg:34.39ms +step:245/1555 train_time:8423ms step_avg:34.38ms +step:246/1555 train_time:8461ms step_avg:34.39ms +step:247/1555 train_time:8492ms step_avg:34.38ms +step:248/1555 train_time:8529ms step_avg:34.39ms +step:249/1555 train_time:8560ms step_avg:34.38ms +step:250/1555 train_time:8598ms step_avg:34.39ms +step:250/1555 val_loss:4.5545 train_time:8648ms step_avg:34.59ms +step:251/1555 train_time:8668ms step_avg:34.53ms +step:252/1555 train_time:8687ms step_avg:34.47ms +step:253/1555 train_time:8704ms step_avg:34.40ms +step:254/1555 train_time:8737ms step_avg:34.40ms +step:255/1555 train_time:8771ms step_avg:34.39ms +step:256/1555 train_time:8809ms step_avg:34.41ms +step:257/1555 train_time:8840ms step_avg:34.40ms +step:258/1555 train_time:8878ms step_avg:34.41ms +step:259/1555 train_time:8910ms step_avg:34.40ms +step:260/1555 train_time:8948ms step_avg:34.41ms +step:261/1555 train_time:8979ms step_avg:34.40ms +step:262/1555 train_time:9016ms step_avg:34.41ms +step:263/1555 train_time:9047ms step_avg:34.40ms +step:264/1555 train_time:9084ms step_avg:34.41ms +step:265/1555 train_time:9115ms step_avg:34.40ms +step:266/1555 train_time:9153ms step_avg:34.41ms +step:267/1555 train_time:9184ms step_avg:34.40ms +step:268/1555 train_time:9221ms step_avg:34.41ms +step:269/1555 train_time:9252ms step_avg:34.39ms +step:270/1555 train_time:9289ms step_avg:34.40ms +step:271/1555 train_time:9320ms step_avg:34.39ms +step:272/1555 train_time:9357ms step_avg:34.40ms +step:273/1555 train_time:9388ms step_avg:34.39ms +step:274/1555 train_time:9425ms step_avg:34.40ms +step:275/1555 train_time:9456ms step_avg:34.39ms +step:276/1555 train_time:9493ms step_avg:34.40ms +step:277/1555 train_time:9524ms step_avg:34.38ms +step:278/1555 train_time:9561ms step_avg:34.39ms +step:279/1555 train_time:9592ms step_avg:34.38ms +step:280/1555 train_time:9630ms step_avg:34.39ms +step:281/1555 train_time:9661ms step_avg:34.38ms +step:282/1555 train_time:9698ms step_avg:34.39ms +step:283/1555 train_time:9729ms step_avg:34.38ms +step:284/1555 train_time:9767ms step_avg:34.39ms +step:285/1555 train_time:9798ms step_avg:34.38ms +step:286/1555 train_time:9836ms step_avg:34.39ms +step:287/1555 train_time:9867ms step_avg:34.38ms +step:288/1555 train_time:9905ms step_avg:34.39ms +step:289/1555 train_time:9936ms step_avg:34.38ms +step:290/1555 train_time:9974ms step_avg:34.39ms +step:291/1555 train_time:10005ms step_avg:34.38ms +step:292/1555 train_time:10042ms step_avg:34.39ms +step:293/1555 train_time:10073ms step_avg:34.38ms +step:294/1555 train_time:10111ms step_avg:34.39ms +step:295/1555 train_time:10142ms step_avg:34.38ms +step:296/1555 train_time:10180ms step_avg:34.39ms +step:297/1555 train_time:10211ms step_avg:34.38ms +step:298/1555 train_time:10248ms step_avg:34.39ms +step:299/1555 train_time:10279ms step_avg:34.38ms +step:300/1555 train_time:10317ms step_avg:34.39ms +step:301/1555 train_time:10347ms step_avg:34.38ms +step:302/1555 train_time:10385ms step_avg:34.39ms +step:303/1555 train_time:10416ms step_avg:34.38ms +step:304/1555 train_time:10454ms step_avg:34.39ms +step:305/1555 train_time:10484ms step_avg:34.38ms +step:306/1555 train_time:10522ms step_avg:34.38ms +step:307/1555 train_time:10553ms step_avg:34.37ms +step:308/1555 train_time:10591ms step_avg:34.39ms +step:309/1555 train_time:10622ms step_avg:34.37ms +step:310/1555 train_time:10659ms step_avg:34.38ms +step:311/1555 train_time:10690ms step_avg:34.37ms +step:312/1555 train_time:10727ms step_avg:34.38ms +step:313/1555 train_time:10758ms step_avg:34.37ms +step:314/1555 train_time:10796ms step_avg:34.38ms +step:315/1555 train_time:10827ms step_avg:34.37ms +step:316/1555 train_time:10864ms step_avg:34.38ms +step:317/1555 train_time:10895ms step_avg:34.37ms +step:318/1555 train_time:10933ms step_avg:34.38ms +step:319/1555 train_time:10964ms step_avg:34.37ms +step:320/1555 train_time:11001ms step_avg:34.38ms +step:321/1555 train_time:11032ms step_avg:34.37ms +step:322/1555 train_time:11070ms step_avg:34.38ms +step:323/1555 train_time:11101ms step_avg:34.37ms +step:324/1555 train_time:11139ms step_avg:34.38ms +step:325/1555 train_time:11170ms step_avg:34.37ms +step:326/1555 train_time:11208ms step_avg:34.38ms +step:327/1555 train_time:11239ms step_avg:34.37ms +step:328/1555 train_time:11277ms step_avg:34.38ms +step:329/1555 train_time:11308ms step_avg:34.37ms +step:330/1555 train_time:11345ms step_avg:34.38ms +step:331/1555 train_time:11376ms step_avg:34.37ms +step:332/1555 train_time:11414ms step_avg:34.38ms +step:333/1555 train_time:11445ms step_avg:34.37ms +step:334/1555 train_time:11482ms step_avg:34.38ms +step:335/1555 train_time:11513ms step_avg:34.37ms +step:336/1555 train_time:11551ms step_avg:34.38ms +step:337/1555 train_time:11582ms step_avg:34.37ms +step:338/1555 train_time:11619ms step_avg:34.38ms +step:339/1555 train_time:11650ms step_avg:34.37ms +step:340/1555 train_time:11687ms step_avg:34.37ms +step:341/1555 train_time:11718ms step_avg:34.36ms +step:342/1555 train_time:11756ms step_avg:34.37ms +step:343/1555 train_time:11786ms step_avg:34.36ms +step:344/1555 train_time:11823ms step_avg:34.37ms +step:345/1555 train_time:11854ms step_avg:34.36ms +step:346/1555 train_time:11892ms step_avg:34.37ms +step:347/1555 train_time:11923ms step_avg:34.36ms +step:348/1555 train_time:11960ms step_avg:34.37ms +step:349/1555 train_time:11992ms step_avg:34.36ms +step:350/1555 train_time:12029ms step_avg:34.37ms +step:351/1555 train_time:12060ms step_avg:34.36ms +step:352/1555 train_time:12098ms step_avg:34.37ms +step:353/1555 train_time:12129ms step_avg:34.36ms +step:354/1555 train_time:12166ms step_avg:34.37ms +step:355/1555 train_time:12197ms step_avg:34.36ms +step:356/1555 train_time:12235ms step_avg:34.37ms +step:357/1555 train_time:12266ms step_avg:34.36ms +step:358/1555 train_time:12304ms step_avg:34.37ms +step:359/1555 train_time:12334ms step_avg:34.36ms +step:360/1555 train_time:12372ms step_avg:34.37ms +step:361/1555 train_time:12403ms step_avg:34.36ms +step:362/1555 train_time:12440ms step_avg:34.37ms +step:363/1555 train_time:12472ms step_avg:34.36ms +step:364/1555 train_time:12509ms step_avg:34.37ms +step:365/1555 train_time:12540ms step_avg:34.36ms +step:366/1555 train_time:12578ms step_avg:34.37ms +step:367/1555 train_time:12609ms step_avg:34.36ms +step:368/1555 train_time:12646ms step_avg:34.36ms +step:369/1555 train_time:12677ms step_avg:34.36ms +step:370/1555 train_time:12715ms step_avg:34.36ms +step:371/1555 train_time:12746ms step_avg:34.36ms +step:372/1555 train_time:12783ms step_avg:34.36ms +step:373/1555 train_time:12815ms step_avg:34.36ms +step:374/1555 train_time:12853ms step_avg:34.37ms +step:375/1555 train_time:12884ms step_avg:34.36ms +step:376/1555 train_time:12921ms step_avg:34.36ms +step:377/1555 train_time:12952ms step_avg:34.36ms +step:378/1555 train_time:12989ms step_avg:34.36ms +step:379/1555 train_time:13020ms step_avg:34.35ms +step:380/1555 train_time:13058ms step_avg:34.36ms +step:381/1555 train_time:13089ms step_avg:34.35ms +step:382/1555 train_time:13126ms step_avg:34.36ms +step:383/1555 train_time:13157ms step_avg:34.35ms +step:384/1555 train_time:13195ms step_avg:34.36ms +step:385/1555 train_time:13226ms step_avg:34.35ms +step:386/1555 train_time:13263ms step_avg:34.36ms +step:387/1555 train_time:13294ms step_avg:34.35ms +step:388/1555 train_time:13332ms step_avg:34.36ms +step:389/1555 train_time:13363ms step_avg:34.35ms +step:390/1555 train_time:13401ms step_avg:34.36ms +step:391/1555 train_time:13432ms step_avg:34.35ms +step:392/1555 train_time:13470ms step_avg:34.36ms +step:393/1555 train_time:13501ms step_avg:34.35ms +step:394/1555 train_time:13538ms step_avg:34.36ms +step:395/1555 train_time:13569ms step_avg:34.35ms +step:396/1555 train_time:13607ms step_avg:34.36ms +step:397/1555 train_time:13638ms step_avg:34.35ms +step:398/1555 train_time:13675ms step_avg:34.36ms +step:399/1555 train_time:13706ms step_avg:34.35ms +step:400/1555 train_time:13743ms step_avg:34.36ms +step:401/1555 train_time:13774ms step_avg:34.35ms +step:402/1555 train_time:13812ms step_avg:34.36ms +step:403/1555 train_time:13843ms step_avg:34.35ms +step:404/1555 train_time:13880ms step_avg:34.36ms +step:405/1555 train_time:13912ms step_avg:34.35ms +step:406/1555 train_time:13949ms step_avg:34.36ms +step:407/1555 train_time:13980ms step_avg:34.35ms +step:408/1555 train_time:14017ms step_avg:34.36ms +step:409/1555 train_time:14048ms step_avg:34.35ms +step:410/1555 train_time:14085ms step_avg:34.35ms +step:411/1555 train_time:14116ms step_avg:34.35ms +step:412/1555 train_time:14154ms step_avg:34.35ms +step:413/1555 train_time:14185ms step_avg:34.35ms +step:414/1555 train_time:14222ms step_avg:34.35ms +step:415/1555 train_time:14254ms step_avg:34.35ms +step:416/1555 train_time:14291ms step_avg:34.35ms +step:417/1555 train_time:14323ms step_avg:34.35ms +step:418/1555 train_time:14360ms step_avg:34.35ms +step:419/1555 train_time:14391ms step_avg:34.35ms +step:420/1555 train_time:14429ms step_avg:34.35ms +step:421/1555 train_time:14460ms step_avg:34.35ms +step:422/1555 train_time:14497ms step_avg:34.35ms +step:423/1555 train_time:14528ms step_avg:34.35ms +step:424/1555 train_time:14565ms step_avg:34.35ms +step:425/1555 train_time:14596ms step_avg:34.34ms +step:426/1555 train_time:14634ms step_avg:34.35ms +step:427/1555 train_time:14664ms step_avg:34.34ms +step:428/1555 train_time:14701ms step_avg:34.35ms +step:429/1555 train_time:14733ms step_avg:34.34ms +step:430/1555 train_time:14770ms step_avg:34.35ms +step:431/1555 train_time:14801ms step_avg:34.34ms +step:432/1555 train_time:14838ms step_avg:34.35ms +step:433/1555 train_time:14870ms step_avg:34.34ms +step:434/1555 train_time:14907ms step_avg:34.35ms +step:435/1555 train_time:14939ms step_avg:34.34ms +step:436/1555 train_time:14977ms step_avg:34.35ms +step:437/1555 train_time:15007ms step_avg:34.34ms +step:438/1555 train_time:15045ms step_avg:34.35ms +step:439/1555 train_time:15076ms step_avg:34.34ms +step:440/1555 train_time:15113ms step_avg:34.35ms +step:441/1555 train_time:15145ms step_avg:34.34ms +step:442/1555 train_time:15182ms step_avg:34.35ms +step:443/1555 train_time:15213ms step_avg:34.34ms +step:444/1555 train_time:15251ms step_avg:34.35ms +step:445/1555 train_time:15282ms step_avg:34.34ms +step:446/1555 train_time:15319ms step_avg:34.35ms +step:447/1555 train_time:15350ms step_avg:34.34ms +step:448/1555 train_time:15388ms step_avg:34.35ms +step:449/1555 train_time:15419ms step_avg:34.34ms +step:450/1555 train_time:15457ms step_avg:34.35ms +step:451/1555 train_time:15488ms step_avg:34.34ms +step:452/1555 train_time:15525ms step_avg:34.35ms +step:453/1555 train_time:15556ms step_avg:34.34ms +step:454/1555 train_time:15593ms step_avg:34.35ms +step:455/1555 train_time:15624ms step_avg:34.34ms +step:456/1555 train_time:15662ms step_avg:34.35ms +step:457/1555 train_time:15692ms step_avg:34.34ms +step:458/1555 train_time:15729ms step_avg:34.34ms +step:459/1555 train_time:15760ms step_avg:34.34ms +step:460/1555 train_time:15797ms step_avg:34.34ms +step:461/1555 train_time:15828ms step_avg:34.33ms +step:462/1555 train_time:15866ms step_avg:34.34ms +step:463/1555 train_time:15897ms step_avg:34.33ms +step:464/1555 train_time:15934ms step_avg:34.34ms +step:465/1555 train_time:15965ms step_avg:34.33ms +step:466/1555 train_time:16002ms step_avg:34.34ms +step:467/1555 train_time:16033ms step_avg:34.33ms +step:468/1555 train_time:16071ms step_avg:34.34ms +step:469/1555 train_time:16102ms step_avg:34.33ms +step:470/1555 train_time:16139ms step_avg:34.34ms +step:471/1555 train_time:16170ms step_avg:34.33ms +step:472/1555 train_time:16207ms step_avg:34.34ms +step:473/1555 train_time:16238ms step_avg:34.33ms +step:474/1555 train_time:16276ms step_avg:34.34ms +step:475/1555 train_time:16307ms step_avg:34.33ms +step:476/1555 train_time:16344ms step_avg:34.34ms +step:477/1555 train_time:16375ms step_avg:34.33ms +step:478/1555 train_time:16414ms step_avg:34.34ms +step:479/1555 train_time:16444ms step_avg:34.33ms +step:480/1555 train_time:16481ms step_avg:34.34ms +step:481/1555 train_time:16513ms step_avg:34.33ms +step:482/1555 train_time:16550ms step_avg:34.34ms +step:483/1555 train_time:16581ms step_avg:34.33ms +step:484/1555 train_time:16619ms step_avg:34.34ms +step:485/1555 train_time:16650ms step_avg:34.33ms +step:486/1555 train_time:16687ms step_avg:34.34ms +step:487/1555 train_time:16719ms step_avg:34.33ms +step:488/1555 train_time:16756ms step_avg:34.34ms +step:489/1555 train_time:16787ms step_avg:34.33ms +step:490/1555 train_time:16824ms step_avg:34.34ms +step:491/1555 train_time:16856ms step_avg:34.33ms +step:492/1555 train_time:16893ms step_avg:34.34ms +step:493/1555 train_time:16925ms step_avg:34.33ms +step:494/1555 train_time:16962ms step_avg:34.34ms +step:495/1555 train_time:16993ms step_avg:34.33ms +step:496/1555 train_time:17030ms step_avg:34.33ms +step:497/1555 train_time:17061ms step_avg:34.33ms +step:498/1555 train_time:17098ms step_avg:34.33ms +step:499/1555 train_time:17130ms step_avg:34.33ms +step:500/1555 train_time:17167ms step_avg:34.33ms +step:500/1555 val_loss:4.2700 train_time:17216ms step_avg:34.43ms +step:501/1555 train_time:17234ms step_avg:34.40ms +step:502/1555 train_time:17252ms step_avg:34.37ms +step:503/1555 train_time:17269ms step_avg:34.33ms +step:504/1555 train_time:17306ms step_avg:34.34ms +step:505/1555 train_time:17337ms step_avg:34.33ms +step:506/1555 train_time:17380ms step_avg:34.35ms +step:507/1555 train_time:17435ms step_avg:34.39ms +step:508/1555 train_time:17500ms step_avg:34.45ms +step:509/1555 train_time:17557ms step_avg:34.49ms +step:510/1555 train_time:17621ms step_avg:34.55ms +step:511/1555 train_time:17679ms step_avg:34.60ms +step:512/1555 train_time:17743ms step_avg:34.65ms +step:513/1555 train_time:17800ms step_avg:34.70ms +step:514/1555 train_time:17864ms step_avg:34.76ms +step:515/1555 train_time:17921ms step_avg:34.80ms +step:516/1555 train_time:17985ms step_avg:34.85ms +step:517/1555 train_time:18042ms step_avg:34.90ms +step:518/1555 train_time:18106ms step_avg:34.95ms +step:519/1555 train_time:18164ms step_avg:35.00ms +step:520/1555 train_time:18230ms step_avg:35.06ms +step:521/1555 train_time:18290ms step_avg:35.10ms +step:522/1555 train_time:18355ms step_avg:35.16ms +step:523/1555 train_time:18413ms step_avg:35.21ms +step:524/1555 train_time:18477ms step_avg:35.26ms +step:525/1555 train_time:18534ms step_avg:35.30ms +step:526/1555 train_time:18599ms step_avg:35.36ms +step:527/1555 train_time:18656ms step_avg:35.40ms +step:528/1555 train_time:18720ms step_avg:35.46ms +step:529/1555 train_time:18777ms step_avg:35.50ms +step:530/1555 train_time:18842ms step_avg:35.55ms +step:531/1555 train_time:18898ms step_avg:35.59ms +step:532/1555 train_time:18962ms step_avg:35.64ms +step:533/1555 train_time:19019ms step_avg:35.68ms +step:534/1555 train_time:19083ms step_avg:35.74ms +step:535/1555 train_time:19140ms step_avg:35.78ms +step:536/1555 train_time:19205ms step_avg:35.83ms +step:537/1555 train_time:19264ms step_avg:35.87ms +step:538/1555 train_time:19330ms step_avg:35.93ms +step:539/1555 train_time:19389ms step_avg:35.97ms +step:540/1555 train_time:19453ms step_avg:36.02ms +step:541/1555 train_time:19511ms step_avg:36.06ms +step:542/1555 train_time:19575ms step_avg:36.12ms +step:543/1555 train_time:19634ms step_avg:36.16ms +step:544/1555 train_time:19697ms step_avg:36.21ms +step:545/1555 train_time:19755ms step_avg:36.25ms +step:546/1555 train_time:19818ms step_avg:36.30ms +step:547/1555 train_time:19875ms step_avg:36.33ms +step:548/1555 train_time:19939ms step_avg:36.39ms +step:549/1555 train_time:19996ms step_avg:36.42ms +step:550/1555 train_time:20059ms step_avg:36.47ms +step:551/1555 train_time:20117ms step_avg:36.51ms +step:552/1555 train_time:20181ms step_avg:36.56ms +step:553/1555 train_time:20240ms step_avg:36.60ms +step:554/1555 train_time:20305ms step_avg:36.65ms +step:555/1555 train_time:20363ms step_avg:36.69ms +step:556/1555 train_time:20429ms step_avg:36.74ms +step:557/1555 train_time:20487ms step_avg:36.78ms +step:558/1555 train_time:20552ms step_avg:36.83ms +step:559/1555 train_time:20609ms step_avg:36.87ms +step:560/1555 train_time:20673ms step_avg:36.92ms +step:561/1555 train_time:20731ms step_avg:36.95ms +step:562/1555 train_time:20795ms step_avg:37.00ms +step:563/1555 train_time:20852ms step_avg:37.04ms +step:564/1555 train_time:20917ms step_avg:37.09ms +step:565/1555 train_time:20974ms step_avg:37.12ms +step:566/1555 train_time:21038ms step_avg:37.17ms +step:567/1555 train_time:21096ms step_avg:37.21ms +step:568/1555 train_time:21160ms step_avg:37.25ms +step:569/1555 train_time:21218ms step_avg:37.29ms +step:570/1555 train_time:21281ms step_avg:37.34ms +step:571/1555 train_time:21339ms step_avg:37.37ms +step:572/1555 train_time:21405ms step_avg:37.42ms +step:573/1555 train_time:21463ms step_avg:37.46ms +step:574/1555 train_time:21528ms step_avg:37.50ms +step:575/1555 train_time:21585ms step_avg:37.54ms +step:576/1555 train_time:21650ms step_avg:37.59ms +step:577/1555 train_time:21708ms step_avg:37.62ms +step:578/1555 train_time:21772ms step_avg:37.67ms +step:579/1555 train_time:21830ms step_avg:37.70ms +step:580/1555 train_time:21894ms step_avg:37.75ms +step:581/1555 train_time:21952ms step_avg:37.78ms +step:582/1555 train_time:22016ms step_avg:37.83ms +step:583/1555 train_time:22074ms step_avg:37.86ms +step:584/1555 train_time:22139ms step_avg:37.91ms +step:585/1555 train_time:22196ms step_avg:37.94ms +step:586/1555 train_time:22260ms step_avg:37.99ms +step:587/1555 train_time:22317ms step_avg:38.02ms +step:588/1555 train_time:22382ms step_avg:38.06ms +step:589/1555 train_time:22439ms step_avg:38.10ms +step:590/1555 train_time:22503ms step_avg:38.14ms +step:591/1555 train_time:22561ms step_avg:38.17ms +step:592/1555 train_time:22627ms step_avg:38.22ms +step:593/1555 train_time:22685ms step_avg:38.25ms +step:594/1555 train_time:22749ms step_avg:38.30ms +step:595/1555 train_time:22806ms step_avg:38.33ms +step:596/1555 train_time:22871ms step_avg:38.37ms +step:597/1555 train_time:22929ms step_avg:38.41ms +step:598/1555 train_time:22994ms step_avg:38.45ms +step:599/1555 train_time:23052ms step_avg:38.48ms +step:600/1555 train_time:23117ms step_avg:38.53ms +step:601/1555 train_time:23175ms step_avg:38.56ms +step:602/1555 train_time:23238ms step_avg:38.60ms +step:603/1555 train_time:23295ms step_avg:38.63ms +step:604/1555 train_time:23359ms step_avg:38.67ms +step:605/1555 train_time:23417ms step_avg:38.71ms +step:606/1555 train_time:23481ms step_avg:38.75ms +step:607/1555 train_time:23540ms step_avg:38.78ms +step:608/1555 train_time:23605ms step_avg:38.82ms +step:609/1555 train_time:23662ms step_avg:38.85ms +step:610/1555 train_time:23727ms step_avg:38.90ms +step:611/1555 train_time:23785ms step_avg:38.93ms +step:612/1555 train_time:23849ms step_avg:38.97ms +step:613/1555 train_time:23907ms step_avg:39.00ms +step:614/1555 train_time:23972ms step_avg:39.04ms +step:615/1555 train_time:24030ms step_avg:39.07ms +step:616/1555 train_time:24094ms step_avg:39.11ms +step:617/1555 train_time:24152ms step_avg:39.14ms +step:618/1555 train_time:24217ms step_avg:39.19ms +step:619/1555 train_time:24274ms step_avg:39.22ms +step:620/1555 train_time:24339ms step_avg:39.26ms +step:621/1555 train_time:24396ms step_avg:39.29ms +step:622/1555 train_time:24460ms step_avg:39.32ms +step:623/1555 train_time:24517ms step_avg:39.35ms +step:624/1555 train_time:24582ms step_avg:39.39ms +step:625/1555 train_time:24640ms step_avg:39.42ms +step:626/1555 train_time:24704ms step_avg:39.46ms +step:627/1555 train_time:24762ms step_avg:39.49ms +step:628/1555 train_time:24827ms step_avg:39.53ms +step:629/1555 train_time:24885ms step_avg:39.56ms +step:630/1555 train_time:24950ms step_avg:39.60ms +step:631/1555 train_time:25008ms step_avg:39.63ms +step:632/1555 train_time:25072ms step_avg:39.67ms +step:633/1555 train_time:25132ms step_avg:39.70ms +step:634/1555 train_time:25195ms step_avg:39.74ms +step:635/1555 train_time:25253ms step_avg:39.77ms +step:636/1555 train_time:25317ms step_avg:39.81ms +step:637/1555 train_time:25375ms step_avg:39.84ms +step:638/1555 train_time:25439ms step_avg:39.87ms +step:639/1555 train_time:25496ms step_avg:39.90ms +step:640/1555 train_time:25560ms step_avg:39.94ms +step:641/1555 train_time:25618ms step_avg:39.97ms +step:642/1555 train_time:25682ms step_avg:40.00ms +step:643/1555 train_time:25740ms step_avg:40.03ms +step:644/1555 train_time:25805ms step_avg:40.07ms +step:645/1555 train_time:25862ms step_avg:40.10ms +step:646/1555 train_time:25927ms step_avg:40.14ms +step:647/1555 train_time:25986ms step_avg:40.16ms +step:648/1555 train_time:26050ms step_avg:40.20ms +step:649/1555 train_time:26108ms step_avg:40.23ms +step:650/1555 train_time:26173ms step_avg:40.27ms +step:651/1555 train_time:26231ms step_avg:40.29ms +step:652/1555 train_time:26295ms step_avg:40.33ms +step:653/1555 train_time:26352ms step_avg:40.36ms +step:654/1555 train_time:26416ms step_avg:40.39ms +step:655/1555 train_time:26475ms step_avg:40.42ms +step:656/1555 train_time:26538ms step_avg:40.45ms +step:657/1555 train_time:26595ms step_avg:40.48ms +step:658/1555 train_time:26660ms step_avg:40.52ms +step:659/1555 train_time:26718ms step_avg:40.54ms +step:660/1555 train_time:26782ms step_avg:40.58ms +step:661/1555 train_time:26840ms step_avg:40.61ms +step:662/1555 train_time:26905ms step_avg:40.64ms +step:663/1555 train_time:26963ms step_avg:40.67ms +step:664/1555 train_time:27029ms step_avg:40.71ms +step:665/1555 train_time:27086ms step_avg:40.73ms +step:666/1555 train_time:27151ms step_avg:40.77ms +step:667/1555 train_time:27209ms step_avg:40.79ms +step:668/1555 train_time:27273ms step_avg:40.83ms +step:669/1555 train_time:27331ms step_avg:40.85ms +step:670/1555 train_time:27395ms step_avg:40.89ms +step:671/1555 train_time:27453ms step_avg:40.91ms +step:672/1555 train_time:27517ms step_avg:40.95ms +step:673/1555 train_time:27574ms step_avg:40.97ms +step:674/1555 train_time:27638ms step_avg:41.01ms +step:675/1555 train_time:27695ms step_avg:41.03ms +step:676/1555 train_time:27758ms step_avg:41.06ms +step:677/1555 train_time:27817ms step_avg:41.09ms +step:678/1555 train_time:27882ms step_avg:41.12ms +step:679/1555 train_time:27940ms step_avg:41.15ms +step:680/1555 train_time:28005ms step_avg:41.18ms +step:681/1555 train_time:28064ms step_avg:41.21ms +step:682/1555 train_time:28129ms step_avg:41.24ms +step:683/1555 train_time:28187ms step_avg:41.27ms +step:684/1555 train_time:28251ms step_avg:41.30ms +step:685/1555 train_time:28309ms step_avg:41.33ms +step:686/1555 train_time:28373ms step_avg:41.36ms +step:687/1555 train_time:28431ms step_avg:41.38ms +step:688/1555 train_time:28495ms step_avg:41.42ms +step:689/1555 train_time:28553ms step_avg:41.44ms +step:690/1555 train_time:28616ms step_avg:41.47ms +step:691/1555 train_time:28674ms step_avg:41.50ms +step:692/1555 train_time:28739ms step_avg:41.53ms +step:693/1555 train_time:28797ms step_avg:41.55ms +step:694/1555 train_time:28862ms step_avg:41.59ms +step:695/1555 train_time:28919ms step_avg:41.61ms +step:696/1555 train_time:28984ms step_avg:41.64ms +step:697/1555 train_time:29041ms step_avg:41.67ms +step:698/1555 train_time:29106ms step_avg:41.70ms +step:699/1555 train_time:29163ms step_avg:41.72ms +step:700/1555 train_time:29228ms step_avg:41.75ms +step:701/1555 train_time:29286ms step_avg:41.78ms +step:702/1555 train_time:29350ms step_avg:41.81ms +step:703/1555 train_time:29408ms step_avg:41.83ms +step:704/1555 train_time:29473ms step_avg:41.87ms +step:705/1555 train_time:29532ms step_avg:41.89ms +step:706/1555 train_time:29595ms step_avg:41.92ms +step:707/1555 train_time:29653ms step_avg:41.94ms +step:708/1555 train_time:29717ms step_avg:41.97ms +step:709/1555 train_time:29775ms step_avg:42.00ms +step:710/1555 train_time:29839ms step_avg:42.03ms +step:711/1555 train_time:29897ms step_avg:42.05ms +step:712/1555 train_time:29960ms step_avg:42.08ms +step:713/1555 train_time:30017ms step_avg:42.10ms +step:714/1555 train_time:30083ms step_avg:42.13ms +step:715/1555 train_time:30140ms step_avg:42.15ms +step:716/1555 train_time:30204ms step_avg:42.18ms +step:717/1555 train_time:30263ms step_avg:42.21ms +step:718/1555 train_time:30327ms step_avg:42.24ms +step:719/1555 train_time:30385ms step_avg:42.26ms +step:720/1555 train_time:30449ms step_avg:42.29ms +step:721/1555 train_time:30507ms step_avg:42.31ms +step:722/1555 train_time:30572ms step_avg:42.34ms +step:723/1555 train_time:30630ms step_avg:42.37ms +step:724/1555 train_time:30694ms step_avg:42.40ms +step:725/1555 train_time:30752ms step_avg:42.42ms +step:726/1555 train_time:30817ms step_avg:42.45ms +step:727/1555 train_time:30874ms step_avg:42.47ms +step:728/1555 train_time:30938ms step_avg:42.50ms +step:729/1555 train_time:30995ms step_avg:42.52ms +step:730/1555 train_time:31060ms step_avg:42.55ms +step:731/1555 train_time:31117ms step_avg:42.57ms +step:732/1555 train_time:31182ms step_avg:42.60ms +step:733/1555 train_time:31239ms step_avg:42.62ms +step:734/1555 train_time:31304ms step_avg:42.65ms +step:735/1555 train_time:31363ms step_avg:42.67ms +step:736/1555 train_time:31429ms step_avg:42.70ms +step:737/1555 train_time:31485ms step_avg:42.72ms +step:738/1555 train_time:31550ms step_avg:42.75ms +step:739/1555 train_time:31608ms step_avg:42.77ms +step:740/1555 train_time:31672ms step_avg:42.80ms +step:741/1555 train_time:31731ms step_avg:42.82ms +step:742/1555 train_time:31795ms step_avg:42.85ms +step:743/1555 train_time:31854ms step_avg:42.87ms +step:744/1555 train_time:31917ms step_avg:42.90ms +step:745/1555 train_time:31975ms step_avg:42.92ms +step:746/1555 train_time:32040ms step_avg:42.95ms +step:747/1555 train_time:32096ms step_avg:42.97ms +step:748/1555 train_time:32161ms step_avg:43.00ms +step:749/1555 train_time:32218ms step_avg:43.01ms +step:750/1555 train_time:32282ms step_avg:43.04ms +step:750/1555 val_loss:3.8685 train_time:32365ms step_avg:43.15ms +step:751/1555 train_time:32388ms step_avg:43.13ms +step:752/1555 train_time:32410ms step_avg:43.10ms +step:753/1555 train_time:32466ms step_avg:43.12ms +step:754/1555 train_time:32532ms step_avg:43.15ms +step:755/1555 train_time:32592ms step_avg:43.17ms +step:756/1555 train_time:32657ms step_avg:43.20ms +step:757/1555 train_time:32713ms step_avg:43.21ms +step:758/1555 train_time:32777ms step_avg:43.24ms +step:759/1555 train_time:32834ms step_avg:43.26ms +step:760/1555 train_time:32897ms step_avg:43.29ms +step:761/1555 train_time:32954ms step_avg:43.30ms +step:762/1555 train_time:33018ms step_avg:43.33ms +step:763/1555 train_time:33074ms step_avg:43.35ms +step:764/1555 train_time:33138ms step_avg:43.37ms +step:765/1555 train_time:33195ms step_avg:43.39ms +step:766/1555 train_time:33259ms step_avg:43.42ms +step:767/1555 train_time:33317ms step_avg:43.44ms +step:768/1555 train_time:33383ms step_avg:43.47ms +step:769/1555 train_time:33441ms step_avg:43.49ms +step:770/1555 train_time:33508ms step_avg:43.52ms +step:771/1555 train_time:33567ms step_avg:43.54ms +step:772/1555 train_time:33632ms step_avg:43.56ms +step:773/1555 train_time:33690ms step_avg:43.58ms +step:774/1555 train_time:33754ms step_avg:43.61ms +step:775/1555 train_time:33811ms step_avg:43.63ms +step:776/1555 train_time:33875ms step_avg:43.65ms +step:777/1555 train_time:33932ms step_avg:43.67ms +step:778/1555 train_time:33995ms step_avg:43.70ms +step:779/1555 train_time:34053ms step_avg:43.71ms +step:780/1555 train_time:34116ms step_avg:43.74ms +step:781/1555 train_time:34173ms step_avg:43.76ms +step:782/1555 train_time:34238ms step_avg:43.78ms +step:783/1555 train_time:34296ms step_avg:43.80ms +step:784/1555 train_time:34360ms step_avg:43.83ms +step:785/1555 train_time:34418ms step_avg:43.84ms +step:786/1555 train_time:34482ms step_avg:43.87ms +step:787/1555 train_time:34540ms step_avg:43.89ms +step:788/1555 train_time:34606ms step_avg:43.92ms +step:789/1555 train_time:34664ms step_avg:43.93ms +step:790/1555 train_time:34729ms step_avg:43.96ms +step:791/1555 train_time:34787ms step_avg:43.98ms +step:792/1555 train_time:34851ms step_avg:44.00ms +step:793/1555 train_time:34908ms step_avg:44.02ms +step:794/1555 train_time:34972ms step_avg:44.05ms +step:795/1555 train_time:35031ms step_avg:44.06ms +step:796/1555 train_time:35094ms step_avg:44.09ms +step:797/1555 train_time:35152ms step_avg:44.11ms +step:798/1555 train_time:35217ms step_avg:44.13ms +step:799/1555 train_time:35274ms step_avg:44.15ms +step:800/1555 train_time:35339ms step_avg:44.17ms +step:801/1555 train_time:35398ms step_avg:44.19ms +step:802/1555 train_time:35462ms step_avg:44.22ms +step:803/1555 train_time:35520ms step_avg:44.23ms +step:804/1555 train_time:35584ms step_avg:44.26ms +step:805/1555 train_time:35641ms step_avg:44.27ms +step:806/1555 train_time:35705ms step_avg:44.30ms +step:807/1555 train_time:35763ms step_avg:44.32ms +step:808/1555 train_time:35827ms step_avg:44.34ms +step:809/1555 train_time:35885ms step_avg:44.36ms +step:810/1555 train_time:35950ms step_avg:44.38ms +step:811/1555 train_time:36007ms step_avg:44.40ms +step:812/1555 train_time:36072ms step_avg:44.42ms +step:813/1555 train_time:36129ms step_avg:44.44ms +step:814/1555 train_time:36193ms step_avg:44.46ms +step:815/1555 train_time:36252ms step_avg:44.48ms +step:816/1555 train_time:36317ms step_avg:44.51ms +step:817/1555 train_time:36375ms step_avg:44.52ms +step:818/1555 train_time:36439ms step_avg:44.55ms +step:819/1555 train_time:36498ms step_avg:44.56ms +step:820/1555 train_time:36562ms step_avg:44.59ms +step:821/1555 train_time:36619ms step_avg:44.60ms +step:822/1555 train_time:36683ms step_avg:44.63ms +step:823/1555 train_time:36740ms step_avg:44.64ms +step:824/1555 train_time:36805ms step_avg:44.67ms +step:825/1555 train_time:36862ms step_avg:44.68ms +step:826/1555 train_time:36927ms step_avg:44.71ms +step:827/1555 train_time:36984ms step_avg:44.72ms +step:828/1555 train_time:37050ms step_avg:44.75ms +step:829/1555 train_time:37107ms step_avg:44.76ms +step:830/1555 train_time:37171ms step_avg:44.78ms +step:831/1555 train_time:37229ms step_avg:44.80ms +step:832/1555 train_time:37295ms step_avg:44.83ms +step:833/1555 train_time:37353ms step_avg:44.84ms +step:834/1555 train_time:37417ms step_avg:44.86ms +step:835/1555 train_time:37475ms step_avg:44.88ms +step:836/1555 train_time:37539ms step_avg:44.90ms +step:837/1555 train_time:37597ms step_avg:44.92ms +step:838/1555 train_time:37661ms step_avg:44.94ms +step:839/1555 train_time:37718ms step_avg:44.96ms +step:840/1555 train_time:37782ms step_avg:44.98ms +step:841/1555 train_time:37840ms step_avg:44.99ms +step:842/1555 train_time:37903ms step_avg:45.02ms +step:843/1555 train_time:37961ms step_avg:45.03ms +step:844/1555 train_time:38027ms step_avg:45.06ms +step:845/1555 train_time:38085ms step_avg:45.07ms +step:846/1555 train_time:38150ms step_avg:45.10ms +step:847/1555 train_time:38209ms step_avg:45.11ms +step:848/1555 train_time:38273ms step_avg:45.13ms +step:849/1555 train_time:38331ms step_avg:45.15ms +step:850/1555 train_time:38395ms step_avg:45.17ms +step:851/1555 train_time:38453ms step_avg:45.19ms +step:852/1555 train_time:38518ms step_avg:45.21ms +step:853/1555 train_time:38575ms step_avg:45.22ms +step:854/1555 train_time:38640ms step_avg:45.25ms +step:855/1555 train_time:38698ms step_avg:45.26ms +step:856/1555 train_time:38761ms step_avg:45.28ms +step:857/1555 train_time:38819ms step_avg:45.30ms +step:858/1555 train_time:38882ms step_avg:45.32ms +step:859/1555 train_time:38940ms step_avg:45.33ms +step:860/1555 train_time:39005ms step_avg:45.35ms +step:861/1555 train_time:39062ms step_avg:45.37ms +step:862/1555 train_time:39128ms step_avg:45.39ms +step:863/1555 train_time:39185ms step_avg:45.41ms +step:864/1555 train_time:39250ms step_avg:45.43ms +step:865/1555 train_time:39307ms step_avg:45.44ms +step:866/1555 train_time:39372ms step_avg:45.46ms +step:867/1555 train_time:39430ms step_avg:45.48ms +step:868/1555 train_time:39494ms step_avg:45.50ms +step:869/1555 train_time:39552ms step_avg:45.51ms +step:870/1555 train_time:39617ms step_avg:45.54ms +step:871/1555 train_time:39676ms step_avg:45.55ms +step:872/1555 train_time:39740ms step_avg:45.57ms +step:873/1555 train_time:39797ms step_avg:45.59ms +step:874/1555 train_time:39861ms step_avg:45.61ms +step:875/1555 train_time:39919ms step_avg:45.62ms +step:876/1555 train_time:39983ms step_avg:45.64ms +step:877/1555 train_time:40040ms step_avg:45.66ms +step:878/1555 train_time:40104ms step_avg:45.68ms +step:879/1555 train_time:40162ms step_avg:45.69ms +step:880/1555 train_time:40226ms step_avg:45.71ms +step:881/1555 train_time:40284ms step_avg:45.73ms +step:882/1555 train_time:40349ms step_avg:45.75ms +step:883/1555 train_time:40406ms step_avg:45.76ms +step:884/1555 train_time:40471ms step_avg:45.78ms +step:885/1555 train_time:40529ms step_avg:45.80ms +step:886/1555 train_time:40593ms step_avg:45.82ms +step:887/1555 train_time:40652ms step_avg:45.83ms +step:888/1555 train_time:40716ms step_avg:45.85ms +step:889/1555 train_time:40774ms step_avg:45.86ms +step:890/1555 train_time:40838ms step_avg:45.89ms +step:891/1555 train_time:40896ms step_avg:45.90ms +step:892/1555 train_time:40960ms step_avg:45.92ms +step:893/1555 train_time:41017ms step_avg:45.93ms +step:894/1555 train_time:41081ms step_avg:45.95ms +step:895/1555 train_time:41139ms step_avg:45.97ms +step:896/1555 train_time:41203ms step_avg:45.99ms +step:897/1555 train_time:41261ms step_avg:46.00ms +step:898/1555 train_time:41325ms step_avg:46.02ms +step:899/1555 train_time:41383ms step_avg:46.03ms +step:900/1555 train_time:41448ms step_avg:46.05ms +step:901/1555 train_time:41505ms step_avg:46.07ms +step:902/1555 train_time:41570ms step_avg:46.09ms +step:903/1555 train_time:41629ms step_avg:46.10ms +step:904/1555 train_time:41693ms step_avg:46.12ms +step:905/1555 train_time:41752ms step_avg:46.13ms +step:906/1555 train_time:41817ms step_avg:46.16ms +step:907/1555 train_time:41874ms step_avg:46.17ms +step:908/1555 train_time:41939ms step_avg:46.19ms +step:909/1555 train_time:41996ms step_avg:46.20ms +step:910/1555 train_time:42060ms step_avg:46.22ms +step:911/1555 train_time:42118ms step_avg:46.23ms +step:912/1555 train_time:42181ms step_avg:46.25ms +step:913/1555 train_time:42239ms step_avg:46.26ms +step:914/1555 train_time:42304ms step_avg:46.28ms +step:915/1555 train_time:42360ms step_avg:46.30ms +step:916/1555 train_time:42425ms step_avg:46.32ms +step:917/1555 train_time:42482ms step_avg:46.33ms +step:918/1555 train_time:42548ms step_avg:46.35ms +step:919/1555 train_time:42607ms step_avg:46.36ms +step:920/1555 train_time:42672ms step_avg:46.38ms +step:921/1555 train_time:42730ms step_avg:46.39ms +step:922/1555 train_time:42794ms step_avg:46.41ms +step:923/1555 train_time:42852ms step_avg:46.43ms +step:924/1555 train_time:42916ms step_avg:46.45ms +step:925/1555 train_time:42974ms step_avg:46.46ms +step:926/1555 train_time:43038ms step_avg:46.48ms +step:927/1555 train_time:43096ms step_avg:46.49ms +step:928/1555 train_time:43159ms step_avg:46.51ms +step:929/1555 train_time:43217ms step_avg:46.52ms +step:930/1555 train_time:43281ms step_avg:46.54ms +step:931/1555 train_time:43339ms step_avg:46.55ms +step:932/1555 train_time:43403ms step_avg:46.57ms +step:933/1555 train_time:43461ms step_avg:46.58ms +step:934/1555 train_time:43525ms step_avg:46.60ms +step:935/1555 train_time:43584ms step_avg:46.61ms +step:936/1555 train_time:43648ms step_avg:46.63ms +step:937/1555 train_time:43705ms step_avg:46.64ms +step:938/1555 train_time:43770ms step_avg:46.66ms +step:939/1555 train_time:43828ms step_avg:46.68ms +step:940/1555 train_time:43893ms step_avg:46.69ms +step:941/1555 train_time:43952ms step_avg:46.71ms +step:942/1555 train_time:44015ms step_avg:46.73ms +step:943/1555 train_time:44075ms step_avg:46.74ms +step:944/1555 train_time:44137ms step_avg:46.76ms +step:945/1555 train_time:44194ms step_avg:46.77ms +step:946/1555 train_time:44259ms step_avg:46.79ms +step:947/1555 train_time:44317ms step_avg:46.80ms +step:948/1555 train_time:44380ms step_avg:46.81ms +step:949/1555 train_time:44438ms step_avg:46.83ms +step:950/1555 train_time:44502ms step_avg:46.84ms +step:951/1555 train_time:44559ms step_avg:46.85ms +step:952/1555 train_time:44624ms step_avg:46.87ms +step:953/1555 train_time:44680ms step_avg:46.88ms +step:954/1555 train_time:44746ms step_avg:46.90ms +step:955/1555 train_time:44803ms step_avg:46.91ms +step:956/1555 train_time:44869ms step_avg:46.93ms +step:957/1555 train_time:44927ms step_avg:46.95ms +step:958/1555 train_time:44991ms step_avg:46.96ms +step:959/1555 train_time:45050ms step_avg:46.98ms +step:960/1555 train_time:45114ms step_avg:46.99ms +step:961/1555 train_time:45172ms step_avg:47.01ms +step:962/1555 train_time:45238ms step_avg:47.02ms +step:963/1555 train_time:45295ms step_avg:47.04ms +step:964/1555 train_time:45359ms step_avg:47.05ms +step:965/1555 train_time:45417ms step_avg:47.06ms +step:966/1555 train_time:45480ms step_avg:47.08ms +step:967/1555 train_time:45538ms step_avg:47.09ms +step:968/1555 train_time:45602ms step_avg:47.11ms +step:969/1555 train_time:45659ms step_avg:47.12ms +step:970/1555 train_time:45723ms step_avg:47.14ms +step:971/1555 train_time:45781ms step_avg:47.15ms +step:972/1555 train_time:45846ms step_avg:47.17ms +step:973/1555 train_time:45904ms step_avg:47.18ms +step:974/1555 train_time:45970ms step_avg:47.20ms +step:975/1555 train_time:46028ms step_avg:47.21ms +step:976/1555 train_time:46092ms step_avg:47.23ms +step:977/1555 train_time:46150ms step_avg:47.24ms +step:978/1555 train_time:46215ms step_avg:47.25ms +step:979/1555 train_time:46273ms step_avg:47.27ms +step:980/1555 train_time:46337ms step_avg:47.28ms +step:981/1555 train_time:46395ms step_avg:47.29ms +step:982/1555 train_time:46460ms step_avg:47.31ms +step:983/1555 train_time:46517ms step_avg:47.32ms +step:984/1555 train_time:46580ms step_avg:47.34ms +step:985/1555 train_time:46638ms step_avg:47.35ms +step:986/1555 train_time:46702ms step_avg:47.37ms +step:987/1555 train_time:46760ms step_avg:47.38ms +step:988/1555 train_time:46823ms step_avg:47.39ms +step:989/1555 train_time:46881ms step_avg:47.40ms +step:990/1555 train_time:46945ms step_avg:47.42ms +step:991/1555 train_time:47003ms step_avg:47.43ms +step:992/1555 train_time:47068ms step_avg:47.45ms +step:993/1555 train_time:47126ms step_avg:47.46ms +step:994/1555 train_time:47191ms step_avg:47.48ms +step:995/1555 train_time:47250ms step_avg:47.49ms +step:996/1555 train_time:47314ms step_avg:47.50ms +step:997/1555 train_time:47371ms step_avg:47.51ms +step:998/1555 train_time:47436ms step_avg:47.53ms +step:999/1555 train_time:47494ms step_avg:47.54ms +step:1000/1555 train_time:47558ms step_avg:47.56ms +step:1000/1555 val_loss:3.5683 train_time:47640ms step_avg:47.64ms +step:1001/1555 train_time:47658ms step_avg:47.61ms +step:1002/1555 train_time:47681ms step_avg:47.59ms +step:1003/1555 train_time:47738ms step_avg:47.60ms +step:1004/1555 train_time:47807ms step_avg:47.62ms +step:1005/1555 train_time:47866ms step_avg:47.63ms +step:1006/1555 train_time:47931ms step_avg:47.64ms +step:1007/1555 train_time:47989ms step_avg:47.66ms +step:1008/1555 train_time:48052ms step_avg:47.67ms +step:1009/1555 train_time:48111ms step_avg:47.68ms +step:1010/1555 train_time:48174ms step_avg:47.70ms +step:1011/1555 train_time:48235ms step_avg:47.71ms +step:1012/1555 train_time:48320ms step_avg:47.75ms +step:1013/1555 train_time:48404ms step_avg:47.78ms +step:1014/1555 train_time:48493ms step_avg:47.82ms +step:1015/1555 train_time:48577ms step_avg:47.86ms +step:1016/1555 train_time:48668ms step_avg:47.90ms +step:1017/1555 train_time:48754ms step_avg:47.94ms +step:1018/1555 train_time:48847ms step_avg:47.98ms +step:1019/1555 train_time:48931ms step_avg:48.02ms +step:1020/1555 train_time:49021ms step_avg:48.06ms +step:1021/1555 train_time:49105ms step_avg:48.09ms +step:1022/1555 train_time:49194ms step_avg:48.13ms +step:1023/1555 train_time:49278ms step_avg:48.17ms +step:1024/1555 train_time:49367ms step_avg:48.21ms +step:1025/1555 train_time:49449ms step_avg:48.24ms +step:1026/1555 train_time:49540ms step_avg:48.28ms +step:1027/1555 train_time:49625ms step_avg:48.32ms +step:1028/1555 train_time:49716ms step_avg:48.36ms +step:1029/1555 train_time:49802ms step_avg:48.40ms +step:1030/1555 train_time:49891ms step_avg:48.44ms +step:1031/1555 train_time:49977ms step_avg:48.47ms +step:1032/1555 train_time:50067ms step_avg:48.51ms +step:1033/1555 train_time:50151ms step_avg:48.55ms +step:1034/1555 train_time:50242ms step_avg:48.59ms +step:1035/1555 train_time:50325ms step_avg:48.62ms +step:1036/1555 train_time:50414ms step_avg:48.66ms +step:1037/1555 train_time:50498ms step_avg:48.70ms +step:1038/1555 train_time:50587ms step_avg:48.74ms +step:1039/1555 train_time:50672ms step_avg:48.77ms +step:1040/1555 train_time:50763ms step_avg:48.81ms +step:1041/1555 train_time:50848ms step_avg:48.85ms +step:1042/1555 train_time:50939ms step_avg:48.89ms +step:1043/1555 train_time:51023ms step_avg:48.92ms +step:1044/1555 train_time:51112ms step_avg:48.96ms +step:1045/1555 train_time:51196ms step_avg:48.99ms +step:1046/1555 train_time:51287ms step_avg:49.03ms +step:1047/1555 train_time:51370ms step_avg:49.06ms +step:1048/1555 train_time:51460ms step_avg:49.10ms +step:1049/1555 train_time:51543ms step_avg:49.14ms +step:1050/1555 train_time:51632ms step_avg:49.17ms +step:1051/1555 train_time:51718ms step_avg:49.21ms +step:1052/1555 train_time:51808ms step_avg:49.25ms +step:1053/1555 train_time:51892ms step_avg:49.28ms +step:1054/1555 train_time:51983ms step_avg:49.32ms +step:1055/1555 train_time:52068ms step_avg:49.35ms +step:1056/1555 train_time:52157ms step_avg:49.39ms +step:1057/1555 train_time:52242ms step_avg:49.42ms +step:1058/1555 train_time:52330ms step_avg:49.46ms +step:1059/1555 train_time:52414ms step_avg:49.49ms +step:1060/1555 train_time:52505ms step_avg:49.53ms +step:1061/1555 train_time:52588ms step_avg:49.56ms +step:1062/1555 train_time:52678ms step_avg:49.60ms +step:1063/1555 train_time:52762ms step_avg:49.63ms +step:1064/1555 train_time:52851ms step_avg:49.67ms +step:1065/1555 train_time:52938ms step_avg:49.71ms +step:1066/1555 train_time:53029ms step_avg:49.75ms +step:1067/1555 train_time:53113ms step_avg:49.78ms +step:1068/1555 train_time:53204ms step_avg:49.82ms +step:1069/1555 train_time:53287ms step_avg:49.85ms +step:1070/1555 train_time:53377ms step_avg:49.89ms +step:1071/1555 train_time:53461ms step_avg:49.92ms +step:1072/1555 train_time:53551ms step_avg:49.95ms +step:1073/1555 train_time:53635ms step_avg:49.99ms +step:1074/1555 train_time:53725ms step_avg:50.02ms +step:1075/1555 train_time:53809ms step_avg:50.05ms +step:1076/1555 train_time:53901ms step_avg:50.09ms +step:1077/1555 train_time:53985ms step_avg:50.12ms +step:1078/1555 train_time:54076ms step_avg:50.16ms +step:1079/1555 train_time:54160ms step_avg:50.19ms +step:1080/1555 train_time:54248ms step_avg:50.23ms +step:1081/1555 train_time:54333ms step_avg:50.26ms +step:1082/1555 train_time:54423ms step_avg:50.30ms +step:1083/1555 train_time:54507ms step_avg:50.33ms +step:1084/1555 train_time:54597ms step_avg:50.37ms +step:1085/1555 train_time:54681ms step_avg:50.40ms +step:1086/1555 train_time:54771ms step_avg:50.43ms +step:1087/1555 train_time:54857ms step_avg:50.47ms +step:1088/1555 train_time:54947ms step_avg:50.50ms +step:1089/1555 train_time:55031ms step_avg:50.53ms +step:1090/1555 train_time:55123ms step_avg:50.57ms +step:1091/1555 train_time:55206ms step_avg:50.60ms +step:1092/1555 train_time:55296ms step_avg:50.64ms +step:1093/1555 train_time:55380ms step_avg:50.67ms +step:1094/1555 train_time:55469ms step_avg:50.70ms +step:1095/1555 train_time:55553ms step_avg:50.73ms +step:1096/1555 train_time:55644ms step_avg:50.77ms +step:1097/1555 train_time:55727ms step_avg:50.80ms +step:1098/1555 train_time:55817ms step_avg:50.83ms +step:1099/1555 train_time:55901ms step_avg:50.87ms +step:1100/1555 train_time:55990ms step_avg:50.90ms +step:1101/1555 train_time:56075ms step_avg:50.93ms +step:1102/1555 train_time:56165ms step_avg:50.97ms +step:1103/1555 train_time:56249ms step_avg:51.00ms +step:1104/1555 train_time:56340ms step_avg:51.03ms +step:1105/1555 train_time:56424ms step_avg:51.06ms +step:1106/1555 train_time:56514ms step_avg:51.10ms +step:1107/1555 train_time:56598ms step_avg:51.13ms +step:1108/1555 train_time:56688ms step_avg:51.16ms +step:1109/1555 train_time:56772ms step_avg:51.19ms +step:1110/1555 train_time:56863ms step_avg:51.23ms +step:1111/1555 train_time:56946ms step_avg:51.26ms +step:1112/1555 train_time:57037ms step_avg:51.29ms +step:1113/1555 train_time:57121ms step_avg:51.32ms +step:1114/1555 train_time:57211ms step_avg:51.36ms +step:1115/1555 train_time:57296ms step_avg:51.39ms +step:1116/1555 train_time:57386ms step_avg:51.42ms +step:1117/1555 train_time:57470ms step_avg:51.45ms +step:1118/1555 train_time:57559ms step_avg:51.48ms +step:1119/1555 train_time:57643ms step_avg:51.51ms +step:1120/1555 train_time:57733ms step_avg:51.55ms +step:1121/1555 train_time:57819ms step_avg:51.58ms +step:1122/1555 train_time:57908ms step_avg:51.61ms +step:1123/1555 train_time:57992ms step_avg:51.64ms +step:1124/1555 train_time:58082ms step_avg:51.67ms +step:1125/1555 train_time:58165ms step_avg:51.70ms +step:1126/1555 train_time:58256ms step_avg:51.74ms +step:1127/1555 train_time:58341ms step_avg:51.77ms +step:1128/1555 train_time:58430ms step_avg:51.80ms +step:1129/1555 train_time:58514ms step_avg:51.83ms +step:1130/1555 train_time:58605ms step_avg:51.86ms +step:1131/1555 train_time:58688ms step_avg:51.89ms +step:1132/1555 train_time:58777ms step_avg:51.92ms +step:1133/1555 train_time:58862ms step_avg:51.95ms +step:1134/1555 train_time:58952ms step_avg:51.99ms +step:1135/1555 train_time:59036ms step_avg:52.01ms +step:1136/1555 train_time:59126ms step_avg:52.05ms +step:1137/1555 train_time:59211ms step_avg:52.08ms +step:1138/1555 train_time:59302ms step_avg:52.11ms +step:1139/1555 train_time:59385ms step_avg:52.14ms +step:1140/1555 train_time:59475ms step_avg:52.17ms +step:1141/1555 train_time:59559ms step_avg:52.20ms +step:1142/1555 train_time:59649ms step_avg:52.23ms +step:1143/1555 train_time:59733ms step_avg:52.26ms +step:1144/1555 train_time:59825ms step_avg:52.29ms +step:1145/1555 train_time:59909ms step_avg:52.32ms +step:1146/1555 train_time:60000ms step_avg:52.36ms +step:1147/1555 train_time:60084ms step_avg:52.38ms +step:1148/1555 train_time:60172ms step_avg:52.41ms +step:1149/1555 train_time:60257ms step_avg:52.44ms +step:1150/1555 train_time:60348ms step_avg:52.48ms +step:1151/1555 train_time:60430ms step_avg:52.50ms +step:1152/1555 train_time:60522ms step_avg:52.54ms +step:1153/1555 train_time:60605ms step_avg:52.56ms +step:1154/1555 train_time:60694ms step_avg:52.59ms +step:1155/1555 train_time:60779ms step_avg:52.62ms +step:1156/1555 train_time:60868ms step_avg:52.65ms +step:1157/1555 train_time:60952ms step_avg:52.68ms +step:1158/1555 train_time:61044ms step_avg:52.72ms +step:1159/1555 train_time:61127ms step_avg:52.74ms +step:1160/1555 train_time:61217ms step_avg:52.77ms +step:1161/1555 train_time:61301ms step_avg:52.80ms +step:1162/1555 train_time:61391ms step_avg:52.83ms +step:1163/1555 train_time:61475ms step_avg:52.86ms +step:1164/1555 train_time:61566ms step_avg:52.89ms +step:1165/1555 train_time:61649ms step_avg:52.92ms +step:1166/1555 train_time:61740ms step_avg:52.95ms +step:1167/1555 train_time:61824ms step_avg:52.98ms +step:1168/1555 train_time:61915ms step_avg:53.01ms +step:1169/1555 train_time:61999ms step_avg:53.04ms +step:1170/1555 train_time:62088ms step_avg:53.07ms +step:1171/1555 train_time:62173ms step_avg:53.09ms +step:1172/1555 train_time:62263ms step_avg:53.13ms +step:1173/1555 train_time:62347ms step_avg:53.15ms +step:1174/1555 train_time:62437ms step_avg:53.18ms +step:1175/1555 train_time:62521ms step_avg:53.21ms +step:1176/1555 train_time:62610ms step_avg:53.24ms +step:1177/1555 train_time:62694ms step_avg:53.27ms +step:1178/1555 train_time:62785ms step_avg:53.30ms +step:1179/1555 train_time:62869ms step_avg:53.32ms +step:1180/1555 train_time:62958ms step_avg:53.35ms +step:1181/1555 train_time:63043ms step_avg:53.38ms +step:1182/1555 train_time:63132ms step_avg:53.41ms +step:1183/1555 train_time:63216ms step_avg:53.44ms +step:1184/1555 train_time:63306ms step_avg:53.47ms +step:1185/1555 train_time:63389ms step_avg:53.49ms +step:1186/1555 train_time:63480ms step_avg:53.52ms +step:1187/1555 train_time:63564ms step_avg:53.55ms +step:1188/1555 train_time:63654ms step_avg:53.58ms +step:1189/1555 train_time:63738ms step_avg:53.61ms +step:1190/1555 train_time:63828ms step_avg:53.64ms +step:1191/1555 train_time:63912ms step_avg:53.66ms +step:1192/1555 train_time:64004ms step_avg:53.69ms +step:1193/1555 train_time:64088ms step_avg:53.72ms +step:1194/1555 train_time:64179ms step_avg:53.75ms +step:1195/1555 train_time:64263ms step_avg:53.78ms +step:1196/1555 train_time:64352ms step_avg:53.81ms +step:1197/1555 train_time:64437ms step_avg:53.83ms +step:1198/1555 train_time:64527ms step_avg:53.86ms +step:1199/1555 train_time:64611ms step_avg:53.89ms +step:1200/1555 train_time:64703ms step_avg:53.92ms +step:1201/1555 train_time:64786ms step_avg:53.94ms +step:1202/1555 train_time:64876ms step_avg:53.97ms +step:1203/1555 train_time:64961ms step_avg:54.00ms +step:1204/1555 train_time:65049ms step_avg:54.03ms +step:1205/1555 train_time:65133ms step_avg:54.05ms +step:1206/1555 train_time:65225ms step_avg:54.08ms +step:1207/1555 train_time:65308ms step_avg:54.11ms +step:1208/1555 train_time:65400ms step_avg:54.14ms +step:1209/1555 train_time:65483ms step_avg:54.16ms +step:1210/1555 train_time:65572ms step_avg:54.19ms +step:1211/1555 train_time:65657ms step_avg:54.22ms +step:1212/1555 train_time:65748ms step_avg:54.25ms +step:1213/1555 train_time:65832ms step_avg:54.27ms +step:1214/1555 train_time:65923ms step_avg:54.30ms +step:1215/1555 train_time:66006ms step_avg:54.33ms +step:1216/1555 train_time:66095ms step_avg:54.35ms +step:1217/1555 train_time:66179ms step_avg:54.38ms +step:1218/1555 train_time:66269ms step_avg:54.41ms +step:1219/1555 train_time:66353ms step_avg:54.43ms +step:1220/1555 train_time:66444ms step_avg:54.46ms +step:1221/1555 train_time:66528ms step_avg:54.49ms +step:1222/1555 train_time:66618ms step_avg:54.52ms +step:1223/1555 train_time:66702ms step_avg:54.54ms +step:1224/1555 train_time:66794ms step_avg:54.57ms +step:1225/1555 train_time:66876ms step_avg:54.59ms +step:1226/1555 train_time:66967ms step_avg:54.62ms +step:1227/1555 train_time:67051ms step_avg:54.65ms +step:1228/1555 train_time:67142ms step_avg:54.68ms +step:1229/1555 train_time:67226ms step_avg:54.70ms +step:1230/1555 train_time:67316ms step_avg:54.73ms +step:1231/1555 train_time:67401ms step_avg:54.75ms +step:1232/1555 train_time:67490ms step_avg:54.78ms +step:1233/1555 train_time:67574ms step_avg:54.80ms +step:1234/1555 train_time:67665ms step_avg:54.83ms +step:1235/1555 train_time:67749ms step_avg:54.86ms +step:1236/1555 train_time:67842ms step_avg:54.89ms +step:1237/1555 train_time:67925ms step_avg:54.91ms +step:1238/1555 train_time:68014ms step_avg:54.94ms +step:1239/1555 train_time:68101ms step_avg:54.96ms +step:1240/1555 train_time:68190ms step_avg:54.99ms +step:1241/1555 train_time:68273ms step_avg:55.01ms +step:1242/1555 train_time:68365ms step_avg:55.04ms +step:1243/1555 train_time:68448ms step_avg:55.07ms +step:1244/1555 train_time:68539ms step_avg:55.10ms +step:1245/1555 train_time:68625ms step_avg:55.12ms +step:1246/1555 train_time:68714ms step_avg:55.15ms +step:1247/1555 train_time:68798ms step_avg:55.17ms +step:1248/1555 train_time:68888ms step_avg:55.20ms +step:1249/1555 train_time:68972ms step_avg:55.22ms +step:1250/1555 train_time:69063ms step_avg:55.25ms +step:1250/1555 val_loss:3.3959 train_time:69177ms step_avg:55.34ms +step:1251/1555 train_time:69195ms step_avg:55.31ms +step:1252/1555 train_time:69237ms step_avg:55.30ms +step:1253/1555 train_time:69325ms step_avg:55.33ms +step:1254/1555 train_time:69417ms step_avg:55.36ms +step:1255/1555 train_time:69502ms step_avg:55.38ms +step:1256/1555 train_time:69591ms step_avg:55.41ms +step:1257/1555 train_time:69674ms step_avg:55.43ms +step:1258/1555 train_time:69762ms step_avg:55.45ms +step:1259/1555 train_time:69847ms step_avg:55.48ms +step:1260/1555 train_time:69935ms step_avg:55.50ms +step:1261/1555 train_time:70018ms step_avg:55.53ms +step:1262/1555 train_time:70109ms step_avg:55.55ms +step:1263/1555 train_time:70195ms step_avg:55.58ms +step:1264/1555 train_time:70288ms step_avg:55.61ms +step:1265/1555 train_time:70374ms step_avg:55.63ms +step:1266/1555 train_time:70463ms step_avg:55.66ms +step:1267/1555 train_time:70548ms step_avg:55.68ms +step:1268/1555 train_time:70638ms step_avg:55.71ms +step:1269/1555 train_time:70721ms step_avg:55.73ms +step:1270/1555 train_time:70812ms step_avg:55.76ms +step:1271/1555 train_time:70894ms step_avg:55.78ms +step:1272/1555 train_time:70983ms step_avg:55.80ms +step:1273/1555 train_time:71067ms step_avg:55.83ms +step:1274/1555 train_time:71158ms step_avg:55.85ms +step:1275/1555 train_time:71243ms step_avg:55.88ms +step:1276/1555 train_time:71335ms step_avg:55.90ms +step:1277/1555 train_time:71418ms step_avg:55.93ms +step:1278/1555 train_time:71510ms step_avg:55.95ms +step:1279/1555 train_time:71593ms step_avg:55.98ms +step:1280/1555 train_time:71682ms step_avg:56.00ms +step:1281/1555 train_time:71766ms step_avg:56.02ms +step:1282/1555 train_time:71856ms step_avg:56.05ms +step:1283/1555 train_time:71939ms step_avg:56.07ms +step:1284/1555 train_time:72029ms step_avg:56.10ms +step:1285/1555 train_time:72114ms step_avg:56.12ms +step:1286/1555 train_time:72204ms step_avg:56.15ms +step:1287/1555 train_time:72289ms step_avg:56.17ms +step:1288/1555 train_time:72380ms step_avg:56.20ms +step:1289/1555 train_time:72465ms step_avg:56.22ms +step:1290/1555 train_time:72556ms step_avg:56.24ms +step:1291/1555 train_time:72640ms step_avg:56.27ms +step:1292/1555 train_time:72730ms step_avg:56.29ms +step:1293/1555 train_time:72814ms step_avg:56.31ms +step:1294/1555 train_time:72903ms step_avg:56.34ms +step:1295/1555 train_time:72987ms step_avg:56.36ms +step:1296/1555 train_time:73078ms step_avg:56.39ms +step:1297/1555 train_time:73161ms step_avg:56.41ms +step:1298/1555 train_time:73253ms step_avg:56.44ms +step:1299/1555 train_time:73337ms step_avg:56.46ms +step:1300/1555 train_time:73427ms step_avg:56.48ms +step:1301/1555 train_time:73512ms step_avg:56.50ms +step:1302/1555 train_time:73601ms step_avg:56.53ms +step:1303/1555 train_time:73686ms step_avg:56.55ms +step:1304/1555 train_time:73776ms step_avg:56.58ms +step:1305/1555 train_time:73859ms step_avg:56.60ms +step:1306/1555 train_time:73949ms step_avg:56.62ms +step:1307/1555 train_time:74033ms step_avg:56.64ms +step:1308/1555 train_time:74121ms step_avg:56.67ms +step:1309/1555 train_time:74207ms step_avg:56.69ms +step:1310/1555 train_time:74296ms step_avg:56.71ms +step:1311/1555 train_time:74379ms step_avg:56.73ms +step:1312/1555 train_time:74471ms step_avg:56.76ms +step:1313/1555 train_time:74555ms step_avg:56.78ms +step:1314/1555 train_time:74645ms step_avg:56.81ms +step:1315/1555 train_time:74729ms step_avg:56.83ms +step:1316/1555 train_time:74819ms step_avg:56.85ms +step:1317/1555 train_time:74903ms step_avg:56.87ms +step:1318/1555 train_time:74993ms step_avg:56.90ms +step:1319/1555 train_time:75076ms step_avg:56.92ms +step:1320/1555 train_time:75167ms step_avg:56.94ms +step:1321/1555 train_time:75252ms step_avg:56.97ms +step:1322/1555 train_time:75340ms step_avg:56.99ms +step:1323/1555 train_time:75425ms step_avg:57.01ms +step:1324/1555 train_time:75516ms step_avg:57.04ms +step:1325/1555 train_time:75599ms step_avg:57.06ms +step:1326/1555 train_time:75691ms step_avg:57.08ms +step:1327/1555 train_time:75775ms step_avg:57.10ms +step:1328/1555 train_time:75864ms step_avg:57.13ms +step:1329/1555 train_time:75948ms step_avg:57.15ms +step:1330/1555 train_time:76038ms step_avg:57.17ms +step:1331/1555 train_time:76121ms step_avg:57.19ms +step:1332/1555 train_time:76213ms step_avg:57.22ms +step:1333/1555 train_time:76296ms step_avg:57.24ms +step:1334/1555 train_time:76385ms step_avg:57.26ms +step:1335/1555 train_time:76470ms step_avg:57.28ms +step:1336/1555 train_time:76559ms step_avg:57.30ms +step:1337/1555 train_time:76643ms step_avg:57.32ms +step:1338/1555 train_time:76735ms step_avg:57.35ms +step:1339/1555 train_time:76819ms step_avg:57.37ms +step:1340/1555 train_time:76909ms step_avg:57.39ms +step:1341/1555 train_time:76993ms step_avg:57.41ms +step:1342/1555 train_time:77082ms step_avg:57.44ms +step:1343/1555 train_time:77166ms step_avg:57.46ms +step:1344/1555 train_time:77257ms step_avg:57.48ms +step:1345/1555 train_time:77341ms step_avg:57.50ms +step:1346/1555 train_time:77432ms step_avg:57.53ms +step:1347/1555 train_time:77516ms step_avg:57.55ms +step:1348/1555 train_time:77605ms step_avg:57.57ms +step:1349/1555 train_time:77690ms step_avg:57.59ms +step:1350/1555 train_time:77780ms step_avg:57.61ms +step:1351/1555 train_time:77863ms step_avg:57.63ms +step:1352/1555 train_time:77954ms step_avg:57.66ms +step:1353/1555 train_time:78037ms step_avg:57.68ms +step:1354/1555 train_time:78127ms step_avg:57.70ms +step:1355/1555 train_time:78211ms step_avg:57.72ms +step:1356/1555 train_time:78300ms step_avg:57.74ms +step:1357/1555 train_time:78384ms step_avg:57.76ms +step:1358/1555 train_time:78476ms step_avg:57.79ms +step:1359/1555 train_time:78559ms step_avg:57.81ms +step:1360/1555 train_time:78650ms step_avg:57.83ms +step:1361/1555 train_time:78734ms step_avg:57.85ms +step:1362/1555 train_time:78824ms step_avg:57.87ms +step:1363/1555 train_time:78909ms step_avg:57.89ms +step:1364/1555 train_time:78998ms step_avg:57.92ms +step:1365/1555 train_time:79082ms step_avg:57.94ms +step:1366/1555 train_time:79173ms step_avg:57.96ms +step:1367/1555 train_time:79256ms step_avg:57.98ms +step:1368/1555 train_time:79346ms step_avg:58.00ms +step:1369/1555 train_time:79431ms step_avg:58.02ms +step:1370/1555 train_time:79521ms step_avg:58.04ms +step:1371/1555 train_time:79605ms step_avg:58.06ms +step:1372/1555 train_time:79696ms step_avg:58.09ms +step:1373/1555 train_time:79780ms step_avg:58.11ms +step:1374/1555 train_time:79871ms step_avg:58.13ms +step:1375/1555 train_time:79954ms step_avg:58.15ms +step:1376/1555 train_time:80043ms step_avg:58.17ms +step:1377/1555 train_time:80129ms step_avg:58.19ms +step:1378/1555 train_time:80219ms step_avg:58.21ms +step:1379/1555 train_time:80302ms step_avg:58.23ms +step:1380/1555 train_time:80393ms step_avg:58.26ms +step:1381/1555 train_time:80476ms step_avg:58.27ms +step:1382/1555 train_time:80565ms step_avg:58.30ms +step:1383/1555 train_time:80652ms step_avg:58.32ms +step:1384/1555 train_time:80741ms step_avg:58.34ms +step:1385/1555 train_time:80825ms step_avg:58.36ms +step:1386/1555 train_time:80916ms step_avg:58.38ms +step:1387/1555 train_time:80999ms step_avg:58.40ms +step:1388/1555 train_time:81089ms step_avg:58.42ms +step:1389/1555 train_time:81174ms step_avg:58.44ms +step:1390/1555 train_time:81263ms step_avg:58.46ms +step:1391/1555 train_time:81347ms step_avg:58.48ms +step:1392/1555 train_time:81438ms step_avg:58.50ms +step:1393/1555 train_time:81522ms step_avg:58.52ms +step:1394/1555 train_time:81614ms step_avg:58.55ms +step:1395/1555 train_time:81697ms step_avg:58.56ms +step:1396/1555 train_time:81788ms step_avg:58.59ms +step:1397/1555 train_time:81872ms step_avg:58.61ms +step:1398/1555 train_time:81961ms step_avg:58.63ms +step:1399/1555 train_time:82045ms step_avg:58.65ms +step:1400/1555 train_time:82137ms step_avg:58.67ms +step:1401/1555 train_time:82220ms step_avg:58.69ms +step:1402/1555 train_time:82311ms step_avg:58.71ms +step:1403/1555 train_time:82394ms step_avg:58.73ms +step:1404/1555 train_time:82483ms step_avg:58.75ms +step:1405/1555 train_time:82569ms step_avg:58.77ms +step:1406/1555 train_time:82659ms step_avg:58.79ms +step:1407/1555 train_time:82742ms step_avg:58.81ms +step:1408/1555 train_time:82833ms step_avg:58.83ms +step:1409/1555 train_time:82916ms step_avg:58.85ms +step:1410/1555 train_time:83005ms step_avg:58.87ms +step:1411/1555 train_time:83089ms step_avg:58.89ms +step:1412/1555 train_time:83179ms step_avg:58.91ms +step:1413/1555 train_time:83263ms step_avg:58.93ms +step:1414/1555 train_time:83353ms step_avg:58.95ms +step:1415/1555 train_time:83437ms step_avg:58.97ms +step:1416/1555 train_time:83527ms step_avg:58.99ms +step:1417/1555 train_time:83611ms step_avg:59.01ms +step:1418/1555 train_time:83700ms step_avg:59.03ms +step:1419/1555 train_time:83784ms step_avg:59.04ms +step:1420/1555 train_time:83875ms step_avg:59.07ms +step:1421/1555 train_time:83959ms step_avg:59.08ms +step:1422/1555 train_time:84049ms step_avg:59.11ms +step:1423/1555 train_time:84134ms step_avg:59.12ms +step:1424/1555 train_time:84224ms step_avg:59.15ms +step:1425/1555 train_time:84309ms step_avg:59.16ms +step:1426/1555 train_time:84399ms step_avg:59.19ms +step:1427/1555 train_time:84484ms step_avg:59.20ms +step:1428/1555 train_time:84575ms step_avg:59.23ms +step:1429/1555 train_time:84659ms step_avg:59.24ms +step:1430/1555 train_time:84748ms step_avg:59.26ms +step:1431/1555 train_time:84833ms step_avg:59.28ms +step:1432/1555 train_time:84922ms step_avg:59.30ms +step:1433/1555 train_time:85007ms step_avg:59.32ms +step:1434/1555 train_time:85097ms step_avg:59.34ms +step:1435/1555 train_time:85182ms step_avg:59.36ms +step:1436/1555 train_time:85273ms step_avg:59.38ms +step:1437/1555 train_time:85356ms step_avg:59.40ms +step:1438/1555 train_time:85446ms step_avg:59.42ms +step:1439/1555 train_time:85531ms step_avg:59.44ms +step:1440/1555 train_time:85621ms step_avg:59.46ms +step:1441/1555 train_time:85706ms step_avg:59.48ms +step:1442/1555 train_time:85795ms step_avg:59.50ms +step:1443/1555 train_time:85880ms step_avg:59.52ms +step:1444/1555 train_time:85970ms step_avg:59.54ms +step:1445/1555 train_time:86054ms step_avg:59.55ms +step:1446/1555 train_time:86144ms step_avg:59.57ms +step:1447/1555 train_time:86230ms step_avg:59.59ms +step:1448/1555 train_time:86319ms step_avg:59.61ms +step:1449/1555 train_time:86403ms step_avg:59.63ms +step:1450/1555 train_time:86493ms step_avg:59.65ms +step:1451/1555 train_time:86576ms step_avg:59.67ms +step:1452/1555 train_time:86667ms step_avg:59.69ms +step:1453/1555 train_time:86751ms step_avg:59.70ms +step:1454/1555 train_time:86841ms step_avg:59.73ms +step:1455/1555 train_time:86925ms step_avg:59.74ms +step:1456/1555 train_time:87017ms step_avg:59.76ms +step:1457/1555 train_time:87100ms step_avg:59.78ms +step:1458/1555 train_time:87190ms step_avg:59.80ms +step:1459/1555 train_time:87274ms step_avg:59.82ms +step:1460/1555 train_time:87364ms step_avg:59.84ms +step:1461/1555 train_time:87450ms step_avg:59.86ms +step:1462/1555 train_time:87539ms step_avg:59.88ms +step:1463/1555 train_time:87624ms step_avg:59.89ms +step:1464/1555 train_time:87716ms step_avg:59.92ms +step:1465/1555 train_time:87798ms step_avg:59.93ms +step:1466/1555 train_time:87889ms step_avg:59.95ms +step:1467/1555 train_time:87973ms step_avg:59.97ms +step:1468/1555 train_time:88062ms step_avg:59.99ms +step:1469/1555 train_time:88149ms step_avg:60.01ms +step:1470/1555 train_time:88239ms step_avg:60.03ms +step:1471/1555 train_time:88322ms step_avg:60.04ms +step:1472/1555 train_time:88414ms step_avg:60.06ms +step:1473/1555 train_time:88497ms step_avg:60.08ms +step:1474/1555 train_time:88587ms step_avg:60.10ms +step:1475/1555 train_time:88670ms step_avg:60.12ms +step:1476/1555 train_time:88760ms step_avg:60.14ms +step:1477/1555 train_time:88846ms step_avg:60.15ms +step:1478/1555 train_time:88936ms step_avg:60.17ms +step:1479/1555 train_time:89019ms step_avg:60.19ms +step:1480/1555 train_time:89111ms step_avg:60.21ms +step:1481/1555 train_time:89194ms step_avg:60.23ms +step:1482/1555 train_time:89283ms step_avg:60.25ms +step:1483/1555 train_time:89369ms step_avg:60.26ms +step:1484/1555 train_time:89459ms step_avg:60.28ms +step:1485/1555 train_time:89543ms step_avg:60.30ms +step:1486/1555 train_time:89633ms step_avg:60.32ms +step:1487/1555 train_time:89716ms step_avg:60.33ms +step:1488/1555 train_time:89807ms step_avg:60.35ms +step:1489/1555 train_time:89890ms step_avg:60.37ms +step:1490/1555 train_time:89980ms step_avg:60.39ms +step:1491/1555 train_time:90064ms step_avg:60.41ms +step:1492/1555 train_time:90155ms step_avg:60.43ms +step:1493/1555 train_time:90239ms step_avg:60.44ms +step:1494/1555 train_time:90329ms step_avg:60.46ms +step:1495/1555 train_time:90413ms step_avg:60.48ms +step:1496/1555 train_time:90503ms step_avg:60.50ms +step:1497/1555 train_time:90588ms step_avg:60.51ms +step:1498/1555 train_time:90677ms step_avg:60.53ms +step:1499/1555 train_time:90760ms step_avg:60.55ms +step:1500/1555 train_time:90851ms step_avg:60.57ms +step:1500/1555 val_loss:3.2927 train_time:90966ms step_avg:60.64ms +step:1501/1555 train_time:90986ms step_avg:60.62ms +step:1502/1555 train_time:91026ms step_avg:60.60ms +step:1503/1555 train_time:91112ms step_avg:60.62ms +step:1504/1555 train_time:91206ms step_avg:60.64ms +step:1505/1555 train_time:91291ms step_avg:60.66ms +step:1506/1555 train_time:91382ms step_avg:60.68ms +step:1507/1555 train_time:91465ms step_avg:60.69ms +step:1508/1555 train_time:91553ms step_avg:60.71ms +step:1509/1555 train_time:91637ms step_avg:60.73ms +step:1510/1555 train_time:91726ms step_avg:60.75ms +step:1511/1555 train_time:91808ms step_avg:60.76ms +step:1512/1555 train_time:91899ms step_avg:60.78ms +step:1513/1555 train_time:91985ms step_avg:60.80ms +step:1514/1555 train_time:92076ms step_avg:60.82ms +step:1515/1555 train_time:92163ms step_avg:60.83ms +step:1516/1555 train_time:92258ms step_avg:60.86ms +step:1517/1555 train_time:92345ms step_avg:60.87ms +step:1518/1555 train_time:92433ms step_avg:60.89ms +step:1519/1555 train_time:92518ms step_avg:60.91ms +step:1520/1555 train_time:92607ms step_avg:60.93ms +step:1521/1555 train_time:92689ms step_avg:60.94ms +step:1522/1555 train_time:92779ms step_avg:60.96ms +step:1523/1555 train_time:92864ms step_avg:60.97ms +step:1524/1555 train_time:92954ms step_avg:60.99ms +step:1525/1555 train_time:93041ms step_avg:61.01ms +step:1526/1555 train_time:93132ms step_avg:61.03ms +step:1527/1555 train_time:93218ms step_avg:61.05ms +step:1528/1555 train_time:93308ms step_avg:61.07ms +step:1529/1555 train_time:93393ms step_avg:61.08ms +step:1530/1555 train_time:93485ms step_avg:61.10ms +step:1531/1555 train_time:93568ms step_avg:61.12ms +step:1532/1555 train_time:93658ms step_avg:61.13ms +step:1533/1555 train_time:93741ms step_avg:61.15ms +step:1534/1555 train_time:93831ms step_avg:61.17ms +step:1535/1555 train_time:93915ms step_avg:61.18ms +step:1536/1555 train_time:94007ms step_avg:61.20ms +step:1537/1555 train_time:94092ms step_avg:61.22ms +step:1538/1555 train_time:94183ms step_avg:61.24ms +step:1539/1555 train_time:94268ms step_avg:61.25ms +step:1540/1555 train_time:94358ms step_avg:61.27ms +step:1541/1555 train_time:94443ms step_avg:61.29ms +step:1542/1555 train_time:94532ms step_avg:61.30ms +step:1543/1555 train_time:94617ms step_avg:61.32ms +step:1544/1555 train_time:94707ms step_avg:61.34ms +step:1545/1555 train_time:94791ms step_avg:61.35ms +step:1546/1555 train_time:94881ms step_avg:61.37ms +step:1547/1555 train_time:94966ms step_avg:61.39ms +step:1548/1555 train_time:95056ms step_avg:61.41ms +step:1549/1555 train_time:95141ms step_avg:61.42ms +step:1550/1555 train_time:95231ms step_avg:61.44ms +step:1551/1555 train_time:95317ms step_avg:61.46ms +step:1552/1555 train_time:95408ms step_avg:61.47ms +step:1553/1555 train_time:95493ms step_avg:61.49ms +step:1554/1555 train_time:95584ms step_avg:61.51ms +step:1555/1555 train_time:95667ms step_avg:61.52ms +step:1555/1555 val_loss:3.2765 train_time:95782ms step_avg:61.60ms +peak memory allocated: 31630 MiB reserved: 46718 MiB diff --git a/records/track_1_short/2026-01-31-BigramHashH2D/9b20d90c-95f3-4222-befa-9fb695b83939.txt b/records/track_1_short/2026-01-31-BigramHashH2D/9b20d90c-95f3-4222-befa-9fb695b83939.txt new file mode 100644 index 000000000..34fd920a2 --- /dev/null +++ b/records/track_1_short/2026-01-31-BigramHashH2D/9b20d90c-95f3-4222-befa-9fb695b83939.txt @@ -0,0 +1,3976 @@ +import os +import sys + +# Read the current file and the kernels file code ASAP, for logging +with open(sys.argv[0], 'r') as f: + code = f.read() +with open(os.path.join(os.path.dirname(sys.argv[0]), 'triton_kernels.py'), 'r') as f: + code += f"\n\n{'-'*40}\n# triton_kernels.py\n{'-'*40}\n\n" + code += f.read() + +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate, pairwise +from pathlib import Path +import gc + +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +import torch +import triton + +torch.empty( + 1, device=f"cuda:{os.environ['LOCAL_RANK']}", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +from kernels import get_kernel +from torch import Tensor, nn + +from triton_kernels import XXT, ba_plus_cAA, FusedLinearReLUSquareFunction, FusedSoftcappedCrossEntropy + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Distributed training setup +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +grad_scale = 2 / grad_accum_steps # consistent grad magnitudes between different num_devices +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng +# Transposed layout by @ChrisJMcCormick allows for faster gradient accumulation. + +@torch.library.custom_op("nanogpt::mm_t", mutates_args=()) +def mm_t_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + """Computes y = x @ w with F8 weights stored as (in_features, out_features).""" + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + assert x.shape[1] == w.shape[0] # x: (batch, in), w: (in, out) + + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + + # _scaled_mm requires column-major B. w_f8 is row-major (in, out). + # .T.contiguous().T creates a column-major view without changing logical shape. + w_f8_col_major = w_f8.T.contiguous().T + + out = torch._scaled_mm( + x_f8, + w_f8_col_major, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_t_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[0] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_t_backward", mutates_args=()) +def mm_t_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + + x_scale = grad.new_tensor(x_s, dtype=torch.float32) + w_scale = grad.new_tensor(w_s, dtype=torch.float32) + grad_scale = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + + # grad_x = grad @ w.T + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=grad_scale, + scale_b=w_scale, + use_fast_accum=False, + ) + + # grad_w = x.T @ grad + # Result is (in, out), naturally matching weight storage. No final .T needed. + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_scale, + scale_b=grad_scale, + use_fast_accum=False, + ) + + return grad_x, grad_w + + grad_x, grad_w = impl(g, x_f8, w_f8) + + return grad_x, grad_w + +@mm_t_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) + +def backward_t(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_t_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context_t(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_t_op.register_autograd(backward_t, setup_context=setup_context_t) + +# ----------------------------------------------------------------------------- +# Polar Express + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor, split_baddbmm: bool = False): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + # Select batched vs unbatched + if split_baddbmm: + BX_matmul = torch.bmm if X.ndim > 2 else torch.mm + else: + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in polar_express_coeffs: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + + # Referencing X twice causes pytorch to make a defensive copy, + # resulting in a cudaMemcpyAsync in baddbmm. + # For large matrices (i.e., the mlp weights), it's faster to split + # the operation into two kernels to avoid this. + if split_baddbmm: + BX_matmul(B, X, out=C) # C = B @ X + C.add_(X, alpha=a) # C = C + a*X (in-place, X only read) + else: + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Combined NorMuon + Adam Optimizer + +@dataclass +class ParamConfig: + """Per-parameter configuration for NorMuonAndAdam optimizer.""" + label: str + optim: str # "adam" or "normuon" + comms: str # "none", "replicated", or "sharded" + adam_betas: tuple[float, float] | None + lr_mul: float + wd_mul: float + lr: float + initial_lr: float + weight_decay: float + # Adam-specific + eps: float | None = None + # NorMuon-specific + reshape: tuple | None = None + chunk_size: int | None = None + momentum: float | None = None + beta2: float | None = None + per_matrix_lr_mul: list[float] | None = None + + +class NorMuonAndAdam: + """ + Combined optimizer that handles both NorMuon (for projection matrices) and + Adam (for embeddings/scalars/gate weights). + + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, Muon uses a Newton-Schulz iteration (replaced + here with Polar Express), which has the advantage that it can be stably run in bfloat16 on the GPU. + + Muon is applied only to the projection matrices in the attention and MLP layers, and is not recommended + for embeddings, scalars, or individual weight vectors (e.g., bias terms or gate weights). + + Differences from standard Muon: + - Newton-Shulz is replaced with Polar Express for the orthogonalization step + - NorMuon adds a low-rank variance estimator similar to Adafactor. https://arxiv.org/pdf/2510.05491 + - Cautious weight decay, a gated version of decoupled weight decay + - Mantissa tracking for precision + + Adam (for embeddings/scalars/gates): + - Standard Adam with bias correction + - Cautious weight decay + + Configuration: + Unlike torch.optim.Optimizer, this class uses per-parameter configs from a `param_table` dict + and does not include parameter "groups". All parameters require a .label attribute, and a + corresponding entry in the param_table to specify their hyperparameters (lr_mul, wd_mul, adam_betas, etc.). + + Communication and ordering: + Gradient communication is explicitly scheduled rather than hook-driven. + Reductions are launched in `scatter_order`, while update math and final + gathers are executed in `work_order`. These orders are independent and + must each contain every parameter label exactly once. + + Two communication modes are supported per parameter: + - 'replicated': Gradients are all-reduced and each rank computes the full update. + - 'sharded': Gradients are reduce-scattered, each rank updates its shard, + and results are all-gathered. + + Adam parameters may be freely sharded. NorMuon operates on full matrices; sharding is + supported by grouping matrices into parameter banks. NorMuon parameters must have a + `.reshape` attribute that reshapes the bank so that the leading dimension is divisible + by world_size. + + # Contributors include @YouJiacheng, @KonstantinWilleke, @alexrgilbert, @adricarda, + # @tuttyfrutyee, @vdlad, @ryanyang0, @vagrawal, @varunneal, @chrisjmccormick + """ + def __init__(self, named_params, param_table: dict, scatter_order: list, work_order: list, + adam_defaults: dict, normuon_defaults: dict): + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # Store defaults for each optimizer type + self.adam_defaults = adam_defaults + self.normuon_defaults = normuon_defaults + self.param_table = param_table + self.scatter_order = scatter_order + self.work_order = work_order + + # Collect params by label and build config + self.param_cfgs: dict[nn.Parameter, ParamConfig] = {} + self.param_states: dict[nn.Parameter, dict] = {} + self._param_by_label: dict[str, nn.Parameter] = {} + for name, param in named_params: + label = getattr(param, "label", None) + assert label is not None and label in param_table # all params must have valid label + assert label not in self._param_by_label # exactly one param per label + self._param_by_label[label] = param + self._build_param_cfg(param, label) + + # Assert scatter_order and work_order match present labels exactly + present = set(self._param_by_label.keys()) + assert set(scatter_order) == present and set(work_order) == present + + # Handle world_size=1: overwrite comms to "none" + if self.world_size == 1: + for p_cfg in self.param_cfgs.values(): + p_cfg.comms = "none" + + # Initialize state for all params + self._init_state() + + # 0-D CPU tensors to avoid recompilation + self._step_size_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + + # Track async operations + self._reduce_futures: dict[nn.Parameter, tuple] = {} + + # Embed/lm_head tying state + self.split_embed = False + self._lm_head_param = self._param_by_label.get("lm_head") + self._embed_param = self._param_by_label.get("embed") + + def _build_param_cfg(self, param: nn.Parameter, label: str): + """Build config for a single parameter from param_table.""" + table_entry = self.param_table[label] + optim = table_entry["optim"] + comms = table_entry["comms"] + adam_betas = table_entry.get("adam_betas") + lr_mul = table_entry.get("lr_mul", 1.0) + wd_mul = table_entry.get("wd_mul", 1.0) + + if optim == "adam": + chunk_size = param.shape[0] // self.world_size if comms == "sharded" else None + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.adam_defaults["lr"], + initial_lr=self.adam_defaults["lr"], + weight_decay=self.adam_defaults["weight_decay"], + eps=self.adam_defaults["eps"], + chunk_size=chunk_size, + ) + elif optim == "normuon": + reshape = getattr(param, "reshape", None) + if reshape is None: + raise ValueError(f"NorMuon param {label} must have .reshape attribute") + if reshape[0] % self.world_size != 0: + raise ValueError(f"reshape[0]={reshape[0]} must be divisible by world_size") + + chunk_size = reshape[0] // self.world_size + chunk_shape = (chunk_size, *reshape[1:]) + # Shape-based LR multiplier for NorMuon + shape_mult = max(1.0, chunk_shape[-2] / chunk_shape[-1]) ** 0.5 if len(chunk_shape) >= 2 else 1.0 + lr_mul = shape_mult * lr_mul + + # Per-matrix LR multipliers for MLP c_proj (2x LR on odd indices) + per_matrix_lr_mul = None + if label == "mlp": + rank = dist.get_rank() if dist.is_initialized() else 0 + start_idx = rank * chunk_size + per_matrix_lr_mul = [] + for i in range(chunk_size): + global_idx = start_idx + i + is_c_proj = (global_idx % 2 == 1) + per_matrix_lr_mul.append(2.0 if is_c_proj else 1.0) + + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.normuon_defaults["lr"], + initial_lr=self.normuon_defaults["lr"], + weight_decay=self.normuon_defaults["weight_decay"], + reshape=reshape, + chunk_size=chunk_size, + momentum=self.normuon_defaults["momentum"], + beta2=self.normuon_defaults["beta2"], + per_matrix_lr_mul=per_matrix_lr_mul, + ) + else: + raise ValueError(f"Unknown optim type: {optim}") + + self.param_cfgs[param] = p_cfg + + def _init_state(self): + """Initialize optimizer state for all parameters.""" + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam": + # Sharded params use chunk state, replicated use full state + if p_cfg.comms == "sharded": + chunk = param[:p_cfg.chunk_size] + else: + chunk = param + exp_avg = torch.zeros_like(chunk, dtype=torch.float32, device=param.device) + self.param_states[param] = dict(step=0, exp_avg=exp_avg, exp_avg_sq=torch.zeros_like(exp_avg)) + + elif p_cfg.optim == "normuon": + chunk_shape = (p_cfg.chunk_size, *p_cfg.reshape[1:]) + + # Momentum buffer (FP32 for precision) + momentum_buffer = torch.zeros( + chunk_shape, dtype=torch.float32, device=param.device + ) + + # Second momentum buffer - reduced along one dimension + if chunk_shape[-2] >= chunk_shape[-1]: + second_mom_shape = (*chunk_shape[:-1], 1) + else: + second_mom_shape = (*chunk_shape[:-2], 1, chunk_shape[-1]) + second_momentum_buffer = torch.zeros( + second_mom_shape, dtype=torch.float32, device=param.device + ) + + # Mantissa buffer for precision tracking + mantissa = torch.zeros( + chunk_shape, dtype=torch.uint16, device=param.device + ) + + self.param_states[param] = dict( + momentum_buffer=momentum_buffer, + second_momentum_buffer=second_momentum_buffer, + mantissa=mantissa, + ) + + # ----------------------------------- + # Reduce/Gather operations + + def _launch_reduce(self, param: nn.Parameter, grad: Tensor): + """Launch async reduce for a parameter based on its comms policy.""" + p_cfg = self.param_cfgs[param] + + if p_cfg.comms == "none": + if p_cfg.optim == "normuon": + # NorMuon needs reshaped gradient even without communication + grad = grad.view(p_cfg.reshape) + self._reduce_futures[param] = (None, grad) + elif p_cfg.comms == "replicated": + future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + self._reduce_futures[param] = (future, grad) + elif p_cfg.comms == "sharded": + if p_cfg.optim == "normuon": + # NorMuon: reshape before reduce_scatter + grad_reshaped = grad.view(p_cfg.reshape) + grad_chunk = torch.empty( + (p_cfg.chunk_size, *grad_reshaped.shape[1:]), + dtype=grad.dtype, + device=grad.device + ) + future = dist.reduce_scatter_tensor( + grad_chunk, grad_reshaped.contiguous(), op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + else: + # Adam: simple reduce_scatter + grad_chunk = torch.empty_like(grad[:p_cfg.chunk_size]) + future = dist.reduce_scatter_tensor( + grad_chunk, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + + def _launch_gather(self, param: nn.Parameter, p_slice: Tensor) -> "torch.futures.Future": + """Launch async all_gather for a sharded parameter.""" + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "normuon": + full_param = param.data.view(p_cfg.reshape) + assert full_param.is_contiguous() + return dist.all_gather_into_tensor( + full_param, p_slice.contiguous(), async_op=True + ).get_future() + else: + return dist.all_gather_into_tensor( + param, p_slice.contiguous(), async_op=True + ).get_future() + + # ----------------------------------- + # State management + + def reset(self): + """Reset NorMuon momentum buffers and split_embed state (called on training reset).""" + self.split_embed = False + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "normuon": + p_state = self.param_states[param] + p_state["momentum_buffer"].zero_() + p_state["mantissa"].zero_() + p_state["second_momentum_buffer"].zero_() + + def copy_lm_state_to_embed(self): + """ + Copy the optimizer state from the lm_head to the embed at the untie point. + This requires an all-gather + reshard because of different sharding: + - lm_head (768, 50304) is sharded to (96, 50304) per rank (along model_dim) + - embed (50304, 768) is sharded to (6288, 768) per rank (along vocab_size) + + We all-gather the lm_head momentum, transpose it, then each rank takes their + embed shard to get the correct momentum state. + """ + lm_head = self._lm_head_param + embed = self._embed_param + lm_state = self.param_states[lm_head] + embed_state = self.param_states[embed] + lm_cfg = self.param_cfgs[lm_head] + embed_cfg = self.param_cfgs[embed] + + embed_state['step'] = lm_state['step'] # Preserve step count for bias correction + + # Copy optimizer state with all-gather + transpose + reshard + if self.world_size > 1: + rank = dist.get_rank() + lm_chunk_size = lm_cfg.chunk_size # 96 + embed_chunk_size = embed_cfg.chunk_size # 6288 + + # All-gather lm_head momentum to get full (768, 50304) tensor + for key in ["exp_avg", "exp_avg_sq"]: + lm_chunk = lm_state[key] # (96, 50304) + full_lm = torch.empty(lm_head.shape[0], lm_head.shape[1], dtype=lm_chunk.dtype, device=lm_chunk.device) + dist.all_gather_into_tensor(full_lm, lm_chunk.contiguous()) + embed_state[key].copy_(full_lm.T[rank * embed_chunk_size:(rank + 1) * embed_chunk_size]) + else: + # Single GPU: simple transpose + for key in ["exp_avg", "exp_avg_sq"]: + embed_state[key].copy_(lm_state[key].T) + + # Mark as split + self.split_embed = True + + def state_dict(self): + """Return the optimizer state as a dict.""" + return { + "param_states": {id(p): s for p, s in self.param_states.items()}, + "param_cfgs": {id(p): s for p, s in self.param_cfgs.items()}, + } + + def load_state_dict(self, state_dict): + """Load optimizer state from a dict.""" + # Build id->param mapping + id_to_param = {id(p): p for p in self.param_cfgs.keys()} + + # Load state, preserving dtypes + for param_id, saved_p_state in state_dict["param_states"].items(): + if param_id in id_to_param: + param = id_to_param[param_id] + p_state = self.param_states[param] + for k, v in saved_p_state.items(): + if isinstance(v, torch.Tensor) and k in p_state: + target_dtype = p_state[k].dtype + p_state[k] = v.to(dtype=target_dtype, device=p_state[k].device) + else: + p_state[k] = v + + # ----------------------------------- + # Unified optimizer step with explicit ordering + + @torch.no_grad() + def step(self, do_adam: bool = True): + """ + Combined optimizer step with explicit ordering. + + Args: + do_adam: If True, update Adam params. NorMuon params always updated. + + Flow: + 1. Scatter phase: Launch reduces in scatter_order + 2. Work phase: Process updates in work_order + - Wait for reduce, compute update, launch gather + 3. Finalize phase: Wait for gathers + + While the embeddings are tied: + - Comms and update math are only done on lm_head. + - We add embed.grad.T into lm_head.grad before comms. + - After lm_head gather, we copy lm_head.data.T --> embed.data + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + lm_param, embed_param = self._lm_head_param, self._embed_param + + # ===== Phase 1: Launch reduces in scatter_order ===== + for label in self.scatter_order: + param = self._param_by_label[label] + p_cfg = self.param_cfgs[param] + + if p_cfg.optim == "adam" and not do_adam: + continue + if param.grad is None: + continue + + # lm_head when tied: aggregate embed.grad.T (transposed shapes) + if label == "lm_head" and do_adam and not self.split_embed: + if embed_param is not None and embed_param.grad is not None: + param.grad.add_(embed_param.grad.T) + + # Skip embed when tied (copied from lm_head after gather) + if label == "embed" and not self.split_embed: + continue + + self._launch_reduce(param, param.grad) + + # ===== Phase 2: Process updates in work_order ===== + gather_futures = [] + lm_head_gather_future = None + + for label in self.work_order: + param = self._param_by_label[label] + if param not in self._reduce_futures: + continue + + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "adam" and not do_adam: + continue + # Wait for reduce + future, grad_chunk = self._reduce_futures[param] + if future is not None: + future.wait() + # Apply update based on optim type + if p_cfg.optim == "adam": + p_slice = self._adam_update(param, grad_chunk, p_cfg, rank) + else: + p_slice = self._normuon_update(param, grad_chunk, p_cfg, rank) + # Launch gather for sharded params + if p_cfg.comms == "sharded" and self.world_size > 1: + gather_fut = self._launch_gather(param, p_slice) + if label == "lm_head": + lm_head_gather_future = gather_fut + else: + gather_futures.append(gather_fut) + + # ===== Phase 3: Wait for gathers, sync embed if tied ===== + # Wait for lm_head gather first so we can copy to embed while other gathers complete + if lm_head_gather_future is not None: + lm_head_gather_future.wait() + + # When tied: copy lm_head.T to embed + if do_adam and not self.split_embed and embed_param is not None and lm_param is not None: + embed_param.data.copy_(lm_param.data.T) + + # Wait for remaining gathers + for fut in gather_futures: + fut.wait() + + self._reduce_futures.clear() + + # Clear grads for updated params + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam" and not do_adam: + continue # Don't clear Adam grads on even steps + param.grad = None + + # ----------------------------------- + # Adam update + + def _adam_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply Adam update to a parameter. Returns the updated p_slice.""" + beta1, beta2 = p_cfg.adam_betas + lr = p_cfg.lr * p_cfg.lr_mul + + # Get parameter slice + if p_cfg.comms == "sharded": + p_slice = param[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + else: + p_slice = param + + p_state = self.param_states[param] + p_state["step"] += 1 + t = p_state["step"] + + bias1, bias2 = 1 - beta1 ** t, 1 - beta2 ** t + self._step_size_t.fill_(lr * (bias2 ** 0.5 / bias1)) + self._eff_wd_t.fill_(lr * lr * p_cfg.weight_decay * p_cfg.wd_mul) + + NorMuonAndAdam._adam_update_step( + p_slice, grad_chunk, p_state["exp_avg"], p_state["exp_avg_sq"], + beta1, beta2, p_cfg.eps, self._step_size_t, self._eff_wd_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _adam_update_step(p_slice, g_slice, exp_avg, exp_avg_sq, beta1, beta2, eps, step_size_t, eff_wd_t): + """Compiled Adam update step.""" + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + update = exp_avg.div(exp_avg_sq.sqrt().add_(eps)).mul_(step_size_t) + # Cautious weight decay + mask = (update * p_slice) > 0 + update.addcmul_(p_slice, mask, value=eff_wd_t) + p_slice.add_(other=update, alpha=-1.0) + + # ----------------------------------- + # NorMuon update + + def _normuon_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply NorMuon update to a parameter. Returns the updated p_slice.""" + chunk_shape = grad_chunk.shape + + p_state = self.param_states[param] + grad_chunk = grad_chunk.float() # FP32 for momentum + + # Momentum update + momentum_buffer = p_state["momentum_buffer"] + momentum_buffer.lerp_(grad_chunk, 1 - p_cfg.momentum) + updated_grads = grad_chunk.lerp_(momentum_buffer, p_cfg.momentum) + + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + + # Polar Express orthogonalization + is_large_matrix = chunk_shape[-2] > 1024 + v_chunk = polar_express(updated_grads, split_baddbmm=is_large_matrix) + + # Variance reduction + red_dim = -1 if chunk_shape[-2] >= chunk_shape[-1] else -2 + v_chunk = NorMuonAndAdam._apply_normuon_variance_reduction( + v_chunk, p_state["second_momentum_buffer"], p_cfg.beta2, red_dim + ) + + # Update parameter, in place, with cautious weight decay + param_view = param.data.view(p_cfg.reshape) + p_slice = param_view[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + + # MLP has per-matrix LR multipliers (c_proj gets 2x LR) + if p_cfg.per_matrix_lr_mul is not None: + for mat_idx in range(p_cfg.chunk_size): + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.per_matrix_lr_mul[mat_idx] * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice[mat_idx].view(torch.uint16), p_state["mantissa"][mat_idx], v_chunk[mat_idx], + self._eff_wd_t, self._eff_lr_t + ) + else: + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice.view(torch.uint16), p_state["mantissa"], v_chunk, + self._eff_wd_t, self._eff_lr_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _cautious_wd_and_update_inplace(p, mantissa, grad, wd_tensor, lr_tensor): + """ + Cautious weight decay + parameter update. wd_tensor and lr_tensor are 0-D CPU tensors. + Mantissa is tracked to enable higher precision updates on bfloat16 parameters. + bfloat16 format: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits total + float32 format: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits total + """ + assert p.dtype == mantissa.dtype == torch.uint16 + grad = grad.float() + wd_factor = wd_tensor.to(torch.float32) + lr_factor = lr_tensor.to(torch.float32) + p_precise_raw = (p.to(torch.uint32) << 16) | mantissa.to(torch.uint32) + p_precise = p_precise_raw.view(torch.float32) + mask = (grad * p_precise) >= 0 + p_precise.copy_(p_precise - (p_precise * mask * wd_factor * lr_factor) - (grad * lr_factor)) + p.copy_((p_precise_raw >> 16).to(torch.uint16)) + mantissa.copy_(p_precise_raw.to(torch.uint16)) + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _apply_normuon_variance_reduction(v_chunk, second_momentum_buffer, beta2, red_dim): + """NorMuon variance reduction. Algebraically fuses the normalization steps to minimize memory ops.""" + v_mean = v_chunk.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = v_chunk.size(red_dim) + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True).mul_(red_dim_size) + v_norm = v_norm_sq.sqrt_() + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt_() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt_() + final_scale = step_size * (v_norm / v_norm_new.clamp_min_(1e-10)) + return v_chunk.mul_(final_scale.type_as(v_chunk)) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinearT(nn.Module): + """ + Linear layer with transposed weight storage (in_features, out_features) which + addresses the slow kernel that was used for gradient accumulation. @chrisjmccormick + """ + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + self.weight = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.bfloat16)) + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + nn.init.zeros_(self.weight) # @Grad62304977 and others + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out = torch.ops.nanogpt.mm_t(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return x @ self.weight.type_as(x) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len, paired=False): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.paired = paired + self.reset() + + def rotary(self, x_BTHD): + assert self.factor1.size(0) >= x_BTHD.size(-3) + factor1, factor2 = ( + self.factor1[None, : x_BTHD.size(-3), None, :], + self.factor2[None, : x_BTHD.size(-3), None, :], + ) + x_flip = x_BTHD.view(*x_BTHD.shape[:-1], x_BTHD.shape[-1] // 2, 2).flip(-1).view(x_BTHD.shape) + return factor1 * x_BTHD + factor2 * x_flip + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + angular_freq = angular_freq.repeat_interleave(2) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//2)]) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=device) + if not self.paired: + theta = torch.outer(t, angular_freq) + self.factor1 = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.factor2 = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, angular_freq) + theta2 = torch.outer(t_odd, angular_freq) + self.factor1 = nn.Buffer( + torch.cat((theta1.cos(), theta2.cos()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2 = nn.Buffer( + torch.cat((theta1.sin(), theta2.sin()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2[..., 1::2] *= -1 + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + if not self.paired: + theta = torch.outer(t, self.angular_freq) + self.factor1.copy_(theta.cos()) + self.factor2.copy_(theta.sin()) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, self.angular_freq) + theta2 = torch.outer(t_odd, self.angular_freq) + self.factor1.copy_(torch.cat((theta1.cos(), theta2.cos()), dim=-1)) + self.factor2.copy_(torch.cat((theta1.sin(), theta2.sin()), dim=-1)) + self.factor2[..., 1::2] *= -1 + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + yarn: Yarn + key_offset: bool + attn_gate_w: torch.Tensor + ve_gate_w: torch.Tensor + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, paired: bool = False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + self.paired = paired + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + yarn = attn_args.yarn + ve, sa_lambdas, key_offset = attn_args.ve, attn_args.sa_lambdas, attn_args.key_offset + seqlens, bm_size = attn_args.seqlens, attn_args.bm_size + # sparse gated attention to enable context based no-op by @classiclarryd + # only include gates on layers with value embeds used on forward pass + attn_gate_w, ve_gate_w = attn_args.attn_gate_w, attn_args.ve_gate_w + + q, k, v = F.linear(x, sa_lambdas[0] * qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + q, k = norm(q), norm(k) # QK norm @Grad62304977 + + if not self.paired: + q, k = yarn.rotary(q), yarn.rotary(k) + + if key_offset: + # shift keys forward for the stationary head dims. Enables 1-layer induction. + k[:, 1:, :, self.head_dim // 2:] = k[:, :-1, :, self.head_dim // 2:] + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T, self.num_heads, 1) + v = v + ve_gate_out * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + + else: + # Paired heads: adjacent heads' queries attend to each other's keys. + # Two copies of the input stream are interleaved to achieve this, which: + # - doubles the length of each sequence + # - halves the effective window size + q = q.view(B, T, self.num_heads // 2, self.head_dim * 2) + k = k.view(B, T, self.num_heads // 2, self.head_dim * 2) + v = v.reshape(B, T * 2, self.num_heads // 2, self.head_dim) + + q, k = yarn.rotary(q), yarn.rotary(k) + + q = q.view(B, T * 2, self.num_heads // 2, self.head_dim) + k = k.view(B, T * 2, self.num_heads // 2, self.head_dim) + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T * 2, self.num_heads // 2, 1) + v = v + ve_gate_out * ve.view_as(v) + + seqlens = 2 * seqlens + max_len = 2 * max_len + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, + max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=yarn.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(F.linear(x[..., :12], attn_gate_w)).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, sa_lambdas[1] * qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg + return y + +class MLP(nn.Module): + def __init__(self): + super().__init__() + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, c_fc: Tensor, c_proj: Tensor): + # relu(x)^2: + # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + # Fused triton kernel for relu(x @ W1.T)^2 @ W2.T + return FusedLinearReLUSquareFunction.apply(x, c_fc, c_proj) + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, has_attn: bool, has_mlp: bool, use_paired_head: bool): + super().__init__() + # skip attention of blocks.6 (the 7th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads, paired=use_paired_head) if has_attn else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP() if has_mlp else None + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor = None, c_fc: Tensor = None, c_proj: Tensor = None): + if self.attn is not None: + x = x + self.attn(norm(x), attn_args, qkvo_w) + if self.mlp is not None: + x = x + self.mlp(norm(x), c_fc, c_proj) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +@dataclass +class ForwardScheduleConfig: + mtp_weights: torch.Tensor + ws_short: int + ws_long: int + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + self.num_layers = num_layers + self.vocab_size = next_multiple_of_n(vocab_size, n=128) + + self.smear_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.smear_gate.weight) + self.smear_gate.weight.label = 'smear_gate' + + self.skip_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.skip_gate.weight) + self.skip_gate.weight.label = 'skip_gate' + + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.Parameter(torch.zeros(5 * self.vocab_size, model_dim, dtype=torch.bfloat16)) + self.value_embeds.label = 'value_embed' + + # parameter banks for attention and value embedding gate weights + self.attn_gate_bank = nn.Parameter(torch.zeros(10, num_heads, 12)) # 10 layers + self.attn_gate_bank.label = 'attn_gate_bank' + self.ve_gate_bank = nn.Parameter(torch.zeros(5, num_heads, 12)) # 5 unique gates + self.ve_gate_bank.label = 've_gate_bank' + + # ----------------------------------- + # Parameter banks for sharded optimization, by @chrisjmccormick + + # Identify which layers have attention/MLP + # Attention is skipped in layer 6 by @YouJiacheng + self.attn_layer_indices = [i for i in range(num_layers) if i != 6] + # All layers have MLP (At 11 layers--dropped first layer @EmelyanenkoK) + self.mlp_layer_indices = list(range(num_layers)) + + hdim = num_heads * head_dim + mlp_hdim = 4 * model_dim + + # Create index mappings: layer_idx -> bank_idx + self.layer_to_attn_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.attn_layer_indices)} + self.layer_to_mlp_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.mlp_layer_indices)} + + # Attention bank: stores QKVO weights for all attention layers + # merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # Simplified layout by @chrisjmccormick + # Shape: (num_attn_layers, 4*model_dim, hdim) = (10, 3072, 768) + # Reshape for sharding: (40, 768, 768) for even distribution across 8 GPUs + self.attn_bank = nn.Parameter(torch.empty(len(self.attn_layer_indices), 4 * model_dim, hdim)) + self.attn_bank.label = 'attn' + self.attn_bank.reshape = (len(self.attn_layer_indices) * 4, hdim, hdim) # (40, 768, 768) + + # MLP bank: stores c_fc and c_proj for all MLP layers + # Shape: (num_mlp_layers + padding, 2, mlp_hdim, model_dim) = (12, 2, 3072, 768) + # We add 1 padding layer (index 11) to get 12*2=24 matrices for even distribution across 8 GPUs + # Reshape for sharding: (24, 3072, 768) + num_mlp_with_padding = len(self.mlp_layer_indices) + 1 # 11 + 1 = 12 + self.mlp_bank = nn.Parameter(torch.empty(num_mlp_with_padding, 2, mlp_hdim, model_dim)) + self.mlp_bank.label = 'mlp' + self.mlp_bank.reshape = (num_mlp_with_padding * 2, mlp_hdim, model_dim) # (24, 3072, 768) + + # improved init scale by @YouJiacheng and @srashedll + std = 0.5 * model_dim ** -0.5 + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.attn_bank.uniform_(-bound, bound) + self.mlp_bank[:, 0, :, :].uniform_(-bound, bound) # c_fc + self.mlp_bank[:, 1, :, :].zero_() # c_proj - zero init suggested by @Grad62304977 + + # Create blocks with has_attn/has_mlp flags + self.paired_head_layers = [0, 2, 5, 9] + self.blocks = nn.ModuleList([ + Block(model_dim, head_dim, num_heads, + has_attn=(i in self.layer_to_attn_idx), + has_mlp=(i in self.layer_to_mlp_idx), + use_paired_head=(i in self.paired_head_layers)) + for i in range(num_layers) + ]) + self.yarn = Yarn(head_dim, max_seq_len) + self.yarn_paired_head = Yarn(head_dim, max_seq_len, paired=True) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + # Transposed weight storage for faster gradient accumulation + self.lm_head = CastedLinearT(model_dim, self.vocab_size, use_fp8=use_fp8, x_s=100/448, w_s=1.6/448, grad_s=grad_scale * 0.75/448) + + nn.init.normal_(self.lm_head.weight, mean=0, std=0.005) + self.lm_head.weight.label = 'lm_head' + + self.embed = nn.Embedding(self.vocab_size, model_dim) + self.embed.weight.label = 'embed' + with torch.no_grad(): + self.embed.weight.copy_(self.lm_head.weight.T) + + self.bigram_embed = nn.Embedding(args.bigram_vocab_size, model_dim) + self.bigram_embed.weight.label = 'bigram_embed' + nn.init.zeros_(self.bigram_embed.weight) + + # x0_lambdas separated out for different optimizer treatment (no beta smoothing) + self.x0_lambdas = nn.Parameter(torch.zeros(num_layers)) + self.x0_lambdas.label = 'x0_lambdas' + + pad = (-num_layers * 3 - 3) % dist.get_world_size() # updated: 3*num_layers instead of 4* + self.scalars = nn.Parameter( + torch.cat( + [ + 1.1 * torch.ones(num_layers), # resid lambdas. 1.1 init such that layer i weight is i^(num_layers-i). + *[torch.tensor([0.5, 1.0]) for _ in range(num_layers)], # SA lambdas + 0.1 * torch.ones(num_layers), # bigram lambdas + torch.zeros(1), # smear_lambda + 0.5*torch.ones(1), # backout_lambda + -1.5 * torch.ones(1), # skip_lambda -> σ(-1.5) ≈ 0.18 + torch.ones(pad), + ] + ) + ) + self.scalars.label = 'scalars' + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _compute_bigram_hash(x: Tensor, mod: int) -> Tensor: + """ + Computes bigram hash on GPU for each position using [prev_token, curr_token]. + Mathematically identical to the CPU version but computed on device. + """ + rand_int_1 = 36313 + rand_int_2 = 27191 + result = torch.empty_like(x) + result[0] = mod + result[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod + return result + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, schedule_cfg: ForwardScheduleConfig): + assert input_seq.ndim == 1 + + # unpack schedule_cfg + mtp_weights, ws_short, ws_long = schedule_cfg.mtp_weights, schedule_cfg.ws_short, schedule_cfg.ws_long + + # set configs + skip_connections = [] + skip_in = [3] # long attention window on layer 3 + skip_out = [6] # no attn op on layer 6 + x_backout = None + backout_layer = 7 + + # set lambdas + resid_lambdas = self.scalars[: 1 * self.num_layers] + x0_lambdas = self.x0_lambdas + sa_lambdas = self.scalars[1 * self.num_layers: 3 * self.num_layers].view(-1, 2) + bigram_lambdas = self.scalars[3 * self.num_layers: 4 * self.num_layers] + smear_lambda = self.scalars[4 * self.num_layers] + backout_lambda = self.scalars[4 * self.num_layers+1] + skip_lambda = self.scalars[4 * self.num_layers+2] + + # set block masks and key shift + bm_sizes = [ws_short, ws_short, ws_short, ws_long, ws_short, ws_short, None, ws_short, ws_short, ws_short, ws_long] + assert len(bm_sizes) == self.num_layers + key_offset = [b==ws_long for b in bm_sizes] # apply partial key offset to long windows + + # Embedding lookup - embed is synced from lm_head during tied phase by optimizer + x = self.embed(input_seq) + # Compute bigram hash on GPU (moved from CPU data loader) + bigram_seq = self._compute_bigram_hash(input_seq, args.bigram_vocab_size - 1) + x0_bigram = self.bigram_embed(bigram_seq)[None] + + # Value embeddings - always computed (not precomputed) + ve = self.value_embeds.view(5, self.vocab_size, -1)[:, input_seq] + # 01 ... 234 structure on token value embeddings by @photomz + ve = [ve[0], ve[1]] + [None] * (self.num_layers - 5) + [ve[2], ve[3], ve[4]] + assert len(ve) == self.num_layers + + # smear token embed forward 1 position @classiclarryd + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # unbind gate banks to avoid select_backwards kernel + ag = [w.bfloat16() for w in self.attn_gate_bank.unbind(0)] + veg = [w.bfloat16() for w in self.ve_gate_bank.unbind(0)] + attn_gates = ag[:6] + [None] + ag[6:] + ve_gates = [veg[0], veg[1]] + [None] * (self.num_layers - 5) + [veg[2], veg[3], veg[4]] + assert len(attn_gates) == self.num_layers + assert len(ve_gates) == self.num_layers + + # unbind weight banks to avoid select_backwards kernel + attn_weights = self.attn_bank.unbind(0) # tuple of [4*dim, hdim] tensors + mlp_fcs = self.mlp_bank[:, 0, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + mlp_projs = self.mlp_bank[:, 1, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + + for i in range(self.num_layers): + yarn = self.yarn_paired_head if i in self.paired_head_layers else self.yarn + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + yarn=yarn, + key_offset=key_offset[i], + attn_gate_w=attn_gates[i], + ve_gate_w=ve_gates[i] + ) + if i in skip_out: + skip_gate_out = torch.sigmoid(skip_lambda) * 2 * torch.sigmoid(self.skip_gate(x0[..., :self.skip_gate.weight.size(-1)])) + x = x + skip_gate_out * skip_connections.pop() + if i == 0: + x = (resid_lambdas[0] + x0_lambdas[0]) * x + bigram_lambdas[0] * x0_bigram + else: + x = resid_lambdas[i] * x + x0_lambdas[i] * x0 + bigram_lambdas[i] * x0_bigram + + # Get weights for this layer from banks + qkvo_w = attn_weights[self.layer_to_attn_idx[i]] if i in self.layer_to_attn_idx else None + c_fc = mlp_fcs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + c_proj = mlp_projs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + + x = self.blocks[i](x, attn_args, qkvo_w, c_fc, c_proj) + if i in skip_in: + skip_connections.append(x) + if i == backout_layer: + x_backout = x + + # back out contributions from first 7 layers that are only required for downstream context and not direct prediction + x -= backout_lambda * x_backout + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15 + # @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1). @classiclarryd updated to 23*sigmoid((logits+5)/7.5) + if self.training: + losses = FusedSoftcappedCrossEntropy.apply(logits.view(-1, logits.size(-1)), target_seq, mtp_weights, 23.0, 5.0, 7.5) + loss = losses.sum() + else: + logits = 23 * torch.sigmoid((logits + 5) / 7.5) + logits_for_loss = logits.float() + loss = F.cross_entropy(logits_for_loss.view(-1, logits_for_loss.size(-1)), target_seq, reduction="mean") + return loss +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class Shard: + def __init__(self, tokens: Tensor, world_size: int = 1): + self.tokens = tokens + self.size = tokens.numel() + self.world_size = world_size + self.i = 0 + + # Partial index now, full index async + self.bos_idx = (tokens[:6_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._full_idx = None + self._loader_thread = None + self._ready = threading.Event() + self._loader_thread = threading.Thread(target=self._scan) + self._loader_thread.start() + + def _scan(self): + self._full_idx = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._ready.set() + + def _maybe_switch(self): + # Switch to full index as soon as async scan completes + if self.bos_idx is not self._full_idx and self._ready.is_set(): + self._loader_thread.join() + self.bos_idx = self._full_idx + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + self._maybe_switch() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + return starts, ends + + @staticmethod + def load_async(file: Path, world_size: int = 1): + """Returns getter function for async shard loading""" + result = {} + ready = threading.Event() + def load(): + tokens = _load_data_shard(file) + result['shard'] = Shard(tokens, world_size) + ready.set() + thread = threading.Thread(target=load) + thread.start() + def get(): + ready.wait() + thread.join() + return result['shard'] + return get + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + shard = Shard(tokens, world_size) + next_shard_getter = Shard.load_async(next(file_iter), world_size) + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = shard.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + shard = next_shard_getter() + tokens = shard.tokens + try: + next_shard_getter = Shard.load_async(next(file_iter), world_size) + except StopIteration: + next_shard_getter = None # no more shards to preload + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + # Cast to int32 on CPU before transfer to avoid dtype conversion during .to() + _inputs = _inputs.to(dtype=torch.int32) + _targets = _targets.to(dtype=torch.int64) + _cum_lengths = _cum_lengths.to(dtype=torch.int32) + # Bigram hash computation moved to GPU in forward() + + new_params = yield ( + _inputs.to(device="cuda", non_blocking=True), + _targets.to(device="cuda", non_blocking=True), + _cum_lengths.to(device="cuda", non_blocking=True), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * new_grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens // new_grad_accum_steps + max_seq_len = new_max_seq_len + +# ----------------------------------------------------------------------------- +# Training Management + +@dataclass +class Hyperparameters: + # data + data_path = os.environ.get("DATA_PATH", ".") + train_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_train_*.bin") # input .bin to train on + val_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_val_*.bin") # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + # batch sizes + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # schedule + num_scheduled_iterations: int = 1515 # number of steps to complete lr and ws schedule + num_extension_iterations: int = 40 # number of steps to continue training at final lr and ws + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 250 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # bigram hash embedding + bigram_vocab_size: int = 50304 * 5 + +args = Hyperparameters() + +@dataclass +class TrainingStage: + lr_mul: float + batch_size: int + window_sizes: tuple[int, int] # (short, long) in block units + mtp_weights_start: list[float] + mtp_weights_end: list[float] + duration: float = None + +class TrainingSchedule: + """ + Training schedule initialized via TRAINING_STAGES + 1. Multi Token Prediction schedule of [1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1] @varunneal + 2. Sliding Attention window schedule of [1,3] -> [3,7] -> [5,11] -> [6,13] + 3. YaRN updates to RoPE on window changes + 4. Split embed and lm head at 2/3 of training + 5. Batch size schedule of 8 -> 16 -> 24 + 6. Post training extension of long windows from 13 to 20 + """ + + def __init__(self, stages: list[TrainingStage], scheduled_iterations: int, extension_iterations: int, + cooldown_frac: float = 0.5, split_embed_stage: int = 2, ws_post_yarn_ext: int = 20): + self.stages = stages + self.scheduled_iterations = scheduled_iterations + self.cooldown_frac = cooldown_frac + # increase final validation ws, used for YaRN extension and short window size @classiclarryd + self.ws_post_yarn_ext = ws_post_yarn_ext + + self.total_steps = self.scheduled_iterations + extension_iterations + + # Build stage boundaries (last is extension stage) + ends = [0] + [round(c * scheduled_iterations) for c in accumulate(s.duration for s in stages[:-1])] + [self.total_steps] + assert self.scheduled_iterations == ends[-2] + self.boundaries = list(pairwise(ends)) + + # Split embed at specified stage (ensure odd step for Adam) + self.split_step = self.boundaries[split_embed_stage][0] | 1 + + # Precompute MTP weights for all steps + self.mtp_weights = [] + for step in range(self.total_steps + 1): + stage, t = self.lookup(step) + w = [a + (b - a) * t for a, b in zip(stage.mtp_weights_start, stage.mtp_weights_end)] + self.mtp_weights.append(torch.tensor(w, device=device)) + + def lookup(self, step: int) -> tuple[TrainingStage, float]: + # Returns stage and % of the way through that stage + for i, (start, end) in enumerate(self.boundaries): + if step < end: + t = (step - start) / (end - start) + return self.stages[i], t + return self.stages[-1], 1.0 + + def get_lr(self, step: int) -> float: + # learning rate schedule: tied to batch size schedule, with cooldown at the end + stage, _ = self.lookup(step) + lr = stage.lr_mul + cd_start = int(self.scheduled_iterations * (1 - self.cooldown_frac)) + if step >= cd_start: + t = min(1.0, (step - cd_start) / (self.scheduled_iterations - cd_start)) + lr = lr * (1 - t) + 0.1 * t + return lr + +# window_sizes are in units of `block_size` tokens (defined in TrainingManager) +TRAINING_STAGES = [ + TrainingStage(duration=1/3, batch_size=8 * 2048 * 8, window_sizes=(1, 3), lr_mul=1.0, + mtp_weights_start=[1.0, 0.5, 0.25], mtp_weights_end=[1.0, 0.5, 0.0]), + TrainingStage(duration=1/3, batch_size=16 * 2048 * 8, window_sizes=(3, 7), lr_mul=1.52, # (16/8)**0.6 + mtp_weights_start=[1.0, 0.5], mtp_weights_end=[1.0, 0.0]), + TrainingStage(duration=1/3, batch_size=24 * 2048 * 8, window_sizes=(5, 11), lr_mul=1.73, # (24/8)**0.5 + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), + # extension stage + TrainingStage(batch_size=24 * 2048 * 8, window_sizes=(6, 13), lr_mul=1.0, # lr_mul is not used + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), +] + +training_schedule = TrainingSchedule(TRAINING_STAGES, args.num_scheduled_iterations, args.num_extension_iterations, cooldown_frac=0.55) + +def get_muon_momentum(step: int, muon_warmup_steps=300, muon_cooldown_steps=50, momentum_min=0.85, momentum_max=0.95): + # warmup phase: linearly increase momentum from min to max + # cooldown phase: linearly decrease momentum from max to min + momentum_cd_start = training_schedule.total_steps - muon_cooldown_steps + if step < muon_warmup_steps: + frac = step / muon_warmup_steps + momentum = momentum_min + frac * (momentum_max - momentum_min) + elif step > momentum_cd_start: + frac = (step - momentum_cd_start) / muon_cooldown_steps + momentum = momentum_max - frac * (momentum_max - momentum_min) + else: + momentum = momentum_max + return momentum + +class TrainingManager(): + """ + Manages the NorMuonAndAdam for all parameters with explicit ordering. + 1. Scalars are given higher momentum terms to smooth learning @ChrisJMcCormick + 2. Adam optimizers are only stepped on odd steps @classiclarryd + 3. Explicit scatter_order and work_order for communication scheduling (no backward hooks) + 4. Muon has a linear momentum warmup and cooldown schedule + 5. Learning rates follow a linear decay schedule + 6. Embed is tied to lm_head until split step (2/3 of training), then untied @classiclarryd + """ + def __init__(self, model): + self.model = model + self.block_size = 128 + + # - Ordering dictates when to launch reduce/reduce_scatter operations + # - "sharded" parameters use reduce_scatter/all_gather and "replicated" ones use all_reduce + # - lr_mul and wd_mul are per-parameter learning rate and weight decay multipliers + self.param_table = { + "attn": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "mlp": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "scalars": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 5.0, "wd_mul": 0.0}, + "value_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "bigram_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "smear_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.01, "wd_mul": 0.0}, + "skip_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.05, "wd_mul": 0.0}, + "attn_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "ve_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "x0_lambdas": {"optim": "adam", "comms": "replicated", "adam_betas": [0.65, 0.95], "lr_mul": 5.0, "wd_mul": 0.0}, + "lm_head": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + "embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + } + + # - Process smaller/faster params first while large reduces complete + # - lm_head must complete before embed sync (when tied) + self.work_order = [ + "scalars", "smear_gate", "skip_gate", "attn_gate_bank", "ve_gate_bank", "x0_lambdas", # Small, fast + "value_embed", "bigram_embed", # Medium + "lm_head", "embed", # lm_head must complete before embed sync (when tied) + "attn", "mlp", # Large, polar express - process last to maximize overlap + ] + + adam_defaults = dict( + lr=0.008, + eps=1e-10, + weight_decay=0.005, + ) + + normuon_defaults = dict( + lr=0.023, + momentum=0.95, + beta2=0.95, + weight_decay=1.2, + ) + + self.optimizer = NorMuonAndAdam( + model.named_parameters(), + param_table=self.param_table, + scatter_order=list(self.param_table.keys()), # Dict order defines scatter priority + work_order=self.work_order, + adam_defaults=adam_defaults, + normuon_defaults=normuon_defaults, + ) + + # Split embed from lm_head at 2/3 of training (on an odd step so Adam updates) + self.split_step = training_schedule.split_step + + self.reset() + + def apply_final_ws_ext(self): + self.ws_long = training_schedule.ws_post_yarn_ext + + def get_forward_args(self): + return ForwardScheduleConfig( + mtp_weights = self.mtp_weights, + ws_short = self.ws_short * self.block_size, + ws_long = self.ws_long * self.block_size + ) + + def _is_adam_step(self, step: int): + """Adam params are only updated on odd steps.""" + return step % 2 == 1 + + def get_transition_steps(self): + return [start for start, _ in training_schedule.boundaries[1:]] + + def advance_schedule(self, step: int): + stage, _ = training_schedule.lookup(step) + self.ws_short, new_ws_long = stage.window_sizes + if new_ws_long != self.ws_long: + self.model.yarn.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + self.model.yarn_paired_head.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + + new_batch_size = stage.batch_size + if new_batch_size != self.batch_size: + self.train_loader_send_args = (new_batch_size, args.train_max_seq_len, grad_accum_steps) + self.batch_size = new_batch_size + else: + self.train_loader_send_args = None + + self.ws_long = new_ws_long + self.mtp_weights = training_schedule.mtp_weights[step] + + def step_optimizers(self, step: int): + step_lr = training_schedule.get_lr(step) + muon_momentum = get_muon_momentum(step) + do_adam = self._is_adam_step(step) + + # Update learning rates and momentum for all params + for param, p_cfg in self.optimizer.param_cfgs.items(): + p_cfg.lr = p_cfg.initial_lr * step_lr + if p_cfg.optim == "normuon": + p_cfg.momentum = muon_momentum + + # Step optimizer with do_adam flag + self.optimizer.step(do_adam=do_adam) + + # At split step: copy lm_head optimizer state to embed and mark as split + if step == self.split_step: + self.optimizer.copy_lm_state_to_embed() + + def reset(self, state=None): + if state is not None: + self.optimizer.load_state_dict(state) + + # Reset NorMuon momentum buffers and split_embed state + self.optimizer.reset() + + stage, _ = training_schedule.lookup(0) + self.ws_short, self.ws_long = stage.window_sizes + self.batch_size = stage.batch_size + self.model.yarn.reset() + self.model.yarn_paired_head.reset() + + def get_state(self): + return copy.deepcopy(self.optimizer.state_dict()) + +# ----------------------------------------------------------------------------- +# int main + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=11, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=args.val_batch_size // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.weight.data = m.weight.data.bfloat16() +model.attn_gate_bank.data = model.attn_gate_bank.data.bfloat16() +model.ve_gate_bank.data = model.ve_gate_bank.data.bfloat16() +model.attn_bank.data = model.attn_bank.data.bfloat16() +model.mlp_bank.data = model.mlp_bank.data.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) +training_manager = TrainingManager(model) + +######################################## +# Warmup kernels # +######################################## +print0("Compiling model and warming up kernels (~7 minutes on first execution)", console=True) +# Warmup the training kernels, then re-initialize the state so we aren't cheating +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizer=training_manager.get_state()) # save the initial state +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + +transition_steps = training_manager.get_transition_steps() +# first few steps plus transitions +warmup_steps = sorted({0, 1, 2} | set(s + offset for s in transition_steps for offset in [-1, 0, 1] if s + offset >= 0)) +print0(f"Sampling steps {warmup_steps} for warmup", console=True) +for step in warmup_steps: + training_manager.advance_schedule(step) + model.eval() + with torch.no_grad(): + inputs, targets, cum_seqlens = next(val_loader) + model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + model.train() + for idx in range(grad_accum_steps): + send_args = training_manager.train_loader_send_args + inputs, targets, cum_seqlens = train_loader.send(send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) +print0("Resetting Model", console=True) +model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +training_manager.reset(initial_state["optimizer"]) +del val_loader, train_loader, initial_state +model.train() + +######################################## +# Training and validation # +######################################## +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) + +gc.collect() + +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = training_schedule.total_steps +for step in range(train_steps + 1): + last_step = (step == train_steps) + training_manager.advance_schedule(step) + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + training_manager.apply_final_ws_ext() + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + val_loss /= val_steps + del val_loader + dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizer=training_manager.get_state()) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for idx in range(grad_accum_steps): + inputs, targets, cum_seqlens = train_loader.send(training_manager.train_loader_send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) + + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + + +---------------------------------------- +# triton_kernels.py +---------------------------------------- + +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded configs based on H100 autotuning + if K == 768: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + else: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 128 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded config based on H100 autotuning (M=768) + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +# ----------------------------------------------------------------------------- +# Triton kernel for MLP: relu(x @ W1.T)^2, by @andrewbriand, @jrauvola + +@triton.jit +def linear_relu_square_kernel(a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, + ): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am_c = pid_m * BLOCK_SIZE_M + offs_bn_c = pid_n * BLOCK_SIZE_N + + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + + c0 = acc0.to(dtype) + if not FORWARD: + c0_pre = aux_desc.load([offs_am_c, offs_bn_c]) + c0 = 2 * c0 * tl.where(c0_pre > 0, c0_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c], c0) + + if FORWARD: + c0_post = tl.maximum(c0, 0) + c0_post = c0_post * c0_post + aux_desc.store([offs_am_c, offs_bn_c], c0_post) + + c1 = acc1.to(dtype) + if not FORWARD: + c1_pre = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c1 = 2 * c1 * tl.where(c1_pre > 0, c1_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + + if FORWARD: + c1_post = tl.maximum(c1, 0) + c1_post = c1_post * c1_post + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1_post) + + +def linear_relu_square(a, b, aux=None): + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + FORWARD = False + if aux is None: + FORWARD = True + aux = torch.empty((M, N), device=a.device, dtype=dtype) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_stages = 4 if FORWARD else 3 + num_warps = 8 + + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + + def grid(META): + return (min( + NUM_SMS, + triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), + ), ) + + linear_relu_square_kernel[grid]( + a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + NUM_SMS=NUM_SMS, + FORWARD=FORWARD, + num_stages=num_stages, + num_warps=num_warps + ) + + if FORWARD: + return c, aux + else: + return c + +class FusedLinearReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, W1, W2): + pre, post = linear_relu_square(x.view((-1, x.shape[-1])), W1) + x3 = post @ W2 + ctx.save_for_backward(x, W1, W2, pre, post) + return x3.view(x.shape) + + @staticmethod + def backward(ctx, grad_output): + x, W1, W2, pre, post = ctx.saved_tensors + dW2 = post.T @ grad_output + dpre = linear_relu_square(grad_output.view((-1, grad_output.shape[-1])), W2, aux=pre) + dW1 = dpre.T @ x + dx = dpre @ W1 + return dx.view(x.shape), dW1, dW2 + +# ----------------------------------------------------------------------------- +# Fused Softcapped Cross Entropy + + +@triton.jit +def fused_softcapped_entropy_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + + max_val = -float('inf') + sum_exp = 0.0 + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=-float('inf')).to(tl.float32) + z = A * tl.sigmoid((val + B) / C) + z = tl.where(mask, z, -float('inf')) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + + total_loss = 0.0 + for k in range(n_predict): + target_idx = row_idx + k + if target_idx < n_rows: + weight = tl.load(mtp_weights_ptr + k) + if weight > 0: + target = tl.load(targets_ptr + target_idx).to(tl.int32) + if target >= 0 and target < n_cols: + val_target = tl.load(logits_row_ptr + target).to(tl.float32) + z_target = A * tl.sigmoid((val_target + B) / C) + total_loss += weight * (lse - z_target) + + tl.store(losses_ptr + row_idx, total_loss) + +@triton.jit +def fused_softcapped_entropy_bwd_kernel( + grad_input_ptr, grad_output_ptr, lse_ptr, logits_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, stride_grad_n, stride_grad_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_input_ptr + row_idx * stride_grad_n + + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_output_ptr + row_idx) + + S_w = 0.0 + for k in range(n_predict): + if row_idx + k < n_rows: + S_w += tl.load(mtp_weights_ptr + k) + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=0.0).to(tl.float32) + u = (val + B) / C + sigmoid_u = tl.sigmoid(u) + z = A * sigmoid_u + p = tl.exp(z - lse) + + term1 = S_w * p + term2 = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for k in range(n_predict): + if row_idx + k < n_rows: + target = tl.load(targets_ptr + row_idx + k).to(tl.int32) + weight = tl.load(mtp_weights_ptr + k) + term2 += tl.where(cols == target, weight, 0.0) + + grad_z = grad_loss * (term1 - term2) + dz_dx = (1.0 / C) * z * (1.0 - sigmoid_u) + grad_x = grad_z * dz_dx + tl.store(grad_row_ptr + cols, grad_x.to(tl.bfloat16), mask=mask) + +class FusedSoftcappedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, targets, mtp_weights, A=23.0, B=5.0, C=7.5): + n_rows, n_cols = logits.shape + if mtp_weights is None: + mtp_weights = torch.tensor([1.0], device=logits.device, dtype=torch.float32) + n_predict = mtp_weights.shape[0] + + losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + lse = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + + logits = logits.contiguous() + targets = targets.contiguous() + mtp_weights = mtp_weights.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_fwd_kernel[grid]( + logits, losses, lse, targets, mtp_weights, + logits.stride(0), logits.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + + ctx.save_for_backward(logits, targets, mtp_weights, lse) + ctx.params = (A, B, C) + return losses + + @staticmethod + def backward(ctx, grad_output): + logits, targets, mtp_weights, lse = ctx.saved_tensors + A, B, C = ctx.params + n_rows, n_cols = logits.shape + n_predict = mtp_weights.shape[0] + + grad_input = torch.empty((n_rows, n_cols), dtype=torch.bfloat16, device=logits.device) + grad_output = grad_output.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_bwd_kernel[grid]( + grad_input, grad_output, lse, logits, targets, mtp_weights, + logits.stride(0), logits.stride(1), grad_input.stride(0), grad_input.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + return grad_input, None, None, None, None, None + +==================================================================================================== +Running Python 3.12.7 (main, Jan 31 2026, 04:21:49) [GCC 13.2.0] +Running PyTorch 2.10.0.dev20251210+cu126 compiled for CUDA 12.6 +Running Triton version 3.6.0 +Sun Feb 1 10:30:04 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:71:00.0 Off | 0 | +| N/A 32C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:79:00.0 Off | 0 | +| N/A 30C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:7F:00.0 Off | 0 | +| N/A 28C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 30C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 94 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 95 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 96 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 97 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 98 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 99 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 100 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 101 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +Compiling model and warming up kernels (~7 minutes on first execution) +Sampling steps [0, 1, 2, 504, 505, 506, 1009, 1010, 1011, 1514, 1515, 1516] for warmup +Resetting Model +step:0/1555 val_loss:10.8316 train_time:0ms step_avg:0.03ms +step:1/1555 train_time:87ms step_avg:87.43ms +step:2/1555 train_time:109ms step_avg:54.31ms +step:3/1555 train_time:128ms step_avg:42.67ms +step:4/1555 train_time:154ms step_avg:38.46ms +step:5/1555 train_time:185ms step_avg:36.94ms +step:6/1555 train_time:222ms step_avg:37.00ms +step:7/1555 train_time:253ms step_avg:36.14ms +step:8/1555 train_time:290ms step_avg:36.30ms +step:9/1555 train_time:322ms step_avg:35.76ms +step:10/1555 train_time:360ms step_avg:35.97ms +step:11/1555 train_time:390ms step_avg:35.50ms +step:12/1555 train_time:428ms step_avg:35.70ms +step:13/1555 train_time:459ms step_avg:35.34ms +step:14/1555 train_time:497ms step_avg:35.50ms +step:15/1555 train_time:528ms step_avg:35.21ms +step:16/1555 train_time:566ms step_avg:35.35ms +step:17/1555 train_time:597ms step_avg:35.09ms +step:18/1555 train_time:634ms step_avg:35.23ms +step:19/1555 train_time:665ms step_avg:35.00ms +step:20/1555 train_time:702ms step_avg:35.12ms +step:21/1555 train_time:734ms step_avg:34.94ms +step:22/1555 train_time:771ms step_avg:35.05ms +step:23/1555 train_time:802ms step_avg:34.88ms +step:24/1555 train_time:840ms step_avg:34.99ms +step:25/1555 train_time:871ms step_avg:34.83ms +step:26/1555 train_time:908ms step_avg:34.94ms +step:27/1555 train_time:939ms step_avg:34.79ms +step:28/1555 train_time:977ms step_avg:34.89ms +step:29/1555 train_time:1008ms step_avg:34.75ms +step:30/1555 train_time:1046ms step_avg:34.87ms +step:31/1555 train_time:1077ms step_avg:34.76ms +step:32/1555 train_time:1115ms step_avg:34.85ms +step:33/1555 train_time:1146ms step_avg:34.74ms +step:34/1555 train_time:1184ms step_avg:34.84ms +step:35/1555 train_time:1216ms step_avg:34.74ms +step:36/1555 train_time:1254ms step_avg:34.83ms +step:37/1555 train_time:1285ms step_avg:34.72ms +step:38/1555 train_time:1322ms step_avg:34.79ms +step:39/1555 train_time:1353ms step_avg:34.70ms +step:40/1555 train_time:1391ms step_avg:34.78ms +step:41/1555 train_time:1422ms step_avg:34.68ms +step:42/1555 train_time:1460ms step_avg:34.76ms +step:43/1555 train_time:1491ms step_avg:34.67ms +step:44/1555 train_time:1529ms step_avg:34.75ms +step:45/1555 train_time:1559ms step_avg:34.65ms +step:46/1555 train_time:1597ms step_avg:34.72ms +step:47/1555 train_time:1628ms step_avg:34.64ms +step:48/1555 train_time:1666ms step_avg:34.70ms +step:49/1555 train_time:1697ms step_avg:34.62ms +step:50/1555 train_time:1734ms step_avg:34.69ms +step:51/1555 train_time:1765ms step_avg:34.61ms +step:52/1555 train_time:1803ms step_avg:34.68ms +step:53/1555 train_time:1834ms step_avg:34.61ms +step:54/1555 train_time:1871ms step_avg:34.66ms +step:55/1555 train_time:1903ms step_avg:34.59ms +step:56/1555 train_time:1940ms step_avg:34.65ms +step:57/1555 train_time:1971ms step_avg:34.58ms +step:58/1555 train_time:2009ms step_avg:34.63ms +step:59/1555 train_time:2040ms step_avg:34.57ms +step:60/1555 train_time:2077ms step_avg:34.62ms +step:61/1555 train_time:2108ms step_avg:34.56ms +step:62/1555 train_time:2146ms step_avg:34.61ms +step:63/1555 train_time:2177ms step_avg:34.56ms +step:64/1555 train_time:2215ms step_avg:34.61ms +step:65/1555 train_time:2246ms step_avg:34.55ms +step:66/1555 train_time:2284ms step_avg:34.60ms +step:67/1555 train_time:2315ms step_avg:34.55ms +step:68/1555 train_time:2353ms step_avg:34.60ms +step:69/1555 train_time:2383ms step_avg:34.54ms +step:70/1555 train_time:2421ms step_avg:34.59ms +step:71/1555 train_time:2452ms step_avg:34.54ms +step:72/1555 train_time:2490ms step_avg:34.58ms +step:73/1555 train_time:2521ms step_avg:34.54ms +step:74/1555 train_time:2560ms step_avg:34.59ms +step:75/1555 train_time:2591ms step_avg:34.54ms +step:76/1555 train_time:2629ms step_avg:34.59ms +step:77/1555 train_time:2660ms step_avg:34.54ms +step:78/1555 train_time:2698ms step_avg:34.59ms +step:79/1555 train_time:2729ms step_avg:34.54ms +step:80/1555 train_time:2767ms step_avg:34.58ms +step:81/1555 train_time:2798ms step_avg:34.54ms +step:82/1555 train_time:2836ms step_avg:34.58ms +step:83/1555 train_time:2867ms step_avg:34.54ms +step:84/1555 train_time:2905ms step_avg:34.58ms +step:85/1555 train_time:2935ms step_avg:34.53ms +step:86/1555 train_time:2973ms step_avg:34.57ms +step:87/1555 train_time:3004ms step_avg:34.52ms +step:88/1555 train_time:3041ms step_avg:34.56ms +step:89/1555 train_time:3072ms step_avg:34.51ms +step:90/1555 train_time:3109ms step_avg:34.55ms +step:91/1555 train_time:3141ms step_avg:34.51ms +step:92/1555 train_time:3178ms step_avg:34.54ms +step:93/1555 train_time:3209ms step_avg:34.51ms +step:94/1555 train_time:3247ms step_avg:34.54ms +step:95/1555 train_time:3278ms step_avg:34.51ms +step:96/1555 train_time:3316ms step_avg:34.54ms +step:97/1555 train_time:3347ms step_avg:34.50ms +step:98/1555 train_time:3385ms step_avg:34.54ms +step:99/1555 train_time:3416ms step_avg:34.50ms +step:100/1555 train_time:3453ms step_avg:34.53ms +step:101/1555 train_time:3484ms step_avg:34.50ms +step:102/1555 train_time:3522ms step_avg:34.53ms +step:103/1555 train_time:3553ms step_avg:34.49ms +step:104/1555 train_time:3590ms step_avg:34.52ms +step:105/1555 train_time:3621ms step_avg:34.49ms +step:106/1555 train_time:3659ms step_avg:34.52ms +step:107/1555 train_time:3690ms step_avg:34.49ms +step:108/1555 train_time:3727ms step_avg:34.51ms +step:109/1555 train_time:3759ms step_avg:34.48ms +step:110/1555 train_time:3797ms step_avg:34.51ms +step:111/1555 train_time:3827ms step_avg:34.48ms +step:112/1555 train_time:3865ms step_avg:34.51ms +step:113/1555 train_time:3896ms step_avg:34.48ms +step:114/1555 train_time:3935ms step_avg:34.51ms +step:115/1555 train_time:3966ms step_avg:34.48ms +step:116/1555 train_time:4004ms step_avg:34.51ms +step:117/1555 train_time:4035ms step_avg:34.48ms +step:118/1555 train_time:4072ms step_avg:34.51ms +step:119/1555 train_time:4103ms step_avg:34.48ms +step:120/1555 train_time:4141ms step_avg:34.51ms +step:121/1555 train_time:4171ms step_avg:34.47ms +step:122/1555 train_time:4209ms step_avg:34.50ms +step:123/1555 train_time:4240ms step_avg:34.47ms +step:124/1555 train_time:4277ms step_avg:34.49ms +step:125/1555 train_time:4308ms step_avg:34.46ms +step:126/1555 train_time:4346ms step_avg:34.49ms +step:127/1555 train_time:4377ms step_avg:34.46ms +step:128/1555 train_time:4414ms step_avg:34.49ms +step:129/1555 train_time:4445ms step_avg:34.46ms +step:130/1555 train_time:4483ms step_avg:34.49ms +step:131/1555 train_time:4514ms step_avg:34.45ms +step:132/1555 train_time:4551ms step_avg:34.48ms +step:133/1555 train_time:4582ms step_avg:34.45ms +step:134/1555 train_time:4620ms step_avg:34.48ms +step:135/1555 train_time:4651ms step_avg:34.45ms +step:136/1555 train_time:4688ms step_avg:34.47ms +step:137/1555 train_time:4720ms step_avg:34.45ms +step:138/1555 train_time:4758ms step_avg:34.48ms +step:139/1555 train_time:4789ms step_avg:34.45ms +step:140/1555 train_time:4827ms step_avg:34.48ms +step:141/1555 train_time:4858ms step_avg:34.45ms +step:142/1555 train_time:4895ms step_avg:34.48ms +step:143/1555 train_time:4927ms step_avg:34.45ms +step:144/1555 train_time:4965ms step_avg:34.48ms +step:145/1555 train_time:4995ms step_avg:34.45ms +step:146/1555 train_time:5033ms step_avg:34.47ms +step:147/1555 train_time:5064ms step_avg:34.45ms +step:148/1555 train_time:5103ms step_avg:34.48ms +step:149/1555 train_time:5133ms step_avg:34.45ms +step:150/1555 train_time:5171ms step_avg:34.47ms +step:151/1555 train_time:5202ms step_avg:34.45ms +step:152/1555 train_time:5240ms step_avg:34.47ms +step:153/1555 train_time:5271ms step_avg:34.45ms +step:154/1555 train_time:5308ms step_avg:34.47ms +step:155/1555 train_time:5339ms step_avg:34.45ms +step:156/1555 train_time:5377ms step_avg:34.47ms +step:157/1555 train_time:5408ms step_avg:34.44ms +step:158/1555 train_time:5447ms step_avg:34.47ms +step:159/1555 train_time:5476ms step_avg:34.44ms +step:160/1555 train_time:5513ms step_avg:34.46ms +step:161/1555 train_time:5544ms step_avg:34.43ms +step:162/1555 train_time:5582ms step_avg:34.45ms +step:163/1555 train_time:5613ms step_avg:34.43ms +step:164/1555 train_time:5651ms step_avg:34.46ms +step:165/1555 train_time:5682ms step_avg:34.43ms +step:166/1555 train_time:5719ms step_avg:34.45ms +step:167/1555 train_time:5750ms step_avg:34.43ms +step:168/1555 train_time:5788ms step_avg:34.45ms +step:169/1555 train_time:5818ms step_avg:34.43ms +step:170/1555 train_time:5856ms step_avg:34.44ms +step:171/1555 train_time:5886ms step_avg:34.42ms +step:172/1555 train_time:5924ms step_avg:34.44ms +step:173/1555 train_time:5955ms step_avg:34.42ms +step:174/1555 train_time:5992ms step_avg:34.44ms +step:175/1555 train_time:6023ms step_avg:34.42ms +step:176/1555 train_time:6061ms step_avg:34.44ms +step:177/1555 train_time:6092ms step_avg:34.42ms +step:178/1555 train_time:6130ms step_avg:34.44ms +step:179/1555 train_time:6161ms step_avg:34.42ms +step:180/1555 train_time:6198ms step_avg:34.43ms +step:181/1555 train_time:6229ms step_avg:34.42ms +step:182/1555 train_time:6266ms step_avg:34.43ms +step:183/1555 train_time:6297ms step_avg:34.41ms +step:184/1555 train_time:6335ms step_avg:34.43ms +step:185/1555 train_time:6366ms step_avg:34.41ms +step:186/1555 train_time:6404ms step_avg:34.43ms +step:187/1555 train_time:6435ms step_avg:34.41ms +step:188/1555 train_time:6473ms step_avg:34.43ms +step:189/1555 train_time:6503ms step_avg:34.41ms +step:190/1555 train_time:6541ms step_avg:34.43ms +step:191/1555 train_time:6572ms step_avg:34.41ms +step:192/1555 train_time:6610ms step_avg:34.43ms +step:193/1555 train_time:6641ms step_avg:34.41ms +step:194/1555 train_time:6679ms step_avg:34.43ms +step:195/1555 train_time:6709ms step_avg:34.41ms +step:196/1555 train_time:6747ms step_avg:34.42ms +step:197/1555 train_time:6777ms step_avg:34.40ms +step:198/1555 train_time:6815ms step_avg:34.42ms +step:199/1555 train_time:6846ms step_avg:34.40ms +step:200/1555 train_time:6883ms step_avg:34.42ms +step:201/1555 train_time:6914ms step_avg:34.40ms +step:202/1555 train_time:6952ms step_avg:34.41ms +step:203/1555 train_time:6982ms step_avg:34.40ms +step:204/1555 train_time:7020ms step_avg:34.41ms +step:205/1555 train_time:7051ms step_avg:34.39ms +step:206/1555 train_time:7089ms step_avg:34.41ms +step:207/1555 train_time:7120ms step_avg:34.39ms +step:208/1555 train_time:7157ms step_avg:34.41ms +step:209/1555 train_time:7188ms step_avg:34.39ms +step:210/1555 train_time:7226ms step_avg:34.41ms +step:211/1555 train_time:7257ms step_avg:34.39ms +step:212/1555 train_time:7295ms step_avg:34.41ms +step:213/1555 train_time:7326ms step_avg:34.39ms +step:214/1555 train_time:7363ms step_avg:34.41ms +step:215/1555 train_time:7394ms step_avg:34.39ms +step:216/1555 train_time:7432ms step_avg:34.41ms +step:217/1555 train_time:7463ms step_avg:34.39ms +step:218/1555 train_time:7500ms step_avg:34.40ms +step:219/1555 train_time:7531ms step_avg:34.39ms +step:220/1555 train_time:7569ms step_avg:34.41ms +step:221/1555 train_time:7600ms step_avg:34.39ms +step:222/1555 train_time:7637ms step_avg:34.40ms +step:223/1555 train_time:7668ms step_avg:34.39ms +step:224/1555 train_time:7706ms step_avg:34.40ms +step:225/1555 train_time:7737ms step_avg:34.39ms +step:226/1555 train_time:7775ms step_avg:34.40ms +step:227/1555 train_time:7806ms step_avg:34.39ms +step:228/1555 train_time:7843ms step_avg:34.40ms +step:229/1555 train_time:7874ms step_avg:34.38ms +step:230/1555 train_time:7912ms step_avg:34.40ms +step:231/1555 train_time:7943ms step_avg:34.38ms +step:232/1555 train_time:7980ms step_avg:34.40ms +step:233/1555 train_time:8011ms step_avg:34.38ms +step:234/1555 train_time:8049ms step_avg:34.40ms +step:235/1555 train_time:8080ms step_avg:34.38ms +step:236/1555 train_time:8117ms step_avg:34.40ms +step:237/1555 train_time:8148ms step_avg:34.38ms +step:238/1555 train_time:8186ms step_avg:34.40ms +step:239/1555 train_time:8218ms step_avg:34.38ms +step:240/1555 train_time:8255ms step_avg:34.40ms +step:241/1555 train_time:8286ms step_avg:34.38ms +step:242/1555 train_time:8324ms step_avg:34.40ms +step:243/1555 train_time:8355ms step_avg:34.38ms +step:244/1555 train_time:8393ms step_avg:34.40ms +step:245/1555 train_time:8423ms step_avg:34.38ms +step:246/1555 train_time:8461ms step_avg:34.39ms +step:247/1555 train_time:8492ms step_avg:34.38ms +step:248/1555 train_time:8529ms step_avg:34.39ms +step:249/1555 train_time:8560ms step_avg:34.38ms +step:250/1555 train_time:8597ms step_avg:34.39ms +step:250/1555 val_loss:4.5653 train_time:8647ms step_avg:34.59ms +step:251/1555 train_time:8668ms step_avg:34.53ms +step:252/1555 train_time:8695ms step_avg:34.50ms +step:253/1555 train_time:8714ms step_avg:34.44ms +step:254/1555 train_time:8737ms step_avg:34.40ms +step:255/1555 train_time:8769ms step_avg:34.39ms +step:256/1555 train_time:8808ms step_avg:34.41ms +step:257/1555 train_time:8840ms step_avg:34.40ms +step:258/1555 train_time:8878ms step_avg:34.41ms +step:259/1555 train_time:8909ms step_avg:34.40ms +step:260/1555 train_time:8947ms step_avg:34.41ms +step:261/1555 train_time:8977ms step_avg:34.40ms +step:262/1555 train_time:9015ms step_avg:34.41ms +step:263/1555 train_time:9046ms step_avg:34.39ms +step:264/1555 train_time:9083ms step_avg:34.41ms +step:265/1555 train_time:9114ms step_avg:34.39ms +step:266/1555 train_time:9152ms step_avg:34.41ms +step:267/1555 train_time:9182ms step_avg:34.39ms +step:268/1555 train_time:9220ms step_avg:34.40ms +step:269/1555 train_time:9250ms step_avg:34.39ms +step:270/1555 train_time:9288ms step_avg:34.40ms +step:271/1555 train_time:9319ms step_avg:34.39ms +step:272/1555 train_time:9356ms step_avg:34.40ms +step:273/1555 train_time:9387ms step_avg:34.38ms +step:274/1555 train_time:9425ms step_avg:34.40ms +step:275/1555 train_time:9455ms step_avg:34.38ms +step:276/1555 train_time:9492ms step_avg:34.39ms +step:277/1555 train_time:9523ms step_avg:34.38ms +step:278/1555 train_time:9560ms step_avg:34.39ms +step:279/1555 train_time:9591ms step_avg:34.38ms +step:280/1555 train_time:9630ms step_avg:34.39ms +step:281/1555 train_time:9660ms step_avg:34.38ms +step:282/1555 train_time:9697ms step_avg:34.39ms +step:283/1555 train_time:9729ms step_avg:34.38ms +step:284/1555 train_time:9766ms step_avg:34.39ms +step:285/1555 train_time:9797ms step_avg:34.38ms +step:286/1555 train_time:9835ms step_avg:34.39ms +step:287/1555 train_time:9866ms step_avg:34.38ms +step:288/1555 train_time:9904ms step_avg:34.39ms +step:289/1555 train_time:9935ms step_avg:34.38ms +step:290/1555 train_time:9972ms step_avg:34.39ms +step:291/1555 train_time:10003ms step_avg:34.37ms +step:292/1555 train_time:10041ms step_avg:34.39ms +step:293/1555 train_time:10072ms step_avg:34.38ms +step:294/1555 train_time:10111ms step_avg:34.39ms +step:295/1555 train_time:10140ms step_avg:34.37ms +step:296/1555 train_time:10178ms step_avg:34.38ms +step:297/1555 train_time:10208ms step_avg:34.37ms +step:298/1555 train_time:10246ms step_avg:34.38ms +step:299/1555 train_time:10277ms step_avg:34.37ms +step:300/1555 train_time:10314ms step_avg:34.38ms +step:301/1555 train_time:10345ms step_avg:34.37ms +step:302/1555 train_time:10382ms step_avg:34.38ms +step:303/1555 train_time:10413ms step_avg:34.37ms +step:304/1555 train_time:10451ms step_avg:34.38ms +step:305/1555 train_time:10481ms step_avg:34.37ms +step:306/1555 train_time:10519ms step_avg:34.38ms +step:307/1555 train_time:10550ms step_avg:34.37ms +step:308/1555 train_time:10588ms step_avg:34.38ms +step:309/1555 train_time:10619ms step_avg:34.37ms +step:310/1555 train_time:10657ms step_avg:34.38ms +step:311/1555 train_time:10688ms step_avg:34.37ms +step:312/1555 train_time:10726ms step_avg:34.38ms +step:313/1555 train_time:10757ms step_avg:34.37ms +step:314/1555 train_time:10794ms step_avg:34.38ms +step:315/1555 train_time:10825ms step_avg:34.37ms +step:316/1555 train_time:10863ms step_avg:34.38ms +step:317/1555 train_time:10894ms step_avg:34.37ms +step:318/1555 train_time:10932ms step_avg:34.38ms +step:319/1555 train_time:10963ms step_avg:34.37ms +step:320/1555 train_time:11001ms step_avg:34.38ms +step:321/1555 train_time:11031ms step_avg:34.37ms +step:322/1555 train_time:11070ms step_avg:34.38ms +step:323/1555 train_time:11100ms step_avg:34.37ms +step:324/1555 train_time:11138ms step_avg:34.38ms +step:325/1555 train_time:11169ms step_avg:34.37ms +step:326/1555 train_time:11206ms step_avg:34.38ms +step:327/1555 train_time:11237ms step_avg:34.36ms +step:328/1555 train_time:11275ms step_avg:34.37ms +step:329/1555 train_time:11305ms step_avg:34.36ms +step:330/1555 train_time:11343ms step_avg:34.37ms +step:331/1555 train_time:11374ms step_avg:34.36ms +step:332/1555 train_time:11411ms step_avg:34.37ms +step:333/1555 train_time:11442ms step_avg:34.36ms +step:334/1555 train_time:11480ms step_avg:34.37ms +step:335/1555 train_time:11511ms step_avg:34.36ms +step:336/1555 train_time:11550ms step_avg:34.37ms +step:337/1555 train_time:11580ms step_avg:34.36ms +step:338/1555 train_time:11618ms step_avg:34.37ms +step:339/1555 train_time:11648ms step_avg:34.36ms +step:340/1555 train_time:11686ms step_avg:34.37ms +step:341/1555 train_time:11717ms step_avg:34.36ms +step:342/1555 train_time:11754ms step_avg:34.37ms +step:343/1555 train_time:11785ms step_avg:34.36ms +step:344/1555 train_time:11822ms step_avg:34.37ms +step:345/1555 train_time:11853ms step_avg:34.36ms +step:346/1555 train_time:11891ms step_avg:34.37ms +step:347/1555 train_time:11922ms step_avg:34.36ms +step:348/1555 train_time:11959ms step_avg:34.37ms +step:349/1555 train_time:11991ms step_avg:34.36ms +step:350/1555 train_time:12028ms step_avg:34.37ms +step:351/1555 train_time:12059ms step_avg:34.36ms +step:352/1555 train_time:12096ms step_avg:34.36ms +step:353/1555 train_time:12127ms step_avg:34.35ms +step:354/1555 train_time:12165ms step_avg:34.36ms +step:355/1555 train_time:12196ms step_avg:34.35ms +step:356/1555 train_time:12234ms step_avg:34.36ms +step:357/1555 train_time:12264ms step_avg:34.35ms +step:358/1555 train_time:12302ms step_avg:34.36ms +step:359/1555 train_time:12333ms step_avg:34.35ms +step:360/1555 train_time:12370ms step_avg:34.36ms +step:361/1555 train_time:12401ms step_avg:34.35ms +step:362/1555 train_time:12439ms step_avg:34.36ms +step:363/1555 train_time:12469ms step_avg:34.35ms +step:364/1555 train_time:12507ms step_avg:34.36ms +step:365/1555 train_time:12538ms step_avg:34.35ms +step:366/1555 train_time:12575ms step_avg:34.36ms +step:367/1555 train_time:12606ms step_avg:34.35ms +step:368/1555 train_time:12644ms step_avg:34.36ms +step:369/1555 train_time:12675ms step_avg:34.35ms +step:370/1555 train_time:12712ms step_avg:34.36ms +step:371/1555 train_time:12743ms step_avg:34.35ms +step:372/1555 train_time:12781ms step_avg:34.36ms +step:373/1555 train_time:12812ms step_avg:34.35ms +step:374/1555 train_time:12850ms step_avg:34.36ms +step:375/1555 train_time:12880ms step_avg:34.35ms +step:376/1555 train_time:12918ms step_avg:34.36ms +step:377/1555 train_time:12949ms step_avg:34.35ms +step:378/1555 train_time:12987ms step_avg:34.36ms +step:379/1555 train_time:13018ms step_avg:34.35ms +step:380/1555 train_time:13056ms step_avg:34.36ms +step:381/1555 train_time:13086ms step_avg:34.35ms +step:382/1555 train_time:13123ms step_avg:34.35ms +step:383/1555 train_time:13154ms step_avg:34.35ms +step:384/1555 train_time:13193ms step_avg:34.36ms +step:385/1555 train_time:13223ms step_avg:34.35ms +step:386/1555 train_time:13261ms step_avg:34.35ms +step:387/1555 train_time:13292ms step_avg:34.35ms +step:388/1555 train_time:13329ms step_avg:34.35ms +step:389/1555 train_time:13360ms step_avg:34.34ms +step:390/1555 train_time:13397ms step_avg:34.35ms +step:391/1555 train_time:13428ms step_avg:34.34ms +step:392/1555 train_time:13466ms step_avg:34.35ms +step:393/1555 train_time:13497ms step_avg:34.34ms +step:394/1555 train_time:13534ms step_avg:34.35ms +step:395/1555 train_time:13565ms step_avg:34.34ms +step:396/1555 train_time:13603ms step_avg:34.35ms +step:397/1555 train_time:13634ms step_avg:34.34ms +step:398/1555 train_time:13671ms step_avg:34.35ms +step:399/1555 train_time:13701ms step_avg:34.34ms +step:400/1555 train_time:13739ms step_avg:34.35ms +step:401/1555 train_time:13770ms step_avg:34.34ms +step:402/1555 train_time:13807ms step_avg:34.35ms +step:403/1555 train_time:13838ms step_avg:34.34ms +step:404/1555 train_time:13875ms step_avg:34.34ms +step:405/1555 train_time:13906ms step_avg:34.34ms +step:406/1555 train_time:13944ms step_avg:34.34ms +step:407/1555 train_time:13975ms step_avg:34.34ms +step:408/1555 train_time:14013ms step_avg:34.34ms +step:409/1555 train_time:14044ms step_avg:34.34ms +step:410/1555 train_time:14081ms step_avg:34.34ms +step:411/1555 train_time:14112ms step_avg:34.34ms +step:412/1555 train_time:14150ms step_avg:34.35ms +step:413/1555 train_time:14181ms step_avg:34.34ms +step:414/1555 train_time:14218ms step_avg:34.34ms +step:415/1555 train_time:14250ms step_avg:34.34ms +step:416/1555 train_time:14288ms step_avg:34.35ms +step:417/1555 train_time:14319ms step_avg:34.34ms +step:418/1555 train_time:14357ms step_avg:34.35ms +step:419/1555 train_time:14388ms step_avg:34.34ms +step:420/1555 train_time:14426ms step_avg:34.35ms +step:421/1555 train_time:14456ms step_avg:34.34ms +step:422/1555 train_time:14494ms step_avg:34.35ms +step:423/1555 train_time:14525ms step_avg:34.34ms +step:424/1555 train_time:14563ms step_avg:34.35ms +step:425/1555 train_time:14594ms step_avg:34.34ms +step:426/1555 train_time:14632ms step_avg:34.35ms +step:427/1555 train_time:14663ms step_avg:34.34ms +step:428/1555 train_time:14700ms step_avg:34.35ms +step:429/1555 train_time:14731ms step_avg:34.34ms +step:430/1555 train_time:14768ms step_avg:34.35ms +step:431/1555 train_time:14799ms step_avg:34.34ms +step:432/1555 train_time:14837ms step_avg:34.34ms +step:433/1555 train_time:14867ms step_avg:34.34ms +step:434/1555 train_time:14905ms step_avg:34.34ms +step:435/1555 train_time:14936ms step_avg:34.34ms +step:436/1555 train_time:14973ms step_avg:34.34ms +step:437/1555 train_time:15004ms step_avg:34.33ms +step:438/1555 train_time:15042ms step_avg:34.34ms +step:439/1555 train_time:15073ms step_avg:34.33ms +step:440/1555 train_time:15111ms step_avg:34.34ms +step:441/1555 train_time:15141ms step_avg:34.33ms +step:442/1555 train_time:15179ms step_avg:34.34ms +step:443/1555 train_time:15210ms step_avg:34.33ms +step:444/1555 train_time:15247ms step_avg:34.34ms +step:445/1555 train_time:15278ms step_avg:34.33ms +step:446/1555 train_time:15315ms step_avg:34.34ms +step:447/1555 train_time:15345ms step_avg:34.33ms +step:448/1555 train_time:15382ms step_avg:34.34ms +step:449/1555 train_time:15414ms step_avg:34.33ms +step:450/1555 train_time:15451ms step_avg:34.34ms +step:451/1555 train_time:15482ms step_avg:34.33ms +step:452/1555 train_time:15519ms step_avg:34.33ms +step:453/1555 train_time:15551ms step_avg:34.33ms +step:454/1555 train_time:15589ms step_avg:34.34ms +step:455/1555 train_time:15619ms step_avg:34.33ms +step:456/1555 train_time:15657ms step_avg:34.33ms +step:457/1555 train_time:15687ms step_avg:34.33ms +step:458/1555 train_time:15725ms step_avg:34.33ms +step:459/1555 train_time:15755ms step_avg:34.32ms +step:460/1555 train_time:15793ms step_avg:34.33ms +step:461/1555 train_time:15824ms step_avg:34.33ms +step:462/1555 train_time:15862ms step_avg:34.33ms +step:463/1555 train_time:15892ms step_avg:34.32ms +step:464/1555 train_time:15930ms step_avg:34.33ms +step:465/1555 train_time:15961ms step_avg:34.33ms +step:466/1555 train_time:15999ms step_avg:34.33ms +step:467/1555 train_time:16030ms step_avg:34.33ms +step:468/1555 train_time:16068ms step_avg:34.33ms +step:469/1555 train_time:16099ms step_avg:34.33ms +step:470/1555 train_time:16137ms step_avg:34.33ms +step:471/1555 train_time:16168ms step_avg:34.33ms +step:472/1555 train_time:16206ms step_avg:34.33ms +step:473/1555 train_time:16237ms step_avg:34.33ms +step:474/1555 train_time:16274ms step_avg:34.33ms +step:475/1555 train_time:16305ms step_avg:34.33ms +step:476/1555 train_time:16343ms step_avg:34.33ms +step:477/1555 train_time:16374ms step_avg:34.33ms +step:478/1555 train_time:16411ms step_avg:34.33ms +step:479/1555 train_time:16442ms step_avg:34.33ms +step:480/1555 train_time:16480ms step_avg:34.33ms +step:481/1555 train_time:16511ms step_avg:34.33ms +step:482/1555 train_time:16549ms step_avg:34.34ms +step:483/1555 train_time:16580ms step_avg:34.33ms +step:484/1555 train_time:16618ms step_avg:34.33ms +step:485/1555 train_time:16648ms step_avg:34.33ms +step:486/1555 train_time:16686ms step_avg:34.33ms +step:487/1555 train_time:16716ms step_avg:34.33ms +step:488/1555 train_time:16754ms step_avg:34.33ms +step:489/1555 train_time:16785ms step_avg:34.32ms +step:490/1555 train_time:16822ms step_avg:34.33ms +step:491/1555 train_time:16853ms step_avg:34.32ms +step:492/1555 train_time:16890ms step_avg:34.33ms +step:493/1555 train_time:16922ms step_avg:34.32ms +step:494/1555 train_time:16959ms step_avg:34.33ms +step:495/1555 train_time:16990ms step_avg:34.32ms +step:496/1555 train_time:17028ms step_avg:34.33ms +step:497/1555 train_time:17058ms step_avg:34.32ms +step:498/1555 train_time:17096ms step_avg:34.33ms +step:499/1555 train_time:17127ms step_avg:34.32ms +step:500/1555 train_time:17165ms step_avg:34.33ms +step:500/1555 val_loss:4.2334 train_time:17214ms step_avg:34.43ms +step:501/1555 train_time:17234ms step_avg:34.40ms +step:502/1555 train_time:17255ms step_avg:34.37ms +step:503/1555 train_time:17274ms step_avg:34.34ms +step:504/1555 train_time:17303ms step_avg:34.33ms +step:505/1555 train_time:17337ms step_avg:34.33ms +step:506/1555 train_time:17379ms step_avg:34.35ms +step:507/1555 train_time:17434ms step_avg:34.39ms +step:508/1555 train_time:17498ms step_avg:34.45ms +step:509/1555 train_time:17556ms step_avg:34.49ms +step:510/1555 train_time:17619ms step_avg:34.55ms +step:511/1555 train_time:17676ms step_avg:34.59ms +step:512/1555 train_time:17739ms step_avg:34.65ms +step:513/1555 train_time:17796ms step_avg:34.69ms +step:514/1555 train_time:17858ms step_avg:34.74ms +step:515/1555 train_time:17915ms step_avg:34.79ms +step:516/1555 train_time:17978ms step_avg:34.84ms +step:517/1555 train_time:18036ms step_avg:34.89ms +step:518/1555 train_time:18100ms step_avg:34.94ms +step:519/1555 train_time:18157ms step_avg:34.98ms +step:520/1555 train_time:18224ms step_avg:35.05ms +step:521/1555 train_time:18284ms step_avg:35.09ms +step:522/1555 train_time:18351ms step_avg:35.15ms +step:523/1555 train_time:18409ms step_avg:35.20ms +step:524/1555 train_time:18475ms step_avg:35.26ms +step:525/1555 train_time:18533ms step_avg:35.30ms +step:526/1555 train_time:18598ms step_avg:35.36ms +step:527/1555 train_time:18655ms step_avg:35.40ms +step:528/1555 train_time:18720ms step_avg:35.45ms +step:529/1555 train_time:18777ms step_avg:35.49ms +step:530/1555 train_time:18840ms step_avg:35.55ms +step:531/1555 train_time:18897ms step_avg:35.59ms +step:532/1555 train_time:18960ms step_avg:35.64ms +step:533/1555 train_time:19017ms step_avg:35.68ms +step:534/1555 train_time:19081ms step_avg:35.73ms +step:535/1555 train_time:19138ms step_avg:35.77ms +step:536/1555 train_time:19202ms step_avg:35.83ms +step:537/1555 train_time:19261ms step_avg:35.87ms +step:538/1555 train_time:19327ms step_avg:35.92ms +step:539/1555 train_time:19386ms step_avg:35.97ms +step:540/1555 train_time:19451ms step_avg:36.02ms +step:541/1555 train_time:19510ms step_avg:36.06ms +step:542/1555 train_time:19575ms step_avg:36.12ms +step:543/1555 train_time:19632ms step_avg:36.16ms +step:544/1555 train_time:19696ms step_avg:36.21ms +step:545/1555 train_time:19754ms step_avg:36.25ms +step:546/1555 train_time:19818ms step_avg:36.30ms +step:547/1555 train_time:19876ms step_avg:36.34ms +step:548/1555 train_time:19939ms step_avg:36.39ms +step:549/1555 train_time:19996ms step_avg:36.42ms +step:550/1555 train_time:20059ms step_avg:36.47ms +step:551/1555 train_time:20117ms step_avg:36.51ms +step:552/1555 train_time:20182ms step_avg:36.56ms +step:553/1555 train_time:20240ms step_avg:36.60ms +step:554/1555 train_time:20304ms step_avg:36.65ms +step:555/1555 train_time:20363ms step_avg:36.69ms +step:556/1555 train_time:20428ms step_avg:36.74ms +step:557/1555 train_time:20486ms step_avg:36.78ms +step:558/1555 train_time:20552ms step_avg:36.83ms +step:559/1555 train_time:20610ms step_avg:36.87ms +step:560/1555 train_time:20675ms step_avg:36.92ms +step:561/1555 train_time:20732ms step_avg:36.96ms +step:562/1555 train_time:20797ms step_avg:37.01ms +step:563/1555 train_time:20855ms step_avg:37.04ms +step:564/1555 train_time:20920ms step_avg:37.09ms +step:565/1555 train_time:20977ms step_avg:37.13ms +step:566/1555 train_time:21040ms step_avg:37.17ms +step:567/1555 train_time:21097ms step_avg:37.21ms +step:568/1555 train_time:21161ms step_avg:37.26ms +step:569/1555 train_time:21218ms step_avg:37.29ms +step:570/1555 train_time:21283ms step_avg:37.34ms +step:571/1555 train_time:21341ms step_avg:37.37ms +step:572/1555 train_time:21407ms step_avg:37.42ms +step:573/1555 train_time:21465ms step_avg:37.46ms +step:574/1555 train_time:21530ms step_avg:37.51ms +step:575/1555 train_time:21588ms step_avg:37.54ms +step:576/1555 train_time:21653ms step_avg:37.59ms +step:577/1555 train_time:21711ms step_avg:37.63ms +step:578/1555 train_time:21775ms step_avg:37.67ms +step:579/1555 train_time:21833ms step_avg:37.71ms +step:580/1555 train_time:21898ms step_avg:37.76ms +step:581/1555 train_time:21956ms step_avg:37.79ms +step:582/1555 train_time:22020ms step_avg:37.84ms +step:583/1555 train_time:22078ms step_avg:37.87ms +step:584/1555 train_time:22141ms step_avg:37.91ms +step:585/1555 train_time:22198ms step_avg:37.95ms +step:586/1555 train_time:22262ms step_avg:37.99ms +step:587/1555 train_time:22320ms step_avg:38.02ms +step:588/1555 train_time:22384ms step_avg:38.07ms +step:589/1555 train_time:22442ms step_avg:38.10ms +step:590/1555 train_time:22507ms step_avg:38.15ms +step:591/1555 train_time:22566ms step_avg:38.18ms +step:592/1555 train_time:22630ms step_avg:38.23ms +step:593/1555 train_time:22689ms step_avg:38.26ms +step:594/1555 train_time:22754ms step_avg:38.31ms +step:595/1555 train_time:22812ms step_avg:38.34ms +step:596/1555 train_time:22877ms step_avg:38.38ms +step:597/1555 train_time:22934ms step_avg:38.42ms +step:598/1555 train_time:22999ms step_avg:38.46ms +step:599/1555 train_time:23056ms step_avg:38.49ms +step:600/1555 train_time:23120ms step_avg:38.53ms +step:601/1555 train_time:23177ms step_avg:38.56ms +step:602/1555 train_time:23241ms step_avg:38.61ms +step:603/1555 train_time:23298ms step_avg:38.64ms +step:604/1555 train_time:23362ms step_avg:38.68ms +step:605/1555 train_time:23420ms step_avg:38.71ms +step:606/1555 train_time:23485ms step_avg:38.75ms +step:607/1555 train_time:23543ms step_avg:38.79ms +step:608/1555 train_time:23608ms step_avg:38.83ms +step:609/1555 train_time:23666ms step_avg:38.86ms +step:610/1555 train_time:23730ms step_avg:38.90ms +step:611/1555 train_time:23789ms step_avg:38.93ms +step:612/1555 train_time:23852ms step_avg:38.97ms +step:613/1555 train_time:23911ms step_avg:39.01ms +step:614/1555 train_time:23976ms step_avg:39.05ms +step:615/1555 train_time:24034ms step_avg:39.08ms +step:616/1555 train_time:24098ms step_avg:39.12ms +step:617/1555 train_time:24156ms step_avg:39.15ms +step:618/1555 train_time:24221ms step_avg:39.19ms +step:619/1555 train_time:24278ms step_avg:39.22ms +step:620/1555 train_time:24342ms step_avg:39.26ms +step:621/1555 train_time:24399ms step_avg:39.29ms +step:622/1555 train_time:24463ms step_avg:39.33ms +step:623/1555 train_time:24520ms step_avg:39.36ms +step:624/1555 train_time:24586ms step_avg:39.40ms +step:625/1555 train_time:24644ms step_avg:39.43ms +step:626/1555 train_time:24709ms step_avg:39.47ms +step:627/1555 train_time:24767ms step_avg:39.50ms +step:628/1555 train_time:24832ms step_avg:39.54ms +step:629/1555 train_time:24889ms step_avg:39.57ms +step:630/1555 train_time:24954ms step_avg:39.61ms +step:631/1555 train_time:25012ms step_avg:39.64ms +step:632/1555 train_time:25077ms step_avg:39.68ms +step:633/1555 train_time:25134ms step_avg:39.71ms +step:634/1555 train_time:25198ms step_avg:39.75ms +step:635/1555 train_time:25256ms step_avg:39.77ms +step:636/1555 train_time:25321ms step_avg:39.81ms +step:637/1555 train_time:25378ms step_avg:39.84ms +step:638/1555 train_time:25441ms step_avg:39.88ms +step:639/1555 train_time:25499ms step_avg:39.90ms +step:640/1555 train_time:25563ms step_avg:39.94ms +step:641/1555 train_time:25620ms step_avg:39.97ms +step:642/1555 train_time:25685ms step_avg:40.01ms +step:643/1555 train_time:25743ms step_avg:40.04ms +step:644/1555 train_time:25808ms step_avg:40.07ms +step:645/1555 train_time:25866ms step_avg:40.10ms +step:646/1555 train_time:25931ms step_avg:40.14ms +step:647/1555 train_time:25989ms step_avg:40.17ms +step:648/1555 train_time:26055ms step_avg:40.21ms +step:649/1555 train_time:26113ms step_avg:40.24ms +step:650/1555 train_time:26178ms step_avg:40.27ms +step:651/1555 train_time:26235ms step_avg:40.30ms +step:652/1555 train_time:26300ms step_avg:40.34ms +step:653/1555 train_time:26357ms step_avg:40.36ms +step:654/1555 train_time:26421ms step_avg:40.40ms +step:655/1555 train_time:26478ms step_avg:40.42ms +step:656/1555 train_time:26542ms step_avg:40.46ms +step:657/1555 train_time:26599ms step_avg:40.49ms +step:658/1555 train_time:26664ms step_avg:40.52ms +step:659/1555 train_time:26721ms step_avg:40.55ms +step:660/1555 train_time:26786ms step_avg:40.58ms +step:661/1555 train_time:26844ms step_avg:40.61ms +step:662/1555 train_time:26908ms step_avg:40.65ms +step:663/1555 train_time:26967ms step_avg:40.67ms +step:664/1555 train_time:27032ms step_avg:40.71ms +step:665/1555 train_time:27090ms step_avg:40.74ms +step:666/1555 train_time:27155ms step_avg:40.77ms +step:667/1555 train_time:27213ms step_avg:40.80ms +step:668/1555 train_time:27278ms step_avg:40.83ms +step:669/1555 train_time:27335ms step_avg:40.86ms +step:670/1555 train_time:27399ms step_avg:40.89ms +step:671/1555 train_time:27456ms step_avg:40.92ms +step:672/1555 train_time:27520ms step_avg:40.95ms +step:673/1555 train_time:27577ms step_avg:40.98ms +step:674/1555 train_time:27641ms step_avg:41.01ms +step:675/1555 train_time:27698ms step_avg:41.03ms +step:676/1555 train_time:27762ms step_avg:41.07ms +step:677/1555 train_time:27819ms step_avg:41.09ms +step:678/1555 train_time:27885ms step_avg:41.13ms +step:679/1555 train_time:27943ms step_avg:41.15ms +step:680/1555 train_time:28008ms step_avg:41.19ms +step:681/1555 train_time:28066ms step_avg:41.21ms +step:682/1555 train_time:28131ms step_avg:41.25ms +step:683/1555 train_time:28190ms step_avg:41.27ms +step:684/1555 train_time:28254ms step_avg:41.31ms +step:685/1555 train_time:28313ms step_avg:41.33ms +step:686/1555 train_time:28376ms step_avg:41.36ms +step:687/1555 train_time:28433ms step_avg:41.39ms +step:688/1555 train_time:28498ms step_avg:41.42ms +step:689/1555 train_time:28556ms step_avg:41.45ms +step:690/1555 train_time:28620ms step_avg:41.48ms +step:691/1555 train_time:28677ms step_avg:41.50ms +step:692/1555 train_time:28740ms step_avg:41.53ms +step:693/1555 train_time:28797ms step_avg:41.55ms +step:694/1555 train_time:28862ms step_avg:41.59ms +step:695/1555 train_time:28920ms step_avg:41.61ms +step:696/1555 train_time:28985ms step_avg:41.64ms +step:697/1555 train_time:29042ms step_avg:41.67ms +step:698/1555 train_time:29108ms step_avg:41.70ms +step:699/1555 train_time:29166ms step_avg:41.73ms +step:700/1555 train_time:29231ms step_avg:41.76ms +step:701/1555 train_time:29289ms step_avg:41.78ms +step:702/1555 train_time:29354ms step_avg:41.81ms +step:703/1555 train_time:29413ms step_avg:41.84ms +step:704/1555 train_time:29476ms step_avg:41.87ms +step:705/1555 train_time:29534ms step_avg:41.89ms +step:706/1555 train_time:29598ms step_avg:41.92ms +step:707/1555 train_time:29656ms step_avg:41.95ms +step:708/1555 train_time:29719ms step_avg:41.98ms +step:709/1555 train_time:29777ms step_avg:42.00ms +step:710/1555 train_time:29841ms step_avg:42.03ms +step:711/1555 train_time:29898ms step_avg:42.05ms +step:712/1555 train_time:29963ms step_avg:42.08ms +step:713/1555 train_time:30020ms step_avg:42.10ms +step:714/1555 train_time:30085ms step_avg:42.14ms +step:715/1555 train_time:30143ms step_avg:42.16ms +step:716/1555 train_time:30208ms step_avg:42.19ms +step:717/1555 train_time:30267ms step_avg:42.21ms +step:718/1555 train_time:30331ms step_avg:42.24ms +step:719/1555 train_time:30389ms step_avg:42.27ms +step:720/1555 train_time:30454ms step_avg:42.30ms +step:721/1555 train_time:30512ms step_avg:42.32ms +step:722/1555 train_time:30576ms step_avg:42.35ms +step:723/1555 train_time:30634ms step_avg:42.37ms +step:724/1555 train_time:30698ms step_avg:42.40ms +step:725/1555 train_time:30757ms step_avg:42.42ms +step:726/1555 train_time:30821ms step_avg:42.45ms +step:727/1555 train_time:30878ms step_avg:42.47ms +step:728/1555 train_time:30941ms step_avg:42.50ms +step:729/1555 train_time:30998ms step_avg:42.52ms +step:730/1555 train_time:31063ms step_avg:42.55ms +step:731/1555 train_time:31120ms step_avg:42.57ms +step:732/1555 train_time:31185ms step_avg:42.60ms +step:733/1555 train_time:31244ms step_avg:42.62ms +step:734/1555 train_time:31308ms step_avg:42.65ms +step:735/1555 train_time:31367ms step_avg:42.68ms +step:736/1555 train_time:31431ms step_avg:42.70ms +step:737/1555 train_time:31489ms step_avg:42.73ms +step:738/1555 train_time:31553ms step_avg:42.76ms +step:739/1555 train_time:31612ms step_avg:42.78ms +step:740/1555 train_time:31676ms step_avg:42.81ms +step:741/1555 train_time:31734ms step_avg:42.83ms +step:742/1555 train_time:31799ms step_avg:42.86ms +step:743/1555 train_time:31856ms step_avg:42.88ms +step:744/1555 train_time:31921ms step_avg:42.90ms +step:745/1555 train_time:31978ms step_avg:42.92ms +step:746/1555 train_time:32041ms step_avg:42.95ms +step:747/1555 train_time:32098ms step_avg:42.97ms +step:748/1555 train_time:32162ms step_avg:43.00ms +step:749/1555 train_time:32220ms step_avg:43.02ms +step:750/1555 train_time:32285ms step_avg:43.05ms +step:750/1555 val_loss:3.8780 train_time:32368ms step_avg:43.16ms +step:751/1555 train_time:32389ms step_avg:43.13ms +step:752/1555 train_time:32412ms step_avg:43.10ms +step:753/1555 train_time:32466ms step_avg:43.12ms +step:754/1555 train_time:32536ms step_avg:43.15ms +step:755/1555 train_time:32594ms step_avg:43.17ms +step:756/1555 train_time:32658ms step_avg:43.20ms +step:757/1555 train_time:32715ms step_avg:43.22ms +step:758/1555 train_time:32779ms step_avg:43.24ms +step:759/1555 train_time:32835ms step_avg:43.26ms +step:760/1555 train_time:32899ms step_avg:43.29ms +step:761/1555 train_time:32956ms step_avg:43.31ms +step:762/1555 train_time:33020ms step_avg:43.33ms +step:763/1555 train_time:33077ms step_avg:43.35ms +step:764/1555 train_time:33140ms step_avg:43.38ms +step:765/1555 train_time:33198ms step_avg:43.40ms +step:766/1555 train_time:33261ms step_avg:43.42ms +step:767/1555 train_time:33320ms step_avg:43.44ms +step:768/1555 train_time:33385ms step_avg:43.47ms +step:769/1555 train_time:33446ms step_avg:43.49ms +step:770/1555 train_time:33511ms step_avg:43.52ms +step:771/1555 train_time:33570ms step_avg:43.54ms +step:772/1555 train_time:33635ms step_avg:43.57ms +step:773/1555 train_time:33692ms step_avg:43.59ms +step:774/1555 train_time:33756ms step_avg:43.61ms +step:775/1555 train_time:33813ms step_avg:43.63ms +step:776/1555 train_time:33877ms step_avg:43.66ms +step:777/1555 train_time:33934ms step_avg:43.67ms +step:778/1555 train_time:33999ms step_avg:43.70ms +step:779/1555 train_time:34055ms step_avg:43.72ms +step:780/1555 train_time:34119ms step_avg:43.74ms +step:781/1555 train_time:34175ms step_avg:43.76ms +step:782/1555 train_time:34240ms step_avg:43.79ms +step:783/1555 train_time:34297ms step_avg:43.80ms +step:784/1555 train_time:34362ms step_avg:43.83ms +step:785/1555 train_time:34421ms step_avg:43.85ms +step:786/1555 train_time:34486ms step_avg:43.88ms +step:787/1555 train_time:34545ms step_avg:43.90ms +step:788/1555 train_time:34610ms step_avg:43.92ms +step:789/1555 train_time:34668ms step_avg:43.94ms +step:790/1555 train_time:34733ms step_avg:43.97ms +step:791/1555 train_time:34791ms step_avg:43.98ms +step:792/1555 train_time:34854ms step_avg:44.01ms +step:793/1555 train_time:34911ms step_avg:44.02ms +step:794/1555 train_time:34974ms step_avg:44.05ms +step:795/1555 train_time:35031ms step_avg:44.06ms +step:796/1555 train_time:35095ms step_avg:44.09ms +step:797/1555 train_time:35152ms step_avg:44.11ms +step:798/1555 train_time:35216ms step_avg:44.13ms +step:799/1555 train_time:35274ms step_avg:44.15ms +step:800/1555 train_time:35338ms step_avg:44.17ms +step:801/1555 train_time:35395ms step_avg:44.19ms +step:802/1555 train_time:35461ms step_avg:44.22ms +step:803/1555 train_time:35519ms step_avg:44.23ms +step:804/1555 train_time:35584ms step_avg:44.26ms +step:805/1555 train_time:35643ms step_avg:44.28ms +step:806/1555 train_time:35707ms step_avg:44.30ms +step:807/1555 train_time:35766ms step_avg:44.32ms +step:808/1555 train_time:35830ms step_avg:44.34ms +step:809/1555 train_time:35888ms step_avg:44.36ms +step:810/1555 train_time:35951ms step_avg:44.38ms +step:811/1555 train_time:36009ms step_avg:44.40ms +step:812/1555 train_time:36072ms step_avg:44.42ms +step:813/1555 train_time:36130ms step_avg:44.44ms +step:814/1555 train_time:36194ms step_avg:44.46ms +step:815/1555 train_time:36251ms step_avg:44.48ms +step:816/1555 train_time:36315ms step_avg:44.50ms +step:817/1555 train_time:36372ms step_avg:44.52ms +step:818/1555 train_time:36438ms step_avg:44.54ms +step:819/1555 train_time:36494ms step_avg:44.56ms +step:820/1555 train_time:36560ms step_avg:44.59ms +step:821/1555 train_time:36618ms step_avg:44.60ms +step:822/1555 train_time:36682ms step_avg:44.63ms +step:823/1555 train_time:36740ms step_avg:44.64ms +step:824/1555 train_time:36804ms step_avg:44.67ms +step:825/1555 train_time:36863ms step_avg:44.68ms +step:826/1555 train_time:36928ms step_avg:44.71ms +step:827/1555 train_time:36986ms step_avg:44.72ms +step:828/1555 train_time:37050ms step_avg:44.75ms +step:829/1555 train_time:37107ms step_avg:44.76ms +step:830/1555 train_time:37171ms step_avg:44.78ms +step:831/1555 train_time:37230ms step_avg:44.80ms +step:832/1555 train_time:37293ms step_avg:44.82ms +step:833/1555 train_time:37350ms step_avg:44.84ms +step:834/1555 train_time:37415ms step_avg:44.86ms +step:835/1555 train_time:37471ms step_avg:44.88ms +step:836/1555 train_time:37536ms step_avg:44.90ms +step:837/1555 train_time:37594ms step_avg:44.91ms +step:838/1555 train_time:37658ms step_avg:44.94ms +step:839/1555 train_time:37717ms step_avg:44.95ms +step:840/1555 train_time:37782ms step_avg:44.98ms +step:841/1555 train_time:37840ms step_avg:44.99ms +step:842/1555 train_time:37904ms step_avg:45.02ms +step:843/1555 train_time:37962ms step_avg:45.03ms +step:844/1555 train_time:38026ms step_avg:45.05ms +step:845/1555 train_time:38084ms step_avg:45.07ms +step:846/1555 train_time:38148ms step_avg:45.09ms +step:847/1555 train_time:38206ms step_avg:45.11ms +step:848/1555 train_time:38270ms step_avg:45.13ms +step:849/1555 train_time:38329ms step_avg:45.15ms +step:850/1555 train_time:38393ms step_avg:45.17ms +step:851/1555 train_time:38450ms step_avg:45.18ms +step:852/1555 train_time:38514ms step_avg:45.20ms +step:853/1555 train_time:38571ms step_avg:45.22ms +step:854/1555 train_time:38636ms step_avg:45.24ms +step:855/1555 train_time:38693ms step_avg:45.26ms +step:856/1555 train_time:38758ms step_avg:45.28ms +step:857/1555 train_time:38816ms step_avg:45.29ms +step:858/1555 train_time:38880ms step_avg:45.32ms +step:859/1555 train_time:38939ms step_avg:45.33ms +step:860/1555 train_time:39003ms step_avg:45.35ms +step:861/1555 train_time:39060ms step_avg:45.37ms +step:862/1555 train_time:39125ms step_avg:45.39ms +step:863/1555 train_time:39184ms step_avg:45.40ms +step:864/1555 train_time:39249ms step_avg:45.43ms +step:865/1555 train_time:39307ms step_avg:45.44ms +step:866/1555 train_time:39372ms step_avg:45.46ms +step:867/1555 train_time:39430ms step_avg:45.48ms +step:868/1555 train_time:39493ms step_avg:45.50ms +step:869/1555 train_time:39551ms step_avg:45.51ms +step:870/1555 train_time:39615ms step_avg:45.53ms +step:871/1555 train_time:39672ms step_avg:45.55ms +step:872/1555 train_time:39736ms step_avg:45.57ms +step:873/1555 train_time:39794ms step_avg:45.58ms +step:874/1555 train_time:39858ms step_avg:45.60ms +step:875/1555 train_time:39916ms step_avg:45.62ms +step:876/1555 train_time:39980ms step_avg:45.64ms +step:877/1555 train_time:40038ms step_avg:45.65ms +step:878/1555 train_time:40102ms step_avg:45.67ms +step:879/1555 train_time:40161ms step_avg:45.69ms +step:880/1555 train_time:40225ms step_avg:45.71ms +step:881/1555 train_time:40284ms step_avg:45.73ms +step:882/1555 train_time:40349ms step_avg:45.75ms +step:883/1555 train_time:40407ms step_avg:45.76ms +step:884/1555 train_time:40471ms step_avg:45.78ms +step:885/1555 train_time:40530ms step_avg:45.80ms +step:886/1555 train_time:40592ms step_avg:45.82ms +step:887/1555 train_time:40650ms step_avg:45.83ms +step:888/1555 train_time:40714ms step_avg:45.85ms +step:889/1555 train_time:40771ms step_avg:45.86ms +step:890/1555 train_time:40836ms step_avg:45.88ms +step:891/1555 train_time:40893ms step_avg:45.90ms +step:892/1555 train_time:40958ms step_avg:45.92ms +step:893/1555 train_time:41015ms step_avg:45.93ms +step:894/1555 train_time:41080ms step_avg:45.95ms +step:895/1555 train_time:41138ms step_avg:45.96ms +step:896/1555 train_time:41203ms step_avg:45.99ms +step:897/1555 train_time:41261ms step_avg:46.00ms +step:898/1555 train_time:41326ms step_avg:46.02ms +step:899/1555 train_time:41384ms step_avg:46.03ms +step:900/1555 train_time:41448ms step_avg:46.05ms +step:901/1555 train_time:41506ms step_avg:46.07ms +step:902/1555 train_time:41570ms step_avg:46.09ms +step:903/1555 train_time:41628ms step_avg:46.10ms +step:904/1555 train_time:41692ms step_avg:46.12ms +step:905/1555 train_time:41750ms step_avg:46.13ms +step:906/1555 train_time:41814ms step_avg:46.15ms +step:907/1555 train_time:41872ms step_avg:46.17ms +step:908/1555 train_time:41936ms step_avg:46.18ms +step:909/1555 train_time:41992ms step_avg:46.20ms +step:910/1555 train_time:42057ms step_avg:46.22ms +step:911/1555 train_time:42113ms step_avg:46.23ms +step:912/1555 train_time:42178ms step_avg:46.25ms +step:913/1555 train_time:42236ms step_avg:46.26ms +step:914/1555 train_time:42301ms step_avg:46.28ms +step:915/1555 train_time:42359ms step_avg:46.29ms +step:916/1555 train_time:42423ms step_avg:46.31ms +step:917/1555 train_time:42481ms step_avg:46.33ms +step:918/1555 train_time:42545ms step_avg:46.35ms +step:919/1555 train_time:42605ms step_avg:46.36ms +step:920/1555 train_time:42668ms step_avg:46.38ms +step:921/1555 train_time:42727ms step_avg:46.39ms +step:922/1555 train_time:42791ms step_avg:46.41ms +step:923/1555 train_time:42849ms step_avg:46.42ms +step:924/1555 train_time:42913ms step_avg:46.44ms +step:925/1555 train_time:42971ms step_avg:46.46ms +step:926/1555 train_time:43034ms step_avg:46.47ms +step:927/1555 train_time:43091ms step_avg:46.48ms +step:928/1555 train_time:43155ms step_avg:46.50ms +step:929/1555 train_time:43213ms step_avg:46.52ms +step:930/1555 train_time:43278ms step_avg:46.54ms +step:931/1555 train_time:43335ms step_avg:46.55ms +step:932/1555 train_time:43400ms step_avg:46.57ms +step:933/1555 train_time:43458ms step_avg:46.58ms +step:934/1555 train_time:43523ms step_avg:46.60ms +step:935/1555 train_time:43581ms step_avg:46.61ms +step:936/1555 train_time:43645ms step_avg:46.63ms +step:937/1555 train_time:43704ms step_avg:46.64ms +step:938/1555 train_time:43767ms step_avg:46.66ms +step:939/1555 train_time:43825ms step_avg:46.67ms +step:940/1555 train_time:43889ms step_avg:46.69ms +step:941/1555 train_time:43948ms step_avg:46.70ms +step:942/1555 train_time:44011ms step_avg:46.72ms +step:943/1555 train_time:44069ms step_avg:46.73ms +step:944/1555 train_time:44133ms step_avg:46.75ms +step:945/1555 train_time:44190ms step_avg:46.76ms +step:946/1555 train_time:44254ms step_avg:46.78ms +step:947/1555 train_time:44311ms step_avg:46.79ms +step:948/1555 train_time:44376ms step_avg:46.81ms +step:949/1555 train_time:44434ms step_avg:46.82ms +step:950/1555 train_time:44499ms step_avg:46.84ms +step:951/1555 train_time:44556ms step_avg:46.85ms +step:952/1555 train_time:44621ms step_avg:46.87ms +step:953/1555 train_time:44679ms step_avg:46.88ms +step:954/1555 train_time:44743ms step_avg:46.90ms +step:955/1555 train_time:44801ms step_avg:46.91ms +step:956/1555 train_time:44866ms step_avg:46.93ms +step:957/1555 train_time:44924ms step_avg:46.94ms +step:958/1555 train_time:44989ms step_avg:46.96ms +step:959/1555 train_time:45047ms step_avg:46.97ms +step:960/1555 train_time:45112ms step_avg:46.99ms +step:961/1555 train_time:45169ms step_avg:47.00ms +step:962/1555 train_time:45234ms step_avg:47.02ms +step:963/1555 train_time:45291ms step_avg:47.03ms +step:964/1555 train_time:45355ms step_avg:47.05ms +step:965/1555 train_time:45412ms step_avg:47.06ms +step:966/1555 train_time:45475ms step_avg:47.08ms +step:967/1555 train_time:45533ms step_avg:47.09ms +step:968/1555 train_time:45598ms step_avg:47.11ms +step:969/1555 train_time:45656ms step_avg:47.12ms +step:970/1555 train_time:45721ms step_avg:47.14ms +step:971/1555 train_time:45779ms step_avg:47.15ms +step:972/1555 train_time:45844ms step_avg:47.16ms +step:973/1555 train_time:45902ms step_avg:47.18ms +step:974/1555 train_time:45966ms step_avg:47.19ms +step:975/1555 train_time:46024ms step_avg:47.20ms +step:976/1555 train_time:46089ms step_avg:47.22ms +step:977/1555 train_time:46147ms step_avg:47.23ms +step:978/1555 train_time:46211ms step_avg:47.25ms +step:979/1555 train_time:46270ms step_avg:47.26ms +step:980/1555 train_time:46333ms step_avg:47.28ms +step:981/1555 train_time:46391ms step_avg:47.29ms +step:982/1555 train_time:46455ms step_avg:47.31ms +step:983/1555 train_time:46512ms step_avg:47.32ms +step:984/1555 train_time:46576ms step_avg:47.33ms +step:985/1555 train_time:46635ms step_avg:47.34ms +step:986/1555 train_time:46699ms step_avg:47.36ms +step:987/1555 train_time:46757ms step_avg:47.37ms +step:988/1555 train_time:46820ms step_avg:47.39ms +step:989/1555 train_time:46878ms step_avg:47.40ms +step:990/1555 train_time:46944ms step_avg:47.42ms +step:991/1555 train_time:47002ms step_avg:47.43ms +step:992/1555 train_time:47066ms step_avg:47.45ms +step:993/1555 train_time:47125ms step_avg:47.46ms +step:994/1555 train_time:47189ms step_avg:47.47ms +step:995/1555 train_time:47247ms step_avg:47.48ms +step:996/1555 train_time:47311ms step_avg:47.50ms +step:997/1555 train_time:47370ms step_avg:47.51ms +step:998/1555 train_time:47434ms step_avg:47.53ms +step:999/1555 train_time:47492ms step_avg:47.54ms +step:1000/1555 train_time:47555ms step_avg:47.56ms +step:1000/1555 val_loss:3.5743 train_time:47637ms step_avg:47.64ms +step:1001/1555 train_time:47660ms step_avg:47.61ms +step:1002/1555 train_time:47684ms step_avg:47.59ms +step:1003/1555 train_time:47737ms step_avg:47.59ms +step:1004/1555 train_time:47807ms step_avg:47.62ms +step:1005/1555 train_time:47866ms step_avg:47.63ms +step:1006/1555 train_time:47930ms step_avg:47.64ms +step:1007/1555 train_time:47990ms step_avg:47.66ms +step:1008/1555 train_time:48053ms step_avg:47.67ms +step:1009/1555 train_time:48109ms step_avg:47.68ms +step:1010/1555 train_time:48172ms step_avg:47.69ms +step:1011/1555 train_time:48233ms step_avg:47.71ms +step:1012/1555 train_time:48317ms step_avg:47.74ms +step:1013/1555 train_time:48400ms step_avg:47.78ms +step:1014/1555 train_time:48491ms step_avg:47.82ms +step:1015/1555 train_time:48573ms step_avg:47.85ms +step:1016/1555 train_time:48663ms step_avg:47.90ms +step:1017/1555 train_time:48750ms step_avg:47.94ms +step:1018/1555 train_time:48843ms step_avg:47.98ms +step:1019/1555 train_time:48931ms step_avg:48.02ms +step:1020/1555 train_time:49021ms step_avg:48.06ms +step:1021/1555 train_time:49106ms step_avg:48.10ms +step:1022/1555 train_time:49196ms step_avg:48.14ms +step:1023/1555 train_time:49280ms step_avg:48.17ms +step:1024/1555 train_time:49369ms step_avg:48.21ms +step:1025/1555 train_time:49451ms step_avg:48.25ms +step:1026/1555 train_time:49540ms step_avg:48.28ms +step:1027/1555 train_time:49624ms step_avg:48.32ms +step:1028/1555 train_time:49715ms step_avg:48.36ms +step:1029/1555 train_time:49800ms step_avg:48.40ms +step:1030/1555 train_time:49894ms step_avg:48.44ms +step:1031/1555 train_time:49977ms step_avg:48.47ms +step:1032/1555 train_time:50067ms step_avg:48.51ms +step:1033/1555 train_time:50152ms step_avg:48.55ms +step:1034/1555 train_time:50241ms step_avg:48.59ms +step:1035/1555 train_time:50325ms step_avg:48.62ms +step:1036/1555 train_time:50413ms step_avg:48.66ms +step:1037/1555 train_time:50497ms step_avg:48.70ms +step:1038/1555 train_time:50586ms step_avg:48.73ms +step:1039/1555 train_time:50670ms step_avg:48.77ms +step:1040/1555 train_time:50760ms step_avg:48.81ms +step:1041/1555 train_time:50846ms step_avg:48.84ms +step:1042/1555 train_time:50936ms step_avg:48.88ms +step:1043/1555 train_time:51021ms step_avg:48.92ms +step:1044/1555 train_time:51111ms step_avg:48.96ms +step:1045/1555 train_time:51195ms step_avg:48.99ms +step:1046/1555 train_time:51284ms step_avg:49.03ms +step:1047/1555 train_time:51368ms step_avg:49.06ms +step:1048/1555 train_time:51457ms step_avg:49.10ms +step:1049/1555 train_time:51540ms step_avg:49.13ms +step:1050/1555 train_time:51631ms step_avg:49.17ms +step:1051/1555 train_time:51715ms step_avg:49.21ms +step:1052/1555 train_time:51807ms step_avg:49.25ms +step:1053/1555 train_time:51891ms step_avg:49.28ms +step:1054/1555 train_time:51981ms step_avg:49.32ms +step:1055/1555 train_time:52067ms step_avg:49.35ms +step:1056/1555 train_time:52156ms step_avg:49.39ms +step:1057/1555 train_time:52241ms step_avg:49.42ms +step:1058/1555 train_time:52332ms step_avg:49.46ms +step:1059/1555 train_time:52414ms step_avg:49.49ms +step:1060/1555 train_time:52504ms step_avg:49.53ms +step:1061/1555 train_time:52589ms step_avg:49.57ms +step:1062/1555 train_time:52678ms step_avg:49.60ms +step:1063/1555 train_time:52763ms step_avg:49.64ms +step:1064/1555 train_time:52852ms step_avg:49.67ms +step:1065/1555 train_time:52936ms step_avg:49.71ms +step:1066/1555 train_time:53029ms step_avg:49.75ms +step:1067/1555 train_time:53112ms step_avg:49.78ms +step:1068/1555 train_time:53202ms step_avg:49.81ms +step:1069/1555 train_time:53286ms step_avg:49.85ms +step:1070/1555 train_time:53375ms step_avg:49.88ms +step:1071/1555 train_time:53458ms step_avg:49.91ms +step:1072/1555 train_time:53548ms step_avg:49.95ms +step:1073/1555 train_time:53632ms step_avg:49.98ms +step:1074/1555 train_time:53723ms step_avg:50.02ms +step:1075/1555 train_time:53805ms step_avg:50.05ms +step:1076/1555 train_time:53895ms step_avg:50.09ms +step:1077/1555 train_time:53979ms step_avg:50.12ms +step:1078/1555 train_time:54071ms step_avg:50.16ms +step:1079/1555 train_time:54154ms step_avg:50.19ms +step:1080/1555 train_time:54245ms step_avg:50.23ms +step:1081/1555 train_time:54328ms step_avg:50.26ms +step:1082/1555 train_time:54419ms step_avg:50.29ms +step:1083/1555 train_time:54503ms step_avg:50.33ms +step:1084/1555 train_time:54592ms step_avg:50.36ms +step:1085/1555 train_time:54676ms step_avg:50.39ms +step:1086/1555 train_time:54766ms step_avg:50.43ms +step:1087/1555 train_time:54850ms step_avg:50.46ms +step:1088/1555 train_time:54939ms step_avg:50.50ms +step:1089/1555 train_time:55024ms step_avg:50.53ms +step:1090/1555 train_time:55113ms step_avg:50.56ms +step:1091/1555 train_time:55198ms step_avg:50.59ms +step:1092/1555 train_time:55288ms step_avg:50.63ms +step:1093/1555 train_time:55372ms step_avg:50.66ms +step:1094/1555 train_time:55462ms step_avg:50.70ms +step:1095/1555 train_time:55547ms step_avg:50.73ms +step:1096/1555 train_time:55636ms step_avg:50.76ms +step:1097/1555 train_time:55720ms step_avg:50.79ms +step:1098/1555 train_time:55811ms step_avg:50.83ms +step:1099/1555 train_time:55895ms step_avg:50.86ms +step:1100/1555 train_time:55985ms step_avg:50.90ms +step:1101/1555 train_time:56070ms step_avg:50.93ms +step:1102/1555 train_time:56159ms step_avg:50.96ms +step:1103/1555 train_time:56243ms step_avg:50.99ms +step:1104/1555 train_time:56333ms step_avg:51.03ms +step:1105/1555 train_time:56418ms step_avg:51.06ms +step:1106/1555 train_time:56509ms step_avg:51.09ms +step:1107/1555 train_time:56592ms step_avg:51.12ms +step:1108/1555 train_time:56682ms step_avg:51.16ms +step:1109/1555 train_time:56766ms step_avg:51.19ms +step:1110/1555 train_time:56855ms step_avg:51.22ms +step:1111/1555 train_time:56939ms step_avg:51.25ms +step:1112/1555 train_time:57030ms step_avg:51.29ms +step:1113/1555 train_time:57114ms step_avg:51.32ms +step:1114/1555 train_time:57204ms step_avg:51.35ms +step:1115/1555 train_time:57288ms step_avg:51.38ms +step:1116/1555 train_time:57378ms step_avg:51.41ms +step:1117/1555 train_time:57462ms step_avg:51.44ms +step:1118/1555 train_time:57552ms step_avg:51.48ms +step:1119/1555 train_time:57636ms step_avg:51.51ms +step:1120/1555 train_time:57726ms step_avg:51.54ms +step:1121/1555 train_time:57810ms step_avg:51.57ms +step:1122/1555 train_time:57899ms step_avg:51.60ms +step:1123/1555 train_time:57983ms step_avg:51.63ms +step:1124/1555 train_time:58074ms step_avg:51.67ms +step:1125/1555 train_time:58157ms step_avg:51.70ms +step:1126/1555 train_time:58247ms step_avg:51.73ms +step:1127/1555 train_time:58332ms step_avg:51.76ms +step:1128/1555 train_time:58421ms step_avg:51.79ms +step:1129/1555 train_time:58505ms step_avg:51.82ms +step:1130/1555 train_time:58595ms step_avg:51.85ms +step:1131/1555 train_time:58679ms step_avg:51.88ms +step:1132/1555 train_time:58768ms step_avg:51.92ms +step:1133/1555 train_time:58852ms step_avg:51.94ms +step:1134/1555 train_time:58942ms step_avg:51.98ms +step:1135/1555 train_time:59026ms step_avg:52.01ms +step:1136/1555 train_time:59115ms step_avg:52.04ms +step:1137/1555 train_time:59198ms step_avg:52.07ms +step:1138/1555 train_time:59289ms step_avg:52.10ms +step:1139/1555 train_time:59373ms step_avg:52.13ms +step:1140/1555 train_time:59464ms step_avg:52.16ms +step:1141/1555 train_time:59548ms step_avg:52.19ms +step:1142/1555 train_time:59637ms step_avg:52.22ms +step:1143/1555 train_time:59721ms step_avg:52.25ms +step:1144/1555 train_time:59811ms step_avg:52.28ms +step:1145/1555 train_time:59895ms step_avg:52.31ms +step:1146/1555 train_time:59986ms step_avg:52.34ms +step:1147/1555 train_time:60070ms step_avg:52.37ms +step:1148/1555 train_time:60159ms step_avg:52.40ms +step:1149/1555 train_time:60244ms step_avg:52.43ms +step:1150/1555 train_time:60334ms step_avg:52.46ms +step:1151/1555 train_time:60418ms step_avg:52.49ms +step:1152/1555 train_time:60509ms step_avg:52.53ms +step:1153/1555 train_time:60593ms step_avg:52.55ms +step:1154/1555 train_time:60683ms step_avg:52.59ms +step:1155/1555 train_time:60767ms step_avg:52.61ms +step:1156/1555 train_time:60857ms step_avg:52.64ms +step:1157/1555 train_time:60941ms step_avg:52.67ms +step:1158/1555 train_time:61031ms step_avg:52.70ms +step:1159/1555 train_time:61115ms step_avg:52.73ms +step:1160/1555 train_time:61205ms step_avg:52.76ms +step:1161/1555 train_time:61289ms step_avg:52.79ms +step:1162/1555 train_time:61379ms step_avg:52.82ms +step:1163/1555 train_time:61463ms step_avg:52.85ms +step:1164/1555 train_time:61553ms step_avg:52.88ms +step:1165/1555 train_time:61638ms step_avg:52.91ms +step:1166/1555 train_time:61728ms step_avg:52.94ms +step:1167/1555 train_time:61811ms step_avg:52.97ms +step:1168/1555 train_time:61902ms step_avg:53.00ms +step:1169/1555 train_time:61986ms step_avg:53.02ms +step:1170/1555 train_time:62075ms step_avg:53.06ms +step:1171/1555 train_time:62158ms step_avg:53.08ms +step:1172/1555 train_time:62248ms step_avg:53.11ms +step:1173/1555 train_time:62332ms step_avg:53.14ms +step:1174/1555 train_time:62422ms step_avg:53.17ms +step:1175/1555 train_time:62506ms step_avg:53.20ms +step:1176/1555 train_time:62596ms step_avg:53.23ms +step:1177/1555 train_time:62679ms step_avg:53.25ms +step:1178/1555 train_time:62770ms step_avg:53.29ms +step:1179/1555 train_time:62854ms step_avg:53.31ms +step:1180/1555 train_time:62945ms step_avg:53.34ms +step:1181/1555 train_time:63028ms step_avg:53.37ms +step:1182/1555 train_time:63118ms step_avg:53.40ms +step:1183/1555 train_time:63202ms step_avg:53.43ms +step:1184/1555 train_time:63292ms step_avg:53.46ms +step:1185/1555 train_time:63376ms step_avg:53.48ms +step:1186/1555 train_time:63467ms step_avg:53.51ms +step:1187/1555 train_time:63551ms step_avg:53.54ms +step:1188/1555 train_time:63640ms step_avg:53.57ms +step:1189/1555 train_time:63724ms step_avg:53.59ms +step:1190/1555 train_time:63814ms step_avg:53.63ms +step:1191/1555 train_time:63898ms step_avg:53.65ms +step:1192/1555 train_time:63988ms step_avg:53.68ms +step:1193/1555 train_time:64073ms step_avg:53.71ms +step:1194/1555 train_time:64162ms step_avg:53.74ms +step:1195/1555 train_time:64246ms step_avg:53.76ms +step:1196/1555 train_time:64336ms step_avg:53.79ms +step:1197/1555 train_time:64420ms step_avg:53.82ms +step:1198/1555 train_time:64510ms step_avg:53.85ms +step:1199/1555 train_time:64594ms step_avg:53.87ms +step:1200/1555 train_time:64684ms step_avg:53.90ms +step:1201/1555 train_time:64768ms step_avg:53.93ms +step:1202/1555 train_time:64858ms step_avg:53.96ms +step:1203/1555 train_time:64943ms step_avg:53.98ms +step:1204/1555 train_time:65034ms step_avg:54.01ms +step:1205/1555 train_time:65117ms step_avg:54.04ms +step:1206/1555 train_time:65209ms step_avg:54.07ms +step:1207/1555 train_time:65293ms step_avg:54.10ms +step:1208/1555 train_time:65383ms step_avg:54.12ms +step:1209/1555 train_time:65467ms step_avg:54.15ms +step:1210/1555 train_time:65556ms step_avg:54.18ms +step:1211/1555 train_time:65641ms step_avg:54.20ms +step:1212/1555 train_time:65732ms step_avg:54.23ms +step:1213/1555 train_time:65816ms step_avg:54.26ms +step:1214/1555 train_time:65905ms step_avg:54.29ms +step:1215/1555 train_time:65990ms step_avg:54.31ms +step:1216/1555 train_time:66080ms step_avg:54.34ms +step:1217/1555 train_time:66164ms step_avg:54.37ms +step:1218/1555 train_time:66254ms step_avg:54.40ms +step:1219/1555 train_time:66336ms step_avg:54.42ms +step:1220/1555 train_time:66427ms step_avg:54.45ms +step:1221/1555 train_time:66511ms step_avg:54.47ms +step:1222/1555 train_time:66603ms step_avg:54.50ms +step:1223/1555 train_time:66687ms step_avg:54.53ms +step:1224/1555 train_time:66776ms step_avg:54.56ms +step:1225/1555 train_time:66860ms step_avg:54.58ms +step:1226/1555 train_time:66951ms step_avg:54.61ms +step:1227/1555 train_time:67035ms step_avg:54.63ms +step:1228/1555 train_time:67124ms step_avg:54.66ms +step:1229/1555 train_time:67208ms step_avg:54.69ms +step:1230/1555 train_time:67299ms step_avg:54.71ms +step:1231/1555 train_time:67384ms step_avg:54.74ms +step:1232/1555 train_time:67475ms step_avg:54.77ms +step:1233/1555 train_time:67558ms step_avg:54.79ms +step:1234/1555 train_time:67649ms step_avg:54.82ms +step:1235/1555 train_time:67733ms step_avg:54.84ms +step:1236/1555 train_time:67822ms step_avg:54.87ms +step:1237/1555 train_time:67906ms step_avg:54.90ms +step:1238/1555 train_time:67996ms step_avg:54.92ms +step:1239/1555 train_time:68081ms step_avg:54.95ms +step:1240/1555 train_time:68171ms step_avg:54.98ms +step:1241/1555 train_time:68255ms step_avg:55.00ms +step:1242/1555 train_time:68344ms step_avg:55.03ms +step:1243/1555 train_time:68429ms step_avg:55.05ms +step:1244/1555 train_time:68518ms step_avg:55.08ms +step:1245/1555 train_time:68603ms step_avg:55.10ms +step:1246/1555 train_time:68694ms step_avg:55.13ms +step:1247/1555 train_time:68776ms step_avg:55.15ms +step:1248/1555 train_time:68867ms step_avg:55.18ms +step:1249/1555 train_time:68951ms step_avg:55.20ms +step:1250/1555 train_time:69041ms step_avg:55.23ms +step:1250/1555 val_loss:3.3996 train_time:69156ms step_avg:55.32ms +step:1251/1555 train_time:69176ms step_avg:55.30ms +step:1252/1555 train_time:69216ms step_avg:55.28ms +step:1253/1555 train_time:69303ms step_avg:55.31ms +step:1254/1555 train_time:69397ms step_avg:55.34ms +step:1255/1555 train_time:69480ms step_avg:55.36ms +step:1256/1555 train_time:69569ms step_avg:55.39ms +step:1257/1555 train_time:69652ms step_avg:55.41ms +step:1258/1555 train_time:69741ms step_avg:55.44ms +step:1259/1555 train_time:69824ms step_avg:55.46ms +step:1260/1555 train_time:69914ms step_avg:55.49ms +step:1261/1555 train_time:69997ms step_avg:55.51ms +step:1262/1555 train_time:70087ms step_avg:55.54ms +step:1263/1555 train_time:70172ms step_avg:55.56ms +step:1264/1555 train_time:70266ms step_avg:55.59ms +step:1265/1555 train_time:70354ms step_avg:55.62ms +step:1266/1555 train_time:70444ms step_avg:55.64ms +step:1267/1555 train_time:70529ms step_avg:55.67ms +step:1268/1555 train_time:70619ms step_avg:55.69ms +step:1269/1555 train_time:70701ms step_avg:55.71ms +step:1270/1555 train_time:70790ms step_avg:55.74ms +step:1271/1555 train_time:70874ms step_avg:55.76ms +step:1272/1555 train_time:70963ms step_avg:55.79ms +step:1273/1555 train_time:71047ms step_avg:55.81ms +step:1274/1555 train_time:71138ms step_avg:55.84ms +step:1275/1555 train_time:71223ms step_avg:55.86ms +step:1276/1555 train_time:71316ms step_avg:55.89ms +step:1277/1555 train_time:71400ms step_avg:55.91ms +step:1278/1555 train_time:71491ms step_avg:55.94ms +step:1279/1555 train_time:71575ms step_avg:55.96ms +step:1280/1555 train_time:71665ms step_avg:55.99ms +step:1281/1555 train_time:71749ms step_avg:56.01ms +step:1282/1555 train_time:71839ms step_avg:56.04ms +step:1283/1555 train_time:71921ms step_avg:56.06ms +step:1284/1555 train_time:72011ms step_avg:56.08ms +step:1285/1555 train_time:72096ms step_avg:56.11ms +step:1286/1555 train_time:72186ms step_avg:56.13ms +step:1287/1555 train_time:72272ms step_avg:56.16ms +step:1288/1555 train_time:72362ms step_avg:56.18ms +step:1289/1555 train_time:72447ms step_avg:56.20ms +step:1290/1555 train_time:72538ms step_avg:56.23ms +step:1291/1555 train_time:72622ms step_avg:56.25ms +step:1292/1555 train_time:72712ms step_avg:56.28ms +step:1293/1555 train_time:72796ms step_avg:56.30ms +step:1294/1555 train_time:72885ms step_avg:56.32ms +step:1295/1555 train_time:72969ms step_avg:56.35ms +step:1296/1555 train_time:73058ms step_avg:56.37ms +step:1297/1555 train_time:73142ms step_avg:56.39ms +step:1298/1555 train_time:73233ms step_avg:56.42ms +step:1299/1555 train_time:73317ms step_avg:56.44ms +step:1300/1555 train_time:73408ms step_avg:56.47ms +step:1301/1555 train_time:73492ms step_avg:56.49ms +step:1302/1555 train_time:73582ms step_avg:56.51ms +step:1303/1555 train_time:73666ms step_avg:56.54ms +step:1304/1555 train_time:73756ms step_avg:56.56ms +step:1305/1555 train_time:73839ms step_avg:56.58ms +step:1306/1555 train_time:73929ms step_avg:56.61ms +step:1307/1555 train_time:74014ms step_avg:56.63ms +step:1308/1555 train_time:74103ms step_avg:56.65ms +step:1309/1555 train_time:74188ms step_avg:56.68ms +step:1310/1555 train_time:74277ms step_avg:56.70ms +step:1311/1555 train_time:74361ms step_avg:56.72ms +step:1312/1555 train_time:74453ms step_avg:56.75ms +step:1313/1555 train_time:74537ms step_avg:56.77ms +step:1314/1555 train_time:74627ms step_avg:56.79ms +step:1315/1555 train_time:74711ms step_avg:56.81ms +step:1316/1555 train_time:74801ms step_avg:56.84ms +step:1317/1555 train_time:74884ms step_avg:56.86ms +step:1318/1555 train_time:74975ms step_avg:56.89ms +step:1319/1555 train_time:75059ms step_avg:56.91ms +step:1320/1555 train_time:75149ms step_avg:56.93ms +step:1321/1555 train_time:75233ms step_avg:56.95ms +step:1322/1555 train_time:75323ms step_avg:56.98ms +step:1323/1555 train_time:75407ms step_avg:57.00ms +step:1324/1555 train_time:75499ms step_avg:57.02ms +step:1325/1555 train_time:75582ms step_avg:57.04ms +step:1326/1555 train_time:75672ms step_avg:57.07ms +step:1327/1555 train_time:75756ms step_avg:57.09ms +step:1328/1555 train_time:75846ms step_avg:57.11ms +step:1329/1555 train_time:75931ms step_avg:57.13ms +step:1330/1555 train_time:76020ms step_avg:57.16ms +step:1331/1555 train_time:76104ms step_avg:57.18ms +step:1332/1555 train_time:76195ms step_avg:57.20ms +step:1333/1555 train_time:76279ms step_avg:57.22ms +step:1334/1555 train_time:76368ms step_avg:57.25ms +step:1335/1555 train_time:76454ms step_avg:57.27ms +step:1336/1555 train_time:76543ms step_avg:57.29ms +step:1337/1555 train_time:76627ms step_avg:57.31ms +step:1338/1555 train_time:76717ms step_avg:57.34ms +step:1339/1555 train_time:76802ms step_avg:57.36ms +step:1340/1555 train_time:76891ms step_avg:57.38ms +step:1341/1555 train_time:76975ms step_avg:57.40ms +step:1342/1555 train_time:77065ms step_avg:57.43ms +step:1343/1555 train_time:77150ms step_avg:57.45ms +step:1344/1555 train_time:77240ms step_avg:57.47ms +step:1345/1555 train_time:77323ms step_avg:57.49ms +step:1346/1555 train_time:77415ms step_avg:57.51ms +step:1347/1555 train_time:77499ms step_avg:57.53ms +step:1348/1555 train_time:77590ms step_avg:57.56ms +step:1349/1555 train_time:77675ms step_avg:57.58ms +step:1350/1555 train_time:77764ms step_avg:57.60ms +step:1351/1555 train_time:77848ms step_avg:57.62ms +step:1352/1555 train_time:77938ms step_avg:57.65ms +step:1353/1555 train_time:78021ms step_avg:57.67ms +step:1354/1555 train_time:78113ms step_avg:57.69ms +step:1355/1555 train_time:78197ms step_avg:57.71ms +step:1356/1555 train_time:78286ms step_avg:57.73ms +step:1357/1555 train_time:78372ms step_avg:57.75ms +step:1358/1555 train_time:78461ms step_avg:57.78ms +step:1359/1555 train_time:78546ms step_avg:57.80ms +step:1360/1555 train_time:78636ms step_avg:57.82ms +step:1361/1555 train_time:78719ms step_avg:57.84ms +step:1362/1555 train_time:78810ms step_avg:57.86ms +step:1363/1555 train_time:78895ms step_avg:57.88ms +step:1364/1555 train_time:78983ms step_avg:57.91ms +step:1365/1555 train_time:79068ms step_avg:57.93ms +step:1366/1555 train_time:79157ms step_avg:57.95ms +step:1367/1555 train_time:79241ms step_avg:57.97ms +step:1368/1555 train_time:79331ms step_avg:57.99ms +step:1369/1555 train_time:79415ms step_avg:58.01ms +step:1370/1555 train_time:79505ms step_avg:58.03ms +step:1371/1555 train_time:79590ms step_avg:58.05ms +step:1372/1555 train_time:79680ms step_avg:58.08ms +step:1373/1555 train_time:79763ms step_avg:58.09ms +step:1374/1555 train_time:79854ms step_avg:58.12ms +step:1375/1555 train_time:79938ms step_avg:58.14ms +step:1376/1555 train_time:80028ms step_avg:58.16ms +step:1377/1555 train_time:80111ms step_avg:58.18ms +step:1378/1555 train_time:80201ms step_avg:58.20ms +step:1379/1555 train_time:80285ms step_avg:58.22ms +step:1380/1555 train_time:80375ms step_avg:58.24ms +step:1381/1555 train_time:80459ms step_avg:58.26ms +step:1382/1555 train_time:80549ms step_avg:58.28ms +step:1383/1555 train_time:80633ms step_avg:58.30ms +step:1384/1555 train_time:80723ms step_avg:58.33ms +step:1385/1555 train_time:80807ms step_avg:58.34ms +step:1386/1555 train_time:80897ms step_avg:58.37ms +step:1387/1555 train_time:80980ms step_avg:58.39ms +step:1388/1555 train_time:81071ms step_avg:58.41ms +step:1389/1555 train_time:81155ms step_avg:58.43ms +step:1390/1555 train_time:81245ms step_avg:58.45ms +step:1391/1555 train_time:81330ms step_avg:58.47ms +step:1392/1555 train_time:81420ms step_avg:58.49ms +step:1393/1555 train_time:81503ms step_avg:58.51ms +step:1394/1555 train_time:81594ms step_avg:58.53ms +step:1395/1555 train_time:81678ms step_avg:58.55ms +step:1396/1555 train_time:81768ms step_avg:58.57ms +step:1397/1555 train_time:81852ms step_avg:58.59ms +step:1398/1555 train_time:81942ms step_avg:58.61ms +step:1399/1555 train_time:82026ms step_avg:58.63ms +step:1400/1555 train_time:82118ms step_avg:58.66ms +step:1401/1555 train_time:82202ms step_avg:58.67ms +step:1402/1555 train_time:82292ms step_avg:58.70ms +step:1403/1555 train_time:82376ms step_avg:58.71ms +step:1404/1555 train_time:82465ms step_avg:58.74ms +step:1405/1555 train_time:82549ms step_avg:58.75ms +step:1406/1555 train_time:82639ms step_avg:58.78ms +step:1407/1555 train_time:82723ms step_avg:58.79ms +step:1408/1555 train_time:82814ms step_avg:58.82ms +step:1409/1555 train_time:82898ms step_avg:58.83ms +step:1410/1555 train_time:82988ms step_avg:58.86ms +step:1411/1555 train_time:83073ms step_avg:58.88ms +step:1412/1555 train_time:83163ms step_avg:58.90ms +step:1413/1555 train_time:83248ms step_avg:58.92ms +step:1414/1555 train_time:83339ms step_avg:58.94ms +step:1415/1555 train_time:83423ms step_avg:58.96ms +step:1416/1555 train_time:83513ms step_avg:58.98ms +step:1417/1555 train_time:83597ms step_avg:59.00ms +step:1418/1555 train_time:83687ms step_avg:59.02ms +step:1419/1555 train_time:83773ms step_avg:59.04ms +step:1420/1555 train_time:83861ms step_avg:59.06ms +step:1421/1555 train_time:83945ms step_avg:59.07ms +step:1422/1555 train_time:84036ms step_avg:59.10ms +step:1423/1555 train_time:84119ms step_avg:59.11ms +step:1424/1555 train_time:84210ms step_avg:59.14ms +step:1425/1555 train_time:84295ms step_avg:59.15ms +step:1426/1555 train_time:84385ms step_avg:59.18ms +step:1427/1555 train_time:84470ms step_avg:59.19ms +step:1428/1555 train_time:84559ms step_avg:59.21ms +step:1429/1555 train_time:84643ms step_avg:59.23ms +step:1430/1555 train_time:84734ms step_avg:59.25ms +step:1431/1555 train_time:84818ms step_avg:59.27ms +step:1432/1555 train_time:84908ms step_avg:59.29ms +step:1433/1555 train_time:84991ms step_avg:59.31ms +step:1434/1555 train_time:85081ms step_avg:59.33ms +step:1435/1555 train_time:85165ms step_avg:59.35ms +step:1436/1555 train_time:85256ms step_avg:59.37ms +step:1437/1555 train_time:85339ms step_avg:59.39ms +step:1438/1555 train_time:85430ms step_avg:59.41ms +step:1439/1555 train_time:85514ms step_avg:59.43ms +step:1440/1555 train_time:85604ms step_avg:59.45ms +step:1441/1555 train_time:85689ms step_avg:59.47ms +step:1442/1555 train_time:85780ms step_avg:59.49ms +step:1443/1555 train_time:85863ms step_avg:59.50ms +step:1444/1555 train_time:85953ms step_avg:59.52ms +step:1445/1555 train_time:86038ms step_avg:59.54ms +step:1446/1555 train_time:86128ms step_avg:59.56ms +step:1447/1555 train_time:86213ms step_avg:59.58ms +step:1448/1555 train_time:86302ms step_avg:59.60ms +step:1449/1555 train_time:86387ms step_avg:59.62ms +step:1450/1555 train_time:86477ms step_avg:59.64ms +step:1451/1555 train_time:86561ms step_avg:59.66ms +step:1452/1555 train_time:86651ms step_avg:59.68ms +step:1453/1555 train_time:86735ms step_avg:59.69ms +step:1454/1555 train_time:86825ms step_avg:59.71ms +step:1455/1555 train_time:86909ms step_avg:59.73ms +step:1456/1555 train_time:87000ms step_avg:59.75ms +step:1457/1555 train_time:87083ms step_avg:59.77ms +step:1458/1555 train_time:87174ms step_avg:59.79ms +step:1459/1555 train_time:87258ms step_avg:59.81ms +step:1460/1555 train_time:87348ms step_avg:59.83ms +step:1461/1555 train_time:87432ms step_avg:59.84ms +step:1462/1555 train_time:87521ms step_avg:59.86ms +step:1463/1555 train_time:87606ms step_avg:59.88ms +step:1464/1555 train_time:87698ms step_avg:59.90ms +step:1465/1555 train_time:87781ms step_avg:59.92ms +step:1466/1555 train_time:87872ms step_avg:59.94ms +step:1467/1555 train_time:87955ms step_avg:59.96ms +step:1468/1555 train_time:88045ms step_avg:59.98ms +step:1469/1555 train_time:88130ms step_avg:59.99ms +step:1470/1555 train_time:88219ms step_avg:60.01ms +step:1471/1555 train_time:88302ms step_avg:60.03ms +step:1472/1555 train_time:88393ms step_avg:60.05ms +step:1473/1555 train_time:88478ms step_avg:60.07ms +step:1474/1555 train_time:88567ms step_avg:60.09ms +step:1475/1555 train_time:88652ms step_avg:60.10ms +step:1476/1555 train_time:88741ms step_avg:60.12ms +step:1477/1555 train_time:88825ms step_avg:60.14ms +step:1478/1555 train_time:88916ms step_avg:60.16ms +step:1479/1555 train_time:89000ms step_avg:60.18ms +step:1480/1555 train_time:89090ms step_avg:60.20ms +step:1481/1555 train_time:89176ms step_avg:60.21ms +step:1482/1555 train_time:89264ms step_avg:60.23ms +step:1483/1555 train_time:89349ms step_avg:60.25ms +step:1484/1555 train_time:89439ms step_avg:60.27ms +step:1485/1555 train_time:89523ms step_avg:60.28ms +step:1486/1555 train_time:89613ms step_avg:60.31ms +step:1487/1555 train_time:89698ms step_avg:60.32ms +step:1488/1555 train_time:89786ms step_avg:60.34ms +step:1489/1555 train_time:89870ms step_avg:60.36ms +step:1490/1555 train_time:89961ms step_avg:60.38ms +step:1491/1555 train_time:90044ms step_avg:60.39ms +step:1492/1555 train_time:90133ms step_avg:60.41ms +step:1493/1555 train_time:90218ms step_avg:60.43ms +step:1494/1555 train_time:90308ms step_avg:60.45ms +step:1495/1555 train_time:90393ms step_avg:60.46ms +step:1496/1555 train_time:90482ms step_avg:60.48ms +step:1497/1555 train_time:90568ms step_avg:60.50ms +step:1498/1555 train_time:90657ms step_avg:60.52ms +step:1499/1555 train_time:90740ms step_avg:60.53ms +step:1500/1555 train_time:90831ms step_avg:60.55ms +step:1500/1555 val_loss:3.2959 train_time:90947ms step_avg:60.63ms +step:1501/1555 train_time:90966ms step_avg:60.60ms +step:1502/1555 train_time:91009ms step_avg:60.59ms +step:1503/1555 train_time:91100ms step_avg:60.61ms +step:1504/1555 train_time:91192ms step_avg:60.63ms +step:1505/1555 train_time:91277ms step_avg:60.65ms +step:1506/1555 train_time:91368ms step_avg:60.67ms +step:1507/1555 train_time:91450ms step_avg:60.68ms +step:1508/1555 train_time:91539ms step_avg:60.70ms +step:1509/1555 train_time:91621ms step_avg:60.72ms +step:1510/1555 train_time:91710ms step_avg:60.74ms +step:1511/1555 train_time:91793ms step_avg:60.75ms +step:1512/1555 train_time:91884ms step_avg:60.77ms +step:1513/1555 train_time:91969ms step_avg:60.79ms +step:1514/1555 train_time:92064ms step_avg:60.81ms +step:1515/1555 train_time:92149ms step_avg:60.82ms +step:1516/1555 train_time:92244ms step_avg:60.85ms +step:1517/1555 train_time:92330ms step_avg:60.86ms +step:1518/1555 train_time:92418ms step_avg:60.88ms +step:1519/1555 train_time:92502ms step_avg:60.90ms +step:1520/1555 train_time:92591ms step_avg:60.92ms +step:1521/1555 train_time:92675ms step_avg:60.93ms +step:1522/1555 train_time:92764ms step_avg:60.95ms +step:1523/1555 train_time:92847ms step_avg:60.96ms +step:1524/1555 train_time:92938ms step_avg:60.98ms +step:1525/1555 train_time:93026ms step_avg:61.00ms +step:1526/1555 train_time:93119ms step_avg:61.02ms +step:1527/1555 train_time:93204ms step_avg:61.04ms +step:1528/1555 train_time:93295ms step_avg:61.06ms +step:1529/1555 train_time:93379ms step_avg:61.07ms +step:1530/1555 train_time:93469ms step_avg:61.09ms +step:1531/1555 train_time:93552ms step_avg:61.11ms +step:1532/1555 train_time:93642ms step_avg:61.12ms +step:1533/1555 train_time:93726ms step_avg:61.14ms +step:1534/1555 train_time:93815ms step_avg:61.16ms +step:1535/1555 train_time:93899ms step_avg:61.17ms +step:1536/1555 train_time:93991ms step_avg:61.19ms +step:1537/1555 train_time:94077ms step_avg:61.21ms +step:1538/1555 train_time:94168ms step_avg:61.23ms +step:1539/1555 train_time:94253ms step_avg:61.24ms +step:1540/1555 train_time:94344ms step_avg:61.26ms +step:1541/1555 train_time:94428ms step_avg:61.28ms +step:1542/1555 train_time:94519ms step_avg:61.30ms +step:1543/1555 train_time:94603ms step_avg:61.31ms +step:1544/1555 train_time:94692ms step_avg:61.33ms +step:1545/1555 train_time:94777ms step_avg:61.34ms +step:1546/1555 train_time:94867ms step_avg:61.36ms +step:1547/1555 train_time:94951ms step_avg:61.38ms +step:1548/1555 train_time:95042ms step_avg:61.40ms +step:1549/1555 train_time:95128ms step_avg:61.41ms +step:1550/1555 train_time:95220ms step_avg:61.43ms +step:1551/1555 train_time:95304ms step_avg:61.45ms +step:1552/1555 train_time:95396ms step_avg:61.47ms +step:1553/1555 train_time:95480ms step_avg:61.48ms +step:1554/1555 train_time:95570ms step_avg:61.50ms +step:1555/1555 train_time:95654ms step_avg:61.51ms +step:1555/1555 val_loss:3.2796 train_time:95768ms step_avg:61.59ms +peak memory allocated: 31630 MiB reserved: 46498 MiB diff --git a/records/track_1_short/2026-01-31-BigramHashH2D/9e15fb87-54d3-4efa-b03c-75d0e4c2c734.txt b/records/track_1_short/2026-01-31-BigramHashH2D/9e15fb87-54d3-4efa-b03c-75d0e4c2c734.txt new file mode 100644 index 000000000..937027ae3 --- /dev/null +++ b/records/track_1_short/2026-01-31-BigramHashH2D/9e15fb87-54d3-4efa-b03c-75d0e4c2c734.txt @@ -0,0 +1,3976 @@ +import os +import sys + +# Read the current file and the kernels file code ASAP, for logging +with open(sys.argv[0], 'r') as f: + code = f.read() +with open(os.path.join(os.path.dirname(sys.argv[0]), 'triton_kernels.py'), 'r') as f: + code += f"\n\n{'-'*40}\n# triton_kernels.py\n{'-'*40}\n\n" + code += f.read() + +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate, pairwise +from pathlib import Path +import gc + +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +import torch +import triton + +torch.empty( + 1, device=f"cuda:{os.environ['LOCAL_RANK']}", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +from kernels import get_kernel +from torch import Tensor, nn + +from triton_kernels import XXT, ba_plus_cAA, FusedLinearReLUSquareFunction, FusedSoftcappedCrossEntropy + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Distributed training setup +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +grad_scale = 2 / grad_accum_steps # consistent grad magnitudes between different num_devices +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng +# Transposed layout by @ChrisJMcCormick allows for faster gradient accumulation. + +@torch.library.custom_op("nanogpt::mm_t", mutates_args=()) +def mm_t_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + """Computes y = x @ w with F8 weights stored as (in_features, out_features).""" + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + assert x.shape[1] == w.shape[0] # x: (batch, in), w: (in, out) + + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + + # _scaled_mm requires column-major B. w_f8 is row-major (in, out). + # .T.contiguous().T creates a column-major view without changing logical shape. + w_f8_col_major = w_f8.T.contiguous().T + + out = torch._scaled_mm( + x_f8, + w_f8_col_major, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_t_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[0] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_t_backward", mutates_args=()) +def mm_t_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + + x_scale = grad.new_tensor(x_s, dtype=torch.float32) + w_scale = grad.new_tensor(w_s, dtype=torch.float32) + grad_scale = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + + # grad_x = grad @ w.T + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=grad_scale, + scale_b=w_scale, + use_fast_accum=False, + ) + + # grad_w = x.T @ grad + # Result is (in, out), naturally matching weight storage. No final .T needed. + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_scale, + scale_b=grad_scale, + use_fast_accum=False, + ) + + return grad_x, grad_w + + grad_x, grad_w = impl(g, x_f8, w_f8) + + return grad_x, grad_w + +@mm_t_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) + +def backward_t(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_t_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context_t(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_t_op.register_autograd(backward_t, setup_context=setup_context_t) + +# ----------------------------------------------------------------------------- +# Polar Express + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor, split_baddbmm: bool = False): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + # Select batched vs unbatched + if split_baddbmm: + BX_matmul = torch.bmm if X.ndim > 2 else torch.mm + else: + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in polar_express_coeffs: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + + # Referencing X twice causes pytorch to make a defensive copy, + # resulting in a cudaMemcpyAsync in baddbmm. + # For large matrices (i.e., the mlp weights), it's faster to split + # the operation into two kernels to avoid this. + if split_baddbmm: + BX_matmul(B, X, out=C) # C = B @ X + C.add_(X, alpha=a) # C = C + a*X (in-place, X only read) + else: + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Combined NorMuon + Adam Optimizer + +@dataclass +class ParamConfig: + """Per-parameter configuration for NorMuonAndAdam optimizer.""" + label: str + optim: str # "adam" or "normuon" + comms: str # "none", "replicated", or "sharded" + adam_betas: tuple[float, float] | None + lr_mul: float + wd_mul: float + lr: float + initial_lr: float + weight_decay: float + # Adam-specific + eps: float | None = None + # NorMuon-specific + reshape: tuple | None = None + chunk_size: int | None = None + momentum: float | None = None + beta2: float | None = None + per_matrix_lr_mul: list[float] | None = None + + +class NorMuonAndAdam: + """ + Combined optimizer that handles both NorMuon (for projection matrices) and + Adam (for embeddings/scalars/gate weights). + + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, Muon uses a Newton-Schulz iteration (replaced + here with Polar Express), which has the advantage that it can be stably run in bfloat16 on the GPU. + + Muon is applied only to the projection matrices in the attention and MLP layers, and is not recommended + for embeddings, scalars, or individual weight vectors (e.g., bias terms or gate weights). + + Differences from standard Muon: + - Newton-Shulz is replaced with Polar Express for the orthogonalization step + - NorMuon adds a low-rank variance estimator similar to Adafactor. https://arxiv.org/pdf/2510.05491 + - Cautious weight decay, a gated version of decoupled weight decay + - Mantissa tracking for precision + + Adam (for embeddings/scalars/gates): + - Standard Adam with bias correction + - Cautious weight decay + + Configuration: + Unlike torch.optim.Optimizer, this class uses per-parameter configs from a `param_table` dict + and does not include parameter "groups". All parameters require a .label attribute, and a + corresponding entry in the param_table to specify their hyperparameters (lr_mul, wd_mul, adam_betas, etc.). + + Communication and ordering: + Gradient communication is explicitly scheduled rather than hook-driven. + Reductions are launched in `scatter_order`, while update math and final + gathers are executed in `work_order`. These orders are independent and + must each contain every parameter label exactly once. + + Two communication modes are supported per parameter: + - 'replicated': Gradients are all-reduced and each rank computes the full update. + - 'sharded': Gradients are reduce-scattered, each rank updates its shard, + and results are all-gathered. + + Adam parameters may be freely sharded. NorMuon operates on full matrices; sharding is + supported by grouping matrices into parameter banks. NorMuon parameters must have a + `.reshape` attribute that reshapes the bank so that the leading dimension is divisible + by world_size. + + # Contributors include @YouJiacheng, @KonstantinWilleke, @alexrgilbert, @adricarda, + # @tuttyfrutyee, @vdlad, @ryanyang0, @vagrawal, @varunneal, @chrisjmccormick + """ + def __init__(self, named_params, param_table: dict, scatter_order: list, work_order: list, + adam_defaults: dict, normuon_defaults: dict): + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # Store defaults for each optimizer type + self.adam_defaults = adam_defaults + self.normuon_defaults = normuon_defaults + self.param_table = param_table + self.scatter_order = scatter_order + self.work_order = work_order + + # Collect params by label and build config + self.param_cfgs: dict[nn.Parameter, ParamConfig] = {} + self.param_states: dict[nn.Parameter, dict] = {} + self._param_by_label: dict[str, nn.Parameter] = {} + for name, param in named_params: + label = getattr(param, "label", None) + assert label is not None and label in param_table # all params must have valid label + assert label not in self._param_by_label # exactly one param per label + self._param_by_label[label] = param + self._build_param_cfg(param, label) + + # Assert scatter_order and work_order match present labels exactly + present = set(self._param_by_label.keys()) + assert set(scatter_order) == present and set(work_order) == present + + # Handle world_size=1: overwrite comms to "none" + if self.world_size == 1: + for p_cfg in self.param_cfgs.values(): + p_cfg.comms = "none" + + # Initialize state for all params + self._init_state() + + # 0-D CPU tensors to avoid recompilation + self._step_size_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + + # Track async operations + self._reduce_futures: dict[nn.Parameter, tuple] = {} + + # Embed/lm_head tying state + self.split_embed = False + self._lm_head_param = self._param_by_label.get("lm_head") + self._embed_param = self._param_by_label.get("embed") + + def _build_param_cfg(self, param: nn.Parameter, label: str): + """Build config for a single parameter from param_table.""" + table_entry = self.param_table[label] + optim = table_entry["optim"] + comms = table_entry["comms"] + adam_betas = table_entry.get("adam_betas") + lr_mul = table_entry.get("lr_mul", 1.0) + wd_mul = table_entry.get("wd_mul", 1.0) + + if optim == "adam": + chunk_size = param.shape[0] // self.world_size if comms == "sharded" else None + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.adam_defaults["lr"], + initial_lr=self.adam_defaults["lr"], + weight_decay=self.adam_defaults["weight_decay"], + eps=self.adam_defaults["eps"], + chunk_size=chunk_size, + ) + elif optim == "normuon": + reshape = getattr(param, "reshape", None) + if reshape is None: + raise ValueError(f"NorMuon param {label} must have .reshape attribute") + if reshape[0] % self.world_size != 0: + raise ValueError(f"reshape[0]={reshape[0]} must be divisible by world_size") + + chunk_size = reshape[0] // self.world_size + chunk_shape = (chunk_size, *reshape[1:]) + # Shape-based LR multiplier for NorMuon + shape_mult = max(1.0, chunk_shape[-2] / chunk_shape[-1]) ** 0.5 if len(chunk_shape) >= 2 else 1.0 + lr_mul = shape_mult * lr_mul + + # Per-matrix LR multipliers for MLP c_proj (2x LR on odd indices) + per_matrix_lr_mul = None + if label == "mlp": + rank = dist.get_rank() if dist.is_initialized() else 0 + start_idx = rank * chunk_size + per_matrix_lr_mul = [] + for i in range(chunk_size): + global_idx = start_idx + i + is_c_proj = (global_idx % 2 == 1) + per_matrix_lr_mul.append(2.0 if is_c_proj else 1.0) + + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.normuon_defaults["lr"], + initial_lr=self.normuon_defaults["lr"], + weight_decay=self.normuon_defaults["weight_decay"], + reshape=reshape, + chunk_size=chunk_size, + momentum=self.normuon_defaults["momentum"], + beta2=self.normuon_defaults["beta2"], + per_matrix_lr_mul=per_matrix_lr_mul, + ) + else: + raise ValueError(f"Unknown optim type: {optim}") + + self.param_cfgs[param] = p_cfg + + def _init_state(self): + """Initialize optimizer state for all parameters.""" + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam": + # Sharded params use chunk state, replicated use full state + if p_cfg.comms == "sharded": + chunk = param[:p_cfg.chunk_size] + else: + chunk = param + exp_avg = torch.zeros_like(chunk, dtype=torch.float32, device=param.device) + self.param_states[param] = dict(step=0, exp_avg=exp_avg, exp_avg_sq=torch.zeros_like(exp_avg)) + + elif p_cfg.optim == "normuon": + chunk_shape = (p_cfg.chunk_size, *p_cfg.reshape[1:]) + + # Momentum buffer (FP32 for precision) + momentum_buffer = torch.zeros( + chunk_shape, dtype=torch.float32, device=param.device + ) + + # Second momentum buffer - reduced along one dimension + if chunk_shape[-2] >= chunk_shape[-1]: + second_mom_shape = (*chunk_shape[:-1], 1) + else: + second_mom_shape = (*chunk_shape[:-2], 1, chunk_shape[-1]) + second_momentum_buffer = torch.zeros( + second_mom_shape, dtype=torch.float32, device=param.device + ) + + # Mantissa buffer for precision tracking + mantissa = torch.zeros( + chunk_shape, dtype=torch.uint16, device=param.device + ) + + self.param_states[param] = dict( + momentum_buffer=momentum_buffer, + second_momentum_buffer=second_momentum_buffer, + mantissa=mantissa, + ) + + # ----------------------------------- + # Reduce/Gather operations + + def _launch_reduce(self, param: nn.Parameter, grad: Tensor): + """Launch async reduce for a parameter based on its comms policy.""" + p_cfg = self.param_cfgs[param] + + if p_cfg.comms == "none": + if p_cfg.optim == "normuon": + # NorMuon needs reshaped gradient even without communication + grad = grad.view(p_cfg.reshape) + self._reduce_futures[param] = (None, grad) + elif p_cfg.comms == "replicated": + future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + self._reduce_futures[param] = (future, grad) + elif p_cfg.comms == "sharded": + if p_cfg.optim == "normuon": + # NorMuon: reshape before reduce_scatter + grad_reshaped = grad.view(p_cfg.reshape) + grad_chunk = torch.empty( + (p_cfg.chunk_size, *grad_reshaped.shape[1:]), + dtype=grad.dtype, + device=grad.device + ) + future = dist.reduce_scatter_tensor( + grad_chunk, grad_reshaped.contiguous(), op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + else: + # Adam: simple reduce_scatter + grad_chunk = torch.empty_like(grad[:p_cfg.chunk_size]) + future = dist.reduce_scatter_tensor( + grad_chunk, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + + def _launch_gather(self, param: nn.Parameter, p_slice: Tensor) -> "torch.futures.Future": + """Launch async all_gather for a sharded parameter.""" + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "normuon": + full_param = param.data.view(p_cfg.reshape) + assert full_param.is_contiguous() + return dist.all_gather_into_tensor( + full_param, p_slice.contiguous(), async_op=True + ).get_future() + else: + return dist.all_gather_into_tensor( + param, p_slice.contiguous(), async_op=True + ).get_future() + + # ----------------------------------- + # State management + + def reset(self): + """Reset NorMuon momentum buffers and split_embed state (called on training reset).""" + self.split_embed = False + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "normuon": + p_state = self.param_states[param] + p_state["momentum_buffer"].zero_() + p_state["mantissa"].zero_() + p_state["second_momentum_buffer"].zero_() + + def copy_lm_state_to_embed(self): + """ + Copy the optimizer state from the lm_head to the embed at the untie point. + This requires an all-gather + reshard because of different sharding: + - lm_head (768, 50304) is sharded to (96, 50304) per rank (along model_dim) + - embed (50304, 768) is sharded to (6288, 768) per rank (along vocab_size) + + We all-gather the lm_head momentum, transpose it, then each rank takes their + embed shard to get the correct momentum state. + """ + lm_head = self._lm_head_param + embed = self._embed_param + lm_state = self.param_states[lm_head] + embed_state = self.param_states[embed] + lm_cfg = self.param_cfgs[lm_head] + embed_cfg = self.param_cfgs[embed] + + embed_state['step'] = lm_state['step'] # Preserve step count for bias correction + + # Copy optimizer state with all-gather + transpose + reshard + if self.world_size > 1: + rank = dist.get_rank() + lm_chunk_size = lm_cfg.chunk_size # 96 + embed_chunk_size = embed_cfg.chunk_size # 6288 + + # All-gather lm_head momentum to get full (768, 50304) tensor + for key in ["exp_avg", "exp_avg_sq"]: + lm_chunk = lm_state[key] # (96, 50304) + full_lm = torch.empty(lm_head.shape[0], lm_head.shape[1], dtype=lm_chunk.dtype, device=lm_chunk.device) + dist.all_gather_into_tensor(full_lm, lm_chunk.contiguous()) + embed_state[key].copy_(full_lm.T[rank * embed_chunk_size:(rank + 1) * embed_chunk_size]) + else: + # Single GPU: simple transpose + for key in ["exp_avg", "exp_avg_sq"]: + embed_state[key].copy_(lm_state[key].T) + + # Mark as split + self.split_embed = True + + def state_dict(self): + """Return the optimizer state as a dict.""" + return { + "param_states": {id(p): s for p, s in self.param_states.items()}, + "param_cfgs": {id(p): s for p, s in self.param_cfgs.items()}, + } + + def load_state_dict(self, state_dict): + """Load optimizer state from a dict.""" + # Build id->param mapping + id_to_param = {id(p): p for p in self.param_cfgs.keys()} + + # Load state, preserving dtypes + for param_id, saved_p_state in state_dict["param_states"].items(): + if param_id in id_to_param: + param = id_to_param[param_id] + p_state = self.param_states[param] + for k, v in saved_p_state.items(): + if isinstance(v, torch.Tensor) and k in p_state: + target_dtype = p_state[k].dtype + p_state[k] = v.to(dtype=target_dtype, device=p_state[k].device) + else: + p_state[k] = v + + # ----------------------------------- + # Unified optimizer step with explicit ordering + + @torch.no_grad() + def step(self, do_adam: bool = True): + """ + Combined optimizer step with explicit ordering. + + Args: + do_adam: If True, update Adam params. NorMuon params always updated. + + Flow: + 1. Scatter phase: Launch reduces in scatter_order + 2. Work phase: Process updates in work_order + - Wait for reduce, compute update, launch gather + 3. Finalize phase: Wait for gathers + + While the embeddings are tied: + - Comms and update math are only done on lm_head. + - We add embed.grad.T into lm_head.grad before comms. + - After lm_head gather, we copy lm_head.data.T --> embed.data + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + lm_param, embed_param = self._lm_head_param, self._embed_param + + # ===== Phase 1: Launch reduces in scatter_order ===== + for label in self.scatter_order: + param = self._param_by_label[label] + p_cfg = self.param_cfgs[param] + + if p_cfg.optim == "adam" and not do_adam: + continue + if param.grad is None: + continue + + # lm_head when tied: aggregate embed.grad.T (transposed shapes) + if label == "lm_head" and do_adam and not self.split_embed: + if embed_param is not None and embed_param.grad is not None: + param.grad.add_(embed_param.grad.T) + + # Skip embed when tied (copied from lm_head after gather) + if label == "embed" and not self.split_embed: + continue + + self._launch_reduce(param, param.grad) + + # ===== Phase 2: Process updates in work_order ===== + gather_futures = [] + lm_head_gather_future = None + + for label in self.work_order: + param = self._param_by_label[label] + if param not in self._reduce_futures: + continue + + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "adam" and not do_adam: + continue + # Wait for reduce + future, grad_chunk = self._reduce_futures[param] + if future is not None: + future.wait() + # Apply update based on optim type + if p_cfg.optim == "adam": + p_slice = self._adam_update(param, grad_chunk, p_cfg, rank) + else: + p_slice = self._normuon_update(param, grad_chunk, p_cfg, rank) + # Launch gather for sharded params + if p_cfg.comms == "sharded" and self.world_size > 1: + gather_fut = self._launch_gather(param, p_slice) + if label == "lm_head": + lm_head_gather_future = gather_fut + else: + gather_futures.append(gather_fut) + + # ===== Phase 3: Wait for gathers, sync embed if tied ===== + # Wait for lm_head gather first so we can copy to embed while other gathers complete + if lm_head_gather_future is not None: + lm_head_gather_future.wait() + + # When tied: copy lm_head.T to embed + if do_adam and not self.split_embed and embed_param is not None and lm_param is not None: + embed_param.data.copy_(lm_param.data.T) + + # Wait for remaining gathers + for fut in gather_futures: + fut.wait() + + self._reduce_futures.clear() + + # Clear grads for updated params + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam" and not do_adam: + continue # Don't clear Adam grads on even steps + param.grad = None + + # ----------------------------------- + # Adam update + + def _adam_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply Adam update to a parameter. Returns the updated p_slice.""" + beta1, beta2 = p_cfg.adam_betas + lr = p_cfg.lr * p_cfg.lr_mul + + # Get parameter slice + if p_cfg.comms == "sharded": + p_slice = param[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + else: + p_slice = param + + p_state = self.param_states[param] + p_state["step"] += 1 + t = p_state["step"] + + bias1, bias2 = 1 - beta1 ** t, 1 - beta2 ** t + self._step_size_t.fill_(lr * (bias2 ** 0.5 / bias1)) + self._eff_wd_t.fill_(lr * lr * p_cfg.weight_decay * p_cfg.wd_mul) + + NorMuonAndAdam._adam_update_step( + p_slice, grad_chunk, p_state["exp_avg"], p_state["exp_avg_sq"], + beta1, beta2, p_cfg.eps, self._step_size_t, self._eff_wd_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _adam_update_step(p_slice, g_slice, exp_avg, exp_avg_sq, beta1, beta2, eps, step_size_t, eff_wd_t): + """Compiled Adam update step.""" + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + update = exp_avg.div(exp_avg_sq.sqrt().add_(eps)).mul_(step_size_t) + # Cautious weight decay + mask = (update * p_slice) > 0 + update.addcmul_(p_slice, mask, value=eff_wd_t) + p_slice.add_(other=update, alpha=-1.0) + + # ----------------------------------- + # NorMuon update + + def _normuon_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply NorMuon update to a parameter. Returns the updated p_slice.""" + chunk_shape = grad_chunk.shape + + p_state = self.param_states[param] + grad_chunk = grad_chunk.float() # FP32 for momentum + + # Momentum update + momentum_buffer = p_state["momentum_buffer"] + momentum_buffer.lerp_(grad_chunk, 1 - p_cfg.momentum) + updated_grads = grad_chunk.lerp_(momentum_buffer, p_cfg.momentum) + + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + + # Polar Express orthogonalization + is_large_matrix = chunk_shape[-2] > 1024 + v_chunk = polar_express(updated_grads, split_baddbmm=is_large_matrix) + + # Variance reduction + red_dim = -1 if chunk_shape[-2] >= chunk_shape[-1] else -2 + v_chunk = NorMuonAndAdam._apply_normuon_variance_reduction( + v_chunk, p_state["second_momentum_buffer"], p_cfg.beta2, red_dim + ) + + # Update parameter, in place, with cautious weight decay + param_view = param.data.view(p_cfg.reshape) + p_slice = param_view[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + + # MLP has per-matrix LR multipliers (c_proj gets 2x LR) + if p_cfg.per_matrix_lr_mul is not None: + for mat_idx in range(p_cfg.chunk_size): + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.per_matrix_lr_mul[mat_idx] * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice[mat_idx].view(torch.uint16), p_state["mantissa"][mat_idx], v_chunk[mat_idx], + self._eff_wd_t, self._eff_lr_t + ) + else: + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice.view(torch.uint16), p_state["mantissa"], v_chunk, + self._eff_wd_t, self._eff_lr_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _cautious_wd_and_update_inplace(p, mantissa, grad, wd_tensor, lr_tensor): + """ + Cautious weight decay + parameter update. wd_tensor and lr_tensor are 0-D CPU tensors. + Mantissa is tracked to enable higher precision updates on bfloat16 parameters. + bfloat16 format: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits total + float32 format: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits total + """ + assert p.dtype == mantissa.dtype == torch.uint16 + grad = grad.float() + wd_factor = wd_tensor.to(torch.float32) + lr_factor = lr_tensor.to(torch.float32) + p_precise_raw = (p.to(torch.uint32) << 16) | mantissa.to(torch.uint32) + p_precise = p_precise_raw.view(torch.float32) + mask = (grad * p_precise) >= 0 + p_precise.copy_(p_precise - (p_precise * mask * wd_factor * lr_factor) - (grad * lr_factor)) + p.copy_((p_precise_raw >> 16).to(torch.uint16)) + mantissa.copy_(p_precise_raw.to(torch.uint16)) + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _apply_normuon_variance_reduction(v_chunk, second_momentum_buffer, beta2, red_dim): + """NorMuon variance reduction. Algebraically fuses the normalization steps to minimize memory ops.""" + v_mean = v_chunk.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = v_chunk.size(red_dim) + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True).mul_(red_dim_size) + v_norm = v_norm_sq.sqrt_() + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt_() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt_() + final_scale = step_size * (v_norm / v_norm_new.clamp_min_(1e-10)) + return v_chunk.mul_(final_scale.type_as(v_chunk)) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinearT(nn.Module): + """ + Linear layer with transposed weight storage (in_features, out_features) which + addresses the slow kernel that was used for gradient accumulation. @chrisjmccormick + """ + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + self.weight = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.bfloat16)) + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + nn.init.zeros_(self.weight) # @Grad62304977 and others + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out = torch.ops.nanogpt.mm_t(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return x @ self.weight.type_as(x) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len, paired=False): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.paired = paired + self.reset() + + def rotary(self, x_BTHD): + assert self.factor1.size(0) >= x_BTHD.size(-3) + factor1, factor2 = ( + self.factor1[None, : x_BTHD.size(-3), None, :], + self.factor2[None, : x_BTHD.size(-3), None, :], + ) + x_flip = x_BTHD.view(*x_BTHD.shape[:-1], x_BTHD.shape[-1] // 2, 2).flip(-1).view(x_BTHD.shape) + return factor1 * x_BTHD + factor2 * x_flip + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + angular_freq = angular_freq.repeat_interleave(2) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//2)]) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=device) + if not self.paired: + theta = torch.outer(t, angular_freq) + self.factor1 = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.factor2 = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, angular_freq) + theta2 = torch.outer(t_odd, angular_freq) + self.factor1 = nn.Buffer( + torch.cat((theta1.cos(), theta2.cos()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2 = nn.Buffer( + torch.cat((theta1.sin(), theta2.sin()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2[..., 1::2] *= -1 + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + if not self.paired: + theta = torch.outer(t, self.angular_freq) + self.factor1.copy_(theta.cos()) + self.factor2.copy_(theta.sin()) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, self.angular_freq) + theta2 = torch.outer(t_odd, self.angular_freq) + self.factor1.copy_(torch.cat((theta1.cos(), theta2.cos()), dim=-1)) + self.factor2.copy_(torch.cat((theta1.sin(), theta2.sin()), dim=-1)) + self.factor2[..., 1::2] *= -1 + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + yarn: Yarn + key_offset: bool + attn_gate_w: torch.Tensor + ve_gate_w: torch.Tensor + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, paired: bool = False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + self.paired = paired + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + yarn = attn_args.yarn + ve, sa_lambdas, key_offset = attn_args.ve, attn_args.sa_lambdas, attn_args.key_offset + seqlens, bm_size = attn_args.seqlens, attn_args.bm_size + # sparse gated attention to enable context based no-op by @classiclarryd + # only include gates on layers with value embeds used on forward pass + attn_gate_w, ve_gate_w = attn_args.attn_gate_w, attn_args.ve_gate_w + + q, k, v = F.linear(x, sa_lambdas[0] * qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + q, k = norm(q), norm(k) # QK norm @Grad62304977 + + if not self.paired: + q, k = yarn.rotary(q), yarn.rotary(k) + + if key_offset: + # shift keys forward for the stationary head dims. Enables 1-layer induction. + k[:, 1:, :, self.head_dim // 2:] = k[:, :-1, :, self.head_dim // 2:] + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T, self.num_heads, 1) + v = v + ve_gate_out * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + + else: + # Paired heads: adjacent heads' queries attend to each other's keys. + # Two copies of the input stream are interleaved to achieve this, which: + # - doubles the length of each sequence + # - halves the effective window size + q = q.view(B, T, self.num_heads // 2, self.head_dim * 2) + k = k.view(B, T, self.num_heads // 2, self.head_dim * 2) + v = v.reshape(B, T * 2, self.num_heads // 2, self.head_dim) + + q, k = yarn.rotary(q), yarn.rotary(k) + + q = q.view(B, T * 2, self.num_heads // 2, self.head_dim) + k = k.view(B, T * 2, self.num_heads // 2, self.head_dim) + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T * 2, self.num_heads // 2, 1) + v = v + ve_gate_out * ve.view_as(v) + + seqlens = 2 * seqlens + max_len = 2 * max_len + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, + max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=yarn.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(F.linear(x[..., :12], attn_gate_w)).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, sa_lambdas[1] * qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg + return y + +class MLP(nn.Module): + def __init__(self): + super().__init__() + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, c_fc: Tensor, c_proj: Tensor): + # relu(x)^2: + # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + # Fused triton kernel for relu(x @ W1.T)^2 @ W2.T + return FusedLinearReLUSquareFunction.apply(x, c_fc, c_proj) + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, has_attn: bool, has_mlp: bool, use_paired_head: bool): + super().__init__() + # skip attention of blocks.6 (the 7th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads, paired=use_paired_head) if has_attn else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP() if has_mlp else None + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor = None, c_fc: Tensor = None, c_proj: Tensor = None): + if self.attn is not None: + x = x + self.attn(norm(x), attn_args, qkvo_w) + if self.mlp is not None: + x = x + self.mlp(norm(x), c_fc, c_proj) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +@dataclass +class ForwardScheduleConfig: + mtp_weights: torch.Tensor + ws_short: int + ws_long: int + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + self.num_layers = num_layers + self.vocab_size = next_multiple_of_n(vocab_size, n=128) + + self.smear_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.smear_gate.weight) + self.smear_gate.weight.label = 'smear_gate' + + self.skip_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.skip_gate.weight) + self.skip_gate.weight.label = 'skip_gate' + + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.Parameter(torch.zeros(5 * self.vocab_size, model_dim, dtype=torch.bfloat16)) + self.value_embeds.label = 'value_embed' + + # parameter banks for attention and value embedding gate weights + self.attn_gate_bank = nn.Parameter(torch.zeros(10, num_heads, 12)) # 10 layers + self.attn_gate_bank.label = 'attn_gate_bank' + self.ve_gate_bank = nn.Parameter(torch.zeros(5, num_heads, 12)) # 5 unique gates + self.ve_gate_bank.label = 've_gate_bank' + + # ----------------------------------- + # Parameter banks for sharded optimization, by @chrisjmccormick + + # Identify which layers have attention/MLP + # Attention is skipped in layer 6 by @YouJiacheng + self.attn_layer_indices = [i for i in range(num_layers) if i != 6] + # All layers have MLP (At 11 layers--dropped first layer @EmelyanenkoK) + self.mlp_layer_indices = list(range(num_layers)) + + hdim = num_heads * head_dim + mlp_hdim = 4 * model_dim + + # Create index mappings: layer_idx -> bank_idx + self.layer_to_attn_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.attn_layer_indices)} + self.layer_to_mlp_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.mlp_layer_indices)} + + # Attention bank: stores QKVO weights for all attention layers + # merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # Simplified layout by @chrisjmccormick + # Shape: (num_attn_layers, 4*model_dim, hdim) = (10, 3072, 768) + # Reshape for sharding: (40, 768, 768) for even distribution across 8 GPUs + self.attn_bank = nn.Parameter(torch.empty(len(self.attn_layer_indices), 4 * model_dim, hdim)) + self.attn_bank.label = 'attn' + self.attn_bank.reshape = (len(self.attn_layer_indices) * 4, hdim, hdim) # (40, 768, 768) + + # MLP bank: stores c_fc and c_proj for all MLP layers + # Shape: (num_mlp_layers + padding, 2, mlp_hdim, model_dim) = (12, 2, 3072, 768) + # We add 1 padding layer (index 11) to get 12*2=24 matrices for even distribution across 8 GPUs + # Reshape for sharding: (24, 3072, 768) + num_mlp_with_padding = len(self.mlp_layer_indices) + 1 # 11 + 1 = 12 + self.mlp_bank = nn.Parameter(torch.empty(num_mlp_with_padding, 2, mlp_hdim, model_dim)) + self.mlp_bank.label = 'mlp' + self.mlp_bank.reshape = (num_mlp_with_padding * 2, mlp_hdim, model_dim) # (24, 3072, 768) + + # improved init scale by @YouJiacheng and @srashedll + std = 0.5 * model_dim ** -0.5 + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.attn_bank.uniform_(-bound, bound) + self.mlp_bank[:, 0, :, :].uniform_(-bound, bound) # c_fc + self.mlp_bank[:, 1, :, :].zero_() # c_proj - zero init suggested by @Grad62304977 + + # Create blocks with has_attn/has_mlp flags + self.paired_head_layers = [0, 2, 5, 9] + self.blocks = nn.ModuleList([ + Block(model_dim, head_dim, num_heads, + has_attn=(i in self.layer_to_attn_idx), + has_mlp=(i in self.layer_to_mlp_idx), + use_paired_head=(i in self.paired_head_layers)) + for i in range(num_layers) + ]) + self.yarn = Yarn(head_dim, max_seq_len) + self.yarn_paired_head = Yarn(head_dim, max_seq_len, paired=True) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + # Transposed weight storage for faster gradient accumulation + self.lm_head = CastedLinearT(model_dim, self.vocab_size, use_fp8=use_fp8, x_s=100/448, w_s=1.6/448, grad_s=grad_scale * 0.75/448) + + nn.init.normal_(self.lm_head.weight, mean=0, std=0.005) + self.lm_head.weight.label = 'lm_head' + + self.embed = nn.Embedding(self.vocab_size, model_dim) + self.embed.weight.label = 'embed' + with torch.no_grad(): + self.embed.weight.copy_(self.lm_head.weight.T) + + self.bigram_embed = nn.Embedding(args.bigram_vocab_size, model_dim) + self.bigram_embed.weight.label = 'bigram_embed' + nn.init.zeros_(self.bigram_embed.weight) + + # x0_lambdas separated out for different optimizer treatment (no beta smoothing) + self.x0_lambdas = nn.Parameter(torch.zeros(num_layers)) + self.x0_lambdas.label = 'x0_lambdas' + + pad = (-num_layers * 3 - 3) % dist.get_world_size() # updated: 3*num_layers instead of 4* + self.scalars = nn.Parameter( + torch.cat( + [ + 1.1 * torch.ones(num_layers), # resid lambdas. 1.1 init such that layer i weight is i^(num_layers-i). + *[torch.tensor([0.5, 1.0]) for _ in range(num_layers)], # SA lambdas + 0.1 * torch.ones(num_layers), # bigram lambdas + torch.zeros(1), # smear_lambda + 0.5*torch.ones(1), # backout_lambda + -1.5 * torch.ones(1), # skip_lambda -> σ(-1.5) ≈ 0.18 + torch.ones(pad), + ] + ) + ) + self.scalars.label = 'scalars' + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _compute_bigram_hash(x: Tensor, mod: int) -> Tensor: + """ + Computes bigram hash on GPU for each position using [prev_token, curr_token]. + Mathematically identical to the CPU version but computed on device. + """ + rand_int_1 = 36313 + rand_int_2 = 27191 + result = torch.empty_like(x) + result[0] = mod + result[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod + return result + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, schedule_cfg: ForwardScheduleConfig): + assert input_seq.ndim == 1 + + # unpack schedule_cfg + mtp_weights, ws_short, ws_long = schedule_cfg.mtp_weights, schedule_cfg.ws_short, schedule_cfg.ws_long + + # set configs + skip_connections = [] + skip_in = [3] # long attention window on layer 3 + skip_out = [6] # no attn op on layer 6 + x_backout = None + backout_layer = 7 + + # set lambdas + resid_lambdas = self.scalars[: 1 * self.num_layers] + x0_lambdas = self.x0_lambdas + sa_lambdas = self.scalars[1 * self.num_layers: 3 * self.num_layers].view(-1, 2) + bigram_lambdas = self.scalars[3 * self.num_layers: 4 * self.num_layers] + smear_lambda = self.scalars[4 * self.num_layers] + backout_lambda = self.scalars[4 * self.num_layers+1] + skip_lambda = self.scalars[4 * self.num_layers+2] + + # set block masks and key shift + bm_sizes = [ws_short, ws_short, ws_short, ws_long, ws_short, ws_short, None, ws_short, ws_short, ws_short, ws_long] + assert len(bm_sizes) == self.num_layers + key_offset = [b==ws_long for b in bm_sizes] # apply partial key offset to long windows + + # Embedding lookup - embed is synced from lm_head during tied phase by optimizer + x = self.embed(input_seq) + # Compute bigram hash on GPU (moved from CPU data loader) + bigram_seq = self._compute_bigram_hash(input_seq, args.bigram_vocab_size - 1) + x0_bigram = self.bigram_embed(bigram_seq)[None] + + # Value embeddings - always computed (not precomputed) + ve = self.value_embeds.view(5, self.vocab_size, -1)[:, input_seq] + # 01 ... 234 structure on token value embeddings by @photomz + ve = [ve[0], ve[1]] + [None] * (self.num_layers - 5) + [ve[2], ve[3], ve[4]] + assert len(ve) == self.num_layers + + # smear token embed forward 1 position @classiclarryd + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # unbind gate banks to avoid select_backwards kernel + ag = [w.bfloat16() for w in self.attn_gate_bank.unbind(0)] + veg = [w.bfloat16() for w in self.ve_gate_bank.unbind(0)] + attn_gates = ag[:6] + [None] + ag[6:] + ve_gates = [veg[0], veg[1]] + [None] * (self.num_layers - 5) + [veg[2], veg[3], veg[4]] + assert len(attn_gates) == self.num_layers + assert len(ve_gates) == self.num_layers + + # unbind weight banks to avoid select_backwards kernel + attn_weights = self.attn_bank.unbind(0) # tuple of [4*dim, hdim] tensors + mlp_fcs = self.mlp_bank[:, 0, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + mlp_projs = self.mlp_bank[:, 1, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + + for i in range(self.num_layers): + yarn = self.yarn_paired_head if i in self.paired_head_layers else self.yarn + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + yarn=yarn, + key_offset=key_offset[i], + attn_gate_w=attn_gates[i], + ve_gate_w=ve_gates[i] + ) + if i in skip_out: + skip_gate_out = torch.sigmoid(skip_lambda) * 2 * torch.sigmoid(self.skip_gate(x0[..., :self.skip_gate.weight.size(-1)])) + x = x + skip_gate_out * skip_connections.pop() + if i == 0: + x = (resid_lambdas[0] + x0_lambdas[0]) * x + bigram_lambdas[0] * x0_bigram + else: + x = resid_lambdas[i] * x + x0_lambdas[i] * x0 + bigram_lambdas[i] * x0_bigram + + # Get weights for this layer from banks + qkvo_w = attn_weights[self.layer_to_attn_idx[i]] if i in self.layer_to_attn_idx else None + c_fc = mlp_fcs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + c_proj = mlp_projs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + + x = self.blocks[i](x, attn_args, qkvo_w, c_fc, c_proj) + if i in skip_in: + skip_connections.append(x) + if i == backout_layer: + x_backout = x + + # back out contributions from first 7 layers that are only required for downstream context and not direct prediction + x -= backout_lambda * x_backout + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15 + # @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1). @classiclarryd updated to 23*sigmoid((logits+5)/7.5) + if self.training: + losses = FusedSoftcappedCrossEntropy.apply(logits.view(-1, logits.size(-1)), target_seq, mtp_weights, 23.0, 5.0, 7.5) + loss = losses.sum() + else: + logits = 23 * torch.sigmoid((logits + 5) / 7.5) + logits_for_loss = logits.float() + loss = F.cross_entropy(logits_for_loss.view(-1, logits_for_loss.size(-1)), target_seq, reduction="mean") + return loss +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class Shard: + def __init__(self, tokens: Tensor, world_size: int = 1): + self.tokens = tokens + self.size = tokens.numel() + self.world_size = world_size + self.i = 0 + + # Partial index now, full index async + self.bos_idx = (tokens[:6_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._full_idx = None + self._loader_thread = None + self._ready = threading.Event() + self._loader_thread = threading.Thread(target=self._scan) + self._loader_thread.start() + + def _scan(self): + self._full_idx = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._ready.set() + + def _maybe_switch(self): + # Switch to full index as soon as async scan completes + if self.bos_idx is not self._full_idx and self._ready.is_set(): + self._loader_thread.join() + self.bos_idx = self._full_idx + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + self._maybe_switch() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + return starts, ends + + @staticmethod + def load_async(file: Path, world_size: int = 1): + """Returns getter function for async shard loading""" + result = {} + ready = threading.Event() + def load(): + tokens = _load_data_shard(file) + result['shard'] = Shard(tokens, world_size) + ready.set() + thread = threading.Thread(target=load) + thread.start() + def get(): + ready.wait() + thread.join() + return result['shard'] + return get + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + shard = Shard(tokens, world_size) + next_shard_getter = Shard.load_async(next(file_iter), world_size) + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = shard.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + shard = next_shard_getter() + tokens = shard.tokens + try: + next_shard_getter = Shard.load_async(next(file_iter), world_size) + except StopIteration: + next_shard_getter = None # no more shards to preload + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + # Cast to int32 on CPU before transfer to avoid dtype conversion during .to() + _inputs = _inputs.to(dtype=torch.int32) + _targets = _targets.to(dtype=torch.int64) + _cum_lengths = _cum_lengths.to(dtype=torch.int32) + # Bigram hash computation moved to GPU in forward() + + new_params = yield ( + _inputs.to(device="cuda", non_blocking=True), + _targets.to(device="cuda", non_blocking=True), + _cum_lengths.to(device="cuda", non_blocking=True), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * new_grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens // new_grad_accum_steps + max_seq_len = new_max_seq_len + +# ----------------------------------------------------------------------------- +# Training Management + +@dataclass +class Hyperparameters: + # data + data_path = os.environ.get("DATA_PATH", ".") + train_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_train_*.bin") # input .bin to train on + val_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_val_*.bin") # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + # batch sizes + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # schedule + num_scheduled_iterations: int = 1515 # number of steps to complete lr and ws schedule + num_extension_iterations: int = 40 # number of steps to continue training at final lr and ws + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 250 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # bigram hash embedding + bigram_vocab_size: int = 50304 * 5 + +args = Hyperparameters() + +@dataclass +class TrainingStage: + lr_mul: float + batch_size: int + window_sizes: tuple[int, int] # (short, long) in block units + mtp_weights_start: list[float] + mtp_weights_end: list[float] + duration: float = None + +class TrainingSchedule: + """ + Training schedule initialized via TRAINING_STAGES + 1. Multi Token Prediction schedule of [1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1] @varunneal + 2. Sliding Attention window schedule of [1,3] -> [3,7] -> [5,11] -> [6,13] + 3. YaRN updates to RoPE on window changes + 4. Split embed and lm head at 2/3 of training + 5. Batch size schedule of 8 -> 16 -> 24 + 6. Post training extension of long windows from 13 to 20 + """ + + def __init__(self, stages: list[TrainingStage], scheduled_iterations: int, extension_iterations: int, + cooldown_frac: float = 0.5, split_embed_stage: int = 2, ws_post_yarn_ext: int = 20): + self.stages = stages + self.scheduled_iterations = scheduled_iterations + self.cooldown_frac = cooldown_frac + # increase final validation ws, used for YaRN extension and short window size @classiclarryd + self.ws_post_yarn_ext = ws_post_yarn_ext + + self.total_steps = self.scheduled_iterations + extension_iterations + + # Build stage boundaries (last is extension stage) + ends = [0] + [round(c * scheduled_iterations) for c in accumulate(s.duration for s in stages[:-1])] + [self.total_steps] + assert self.scheduled_iterations == ends[-2] + self.boundaries = list(pairwise(ends)) + + # Split embed at specified stage (ensure odd step for Adam) + self.split_step = self.boundaries[split_embed_stage][0] | 1 + + # Precompute MTP weights for all steps + self.mtp_weights = [] + for step in range(self.total_steps + 1): + stage, t = self.lookup(step) + w = [a + (b - a) * t for a, b in zip(stage.mtp_weights_start, stage.mtp_weights_end)] + self.mtp_weights.append(torch.tensor(w, device=device)) + + def lookup(self, step: int) -> tuple[TrainingStage, float]: + # Returns stage and % of the way through that stage + for i, (start, end) in enumerate(self.boundaries): + if step < end: + t = (step - start) / (end - start) + return self.stages[i], t + return self.stages[-1], 1.0 + + def get_lr(self, step: int) -> float: + # learning rate schedule: tied to batch size schedule, with cooldown at the end + stage, _ = self.lookup(step) + lr = stage.lr_mul + cd_start = int(self.scheduled_iterations * (1 - self.cooldown_frac)) + if step >= cd_start: + t = min(1.0, (step - cd_start) / (self.scheduled_iterations - cd_start)) + lr = lr * (1 - t) + 0.1 * t + return lr + +# window_sizes are in units of `block_size` tokens (defined in TrainingManager) +TRAINING_STAGES = [ + TrainingStage(duration=1/3, batch_size=8 * 2048 * 8, window_sizes=(1, 3), lr_mul=1.0, + mtp_weights_start=[1.0, 0.5, 0.25], mtp_weights_end=[1.0, 0.5, 0.0]), + TrainingStage(duration=1/3, batch_size=16 * 2048 * 8, window_sizes=(3, 7), lr_mul=1.52, # (16/8)**0.6 + mtp_weights_start=[1.0, 0.5], mtp_weights_end=[1.0, 0.0]), + TrainingStage(duration=1/3, batch_size=24 * 2048 * 8, window_sizes=(5, 11), lr_mul=1.73, # (24/8)**0.5 + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), + # extension stage + TrainingStage(batch_size=24 * 2048 * 8, window_sizes=(6, 13), lr_mul=1.0, # lr_mul is not used + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), +] + +training_schedule = TrainingSchedule(TRAINING_STAGES, args.num_scheduled_iterations, args.num_extension_iterations, cooldown_frac=0.55) + +def get_muon_momentum(step: int, muon_warmup_steps=300, muon_cooldown_steps=50, momentum_min=0.85, momentum_max=0.95): + # warmup phase: linearly increase momentum from min to max + # cooldown phase: linearly decrease momentum from max to min + momentum_cd_start = training_schedule.total_steps - muon_cooldown_steps + if step < muon_warmup_steps: + frac = step / muon_warmup_steps + momentum = momentum_min + frac * (momentum_max - momentum_min) + elif step > momentum_cd_start: + frac = (step - momentum_cd_start) / muon_cooldown_steps + momentum = momentum_max - frac * (momentum_max - momentum_min) + else: + momentum = momentum_max + return momentum + +class TrainingManager(): + """ + Manages the NorMuonAndAdam for all parameters with explicit ordering. + 1. Scalars are given higher momentum terms to smooth learning @ChrisJMcCormick + 2. Adam optimizers are only stepped on odd steps @classiclarryd + 3. Explicit scatter_order and work_order for communication scheduling (no backward hooks) + 4. Muon has a linear momentum warmup and cooldown schedule + 5. Learning rates follow a linear decay schedule + 6. Embed is tied to lm_head until split step (2/3 of training), then untied @classiclarryd + """ + def __init__(self, model): + self.model = model + self.block_size = 128 + + # - Ordering dictates when to launch reduce/reduce_scatter operations + # - "sharded" parameters use reduce_scatter/all_gather and "replicated" ones use all_reduce + # - lr_mul and wd_mul are per-parameter learning rate and weight decay multipliers + self.param_table = { + "attn": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "mlp": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "scalars": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 5.0, "wd_mul": 0.0}, + "value_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "bigram_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "smear_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.01, "wd_mul": 0.0}, + "skip_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.05, "wd_mul": 0.0}, + "attn_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "ve_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "x0_lambdas": {"optim": "adam", "comms": "replicated", "adam_betas": [0.65, 0.95], "lr_mul": 5.0, "wd_mul": 0.0}, + "lm_head": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + "embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + } + + # - Process smaller/faster params first while large reduces complete + # - lm_head must complete before embed sync (when tied) + self.work_order = [ + "scalars", "smear_gate", "skip_gate", "attn_gate_bank", "ve_gate_bank", "x0_lambdas", # Small, fast + "value_embed", "bigram_embed", # Medium + "lm_head", "embed", # lm_head must complete before embed sync (when tied) + "attn", "mlp", # Large, polar express - process last to maximize overlap + ] + + adam_defaults = dict( + lr=0.008, + eps=1e-10, + weight_decay=0.005, + ) + + normuon_defaults = dict( + lr=0.023, + momentum=0.95, + beta2=0.95, + weight_decay=1.2, + ) + + self.optimizer = NorMuonAndAdam( + model.named_parameters(), + param_table=self.param_table, + scatter_order=list(self.param_table.keys()), # Dict order defines scatter priority + work_order=self.work_order, + adam_defaults=adam_defaults, + normuon_defaults=normuon_defaults, + ) + + # Split embed from lm_head at 2/3 of training (on an odd step so Adam updates) + self.split_step = training_schedule.split_step + + self.reset() + + def apply_final_ws_ext(self): + self.ws_long = training_schedule.ws_post_yarn_ext + + def get_forward_args(self): + return ForwardScheduleConfig( + mtp_weights = self.mtp_weights, + ws_short = self.ws_short * self.block_size, + ws_long = self.ws_long * self.block_size + ) + + def _is_adam_step(self, step: int): + """Adam params are only updated on odd steps.""" + return step % 2 == 1 + + def get_transition_steps(self): + return [start for start, _ in training_schedule.boundaries[1:]] + + def advance_schedule(self, step: int): + stage, _ = training_schedule.lookup(step) + self.ws_short, new_ws_long = stage.window_sizes + if new_ws_long != self.ws_long: + self.model.yarn.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + self.model.yarn_paired_head.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + + new_batch_size = stage.batch_size + if new_batch_size != self.batch_size: + self.train_loader_send_args = (new_batch_size, args.train_max_seq_len, grad_accum_steps) + self.batch_size = new_batch_size + else: + self.train_loader_send_args = None + + self.ws_long = new_ws_long + self.mtp_weights = training_schedule.mtp_weights[step] + + def step_optimizers(self, step: int): + step_lr = training_schedule.get_lr(step) + muon_momentum = get_muon_momentum(step) + do_adam = self._is_adam_step(step) + + # Update learning rates and momentum for all params + for param, p_cfg in self.optimizer.param_cfgs.items(): + p_cfg.lr = p_cfg.initial_lr * step_lr + if p_cfg.optim == "normuon": + p_cfg.momentum = muon_momentum + + # Step optimizer with do_adam flag + self.optimizer.step(do_adam=do_adam) + + # At split step: copy lm_head optimizer state to embed and mark as split + if step == self.split_step: + self.optimizer.copy_lm_state_to_embed() + + def reset(self, state=None): + if state is not None: + self.optimizer.load_state_dict(state) + + # Reset NorMuon momentum buffers and split_embed state + self.optimizer.reset() + + stage, _ = training_schedule.lookup(0) + self.ws_short, self.ws_long = stage.window_sizes + self.batch_size = stage.batch_size + self.model.yarn.reset() + self.model.yarn_paired_head.reset() + + def get_state(self): + return copy.deepcopy(self.optimizer.state_dict()) + +# ----------------------------------------------------------------------------- +# int main + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=11, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=args.val_batch_size // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.weight.data = m.weight.data.bfloat16() +model.attn_gate_bank.data = model.attn_gate_bank.data.bfloat16() +model.ve_gate_bank.data = model.ve_gate_bank.data.bfloat16() +model.attn_bank.data = model.attn_bank.data.bfloat16() +model.mlp_bank.data = model.mlp_bank.data.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) +training_manager = TrainingManager(model) + +######################################## +# Warmup kernels # +######################################## +print0("Compiling model and warming up kernels (~7 minutes on first execution)", console=True) +# Warmup the training kernels, then re-initialize the state so we aren't cheating +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizer=training_manager.get_state()) # save the initial state +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + +transition_steps = training_manager.get_transition_steps() +# first few steps plus transitions +warmup_steps = sorted({0, 1, 2} | set(s + offset for s in transition_steps for offset in [-1, 0, 1] if s + offset >= 0)) +print0(f"Sampling steps {warmup_steps} for warmup", console=True) +for step in warmup_steps: + training_manager.advance_schedule(step) + model.eval() + with torch.no_grad(): + inputs, targets, cum_seqlens = next(val_loader) + model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + model.train() + for idx in range(grad_accum_steps): + send_args = training_manager.train_loader_send_args + inputs, targets, cum_seqlens = train_loader.send(send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) +print0("Resetting Model", console=True) +model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +training_manager.reset(initial_state["optimizer"]) +del val_loader, train_loader, initial_state +model.train() + +######################################## +# Training and validation # +######################################## +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) + +gc.collect() + +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = training_schedule.total_steps +for step in range(train_steps + 1): + last_step = (step == train_steps) + training_manager.advance_schedule(step) + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + training_manager.apply_final_ws_ext() + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + val_loss /= val_steps + del val_loader + dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizer=training_manager.get_state()) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for idx in range(grad_accum_steps): + inputs, targets, cum_seqlens = train_loader.send(training_manager.train_loader_send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) + + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + + +---------------------------------------- +# triton_kernels.py +---------------------------------------- + +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded configs based on H100 autotuning + if K == 768: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + else: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 128 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded config based on H100 autotuning (M=768) + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +# ----------------------------------------------------------------------------- +# Triton kernel for MLP: relu(x @ W1.T)^2, by @andrewbriand, @jrauvola + +@triton.jit +def linear_relu_square_kernel(a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, + ): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am_c = pid_m * BLOCK_SIZE_M + offs_bn_c = pid_n * BLOCK_SIZE_N + + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + + c0 = acc0.to(dtype) + if not FORWARD: + c0_pre = aux_desc.load([offs_am_c, offs_bn_c]) + c0 = 2 * c0 * tl.where(c0_pre > 0, c0_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c], c0) + + if FORWARD: + c0_post = tl.maximum(c0, 0) + c0_post = c0_post * c0_post + aux_desc.store([offs_am_c, offs_bn_c], c0_post) + + c1 = acc1.to(dtype) + if not FORWARD: + c1_pre = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c1 = 2 * c1 * tl.where(c1_pre > 0, c1_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + + if FORWARD: + c1_post = tl.maximum(c1, 0) + c1_post = c1_post * c1_post + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1_post) + + +def linear_relu_square(a, b, aux=None): + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + FORWARD = False + if aux is None: + FORWARD = True + aux = torch.empty((M, N), device=a.device, dtype=dtype) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_stages = 4 if FORWARD else 3 + num_warps = 8 + + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + + def grid(META): + return (min( + NUM_SMS, + triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), + ), ) + + linear_relu_square_kernel[grid]( + a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + NUM_SMS=NUM_SMS, + FORWARD=FORWARD, + num_stages=num_stages, + num_warps=num_warps + ) + + if FORWARD: + return c, aux + else: + return c + +class FusedLinearReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, W1, W2): + pre, post = linear_relu_square(x.view((-1, x.shape[-1])), W1) + x3 = post @ W2 + ctx.save_for_backward(x, W1, W2, pre, post) + return x3.view(x.shape) + + @staticmethod + def backward(ctx, grad_output): + x, W1, W2, pre, post = ctx.saved_tensors + dW2 = post.T @ grad_output + dpre = linear_relu_square(grad_output.view((-1, grad_output.shape[-1])), W2, aux=pre) + dW1 = dpre.T @ x + dx = dpre @ W1 + return dx.view(x.shape), dW1, dW2 + +# ----------------------------------------------------------------------------- +# Fused Softcapped Cross Entropy + + +@triton.jit +def fused_softcapped_entropy_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + + max_val = -float('inf') + sum_exp = 0.0 + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=-float('inf')).to(tl.float32) + z = A * tl.sigmoid((val + B) / C) + z = tl.where(mask, z, -float('inf')) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + + total_loss = 0.0 + for k in range(n_predict): + target_idx = row_idx + k + if target_idx < n_rows: + weight = tl.load(mtp_weights_ptr + k) + if weight > 0: + target = tl.load(targets_ptr + target_idx).to(tl.int32) + if target >= 0 and target < n_cols: + val_target = tl.load(logits_row_ptr + target).to(tl.float32) + z_target = A * tl.sigmoid((val_target + B) / C) + total_loss += weight * (lse - z_target) + + tl.store(losses_ptr + row_idx, total_loss) + +@triton.jit +def fused_softcapped_entropy_bwd_kernel( + grad_input_ptr, grad_output_ptr, lse_ptr, logits_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, stride_grad_n, stride_grad_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_input_ptr + row_idx * stride_grad_n + + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_output_ptr + row_idx) + + S_w = 0.0 + for k in range(n_predict): + if row_idx + k < n_rows: + S_w += tl.load(mtp_weights_ptr + k) + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=0.0).to(tl.float32) + u = (val + B) / C + sigmoid_u = tl.sigmoid(u) + z = A * sigmoid_u + p = tl.exp(z - lse) + + term1 = S_w * p + term2 = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for k in range(n_predict): + if row_idx + k < n_rows: + target = tl.load(targets_ptr + row_idx + k).to(tl.int32) + weight = tl.load(mtp_weights_ptr + k) + term2 += tl.where(cols == target, weight, 0.0) + + grad_z = grad_loss * (term1 - term2) + dz_dx = (1.0 / C) * z * (1.0 - sigmoid_u) + grad_x = grad_z * dz_dx + tl.store(grad_row_ptr + cols, grad_x.to(tl.bfloat16), mask=mask) + +class FusedSoftcappedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, targets, mtp_weights, A=23.0, B=5.0, C=7.5): + n_rows, n_cols = logits.shape + if mtp_weights is None: + mtp_weights = torch.tensor([1.0], device=logits.device, dtype=torch.float32) + n_predict = mtp_weights.shape[0] + + losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + lse = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + + logits = logits.contiguous() + targets = targets.contiguous() + mtp_weights = mtp_weights.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_fwd_kernel[grid]( + logits, losses, lse, targets, mtp_weights, + logits.stride(0), logits.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + + ctx.save_for_backward(logits, targets, mtp_weights, lse) + ctx.params = (A, B, C) + return losses + + @staticmethod + def backward(ctx, grad_output): + logits, targets, mtp_weights, lse = ctx.saved_tensors + A, B, C = ctx.params + n_rows, n_cols = logits.shape + n_predict = mtp_weights.shape[0] + + grad_input = torch.empty((n_rows, n_cols), dtype=torch.bfloat16, device=logits.device) + grad_output = grad_output.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_bwd_kernel[grid]( + grad_input, grad_output, lse, logits, targets, mtp_weights, + logits.stride(0), logits.stride(1), grad_input.stride(0), grad_input.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + return grad_input, None, None, None, None, None + +==================================================================================================== +Running Python 3.12.7 (main, Jan 31 2026, 04:21:49) [GCC 13.2.0] +Running PyTorch 2.10.0.dev20251210+cu126 compiled for CUDA 12.6 +Running Triton version 3.6.0 +Sun Feb 1 10:21:51 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 29C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 30C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:71:00.0 Off | 0 | +| N/A 31C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:79:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:7F:00.0 Off | 0 | +| N/A 29C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 95 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 96 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 97 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 98 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 99 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 100 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 101 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 102 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +Compiling model and warming up kernels (~7 minutes on first execution) +Sampling steps [0, 1, 2, 504, 505, 506, 1009, 1010, 1011, 1514, 1515, 1516] for warmup +Resetting Model +step:0/1555 val_loss:10.8314 train_time:0ms step_avg:0.03ms +step:1/1555 train_time:93ms step_avg:93.13ms +step:2/1555 train_time:127ms step_avg:63.41ms +step:3/1555 train_time:155ms step_avg:51.75ms +step:4/1555 train_time:180ms step_avg:45.11ms +step:5/1555 train_time:203ms step_avg:40.56ms +step:6/1555 train_time:234ms step_avg:39.06ms +step:7/1555 train_time:265ms step_avg:37.90ms +step:8/1555 train_time:303ms step_avg:37.84ms +step:9/1555 train_time:334ms step_avg:37.07ms +step:10/1555 train_time:371ms step_avg:37.13ms +step:11/1555 train_time:402ms step_avg:36.54ms +step:12/1555 train_time:440ms step_avg:36.64ms +step:13/1555 train_time:471ms step_avg:36.21ms +step:14/1555 train_time:509ms step_avg:36.32ms +step:15/1555 train_time:539ms step_avg:35.96ms +step:16/1555 train_time:577ms step_avg:36.04ms +step:17/1555 train_time:607ms step_avg:35.73ms +step:18/1555 train_time:645ms step_avg:35.82ms +step:19/1555 train_time:676ms step_avg:35.57ms +step:20/1555 train_time:714ms step_avg:35.68ms +step:21/1555 train_time:745ms step_avg:35.45ms +step:22/1555 train_time:782ms step_avg:35.55ms +step:23/1555 train_time:813ms step_avg:35.36ms +step:24/1555 train_time:851ms step_avg:35.46ms +step:25/1555 train_time:882ms step_avg:35.29ms +step:26/1555 train_time:920ms step_avg:35.37ms +step:27/1555 train_time:950ms step_avg:35.19ms +step:28/1555 train_time:987ms step_avg:35.26ms +step:29/1555 train_time:1019ms step_avg:35.13ms +step:30/1555 train_time:1057ms step_avg:35.23ms +step:31/1555 train_time:1088ms step_avg:35.10ms +step:32/1555 train_time:1126ms step_avg:35.17ms +step:33/1555 train_time:1157ms step_avg:35.05ms +step:34/1555 train_time:1194ms step_avg:35.13ms +step:35/1555 train_time:1225ms step_avg:35.01ms +step:36/1555 train_time:1263ms step_avg:35.08ms +step:37/1555 train_time:1294ms step_avg:34.99ms +step:38/1555 train_time:1332ms step_avg:35.06ms +step:39/1555 train_time:1364ms step_avg:34.97ms +step:40/1555 train_time:1401ms step_avg:35.03ms +step:41/1555 train_time:1433ms step_avg:34.95ms +step:42/1555 train_time:1471ms step_avg:35.01ms +step:43/1555 train_time:1501ms step_avg:34.92ms +step:44/1555 train_time:1539ms step_avg:34.97ms +step:45/1555 train_time:1570ms step_avg:34.89ms +step:46/1555 train_time:1607ms step_avg:34.94ms +step:47/1555 train_time:1638ms step_avg:34.86ms +step:48/1555 train_time:1676ms step_avg:34.91ms +step:49/1555 train_time:1707ms step_avg:34.84ms +step:50/1555 train_time:1744ms step_avg:34.89ms +step:51/1555 train_time:1776ms step_avg:34.81ms +step:52/1555 train_time:1814ms step_avg:34.88ms +step:53/1555 train_time:1845ms step_avg:34.81ms +step:54/1555 train_time:1882ms step_avg:34.85ms +step:55/1555 train_time:1913ms step_avg:34.78ms +step:56/1555 train_time:1951ms step_avg:34.84ms +step:57/1555 train_time:1982ms step_avg:34.77ms +step:58/1555 train_time:2020ms step_avg:34.82ms +step:59/1555 train_time:2051ms step_avg:34.76ms +step:60/1555 train_time:2088ms step_avg:34.80ms +step:61/1555 train_time:2119ms step_avg:34.74ms +step:62/1555 train_time:2157ms step_avg:34.79ms +step:63/1555 train_time:2188ms step_avg:34.73ms +step:64/1555 train_time:2226ms step_avg:34.78ms +step:65/1555 train_time:2257ms step_avg:34.72ms +step:66/1555 train_time:2294ms step_avg:34.76ms +step:67/1555 train_time:2325ms step_avg:34.70ms +step:68/1555 train_time:2362ms step_avg:34.74ms +step:69/1555 train_time:2394ms step_avg:34.70ms +step:70/1555 train_time:2432ms step_avg:34.74ms +step:71/1555 train_time:2463ms step_avg:34.69ms +step:72/1555 train_time:2500ms step_avg:34.73ms +step:73/1555 train_time:2532ms step_avg:34.68ms +step:74/1555 train_time:2569ms step_avg:34.72ms +step:75/1555 train_time:2600ms step_avg:34.67ms +step:76/1555 train_time:2638ms step_avg:34.71ms +step:77/1555 train_time:2669ms step_avg:34.66ms +step:78/1555 train_time:2707ms step_avg:34.70ms +step:79/1555 train_time:2738ms step_avg:34.65ms +step:80/1555 train_time:2775ms step_avg:34.69ms +step:81/1555 train_time:2806ms step_avg:34.64ms +step:82/1555 train_time:2843ms step_avg:34.67ms +step:83/1555 train_time:2874ms step_avg:34.63ms +step:84/1555 train_time:2912ms step_avg:34.67ms +step:85/1555 train_time:2943ms step_avg:34.62ms +step:86/1555 train_time:2980ms step_avg:34.65ms +step:87/1555 train_time:3012ms step_avg:34.62ms +step:88/1555 train_time:3049ms step_avg:34.65ms +step:89/1555 train_time:3080ms step_avg:34.61ms +step:90/1555 train_time:3118ms step_avg:34.64ms +step:91/1555 train_time:3148ms step_avg:34.60ms +step:92/1555 train_time:3186ms step_avg:34.63ms +step:93/1555 train_time:3216ms step_avg:34.59ms +step:94/1555 train_time:3254ms step_avg:34.62ms +step:95/1555 train_time:3286ms step_avg:34.59ms +step:96/1555 train_time:3323ms step_avg:34.62ms +step:97/1555 train_time:3355ms step_avg:34.58ms +step:98/1555 train_time:3392ms step_avg:34.62ms +step:99/1555 train_time:3424ms step_avg:34.58ms +step:100/1555 train_time:3461ms step_avg:34.61ms +step:101/1555 train_time:3492ms step_avg:34.58ms +step:102/1555 train_time:3530ms step_avg:34.61ms +step:103/1555 train_time:3561ms step_avg:34.58ms +step:104/1555 train_time:3599ms step_avg:34.60ms +step:105/1555 train_time:3629ms step_avg:34.57ms +step:106/1555 train_time:3667ms step_avg:34.59ms +step:107/1555 train_time:3698ms step_avg:34.56ms +step:108/1555 train_time:3736ms step_avg:34.59ms +step:109/1555 train_time:3766ms step_avg:34.55ms +step:110/1555 train_time:3804ms step_avg:34.58ms +step:111/1555 train_time:3835ms step_avg:34.55ms +step:112/1555 train_time:3873ms step_avg:34.58ms +step:113/1555 train_time:3904ms step_avg:34.55ms +step:114/1555 train_time:3941ms step_avg:34.57ms +step:115/1555 train_time:3973ms step_avg:34.54ms +step:116/1555 train_time:4011ms step_avg:34.58ms +step:117/1555 train_time:4042ms step_avg:34.55ms +step:118/1555 train_time:4080ms step_avg:34.57ms +step:119/1555 train_time:4111ms step_avg:34.55ms +step:120/1555 train_time:4149ms step_avg:34.57ms +step:121/1555 train_time:4179ms step_avg:34.54ms +step:122/1555 train_time:4217ms step_avg:34.57ms +step:123/1555 train_time:4248ms step_avg:34.54ms +step:124/1555 train_time:4286ms step_avg:34.56ms +step:125/1555 train_time:4318ms step_avg:34.54ms +step:126/1555 train_time:4355ms step_avg:34.57ms +step:127/1555 train_time:4386ms step_avg:34.54ms +step:128/1555 train_time:4424ms step_avg:34.56ms +step:129/1555 train_time:4455ms step_avg:34.54ms +step:130/1555 train_time:4493ms step_avg:34.56ms +step:131/1555 train_time:4524ms step_avg:34.53ms +step:132/1555 train_time:4562ms step_avg:34.56ms +step:133/1555 train_time:4593ms step_avg:34.53ms +step:134/1555 train_time:4630ms step_avg:34.56ms +step:135/1555 train_time:4661ms step_avg:34.53ms +step:136/1555 train_time:4699ms step_avg:34.55ms +step:137/1555 train_time:4730ms step_avg:34.52ms +step:138/1555 train_time:4767ms step_avg:34.55ms +step:139/1555 train_time:4798ms step_avg:34.52ms +step:140/1555 train_time:4836ms step_avg:34.54ms +step:141/1555 train_time:4866ms step_avg:34.51ms +step:142/1555 train_time:4904ms step_avg:34.53ms +step:143/1555 train_time:4935ms step_avg:34.51ms +step:144/1555 train_time:4973ms step_avg:34.53ms +step:145/1555 train_time:5004ms step_avg:34.51ms +step:146/1555 train_time:5041ms step_avg:34.53ms +step:147/1555 train_time:5072ms step_avg:34.50ms +step:148/1555 train_time:5110ms step_avg:34.53ms +step:149/1555 train_time:5141ms step_avg:34.50ms +step:150/1555 train_time:5179ms step_avg:34.52ms +step:151/1555 train_time:5209ms step_avg:34.50ms +step:152/1555 train_time:5247ms step_avg:34.52ms +step:153/1555 train_time:5279ms step_avg:34.50ms +step:154/1555 train_time:5316ms step_avg:34.52ms +step:155/1555 train_time:5347ms step_avg:34.49ms +step:156/1555 train_time:5384ms step_avg:34.51ms +step:157/1555 train_time:5415ms step_avg:34.49ms +step:158/1555 train_time:5453ms step_avg:34.51ms +step:159/1555 train_time:5484ms step_avg:34.49ms +step:160/1555 train_time:5522ms step_avg:34.51ms +step:161/1555 train_time:5553ms step_avg:34.49ms +step:162/1555 train_time:5590ms step_avg:34.51ms +step:163/1555 train_time:5622ms step_avg:34.49ms +step:164/1555 train_time:5659ms step_avg:34.51ms +step:165/1555 train_time:5690ms step_avg:34.48ms +step:166/1555 train_time:5727ms step_avg:34.50ms +step:167/1555 train_time:5758ms step_avg:34.48ms +step:168/1555 train_time:5795ms step_avg:34.50ms +step:169/1555 train_time:5826ms step_avg:34.47ms +step:170/1555 train_time:5863ms step_avg:34.49ms +step:171/1555 train_time:5894ms step_avg:34.47ms +step:172/1555 train_time:5932ms step_avg:34.49ms +step:173/1555 train_time:5963ms step_avg:34.47ms +step:174/1555 train_time:6001ms step_avg:34.49ms +step:175/1555 train_time:6032ms step_avg:34.47ms +step:176/1555 train_time:6069ms step_avg:34.48ms +step:177/1555 train_time:6100ms step_avg:34.46ms +step:178/1555 train_time:6138ms step_avg:34.48ms +step:179/1555 train_time:6168ms step_avg:34.46ms +step:180/1555 train_time:6206ms step_avg:34.48ms +step:181/1555 train_time:6237ms step_avg:34.46ms +step:182/1555 train_time:6275ms step_avg:34.48ms +step:183/1555 train_time:6306ms step_avg:34.46ms +step:184/1555 train_time:6343ms step_avg:34.47ms +step:185/1555 train_time:6374ms step_avg:34.46ms +step:186/1555 train_time:6412ms step_avg:34.47ms +step:187/1555 train_time:6443ms step_avg:34.45ms +step:188/1555 train_time:6480ms step_avg:34.47ms +step:189/1555 train_time:6511ms step_avg:34.45ms +step:190/1555 train_time:6549ms step_avg:34.47ms +step:191/1555 train_time:6580ms step_avg:34.45ms +step:192/1555 train_time:6617ms step_avg:34.46ms +step:193/1555 train_time:6648ms step_avg:34.44ms +step:194/1555 train_time:6685ms step_avg:34.46ms +step:195/1555 train_time:6717ms step_avg:34.44ms +step:196/1555 train_time:6754ms step_avg:34.46ms +step:197/1555 train_time:6785ms step_avg:34.44ms +step:198/1555 train_time:6822ms step_avg:34.46ms +step:199/1555 train_time:6854ms step_avg:34.44ms +step:200/1555 train_time:6891ms step_avg:34.46ms +step:201/1555 train_time:6922ms step_avg:34.44ms +step:202/1555 train_time:6959ms step_avg:34.45ms +step:203/1555 train_time:6990ms step_avg:34.43ms +step:204/1555 train_time:7028ms step_avg:34.45ms +step:205/1555 train_time:7058ms step_avg:34.43ms +step:206/1555 train_time:7096ms step_avg:34.44ms +step:207/1555 train_time:7127ms step_avg:34.43ms +step:208/1555 train_time:7164ms step_avg:34.44ms +step:209/1555 train_time:7195ms step_avg:34.43ms +step:210/1555 train_time:7232ms step_avg:34.44ms +step:211/1555 train_time:7264ms step_avg:34.42ms +step:212/1555 train_time:7301ms step_avg:34.44ms +step:213/1555 train_time:7332ms step_avg:34.42ms +step:214/1555 train_time:7369ms step_avg:34.43ms +step:215/1555 train_time:7400ms step_avg:34.42ms +step:216/1555 train_time:7437ms step_avg:34.43ms +step:217/1555 train_time:7468ms step_avg:34.41ms +step:218/1555 train_time:7506ms step_avg:34.43ms +step:219/1555 train_time:7537ms step_avg:34.42ms +step:220/1555 train_time:7575ms step_avg:34.43ms +step:221/1555 train_time:7606ms step_avg:34.42ms +step:222/1555 train_time:7643ms step_avg:34.43ms +step:223/1555 train_time:7674ms step_avg:34.41ms +step:224/1555 train_time:7711ms step_avg:34.43ms +step:225/1555 train_time:7742ms step_avg:34.41ms +step:226/1555 train_time:7780ms step_avg:34.42ms +step:227/1555 train_time:7812ms step_avg:34.41ms +step:228/1555 train_time:7849ms step_avg:34.43ms +step:229/1555 train_time:7880ms step_avg:34.41ms +step:230/1555 train_time:7917ms step_avg:34.42ms +step:231/1555 train_time:7948ms step_avg:34.41ms +step:232/1555 train_time:7985ms step_avg:34.42ms +step:233/1555 train_time:8016ms step_avg:34.40ms +step:234/1555 train_time:8053ms step_avg:34.42ms +step:235/1555 train_time:8084ms step_avg:34.40ms +step:236/1555 train_time:8121ms step_avg:34.41ms +step:237/1555 train_time:8152ms step_avg:34.40ms +step:238/1555 train_time:8190ms step_avg:34.41ms +step:239/1555 train_time:8221ms step_avg:34.40ms +step:240/1555 train_time:8258ms step_avg:34.41ms +step:241/1555 train_time:8289ms step_avg:34.39ms +step:242/1555 train_time:8326ms step_avg:34.40ms +step:243/1555 train_time:8357ms step_avg:34.39ms +step:244/1555 train_time:8394ms step_avg:34.40ms +step:245/1555 train_time:8425ms step_avg:34.39ms +step:246/1555 train_time:8462ms step_avg:34.40ms +step:247/1555 train_time:8494ms step_avg:34.39ms +step:248/1555 train_time:8532ms step_avg:34.40ms +step:249/1555 train_time:8563ms step_avg:34.39ms +step:250/1555 train_time:8600ms step_avg:34.40ms +step:250/1555 val_loss:4.5535 train_time:8650ms step_avg:34.60ms +step:251/1555 train_time:8669ms step_avg:34.54ms +step:252/1555 train_time:8688ms step_avg:34.48ms +step:253/1555 train_time:8706ms step_avg:34.41ms +step:254/1555 train_time:8738ms step_avg:34.40ms +step:255/1555 train_time:8771ms step_avg:34.39ms +step:256/1555 train_time:8809ms step_avg:34.41ms +step:257/1555 train_time:8841ms step_avg:34.40ms +step:258/1555 train_time:8879ms step_avg:34.41ms +step:259/1555 train_time:8910ms step_avg:34.40ms +step:260/1555 train_time:8947ms step_avg:34.41ms +step:261/1555 train_time:8979ms step_avg:34.40ms +step:262/1555 train_time:9016ms step_avg:34.41ms +step:263/1555 train_time:9047ms step_avg:34.40ms +step:264/1555 train_time:9084ms step_avg:34.41ms +step:265/1555 train_time:9115ms step_avg:34.40ms +step:266/1555 train_time:9152ms step_avg:34.41ms +step:267/1555 train_time:9183ms step_avg:34.39ms +step:268/1555 train_time:9220ms step_avg:34.40ms +step:269/1555 train_time:9251ms step_avg:34.39ms +step:270/1555 train_time:9288ms step_avg:34.40ms +step:271/1555 train_time:9319ms step_avg:34.39ms +step:272/1555 train_time:9356ms step_avg:34.40ms +step:273/1555 train_time:9387ms step_avg:34.39ms +step:274/1555 train_time:9425ms step_avg:34.40ms +step:275/1555 train_time:9455ms step_avg:34.38ms +step:276/1555 train_time:9492ms step_avg:34.39ms +step:277/1555 train_time:9524ms step_avg:34.38ms +step:278/1555 train_time:9561ms step_avg:34.39ms +step:279/1555 train_time:9592ms step_avg:34.38ms +step:280/1555 train_time:9629ms step_avg:34.39ms +step:281/1555 train_time:9660ms step_avg:34.38ms +step:282/1555 train_time:9698ms step_avg:34.39ms +step:283/1555 train_time:9729ms step_avg:34.38ms +step:284/1555 train_time:9767ms step_avg:34.39ms +step:285/1555 train_time:9798ms step_avg:34.38ms +step:286/1555 train_time:9836ms step_avg:34.39ms +step:287/1555 train_time:9867ms step_avg:34.38ms +step:288/1555 train_time:9905ms step_avg:34.39ms +step:289/1555 train_time:9936ms step_avg:34.38ms +step:290/1555 train_time:9973ms step_avg:34.39ms +step:291/1555 train_time:10004ms step_avg:34.38ms +step:292/1555 train_time:10042ms step_avg:34.39ms +step:293/1555 train_time:10073ms step_avg:34.38ms +step:294/1555 train_time:10110ms step_avg:34.39ms +step:295/1555 train_time:10141ms step_avg:34.38ms +step:296/1555 train_time:10179ms step_avg:34.39ms +step:297/1555 train_time:10209ms step_avg:34.38ms +step:298/1555 train_time:10247ms step_avg:34.39ms +step:299/1555 train_time:10278ms step_avg:34.37ms +step:300/1555 train_time:10315ms step_avg:34.38ms +step:301/1555 train_time:10346ms step_avg:34.37ms +step:302/1555 train_time:10383ms step_avg:34.38ms +step:303/1555 train_time:10414ms step_avg:34.37ms +step:304/1555 train_time:10452ms step_avg:34.38ms +step:305/1555 train_time:10483ms step_avg:34.37ms +step:306/1555 train_time:10520ms step_avg:34.38ms +step:307/1555 train_time:10551ms step_avg:34.37ms +step:308/1555 train_time:10588ms step_avg:34.38ms +step:309/1555 train_time:10620ms step_avg:34.37ms +step:310/1555 train_time:10657ms step_avg:34.38ms +step:311/1555 train_time:10689ms step_avg:34.37ms +step:312/1555 train_time:10726ms step_avg:34.38ms +step:313/1555 train_time:10757ms step_avg:34.37ms +step:314/1555 train_time:10794ms step_avg:34.38ms +step:315/1555 train_time:10826ms step_avg:34.37ms +step:316/1555 train_time:10863ms step_avg:34.38ms +step:317/1555 train_time:10894ms step_avg:34.37ms +step:318/1555 train_time:10931ms step_avg:34.37ms +step:319/1555 train_time:10963ms step_avg:34.37ms +step:320/1555 train_time:11001ms step_avg:34.38ms +step:321/1555 train_time:11032ms step_avg:34.37ms +step:322/1555 train_time:11069ms step_avg:34.38ms +step:323/1555 train_time:11100ms step_avg:34.37ms +step:324/1555 train_time:11138ms step_avg:34.38ms +step:325/1555 train_time:11169ms step_avg:34.37ms +step:326/1555 train_time:11206ms step_avg:34.37ms +step:327/1555 train_time:11237ms step_avg:34.36ms +step:328/1555 train_time:11274ms step_avg:34.37ms +step:329/1555 train_time:11305ms step_avg:34.36ms +step:330/1555 train_time:11343ms step_avg:34.37ms +step:331/1555 train_time:11374ms step_avg:34.36ms +step:332/1555 train_time:11411ms step_avg:34.37ms +step:333/1555 train_time:11442ms step_avg:34.36ms +step:334/1555 train_time:11480ms step_avg:34.37ms +step:335/1555 train_time:11511ms step_avg:34.36ms +step:336/1555 train_time:11548ms step_avg:34.37ms +step:337/1555 train_time:11580ms step_avg:34.36ms +step:338/1555 train_time:11617ms step_avg:34.37ms +step:339/1555 train_time:11649ms step_avg:34.36ms +step:340/1555 train_time:11686ms step_avg:34.37ms +step:341/1555 train_time:11717ms step_avg:34.36ms +step:342/1555 train_time:11754ms step_avg:34.37ms +step:343/1555 train_time:11785ms step_avg:34.36ms +step:344/1555 train_time:11823ms step_avg:34.37ms +step:345/1555 train_time:11853ms step_avg:34.36ms +step:346/1555 train_time:11891ms step_avg:34.37ms +step:347/1555 train_time:11922ms step_avg:34.36ms +step:348/1555 train_time:11960ms step_avg:34.37ms +step:349/1555 train_time:11991ms step_avg:34.36ms +step:350/1555 train_time:12029ms step_avg:34.37ms +step:351/1555 train_time:12060ms step_avg:34.36ms +step:352/1555 train_time:12097ms step_avg:34.37ms +step:353/1555 train_time:12128ms step_avg:34.36ms +step:354/1555 train_time:12166ms step_avg:34.37ms +step:355/1555 train_time:12197ms step_avg:34.36ms +step:356/1555 train_time:12235ms step_avg:34.37ms +step:357/1555 train_time:12266ms step_avg:34.36ms +step:358/1555 train_time:12304ms step_avg:34.37ms +step:359/1555 train_time:12334ms step_avg:34.36ms +step:360/1555 train_time:12371ms step_avg:34.37ms +step:361/1555 train_time:12403ms step_avg:34.36ms +step:362/1555 train_time:12440ms step_avg:34.36ms +step:363/1555 train_time:12471ms step_avg:34.36ms +step:364/1555 train_time:12508ms step_avg:34.36ms +step:365/1555 train_time:12539ms step_avg:34.35ms +step:366/1555 train_time:12577ms step_avg:34.36ms +step:367/1555 train_time:12608ms step_avg:34.35ms +step:368/1555 train_time:12645ms step_avg:34.36ms +step:369/1555 train_time:12676ms step_avg:34.35ms +step:370/1555 train_time:12713ms step_avg:34.36ms +step:371/1555 train_time:12744ms step_avg:34.35ms +step:372/1555 train_time:12782ms step_avg:34.36ms +step:373/1555 train_time:12813ms step_avg:34.35ms +step:374/1555 train_time:12850ms step_avg:34.36ms +step:375/1555 train_time:12881ms step_avg:34.35ms +step:376/1555 train_time:12918ms step_avg:34.36ms +step:377/1555 train_time:12949ms step_avg:34.35ms +step:378/1555 train_time:12987ms step_avg:34.36ms +step:379/1555 train_time:13018ms step_avg:34.35ms +step:380/1555 train_time:13055ms step_avg:34.35ms +step:381/1555 train_time:13086ms step_avg:34.35ms +step:382/1555 train_time:13124ms step_avg:34.36ms +step:383/1555 train_time:13155ms step_avg:34.35ms +step:384/1555 train_time:13192ms step_avg:34.35ms +step:385/1555 train_time:13223ms step_avg:34.35ms +step:386/1555 train_time:13261ms step_avg:34.36ms +step:387/1555 train_time:13292ms step_avg:34.35ms +step:388/1555 train_time:13330ms step_avg:34.35ms +step:389/1555 train_time:13361ms step_avg:34.35ms +step:390/1555 train_time:13398ms step_avg:34.35ms +step:391/1555 train_time:13429ms step_avg:34.34ms +step:392/1555 train_time:13466ms step_avg:34.35ms +step:393/1555 train_time:13497ms step_avg:34.34ms +step:394/1555 train_time:13534ms step_avg:34.35ms +step:395/1555 train_time:13565ms step_avg:34.34ms +step:396/1555 train_time:13603ms step_avg:34.35ms +step:397/1555 train_time:13633ms step_avg:34.34ms +step:398/1555 train_time:13670ms step_avg:34.35ms +step:399/1555 train_time:13702ms step_avg:34.34ms +step:400/1555 train_time:13739ms step_avg:34.35ms +step:401/1555 train_time:13770ms step_avg:34.34ms +step:402/1555 train_time:13807ms step_avg:34.35ms +step:403/1555 train_time:13838ms step_avg:34.34ms +step:404/1555 train_time:13876ms step_avg:34.35ms +step:405/1555 train_time:13907ms step_avg:34.34ms +step:406/1555 train_time:13944ms step_avg:34.35ms +step:407/1555 train_time:13975ms step_avg:34.34ms +step:408/1555 train_time:14012ms step_avg:34.34ms +step:409/1555 train_time:14043ms step_avg:34.34ms +step:410/1555 train_time:14081ms step_avg:34.34ms +step:411/1555 train_time:14112ms step_avg:34.34ms +step:412/1555 train_time:14149ms step_avg:34.34ms +step:413/1555 train_time:14180ms step_avg:34.33ms +step:414/1555 train_time:14218ms step_avg:34.34ms +step:415/1555 train_time:14249ms step_avg:34.33ms +step:416/1555 train_time:14287ms step_avg:34.34ms +step:417/1555 train_time:14318ms step_avg:34.34ms +step:418/1555 train_time:14355ms step_avg:34.34ms +step:419/1555 train_time:14386ms step_avg:34.33ms +step:420/1555 train_time:14424ms step_avg:34.34ms +step:421/1555 train_time:14454ms step_avg:34.33ms +step:422/1555 train_time:14492ms step_avg:34.34ms +step:423/1555 train_time:14523ms step_avg:34.33ms +step:424/1555 train_time:14560ms step_avg:34.34ms +step:425/1555 train_time:14591ms step_avg:34.33ms +step:426/1555 train_time:14628ms step_avg:34.34ms +step:427/1555 train_time:14659ms step_avg:34.33ms +step:428/1555 train_time:14697ms step_avg:34.34ms +step:429/1555 train_time:14728ms step_avg:34.33ms +step:430/1555 train_time:14765ms step_avg:34.34ms +step:431/1555 train_time:14796ms step_avg:34.33ms +step:432/1555 train_time:14833ms step_avg:34.34ms +step:433/1555 train_time:14864ms step_avg:34.33ms +step:434/1555 train_time:14901ms step_avg:34.34ms +step:435/1555 train_time:14932ms step_avg:34.33ms +step:436/1555 train_time:14970ms step_avg:34.33ms +step:437/1555 train_time:15001ms step_avg:34.33ms +step:438/1555 train_time:15039ms step_avg:34.34ms +step:439/1555 train_time:15070ms step_avg:34.33ms +step:440/1555 train_time:15107ms step_avg:34.33ms +step:441/1555 train_time:15138ms step_avg:34.33ms +step:442/1555 train_time:15176ms step_avg:34.33ms +step:443/1555 train_time:15207ms step_avg:34.33ms +step:444/1555 train_time:15245ms step_avg:34.34ms +step:445/1555 train_time:15276ms step_avg:34.33ms +step:446/1555 train_time:15313ms step_avg:34.33ms +step:447/1555 train_time:15344ms step_avg:34.33ms +step:448/1555 train_time:15382ms step_avg:34.33ms +step:449/1555 train_time:15413ms step_avg:34.33ms +step:450/1555 train_time:15450ms step_avg:34.33ms +step:451/1555 train_time:15481ms step_avg:34.33ms +step:452/1555 train_time:15518ms step_avg:34.33ms +step:453/1555 train_time:15549ms step_avg:34.32ms +step:454/1555 train_time:15586ms step_avg:34.33ms +step:455/1555 train_time:15618ms step_avg:34.32ms +step:456/1555 train_time:15655ms step_avg:34.33ms +step:457/1555 train_time:15686ms step_avg:34.32ms +step:458/1555 train_time:15724ms step_avg:34.33ms +step:459/1555 train_time:15754ms step_avg:34.32ms +step:460/1555 train_time:15792ms step_avg:34.33ms +step:461/1555 train_time:15822ms step_avg:34.32ms +step:462/1555 train_time:15860ms step_avg:34.33ms +step:463/1555 train_time:15891ms step_avg:34.32ms +step:464/1555 train_time:15928ms step_avg:34.33ms +step:465/1555 train_time:15959ms step_avg:34.32ms +step:466/1555 train_time:15996ms step_avg:34.33ms +step:467/1555 train_time:16027ms step_avg:34.32ms +step:468/1555 train_time:16064ms step_avg:34.33ms +step:469/1555 train_time:16095ms step_avg:34.32ms +step:470/1555 train_time:16133ms step_avg:34.32ms +step:471/1555 train_time:16163ms step_avg:34.32ms +step:472/1555 train_time:16201ms step_avg:34.32ms +step:473/1555 train_time:16232ms step_avg:34.32ms +step:474/1555 train_time:16269ms step_avg:34.32ms +step:475/1555 train_time:16300ms step_avg:34.32ms +step:476/1555 train_time:16338ms step_avg:34.32ms +step:477/1555 train_time:16369ms step_avg:34.32ms +step:478/1555 train_time:16406ms step_avg:34.32ms +step:479/1555 train_time:16438ms step_avg:34.32ms +step:480/1555 train_time:16475ms step_avg:34.32ms +step:481/1555 train_time:16506ms step_avg:34.32ms +step:482/1555 train_time:16543ms step_avg:34.32ms +step:483/1555 train_time:16574ms step_avg:34.32ms +step:484/1555 train_time:16612ms step_avg:34.32ms +step:485/1555 train_time:16643ms step_avg:34.31ms +step:486/1555 train_time:16680ms step_avg:34.32ms +step:487/1555 train_time:16711ms step_avg:34.31ms +step:488/1555 train_time:16748ms step_avg:34.32ms +step:489/1555 train_time:16779ms step_avg:34.31ms +step:490/1555 train_time:16817ms step_avg:34.32ms +step:491/1555 train_time:16848ms step_avg:34.31ms +step:492/1555 train_time:16886ms step_avg:34.32ms +step:493/1555 train_time:16916ms step_avg:34.31ms +step:494/1555 train_time:16953ms step_avg:34.32ms +step:495/1555 train_time:16984ms step_avg:34.31ms +step:496/1555 train_time:17022ms step_avg:34.32ms +step:497/1555 train_time:17052ms step_avg:34.31ms +step:498/1555 train_time:17090ms step_avg:34.32ms +step:499/1555 train_time:17121ms step_avg:34.31ms +step:500/1555 train_time:17159ms step_avg:34.32ms +step:500/1555 val_loss:4.2305 train_time:17208ms step_avg:34.42ms +step:501/1555 train_time:17227ms step_avg:34.39ms +step:502/1555 train_time:17247ms step_avg:34.36ms +step:503/1555 train_time:17265ms step_avg:34.32ms +step:504/1555 train_time:17297ms step_avg:34.32ms +step:505/1555 train_time:17329ms step_avg:34.32ms +step:506/1555 train_time:17372ms step_avg:34.33ms +step:507/1555 train_time:17427ms step_avg:34.37ms +step:508/1555 train_time:17491ms step_avg:34.43ms +step:509/1555 train_time:17549ms step_avg:34.48ms +step:510/1555 train_time:17612ms step_avg:34.53ms +step:511/1555 train_time:17670ms step_avg:34.58ms +step:512/1555 train_time:17733ms step_avg:34.63ms +step:513/1555 train_time:17790ms step_avg:34.68ms +step:514/1555 train_time:17853ms step_avg:34.73ms +step:515/1555 train_time:17910ms step_avg:34.78ms +step:516/1555 train_time:17972ms step_avg:34.83ms +step:517/1555 train_time:18029ms step_avg:34.87ms +step:518/1555 train_time:18093ms step_avg:34.93ms +step:519/1555 train_time:18151ms step_avg:34.97ms +step:520/1555 train_time:18217ms step_avg:35.03ms +step:521/1555 train_time:18276ms step_avg:35.08ms +step:522/1555 train_time:18342ms step_avg:35.14ms +step:523/1555 train_time:18399ms step_avg:35.18ms +step:524/1555 train_time:18464ms step_avg:35.24ms +step:525/1555 train_time:18523ms step_avg:35.28ms +step:526/1555 train_time:18588ms step_avg:35.34ms +step:527/1555 train_time:18646ms step_avg:35.38ms +step:528/1555 train_time:18709ms step_avg:35.43ms +step:529/1555 train_time:18766ms step_avg:35.48ms +step:530/1555 train_time:18830ms step_avg:35.53ms +step:531/1555 train_time:18888ms step_avg:35.57ms +step:532/1555 train_time:18951ms step_avg:35.62ms +step:533/1555 train_time:19008ms step_avg:35.66ms +step:534/1555 train_time:19071ms step_avg:35.71ms +step:535/1555 train_time:19130ms step_avg:35.76ms +step:536/1555 train_time:19195ms step_avg:35.81ms +step:537/1555 train_time:19254ms step_avg:35.85ms +step:538/1555 train_time:19318ms step_avg:35.91ms +step:539/1555 train_time:19375ms step_avg:35.95ms +step:540/1555 train_time:19439ms step_avg:36.00ms +step:541/1555 train_time:19497ms step_avg:36.04ms +step:542/1555 train_time:19561ms step_avg:36.09ms +step:543/1555 train_time:19619ms step_avg:36.13ms +step:544/1555 train_time:19685ms step_avg:36.19ms +step:545/1555 train_time:19744ms step_avg:36.23ms +step:546/1555 train_time:19808ms step_avg:36.28ms +step:547/1555 train_time:19866ms step_avg:36.32ms +step:548/1555 train_time:19931ms step_avg:36.37ms +step:549/1555 train_time:19988ms step_avg:36.41ms +step:550/1555 train_time:20051ms step_avg:36.46ms +step:551/1555 train_time:20109ms step_avg:36.50ms +step:552/1555 train_time:20174ms step_avg:36.55ms +step:553/1555 train_time:20232ms step_avg:36.59ms +step:554/1555 train_time:20296ms step_avg:36.64ms +step:555/1555 train_time:20353ms step_avg:36.67ms +step:556/1555 train_time:20418ms step_avg:36.72ms +step:557/1555 train_time:20475ms step_avg:36.76ms +step:558/1555 train_time:20539ms step_avg:36.81ms +step:559/1555 train_time:20597ms step_avg:36.85ms +step:560/1555 train_time:20661ms step_avg:36.89ms +step:561/1555 train_time:20719ms step_avg:36.93ms +step:562/1555 train_time:20783ms step_avg:36.98ms +step:563/1555 train_time:20842ms step_avg:37.02ms +step:564/1555 train_time:20906ms step_avg:37.07ms +step:565/1555 train_time:20965ms step_avg:37.11ms +step:566/1555 train_time:21030ms step_avg:37.16ms +step:567/1555 train_time:21088ms step_avg:37.19ms +step:568/1555 train_time:21152ms step_avg:37.24ms +step:569/1555 train_time:21209ms step_avg:37.27ms +step:570/1555 train_time:21273ms step_avg:37.32ms +step:571/1555 train_time:21332ms step_avg:37.36ms +step:572/1555 train_time:21396ms step_avg:37.41ms +step:573/1555 train_time:21453ms step_avg:37.44ms +step:574/1555 train_time:21517ms step_avg:37.49ms +step:575/1555 train_time:21574ms step_avg:37.52ms +step:576/1555 train_time:21639ms step_avg:37.57ms +step:577/1555 train_time:21696ms step_avg:37.60ms +step:578/1555 train_time:21760ms step_avg:37.65ms +step:579/1555 train_time:21817ms step_avg:37.68ms +step:580/1555 train_time:21882ms step_avg:37.73ms +step:581/1555 train_time:21941ms step_avg:37.76ms +step:582/1555 train_time:22006ms step_avg:37.81ms +step:583/1555 train_time:22063ms step_avg:37.84ms +step:584/1555 train_time:22128ms step_avg:37.89ms +step:585/1555 train_time:22186ms step_avg:37.92ms +step:586/1555 train_time:22252ms step_avg:37.97ms +step:587/1555 train_time:22309ms step_avg:38.00ms +step:588/1555 train_time:22373ms step_avg:38.05ms +step:589/1555 train_time:22431ms step_avg:38.08ms +step:590/1555 train_time:22495ms step_avg:38.13ms +step:591/1555 train_time:22552ms step_avg:38.16ms +step:592/1555 train_time:22616ms step_avg:38.20ms +step:593/1555 train_time:22673ms step_avg:38.24ms +step:594/1555 train_time:22737ms step_avg:38.28ms +step:595/1555 train_time:22795ms step_avg:38.31ms +step:596/1555 train_time:22859ms step_avg:38.35ms +step:597/1555 train_time:22916ms step_avg:38.38ms +step:598/1555 train_time:22980ms step_avg:38.43ms +step:599/1555 train_time:23037ms step_avg:38.46ms +step:600/1555 train_time:23102ms step_avg:38.50ms +step:601/1555 train_time:23159ms step_avg:38.53ms +step:602/1555 train_time:23225ms step_avg:38.58ms +step:603/1555 train_time:23284ms step_avg:38.61ms +step:604/1555 train_time:23348ms step_avg:38.66ms +step:605/1555 train_time:23407ms step_avg:38.69ms +step:606/1555 train_time:23471ms step_avg:38.73ms +step:607/1555 train_time:23530ms step_avg:38.76ms +step:608/1555 train_time:23595ms step_avg:38.81ms +step:609/1555 train_time:23652ms step_avg:38.84ms +step:610/1555 train_time:23715ms step_avg:38.88ms +step:611/1555 train_time:23774ms step_avg:38.91ms +step:612/1555 train_time:23837ms step_avg:38.95ms +step:613/1555 train_time:23895ms step_avg:38.98ms +step:614/1555 train_time:23957ms step_avg:39.02ms +step:615/1555 train_time:24015ms step_avg:39.05ms +step:616/1555 train_time:24080ms step_avg:39.09ms +step:617/1555 train_time:24138ms step_avg:39.12ms +step:618/1555 train_time:24203ms step_avg:39.16ms +step:619/1555 train_time:24260ms step_avg:39.19ms +step:620/1555 train_time:24325ms step_avg:39.23ms +step:621/1555 train_time:24384ms step_avg:39.27ms +step:622/1555 train_time:24449ms step_avg:39.31ms +step:623/1555 train_time:24507ms step_avg:39.34ms +step:624/1555 train_time:24572ms step_avg:39.38ms +step:625/1555 train_time:24630ms step_avg:39.41ms +step:626/1555 train_time:24694ms step_avg:39.45ms +step:627/1555 train_time:24752ms step_avg:39.48ms +step:628/1555 train_time:24816ms step_avg:39.52ms +step:629/1555 train_time:24873ms step_avg:39.54ms +step:630/1555 train_time:24937ms step_avg:39.58ms +step:631/1555 train_time:24994ms step_avg:39.61ms +step:632/1555 train_time:25057ms step_avg:39.65ms +step:633/1555 train_time:25114ms step_avg:39.67ms +step:634/1555 train_time:25179ms step_avg:39.71ms +step:635/1555 train_time:25236ms step_avg:39.74ms +step:636/1555 train_time:25301ms step_avg:39.78ms +step:637/1555 train_time:25359ms step_avg:39.81ms +step:638/1555 train_time:25423ms step_avg:39.85ms +step:639/1555 train_time:25482ms step_avg:39.88ms +step:640/1555 train_time:25547ms step_avg:39.92ms +step:641/1555 train_time:25606ms step_avg:39.95ms +step:642/1555 train_time:25670ms step_avg:39.98ms +step:643/1555 train_time:25728ms step_avg:40.01ms +step:644/1555 train_time:25793ms step_avg:40.05ms +step:645/1555 train_time:25850ms step_avg:40.08ms +step:646/1555 train_time:25914ms step_avg:40.11ms +step:647/1555 train_time:25972ms step_avg:40.14ms +step:648/1555 train_time:26036ms step_avg:40.18ms +step:649/1555 train_time:26093ms step_avg:40.21ms +step:650/1555 train_time:26157ms step_avg:40.24ms +step:651/1555 train_time:26213ms step_avg:40.27ms +step:652/1555 train_time:26278ms step_avg:40.30ms +step:653/1555 train_time:26335ms step_avg:40.33ms +step:654/1555 train_time:26399ms step_avg:40.37ms +step:655/1555 train_time:26456ms step_avg:40.39ms +step:656/1555 train_time:26521ms step_avg:40.43ms +step:657/1555 train_time:26578ms step_avg:40.45ms +step:658/1555 train_time:26643ms step_avg:40.49ms +step:659/1555 train_time:26702ms step_avg:40.52ms +step:660/1555 train_time:26767ms step_avg:40.56ms +step:661/1555 train_time:26826ms step_avg:40.58ms +step:662/1555 train_time:26891ms step_avg:40.62ms +step:663/1555 train_time:26948ms step_avg:40.65ms +step:664/1555 train_time:27013ms step_avg:40.68ms +step:665/1555 train_time:27071ms step_avg:40.71ms +step:666/1555 train_time:27135ms step_avg:40.74ms +step:667/1555 train_time:27193ms step_avg:40.77ms +step:668/1555 train_time:27256ms step_avg:40.80ms +step:669/1555 train_time:27313ms step_avg:40.83ms +step:670/1555 train_time:27376ms step_avg:40.86ms +step:671/1555 train_time:27434ms step_avg:40.88ms +step:672/1555 train_time:27498ms step_avg:40.92ms +step:673/1555 train_time:27555ms step_avg:40.94ms +step:674/1555 train_time:27619ms step_avg:40.98ms +step:675/1555 train_time:27676ms step_avg:41.00ms +step:676/1555 train_time:27742ms step_avg:41.04ms +step:677/1555 train_time:27799ms step_avg:41.06ms +step:678/1555 train_time:27865ms step_avg:41.10ms +step:679/1555 train_time:27924ms step_avg:41.12ms +step:680/1555 train_time:27989ms step_avg:41.16ms +step:681/1555 train_time:28047ms step_avg:41.19ms +step:682/1555 train_time:28112ms step_avg:41.22ms +step:683/1555 train_time:28170ms step_avg:41.24ms +step:684/1555 train_time:28234ms step_avg:41.28ms +step:685/1555 train_time:28291ms step_avg:41.30ms +step:686/1555 train_time:28354ms step_avg:41.33ms +step:687/1555 train_time:28412ms step_avg:41.36ms +step:688/1555 train_time:28475ms step_avg:41.39ms +step:689/1555 train_time:28532ms step_avg:41.41ms +step:690/1555 train_time:28596ms step_avg:41.44ms +step:691/1555 train_time:28654ms step_avg:41.47ms +step:692/1555 train_time:28718ms step_avg:41.50ms +step:693/1555 train_time:28775ms step_avg:41.52ms +step:694/1555 train_time:28840ms step_avg:41.56ms +step:695/1555 train_time:28898ms step_avg:41.58ms +step:696/1555 train_time:28962ms step_avg:41.61ms +step:697/1555 train_time:29020ms step_avg:41.64ms +step:698/1555 train_time:29085ms step_avg:41.67ms +step:699/1555 train_time:29144ms step_avg:41.69ms +step:700/1555 train_time:29209ms step_avg:41.73ms +step:701/1555 train_time:29267ms step_avg:41.75ms +step:702/1555 train_time:29331ms step_avg:41.78ms +step:703/1555 train_time:29389ms step_avg:41.81ms +step:704/1555 train_time:29453ms step_avg:41.84ms +step:705/1555 train_time:29511ms step_avg:41.86ms +step:706/1555 train_time:29575ms step_avg:41.89ms +step:707/1555 train_time:29633ms step_avg:41.91ms +step:708/1555 train_time:29696ms step_avg:41.94ms +step:709/1555 train_time:29753ms step_avg:41.97ms +step:710/1555 train_time:29817ms step_avg:42.00ms +step:711/1555 train_time:29875ms step_avg:42.02ms +step:712/1555 train_time:29939ms step_avg:42.05ms +step:713/1555 train_time:29997ms step_avg:42.07ms +step:714/1555 train_time:30062ms step_avg:42.10ms +step:715/1555 train_time:30119ms step_avg:42.12ms +step:716/1555 train_time:30184ms step_avg:42.16ms +step:717/1555 train_time:30243ms step_avg:42.18ms +step:718/1555 train_time:30307ms step_avg:42.21ms +step:719/1555 train_time:30365ms step_avg:42.23ms +step:720/1555 train_time:30430ms step_avg:42.26ms +step:721/1555 train_time:30489ms step_avg:42.29ms +step:722/1555 train_time:30552ms step_avg:42.32ms +step:723/1555 train_time:30610ms step_avg:42.34ms +step:724/1555 train_time:30674ms step_avg:42.37ms +step:725/1555 train_time:30732ms step_avg:42.39ms +step:726/1555 train_time:30796ms step_avg:42.42ms +step:727/1555 train_time:30852ms step_avg:42.44ms +step:728/1555 train_time:30916ms step_avg:42.47ms +step:729/1555 train_time:30974ms step_avg:42.49ms +step:730/1555 train_time:31038ms step_avg:42.52ms +step:731/1555 train_time:31094ms step_avg:42.54ms +step:732/1555 train_time:31159ms step_avg:42.57ms +step:733/1555 train_time:31216ms step_avg:42.59ms +step:734/1555 train_time:31281ms step_avg:42.62ms +step:735/1555 train_time:31338ms step_avg:42.64ms +step:736/1555 train_time:31403ms step_avg:42.67ms +step:737/1555 train_time:31462ms step_avg:42.69ms +step:738/1555 train_time:31528ms step_avg:42.72ms +step:739/1555 train_time:31585ms step_avg:42.74ms +step:740/1555 train_time:31650ms step_avg:42.77ms +step:741/1555 train_time:31708ms step_avg:42.79ms +step:742/1555 train_time:31772ms step_avg:42.82ms +step:743/1555 train_time:31831ms step_avg:42.84ms +step:744/1555 train_time:31894ms step_avg:42.87ms +step:745/1555 train_time:31952ms step_avg:42.89ms +step:746/1555 train_time:32016ms step_avg:42.92ms +step:747/1555 train_time:32074ms step_avg:42.94ms +step:748/1555 train_time:32138ms step_avg:42.97ms +step:749/1555 train_time:32195ms step_avg:42.98ms +step:750/1555 train_time:32260ms step_avg:43.01ms +step:750/1555 val_loss:3.8780 train_time:32341ms step_avg:43.12ms +step:751/1555 train_time:32361ms step_avg:43.09ms +step:752/1555 train_time:32381ms step_avg:43.06ms +step:753/1555 train_time:32438ms step_avg:43.08ms +step:754/1555 train_time:32506ms step_avg:43.11ms +step:755/1555 train_time:32563ms step_avg:43.13ms +step:756/1555 train_time:32627ms step_avg:43.16ms +step:757/1555 train_time:32685ms step_avg:43.18ms +step:758/1555 train_time:32749ms step_avg:43.20ms +step:759/1555 train_time:32806ms step_avg:43.22ms +step:760/1555 train_time:32870ms step_avg:43.25ms +step:761/1555 train_time:32926ms step_avg:43.27ms +step:762/1555 train_time:32991ms step_avg:43.30ms +step:763/1555 train_time:33048ms step_avg:43.31ms +step:764/1555 train_time:33111ms step_avg:43.34ms +step:765/1555 train_time:33169ms step_avg:43.36ms +step:766/1555 train_time:33234ms step_avg:43.39ms +step:767/1555 train_time:33293ms step_avg:43.41ms +step:768/1555 train_time:33358ms step_avg:43.43ms +step:769/1555 train_time:33417ms step_avg:43.45ms +step:770/1555 train_time:33482ms step_avg:43.48ms +step:771/1555 train_time:33540ms step_avg:43.50ms +step:772/1555 train_time:33605ms step_avg:43.53ms +step:773/1555 train_time:33662ms step_avg:43.55ms +step:774/1555 train_time:33725ms step_avg:43.57ms +step:775/1555 train_time:33782ms step_avg:43.59ms +step:776/1555 train_time:33845ms step_avg:43.61ms +step:777/1555 train_time:33902ms step_avg:43.63ms +step:778/1555 train_time:33965ms step_avg:43.66ms +step:779/1555 train_time:34022ms step_avg:43.67ms +step:780/1555 train_time:34086ms step_avg:43.70ms +step:781/1555 train_time:34142ms step_avg:43.72ms +step:782/1555 train_time:34207ms step_avg:43.74ms +step:783/1555 train_time:34264ms step_avg:43.76ms +step:784/1555 train_time:34330ms step_avg:43.79ms +step:785/1555 train_time:34389ms step_avg:43.81ms +step:786/1555 train_time:34455ms step_avg:43.84ms +step:787/1555 train_time:34514ms step_avg:43.85ms +step:788/1555 train_time:34578ms step_avg:43.88ms +step:789/1555 train_time:34636ms step_avg:43.90ms +step:790/1555 train_time:34701ms step_avg:43.92ms +step:791/1555 train_time:34758ms step_avg:43.94ms +step:792/1555 train_time:34822ms step_avg:43.97ms +step:793/1555 train_time:34879ms step_avg:43.98ms +step:794/1555 train_time:34942ms step_avg:44.01ms +step:795/1555 train_time:35000ms step_avg:44.03ms +step:796/1555 train_time:35063ms step_avg:44.05ms +step:797/1555 train_time:35121ms step_avg:44.07ms +step:798/1555 train_time:35185ms step_avg:44.09ms +step:799/1555 train_time:35243ms step_avg:44.11ms +step:800/1555 train_time:35307ms step_avg:44.13ms +step:801/1555 train_time:35363ms step_avg:44.15ms +step:802/1555 train_time:35428ms step_avg:44.17ms +step:803/1555 train_time:35486ms step_avg:44.19ms +step:804/1555 train_time:35551ms step_avg:44.22ms +step:805/1555 train_time:35610ms step_avg:44.24ms +step:806/1555 train_time:35675ms step_avg:44.26ms +step:807/1555 train_time:35734ms step_avg:44.28ms +step:808/1555 train_time:35798ms step_avg:44.30ms +step:809/1555 train_time:35855ms step_avg:44.32ms +step:810/1555 train_time:35920ms step_avg:44.35ms +step:811/1555 train_time:35977ms step_avg:44.36ms +step:812/1555 train_time:36041ms step_avg:44.39ms +step:813/1555 train_time:36100ms step_avg:44.40ms +step:814/1555 train_time:36162ms step_avg:44.43ms +step:815/1555 train_time:36219ms step_avg:44.44ms +step:816/1555 train_time:36284ms step_avg:44.47ms +step:817/1555 train_time:36341ms step_avg:44.48ms +step:818/1555 train_time:36405ms step_avg:44.51ms +step:819/1555 train_time:36462ms step_avg:44.52ms +step:820/1555 train_time:36527ms step_avg:44.55ms +step:821/1555 train_time:36584ms step_avg:44.56ms +step:822/1555 train_time:36649ms step_avg:44.59ms +step:823/1555 train_time:36707ms step_avg:44.60ms +step:824/1555 train_time:36771ms step_avg:44.63ms +step:825/1555 train_time:36830ms step_avg:44.64ms +step:826/1555 train_time:36895ms step_avg:44.67ms +step:827/1555 train_time:36953ms step_avg:44.68ms +step:828/1555 train_time:37017ms step_avg:44.71ms +step:829/1555 train_time:37075ms step_avg:44.72ms +step:830/1555 train_time:37139ms step_avg:44.75ms +step:831/1555 train_time:37196ms step_avg:44.76ms +step:832/1555 train_time:37261ms step_avg:44.78ms +step:833/1555 train_time:37319ms step_avg:44.80ms +step:834/1555 train_time:37383ms step_avg:44.82ms +step:835/1555 train_time:37441ms step_avg:44.84ms +step:836/1555 train_time:37505ms step_avg:44.86ms +step:837/1555 train_time:37562ms step_avg:44.88ms +step:838/1555 train_time:37625ms step_avg:44.90ms +step:839/1555 train_time:37683ms step_avg:44.91ms +step:840/1555 train_time:37747ms step_avg:44.94ms +step:841/1555 train_time:37804ms step_avg:44.95ms +step:842/1555 train_time:37868ms step_avg:44.97ms +step:843/1555 train_time:37926ms step_avg:44.99ms +step:844/1555 train_time:37991ms step_avg:45.01ms +step:845/1555 train_time:38049ms step_avg:45.03ms +step:846/1555 train_time:38114ms step_avg:45.05ms +step:847/1555 train_time:38172ms step_avg:45.07ms +step:848/1555 train_time:38236ms step_avg:45.09ms +step:849/1555 train_time:38295ms step_avg:45.11ms +step:850/1555 train_time:38360ms step_avg:45.13ms +step:851/1555 train_time:38418ms step_avg:45.14ms +step:852/1555 train_time:38482ms step_avg:45.17ms +step:853/1555 train_time:38539ms step_avg:45.18ms +step:854/1555 train_time:38604ms step_avg:45.20ms +step:855/1555 train_time:38662ms step_avg:45.22ms +step:856/1555 train_time:38725ms step_avg:45.24ms +step:857/1555 train_time:38783ms step_avg:45.25ms +step:858/1555 train_time:38847ms step_avg:45.28ms +step:859/1555 train_time:38904ms step_avg:45.29ms +step:860/1555 train_time:38968ms step_avg:45.31ms +step:861/1555 train_time:39025ms step_avg:45.33ms +step:862/1555 train_time:39089ms step_avg:45.35ms +step:863/1555 train_time:39147ms step_avg:45.36ms +step:864/1555 train_time:39212ms step_avg:45.38ms +step:865/1555 train_time:39270ms step_avg:45.40ms +step:866/1555 train_time:39336ms step_avg:45.42ms +step:867/1555 train_time:39393ms step_avg:45.44ms +step:868/1555 train_time:39458ms step_avg:45.46ms +step:869/1555 train_time:39516ms step_avg:45.47ms +step:870/1555 train_time:39580ms step_avg:45.49ms +step:871/1555 train_time:39639ms step_avg:45.51ms +step:872/1555 train_time:39703ms step_avg:45.53ms +step:873/1555 train_time:39760ms step_avg:45.54ms +step:874/1555 train_time:39823ms step_avg:45.56ms +step:875/1555 train_time:39880ms step_avg:45.58ms +step:876/1555 train_time:39945ms step_avg:45.60ms +step:877/1555 train_time:40001ms step_avg:45.61ms +step:878/1555 train_time:40065ms step_avg:45.63ms +step:879/1555 train_time:40122ms step_avg:45.65ms +step:880/1555 train_time:40187ms step_avg:45.67ms +step:881/1555 train_time:40245ms step_avg:45.68ms +step:882/1555 train_time:40310ms step_avg:45.70ms +step:883/1555 train_time:40367ms step_avg:45.72ms +step:884/1555 train_time:40432ms step_avg:45.74ms +step:885/1555 train_time:40489ms step_avg:45.75ms +step:886/1555 train_time:40555ms step_avg:45.77ms +step:887/1555 train_time:40613ms step_avg:45.79ms +step:888/1555 train_time:40678ms step_avg:45.81ms +step:889/1555 train_time:40735ms step_avg:45.82ms +step:890/1555 train_time:40800ms step_avg:45.84ms +step:891/1555 train_time:40858ms step_avg:45.86ms +step:892/1555 train_time:40921ms step_avg:45.88ms +step:893/1555 train_time:40979ms step_avg:45.89ms +step:894/1555 train_time:41044ms step_avg:45.91ms +step:895/1555 train_time:41101ms step_avg:45.92ms +step:896/1555 train_time:41164ms step_avg:45.94ms +step:897/1555 train_time:41222ms step_avg:45.95ms +step:898/1555 train_time:41287ms step_avg:45.98ms +step:899/1555 train_time:41343ms step_avg:45.99ms +step:900/1555 train_time:41407ms step_avg:46.01ms +step:901/1555 train_time:41463ms step_avg:46.02ms +step:902/1555 train_time:41529ms step_avg:46.04ms +step:903/1555 train_time:41586ms step_avg:46.05ms +step:904/1555 train_time:41650ms step_avg:46.07ms +step:905/1555 train_time:41709ms step_avg:46.09ms +step:906/1555 train_time:41773ms step_avg:46.11ms +step:907/1555 train_time:41831ms step_avg:46.12ms +step:908/1555 train_time:41896ms step_avg:46.14ms +step:909/1555 train_time:41954ms step_avg:46.15ms +step:910/1555 train_time:42018ms step_avg:46.17ms +step:911/1555 train_time:42076ms step_avg:46.19ms +step:912/1555 train_time:42140ms step_avg:46.21ms +step:913/1555 train_time:42199ms step_avg:46.22ms +step:914/1555 train_time:42262ms step_avg:46.24ms +step:915/1555 train_time:42320ms step_avg:46.25ms +step:916/1555 train_time:42384ms step_avg:46.27ms +step:917/1555 train_time:42441ms step_avg:46.28ms +step:918/1555 train_time:42505ms step_avg:46.30ms +step:919/1555 train_time:42563ms step_avg:46.31ms +step:920/1555 train_time:42627ms step_avg:46.33ms +step:921/1555 train_time:42684ms step_avg:46.35ms +step:922/1555 train_time:42749ms step_avg:46.37ms +step:923/1555 train_time:42807ms step_avg:46.38ms +step:924/1555 train_time:42872ms step_avg:46.40ms +step:925/1555 train_time:42931ms step_avg:46.41ms +step:926/1555 train_time:42995ms step_avg:46.43ms +step:927/1555 train_time:43053ms step_avg:46.44ms +step:928/1555 train_time:43118ms step_avg:46.46ms +step:929/1555 train_time:43176ms step_avg:46.48ms +step:930/1555 train_time:43241ms step_avg:46.50ms +step:931/1555 train_time:43299ms step_avg:46.51ms +step:932/1555 train_time:43363ms step_avg:46.53ms +step:933/1555 train_time:43421ms step_avg:46.54ms +step:934/1555 train_time:43484ms step_avg:46.56ms +step:935/1555 train_time:43542ms step_avg:46.57ms +step:936/1555 train_time:43605ms step_avg:46.59ms +step:937/1555 train_time:43663ms step_avg:46.60ms +step:938/1555 train_time:43727ms step_avg:46.62ms +step:939/1555 train_time:43783ms step_avg:46.63ms +step:940/1555 train_time:43847ms step_avg:46.65ms +step:941/1555 train_time:43905ms step_avg:46.66ms +step:942/1555 train_time:43969ms step_avg:46.68ms +step:943/1555 train_time:44027ms step_avg:46.69ms +step:944/1555 train_time:44092ms step_avg:46.71ms +step:945/1555 train_time:44151ms step_avg:46.72ms +step:946/1555 train_time:44216ms step_avg:46.74ms +step:947/1555 train_time:44273ms step_avg:46.75ms +step:948/1555 train_time:44338ms step_avg:46.77ms +step:949/1555 train_time:44396ms step_avg:46.78ms +step:950/1555 train_time:44460ms step_avg:46.80ms +step:951/1555 train_time:44518ms step_avg:46.81ms +step:952/1555 train_time:44581ms step_avg:46.83ms +step:953/1555 train_time:44640ms step_avg:46.84ms +step:954/1555 train_time:44704ms step_avg:46.86ms +step:955/1555 train_time:44761ms step_avg:46.87ms +step:956/1555 train_time:44825ms step_avg:46.89ms +step:957/1555 train_time:44882ms step_avg:46.90ms +step:958/1555 train_time:44946ms step_avg:46.92ms +step:959/1555 train_time:45004ms step_avg:46.93ms +step:960/1555 train_time:45069ms step_avg:46.95ms +step:961/1555 train_time:45126ms step_avg:46.96ms +step:962/1555 train_time:45190ms step_avg:46.97ms +step:963/1555 train_time:45248ms step_avg:46.99ms +step:964/1555 train_time:45313ms step_avg:47.01ms +step:965/1555 train_time:45371ms step_avg:47.02ms +step:966/1555 train_time:45436ms step_avg:47.03ms +step:967/1555 train_time:45495ms step_avg:47.05ms +step:968/1555 train_time:45559ms step_avg:47.06ms +step:969/1555 train_time:45617ms step_avg:47.08ms +step:970/1555 train_time:45681ms step_avg:47.09ms +step:971/1555 train_time:45738ms step_avg:47.10ms +step:972/1555 train_time:45802ms step_avg:47.12ms +step:973/1555 train_time:45860ms step_avg:47.13ms +step:974/1555 train_time:45923ms step_avg:47.15ms +step:975/1555 train_time:45981ms step_avg:47.16ms +step:976/1555 train_time:46045ms step_avg:47.18ms +step:977/1555 train_time:46102ms step_avg:47.19ms +step:978/1555 train_time:46166ms step_avg:47.20ms +step:979/1555 train_time:46223ms step_avg:47.21ms +step:980/1555 train_time:46288ms step_avg:47.23ms +step:981/1555 train_time:46344ms step_avg:47.24ms +step:982/1555 train_time:46409ms step_avg:47.26ms +step:983/1555 train_time:46467ms step_avg:47.27ms +step:984/1555 train_time:46532ms step_avg:47.29ms +step:985/1555 train_time:46590ms step_avg:47.30ms +step:986/1555 train_time:46655ms step_avg:47.32ms +step:987/1555 train_time:46713ms step_avg:47.33ms +step:988/1555 train_time:46778ms step_avg:47.35ms +step:989/1555 train_time:46836ms step_avg:47.36ms +step:990/1555 train_time:46901ms step_avg:47.38ms +step:991/1555 train_time:46959ms step_avg:47.39ms +step:992/1555 train_time:47022ms step_avg:47.40ms +step:993/1555 train_time:47080ms step_avg:47.41ms +step:994/1555 train_time:47144ms step_avg:47.43ms +step:995/1555 train_time:47201ms step_avg:47.44ms +step:996/1555 train_time:47264ms step_avg:47.45ms +step:997/1555 train_time:47321ms step_avg:47.46ms +step:998/1555 train_time:47386ms step_avg:47.48ms +step:999/1555 train_time:47442ms step_avg:47.49ms +step:1000/1555 train_time:47507ms step_avg:47.51ms +step:1000/1555 val_loss:3.5716 train_time:47590ms step_avg:47.59ms +step:1001/1555 train_time:47610ms step_avg:47.56ms +step:1002/1555 train_time:47631ms step_avg:47.54ms +step:1003/1555 train_time:47691ms step_avg:47.55ms +step:1004/1555 train_time:47759ms step_avg:47.57ms +step:1005/1555 train_time:47819ms step_avg:47.58ms +step:1006/1555 train_time:47884ms step_avg:47.60ms +step:1007/1555 train_time:47940ms step_avg:47.61ms +step:1008/1555 train_time:48003ms step_avg:47.62ms +step:1009/1555 train_time:48060ms step_avg:47.63ms +step:1010/1555 train_time:48123ms step_avg:47.65ms +step:1011/1555 train_time:48183ms step_avg:47.66ms +step:1012/1555 train_time:48268ms step_avg:47.70ms +step:1013/1555 train_time:48351ms step_avg:47.73ms +step:1014/1555 train_time:48442ms step_avg:47.77ms +step:1015/1555 train_time:48524ms step_avg:47.81ms +step:1016/1555 train_time:48615ms step_avg:47.85ms +step:1017/1555 train_time:48702ms step_avg:47.89ms +step:1018/1555 train_time:48792ms step_avg:47.93ms +step:1019/1555 train_time:48879ms step_avg:47.97ms +step:1020/1555 train_time:48969ms step_avg:48.01ms +step:1021/1555 train_time:49055ms step_avg:48.05ms +step:1022/1555 train_time:49145ms step_avg:48.09ms +step:1023/1555 train_time:49227ms step_avg:48.12ms +step:1024/1555 train_time:49316ms step_avg:48.16ms +step:1025/1555 train_time:49400ms step_avg:48.20ms +step:1026/1555 train_time:49488ms step_avg:48.23ms +step:1027/1555 train_time:49573ms step_avg:48.27ms +step:1028/1555 train_time:49664ms step_avg:48.31ms +step:1029/1555 train_time:49748ms step_avg:48.35ms +step:1030/1555 train_time:49841ms step_avg:48.39ms +step:1031/1555 train_time:49925ms step_avg:48.42ms +step:1032/1555 train_time:50015ms step_avg:48.46ms +step:1033/1555 train_time:50100ms step_avg:48.50ms +step:1034/1555 train_time:50189ms step_avg:48.54ms +step:1035/1555 train_time:50273ms step_avg:48.57ms +step:1036/1555 train_time:50363ms step_avg:48.61ms +step:1037/1555 train_time:50445ms step_avg:48.64ms +step:1038/1555 train_time:50534ms step_avg:48.68ms +step:1039/1555 train_time:50619ms step_avg:48.72ms +step:1040/1555 train_time:50710ms step_avg:48.76ms +step:1041/1555 train_time:50795ms step_avg:48.79ms +step:1042/1555 train_time:50886ms step_avg:48.83ms +step:1043/1555 train_time:50970ms step_avg:48.87ms +step:1044/1555 train_time:51061ms step_avg:48.91ms +step:1045/1555 train_time:51144ms step_avg:48.94ms +step:1046/1555 train_time:51234ms step_avg:48.98ms +step:1047/1555 train_time:51317ms step_avg:49.01ms +step:1048/1555 train_time:51407ms step_avg:49.05ms +step:1049/1555 train_time:51490ms step_avg:49.08ms +step:1050/1555 train_time:51581ms step_avg:49.12ms +step:1051/1555 train_time:51665ms step_avg:49.16ms +step:1052/1555 train_time:51757ms step_avg:49.20ms +step:1053/1555 train_time:51841ms step_avg:49.23ms +step:1054/1555 train_time:51930ms step_avg:49.27ms +step:1055/1555 train_time:52016ms step_avg:49.30ms +step:1056/1555 train_time:52107ms step_avg:49.34ms +step:1057/1555 train_time:52190ms step_avg:49.38ms +step:1058/1555 train_time:52279ms step_avg:49.41ms +step:1059/1555 train_time:52362ms step_avg:49.45ms +step:1060/1555 train_time:52452ms step_avg:49.48ms +step:1061/1555 train_time:52536ms step_avg:49.52ms +step:1062/1555 train_time:52626ms step_avg:49.55ms +step:1063/1555 train_time:52710ms step_avg:49.59ms +step:1064/1555 train_time:52800ms step_avg:49.62ms +step:1065/1555 train_time:52885ms step_avg:49.66ms +step:1066/1555 train_time:52975ms step_avg:49.70ms +step:1067/1555 train_time:53060ms step_avg:49.73ms +step:1068/1555 train_time:53149ms step_avg:49.77ms +step:1069/1555 train_time:53235ms step_avg:49.80ms +step:1070/1555 train_time:53324ms step_avg:49.84ms +step:1071/1555 train_time:53407ms step_avg:49.87ms +step:1072/1555 train_time:53497ms step_avg:49.90ms +step:1073/1555 train_time:53581ms step_avg:49.94ms +step:1074/1555 train_time:53670ms step_avg:49.97ms +step:1075/1555 train_time:53755ms step_avg:50.00ms +step:1076/1555 train_time:53845ms step_avg:50.04ms +step:1077/1555 train_time:53928ms step_avg:50.07ms +step:1078/1555 train_time:54019ms step_avg:50.11ms +step:1079/1555 train_time:54103ms step_avg:50.14ms +step:1080/1555 train_time:54193ms step_avg:50.18ms +step:1081/1555 train_time:54277ms step_avg:50.21ms +step:1082/1555 train_time:54367ms step_avg:50.25ms +step:1083/1555 train_time:54450ms step_avg:50.28ms +step:1084/1555 train_time:54540ms step_avg:50.31ms +step:1085/1555 train_time:54623ms step_avg:50.34ms +step:1086/1555 train_time:54713ms step_avg:50.38ms +step:1087/1555 train_time:54797ms step_avg:50.41ms +step:1088/1555 train_time:54888ms step_avg:50.45ms +step:1089/1555 train_time:54972ms step_avg:50.48ms +step:1090/1555 train_time:55063ms step_avg:50.52ms +step:1091/1555 train_time:55147ms step_avg:50.55ms +step:1092/1555 train_time:55235ms step_avg:50.58ms +step:1093/1555 train_time:55319ms step_avg:50.61ms +step:1094/1555 train_time:55409ms step_avg:50.65ms +step:1095/1555 train_time:55493ms step_avg:50.68ms +step:1096/1555 train_time:55588ms step_avg:50.72ms +step:1097/1555 train_time:55667ms step_avg:50.74ms +step:1098/1555 train_time:55757ms step_avg:50.78ms +step:1099/1555 train_time:55843ms step_avg:50.81ms +step:1100/1555 train_time:55932ms step_avg:50.85ms +step:1101/1555 train_time:56016ms step_avg:50.88ms +step:1102/1555 train_time:56107ms step_avg:50.91ms +step:1103/1555 train_time:56191ms step_avg:50.94ms +step:1104/1555 train_time:56280ms step_avg:50.98ms +step:1105/1555 train_time:56364ms step_avg:51.01ms +step:1106/1555 train_time:56454ms step_avg:51.04ms +step:1107/1555 train_time:56537ms step_avg:51.07ms +step:1108/1555 train_time:56627ms step_avg:51.11ms +step:1109/1555 train_time:56710ms step_avg:51.14ms +step:1110/1555 train_time:56802ms step_avg:51.17ms +step:1111/1555 train_time:56885ms step_avg:51.20ms +step:1112/1555 train_time:56975ms step_avg:51.24ms +step:1113/1555 train_time:57059ms step_avg:51.27ms +step:1114/1555 train_time:57149ms step_avg:51.30ms +step:1115/1555 train_time:57234ms step_avg:51.33ms +step:1116/1555 train_time:57323ms step_avg:51.37ms +step:1117/1555 train_time:57407ms step_avg:51.39ms +step:1118/1555 train_time:57497ms step_avg:51.43ms +step:1119/1555 train_time:57582ms step_avg:51.46ms +step:1120/1555 train_time:57672ms step_avg:51.49ms +step:1121/1555 train_time:57757ms step_avg:51.52ms +step:1122/1555 train_time:57847ms step_avg:51.56ms +step:1123/1555 train_time:57930ms step_avg:51.58ms +step:1124/1555 train_time:58021ms step_avg:51.62ms +step:1125/1555 train_time:58105ms step_avg:51.65ms +step:1126/1555 train_time:58194ms step_avg:51.68ms +step:1127/1555 train_time:58279ms step_avg:51.71ms +step:1128/1555 train_time:58368ms step_avg:51.74ms +step:1129/1555 train_time:58452ms step_avg:51.77ms +step:1130/1555 train_time:58542ms step_avg:51.81ms +step:1131/1555 train_time:58625ms step_avg:51.83ms +step:1132/1555 train_time:58714ms step_avg:51.87ms +step:1133/1555 train_time:58798ms step_avg:51.90ms +step:1134/1555 train_time:58890ms step_avg:51.93ms +step:1135/1555 train_time:58974ms step_avg:51.96ms +step:1136/1555 train_time:59065ms step_avg:51.99ms +step:1137/1555 train_time:59148ms step_avg:52.02ms +step:1138/1555 train_time:59237ms step_avg:52.05ms +step:1139/1555 train_time:59322ms step_avg:52.08ms +step:1140/1555 train_time:59411ms step_avg:52.12ms +step:1141/1555 train_time:59497ms step_avg:52.14ms +step:1142/1555 train_time:59586ms step_avg:52.18ms +step:1143/1555 train_time:59669ms step_avg:52.20ms +step:1144/1555 train_time:59761ms step_avg:52.24ms +step:1145/1555 train_time:59845ms step_avg:52.27ms +step:1146/1555 train_time:59934ms step_avg:52.30ms +step:1147/1555 train_time:60018ms step_avg:52.33ms +step:1148/1555 train_time:60108ms step_avg:52.36ms +step:1149/1555 train_time:60192ms step_avg:52.39ms +step:1150/1555 train_time:60282ms step_avg:52.42ms +step:1151/1555 train_time:60365ms step_avg:52.45ms +step:1152/1555 train_time:60455ms step_avg:52.48ms +step:1153/1555 train_time:60539ms step_avg:52.51ms +step:1154/1555 train_time:60629ms step_avg:52.54ms +step:1155/1555 train_time:60713ms step_avg:52.57ms +step:1156/1555 train_time:60804ms step_avg:52.60ms +step:1157/1555 train_time:60887ms step_avg:52.62ms +step:1158/1555 train_time:60976ms step_avg:52.66ms +step:1159/1555 train_time:61061ms step_avg:52.68ms +step:1160/1555 train_time:61150ms step_avg:52.72ms +step:1161/1555 train_time:61235ms step_avg:52.74ms +step:1162/1555 train_time:61325ms step_avg:52.78ms +step:1163/1555 train_time:61409ms step_avg:52.80ms +step:1164/1555 train_time:61499ms step_avg:52.83ms +step:1165/1555 train_time:61584ms step_avg:52.86ms +step:1166/1555 train_time:61673ms step_avg:52.89ms +step:1167/1555 train_time:61758ms step_avg:52.92ms +step:1168/1555 train_time:61849ms step_avg:52.95ms +step:1169/1555 train_time:61932ms step_avg:52.98ms +step:1170/1555 train_time:62023ms step_avg:53.01ms +step:1171/1555 train_time:62107ms step_avg:53.04ms +step:1172/1555 train_time:62197ms step_avg:53.07ms +step:1173/1555 train_time:62281ms step_avg:53.10ms +step:1174/1555 train_time:62370ms step_avg:53.13ms +step:1175/1555 train_time:62454ms step_avg:53.15ms +step:1176/1555 train_time:62545ms step_avg:53.18ms +step:1177/1555 train_time:62628ms step_avg:53.21ms +step:1178/1555 train_time:62718ms step_avg:53.24ms +step:1179/1555 train_time:62803ms step_avg:53.27ms +step:1180/1555 train_time:62891ms step_avg:53.30ms +step:1181/1555 train_time:62976ms step_avg:53.32ms +step:1182/1555 train_time:63067ms step_avg:53.36ms +step:1183/1555 train_time:63150ms step_avg:53.38ms +step:1184/1555 train_time:63240ms step_avg:53.41ms +step:1185/1555 train_time:63324ms step_avg:53.44ms +step:1186/1555 train_time:63414ms step_avg:53.47ms +step:1187/1555 train_time:63498ms step_avg:53.49ms +step:1188/1555 train_time:63588ms step_avg:53.53ms +step:1189/1555 train_time:63671ms step_avg:53.55ms +step:1190/1555 train_time:63762ms step_avg:53.58ms +step:1191/1555 train_time:63845ms step_avg:53.61ms +step:1192/1555 train_time:63934ms step_avg:53.64ms +step:1193/1555 train_time:64020ms step_avg:53.66ms +step:1194/1555 train_time:64109ms step_avg:53.69ms +step:1195/1555 train_time:64194ms step_avg:53.72ms +step:1196/1555 train_time:64284ms step_avg:53.75ms +step:1197/1555 train_time:64368ms step_avg:53.77ms +step:1198/1555 train_time:64457ms step_avg:53.80ms +step:1199/1555 train_time:64542ms step_avg:53.83ms +step:1200/1555 train_time:64631ms step_avg:53.86ms +step:1201/1555 train_time:64715ms step_avg:53.88ms +step:1202/1555 train_time:64806ms step_avg:53.92ms +step:1203/1555 train_time:64889ms step_avg:53.94ms +step:1204/1555 train_time:64981ms step_avg:53.97ms +step:1205/1555 train_time:65064ms step_avg:53.99ms +step:1206/1555 train_time:65153ms step_avg:54.02ms +step:1207/1555 train_time:65237ms step_avg:54.05ms +step:1208/1555 train_time:65327ms step_avg:54.08ms +step:1209/1555 train_time:65411ms step_avg:54.10ms +step:1210/1555 train_time:65502ms step_avg:54.13ms +step:1211/1555 train_time:65585ms step_avg:54.16ms +step:1212/1555 train_time:65674ms step_avg:54.19ms +step:1213/1555 train_time:65758ms step_avg:54.21ms +step:1214/1555 train_time:65849ms step_avg:54.24ms +step:1215/1555 train_time:65934ms step_avg:54.27ms +step:1216/1555 train_time:66024ms step_avg:54.30ms +step:1217/1555 train_time:66107ms step_avg:54.32ms +step:1218/1555 train_time:66198ms step_avg:54.35ms +step:1219/1555 train_time:66281ms step_avg:54.37ms +step:1220/1555 train_time:66371ms step_avg:54.40ms +step:1221/1555 train_time:66456ms step_avg:54.43ms +step:1222/1555 train_time:66547ms step_avg:54.46ms +step:1223/1555 train_time:66629ms step_avg:54.48ms +step:1224/1555 train_time:66720ms step_avg:54.51ms +step:1225/1555 train_time:66805ms step_avg:54.53ms +step:1226/1555 train_time:66894ms step_avg:54.56ms +step:1227/1555 train_time:66978ms step_avg:54.59ms +step:1228/1555 train_time:67067ms step_avg:54.62ms +step:1229/1555 train_time:67151ms step_avg:54.64ms +step:1230/1555 train_time:67242ms step_avg:54.67ms +step:1231/1555 train_time:67325ms step_avg:54.69ms +step:1232/1555 train_time:67416ms step_avg:54.72ms +step:1233/1555 train_time:67501ms step_avg:54.75ms +step:1234/1555 train_time:67590ms step_avg:54.77ms +step:1235/1555 train_time:67676ms step_avg:54.80ms +step:1236/1555 train_time:67766ms step_avg:54.83ms +step:1237/1555 train_time:67849ms step_avg:54.85ms +step:1238/1555 train_time:67938ms step_avg:54.88ms +step:1239/1555 train_time:68023ms step_avg:54.90ms +step:1240/1555 train_time:68112ms step_avg:54.93ms +step:1241/1555 train_time:68196ms step_avg:54.95ms +step:1242/1555 train_time:68287ms step_avg:54.98ms +step:1243/1555 train_time:68371ms step_avg:55.00ms +step:1244/1555 train_time:68461ms step_avg:55.03ms +step:1245/1555 train_time:68545ms step_avg:55.06ms +step:1246/1555 train_time:68634ms step_avg:55.08ms +step:1247/1555 train_time:68718ms step_avg:55.11ms +step:1248/1555 train_time:68807ms step_avg:55.13ms +step:1249/1555 train_time:68892ms step_avg:55.16ms +step:1250/1555 train_time:68982ms step_avg:55.19ms +step:1250/1555 val_loss:3.3982 train_time:69096ms step_avg:55.28ms +step:1251/1555 train_time:69116ms step_avg:55.25ms +step:1252/1555 train_time:69158ms step_avg:55.24ms +step:1253/1555 train_time:69246ms step_avg:55.26ms +step:1254/1555 train_time:69339ms step_avg:55.29ms +step:1255/1555 train_time:69423ms step_avg:55.32ms +step:1256/1555 train_time:69512ms step_avg:55.34ms +step:1257/1555 train_time:69595ms step_avg:55.37ms +step:1258/1555 train_time:69683ms step_avg:55.39ms +step:1259/1555 train_time:69767ms step_avg:55.41ms +step:1260/1555 train_time:69856ms step_avg:55.44ms +step:1261/1555 train_time:69939ms step_avg:55.46ms +step:1262/1555 train_time:70029ms step_avg:55.49ms +step:1263/1555 train_time:70114ms step_avg:55.51ms +step:1264/1555 train_time:70207ms step_avg:55.54ms +step:1265/1555 train_time:70294ms step_avg:55.57ms +step:1266/1555 train_time:70387ms step_avg:55.60ms +step:1267/1555 train_time:70472ms step_avg:55.62ms +step:1268/1555 train_time:70561ms step_avg:55.65ms +step:1269/1555 train_time:70645ms step_avg:55.67ms +step:1270/1555 train_time:70734ms step_avg:55.70ms +step:1271/1555 train_time:70816ms step_avg:55.72ms +step:1272/1555 train_time:70906ms step_avg:55.74ms +step:1273/1555 train_time:70990ms step_avg:55.77ms +step:1274/1555 train_time:71080ms step_avg:55.79ms +step:1275/1555 train_time:71167ms step_avg:55.82ms +step:1276/1555 train_time:71259ms step_avg:55.85ms +step:1277/1555 train_time:71345ms step_avg:55.87ms +step:1278/1555 train_time:71435ms step_avg:55.90ms +step:1279/1555 train_time:71520ms step_avg:55.92ms +step:1280/1555 train_time:71610ms step_avg:55.95ms +step:1281/1555 train_time:71693ms step_avg:55.97ms +step:1282/1555 train_time:71782ms step_avg:55.99ms +step:1283/1555 train_time:71865ms step_avg:56.01ms +step:1284/1555 train_time:71954ms step_avg:56.04ms +step:1285/1555 train_time:72037ms step_avg:56.06ms +step:1286/1555 train_time:72128ms step_avg:56.09ms +step:1287/1555 train_time:72213ms step_avg:56.11ms +step:1288/1555 train_time:72305ms step_avg:56.14ms +step:1289/1555 train_time:72390ms step_avg:56.16ms +step:1290/1555 train_time:72480ms step_avg:56.19ms +step:1291/1555 train_time:72564ms step_avg:56.21ms +step:1292/1555 train_time:72655ms step_avg:56.23ms +step:1293/1555 train_time:72737ms step_avg:56.25ms +step:1294/1555 train_time:72828ms step_avg:56.28ms +step:1295/1555 train_time:72911ms step_avg:56.30ms +step:1296/1555 train_time:73000ms step_avg:56.33ms +step:1297/1555 train_time:73084ms step_avg:56.35ms +step:1298/1555 train_time:73175ms step_avg:56.38ms +step:1299/1555 train_time:73260ms step_avg:56.40ms +step:1300/1555 train_time:73351ms step_avg:56.42ms +step:1301/1555 train_time:73435ms step_avg:56.44ms +step:1302/1555 train_time:73524ms step_avg:56.47ms +step:1303/1555 train_time:73609ms step_avg:56.49ms +step:1304/1555 train_time:73697ms step_avg:56.52ms +step:1305/1555 train_time:73781ms step_avg:56.54ms +step:1306/1555 train_time:73871ms step_avg:56.56ms +step:1307/1555 train_time:73954ms step_avg:56.58ms +step:1308/1555 train_time:74044ms step_avg:56.61ms +step:1309/1555 train_time:74128ms step_avg:56.63ms +step:1310/1555 train_time:74217ms step_avg:56.65ms +step:1311/1555 train_time:74302ms step_avg:56.68ms +step:1312/1555 train_time:74394ms step_avg:56.70ms +step:1313/1555 train_time:74477ms step_avg:56.72ms +step:1314/1555 train_time:74567ms step_avg:56.75ms +step:1315/1555 train_time:74651ms step_avg:56.77ms +step:1316/1555 train_time:74741ms step_avg:56.79ms +step:1317/1555 train_time:74825ms step_avg:56.81ms +step:1318/1555 train_time:74914ms step_avg:56.84ms +step:1319/1555 train_time:74998ms step_avg:56.86ms +step:1320/1555 train_time:75088ms step_avg:56.88ms +step:1321/1555 train_time:75172ms step_avg:56.91ms +step:1322/1555 train_time:75261ms step_avg:56.93ms +step:1323/1555 train_time:75346ms step_avg:56.95ms +step:1324/1555 train_time:75437ms step_avg:56.98ms +step:1325/1555 train_time:75520ms step_avg:57.00ms +step:1326/1555 train_time:75610ms step_avg:57.02ms +step:1327/1555 train_time:75694ms step_avg:57.04ms +step:1328/1555 train_time:75783ms step_avg:57.07ms +step:1329/1555 train_time:75868ms step_avg:57.09ms +step:1330/1555 train_time:75957ms step_avg:57.11ms +step:1331/1555 train_time:76040ms step_avg:57.13ms +step:1332/1555 train_time:76132ms step_avg:57.16ms +step:1333/1555 train_time:76215ms step_avg:57.18ms +step:1334/1555 train_time:76305ms step_avg:57.20ms +step:1335/1555 train_time:76391ms step_avg:57.22ms +step:1336/1555 train_time:76480ms step_avg:57.25ms +step:1337/1555 train_time:76565ms step_avg:57.27ms +step:1338/1555 train_time:76655ms step_avg:57.29ms +step:1339/1555 train_time:76738ms step_avg:57.31ms +step:1340/1555 train_time:76829ms step_avg:57.34ms +step:1341/1555 train_time:76912ms step_avg:57.35ms +step:1342/1555 train_time:77002ms step_avg:57.38ms +step:1343/1555 train_time:77087ms step_avg:57.40ms +step:1344/1555 train_time:77176ms step_avg:57.42ms +step:1345/1555 train_time:77260ms step_avg:57.44ms +step:1346/1555 train_time:77352ms step_avg:57.47ms +step:1347/1555 train_time:77436ms step_avg:57.49ms +step:1348/1555 train_time:77525ms step_avg:57.51ms +step:1349/1555 train_time:77610ms step_avg:57.53ms +step:1350/1555 train_time:77698ms step_avg:57.55ms +step:1351/1555 train_time:77783ms step_avg:57.57ms +step:1352/1555 train_time:77875ms step_avg:57.60ms +step:1353/1555 train_time:77958ms step_avg:57.62ms +step:1354/1555 train_time:78050ms step_avg:57.64ms +step:1355/1555 train_time:78132ms step_avg:57.66ms +step:1356/1555 train_time:78222ms step_avg:57.69ms +step:1357/1555 train_time:78308ms step_avg:57.71ms +step:1358/1555 train_time:78397ms step_avg:57.73ms +step:1359/1555 train_time:78481ms step_avg:57.75ms +step:1360/1555 train_time:78572ms step_avg:57.77ms +step:1361/1555 train_time:78655ms step_avg:57.79ms +step:1362/1555 train_time:78745ms step_avg:57.82ms +step:1363/1555 train_time:78829ms step_avg:57.84ms +step:1364/1555 train_time:78919ms step_avg:57.86ms +step:1365/1555 train_time:79004ms step_avg:57.88ms +step:1366/1555 train_time:79093ms step_avg:57.90ms +step:1367/1555 train_time:79176ms step_avg:57.92ms +step:1368/1555 train_time:79266ms step_avg:57.94ms +step:1369/1555 train_time:79350ms step_avg:57.96ms +step:1370/1555 train_time:79440ms step_avg:57.99ms +step:1371/1555 train_time:79525ms step_avg:58.00ms +step:1372/1555 train_time:79615ms step_avg:58.03ms +step:1373/1555 train_time:79697ms step_avg:58.05ms +step:1374/1555 train_time:79788ms step_avg:58.07ms +step:1375/1555 train_time:79872ms step_avg:58.09ms +step:1376/1555 train_time:79962ms step_avg:58.11ms +step:1377/1555 train_time:80047ms step_avg:58.13ms +step:1378/1555 train_time:80136ms step_avg:58.15ms +step:1379/1555 train_time:80220ms step_avg:58.17ms +step:1380/1555 train_time:80310ms step_avg:58.20ms +step:1381/1555 train_time:80394ms step_avg:58.21ms +step:1382/1555 train_time:80484ms step_avg:58.24ms +step:1383/1555 train_time:80569ms step_avg:58.26ms +step:1384/1555 train_time:80658ms step_avg:58.28ms +step:1385/1555 train_time:80743ms step_avg:58.30ms +step:1386/1555 train_time:80834ms step_avg:58.32ms +step:1387/1555 train_time:80918ms step_avg:58.34ms +step:1388/1555 train_time:81010ms step_avg:58.36ms +step:1389/1555 train_time:81093ms step_avg:58.38ms +step:1390/1555 train_time:81183ms step_avg:58.41ms +step:1391/1555 train_time:81267ms step_avg:58.42ms +step:1392/1555 train_time:81357ms step_avg:58.45ms +step:1393/1555 train_time:81441ms step_avg:58.46ms +step:1394/1555 train_time:81532ms step_avg:58.49ms +step:1395/1555 train_time:81617ms step_avg:58.51ms +step:1396/1555 train_time:81707ms step_avg:58.53ms +step:1397/1555 train_time:81791ms step_avg:58.55ms +step:1398/1555 train_time:81881ms step_avg:58.57ms +step:1399/1555 train_time:81966ms step_avg:58.59ms +step:1400/1555 train_time:82056ms step_avg:58.61ms +step:1401/1555 train_time:82141ms step_avg:58.63ms +step:1402/1555 train_time:82230ms step_avg:58.65ms +step:1403/1555 train_time:82315ms step_avg:58.67ms +step:1404/1555 train_time:82405ms step_avg:58.69ms +step:1405/1555 train_time:82490ms step_avg:58.71ms +step:1406/1555 train_time:82579ms step_avg:58.73ms +step:1407/1555 train_time:82664ms step_avg:58.75ms +step:1408/1555 train_time:82753ms step_avg:58.77ms +step:1409/1555 train_time:82837ms step_avg:58.79ms +step:1410/1555 train_time:82927ms step_avg:58.81ms +step:1411/1555 train_time:83011ms step_avg:58.83ms +step:1412/1555 train_time:83100ms step_avg:58.85ms +step:1413/1555 train_time:83185ms step_avg:58.87ms +step:1414/1555 train_time:83276ms step_avg:58.89ms +step:1415/1555 train_time:83359ms step_avg:58.91ms +step:1416/1555 train_time:83450ms step_avg:58.93ms +step:1417/1555 train_time:83533ms step_avg:58.95ms +step:1418/1555 train_time:83622ms step_avg:58.97ms +step:1419/1555 train_time:83706ms step_avg:58.99ms +step:1420/1555 train_time:83797ms step_avg:59.01ms +step:1421/1555 train_time:83881ms step_avg:59.03ms +step:1422/1555 train_time:83972ms step_avg:59.05ms +step:1423/1555 train_time:84054ms step_avg:59.07ms +step:1424/1555 train_time:84144ms step_avg:59.09ms +step:1425/1555 train_time:84230ms step_avg:59.11ms +step:1426/1555 train_time:84320ms step_avg:59.13ms +step:1427/1555 train_time:84404ms step_avg:59.15ms +step:1428/1555 train_time:84494ms step_avg:59.17ms +step:1429/1555 train_time:84578ms step_avg:59.19ms +step:1430/1555 train_time:84668ms step_avg:59.21ms +step:1431/1555 train_time:84753ms step_avg:59.23ms +step:1432/1555 train_time:84843ms step_avg:59.25ms +step:1433/1555 train_time:84927ms step_avg:59.27ms +step:1434/1555 train_time:85016ms step_avg:59.29ms +step:1435/1555 train_time:85102ms step_avg:59.30ms +step:1436/1555 train_time:85191ms step_avg:59.33ms +step:1437/1555 train_time:85275ms step_avg:59.34ms +step:1438/1555 train_time:85365ms step_avg:59.36ms +step:1439/1555 train_time:85450ms step_avg:59.38ms +step:1440/1555 train_time:85539ms step_avg:59.40ms +step:1441/1555 train_time:85624ms step_avg:59.42ms +step:1442/1555 train_time:85714ms step_avg:59.44ms +step:1443/1555 train_time:85798ms step_avg:59.46ms +step:1444/1555 train_time:85889ms step_avg:59.48ms +step:1445/1555 train_time:85973ms step_avg:59.50ms +step:1446/1555 train_time:86062ms step_avg:59.52ms +step:1447/1555 train_time:86146ms step_avg:59.53ms +step:1448/1555 train_time:86236ms step_avg:59.56ms +step:1449/1555 train_time:86322ms step_avg:59.57ms +step:1450/1555 train_time:86411ms step_avg:59.59ms +step:1451/1555 train_time:86495ms step_avg:59.61ms +step:1452/1555 train_time:86585ms step_avg:59.63ms +step:1453/1555 train_time:86670ms step_avg:59.65ms +step:1454/1555 train_time:86759ms step_avg:59.67ms +step:1455/1555 train_time:86844ms step_avg:59.69ms +step:1456/1555 train_time:86934ms step_avg:59.71ms +step:1457/1555 train_time:87018ms step_avg:59.72ms +step:1458/1555 train_time:87107ms step_avg:59.74ms +step:1459/1555 train_time:87191ms step_avg:59.76ms +step:1460/1555 train_time:87281ms step_avg:59.78ms +step:1461/1555 train_time:87365ms step_avg:59.80ms +step:1462/1555 train_time:87455ms step_avg:59.82ms +step:1463/1555 train_time:87540ms step_avg:59.84ms +step:1464/1555 train_time:87630ms step_avg:59.86ms +step:1465/1555 train_time:87713ms step_avg:59.87ms +step:1466/1555 train_time:87802ms step_avg:59.89ms +step:1467/1555 train_time:87887ms step_avg:59.91ms +step:1468/1555 train_time:87977ms step_avg:59.93ms +step:1469/1555 train_time:88062ms step_avg:59.95ms +step:1470/1555 train_time:88152ms step_avg:59.97ms +step:1471/1555 train_time:88236ms step_avg:59.98ms +step:1472/1555 train_time:88325ms step_avg:60.00ms +step:1473/1555 train_time:88409ms step_avg:60.02ms +step:1474/1555 train_time:88499ms step_avg:60.04ms +step:1475/1555 train_time:88584ms step_avg:60.06ms +step:1476/1555 train_time:88675ms step_avg:60.08ms +step:1477/1555 train_time:88758ms step_avg:60.09ms +step:1478/1555 train_time:88849ms step_avg:60.11ms +step:1479/1555 train_time:88932ms step_avg:60.13ms +step:1480/1555 train_time:89022ms step_avg:60.15ms +step:1481/1555 train_time:89107ms step_avg:60.17ms +step:1482/1555 train_time:89197ms step_avg:60.19ms +step:1483/1555 train_time:89281ms step_avg:60.20ms +step:1484/1555 train_time:89371ms step_avg:60.22ms +step:1485/1555 train_time:89454ms step_avg:60.24ms +step:1486/1555 train_time:89544ms step_avg:60.26ms +step:1487/1555 train_time:89628ms step_avg:60.27ms +step:1488/1555 train_time:89718ms step_avg:60.29ms +step:1489/1555 train_time:89803ms step_avg:60.31ms +step:1490/1555 train_time:89893ms step_avg:60.33ms +step:1491/1555 train_time:89977ms step_avg:60.35ms +step:1492/1555 train_time:90067ms step_avg:60.37ms +step:1493/1555 train_time:90151ms step_avg:60.38ms +step:1494/1555 train_time:90240ms step_avg:60.40ms +step:1495/1555 train_time:90324ms step_avg:60.42ms +step:1496/1555 train_time:90414ms step_avg:60.44ms +step:1497/1555 train_time:90498ms step_avg:60.45ms +step:1498/1555 train_time:90589ms step_avg:60.47ms +step:1499/1555 train_time:90672ms step_avg:60.49ms +step:1500/1555 train_time:90762ms step_avg:60.51ms +step:1500/1555 val_loss:3.2946 train_time:90878ms step_avg:60.59ms +step:1501/1555 train_time:90898ms step_avg:60.56ms +step:1502/1555 train_time:90940ms step_avg:60.55ms +step:1503/1555 train_time:91028ms step_avg:60.56ms +step:1504/1555 train_time:91121ms step_avg:60.59ms +step:1505/1555 train_time:91206ms step_avg:60.60ms +step:1506/1555 train_time:91295ms step_avg:60.62ms +step:1507/1555 train_time:91379ms step_avg:60.64ms +step:1508/1555 train_time:91468ms step_avg:60.66ms +step:1509/1555 train_time:91551ms step_avg:60.67ms +step:1510/1555 train_time:91641ms step_avg:60.69ms +step:1511/1555 train_time:91723ms step_avg:60.70ms +step:1512/1555 train_time:91812ms step_avg:60.72ms +step:1513/1555 train_time:91897ms step_avg:60.74ms +step:1514/1555 train_time:91990ms step_avg:60.76ms +step:1515/1555 train_time:92077ms step_avg:60.78ms +step:1516/1555 train_time:92172ms step_avg:60.80ms +step:1517/1555 train_time:92258ms step_avg:60.82ms +step:1518/1555 train_time:92347ms step_avg:60.83ms +step:1519/1555 train_time:92431ms step_avg:60.85ms +step:1520/1555 train_time:92521ms step_avg:60.87ms +step:1521/1555 train_time:92604ms step_avg:60.88ms +step:1522/1555 train_time:92692ms step_avg:60.90ms +step:1523/1555 train_time:92776ms step_avg:60.92ms +step:1524/1555 train_time:92868ms step_avg:60.94ms +step:1525/1555 train_time:92953ms step_avg:60.95ms +step:1526/1555 train_time:93046ms step_avg:60.97ms +step:1527/1555 train_time:93131ms step_avg:60.99ms +step:1528/1555 train_time:93222ms step_avg:61.01ms +step:1529/1555 train_time:93306ms step_avg:61.02ms +step:1530/1555 train_time:93396ms step_avg:61.04ms +step:1531/1555 train_time:93480ms step_avg:61.06ms +step:1532/1555 train_time:93572ms step_avg:61.08ms +step:1533/1555 train_time:93655ms step_avg:61.09ms +step:1534/1555 train_time:93746ms step_avg:61.11ms +step:1535/1555 train_time:93830ms step_avg:61.13ms +step:1536/1555 train_time:93920ms step_avg:61.15ms +step:1537/1555 train_time:94005ms step_avg:61.16ms +step:1538/1555 train_time:94098ms step_avg:61.18ms +step:1539/1555 train_time:94185ms step_avg:61.20ms +step:1540/1555 train_time:94275ms step_avg:61.22ms +step:1541/1555 train_time:94359ms step_avg:61.23ms +step:1542/1555 train_time:94449ms step_avg:61.25ms +step:1543/1555 train_time:94533ms step_avg:61.27ms +step:1544/1555 train_time:94623ms step_avg:61.28ms +step:1545/1555 train_time:94707ms step_avg:61.30ms +step:1546/1555 train_time:94797ms step_avg:61.32ms +step:1547/1555 train_time:94882ms step_avg:61.33ms +step:1548/1555 train_time:94973ms step_avg:61.35ms +step:1549/1555 train_time:95059ms step_avg:61.37ms +step:1550/1555 train_time:95150ms step_avg:61.39ms +step:1551/1555 train_time:95235ms step_avg:61.40ms +step:1552/1555 train_time:95326ms step_avg:61.42ms +step:1553/1555 train_time:95410ms step_avg:61.44ms +step:1554/1555 train_time:95500ms step_avg:61.45ms +step:1555/1555 train_time:95585ms step_avg:61.47ms +step:1555/1555 val_loss:3.2787 train_time:95698ms step_avg:61.54ms +peak memory allocated: 31630 MiB reserved: 46618 MiB diff --git a/records/track_1_short/2026-01-31-BigramHashH2D/c8648449-1553-4f21-852c-9aa5f1293dad.txt b/records/track_1_short/2026-01-31-BigramHashH2D/c8648449-1553-4f21-852c-9aa5f1293dad.txt new file mode 100644 index 000000000..7752ad620 --- /dev/null +++ b/records/track_1_short/2026-01-31-BigramHashH2D/c8648449-1553-4f21-852c-9aa5f1293dad.txt @@ -0,0 +1,3976 @@ +import os +import sys + +# Read the current file and the kernels file code ASAP, for logging +with open(sys.argv[0], 'r') as f: + code = f.read() +with open(os.path.join(os.path.dirname(sys.argv[0]), 'triton_kernels.py'), 'r') as f: + code += f"\n\n{'-'*40}\n# triton_kernels.py\n{'-'*40}\n\n" + code += f.read() + +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate, pairwise +from pathlib import Path +import gc + +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +import torch +import triton + +torch.empty( + 1, device=f"cuda:{os.environ['LOCAL_RANK']}", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +from kernels import get_kernel +from torch import Tensor, nn + +from triton_kernels import XXT, ba_plus_cAA, FusedLinearReLUSquareFunction, FusedSoftcappedCrossEntropy + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Distributed training setup +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +grad_scale = 2 / grad_accum_steps # consistent grad magnitudes between different num_devices +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng +# Transposed layout by @ChrisJMcCormick allows for faster gradient accumulation. + +@torch.library.custom_op("nanogpt::mm_t", mutates_args=()) +def mm_t_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + """Computes y = x @ w with F8 weights stored as (in_features, out_features).""" + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + assert x.shape[1] == w.shape[0] # x: (batch, in), w: (in, out) + + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + + # _scaled_mm requires column-major B. w_f8 is row-major (in, out). + # .T.contiguous().T creates a column-major view without changing logical shape. + w_f8_col_major = w_f8.T.contiguous().T + + out = torch._scaled_mm( + x_f8, + w_f8_col_major, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_t_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[0] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_t_backward", mutates_args=()) +def mm_t_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + + x_scale = grad.new_tensor(x_s, dtype=torch.float32) + w_scale = grad.new_tensor(w_s, dtype=torch.float32) + grad_scale = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + + # grad_x = grad @ w.T + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=grad_scale, + scale_b=w_scale, + use_fast_accum=False, + ) + + # grad_w = x.T @ grad + # Result is (in, out), naturally matching weight storage. No final .T needed. + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_scale, + scale_b=grad_scale, + use_fast_accum=False, + ) + + return grad_x, grad_w + + grad_x, grad_w = impl(g, x_f8, w_f8) + + return grad_x, grad_w + +@mm_t_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) + +def backward_t(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_t_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context_t(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_t_op.register_autograd(backward_t, setup_context=setup_context_t) + +# ----------------------------------------------------------------------------- +# Polar Express + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor, split_baddbmm: bool = False): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + # Select batched vs unbatched + if split_baddbmm: + BX_matmul = torch.bmm if X.ndim > 2 else torch.mm + else: + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in polar_express_coeffs: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + + # Referencing X twice causes pytorch to make a defensive copy, + # resulting in a cudaMemcpyAsync in baddbmm. + # For large matrices (i.e., the mlp weights), it's faster to split + # the operation into two kernels to avoid this. + if split_baddbmm: + BX_matmul(B, X, out=C) # C = B @ X + C.add_(X, alpha=a) # C = C + a*X (in-place, X only read) + else: + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Combined NorMuon + Adam Optimizer + +@dataclass +class ParamConfig: + """Per-parameter configuration for NorMuonAndAdam optimizer.""" + label: str + optim: str # "adam" or "normuon" + comms: str # "none", "replicated", or "sharded" + adam_betas: tuple[float, float] | None + lr_mul: float + wd_mul: float + lr: float + initial_lr: float + weight_decay: float + # Adam-specific + eps: float | None = None + # NorMuon-specific + reshape: tuple | None = None + chunk_size: int | None = None + momentum: float | None = None + beta2: float | None = None + per_matrix_lr_mul: list[float] | None = None + + +class NorMuonAndAdam: + """ + Combined optimizer that handles both NorMuon (for projection matrices) and + Adam (for embeddings/scalars/gate weights). + + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, Muon uses a Newton-Schulz iteration (replaced + here with Polar Express), which has the advantage that it can be stably run in bfloat16 on the GPU. + + Muon is applied only to the projection matrices in the attention and MLP layers, and is not recommended + for embeddings, scalars, or individual weight vectors (e.g., bias terms or gate weights). + + Differences from standard Muon: + - Newton-Shulz is replaced with Polar Express for the orthogonalization step + - NorMuon adds a low-rank variance estimator similar to Adafactor. https://arxiv.org/pdf/2510.05491 + - Cautious weight decay, a gated version of decoupled weight decay + - Mantissa tracking for precision + + Adam (for embeddings/scalars/gates): + - Standard Adam with bias correction + - Cautious weight decay + + Configuration: + Unlike torch.optim.Optimizer, this class uses per-parameter configs from a `param_table` dict + and does not include parameter "groups". All parameters require a .label attribute, and a + corresponding entry in the param_table to specify their hyperparameters (lr_mul, wd_mul, adam_betas, etc.). + + Communication and ordering: + Gradient communication is explicitly scheduled rather than hook-driven. + Reductions are launched in `scatter_order`, while update math and final + gathers are executed in `work_order`. These orders are independent and + must each contain every parameter label exactly once. + + Two communication modes are supported per parameter: + - 'replicated': Gradients are all-reduced and each rank computes the full update. + - 'sharded': Gradients are reduce-scattered, each rank updates its shard, + and results are all-gathered. + + Adam parameters may be freely sharded. NorMuon operates on full matrices; sharding is + supported by grouping matrices into parameter banks. NorMuon parameters must have a + `.reshape` attribute that reshapes the bank so that the leading dimension is divisible + by world_size. + + # Contributors include @YouJiacheng, @KonstantinWilleke, @alexrgilbert, @adricarda, + # @tuttyfrutyee, @vdlad, @ryanyang0, @vagrawal, @varunneal, @chrisjmccormick + """ + def __init__(self, named_params, param_table: dict, scatter_order: list, work_order: list, + adam_defaults: dict, normuon_defaults: dict): + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # Store defaults for each optimizer type + self.adam_defaults = adam_defaults + self.normuon_defaults = normuon_defaults + self.param_table = param_table + self.scatter_order = scatter_order + self.work_order = work_order + + # Collect params by label and build config + self.param_cfgs: dict[nn.Parameter, ParamConfig] = {} + self.param_states: dict[nn.Parameter, dict] = {} + self._param_by_label: dict[str, nn.Parameter] = {} + for name, param in named_params: + label = getattr(param, "label", None) + assert label is not None and label in param_table # all params must have valid label + assert label not in self._param_by_label # exactly one param per label + self._param_by_label[label] = param + self._build_param_cfg(param, label) + + # Assert scatter_order and work_order match present labels exactly + present = set(self._param_by_label.keys()) + assert set(scatter_order) == present and set(work_order) == present + + # Handle world_size=1: overwrite comms to "none" + if self.world_size == 1: + for p_cfg in self.param_cfgs.values(): + p_cfg.comms = "none" + + # Initialize state for all params + self._init_state() + + # 0-D CPU tensors to avoid recompilation + self._step_size_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + + # Track async operations + self._reduce_futures: dict[nn.Parameter, tuple] = {} + + # Embed/lm_head tying state + self.split_embed = False + self._lm_head_param = self._param_by_label.get("lm_head") + self._embed_param = self._param_by_label.get("embed") + + def _build_param_cfg(self, param: nn.Parameter, label: str): + """Build config for a single parameter from param_table.""" + table_entry = self.param_table[label] + optim = table_entry["optim"] + comms = table_entry["comms"] + adam_betas = table_entry.get("adam_betas") + lr_mul = table_entry.get("lr_mul", 1.0) + wd_mul = table_entry.get("wd_mul", 1.0) + + if optim == "adam": + chunk_size = param.shape[0] // self.world_size if comms == "sharded" else None + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.adam_defaults["lr"], + initial_lr=self.adam_defaults["lr"], + weight_decay=self.adam_defaults["weight_decay"], + eps=self.adam_defaults["eps"], + chunk_size=chunk_size, + ) + elif optim == "normuon": + reshape = getattr(param, "reshape", None) + if reshape is None: + raise ValueError(f"NorMuon param {label} must have .reshape attribute") + if reshape[0] % self.world_size != 0: + raise ValueError(f"reshape[0]={reshape[0]} must be divisible by world_size") + + chunk_size = reshape[0] // self.world_size + chunk_shape = (chunk_size, *reshape[1:]) + # Shape-based LR multiplier for NorMuon + shape_mult = max(1.0, chunk_shape[-2] / chunk_shape[-1]) ** 0.5 if len(chunk_shape) >= 2 else 1.0 + lr_mul = shape_mult * lr_mul + + # Per-matrix LR multipliers for MLP c_proj (2x LR on odd indices) + per_matrix_lr_mul = None + if label == "mlp": + rank = dist.get_rank() if dist.is_initialized() else 0 + start_idx = rank * chunk_size + per_matrix_lr_mul = [] + for i in range(chunk_size): + global_idx = start_idx + i + is_c_proj = (global_idx % 2 == 1) + per_matrix_lr_mul.append(2.0 if is_c_proj else 1.0) + + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.normuon_defaults["lr"], + initial_lr=self.normuon_defaults["lr"], + weight_decay=self.normuon_defaults["weight_decay"], + reshape=reshape, + chunk_size=chunk_size, + momentum=self.normuon_defaults["momentum"], + beta2=self.normuon_defaults["beta2"], + per_matrix_lr_mul=per_matrix_lr_mul, + ) + else: + raise ValueError(f"Unknown optim type: {optim}") + + self.param_cfgs[param] = p_cfg + + def _init_state(self): + """Initialize optimizer state for all parameters.""" + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam": + # Sharded params use chunk state, replicated use full state + if p_cfg.comms == "sharded": + chunk = param[:p_cfg.chunk_size] + else: + chunk = param + exp_avg = torch.zeros_like(chunk, dtype=torch.float32, device=param.device) + self.param_states[param] = dict(step=0, exp_avg=exp_avg, exp_avg_sq=torch.zeros_like(exp_avg)) + + elif p_cfg.optim == "normuon": + chunk_shape = (p_cfg.chunk_size, *p_cfg.reshape[1:]) + + # Momentum buffer (FP32 for precision) + momentum_buffer = torch.zeros( + chunk_shape, dtype=torch.float32, device=param.device + ) + + # Second momentum buffer - reduced along one dimension + if chunk_shape[-2] >= chunk_shape[-1]: + second_mom_shape = (*chunk_shape[:-1], 1) + else: + second_mom_shape = (*chunk_shape[:-2], 1, chunk_shape[-1]) + second_momentum_buffer = torch.zeros( + second_mom_shape, dtype=torch.float32, device=param.device + ) + + # Mantissa buffer for precision tracking + mantissa = torch.zeros( + chunk_shape, dtype=torch.uint16, device=param.device + ) + + self.param_states[param] = dict( + momentum_buffer=momentum_buffer, + second_momentum_buffer=second_momentum_buffer, + mantissa=mantissa, + ) + + # ----------------------------------- + # Reduce/Gather operations + + def _launch_reduce(self, param: nn.Parameter, grad: Tensor): + """Launch async reduce for a parameter based on its comms policy.""" + p_cfg = self.param_cfgs[param] + + if p_cfg.comms == "none": + if p_cfg.optim == "normuon": + # NorMuon needs reshaped gradient even without communication + grad = grad.view(p_cfg.reshape) + self._reduce_futures[param] = (None, grad) + elif p_cfg.comms == "replicated": + future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + self._reduce_futures[param] = (future, grad) + elif p_cfg.comms == "sharded": + if p_cfg.optim == "normuon": + # NorMuon: reshape before reduce_scatter + grad_reshaped = grad.view(p_cfg.reshape) + grad_chunk = torch.empty( + (p_cfg.chunk_size, *grad_reshaped.shape[1:]), + dtype=grad.dtype, + device=grad.device + ) + future = dist.reduce_scatter_tensor( + grad_chunk, grad_reshaped.contiguous(), op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + else: + # Adam: simple reduce_scatter + grad_chunk = torch.empty_like(grad[:p_cfg.chunk_size]) + future = dist.reduce_scatter_tensor( + grad_chunk, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + + def _launch_gather(self, param: nn.Parameter, p_slice: Tensor) -> "torch.futures.Future": + """Launch async all_gather for a sharded parameter.""" + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "normuon": + full_param = param.data.view(p_cfg.reshape) + assert full_param.is_contiguous() + return dist.all_gather_into_tensor( + full_param, p_slice.contiguous(), async_op=True + ).get_future() + else: + return dist.all_gather_into_tensor( + param, p_slice.contiguous(), async_op=True + ).get_future() + + # ----------------------------------- + # State management + + def reset(self): + """Reset NorMuon momentum buffers and split_embed state (called on training reset).""" + self.split_embed = False + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "normuon": + p_state = self.param_states[param] + p_state["momentum_buffer"].zero_() + p_state["mantissa"].zero_() + p_state["second_momentum_buffer"].zero_() + + def copy_lm_state_to_embed(self): + """ + Copy the optimizer state from the lm_head to the embed at the untie point. + This requires an all-gather + reshard because of different sharding: + - lm_head (768, 50304) is sharded to (96, 50304) per rank (along model_dim) + - embed (50304, 768) is sharded to (6288, 768) per rank (along vocab_size) + + We all-gather the lm_head momentum, transpose it, then each rank takes their + embed shard to get the correct momentum state. + """ + lm_head = self._lm_head_param + embed = self._embed_param + lm_state = self.param_states[lm_head] + embed_state = self.param_states[embed] + lm_cfg = self.param_cfgs[lm_head] + embed_cfg = self.param_cfgs[embed] + + embed_state['step'] = lm_state['step'] # Preserve step count for bias correction + + # Copy optimizer state with all-gather + transpose + reshard + if self.world_size > 1: + rank = dist.get_rank() + lm_chunk_size = lm_cfg.chunk_size # 96 + embed_chunk_size = embed_cfg.chunk_size # 6288 + + # All-gather lm_head momentum to get full (768, 50304) tensor + for key in ["exp_avg", "exp_avg_sq"]: + lm_chunk = lm_state[key] # (96, 50304) + full_lm = torch.empty(lm_head.shape[0], lm_head.shape[1], dtype=lm_chunk.dtype, device=lm_chunk.device) + dist.all_gather_into_tensor(full_lm, lm_chunk.contiguous()) + embed_state[key].copy_(full_lm.T[rank * embed_chunk_size:(rank + 1) * embed_chunk_size]) + else: + # Single GPU: simple transpose + for key in ["exp_avg", "exp_avg_sq"]: + embed_state[key].copy_(lm_state[key].T) + + # Mark as split + self.split_embed = True + + def state_dict(self): + """Return the optimizer state as a dict.""" + return { + "param_states": {id(p): s for p, s in self.param_states.items()}, + "param_cfgs": {id(p): s for p, s in self.param_cfgs.items()}, + } + + def load_state_dict(self, state_dict): + """Load optimizer state from a dict.""" + # Build id->param mapping + id_to_param = {id(p): p for p in self.param_cfgs.keys()} + + # Load state, preserving dtypes + for param_id, saved_p_state in state_dict["param_states"].items(): + if param_id in id_to_param: + param = id_to_param[param_id] + p_state = self.param_states[param] + for k, v in saved_p_state.items(): + if isinstance(v, torch.Tensor) and k in p_state: + target_dtype = p_state[k].dtype + p_state[k] = v.to(dtype=target_dtype, device=p_state[k].device) + else: + p_state[k] = v + + # ----------------------------------- + # Unified optimizer step with explicit ordering + + @torch.no_grad() + def step(self, do_adam: bool = True): + """ + Combined optimizer step with explicit ordering. + + Args: + do_adam: If True, update Adam params. NorMuon params always updated. + + Flow: + 1. Scatter phase: Launch reduces in scatter_order + 2. Work phase: Process updates in work_order + - Wait for reduce, compute update, launch gather + 3. Finalize phase: Wait for gathers + + While the embeddings are tied: + - Comms and update math are only done on lm_head. + - We add embed.grad.T into lm_head.grad before comms. + - After lm_head gather, we copy lm_head.data.T --> embed.data + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + lm_param, embed_param = self._lm_head_param, self._embed_param + + # ===== Phase 1: Launch reduces in scatter_order ===== + for label in self.scatter_order: + param = self._param_by_label[label] + p_cfg = self.param_cfgs[param] + + if p_cfg.optim == "adam" and not do_adam: + continue + if param.grad is None: + continue + + # lm_head when tied: aggregate embed.grad.T (transposed shapes) + if label == "lm_head" and do_adam and not self.split_embed: + if embed_param is not None and embed_param.grad is not None: + param.grad.add_(embed_param.grad.T) + + # Skip embed when tied (copied from lm_head after gather) + if label == "embed" and not self.split_embed: + continue + + self._launch_reduce(param, param.grad) + + # ===== Phase 2: Process updates in work_order ===== + gather_futures = [] + lm_head_gather_future = None + + for label in self.work_order: + param = self._param_by_label[label] + if param not in self._reduce_futures: + continue + + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "adam" and not do_adam: + continue + # Wait for reduce + future, grad_chunk = self._reduce_futures[param] + if future is not None: + future.wait() + # Apply update based on optim type + if p_cfg.optim == "adam": + p_slice = self._adam_update(param, grad_chunk, p_cfg, rank) + else: + p_slice = self._normuon_update(param, grad_chunk, p_cfg, rank) + # Launch gather for sharded params + if p_cfg.comms == "sharded" and self.world_size > 1: + gather_fut = self._launch_gather(param, p_slice) + if label == "lm_head": + lm_head_gather_future = gather_fut + else: + gather_futures.append(gather_fut) + + # ===== Phase 3: Wait for gathers, sync embed if tied ===== + # Wait for lm_head gather first so we can copy to embed while other gathers complete + if lm_head_gather_future is not None: + lm_head_gather_future.wait() + + # When tied: copy lm_head.T to embed + if do_adam and not self.split_embed and embed_param is not None and lm_param is not None: + embed_param.data.copy_(lm_param.data.T) + + # Wait for remaining gathers + for fut in gather_futures: + fut.wait() + + self._reduce_futures.clear() + + # Clear grads for updated params + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam" and not do_adam: + continue # Don't clear Adam grads on even steps + param.grad = None + + # ----------------------------------- + # Adam update + + def _adam_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply Adam update to a parameter. Returns the updated p_slice.""" + beta1, beta2 = p_cfg.adam_betas + lr = p_cfg.lr * p_cfg.lr_mul + + # Get parameter slice + if p_cfg.comms == "sharded": + p_slice = param[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + else: + p_slice = param + + p_state = self.param_states[param] + p_state["step"] += 1 + t = p_state["step"] + + bias1, bias2 = 1 - beta1 ** t, 1 - beta2 ** t + self._step_size_t.fill_(lr * (bias2 ** 0.5 / bias1)) + self._eff_wd_t.fill_(lr * lr * p_cfg.weight_decay * p_cfg.wd_mul) + + NorMuonAndAdam._adam_update_step( + p_slice, grad_chunk, p_state["exp_avg"], p_state["exp_avg_sq"], + beta1, beta2, p_cfg.eps, self._step_size_t, self._eff_wd_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _adam_update_step(p_slice, g_slice, exp_avg, exp_avg_sq, beta1, beta2, eps, step_size_t, eff_wd_t): + """Compiled Adam update step.""" + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + update = exp_avg.div(exp_avg_sq.sqrt().add_(eps)).mul_(step_size_t) + # Cautious weight decay + mask = (update * p_slice) > 0 + update.addcmul_(p_slice, mask, value=eff_wd_t) + p_slice.add_(other=update, alpha=-1.0) + + # ----------------------------------- + # NorMuon update + + def _normuon_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply NorMuon update to a parameter. Returns the updated p_slice.""" + chunk_shape = grad_chunk.shape + + p_state = self.param_states[param] + grad_chunk = grad_chunk.float() # FP32 for momentum + + # Momentum update + momentum_buffer = p_state["momentum_buffer"] + momentum_buffer.lerp_(grad_chunk, 1 - p_cfg.momentum) + updated_grads = grad_chunk.lerp_(momentum_buffer, p_cfg.momentum) + + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + + # Polar Express orthogonalization + is_large_matrix = chunk_shape[-2] > 1024 + v_chunk = polar_express(updated_grads, split_baddbmm=is_large_matrix) + + # Variance reduction + red_dim = -1 if chunk_shape[-2] >= chunk_shape[-1] else -2 + v_chunk = NorMuonAndAdam._apply_normuon_variance_reduction( + v_chunk, p_state["second_momentum_buffer"], p_cfg.beta2, red_dim + ) + + # Update parameter, in place, with cautious weight decay + param_view = param.data.view(p_cfg.reshape) + p_slice = param_view[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + + # MLP has per-matrix LR multipliers (c_proj gets 2x LR) + if p_cfg.per_matrix_lr_mul is not None: + for mat_idx in range(p_cfg.chunk_size): + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.per_matrix_lr_mul[mat_idx] * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice[mat_idx].view(torch.uint16), p_state["mantissa"][mat_idx], v_chunk[mat_idx], + self._eff_wd_t, self._eff_lr_t + ) + else: + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice.view(torch.uint16), p_state["mantissa"], v_chunk, + self._eff_wd_t, self._eff_lr_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _cautious_wd_and_update_inplace(p, mantissa, grad, wd_tensor, lr_tensor): + """ + Cautious weight decay + parameter update. wd_tensor and lr_tensor are 0-D CPU tensors. + Mantissa is tracked to enable higher precision updates on bfloat16 parameters. + bfloat16 format: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits total + float32 format: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits total + """ + assert p.dtype == mantissa.dtype == torch.uint16 + grad = grad.float() + wd_factor = wd_tensor.to(torch.float32) + lr_factor = lr_tensor.to(torch.float32) + p_precise_raw = (p.to(torch.uint32) << 16) | mantissa.to(torch.uint32) + p_precise = p_precise_raw.view(torch.float32) + mask = (grad * p_precise) >= 0 + p_precise.copy_(p_precise - (p_precise * mask * wd_factor * lr_factor) - (grad * lr_factor)) + p.copy_((p_precise_raw >> 16).to(torch.uint16)) + mantissa.copy_(p_precise_raw.to(torch.uint16)) + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _apply_normuon_variance_reduction(v_chunk, second_momentum_buffer, beta2, red_dim): + """NorMuon variance reduction. Algebraically fuses the normalization steps to minimize memory ops.""" + v_mean = v_chunk.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = v_chunk.size(red_dim) + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True).mul_(red_dim_size) + v_norm = v_norm_sq.sqrt_() + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt_() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt_() + final_scale = step_size * (v_norm / v_norm_new.clamp_min_(1e-10)) + return v_chunk.mul_(final_scale.type_as(v_chunk)) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinearT(nn.Module): + """ + Linear layer with transposed weight storage (in_features, out_features) which + addresses the slow kernel that was used for gradient accumulation. @chrisjmccormick + """ + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + self.weight = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.bfloat16)) + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + nn.init.zeros_(self.weight) # @Grad62304977 and others + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out = torch.ops.nanogpt.mm_t(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return x @ self.weight.type_as(x) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len, paired=False): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.paired = paired + self.reset() + + def rotary(self, x_BTHD): + assert self.factor1.size(0) >= x_BTHD.size(-3) + factor1, factor2 = ( + self.factor1[None, : x_BTHD.size(-3), None, :], + self.factor2[None, : x_BTHD.size(-3), None, :], + ) + x_flip = x_BTHD.view(*x_BTHD.shape[:-1], x_BTHD.shape[-1] // 2, 2).flip(-1).view(x_BTHD.shape) + return factor1 * x_BTHD + factor2 * x_flip + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + angular_freq = angular_freq.repeat_interleave(2) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//2)]) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=device) + if not self.paired: + theta = torch.outer(t, angular_freq) + self.factor1 = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.factor2 = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, angular_freq) + theta2 = torch.outer(t_odd, angular_freq) + self.factor1 = nn.Buffer( + torch.cat((theta1.cos(), theta2.cos()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2 = nn.Buffer( + torch.cat((theta1.sin(), theta2.sin()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2[..., 1::2] *= -1 + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + if not self.paired: + theta = torch.outer(t, self.angular_freq) + self.factor1.copy_(theta.cos()) + self.factor2.copy_(theta.sin()) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, self.angular_freq) + theta2 = torch.outer(t_odd, self.angular_freq) + self.factor1.copy_(torch.cat((theta1.cos(), theta2.cos()), dim=-1)) + self.factor2.copy_(torch.cat((theta1.sin(), theta2.sin()), dim=-1)) + self.factor2[..., 1::2] *= -1 + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + yarn: Yarn + key_offset: bool + attn_gate_w: torch.Tensor + ve_gate_w: torch.Tensor + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, paired: bool = False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + self.paired = paired + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + yarn = attn_args.yarn + ve, sa_lambdas, key_offset = attn_args.ve, attn_args.sa_lambdas, attn_args.key_offset + seqlens, bm_size = attn_args.seqlens, attn_args.bm_size + # sparse gated attention to enable context based no-op by @classiclarryd + # only include gates on layers with value embeds used on forward pass + attn_gate_w, ve_gate_w = attn_args.attn_gate_w, attn_args.ve_gate_w + + q, k, v = F.linear(x, sa_lambdas[0] * qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + q, k = norm(q), norm(k) # QK norm @Grad62304977 + + if not self.paired: + q, k = yarn.rotary(q), yarn.rotary(k) + + if key_offset: + # shift keys forward for the stationary head dims. Enables 1-layer induction. + k[:, 1:, :, self.head_dim // 2:] = k[:, :-1, :, self.head_dim // 2:] + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T, self.num_heads, 1) + v = v + ve_gate_out * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + + else: + # Paired heads: adjacent heads' queries attend to each other's keys. + # Two copies of the input stream are interleaved to achieve this, which: + # - doubles the length of each sequence + # - halves the effective window size + q = q.view(B, T, self.num_heads // 2, self.head_dim * 2) + k = k.view(B, T, self.num_heads // 2, self.head_dim * 2) + v = v.reshape(B, T * 2, self.num_heads // 2, self.head_dim) + + q, k = yarn.rotary(q), yarn.rotary(k) + + q = q.view(B, T * 2, self.num_heads // 2, self.head_dim) + k = k.view(B, T * 2, self.num_heads // 2, self.head_dim) + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T * 2, self.num_heads // 2, 1) + v = v + ve_gate_out * ve.view_as(v) + + seqlens = 2 * seqlens + max_len = 2 * max_len + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, + max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=yarn.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(F.linear(x[..., :12], attn_gate_w)).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, sa_lambdas[1] * qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg + return y + +class MLP(nn.Module): + def __init__(self): + super().__init__() + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, c_fc: Tensor, c_proj: Tensor): + # relu(x)^2: + # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + # Fused triton kernel for relu(x @ W1.T)^2 @ W2.T + return FusedLinearReLUSquareFunction.apply(x, c_fc, c_proj) + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, has_attn: bool, has_mlp: bool, use_paired_head: bool): + super().__init__() + # skip attention of blocks.6 (the 7th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads, paired=use_paired_head) if has_attn else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP() if has_mlp else None + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor = None, c_fc: Tensor = None, c_proj: Tensor = None): + if self.attn is not None: + x = x + self.attn(norm(x), attn_args, qkvo_w) + if self.mlp is not None: + x = x + self.mlp(norm(x), c_fc, c_proj) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +@dataclass +class ForwardScheduleConfig: + mtp_weights: torch.Tensor + ws_short: int + ws_long: int + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + self.num_layers = num_layers + self.vocab_size = next_multiple_of_n(vocab_size, n=128) + + self.smear_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.smear_gate.weight) + self.smear_gate.weight.label = 'smear_gate' + + self.skip_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.skip_gate.weight) + self.skip_gate.weight.label = 'skip_gate' + + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.Parameter(torch.zeros(5 * self.vocab_size, model_dim, dtype=torch.bfloat16)) + self.value_embeds.label = 'value_embed' + + # parameter banks for attention and value embedding gate weights + self.attn_gate_bank = nn.Parameter(torch.zeros(10, num_heads, 12)) # 10 layers + self.attn_gate_bank.label = 'attn_gate_bank' + self.ve_gate_bank = nn.Parameter(torch.zeros(5, num_heads, 12)) # 5 unique gates + self.ve_gate_bank.label = 've_gate_bank' + + # ----------------------------------- + # Parameter banks for sharded optimization, by @chrisjmccormick + + # Identify which layers have attention/MLP + # Attention is skipped in layer 6 by @YouJiacheng + self.attn_layer_indices = [i for i in range(num_layers) if i != 6] + # All layers have MLP (At 11 layers--dropped first layer @EmelyanenkoK) + self.mlp_layer_indices = list(range(num_layers)) + + hdim = num_heads * head_dim + mlp_hdim = 4 * model_dim + + # Create index mappings: layer_idx -> bank_idx + self.layer_to_attn_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.attn_layer_indices)} + self.layer_to_mlp_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.mlp_layer_indices)} + + # Attention bank: stores QKVO weights for all attention layers + # merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # Simplified layout by @chrisjmccormick + # Shape: (num_attn_layers, 4*model_dim, hdim) = (10, 3072, 768) + # Reshape for sharding: (40, 768, 768) for even distribution across 8 GPUs + self.attn_bank = nn.Parameter(torch.empty(len(self.attn_layer_indices), 4 * model_dim, hdim)) + self.attn_bank.label = 'attn' + self.attn_bank.reshape = (len(self.attn_layer_indices) * 4, hdim, hdim) # (40, 768, 768) + + # MLP bank: stores c_fc and c_proj for all MLP layers + # Shape: (num_mlp_layers + padding, 2, mlp_hdim, model_dim) = (12, 2, 3072, 768) + # We add 1 padding layer (index 11) to get 12*2=24 matrices for even distribution across 8 GPUs + # Reshape for sharding: (24, 3072, 768) + num_mlp_with_padding = len(self.mlp_layer_indices) + 1 # 11 + 1 = 12 + self.mlp_bank = nn.Parameter(torch.empty(num_mlp_with_padding, 2, mlp_hdim, model_dim)) + self.mlp_bank.label = 'mlp' + self.mlp_bank.reshape = (num_mlp_with_padding * 2, mlp_hdim, model_dim) # (24, 3072, 768) + + # improved init scale by @YouJiacheng and @srashedll + std = 0.5 * model_dim ** -0.5 + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.attn_bank.uniform_(-bound, bound) + self.mlp_bank[:, 0, :, :].uniform_(-bound, bound) # c_fc + self.mlp_bank[:, 1, :, :].zero_() # c_proj - zero init suggested by @Grad62304977 + + # Create blocks with has_attn/has_mlp flags + self.paired_head_layers = [0, 2, 5, 9] + self.blocks = nn.ModuleList([ + Block(model_dim, head_dim, num_heads, + has_attn=(i in self.layer_to_attn_idx), + has_mlp=(i in self.layer_to_mlp_idx), + use_paired_head=(i in self.paired_head_layers)) + for i in range(num_layers) + ]) + self.yarn = Yarn(head_dim, max_seq_len) + self.yarn_paired_head = Yarn(head_dim, max_seq_len, paired=True) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + # Transposed weight storage for faster gradient accumulation + self.lm_head = CastedLinearT(model_dim, self.vocab_size, use_fp8=use_fp8, x_s=100/448, w_s=1.6/448, grad_s=grad_scale * 0.75/448) + + nn.init.normal_(self.lm_head.weight, mean=0, std=0.005) + self.lm_head.weight.label = 'lm_head' + + self.embed = nn.Embedding(self.vocab_size, model_dim) + self.embed.weight.label = 'embed' + with torch.no_grad(): + self.embed.weight.copy_(self.lm_head.weight.T) + + self.bigram_embed = nn.Embedding(args.bigram_vocab_size, model_dim) + self.bigram_embed.weight.label = 'bigram_embed' + nn.init.zeros_(self.bigram_embed.weight) + + # x0_lambdas separated out for different optimizer treatment (no beta smoothing) + self.x0_lambdas = nn.Parameter(torch.zeros(num_layers)) + self.x0_lambdas.label = 'x0_lambdas' + + pad = (-num_layers * 3 - 3) % dist.get_world_size() # updated: 3*num_layers instead of 4* + self.scalars = nn.Parameter( + torch.cat( + [ + 1.1 * torch.ones(num_layers), # resid lambdas. 1.1 init such that layer i weight is i^(num_layers-i). + *[torch.tensor([0.5, 1.0]) for _ in range(num_layers)], # SA lambdas + 0.1 * torch.ones(num_layers), # bigram lambdas + torch.zeros(1), # smear_lambda + 0.5*torch.ones(1), # backout_lambda + -1.5 * torch.ones(1), # skip_lambda -> σ(-1.5) ≈ 0.18 + torch.ones(pad), + ] + ) + ) + self.scalars.label = 'scalars' + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _compute_bigram_hash(x: Tensor, mod: int) -> Tensor: + """ + Computes bigram hash on GPU for each position using [prev_token, curr_token]. + Mathematically identical to the CPU version but computed on device. + """ + rand_int_1 = 36313 + rand_int_2 = 27191 + result = torch.empty_like(x) + result[0] = mod + result[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod + return result + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, schedule_cfg: ForwardScheduleConfig): + assert input_seq.ndim == 1 + + # unpack schedule_cfg + mtp_weights, ws_short, ws_long = schedule_cfg.mtp_weights, schedule_cfg.ws_short, schedule_cfg.ws_long + + # set configs + skip_connections = [] + skip_in = [3] # long attention window on layer 3 + skip_out = [6] # no attn op on layer 6 + x_backout = None + backout_layer = 7 + + # set lambdas + resid_lambdas = self.scalars[: 1 * self.num_layers] + x0_lambdas = self.x0_lambdas + sa_lambdas = self.scalars[1 * self.num_layers: 3 * self.num_layers].view(-1, 2) + bigram_lambdas = self.scalars[3 * self.num_layers: 4 * self.num_layers] + smear_lambda = self.scalars[4 * self.num_layers] + backout_lambda = self.scalars[4 * self.num_layers+1] + skip_lambda = self.scalars[4 * self.num_layers+2] + + # set block masks and key shift + bm_sizes = [ws_short, ws_short, ws_short, ws_long, ws_short, ws_short, None, ws_short, ws_short, ws_short, ws_long] + assert len(bm_sizes) == self.num_layers + key_offset = [b==ws_long for b in bm_sizes] # apply partial key offset to long windows + + # Embedding lookup - embed is synced from lm_head during tied phase by optimizer + x = self.embed(input_seq) + # Compute bigram hash on GPU (moved from CPU data loader) + bigram_seq = self._compute_bigram_hash(input_seq, args.bigram_vocab_size - 1) + x0_bigram = self.bigram_embed(bigram_seq)[None] + + # Value embeddings - always computed (not precomputed) + ve = self.value_embeds.view(5, self.vocab_size, -1)[:, input_seq] + # 01 ... 234 structure on token value embeddings by @photomz + ve = [ve[0], ve[1]] + [None] * (self.num_layers - 5) + [ve[2], ve[3], ve[4]] + assert len(ve) == self.num_layers + + # smear token embed forward 1 position @classiclarryd + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # unbind gate banks to avoid select_backwards kernel + ag = [w.bfloat16() for w in self.attn_gate_bank.unbind(0)] + veg = [w.bfloat16() for w in self.ve_gate_bank.unbind(0)] + attn_gates = ag[:6] + [None] + ag[6:] + ve_gates = [veg[0], veg[1]] + [None] * (self.num_layers - 5) + [veg[2], veg[3], veg[4]] + assert len(attn_gates) == self.num_layers + assert len(ve_gates) == self.num_layers + + # unbind weight banks to avoid select_backwards kernel + attn_weights = self.attn_bank.unbind(0) # tuple of [4*dim, hdim] tensors + mlp_fcs = self.mlp_bank[:, 0, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + mlp_projs = self.mlp_bank[:, 1, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + + for i in range(self.num_layers): + yarn = self.yarn_paired_head if i in self.paired_head_layers else self.yarn + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + yarn=yarn, + key_offset=key_offset[i], + attn_gate_w=attn_gates[i], + ve_gate_w=ve_gates[i] + ) + if i in skip_out: + skip_gate_out = torch.sigmoid(skip_lambda) * 2 * torch.sigmoid(self.skip_gate(x0[..., :self.skip_gate.weight.size(-1)])) + x = x + skip_gate_out * skip_connections.pop() + if i == 0: + x = (resid_lambdas[0] + x0_lambdas[0]) * x + bigram_lambdas[0] * x0_bigram + else: + x = resid_lambdas[i] * x + x0_lambdas[i] * x0 + bigram_lambdas[i] * x0_bigram + + # Get weights for this layer from banks + qkvo_w = attn_weights[self.layer_to_attn_idx[i]] if i in self.layer_to_attn_idx else None + c_fc = mlp_fcs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + c_proj = mlp_projs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + + x = self.blocks[i](x, attn_args, qkvo_w, c_fc, c_proj) + if i in skip_in: + skip_connections.append(x) + if i == backout_layer: + x_backout = x + + # back out contributions from first 7 layers that are only required for downstream context and not direct prediction + x -= backout_lambda * x_backout + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15 + # @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1). @classiclarryd updated to 23*sigmoid((logits+5)/7.5) + if self.training: + losses = FusedSoftcappedCrossEntropy.apply(logits.view(-1, logits.size(-1)), target_seq, mtp_weights, 23.0, 5.0, 7.5) + loss = losses.sum() + else: + logits = 23 * torch.sigmoid((logits + 5) / 7.5) + logits_for_loss = logits.float() + loss = F.cross_entropy(logits_for_loss.view(-1, logits_for_loss.size(-1)), target_seq, reduction="mean") + return loss +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class Shard: + def __init__(self, tokens: Tensor, world_size: int = 1): + self.tokens = tokens + self.size = tokens.numel() + self.world_size = world_size + self.i = 0 + + # Partial index now, full index async + self.bos_idx = (tokens[:6_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._full_idx = None + self._loader_thread = None + self._ready = threading.Event() + self._loader_thread = threading.Thread(target=self._scan) + self._loader_thread.start() + + def _scan(self): + self._full_idx = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._ready.set() + + def _maybe_switch(self): + # Switch to full index as soon as async scan completes + if self.bos_idx is not self._full_idx and self._ready.is_set(): + self._loader_thread.join() + self.bos_idx = self._full_idx + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + self._maybe_switch() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + return starts, ends + + @staticmethod + def load_async(file: Path, world_size: int = 1): + """Returns getter function for async shard loading""" + result = {} + ready = threading.Event() + def load(): + tokens = _load_data_shard(file) + result['shard'] = Shard(tokens, world_size) + ready.set() + thread = threading.Thread(target=load) + thread.start() + def get(): + ready.wait() + thread.join() + return result['shard'] + return get + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + shard = Shard(tokens, world_size) + next_shard_getter = Shard.load_async(next(file_iter), world_size) + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = shard.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + shard = next_shard_getter() + tokens = shard.tokens + try: + next_shard_getter = Shard.load_async(next(file_iter), world_size) + except StopIteration: + next_shard_getter = None # no more shards to preload + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + # Cast to int32 on CPU before transfer to avoid dtype conversion during .to() + _inputs = _inputs.to(dtype=torch.int32) + _targets = _targets.to(dtype=torch.int64) + _cum_lengths = _cum_lengths.to(dtype=torch.int32) + # Bigram hash computation moved to GPU in forward() + + new_params = yield ( + _inputs.to(device="cuda", non_blocking=True), + _targets.to(device="cuda", non_blocking=True), + _cum_lengths.to(device="cuda", non_blocking=True), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * new_grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens // new_grad_accum_steps + max_seq_len = new_max_seq_len + +# ----------------------------------------------------------------------------- +# Training Management + +@dataclass +class Hyperparameters: + # data + data_path = os.environ.get("DATA_PATH", ".") + train_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_train_*.bin") # input .bin to train on + val_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_val_*.bin") # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + # batch sizes + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # schedule + num_scheduled_iterations: int = 1515 # number of steps to complete lr and ws schedule + num_extension_iterations: int = 40 # number of steps to continue training at final lr and ws + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 250 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # bigram hash embedding + bigram_vocab_size: int = 50304 * 5 + +args = Hyperparameters() + +@dataclass +class TrainingStage: + lr_mul: float + batch_size: int + window_sizes: tuple[int, int] # (short, long) in block units + mtp_weights_start: list[float] + mtp_weights_end: list[float] + duration: float = None + +class TrainingSchedule: + """ + Training schedule initialized via TRAINING_STAGES + 1. Multi Token Prediction schedule of [1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1] @varunneal + 2. Sliding Attention window schedule of [1,3] -> [3,7] -> [5,11] -> [6,13] + 3. YaRN updates to RoPE on window changes + 4. Split embed and lm head at 2/3 of training + 5. Batch size schedule of 8 -> 16 -> 24 + 6. Post training extension of long windows from 13 to 20 + """ + + def __init__(self, stages: list[TrainingStage], scheduled_iterations: int, extension_iterations: int, + cooldown_frac: float = 0.5, split_embed_stage: int = 2, ws_post_yarn_ext: int = 20): + self.stages = stages + self.scheduled_iterations = scheduled_iterations + self.cooldown_frac = cooldown_frac + # increase final validation ws, used for YaRN extension and short window size @classiclarryd + self.ws_post_yarn_ext = ws_post_yarn_ext + + self.total_steps = self.scheduled_iterations + extension_iterations + + # Build stage boundaries (last is extension stage) + ends = [0] + [round(c * scheduled_iterations) for c in accumulate(s.duration for s in stages[:-1])] + [self.total_steps] + assert self.scheduled_iterations == ends[-2] + self.boundaries = list(pairwise(ends)) + + # Split embed at specified stage (ensure odd step for Adam) + self.split_step = self.boundaries[split_embed_stage][0] | 1 + + # Precompute MTP weights for all steps + self.mtp_weights = [] + for step in range(self.total_steps + 1): + stage, t = self.lookup(step) + w = [a + (b - a) * t for a, b in zip(stage.mtp_weights_start, stage.mtp_weights_end)] + self.mtp_weights.append(torch.tensor(w, device=device)) + + def lookup(self, step: int) -> tuple[TrainingStage, float]: + # Returns stage and % of the way through that stage + for i, (start, end) in enumerate(self.boundaries): + if step < end: + t = (step - start) / (end - start) + return self.stages[i], t + return self.stages[-1], 1.0 + + def get_lr(self, step: int) -> float: + # learning rate schedule: tied to batch size schedule, with cooldown at the end + stage, _ = self.lookup(step) + lr = stage.lr_mul + cd_start = int(self.scheduled_iterations * (1 - self.cooldown_frac)) + if step >= cd_start: + t = min(1.0, (step - cd_start) / (self.scheduled_iterations - cd_start)) + lr = lr * (1 - t) + 0.1 * t + return lr + +# window_sizes are in units of `block_size` tokens (defined in TrainingManager) +TRAINING_STAGES = [ + TrainingStage(duration=1/3, batch_size=8 * 2048 * 8, window_sizes=(1, 3), lr_mul=1.0, + mtp_weights_start=[1.0, 0.5, 0.25], mtp_weights_end=[1.0, 0.5, 0.0]), + TrainingStage(duration=1/3, batch_size=16 * 2048 * 8, window_sizes=(3, 7), lr_mul=1.52, # (16/8)**0.6 + mtp_weights_start=[1.0, 0.5], mtp_weights_end=[1.0, 0.0]), + TrainingStage(duration=1/3, batch_size=24 * 2048 * 8, window_sizes=(5, 11), lr_mul=1.73, # (24/8)**0.5 + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), + # extension stage + TrainingStage(batch_size=24 * 2048 * 8, window_sizes=(6, 13), lr_mul=1.0, # lr_mul is not used + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), +] + +training_schedule = TrainingSchedule(TRAINING_STAGES, args.num_scheduled_iterations, args.num_extension_iterations, cooldown_frac=0.55) + +def get_muon_momentum(step: int, muon_warmup_steps=300, muon_cooldown_steps=50, momentum_min=0.85, momentum_max=0.95): + # warmup phase: linearly increase momentum from min to max + # cooldown phase: linearly decrease momentum from max to min + momentum_cd_start = training_schedule.total_steps - muon_cooldown_steps + if step < muon_warmup_steps: + frac = step / muon_warmup_steps + momentum = momentum_min + frac * (momentum_max - momentum_min) + elif step > momentum_cd_start: + frac = (step - momentum_cd_start) / muon_cooldown_steps + momentum = momentum_max - frac * (momentum_max - momentum_min) + else: + momentum = momentum_max + return momentum + +class TrainingManager(): + """ + Manages the NorMuonAndAdam for all parameters with explicit ordering. + 1. Scalars are given higher momentum terms to smooth learning @ChrisJMcCormick + 2. Adam optimizers are only stepped on odd steps @classiclarryd + 3. Explicit scatter_order and work_order for communication scheduling (no backward hooks) + 4. Muon has a linear momentum warmup and cooldown schedule + 5. Learning rates follow a linear decay schedule + 6. Embed is tied to lm_head until split step (2/3 of training), then untied @classiclarryd + """ + def __init__(self, model): + self.model = model + self.block_size = 128 + + # - Ordering dictates when to launch reduce/reduce_scatter operations + # - "sharded" parameters use reduce_scatter/all_gather and "replicated" ones use all_reduce + # - lr_mul and wd_mul are per-parameter learning rate and weight decay multipliers + self.param_table = { + "attn": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "mlp": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "scalars": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 5.0, "wd_mul": 0.0}, + "value_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "bigram_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "smear_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.01, "wd_mul": 0.0}, + "skip_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.05, "wd_mul": 0.0}, + "attn_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "ve_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "x0_lambdas": {"optim": "adam", "comms": "replicated", "adam_betas": [0.65, 0.95], "lr_mul": 5.0, "wd_mul": 0.0}, + "lm_head": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + "embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + } + + # - Process smaller/faster params first while large reduces complete + # - lm_head must complete before embed sync (when tied) + self.work_order = [ + "scalars", "smear_gate", "skip_gate", "attn_gate_bank", "ve_gate_bank", "x0_lambdas", # Small, fast + "value_embed", "bigram_embed", # Medium + "lm_head", "embed", # lm_head must complete before embed sync (when tied) + "attn", "mlp", # Large, polar express - process last to maximize overlap + ] + + adam_defaults = dict( + lr=0.008, + eps=1e-10, + weight_decay=0.005, + ) + + normuon_defaults = dict( + lr=0.023, + momentum=0.95, + beta2=0.95, + weight_decay=1.2, + ) + + self.optimizer = NorMuonAndAdam( + model.named_parameters(), + param_table=self.param_table, + scatter_order=list(self.param_table.keys()), # Dict order defines scatter priority + work_order=self.work_order, + adam_defaults=adam_defaults, + normuon_defaults=normuon_defaults, + ) + + # Split embed from lm_head at 2/3 of training (on an odd step so Adam updates) + self.split_step = training_schedule.split_step + + self.reset() + + def apply_final_ws_ext(self): + self.ws_long = training_schedule.ws_post_yarn_ext + + def get_forward_args(self): + return ForwardScheduleConfig( + mtp_weights = self.mtp_weights, + ws_short = self.ws_short * self.block_size, + ws_long = self.ws_long * self.block_size + ) + + def _is_adam_step(self, step: int): + """Adam params are only updated on odd steps.""" + return step % 2 == 1 + + def get_transition_steps(self): + return [start for start, _ in training_schedule.boundaries[1:]] + + def advance_schedule(self, step: int): + stage, _ = training_schedule.lookup(step) + self.ws_short, new_ws_long = stage.window_sizes + if new_ws_long != self.ws_long: + self.model.yarn.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + self.model.yarn_paired_head.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + + new_batch_size = stage.batch_size + if new_batch_size != self.batch_size: + self.train_loader_send_args = (new_batch_size, args.train_max_seq_len, grad_accum_steps) + self.batch_size = new_batch_size + else: + self.train_loader_send_args = None + + self.ws_long = new_ws_long + self.mtp_weights = training_schedule.mtp_weights[step] + + def step_optimizers(self, step: int): + step_lr = training_schedule.get_lr(step) + muon_momentum = get_muon_momentum(step) + do_adam = self._is_adam_step(step) + + # Update learning rates and momentum for all params + for param, p_cfg in self.optimizer.param_cfgs.items(): + p_cfg.lr = p_cfg.initial_lr * step_lr + if p_cfg.optim == "normuon": + p_cfg.momentum = muon_momentum + + # Step optimizer with do_adam flag + self.optimizer.step(do_adam=do_adam) + + # At split step: copy lm_head optimizer state to embed and mark as split + if step == self.split_step: + self.optimizer.copy_lm_state_to_embed() + + def reset(self, state=None): + if state is not None: + self.optimizer.load_state_dict(state) + + # Reset NorMuon momentum buffers and split_embed state + self.optimizer.reset() + + stage, _ = training_schedule.lookup(0) + self.ws_short, self.ws_long = stage.window_sizes + self.batch_size = stage.batch_size + self.model.yarn.reset() + self.model.yarn_paired_head.reset() + + def get_state(self): + return copy.deepcopy(self.optimizer.state_dict()) + +# ----------------------------------------------------------------------------- +# int main + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=11, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=args.val_batch_size // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.weight.data = m.weight.data.bfloat16() +model.attn_gate_bank.data = model.attn_gate_bank.data.bfloat16() +model.ve_gate_bank.data = model.ve_gate_bank.data.bfloat16() +model.attn_bank.data = model.attn_bank.data.bfloat16() +model.mlp_bank.data = model.mlp_bank.data.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) +training_manager = TrainingManager(model) + +######################################## +# Warmup kernels # +######################################## +print0("Compiling model and warming up kernels (~7 minutes on first execution)", console=True) +# Warmup the training kernels, then re-initialize the state so we aren't cheating +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizer=training_manager.get_state()) # save the initial state +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + +transition_steps = training_manager.get_transition_steps() +# first few steps plus transitions +warmup_steps = sorted({0, 1, 2} | set(s + offset for s in transition_steps for offset in [-1, 0, 1] if s + offset >= 0)) +print0(f"Sampling steps {warmup_steps} for warmup", console=True) +for step in warmup_steps: + training_manager.advance_schedule(step) + model.eval() + with torch.no_grad(): + inputs, targets, cum_seqlens = next(val_loader) + model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + model.train() + for idx in range(grad_accum_steps): + send_args = training_manager.train_loader_send_args + inputs, targets, cum_seqlens = train_loader.send(send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) +print0("Resetting Model", console=True) +model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +training_manager.reset(initial_state["optimizer"]) +del val_loader, train_loader, initial_state +model.train() + +######################################## +# Training and validation # +######################################## +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) + +gc.collect() + +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = training_schedule.total_steps +for step in range(train_steps + 1): + last_step = (step == train_steps) + training_manager.advance_schedule(step) + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + training_manager.apply_final_ws_ext() + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + val_loss /= val_steps + del val_loader + dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizer=training_manager.get_state()) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for idx in range(grad_accum_steps): + inputs, targets, cum_seqlens = train_loader.send(training_manager.train_loader_send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) + + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + + +---------------------------------------- +# triton_kernels.py +---------------------------------------- + +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded configs based on H100 autotuning + if K == 768: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + else: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 128 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded config based on H100 autotuning (M=768) + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +# ----------------------------------------------------------------------------- +# Triton kernel for MLP: relu(x @ W1.T)^2, by @andrewbriand, @jrauvola + +@triton.jit +def linear_relu_square_kernel(a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, + ): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am_c = pid_m * BLOCK_SIZE_M + offs_bn_c = pid_n * BLOCK_SIZE_N + + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + + c0 = acc0.to(dtype) + if not FORWARD: + c0_pre = aux_desc.load([offs_am_c, offs_bn_c]) + c0 = 2 * c0 * tl.where(c0_pre > 0, c0_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c], c0) + + if FORWARD: + c0_post = tl.maximum(c0, 0) + c0_post = c0_post * c0_post + aux_desc.store([offs_am_c, offs_bn_c], c0_post) + + c1 = acc1.to(dtype) + if not FORWARD: + c1_pre = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c1 = 2 * c1 * tl.where(c1_pre > 0, c1_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + + if FORWARD: + c1_post = tl.maximum(c1, 0) + c1_post = c1_post * c1_post + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1_post) + + +def linear_relu_square(a, b, aux=None): + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + FORWARD = False + if aux is None: + FORWARD = True + aux = torch.empty((M, N), device=a.device, dtype=dtype) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_stages = 4 if FORWARD else 3 + num_warps = 8 + + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + + def grid(META): + return (min( + NUM_SMS, + triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), + ), ) + + linear_relu_square_kernel[grid]( + a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + NUM_SMS=NUM_SMS, + FORWARD=FORWARD, + num_stages=num_stages, + num_warps=num_warps + ) + + if FORWARD: + return c, aux + else: + return c + +class FusedLinearReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, W1, W2): + pre, post = linear_relu_square(x.view((-1, x.shape[-1])), W1) + x3 = post @ W2 + ctx.save_for_backward(x, W1, W2, pre, post) + return x3.view(x.shape) + + @staticmethod + def backward(ctx, grad_output): + x, W1, W2, pre, post = ctx.saved_tensors + dW2 = post.T @ grad_output + dpre = linear_relu_square(grad_output.view((-1, grad_output.shape[-1])), W2, aux=pre) + dW1 = dpre.T @ x + dx = dpre @ W1 + return dx.view(x.shape), dW1, dW2 + +# ----------------------------------------------------------------------------- +# Fused Softcapped Cross Entropy + + +@triton.jit +def fused_softcapped_entropy_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + + max_val = -float('inf') + sum_exp = 0.0 + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=-float('inf')).to(tl.float32) + z = A * tl.sigmoid((val + B) / C) + z = tl.where(mask, z, -float('inf')) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + + total_loss = 0.0 + for k in range(n_predict): + target_idx = row_idx + k + if target_idx < n_rows: + weight = tl.load(mtp_weights_ptr + k) + if weight > 0: + target = tl.load(targets_ptr + target_idx).to(tl.int32) + if target >= 0 and target < n_cols: + val_target = tl.load(logits_row_ptr + target).to(tl.float32) + z_target = A * tl.sigmoid((val_target + B) / C) + total_loss += weight * (lse - z_target) + + tl.store(losses_ptr + row_idx, total_loss) + +@triton.jit +def fused_softcapped_entropy_bwd_kernel( + grad_input_ptr, grad_output_ptr, lse_ptr, logits_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, stride_grad_n, stride_grad_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_input_ptr + row_idx * stride_grad_n + + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_output_ptr + row_idx) + + S_w = 0.0 + for k in range(n_predict): + if row_idx + k < n_rows: + S_w += tl.load(mtp_weights_ptr + k) + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=0.0).to(tl.float32) + u = (val + B) / C + sigmoid_u = tl.sigmoid(u) + z = A * sigmoid_u + p = tl.exp(z - lse) + + term1 = S_w * p + term2 = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for k in range(n_predict): + if row_idx + k < n_rows: + target = tl.load(targets_ptr + row_idx + k).to(tl.int32) + weight = tl.load(mtp_weights_ptr + k) + term2 += tl.where(cols == target, weight, 0.0) + + grad_z = grad_loss * (term1 - term2) + dz_dx = (1.0 / C) * z * (1.0 - sigmoid_u) + grad_x = grad_z * dz_dx + tl.store(grad_row_ptr + cols, grad_x.to(tl.bfloat16), mask=mask) + +class FusedSoftcappedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, targets, mtp_weights, A=23.0, B=5.0, C=7.5): + n_rows, n_cols = logits.shape + if mtp_weights is None: + mtp_weights = torch.tensor([1.0], device=logits.device, dtype=torch.float32) + n_predict = mtp_weights.shape[0] + + losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + lse = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + + logits = logits.contiguous() + targets = targets.contiguous() + mtp_weights = mtp_weights.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_fwd_kernel[grid]( + logits, losses, lse, targets, mtp_weights, + logits.stride(0), logits.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + + ctx.save_for_backward(logits, targets, mtp_weights, lse) + ctx.params = (A, B, C) + return losses + + @staticmethod + def backward(ctx, grad_output): + logits, targets, mtp_weights, lse = ctx.saved_tensors + A, B, C = ctx.params + n_rows, n_cols = logits.shape + n_predict = mtp_weights.shape[0] + + grad_input = torch.empty((n_rows, n_cols), dtype=torch.bfloat16, device=logits.device) + grad_output = grad_output.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_bwd_kernel[grid]( + grad_input, grad_output, lse, logits, targets, mtp_weights, + logits.stride(0), logits.stride(1), grad_input.stride(0), grad_input.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + return grad_input, None, None, None, None, None + +==================================================================================================== +Running Python 3.12.7 (main, Jan 31 2026, 04:21:49) [GCC 13.2.0] +Running PyTorch 2.10.0.dev20251210+cu126 compiled for CUDA 12.6 +Running Triton version 3.6.0 +Sun Feb 1 06:13:14 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:71:00.0 Off | 0 | +| N/A 39C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:79:00.0 Off | 0 | +| N/A 34C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:7F:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 39C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 37C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 20699 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 20700 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 20701 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 20702 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 20703 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 20704 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 20705 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 20706 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +Compiling model and warming up kernels (~7 minutes on first execution) +Sampling steps [0, 1, 2, 504, 505, 506, 1009, 1010, 1011, 1514, 1515, 1516] for warmup +Resetting Model +step:0/1555 val_loss:10.8327 train_time:0ms step_avg:0.03ms +step:1/1555 train_time:84ms step_avg:83.57ms +step:2/1555 train_time:108ms step_avg:54.04ms +step:3/1555 train_time:129ms step_avg:42.97ms +step:4/1555 train_time:150ms step_avg:37.48ms +step:5/1555 train_time:180ms step_avg:35.94ms +step:6/1555 train_time:217ms step_avg:36.15ms +step:7/1555 train_time:248ms step_avg:35.42ms +step:8/1555 train_time:286ms step_avg:35.76ms +step:9/1555 train_time:317ms step_avg:35.21ms +step:10/1555 train_time:355ms step_avg:35.45ms +step:11/1555 train_time:385ms step_avg:35.01ms +step:12/1555 train_time:423ms step_avg:35.23ms +step:13/1555 train_time:454ms step_avg:34.89ms +step:14/1555 train_time:491ms step_avg:35.10ms +step:15/1555 train_time:522ms step_avg:34.82ms +step:16/1555 train_time:560ms step_avg:35.02ms +step:17/1555 train_time:591ms step_avg:34.78ms +step:18/1555 train_time:629ms step_avg:34.92ms +step:19/1555 train_time:660ms step_avg:34.72ms +step:20/1555 train_time:697ms step_avg:34.86ms +step:21/1555 train_time:728ms step_avg:34.69ms +step:22/1555 train_time:766ms step_avg:34.81ms +step:23/1555 train_time:797ms step_avg:34.63ms +step:24/1555 train_time:834ms step_avg:34.75ms +step:25/1555 train_time:865ms step_avg:34.59ms +step:26/1555 train_time:903ms step_avg:34.71ms +step:27/1555 train_time:934ms step_avg:34.58ms +step:28/1555 train_time:971ms step_avg:34.69ms +step:29/1555 train_time:1002ms step_avg:34.57ms +step:30/1555 train_time:1041ms step_avg:34.69ms +step:31/1555 train_time:1072ms step_avg:34.57ms +step:32/1555 train_time:1110ms step_avg:34.68ms +step:33/1555 train_time:1141ms step_avg:34.57ms +step:34/1555 train_time:1179ms step_avg:34.67ms +step:35/1555 train_time:1210ms step_avg:34.57ms +step:36/1555 train_time:1248ms step_avg:34.65ms +step:37/1555 train_time:1279ms step_avg:34.57ms +step:38/1555 train_time:1317ms step_avg:34.67ms +step:39/1555 train_time:1348ms step_avg:34.56ms +step:40/1555 train_time:1386ms step_avg:34.64ms +step:41/1555 train_time:1417ms step_avg:34.55ms +step:42/1555 train_time:1454ms step_avg:34.62ms +step:43/1555 train_time:1485ms step_avg:34.54ms +step:44/1555 train_time:1523ms step_avg:34.62ms +step:45/1555 train_time:1554ms step_avg:34.54ms +step:46/1555 train_time:1592ms step_avg:34.61ms +step:47/1555 train_time:1623ms step_avg:34.53ms +step:48/1555 train_time:1660ms step_avg:34.59ms +step:49/1555 train_time:1692ms step_avg:34.52ms +step:50/1555 train_time:1730ms step_avg:34.61ms +step:51/1555 train_time:1763ms step_avg:34.56ms +step:52/1555 train_time:1800ms step_avg:34.61ms +step:53/1555 train_time:1831ms step_avg:34.54ms +step:54/1555 train_time:1868ms step_avg:34.60ms +step:55/1555 train_time:1899ms step_avg:34.53ms +step:56/1555 train_time:1937ms step_avg:34.59ms +step:57/1555 train_time:1968ms step_avg:34.53ms +step:58/1555 train_time:2006ms step_avg:34.58ms +step:59/1555 train_time:2036ms step_avg:34.51ms +step:60/1555 train_time:2074ms step_avg:34.57ms +step:61/1555 train_time:2105ms step_avg:34.51ms +step:62/1555 train_time:2143ms step_avg:34.56ms +step:63/1555 train_time:2173ms step_avg:34.50ms +step:64/1555 train_time:2211ms step_avg:34.55ms +step:65/1555 train_time:2242ms step_avg:34.49ms +step:66/1555 train_time:2279ms step_avg:34.54ms +step:67/1555 train_time:2310ms step_avg:34.48ms +step:68/1555 train_time:2348ms step_avg:34.53ms +step:69/1555 train_time:2379ms step_avg:34.48ms +step:70/1555 train_time:2417ms step_avg:34.53ms +step:71/1555 train_time:2448ms step_avg:34.48ms +step:72/1555 train_time:2485ms step_avg:34.52ms +step:73/1555 train_time:2516ms step_avg:34.47ms +step:74/1555 train_time:2554ms step_avg:34.52ms +step:75/1555 train_time:2586ms step_avg:34.48ms +step:76/1555 train_time:2623ms step_avg:34.52ms +step:77/1555 train_time:2655ms step_avg:34.47ms +step:78/1555 train_time:2692ms step_avg:34.51ms +step:79/1555 train_time:2723ms step_avg:34.47ms +step:80/1555 train_time:2761ms step_avg:34.51ms +step:81/1555 train_time:2792ms step_avg:34.47ms +step:82/1555 train_time:2829ms step_avg:34.51ms +step:83/1555 train_time:2861ms step_avg:34.47ms +step:84/1555 train_time:2899ms step_avg:34.52ms +step:85/1555 train_time:2930ms step_avg:34.47ms +step:86/1555 train_time:2967ms step_avg:34.50ms +step:87/1555 train_time:2999ms step_avg:34.47ms +step:88/1555 train_time:3037ms step_avg:34.51ms +step:89/1555 train_time:3068ms step_avg:34.47ms +step:90/1555 train_time:3106ms step_avg:34.51ms +step:91/1555 train_time:3138ms step_avg:34.48ms +step:92/1555 train_time:3176ms step_avg:34.52ms +step:93/1555 train_time:3207ms step_avg:34.49ms +step:94/1555 train_time:3245ms step_avg:34.52ms +step:95/1555 train_time:3276ms step_avg:34.48ms +step:96/1555 train_time:3313ms step_avg:34.51ms +step:97/1555 train_time:3344ms step_avg:34.48ms +step:98/1555 train_time:3382ms step_avg:34.51ms +step:99/1555 train_time:3413ms step_avg:34.47ms +step:100/1555 train_time:3450ms step_avg:34.50ms +step:101/1555 train_time:3481ms step_avg:34.46ms +step:102/1555 train_time:3518ms step_avg:34.49ms +step:103/1555 train_time:3550ms step_avg:34.46ms +step:104/1555 train_time:3587ms step_avg:34.49ms +step:105/1555 train_time:3618ms step_avg:34.46ms +step:106/1555 train_time:3656ms step_avg:34.49ms +step:107/1555 train_time:3687ms step_avg:34.46ms +step:108/1555 train_time:3724ms step_avg:34.48ms +step:109/1555 train_time:3755ms step_avg:34.45ms +step:110/1555 train_time:3792ms step_avg:34.47ms +step:111/1555 train_time:3823ms step_avg:34.44ms +step:112/1555 train_time:3861ms step_avg:34.47ms +step:113/1555 train_time:3892ms step_avg:34.44ms +step:114/1555 train_time:3929ms step_avg:34.47ms +step:115/1555 train_time:3960ms step_avg:34.44ms +step:116/1555 train_time:3998ms step_avg:34.47ms +step:117/1555 train_time:4030ms step_avg:34.44ms +step:118/1555 train_time:4067ms step_avg:34.47ms +step:119/1555 train_time:4099ms step_avg:34.44ms +step:120/1555 train_time:4137ms step_avg:34.47ms +step:121/1555 train_time:4168ms step_avg:34.44ms +step:122/1555 train_time:4206ms step_avg:34.47ms +step:123/1555 train_time:4237ms step_avg:34.44ms +step:124/1555 train_time:4274ms step_avg:34.47ms +step:125/1555 train_time:4305ms step_avg:34.44ms +step:126/1555 train_time:4343ms step_avg:34.47ms +step:127/1555 train_time:4374ms step_avg:34.44ms +step:128/1555 train_time:4411ms step_avg:34.46ms +step:129/1555 train_time:4443ms step_avg:34.44ms +step:130/1555 train_time:4481ms step_avg:34.47ms +step:131/1555 train_time:4512ms step_avg:34.44ms +step:132/1555 train_time:4549ms step_avg:34.46ms +step:133/1555 train_time:4580ms step_avg:34.44ms +step:134/1555 train_time:4618ms step_avg:34.46ms +step:135/1555 train_time:4648ms step_avg:34.43ms +step:136/1555 train_time:4686ms step_avg:34.45ms +step:137/1555 train_time:4717ms step_avg:34.43ms +step:138/1555 train_time:4754ms step_avg:34.45ms +step:139/1555 train_time:4785ms step_avg:34.43ms +step:140/1555 train_time:4823ms step_avg:34.45ms +step:141/1555 train_time:4854ms step_avg:34.43ms +step:142/1555 train_time:4892ms step_avg:34.45ms +step:143/1555 train_time:4922ms step_avg:34.42ms +step:144/1555 train_time:4961ms step_avg:34.45ms +step:145/1555 train_time:4992ms step_avg:34.43ms +step:146/1555 train_time:5029ms step_avg:34.45ms +step:147/1555 train_time:5061ms step_avg:34.43ms +step:148/1555 train_time:5098ms step_avg:34.45ms +step:149/1555 train_time:5129ms step_avg:34.42ms +step:150/1555 train_time:5166ms step_avg:34.44ms +step:151/1555 train_time:5198ms step_avg:34.42ms +step:152/1555 train_time:5235ms step_avg:34.44ms +step:153/1555 train_time:5266ms step_avg:34.42ms +step:154/1555 train_time:5303ms step_avg:34.44ms +step:155/1555 train_time:5334ms step_avg:34.42ms +step:156/1555 train_time:5372ms step_avg:34.43ms +step:157/1555 train_time:5402ms step_avg:34.41ms +step:158/1555 train_time:5440ms step_avg:34.43ms +step:159/1555 train_time:5471ms step_avg:34.41ms +step:160/1555 train_time:5508ms step_avg:34.42ms +step:161/1555 train_time:5539ms step_avg:34.40ms +step:162/1555 train_time:5577ms step_avg:34.42ms +step:163/1555 train_time:5607ms step_avg:34.40ms +step:164/1555 train_time:5645ms step_avg:34.42ms +step:165/1555 train_time:5676ms step_avg:34.40ms +step:166/1555 train_time:5713ms step_avg:34.41ms +step:167/1555 train_time:5744ms step_avg:34.40ms +step:168/1555 train_time:5782ms step_avg:34.42ms +step:169/1555 train_time:5814ms step_avg:34.40ms +step:170/1555 train_time:5851ms step_avg:34.42ms +step:171/1555 train_time:5882ms step_avg:34.40ms +step:172/1555 train_time:5919ms step_avg:34.41ms +step:173/1555 train_time:5950ms step_avg:34.39ms +step:174/1555 train_time:5987ms step_avg:34.41ms +step:175/1555 train_time:6018ms step_avg:34.39ms +step:176/1555 train_time:6056ms step_avg:34.41ms +step:177/1555 train_time:6087ms step_avg:34.39ms +step:178/1555 train_time:6124ms step_avg:34.40ms +step:179/1555 train_time:6155ms step_avg:34.39ms +step:180/1555 train_time:6192ms step_avg:34.40ms +step:181/1555 train_time:6224ms step_avg:34.39ms +step:182/1555 train_time:6262ms step_avg:34.40ms +step:183/1555 train_time:6293ms step_avg:34.39ms +step:184/1555 train_time:6330ms step_avg:34.40ms +step:185/1555 train_time:6361ms step_avg:34.38ms +step:186/1555 train_time:6399ms step_avg:34.40ms +step:187/1555 train_time:6429ms step_avg:34.38ms +step:188/1555 train_time:6466ms step_avg:34.40ms +step:189/1555 train_time:6498ms step_avg:34.38ms +step:190/1555 train_time:6535ms step_avg:34.40ms +step:191/1555 train_time:6566ms step_avg:34.38ms +step:192/1555 train_time:6603ms step_avg:34.39ms +step:193/1555 train_time:6634ms step_avg:34.38ms +step:194/1555 train_time:6672ms step_avg:34.39ms +step:195/1555 train_time:6703ms step_avg:34.37ms +step:196/1555 train_time:6741ms step_avg:34.39ms +step:197/1555 train_time:6772ms step_avg:34.37ms +step:198/1555 train_time:6809ms step_avg:34.39ms +step:199/1555 train_time:6839ms step_avg:34.37ms +step:200/1555 train_time:6877ms step_avg:34.38ms +step:201/1555 train_time:6908ms step_avg:34.37ms +step:202/1555 train_time:6945ms step_avg:34.38ms +step:203/1555 train_time:6976ms step_avg:34.37ms +step:204/1555 train_time:7014ms step_avg:34.38ms +step:205/1555 train_time:7045ms step_avg:34.36ms +step:206/1555 train_time:7083ms step_avg:34.38ms +step:207/1555 train_time:7114ms step_avg:34.37ms +step:208/1555 train_time:7151ms step_avg:34.38ms +step:209/1555 train_time:7182ms step_avg:34.36ms +step:210/1555 train_time:7219ms step_avg:34.38ms +step:211/1555 train_time:7251ms step_avg:34.36ms +step:212/1555 train_time:7288ms step_avg:34.38ms +step:213/1555 train_time:7319ms step_avg:34.36ms +step:214/1555 train_time:7357ms step_avg:34.38ms +step:215/1555 train_time:7388ms step_avg:34.36ms +step:216/1555 train_time:7426ms step_avg:34.38ms +step:217/1555 train_time:7457ms step_avg:34.36ms +step:218/1555 train_time:7495ms step_avg:34.38ms +step:219/1555 train_time:7525ms step_avg:34.36ms +step:220/1555 train_time:7563ms step_avg:34.38ms +step:221/1555 train_time:7594ms step_avg:34.36ms +step:222/1555 train_time:7631ms step_avg:34.37ms +step:223/1555 train_time:7663ms step_avg:34.36ms +step:224/1555 train_time:7700ms step_avg:34.38ms +step:225/1555 train_time:7731ms step_avg:34.36ms +step:226/1555 train_time:7768ms step_avg:34.37ms +step:227/1555 train_time:7799ms step_avg:34.36ms +step:228/1555 train_time:7837ms step_avg:34.37ms +step:229/1555 train_time:7868ms step_avg:34.36ms +step:230/1555 train_time:7905ms step_avg:34.37ms +step:231/1555 train_time:7936ms step_avg:34.36ms +step:232/1555 train_time:7973ms step_avg:34.37ms +step:233/1555 train_time:8005ms step_avg:34.35ms +step:234/1555 train_time:8042ms step_avg:34.37ms +step:235/1555 train_time:8073ms step_avg:34.35ms +step:236/1555 train_time:8110ms step_avg:34.37ms +step:237/1555 train_time:8141ms step_avg:34.35ms +step:238/1555 train_time:8179ms step_avg:34.36ms +step:239/1555 train_time:8210ms step_avg:34.35ms +step:240/1555 train_time:8247ms step_avg:34.36ms +step:241/1555 train_time:8278ms step_avg:34.35ms +step:242/1555 train_time:8316ms step_avg:34.36ms +step:243/1555 train_time:8347ms step_avg:34.35ms +step:244/1555 train_time:8384ms step_avg:34.36ms +step:245/1555 train_time:8416ms step_avg:34.35ms +step:246/1555 train_time:8453ms step_avg:34.36ms +step:247/1555 train_time:8484ms step_avg:34.35ms +step:248/1555 train_time:8521ms step_avg:34.36ms +step:249/1555 train_time:8553ms step_avg:34.35ms +step:250/1555 train_time:8590ms step_avg:34.36ms +step:250/1555 val_loss:4.5471 train_time:8640ms step_avg:34.56ms +step:251/1555 train_time:8658ms step_avg:34.50ms +step:252/1555 train_time:8678ms step_avg:34.44ms +step:253/1555 train_time:8695ms step_avg:34.37ms +step:254/1555 train_time:8732ms step_avg:34.38ms +step:255/1555 train_time:8765ms step_avg:34.37ms +step:256/1555 train_time:8804ms step_avg:34.39ms +step:257/1555 train_time:8835ms step_avg:34.38ms +step:258/1555 train_time:8873ms step_avg:34.39ms +step:259/1555 train_time:8904ms step_avg:34.38ms +step:260/1555 train_time:8941ms step_avg:34.39ms +step:261/1555 train_time:8972ms step_avg:34.38ms +step:262/1555 train_time:9010ms step_avg:34.39ms +step:263/1555 train_time:9041ms step_avg:34.38ms +step:264/1555 train_time:9079ms step_avg:34.39ms +step:265/1555 train_time:9109ms step_avg:34.37ms +step:266/1555 train_time:9146ms step_avg:34.38ms +step:267/1555 train_time:9177ms step_avg:34.37ms +step:268/1555 train_time:9215ms step_avg:34.38ms +step:269/1555 train_time:9246ms step_avg:34.37ms +step:270/1555 train_time:9283ms step_avg:34.38ms +step:271/1555 train_time:9314ms step_avg:34.37ms +step:272/1555 train_time:9351ms step_avg:34.38ms +step:273/1555 train_time:9382ms step_avg:34.37ms +step:274/1555 train_time:9419ms step_avg:34.38ms +step:275/1555 train_time:9450ms step_avg:34.36ms +step:276/1555 train_time:9487ms step_avg:34.37ms +step:277/1555 train_time:9518ms step_avg:34.36ms +step:278/1555 train_time:9555ms step_avg:34.37ms +step:279/1555 train_time:9586ms step_avg:34.36ms +step:280/1555 train_time:9623ms step_avg:34.37ms +step:281/1555 train_time:9654ms step_avg:34.36ms +step:282/1555 train_time:9693ms step_avg:34.37ms +step:283/1555 train_time:9724ms step_avg:34.36ms +step:284/1555 train_time:9762ms step_avg:34.37ms +step:285/1555 train_time:9793ms step_avg:34.36ms +step:286/1555 train_time:9831ms step_avg:34.37ms +step:287/1555 train_time:9862ms step_avg:34.36ms +step:288/1555 train_time:9899ms step_avg:34.37ms +step:289/1555 train_time:9931ms step_avg:34.36ms +step:290/1555 train_time:9968ms step_avg:34.37ms +step:291/1555 train_time:9999ms step_avg:34.36ms +step:292/1555 train_time:10037ms step_avg:34.37ms +step:293/1555 train_time:10068ms step_avg:34.36ms +step:294/1555 train_time:10106ms step_avg:34.37ms +step:295/1555 train_time:10137ms step_avg:34.36ms +step:296/1555 train_time:10175ms step_avg:34.37ms +step:297/1555 train_time:10206ms step_avg:34.36ms +step:298/1555 train_time:10242ms step_avg:34.37ms +step:299/1555 train_time:10274ms step_avg:34.36ms +step:300/1555 train_time:10311ms step_avg:34.37ms +step:301/1555 train_time:10342ms step_avg:34.36ms +step:302/1555 train_time:10380ms step_avg:34.37ms +step:303/1555 train_time:10410ms step_avg:34.36ms +step:304/1555 train_time:10448ms step_avg:34.37ms +step:305/1555 train_time:10479ms step_avg:34.36ms +step:306/1555 train_time:10516ms step_avg:34.37ms +step:307/1555 train_time:10547ms step_avg:34.36ms +step:308/1555 train_time:10584ms step_avg:34.36ms +step:309/1555 train_time:10615ms step_avg:34.35ms +step:310/1555 train_time:10653ms step_avg:34.36ms +step:311/1555 train_time:10683ms step_avg:34.35ms +step:312/1555 train_time:10721ms step_avg:34.36ms +step:313/1555 train_time:10752ms step_avg:34.35ms +step:314/1555 train_time:10789ms step_avg:34.36ms +step:315/1555 train_time:10820ms step_avg:34.35ms +step:316/1555 train_time:10858ms step_avg:34.36ms +step:317/1555 train_time:10889ms step_avg:34.35ms +step:318/1555 train_time:10926ms step_avg:34.36ms +step:319/1555 train_time:10958ms step_avg:34.35ms +step:320/1555 train_time:10995ms step_avg:34.36ms +step:321/1555 train_time:11026ms step_avg:34.35ms +step:322/1555 train_time:11063ms step_avg:34.36ms +step:323/1555 train_time:11094ms step_avg:34.35ms +step:324/1555 train_time:11132ms step_avg:34.36ms +step:325/1555 train_time:11163ms step_avg:34.35ms +step:326/1555 train_time:11200ms step_avg:34.36ms +step:327/1555 train_time:11231ms step_avg:34.35ms +step:328/1555 train_time:11269ms step_avg:34.36ms +step:329/1555 train_time:11300ms step_avg:34.35ms +step:330/1555 train_time:11337ms step_avg:34.35ms +step:331/1555 train_time:11367ms step_avg:34.34ms +step:332/1555 train_time:11405ms step_avg:34.35ms +step:333/1555 train_time:11435ms step_avg:34.34ms +step:334/1555 train_time:11473ms step_avg:34.35ms +step:335/1555 train_time:11504ms step_avg:34.34ms +step:336/1555 train_time:11541ms step_avg:34.35ms +step:337/1555 train_time:11572ms step_avg:34.34ms +step:338/1555 train_time:11610ms step_avg:34.35ms +step:339/1555 train_time:11641ms step_avg:34.34ms +step:340/1555 train_time:11679ms step_avg:34.35ms +step:341/1555 train_time:11710ms step_avg:34.34ms +step:342/1555 train_time:11748ms step_avg:34.35ms +step:343/1555 train_time:11779ms step_avg:34.34ms +step:344/1555 train_time:11817ms step_avg:34.35ms +step:345/1555 train_time:11848ms step_avg:34.34ms +step:346/1555 train_time:11885ms step_avg:34.35ms +step:347/1555 train_time:11916ms step_avg:34.34ms +step:348/1555 train_time:11954ms step_avg:34.35ms +step:349/1555 train_time:11985ms step_avg:34.34ms +step:350/1555 train_time:12022ms step_avg:34.35ms +step:351/1555 train_time:12053ms step_avg:34.34ms +step:352/1555 train_time:12090ms step_avg:34.35ms +step:353/1555 train_time:12122ms step_avg:34.34ms +step:354/1555 train_time:12159ms step_avg:34.35ms +step:355/1555 train_time:12189ms step_avg:34.34ms +step:356/1555 train_time:12227ms step_avg:34.34ms +step:357/1555 train_time:12258ms step_avg:34.34ms +step:358/1555 train_time:12296ms step_avg:34.35ms +step:359/1555 train_time:12326ms step_avg:34.34ms +step:360/1555 train_time:12364ms step_avg:34.34ms +step:361/1555 train_time:12394ms step_avg:34.33ms +step:362/1555 train_time:12432ms step_avg:34.34ms +step:363/1555 train_time:12463ms step_avg:34.33ms +step:364/1555 train_time:12500ms step_avg:34.34ms +step:365/1555 train_time:12532ms step_avg:34.33ms +step:366/1555 train_time:12569ms step_avg:34.34ms +step:367/1555 train_time:12600ms step_avg:34.33ms +step:368/1555 train_time:12637ms step_avg:34.34ms +step:369/1555 train_time:12668ms step_avg:34.33ms +step:370/1555 train_time:12705ms step_avg:34.34ms +step:371/1555 train_time:12737ms step_avg:34.33ms +step:372/1555 train_time:12774ms step_avg:34.34ms +step:373/1555 train_time:12805ms step_avg:34.33ms +step:374/1555 train_time:12842ms step_avg:34.34ms +step:375/1555 train_time:12873ms step_avg:34.33ms +step:376/1555 train_time:12910ms step_avg:34.34ms +step:377/1555 train_time:12942ms step_avg:34.33ms +step:378/1555 train_time:12979ms step_avg:34.34ms +step:379/1555 train_time:13010ms step_avg:34.33ms +step:380/1555 train_time:13048ms step_avg:34.34ms +step:381/1555 train_time:13079ms step_avg:34.33ms +step:382/1555 train_time:13116ms step_avg:34.34ms +step:383/1555 train_time:13147ms step_avg:34.33ms +step:384/1555 train_time:13184ms step_avg:34.33ms +step:385/1555 train_time:13216ms step_avg:34.33ms +step:386/1555 train_time:13253ms step_avg:34.33ms +step:387/1555 train_time:13284ms step_avg:34.33ms +step:388/1555 train_time:13322ms step_avg:34.33ms +step:389/1555 train_time:13353ms step_avg:34.33ms +step:390/1555 train_time:13391ms step_avg:34.34ms +step:391/1555 train_time:13422ms step_avg:34.33ms +step:392/1555 train_time:13459ms step_avg:34.33ms +step:393/1555 train_time:13490ms step_avg:34.33ms +step:394/1555 train_time:13527ms step_avg:34.33ms +step:395/1555 train_time:13558ms step_avg:34.32ms +step:396/1555 train_time:13596ms step_avg:34.33ms +step:397/1555 train_time:13626ms step_avg:34.32ms +step:398/1555 train_time:13664ms step_avg:34.33ms +step:399/1555 train_time:13695ms step_avg:34.32ms +step:400/1555 train_time:13732ms step_avg:34.33ms +step:401/1555 train_time:13763ms step_avg:34.32ms +step:402/1555 train_time:13800ms step_avg:34.33ms +step:403/1555 train_time:13831ms step_avg:34.32ms +step:404/1555 train_time:13868ms step_avg:34.33ms +step:405/1555 train_time:13899ms step_avg:34.32ms +step:406/1555 train_time:13936ms step_avg:34.33ms +step:407/1555 train_time:13967ms step_avg:34.32ms +step:408/1555 train_time:14004ms step_avg:34.32ms +step:409/1555 train_time:14035ms step_avg:34.32ms +step:410/1555 train_time:14073ms step_avg:34.32ms +step:411/1555 train_time:14104ms step_avg:34.32ms +step:412/1555 train_time:14141ms step_avg:34.32ms +step:413/1555 train_time:14172ms step_avg:34.32ms +step:414/1555 train_time:14209ms step_avg:34.32ms +step:415/1555 train_time:14241ms step_avg:34.31ms +step:416/1555 train_time:14278ms step_avg:34.32ms +step:417/1555 train_time:14309ms step_avg:34.31ms +step:418/1555 train_time:14346ms step_avg:34.32ms +step:419/1555 train_time:14378ms step_avg:34.31ms +step:420/1555 train_time:14415ms step_avg:34.32ms +step:421/1555 train_time:14446ms step_avg:34.31ms +step:422/1555 train_time:14483ms step_avg:34.32ms +step:423/1555 train_time:14514ms step_avg:34.31ms +step:424/1555 train_time:14552ms step_avg:34.32ms +step:425/1555 train_time:14583ms step_avg:34.31ms +step:426/1555 train_time:14620ms step_avg:34.32ms +step:427/1555 train_time:14651ms step_avg:34.31ms +step:428/1555 train_time:14689ms step_avg:34.32ms +step:429/1555 train_time:14719ms step_avg:34.31ms +step:430/1555 train_time:14756ms step_avg:34.32ms +step:431/1555 train_time:14787ms step_avg:34.31ms +step:432/1555 train_time:14824ms step_avg:34.32ms +step:433/1555 train_time:14855ms step_avg:34.31ms +step:434/1555 train_time:14893ms step_avg:34.31ms +step:435/1555 train_time:14924ms step_avg:34.31ms +step:436/1555 train_time:14961ms step_avg:34.31ms +step:437/1555 train_time:14993ms step_avg:34.31ms +step:438/1555 train_time:15030ms step_avg:34.32ms +step:439/1555 train_time:15061ms step_avg:34.31ms +step:440/1555 train_time:15098ms step_avg:34.31ms +step:441/1555 train_time:15129ms step_avg:34.31ms +step:442/1555 train_time:15167ms step_avg:34.32ms +step:443/1555 train_time:15199ms step_avg:34.31ms +step:444/1555 train_time:15236ms step_avg:34.32ms +step:445/1555 train_time:15267ms step_avg:34.31ms +step:446/1555 train_time:15304ms step_avg:34.31ms +step:447/1555 train_time:15335ms step_avg:34.31ms +step:448/1555 train_time:15372ms step_avg:34.31ms +step:449/1555 train_time:15403ms step_avg:34.31ms +step:450/1555 train_time:15441ms step_avg:34.31ms +step:451/1555 train_time:15472ms step_avg:34.31ms +step:452/1555 train_time:15509ms step_avg:34.31ms +step:453/1555 train_time:15540ms step_avg:34.30ms +step:454/1555 train_time:15577ms step_avg:34.31ms +step:455/1555 train_time:15609ms step_avg:34.30ms +step:456/1555 train_time:15646ms step_avg:34.31ms +step:457/1555 train_time:15677ms step_avg:34.30ms +step:458/1555 train_time:15714ms step_avg:34.31ms +step:459/1555 train_time:15745ms step_avg:34.30ms +step:460/1555 train_time:15782ms step_avg:34.31ms +step:461/1555 train_time:15814ms step_avg:34.30ms +step:462/1555 train_time:15852ms step_avg:34.31ms +step:463/1555 train_time:15882ms step_avg:34.30ms +step:464/1555 train_time:15920ms step_avg:34.31ms +step:465/1555 train_time:15951ms step_avg:34.30ms +step:466/1555 train_time:15989ms step_avg:34.31ms +step:467/1555 train_time:16021ms step_avg:34.31ms +step:468/1555 train_time:16058ms step_avg:34.31ms +step:469/1555 train_time:16089ms step_avg:34.31ms +step:470/1555 train_time:16127ms step_avg:34.31ms +step:471/1555 train_time:16158ms step_avg:34.31ms +step:472/1555 train_time:16195ms step_avg:34.31ms +step:473/1555 train_time:16226ms step_avg:34.30ms +step:474/1555 train_time:16263ms step_avg:34.31ms +step:475/1555 train_time:16294ms step_avg:34.30ms +step:476/1555 train_time:16332ms step_avg:34.31ms +step:477/1555 train_time:16363ms step_avg:34.30ms +step:478/1555 train_time:16400ms step_avg:34.31ms +step:479/1555 train_time:16431ms step_avg:34.30ms +step:480/1555 train_time:16469ms step_avg:34.31ms +step:481/1555 train_time:16500ms step_avg:34.30ms +step:482/1555 train_time:16537ms step_avg:34.31ms +step:483/1555 train_time:16568ms step_avg:34.30ms +step:484/1555 train_time:16605ms step_avg:34.31ms +step:485/1555 train_time:16636ms step_avg:34.30ms +step:486/1555 train_time:16673ms step_avg:34.31ms +step:487/1555 train_time:16704ms step_avg:34.30ms +step:488/1555 train_time:16741ms step_avg:34.31ms +step:489/1555 train_time:16772ms step_avg:34.30ms +step:490/1555 train_time:16810ms step_avg:34.31ms +step:491/1555 train_time:16841ms step_avg:34.30ms +step:492/1555 train_time:16878ms step_avg:34.31ms +step:493/1555 train_time:16909ms step_avg:34.30ms +step:494/1555 train_time:16947ms step_avg:34.31ms +step:495/1555 train_time:16978ms step_avg:34.30ms +step:496/1555 train_time:17016ms step_avg:34.31ms +step:497/1555 train_time:17047ms step_avg:34.30ms +step:498/1555 train_time:17084ms step_avg:34.30ms +step:499/1555 train_time:17115ms step_avg:34.30ms +step:500/1555 train_time:17153ms step_avg:34.31ms +step:500/1555 val_loss:4.2237 train_time:17203ms step_avg:34.41ms +step:501/1555 train_time:17221ms step_avg:34.37ms +step:502/1555 train_time:17240ms step_avg:34.34ms +step:503/1555 train_time:17256ms step_avg:34.31ms +step:504/1555 train_time:17292ms step_avg:34.31ms +step:505/1555 train_time:17326ms step_avg:34.31ms +step:506/1555 train_time:17368ms step_avg:34.32ms +step:507/1555 train_time:17423ms step_avg:34.36ms +step:508/1555 train_time:17487ms step_avg:34.42ms +step:509/1555 train_time:17544ms step_avg:34.47ms +step:510/1555 train_time:17607ms step_avg:34.52ms +step:511/1555 train_time:17665ms step_avg:34.57ms +step:512/1555 train_time:17729ms step_avg:34.63ms +step:513/1555 train_time:17785ms step_avg:34.67ms +step:514/1555 train_time:17848ms step_avg:34.72ms +step:515/1555 train_time:17905ms step_avg:34.77ms +step:516/1555 train_time:17968ms step_avg:34.82ms +step:517/1555 train_time:18025ms step_avg:34.86ms +step:518/1555 train_time:18089ms step_avg:34.92ms +step:519/1555 train_time:18146ms step_avg:34.96ms +step:520/1555 train_time:18211ms step_avg:35.02ms +step:521/1555 train_time:18270ms step_avg:35.07ms +step:522/1555 train_time:18336ms step_avg:35.13ms +step:523/1555 train_time:18395ms step_avg:35.17ms +step:524/1555 train_time:18461ms step_avg:35.23ms +step:525/1555 train_time:18519ms step_avg:35.27ms +step:526/1555 train_time:18583ms step_avg:35.33ms +step:527/1555 train_time:18641ms step_avg:35.37ms +step:528/1555 train_time:18704ms step_avg:35.42ms +step:529/1555 train_time:18761ms step_avg:35.47ms +step:530/1555 train_time:18825ms step_avg:35.52ms +step:531/1555 train_time:18883ms step_avg:35.56ms +step:532/1555 train_time:18946ms step_avg:35.61ms +step:533/1555 train_time:19003ms step_avg:35.65ms +step:534/1555 train_time:19067ms step_avg:35.71ms +step:535/1555 train_time:19125ms step_avg:35.75ms +step:536/1555 train_time:19191ms step_avg:35.80ms +step:537/1555 train_time:19250ms step_avg:35.85ms +step:538/1555 train_time:19315ms step_avg:35.90ms +step:539/1555 train_time:19372ms step_avg:35.94ms +step:540/1555 train_time:19437ms step_avg:35.99ms +step:541/1555 train_time:19495ms step_avg:36.04ms +step:542/1555 train_time:19560ms step_avg:36.09ms +step:543/1555 train_time:19618ms step_avg:36.13ms +step:544/1555 train_time:19683ms step_avg:36.18ms +step:545/1555 train_time:19741ms step_avg:36.22ms +step:546/1555 train_time:19805ms step_avg:36.27ms +step:547/1555 train_time:19862ms step_avg:36.31ms +step:548/1555 train_time:19927ms step_avg:36.36ms +step:549/1555 train_time:19984ms step_avg:36.40ms +step:550/1555 train_time:20049ms step_avg:36.45ms +step:551/1555 train_time:20107ms step_avg:36.49ms +step:552/1555 train_time:20170ms step_avg:36.54ms +step:553/1555 train_time:20229ms step_avg:36.58ms +step:554/1555 train_time:20292ms step_avg:36.63ms +step:555/1555 train_time:20350ms step_avg:36.67ms +step:556/1555 train_time:20415ms step_avg:36.72ms +step:557/1555 train_time:20472ms step_avg:36.75ms +step:558/1555 train_time:20538ms step_avg:36.81ms +step:559/1555 train_time:20597ms step_avg:36.85ms +step:560/1555 train_time:20661ms step_avg:36.89ms +step:561/1555 train_time:20719ms step_avg:36.93ms +step:562/1555 train_time:20783ms step_avg:36.98ms +step:563/1555 train_time:20842ms step_avg:37.02ms +step:564/1555 train_time:20905ms step_avg:37.07ms +step:565/1555 train_time:20963ms step_avg:37.10ms +step:566/1555 train_time:21028ms step_avg:37.15ms +step:567/1555 train_time:21086ms step_avg:37.19ms +step:568/1555 train_time:21150ms step_avg:37.24ms +step:569/1555 train_time:21208ms step_avg:37.27ms +step:570/1555 train_time:21271ms step_avg:37.32ms +step:571/1555 train_time:21329ms step_avg:37.35ms +step:572/1555 train_time:21393ms step_avg:37.40ms +step:573/1555 train_time:21451ms step_avg:37.44ms +step:574/1555 train_time:21516ms step_avg:37.48ms +step:575/1555 train_time:21574ms step_avg:37.52ms +step:576/1555 train_time:21639ms step_avg:37.57ms +step:577/1555 train_time:21697ms step_avg:37.60ms +step:578/1555 train_time:21762ms step_avg:37.65ms +step:579/1555 train_time:21820ms step_avg:37.69ms +step:580/1555 train_time:21884ms step_avg:37.73ms +step:581/1555 train_time:21942ms step_avg:37.77ms +step:582/1555 train_time:22007ms step_avg:37.81ms +step:583/1555 train_time:22065ms step_avg:37.85ms +step:584/1555 train_time:22129ms step_avg:37.89ms +step:585/1555 train_time:22187ms step_avg:37.93ms +step:586/1555 train_time:22251ms step_avg:37.97ms +step:587/1555 train_time:22308ms step_avg:38.00ms +step:588/1555 train_time:22373ms step_avg:38.05ms +step:589/1555 train_time:22431ms step_avg:38.08ms +step:590/1555 train_time:22494ms step_avg:38.13ms +step:591/1555 train_time:22552ms step_avg:38.16ms +step:592/1555 train_time:22618ms step_avg:38.21ms +step:593/1555 train_time:22676ms step_avg:38.24ms +step:594/1555 train_time:22741ms step_avg:38.28ms +step:595/1555 train_time:22799ms step_avg:38.32ms +step:596/1555 train_time:22863ms step_avg:38.36ms +step:597/1555 train_time:22920ms step_avg:38.39ms +step:598/1555 train_time:22985ms step_avg:38.44ms +step:599/1555 train_time:23044ms step_avg:38.47ms +step:600/1555 train_time:23108ms step_avg:38.51ms +step:601/1555 train_time:23166ms step_avg:38.55ms +step:602/1555 train_time:23230ms step_avg:38.59ms +step:603/1555 train_time:23288ms step_avg:38.62ms +step:604/1555 train_time:23352ms step_avg:38.66ms +step:605/1555 train_time:23410ms step_avg:38.69ms +step:606/1555 train_time:23473ms step_avg:38.73ms +step:607/1555 train_time:23530ms step_avg:38.77ms +step:608/1555 train_time:23595ms step_avg:38.81ms +step:609/1555 train_time:23652ms step_avg:38.84ms +step:610/1555 train_time:23718ms step_avg:38.88ms +step:611/1555 train_time:23775ms step_avg:38.91ms +step:612/1555 train_time:23840ms step_avg:38.95ms +step:613/1555 train_time:23898ms step_avg:38.99ms +step:614/1555 train_time:23963ms step_avg:39.03ms +step:615/1555 train_time:24021ms step_avg:39.06ms +step:616/1555 train_time:24086ms step_avg:39.10ms +step:617/1555 train_time:24145ms step_avg:39.13ms +step:618/1555 train_time:24209ms step_avg:39.17ms +step:619/1555 train_time:24266ms step_avg:39.20ms +step:620/1555 train_time:24330ms step_avg:39.24ms +step:621/1555 train_time:24388ms step_avg:39.27ms +step:622/1555 train_time:24452ms step_avg:39.31ms +step:623/1555 train_time:24509ms step_avg:39.34ms +step:624/1555 train_time:24573ms step_avg:39.38ms +step:625/1555 train_time:24630ms step_avg:39.41ms +step:626/1555 train_time:24696ms step_avg:39.45ms +step:627/1555 train_time:24754ms step_avg:39.48ms +step:628/1555 train_time:24819ms step_avg:39.52ms +step:629/1555 train_time:24877ms step_avg:39.55ms +step:630/1555 train_time:24941ms step_avg:39.59ms +step:631/1555 train_time:25000ms step_avg:39.62ms +step:632/1555 train_time:25065ms step_avg:39.66ms +step:633/1555 train_time:25123ms step_avg:39.69ms +step:634/1555 train_time:25188ms step_avg:39.73ms +step:635/1555 train_time:25246ms step_avg:39.76ms +step:636/1555 train_time:25309ms step_avg:39.79ms +step:637/1555 train_time:25367ms step_avg:39.82ms +step:638/1555 train_time:25430ms step_avg:39.86ms +step:639/1555 train_time:25488ms step_avg:39.89ms +step:640/1555 train_time:25553ms step_avg:39.93ms +step:641/1555 train_time:25610ms step_avg:39.95ms +step:642/1555 train_time:25674ms step_avg:39.99ms +step:643/1555 train_time:25732ms step_avg:40.02ms +step:644/1555 train_time:25796ms step_avg:40.06ms +step:645/1555 train_time:25855ms step_avg:40.09ms +step:646/1555 train_time:25921ms step_avg:40.12ms +step:647/1555 train_time:25979ms step_avg:40.15ms +step:648/1555 train_time:26044ms step_avg:40.19ms +step:649/1555 train_time:26101ms step_avg:40.22ms +step:650/1555 train_time:26165ms step_avg:40.25ms +step:651/1555 train_time:26223ms step_avg:40.28ms +step:652/1555 train_time:26287ms step_avg:40.32ms +step:653/1555 train_time:26347ms step_avg:40.35ms +step:654/1555 train_time:26410ms step_avg:40.38ms +step:655/1555 train_time:26468ms step_avg:40.41ms +step:656/1555 train_time:26532ms step_avg:40.44ms +step:657/1555 train_time:26589ms step_avg:40.47ms +step:658/1555 train_time:26654ms step_avg:40.51ms +step:659/1555 train_time:26711ms step_avg:40.53ms +step:660/1555 train_time:26776ms step_avg:40.57ms +step:661/1555 train_time:26834ms step_avg:40.60ms +step:662/1555 train_time:26899ms step_avg:40.63ms +step:663/1555 train_time:26957ms step_avg:40.66ms +step:664/1555 train_time:27022ms step_avg:40.70ms +step:665/1555 train_time:27079ms step_avg:40.72ms +step:666/1555 train_time:27144ms step_avg:40.76ms +step:667/1555 train_time:27201ms step_avg:40.78ms +step:668/1555 train_time:27266ms step_avg:40.82ms +step:669/1555 train_time:27324ms step_avg:40.84ms +step:670/1555 train_time:27388ms step_avg:40.88ms +step:671/1555 train_time:27446ms step_avg:40.90ms +step:672/1555 train_time:27509ms step_avg:40.94ms +step:673/1555 train_time:27567ms step_avg:40.96ms +step:674/1555 train_time:27631ms step_avg:41.00ms +step:675/1555 train_time:27688ms step_avg:41.02ms +step:676/1555 train_time:27753ms step_avg:41.05ms +step:677/1555 train_time:27810ms step_avg:41.08ms +step:678/1555 train_time:27875ms step_avg:41.11ms +step:679/1555 train_time:27933ms step_avg:41.14ms +step:680/1555 train_time:27998ms step_avg:41.17ms +step:681/1555 train_time:28056ms step_avg:41.20ms +step:682/1555 train_time:28121ms step_avg:41.23ms +step:683/1555 train_time:28178ms step_avg:41.26ms +step:684/1555 train_time:28242ms step_avg:41.29ms +step:685/1555 train_time:28300ms step_avg:41.31ms +step:686/1555 train_time:28365ms step_avg:41.35ms +step:687/1555 train_time:28423ms step_avg:41.37ms +step:688/1555 train_time:28487ms step_avg:41.41ms +step:689/1555 train_time:28546ms step_avg:41.43ms +step:690/1555 train_time:28609ms step_avg:41.46ms +step:691/1555 train_time:28667ms step_avg:41.49ms +step:692/1555 train_time:28731ms step_avg:41.52ms +step:693/1555 train_time:28788ms step_avg:41.54ms +step:694/1555 train_time:28853ms step_avg:41.57ms +step:695/1555 train_time:28909ms step_avg:41.60ms +step:696/1555 train_time:28975ms step_avg:41.63ms +step:697/1555 train_time:29033ms step_avg:41.65ms +step:698/1555 train_time:29098ms step_avg:41.69ms +step:699/1555 train_time:29155ms step_avg:41.71ms +step:700/1555 train_time:29220ms step_avg:41.74ms +step:701/1555 train_time:29277ms step_avg:41.77ms +step:702/1555 train_time:29342ms step_avg:41.80ms +step:703/1555 train_time:29400ms step_avg:41.82ms +step:704/1555 train_time:29465ms step_avg:41.85ms +step:705/1555 train_time:29522ms step_avg:41.88ms +step:706/1555 train_time:29586ms step_avg:41.91ms +step:707/1555 train_time:29644ms step_avg:41.93ms +step:708/1555 train_time:29710ms step_avg:41.96ms +step:709/1555 train_time:29767ms step_avg:41.98ms +step:710/1555 train_time:29831ms step_avg:42.02ms +step:711/1555 train_time:29888ms step_avg:42.04ms +step:712/1555 train_time:29952ms step_avg:42.07ms +step:713/1555 train_time:30010ms step_avg:42.09ms +step:714/1555 train_time:30075ms step_avg:42.12ms +step:715/1555 train_time:30132ms step_avg:42.14ms +step:716/1555 train_time:30198ms step_avg:42.18ms +step:717/1555 train_time:30256ms step_avg:42.20ms +step:718/1555 train_time:30319ms step_avg:42.23ms +step:719/1555 train_time:30377ms step_avg:42.25ms +step:720/1555 train_time:30442ms step_avg:42.28ms +step:721/1555 train_time:30500ms step_avg:42.30ms +step:722/1555 train_time:30564ms step_avg:42.33ms +step:723/1555 train_time:30623ms step_avg:42.35ms +step:724/1555 train_time:30687ms step_avg:42.39ms +step:725/1555 train_time:30745ms step_avg:42.41ms +step:726/1555 train_time:30810ms step_avg:42.44ms +step:727/1555 train_time:30867ms step_avg:42.46ms +step:728/1555 train_time:30931ms step_avg:42.49ms +step:729/1555 train_time:30988ms step_avg:42.51ms +step:730/1555 train_time:31054ms step_avg:42.54ms +step:731/1555 train_time:31110ms step_avg:42.56ms +step:732/1555 train_time:31176ms step_avg:42.59ms +step:733/1555 train_time:31233ms step_avg:42.61ms +step:734/1555 train_time:31298ms step_avg:42.64ms +step:735/1555 train_time:31356ms step_avg:42.66ms +step:736/1555 train_time:31420ms step_avg:42.69ms +step:737/1555 train_time:31478ms step_avg:42.71ms +step:738/1555 train_time:31543ms step_avg:42.74ms +step:739/1555 train_time:31602ms step_avg:42.76ms +step:740/1555 train_time:31665ms step_avg:42.79ms +step:741/1555 train_time:31724ms step_avg:42.81ms +step:742/1555 train_time:31788ms step_avg:42.84ms +step:743/1555 train_time:31846ms step_avg:42.86ms +step:744/1555 train_time:31910ms step_avg:42.89ms +step:745/1555 train_time:31967ms step_avg:42.91ms +step:746/1555 train_time:32032ms step_avg:42.94ms +step:747/1555 train_time:32089ms step_avg:42.96ms +step:748/1555 train_time:32154ms step_avg:42.99ms +step:749/1555 train_time:32211ms step_avg:43.00ms +step:750/1555 train_time:32276ms step_avg:43.03ms +step:750/1555 val_loss:3.8786 train_time:32358ms step_avg:43.14ms +step:751/1555 train_time:32377ms step_avg:43.11ms +step:752/1555 train_time:32399ms step_avg:43.08ms +step:753/1555 train_time:32458ms step_avg:43.11ms +step:754/1555 train_time:32527ms step_avg:43.14ms +step:755/1555 train_time:32588ms step_avg:43.16ms +step:756/1555 train_time:32652ms step_avg:43.19ms +step:757/1555 train_time:32709ms step_avg:43.21ms +step:758/1555 train_time:32773ms step_avg:43.24ms +step:759/1555 train_time:32830ms step_avg:43.25ms +step:760/1555 train_time:32894ms step_avg:43.28ms +step:761/1555 train_time:32950ms step_avg:43.30ms +step:762/1555 train_time:33016ms step_avg:43.33ms +step:763/1555 train_time:33073ms step_avg:43.35ms +step:764/1555 train_time:33137ms step_avg:43.37ms +step:765/1555 train_time:33194ms step_avg:43.39ms +step:766/1555 train_time:33257ms step_avg:43.42ms +step:767/1555 train_time:33315ms step_avg:43.44ms +step:768/1555 train_time:33379ms step_avg:43.46ms +step:769/1555 train_time:33439ms step_avg:43.48ms +step:770/1555 train_time:33504ms step_avg:43.51ms +step:771/1555 train_time:33564ms step_avg:43.53ms +step:772/1555 train_time:33629ms step_avg:43.56ms +step:773/1555 train_time:33686ms step_avg:43.58ms +step:774/1555 train_time:33750ms step_avg:43.60ms +step:775/1555 train_time:33807ms step_avg:43.62ms +step:776/1555 train_time:33870ms step_avg:43.65ms +step:777/1555 train_time:33927ms step_avg:43.66ms +step:778/1555 train_time:33991ms step_avg:43.69ms +step:779/1555 train_time:34049ms step_avg:43.71ms +step:780/1555 train_time:34112ms step_avg:43.73ms +step:781/1555 train_time:34170ms step_avg:43.75ms +step:782/1555 train_time:34233ms step_avg:43.78ms +step:783/1555 train_time:34290ms step_avg:43.79ms +step:784/1555 train_time:34355ms step_avg:43.82ms +step:785/1555 train_time:34413ms step_avg:43.84ms +step:786/1555 train_time:34479ms step_avg:43.87ms +step:787/1555 train_time:34537ms step_avg:43.88ms +step:788/1555 train_time:34602ms step_avg:43.91ms +step:789/1555 train_time:34660ms step_avg:43.93ms +step:790/1555 train_time:34724ms step_avg:43.96ms +step:791/1555 train_time:34782ms step_avg:43.97ms +step:792/1555 train_time:34848ms step_avg:44.00ms +step:793/1555 train_time:34906ms step_avg:44.02ms +step:794/1555 train_time:34971ms step_avg:44.04ms +step:795/1555 train_time:35028ms step_avg:44.06ms +step:796/1555 train_time:35091ms step_avg:44.08ms +step:797/1555 train_time:35149ms step_avg:44.10ms +step:798/1555 train_time:35212ms step_avg:44.13ms +step:799/1555 train_time:35270ms step_avg:44.14ms +step:800/1555 train_time:35333ms step_avg:44.17ms +step:801/1555 train_time:35390ms step_avg:44.18ms +step:802/1555 train_time:35455ms step_avg:44.21ms +step:803/1555 train_time:35513ms step_avg:44.23ms +step:804/1555 train_time:35579ms step_avg:44.25ms +step:805/1555 train_time:35637ms step_avg:44.27ms +step:806/1555 train_time:35701ms step_avg:44.29ms +step:807/1555 train_time:35760ms step_avg:44.31ms +step:808/1555 train_time:35823ms step_avg:44.34ms +step:809/1555 train_time:35882ms step_avg:44.35ms +step:810/1555 train_time:35947ms step_avg:44.38ms +step:811/1555 train_time:36005ms step_avg:44.40ms +step:812/1555 train_time:36069ms step_avg:44.42ms +step:813/1555 train_time:36126ms step_avg:44.44ms +step:814/1555 train_time:36189ms step_avg:44.46ms +step:815/1555 train_time:36248ms step_avg:44.48ms +step:816/1555 train_time:36312ms step_avg:44.50ms +step:817/1555 train_time:36370ms step_avg:44.52ms +step:818/1555 train_time:36433ms step_avg:44.54ms +step:819/1555 train_time:36491ms step_avg:44.56ms +step:820/1555 train_time:36556ms step_avg:44.58ms +step:821/1555 train_time:36613ms step_avg:44.60ms +step:822/1555 train_time:36679ms step_avg:44.62ms +step:823/1555 train_time:36737ms step_avg:44.64ms +step:824/1555 train_time:36802ms step_avg:44.66ms +step:825/1555 train_time:36859ms step_avg:44.68ms +step:826/1555 train_time:36924ms step_avg:44.70ms +step:827/1555 train_time:36982ms step_avg:44.72ms +step:828/1555 train_time:37046ms step_avg:44.74ms +step:829/1555 train_time:37103ms step_avg:44.76ms +step:830/1555 train_time:37168ms step_avg:44.78ms +step:831/1555 train_time:37224ms step_avg:44.79ms +step:832/1555 train_time:37289ms step_avg:44.82ms +step:833/1555 train_time:37347ms step_avg:44.83ms +step:834/1555 train_time:37411ms step_avg:44.86ms +step:835/1555 train_time:37469ms step_avg:44.87ms +step:836/1555 train_time:37532ms step_avg:44.89ms +step:837/1555 train_time:37591ms step_avg:44.91ms +step:838/1555 train_time:37656ms step_avg:44.93ms +step:839/1555 train_time:37713ms step_avg:44.95ms +step:840/1555 train_time:37779ms step_avg:44.97ms +step:841/1555 train_time:37836ms step_avg:44.99ms +step:842/1555 train_time:37902ms step_avg:45.01ms +step:843/1555 train_time:37959ms step_avg:45.03ms +step:844/1555 train_time:38023ms step_avg:45.05ms +step:845/1555 train_time:38081ms step_avg:45.07ms +step:846/1555 train_time:38145ms step_avg:45.09ms +step:847/1555 train_time:38203ms step_avg:45.10ms +step:848/1555 train_time:38269ms step_avg:45.13ms +step:849/1555 train_time:38327ms step_avg:45.14ms +step:850/1555 train_time:38390ms step_avg:45.16ms +step:851/1555 train_time:38448ms step_avg:45.18ms +step:852/1555 train_time:38512ms step_avg:45.20ms +step:853/1555 train_time:38570ms step_avg:45.22ms +step:854/1555 train_time:38636ms step_avg:45.24ms +step:855/1555 train_time:38692ms step_avg:45.25ms +step:856/1555 train_time:38756ms step_avg:45.28ms +step:857/1555 train_time:38814ms step_avg:45.29ms +step:858/1555 train_time:38879ms step_avg:45.31ms +step:859/1555 train_time:38938ms step_avg:45.33ms +step:860/1555 train_time:39002ms step_avg:45.35ms +step:861/1555 train_time:39059ms step_avg:45.36ms +step:862/1555 train_time:39123ms step_avg:45.39ms +step:863/1555 train_time:39181ms step_avg:45.40ms +step:864/1555 train_time:39246ms step_avg:45.42ms +step:865/1555 train_time:39305ms step_avg:45.44ms +step:866/1555 train_time:39369ms step_avg:45.46ms +step:867/1555 train_time:39427ms step_avg:45.47ms +step:868/1555 train_time:39491ms step_avg:45.50ms +step:869/1555 train_time:39549ms step_avg:45.51ms +step:870/1555 train_time:39614ms step_avg:45.53ms +step:871/1555 train_time:39671ms step_avg:45.55ms +step:872/1555 train_time:39736ms step_avg:45.57ms +step:873/1555 train_time:39792ms step_avg:45.58ms +step:874/1555 train_time:39857ms step_avg:45.60ms +step:875/1555 train_time:39914ms step_avg:45.62ms +step:876/1555 train_time:39979ms step_avg:45.64ms +step:877/1555 train_time:40037ms step_avg:45.65ms +step:878/1555 train_time:40101ms step_avg:45.67ms +step:879/1555 train_time:40159ms step_avg:45.69ms +step:880/1555 train_time:40224ms step_avg:45.71ms +step:881/1555 train_time:40282ms step_avg:45.72ms +step:882/1555 train_time:40346ms step_avg:45.74ms +step:883/1555 train_time:40404ms step_avg:45.76ms +step:884/1555 train_time:40470ms step_avg:45.78ms +step:885/1555 train_time:40528ms step_avg:45.79ms +step:886/1555 train_time:40592ms step_avg:45.82ms +step:887/1555 train_time:40650ms step_avg:45.83ms +step:888/1555 train_time:40714ms step_avg:45.85ms +step:889/1555 train_time:40771ms step_avg:45.86ms +step:890/1555 train_time:40835ms step_avg:45.88ms +step:891/1555 train_time:40893ms step_avg:45.90ms +step:892/1555 train_time:40957ms step_avg:45.92ms +step:893/1555 train_time:41014ms step_avg:45.93ms +step:894/1555 train_time:41079ms step_avg:45.95ms +step:895/1555 train_time:41137ms step_avg:45.96ms +step:896/1555 train_time:41202ms step_avg:45.98ms +step:897/1555 train_time:41260ms step_avg:46.00ms +step:898/1555 train_time:41324ms step_avg:46.02ms +step:899/1555 train_time:41381ms step_avg:46.03ms +step:900/1555 train_time:41446ms step_avg:46.05ms +step:901/1555 train_time:41504ms step_avg:46.06ms +step:902/1555 train_time:41569ms step_avg:46.09ms +step:903/1555 train_time:41627ms step_avg:46.10ms +step:904/1555 train_time:41691ms step_avg:46.12ms +step:905/1555 train_time:41750ms step_avg:46.13ms +step:906/1555 train_time:41813ms step_avg:46.15ms +step:907/1555 train_time:41872ms step_avg:46.17ms +step:908/1555 train_time:41934ms step_avg:46.18ms +step:909/1555 train_time:41992ms step_avg:46.20ms +step:910/1555 train_time:42056ms step_avg:46.22ms +step:911/1555 train_time:42113ms step_avg:46.23ms +step:912/1555 train_time:42178ms step_avg:46.25ms +step:913/1555 train_time:42236ms step_avg:46.26ms +step:914/1555 train_time:42301ms step_avg:46.28ms +step:915/1555 train_time:42359ms step_avg:46.29ms +step:916/1555 train_time:42424ms step_avg:46.31ms +step:917/1555 train_time:42481ms step_avg:46.33ms +step:918/1555 train_time:42546ms step_avg:46.35ms +step:919/1555 train_time:42604ms step_avg:46.36ms +step:920/1555 train_time:42668ms step_avg:46.38ms +step:921/1555 train_time:42726ms step_avg:46.39ms +step:922/1555 train_time:42791ms step_avg:46.41ms +step:923/1555 train_time:42850ms step_avg:46.42ms +step:924/1555 train_time:42913ms step_avg:46.44ms +step:925/1555 train_time:42971ms step_avg:46.46ms +step:926/1555 train_time:43034ms step_avg:46.47ms +step:927/1555 train_time:43092ms step_avg:46.49ms +step:928/1555 train_time:43156ms step_avg:46.50ms +step:929/1555 train_time:43213ms step_avg:46.52ms +step:930/1555 train_time:43279ms step_avg:46.54ms +step:931/1555 train_time:43337ms step_avg:46.55ms +step:932/1555 train_time:43401ms step_avg:46.57ms +step:933/1555 train_time:43459ms step_avg:46.58ms +step:934/1555 train_time:43524ms step_avg:46.60ms +step:935/1555 train_time:43582ms step_avg:46.61ms +step:936/1555 train_time:43647ms step_avg:46.63ms +step:937/1555 train_time:43705ms step_avg:46.64ms +step:938/1555 train_time:43770ms step_avg:46.66ms +step:939/1555 train_time:43828ms step_avg:46.68ms +step:940/1555 train_time:43892ms step_avg:46.69ms +step:941/1555 train_time:43950ms step_avg:46.71ms +step:942/1555 train_time:44014ms step_avg:46.72ms +step:943/1555 train_time:44071ms step_avg:46.74ms +step:944/1555 train_time:44135ms step_avg:46.75ms +step:945/1555 train_time:44192ms step_avg:46.76ms +step:946/1555 train_time:44257ms step_avg:46.78ms +step:947/1555 train_time:44314ms step_avg:46.79ms +step:948/1555 train_time:44380ms step_avg:46.81ms +step:949/1555 train_time:44437ms step_avg:46.83ms +step:950/1555 train_time:44502ms step_avg:46.84ms +step:951/1555 train_time:44560ms step_avg:46.86ms +step:952/1555 train_time:44625ms step_avg:46.87ms +step:953/1555 train_time:44682ms step_avg:46.89ms +step:954/1555 train_time:44747ms step_avg:46.90ms +step:955/1555 train_time:44804ms step_avg:46.92ms +step:956/1555 train_time:44869ms step_avg:46.93ms +step:957/1555 train_time:44927ms step_avg:46.95ms +step:958/1555 train_time:44991ms step_avg:46.96ms +step:959/1555 train_time:45049ms step_avg:46.98ms +step:960/1555 train_time:45113ms step_avg:46.99ms +step:961/1555 train_time:45171ms step_avg:47.00ms +step:962/1555 train_time:45235ms step_avg:47.02ms +step:963/1555 train_time:45292ms step_avg:47.03ms +step:964/1555 train_time:45357ms step_avg:47.05ms +step:965/1555 train_time:45414ms step_avg:47.06ms +step:966/1555 train_time:45480ms step_avg:47.08ms +step:967/1555 train_time:45538ms step_avg:47.09ms +step:968/1555 train_time:45602ms step_avg:47.11ms +step:969/1555 train_time:45660ms step_avg:47.12ms +step:970/1555 train_time:45724ms step_avg:47.14ms +step:971/1555 train_time:45781ms step_avg:47.15ms +step:972/1555 train_time:45847ms step_avg:47.17ms +step:973/1555 train_time:45905ms step_avg:47.18ms +step:974/1555 train_time:45969ms step_avg:47.20ms +step:975/1555 train_time:46027ms step_avg:47.21ms +step:976/1555 train_time:46091ms step_avg:47.22ms +step:977/1555 train_time:46149ms step_avg:47.24ms +step:978/1555 train_time:46212ms step_avg:47.25ms +step:979/1555 train_time:46269ms step_avg:47.26ms +step:980/1555 train_time:46334ms step_avg:47.28ms +step:981/1555 train_time:46392ms step_avg:47.29ms +step:982/1555 train_time:46457ms step_avg:47.31ms +step:983/1555 train_time:46514ms step_avg:47.32ms +step:984/1555 train_time:46579ms step_avg:47.34ms +step:985/1555 train_time:46638ms step_avg:47.35ms +step:986/1555 train_time:46702ms step_avg:47.37ms +step:987/1555 train_time:46760ms step_avg:47.38ms +step:988/1555 train_time:46824ms step_avg:47.39ms +step:989/1555 train_time:46883ms step_avg:47.40ms +step:990/1555 train_time:46947ms step_avg:47.42ms +step:991/1555 train_time:47006ms step_avg:47.43ms +step:992/1555 train_time:47070ms step_avg:47.45ms +step:993/1555 train_time:47127ms step_avg:47.46ms +step:994/1555 train_time:47191ms step_avg:47.48ms +step:995/1555 train_time:47250ms step_avg:47.49ms +step:996/1555 train_time:47315ms step_avg:47.51ms +step:997/1555 train_time:47371ms step_avg:47.51ms +step:998/1555 train_time:47435ms step_avg:47.53ms +step:999/1555 train_time:47493ms step_avg:47.54ms +step:1000/1555 train_time:47557ms step_avg:47.56ms +step:1000/1555 val_loss:3.5774 train_time:47640ms step_avg:47.64ms +step:1001/1555 train_time:47658ms step_avg:47.61ms +step:1002/1555 train_time:47681ms step_avg:47.59ms +step:1003/1555 train_time:47740ms step_avg:47.60ms +step:1004/1555 train_time:47807ms step_avg:47.62ms +step:1005/1555 train_time:47866ms step_avg:47.63ms +step:1006/1555 train_time:47930ms step_avg:47.64ms +step:1007/1555 train_time:47987ms step_avg:47.65ms +step:1008/1555 train_time:48051ms step_avg:47.67ms +step:1009/1555 train_time:48107ms step_avg:47.68ms +step:1010/1555 train_time:48171ms step_avg:47.69ms +step:1011/1555 train_time:48231ms step_avg:47.71ms +step:1012/1555 train_time:48316ms step_avg:47.74ms +step:1013/1555 train_time:48398ms step_avg:47.78ms +step:1014/1555 train_time:48488ms step_avg:47.82ms +step:1015/1555 train_time:48571ms step_avg:47.85ms +step:1016/1555 train_time:48663ms step_avg:47.90ms +step:1017/1555 train_time:48752ms step_avg:47.94ms +step:1018/1555 train_time:48843ms step_avg:47.98ms +step:1019/1555 train_time:48928ms step_avg:48.02ms +step:1020/1555 train_time:49018ms step_avg:48.06ms +step:1021/1555 train_time:49102ms step_avg:48.09ms +step:1022/1555 train_time:49191ms step_avg:48.13ms +step:1023/1555 train_time:49274ms step_avg:48.17ms +step:1024/1555 train_time:49363ms step_avg:48.21ms +step:1025/1555 train_time:49446ms step_avg:48.24ms +step:1026/1555 train_time:49536ms step_avg:48.28ms +step:1027/1555 train_time:49620ms step_avg:48.32ms +step:1028/1555 train_time:49712ms step_avg:48.36ms +step:1029/1555 train_time:49795ms step_avg:48.39ms +step:1030/1555 train_time:49887ms step_avg:48.43ms +step:1031/1555 train_time:49973ms step_avg:48.47ms +step:1032/1555 train_time:50063ms step_avg:48.51ms +step:1033/1555 train_time:50147ms step_avg:48.54ms +step:1034/1555 train_time:50236ms step_avg:48.58ms +step:1035/1555 train_time:50319ms step_avg:48.62ms +step:1036/1555 train_time:50409ms step_avg:48.66ms +step:1037/1555 train_time:50493ms step_avg:48.69ms +step:1038/1555 train_time:50583ms step_avg:48.73ms +step:1039/1555 train_time:50668ms step_avg:48.77ms +step:1040/1555 train_time:50759ms step_avg:48.81ms +step:1041/1555 train_time:50845ms step_avg:48.84ms +step:1042/1555 train_time:50937ms step_avg:48.88ms +step:1043/1555 train_time:51020ms step_avg:48.92ms +step:1044/1555 train_time:51111ms step_avg:48.96ms +step:1045/1555 train_time:51194ms step_avg:48.99ms +step:1046/1555 train_time:51284ms step_avg:49.03ms +step:1047/1555 train_time:51368ms step_avg:49.06ms +step:1048/1555 train_time:51457ms step_avg:49.10ms +step:1049/1555 train_time:51541ms step_avg:49.13ms +step:1050/1555 train_time:51633ms step_avg:49.17ms +step:1051/1555 train_time:51717ms step_avg:49.21ms +step:1052/1555 train_time:51810ms step_avg:49.25ms +step:1053/1555 train_time:51893ms step_avg:49.28ms +step:1054/1555 train_time:51983ms step_avg:49.32ms +step:1055/1555 train_time:52066ms step_avg:49.35ms +step:1056/1555 train_time:52156ms step_avg:49.39ms +step:1057/1555 train_time:52241ms step_avg:49.42ms +step:1058/1555 train_time:52332ms step_avg:49.46ms +step:1059/1555 train_time:52414ms step_avg:49.49ms +step:1060/1555 train_time:52504ms step_avg:49.53ms +step:1061/1555 train_time:52588ms step_avg:49.56ms +step:1062/1555 train_time:52677ms step_avg:49.60ms +step:1063/1555 train_time:52762ms step_avg:49.64ms +step:1064/1555 train_time:52853ms step_avg:49.67ms +step:1065/1555 train_time:52936ms step_avg:49.71ms +step:1066/1555 train_time:53027ms step_avg:49.74ms +step:1067/1555 train_time:53112ms step_avg:49.78ms +step:1068/1555 train_time:53201ms step_avg:49.81ms +step:1069/1555 train_time:53284ms step_avg:49.85ms +step:1070/1555 train_time:53374ms step_avg:49.88ms +step:1071/1555 train_time:53458ms step_avg:49.91ms +step:1072/1555 train_time:53548ms step_avg:49.95ms +step:1073/1555 train_time:53633ms step_avg:49.98ms +step:1074/1555 train_time:53723ms step_avg:50.02ms +step:1075/1555 train_time:53807ms step_avg:50.05ms +step:1076/1555 train_time:53896ms step_avg:50.09ms +step:1077/1555 train_time:53980ms step_avg:50.12ms +step:1078/1555 train_time:54072ms step_avg:50.16ms +step:1079/1555 train_time:54156ms step_avg:50.19ms +step:1080/1555 train_time:54246ms step_avg:50.23ms +step:1081/1555 train_time:54330ms step_avg:50.26ms +step:1082/1555 train_time:54419ms step_avg:50.29ms +step:1083/1555 train_time:54503ms step_avg:50.33ms +step:1084/1555 train_time:54592ms step_avg:50.36ms +step:1085/1555 train_time:54676ms step_avg:50.39ms +step:1086/1555 train_time:54767ms step_avg:50.43ms +step:1087/1555 train_time:54851ms step_avg:50.46ms +step:1088/1555 train_time:54940ms step_avg:50.50ms +step:1089/1555 train_time:55025ms step_avg:50.53ms +step:1090/1555 train_time:55114ms step_avg:50.56ms +step:1091/1555 train_time:55198ms step_avg:50.59ms +step:1092/1555 train_time:55287ms step_avg:50.63ms +step:1093/1555 train_time:55371ms step_avg:50.66ms +step:1094/1555 train_time:55461ms step_avg:50.70ms +step:1095/1555 train_time:55547ms step_avg:50.73ms +step:1096/1555 train_time:55636ms step_avg:50.76ms +step:1097/1555 train_time:55720ms step_avg:50.79ms +step:1098/1555 train_time:55812ms step_avg:50.83ms +step:1099/1555 train_time:55894ms step_avg:50.86ms +step:1100/1555 train_time:55985ms step_avg:50.90ms +step:1101/1555 train_time:56070ms step_avg:50.93ms +step:1102/1555 train_time:56160ms step_avg:50.96ms +step:1103/1555 train_time:56244ms step_avg:50.99ms +step:1104/1555 train_time:56334ms step_avg:51.03ms +step:1105/1555 train_time:56418ms step_avg:51.06ms +step:1106/1555 train_time:56510ms step_avg:51.09ms +step:1107/1555 train_time:56593ms step_avg:51.12ms +step:1108/1555 train_time:56684ms step_avg:51.16ms +step:1109/1555 train_time:56768ms step_avg:51.19ms +step:1110/1555 train_time:56857ms step_avg:51.22ms +step:1111/1555 train_time:56941ms step_avg:51.25ms +step:1112/1555 train_time:57031ms step_avg:51.29ms +step:1113/1555 train_time:57114ms step_avg:51.32ms +step:1114/1555 train_time:57204ms step_avg:51.35ms +step:1115/1555 train_time:57287ms step_avg:51.38ms +step:1116/1555 train_time:57377ms step_avg:51.41ms +step:1117/1555 train_time:57461ms step_avg:51.44ms +step:1118/1555 train_time:57552ms step_avg:51.48ms +step:1119/1555 train_time:57636ms step_avg:51.51ms +step:1120/1555 train_time:57725ms step_avg:51.54ms +step:1121/1555 train_time:57810ms step_avg:51.57ms +step:1122/1555 train_time:57899ms step_avg:51.60ms +step:1123/1555 train_time:57984ms step_avg:51.63ms +step:1124/1555 train_time:58073ms step_avg:51.67ms +step:1125/1555 train_time:58157ms step_avg:51.69ms +step:1126/1555 train_time:58248ms step_avg:51.73ms +step:1127/1555 train_time:58331ms step_avg:51.76ms +step:1128/1555 train_time:58422ms step_avg:51.79ms +step:1129/1555 train_time:58507ms step_avg:51.82ms +step:1130/1555 train_time:58596ms step_avg:51.85ms +step:1131/1555 train_time:58679ms step_avg:51.88ms +step:1132/1555 train_time:58770ms step_avg:51.92ms +step:1133/1555 train_time:58854ms step_avg:51.94ms +step:1134/1555 train_time:58944ms step_avg:51.98ms +step:1135/1555 train_time:59029ms step_avg:52.01ms +step:1136/1555 train_time:59118ms step_avg:52.04ms +step:1137/1555 train_time:59201ms step_avg:52.07ms +step:1138/1555 train_time:59291ms step_avg:52.10ms +step:1139/1555 train_time:59376ms step_avg:52.13ms +step:1140/1555 train_time:59466ms step_avg:52.16ms +step:1141/1555 train_time:59551ms step_avg:52.19ms +step:1142/1555 train_time:59640ms step_avg:52.22ms +step:1143/1555 train_time:59725ms step_avg:52.25ms +step:1144/1555 train_time:59815ms step_avg:52.29ms +step:1145/1555 train_time:59899ms step_avg:52.31ms +step:1146/1555 train_time:59989ms step_avg:52.35ms +step:1147/1555 train_time:60073ms step_avg:52.37ms +step:1148/1555 train_time:60163ms step_avg:52.41ms +step:1149/1555 train_time:60248ms step_avg:52.44ms +step:1150/1555 train_time:60337ms step_avg:52.47ms +step:1151/1555 train_time:60421ms step_avg:52.49ms +step:1152/1555 train_time:60513ms step_avg:52.53ms +step:1153/1555 train_time:60596ms step_avg:52.55ms +step:1154/1555 train_time:60686ms step_avg:52.59ms +step:1155/1555 train_time:60770ms step_avg:52.61ms +step:1156/1555 train_time:60859ms step_avg:52.65ms +step:1157/1555 train_time:60944ms step_avg:52.67ms +step:1158/1555 train_time:61035ms step_avg:52.71ms +step:1159/1555 train_time:61120ms step_avg:52.73ms +step:1160/1555 train_time:61211ms step_avg:52.77ms +step:1161/1555 train_time:61294ms step_avg:52.79ms +step:1162/1555 train_time:61383ms step_avg:52.83ms +step:1163/1555 train_time:61468ms step_avg:52.85ms +step:1164/1555 train_time:61557ms step_avg:52.88ms +step:1165/1555 train_time:61641ms step_avg:52.91ms +step:1166/1555 train_time:61733ms step_avg:52.94ms +step:1167/1555 train_time:61816ms step_avg:52.97ms +step:1168/1555 train_time:61907ms step_avg:53.00ms +step:1169/1555 train_time:61990ms step_avg:53.03ms +step:1170/1555 train_time:62080ms step_avg:53.06ms +step:1171/1555 train_time:62165ms step_avg:53.09ms +step:1172/1555 train_time:62255ms step_avg:53.12ms +step:1173/1555 train_time:62338ms step_avg:53.14ms +step:1174/1555 train_time:62429ms step_avg:53.18ms +step:1175/1555 train_time:62513ms step_avg:53.20ms +step:1176/1555 train_time:62603ms step_avg:53.23ms +step:1177/1555 train_time:62687ms step_avg:53.26ms +step:1178/1555 train_time:62777ms step_avg:53.29ms +step:1179/1555 train_time:62862ms step_avg:53.32ms +step:1180/1555 train_time:62953ms step_avg:53.35ms +step:1181/1555 train_time:63037ms step_avg:53.38ms +step:1182/1555 train_time:63127ms step_avg:53.41ms +step:1183/1555 train_time:63211ms step_avg:53.43ms +step:1184/1555 train_time:63300ms step_avg:53.46ms +step:1185/1555 train_time:63385ms step_avg:53.49ms +step:1186/1555 train_time:63474ms step_avg:53.52ms +step:1187/1555 train_time:63558ms step_avg:53.54ms +step:1188/1555 train_time:63648ms step_avg:53.58ms +step:1189/1555 train_time:63733ms step_avg:53.60ms +step:1190/1555 train_time:63823ms step_avg:53.63ms +step:1191/1555 train_time:63908ms step_avg:53.66ms +step:1192/1555 train_time:63997ms step_avg:53.69ms +step:1193/1555 train_time:64083ms step_avg:53.72ms +step:1194/1555 train_time:64173ms step_avg:53.75ms +step:1195/1555 train_time:64256ms step_avg:53.77ms +step:1196/1555 train_time:64346ms step_avg:53.80ms +step:1197/1555 train_time:64431ms step_avg:53.83ms +step:1198/1555 train_time:64521ms step_avg:53.86ms +step:1199/1555 train_time:64606ms step_avg:53.88ms +step:1200/1555 train_time:64696ms step_avg:53.91ms +step:1201/1555 train_time:64779ms step_avg:53.94ms +step:1202/1555 train_time:64870ms step_avg:53.97ms +step:1203/1555 train_time:64953ms step_avg:53.99ms +step:1204/1555 train_time:65043ms step_avg:54.02ms +step:1205/1555 train_time:65127ms step_avg:54.05ms +step:1206/1555 train_time:65216ms step_avg:54.08ms +step:1207/1555 train_time:65300ms step_avg:54.10ms +step:1208/1555 train_time:65391ms step_avg:54.13ms +step:1209/1555 train_time:65475ms step_avg:54.16ms +step:1210/1555 train_time:65564ms step_avg:54.19ms +step:1211/1555 train_time:65648ms step_avg:54.21ms +step:1212/1555 train_time:65738ms step_avg:54.24ms +step:1213/1555 train_time:65823ms step_avg:54.26ms +step:1214/1555 train_time:65914ms step_avg:54.30ms +step:1215/1555 train_time:65998ms step_avg:54.32ms +step:1216/1555 train_time:66088ms step_avg:54.35ms +step:1217/1555 train_time:66172ms step_avg:54.37ms +step:1218/1555 train_time:66261ms step_avg:54.40ms +step:1219/1555 train_time:66346ms step_avg:54.43ms +step:1220/1555 train_time:66435ms step_avg:54.46ms +step:1221/1555 train_time:66519ms step_avg:54.48ms +step:1222/1555 train_time:66609ms step_avg:54.51ms +step:1223/1555 train_time:66692ms step_avg:54.53ms +step:1224/1555 train_time:66784ms step_avg:54.56ms +step:1225/1555 train_time:66869ms step_avg:54.59ms +step:1226/1555 train_time:66959ms step_avg:54.62ms +step:1227/1555 train_time:67044ms step_avg:54.64ms +step:1228/1555 train_time:67134ms step_avg:54.67ms +step:1229/1555 train_time:67217ms step_avg:54.69ms +step:1230/1555 train_time:67309ms step_avg:54.72ms +step:1231/1555 train_time:67393ms step_avg:54.75ms +step:1232/1555 train_time:67483ms step_avg:54.77ms +step:1233/1555 train_time:67567ms step_avg:54.80ms +step:1234/1555 train_time:67656ms step_avg:54.83ms +step:1235/1555 train_time:67741ms step_avg:54.85ms +step:1236/1555 train_time:67831ms step_avg:54.88ms +step:1237/1555 train_time:67916ms step_avg:54.90ms +step:1238/1555 train_time:68006ms step_avg:54.93ms +step:1239/1555 train_time:68090ms step_avg:54.96ms +step:1240/1555 train_time:68178ms step_avg:54.98ms +step:1241/1555 train_time:68263ms step_avg:55.01ms +step:1242/1555 train_time:68353ms step_avg:55.03ms +step:1243/1555 train_time:68437ms step_avg:55.06ms +step:1244/1555 train_time:68527ms step_avg:55.09ms +step:1245/1555 train_time:68613ms step_avg:55.11ms +step:1246/1555 train_time:68702ms step_avg:55.14ms +step:1247/1555 train_time:68785ms step_avg:55.16ms +step:1248/1555 train_time:68876ms step_avg:55.19ms +step:1249/1555 train_time:68959ms step_avg:55.21ms +step:1250/1555 train_time:69050ms step_avg:55.24ms +step:1250/1555 val_loss:3.4009 train_time:69163ms step_avg:55.33ms +step:1251/1555 train_time:69183ms step_avg:55.30ms +step:1252/1555 train_time:69224ms step_avg:55.29ms +step:1253/1555 train_time:69310ms step_avg:55.32ms +step:1254/1555 train_time:69404ms step_avg:55.35ms +step:1255/1555 train_time:69489ms step_avg:55.37ms +step:1256/1555 train_time:69578ms step_avg:55.40ms +step:1257/1555 train_time:69662ms step_avg:55.42ms +step:1258/1555 train_time:69750ms step_avg:55.45ms +step:1259/1555 train_time:69834ms step_avg:55.47ms +step:1260/1555 train_time:69923ms step_avg:55.49ms +step:1261/1555 train_time:70005ms step_avg:55.52ms +step:1262/1555 train_time:70095ms step_avg:55.54ms +step:1263/1555 train_time:70182ms step_avg:55.57ms +step:1264/1555 train_time:70273ms step_avg:55.60ms +step:1265/1555 train_time:70361ms step_avg:55.62ms +step:1266/1555 train_time:70451ms step_avg:55.65ms +step:1267/1555 train_time:70536ms step_avg:55.67ms +step:1268/1555 train_time:70626ms step_avg:55.70ms +step:1269/1555 train_time:70708ms step_avg:55.72ms +step:1270/1555 train_time:70797ms step_avg:55.75ms +step:1271/1555 train_time:70881ms step_avg:55.77ms +step:1272/1555 train_time:70969ms step_avg:55.79ms +step:1273/1555 train_time:71053ms step_avg:55.82ms +step:1274/1555 train_time:71144ms step_avg:55.84ms +step:1275/1555 train_time:71228ms step_avg:55.87ms +step:1276/1555 train_time:71319ms step_avg:55.89ms +step:1277/1555 train_time:71404ms step_avg:55.92ms +step:1278/1555 train_time:71494ms step_avg:55.94ms +step:1279/1555 train_time:71579ms step_avg:55.96ms +step:1280/1555 train_time:71668ms step_avg:55.99ms +step:1281/1555 train_time:71751ms step_avg:56.01ms +step:1282/1555 train_time:71840ms step_avg:56.04ms +step:1283/1555 train_time:71924ms step_avg:56.06ms +step:1284/1555 train_time:72013ms step_avg:56.08ms +step:1285/1555 train_time:72097ms step_avg:56.11ms +step:1286/1555 train_time:72188ms step_avg:56.13ms +step:1287/1555 train_time:72272ms step_avg:56.16ms +step:1288/1555 train_time:72364ms step_avg:56.18ms +step:1289/1555 train_time:72448ms step_avg:56.20ms +step:1290/1555 train_time:72538ms step_avg:56.23ms +step:1291/1555 train_time:72622ms step_avg:56.25ms +step:1292/1555 train_time:72711ms step_avg:56.28ms +step:1293/1555 train_time:72795ms step_avg:56.30ms +step:1294/1555 train_time:72885ms step_avg:56.33ms +step:1295/1555 train_time:72968ms step_avg:56.35ms +step:1296/1555 train_time:73059ms step_avg:56.37ms +step:1297/1555 train_time:73142ms step_avg:56.39ms +step:1298/1555 train_time:73233ms step_avg:56.42ms +step:1299/1555 train_time:73319ms step_avg:56.44ms +step:1300/1555 train_time:73408ms step_avg:56.47ms +step:1301/1555 train_time:73493ms step_avg:56.49ms +step:1302/1555 train_time:73583ms step_avg:56.52ms +step:1303/1555 train_time:73667ms step_avg:56.54ms +step:1304/1555 train_time:73756ms step_avg:56.56ms +step:1305/1555 train_time:73841ms step_avg:56.58ms +step:1306/1555 train_time:73929ms step_avg:56.61ms +step:1307/1555 train_time:74014ms step_avg:56.63ms +step:1308/1555 train_time:74104ms step_avg:56.65ms +step:1309/1555 train_time:74188ms step_avg:56.68ms +step:1310/1555 train_time:74278ms step_avg:56.70ms +step:1311/1555 train_time:74364ms step_avg:56.72ms +step:1312/1555 train_time:74454ms step_avg:56.75ms +step:1313/1555 train_time:74539ms step_avg:56.77ms +step:1314/1555 train_time:74629ms step_avg:56.79ms +step:1315/1555 train_time:74713ms step_avg:56.82ms +step:1316/1555 train_time:74803ms step_avg:56.84ms +step:1317/1555 train_time:74887ms step_avg:56.86ms +step:1318/1555 train_time:74976ms step_avg:56.89ms +step:1319/1555 train_time:75060ms step_avg:56.91ms +step:1320/1555 train_time:75149ms step_avg:56.93ms +step:1321/1555 train_time:75234ms step_avg:56.95ms +step:1322/1555 train_time:75325ms step_avg:56.98ms +step:1323/1555 train_time:75408ms step_avg:57.00ms +step:1324/1555 train_time:75499ms step_avg:57.02ms +step:1325/1555 train_time:75583ms step_avg:57.04ms +step:1326/1555 train_time:75674ms step_avg:57.07ms +step:1327/1555 train_time:75757ms step_avg:57.09ms +step:1328/1555 train_time:75848ms step_avg:57.11ms +step:1329/1555 train_time:75931ms step_avg:57.13ms +step:1330/1555 train_time:76022ms step_avg:57.16ms +step:1331/1555 train_time:76105ms step_avg:57.18ms +step:1332/1555 train_time:76194ms step_avg:57.20ms +step:1333/1555 train_time:76278ms step_avg:57.22ms +step:1334/1555 train_time:76368ms step_avg:57.25ms +step:1335/1555 train_time:76453ms step_avg:57.27ms +step:1336/1555 train_time:76544ms step_avg:57.29ms +step:1337/1555 train_time:76627ms step_avg:57.31ms +step:1338/1555 train_time:76717ms step_avg:57.34ms +step:1339/1555 train_time:76801ms step_avg:57.36ms +step:1340/1555 train_time:76890ms step_avg:57.38ms +step:1341/1555 train_time:76975ms step_avg:57.40ms +step:1342/1555 train_time:77065ms step_avg:57.43ms +step:1343/1555 train_time:77148ms step_avg:57.44ms +step:1344/1555 train_time:77238ms step_avg:57.47ms +step:1345/1555 train_time:77323ms step_avg:57.49ms +step:1346/1555 train_time:77412ms step_avg:57.51ms +step:1347/1555 train_time:77499ms step_avg:57.53ms +step:1348/1555 train_time:77588ms step_avg:57.56ms +step:1349/1555 train_time:77671ms step_avg:57.58ms +step:1350/1555 train_time:77762ms step_avg:57.60ms +step:1351/1555 train_time:77846ms step_avg:57.62ms +step:1352/1555 train_time:77936ms step_avg:57.64ms +step:1353/1555 train_time:78020ms step_avg:57.66ms +step:1354/1555 train_time:78109ms step_avg:57.69ms +step:1355/1555 train_time:78193ms step_avg:57.71ms +step:1356/1555 train_time:78284ms step_avg:57.73ms +step:1357/1555 train_time:78368ms step_avg:57.75ms +step:1358/1555 train_time:78459ms step_avg:57.78ms +step:1359/1555 train_time:78543ms step_avg:57.79ms +step:1360/1555 train_time:78633ms step_avg:57.82ms +step:1361/1555 train_time:78717ms step_avg:57.84ms +step:1362/1555 train_time:78806ms step_avg:57.86ms +step:1363/1555 train_time:78890ms step_avg:57.88ms +step:1364/1555 train_time:78981ms step_avg:57.90ms +step:1365/1555 train_time:79064ms step_avg:57.92ms +step:1366/1555 train_time:79154ms step_avg:57.95ms +step:1367/1555 train_time:79239ms step_avg:57.97ms +step:1368/1555 train_time:79328ms step_avg:57.99ms +step:1369/1555 train_time:79411ms step_avg:58.01ms +step:1370/1555 train_time:79501ms step_avg:58.03ms +step:1371/1555 train_time:79585ms step_avg:58.05ms +step:1372/1555 train_time:79676ms step_avg:58.07ms +step:1373/1555 train_time:79760ms step_avg:58.09ms +step:1374/1555 train_time:79849ms step_avg:58.11ms +step:1375/1555 train_time:79934ms step_avg:58.13ms +step:1376/1555 train_time:80025ms step_avg:58.16ms +step:1377/1555 train_time:80107ms step_avg:58.18ms +step:1378/1555 train_time:80198ms step_avg:58.20ms +step:1379/1555 train_time:80282ms step_avg:58.22ms +step:1380/1555 train_time:80372ms step_avg:58.24ms +step:1381/1555 train_time:80456ms step_avg:58.26ms +step:1382/1555 train_time:80546ms step_avg:58.28ms +step:1383/1555 train_time:80630ms step_avg:58.30ms +step:1384/1555 train_time:80721ms step_avg:58.32ms +step:1385/1555 train_time:80805ms step_avg:58.34ms +step:1386/1555 train_time:80893ms step_avg:58.36ms +step:1387/1555 train_time:80978ms step_avg:58.38ms +step:1388/1555 train_time:81068ms step_avg:58.41ms +step:1389/1555 train_time:81152ms step_avg:58.42ms +step:1390/1555 train_time:81244ms step_avg:58.45ms +step:1391/1555 train_time:81327ms step_avg:58.47ms +step:1392/1555 train_time:81417ms step_avg:58.49ms +step:1393/1555 train_time:81501ms step_avg:58.51ms +step:1394/1555 train_time:81591ms step_avg:58.53ms +step:1395/1555 train_time:81675ms step_avg:58.55ms +step:1396/1555 train_time:81765ms step_avg:58.57ms +step:1397/1555 train_time:81849ms step_avg:58.59ms +step:1398/1555 train_time:81939ms step_avg:58.61ms +step:1399/1555 train_time:82024ms step_avg:58.63ms +step:1400/1555 train_time:82113ms step_avg:58.65ms +step:1401/1555 train_time:82198ms step_avg:58.67ms +step:1402/1555 train_time:82287ms step_avg:58.69ms +step:1403/1555 train_time:82370ms step_avg:58.71ms +step:1404/1555 train_time:82462ms step_avg:58.73ms +step:1405/1555 train_time:82546ms step_avg:58.75ms +step:1406/1555 train_time:82636ms step_avg:58.77ms +step:1407/1555 train_time:82720ms step_avg:58.79ms +step:1408/1555 train_time:82810ms step_avg:58.81ms +step:1409/1555 train_time:82894ms step_avg:58.83ms +step:1410/1555 train_time:82985ms step_avg:58.85ms +step:1411/1555 train_time:83068ms step_avg:58.87ms +step:1412/1555 train_time:83160ms step_avg:58.90ms +step:1413/1555 train_time:83244ms step_avg:58.91ms +step:1414/1555 train_time:83333ms step_avg:58.93ms +step:1415/1555 train_time:83417ms step_avg:58.95ms +step:1416/1555 train_time:83506ms step_avg:58.97ms +step:1417/1555 train_time:83591ms step_avg:58.99ms +step:1418/1555 train_time:83680ms step_avg:59.01ms +step:1419/1555 train_time:83765ms step_avg:59.03ms +step:1420/1555 train_time:83856ms step_avg:59.05ms +step:1421/1555 train_time:83940ms step_avg:59.07ms +step:1422/1555 train_time:84029ms step_avg:59.09ms +step:1423/1555 train_time:84112ms step_avg:59.11ms +step:1424/1555 train_time:84203ms step_avg:59.13ms +step:1425/1555 train_time:84287ms step_avg:59.15ms +step:1426/1555 train_time:84377ms step_avg:59.17ms +step:1427/1555 train_time:84461ms step_avg:59.19ms +step:1428/1555 train_time:84551ms step_avg:59.21ms +step:1429/1555 train_time:84636ms step_avg:59.23ms +step:1430/1555 train_time:84726ms step_avg:59.25ms +step:1431/1555 train_time:84809ms step_avg:59.27ms +step:1432/1555 train_time:84900ms step_avg:59.29ms +step:1433/1555 train_time:84984ms step_avg:59.31ms +step:1434/1555 train_time:85074ms step_avg:59.33ms +step:1435/1555 train_time:85159ms step_avg:59.34ms +step:1436/1555 train_time:85249ms step_avg:59.37ms +step:1437/1555 train_time:85333ms step_avg:59.38ms +step:1438/1555 train_time:85425ms step_avg:59.41ms +step:1439/1555 train_time:85507ms step_avg:59.42ms +step:1440/1555 train_time:85597ms step_avg:59.44ms +step:1441/1555 train_time:85682ms step_avg:59.46ms +step:1442/1555 train_time:85772ms step_avg:59.48ms +step:1443/1555 train_time:85856ms step_avg:59.50ms +step:1444/1555 train_time:85947ms step_avg:59.52ms +step:1445/1555 train_time:86031ms step_avg:59.54ms +step:1446/1555 train_time:86123ms step_avg:59.56ms +step:1447/1555 train_time:86205ms step_avg:59.58ms +step:1448/1555 train_time:86295ms step_avg:59.60ms +step:1449/1555 train_time:86379ms step_avg:59.61ms +step:1450/1555 train_time:86468ms step_avg:59.63ms +step:1451/1555 train_time:86553ms step_avg:59.65ms +step:1452/1555 train_time:86645ms step_avg:59.67ms +step:1453/1555 train_time:86728ms step_avg:59.69ms +step:1454/1555 train_time:86818ms step_avg:59.71ms +step:1455/1555 train_time:86902ms step_avg:59.73ms +step:1456/1555 train_time:86991ms step_avg:59.75ms +step:1457/1555 train_time:87075ms step_avg:59.76ms +step:1458/1555 train_time:87166ms step_avg:59.78ms +step:1459/1555 train_time:87249ms step_avg:59.80ms +step:1460/1555 train_time:87339ms step_avg:59.82ms +step:1461/1555 train_time:87423ms step_avg:59.84ms +step:1462/1555 train_time:87511ms step_avg:59.86ms +step:1463/1555 train_time:87596ms step_avg:59.87ms +step:1464/1555 train_time:87686ms step_avg:59.89ms +step:1465/1555 train_time:87770ms step_avg:59.91ms +step:1466/1555 train_time:87861ms step_avg:59.93ms +step:1467/1555 train_time:87945ms step_avg:59.95ms +step:1468/1555 train_time:88035ms step_avg:59.97ms +step:1469/1555 train_time:88119ms step_avg:59.99ms +step:1470/1555 train_time:88208ms step_avg:60.01ms +step:1471/1555 train_time:88292ms step_avg:60.02ms +step:1472/1555 train_time:88382ms step_avg:60.04ms +step:1473/1555 train_time:88467ms step_avg:60.06ms +step:1474/1555 train_time:88556ms step_avg:60.08ms +step:1475/1555 train_time:88642ms step_avg:60.10ms +step:1476/1555 train_time:88731ms step_avg:60.12ms +step:1477/1555 train_time:88815ms step_avg:60.13ms +step:1478/1555 train_time:88905ms step_avg:60.15ms +step:1479/1555 train_time:88989ms step_avg:60.17ms +step:1480/1555 train_time:89079ms step_avg:60.19ms +step:1481/1555 train_time:89164ms step_avg:60.21ms +step:1482/1555 train_time:89253ms step_avg:60.22ms +step:1483/1555 train_time:89338ms step_avg:60.24ms +step:1484/1555 train_time:89428ms step_avg:60.26ms +step:1485/1555 train_time:89512ms step_avg:60.28ms +step:1486/1555 train_time:89601ms step_avg:60.30ms +step:1487/1555 train_time:89685ms step_avg:60.31ms +step:1488/1555 train_time:89776ms step_avg:60.33ms +step:1489/1555 train_time:89861ms step_avg:60.35ms +step:1490/1555 train_time:89952ms step_avg:60.37ms +step:1491/1555 train_time:90036ms step_avg:60.39ms +step:1492/1555 train_time:90128ms step_avg:60.41ms +step:1493/1555 train_time:90211ms step_avg:60.42ms +step:1494/1555 train_time:90301ms step_avg:60.44ms +step:1495/1555 train_time:90385ms step_avg:60.46ms +step:1496/1555 train_time:90474ms step_avg:60.48ms +step:1497/1555 train_time:90560ms step_avg:60.49ms +step:1498/1555 train_time:90649ms step_avg:60.51ms +step:1499/1555 train_time:90733ms step_avg:60.53ms +step:1500/1555 train_time:90824ms step_avg:60.55ms +step:1500/1555 val_loss:3.2972 train_time:90937ms step_avg:60.62ms +step:1501/1555 train_time:90956ms step_avg:60.60ms +step:1502/1555 train_time:90997ms step_avg:60.58ms +step:1503/1555 train_time:91086ms step_avg:60.60ms +step:1504/1555 train_time:91178ms step_avg:60.62ms +step:1505/1555 train_time:91261ms step_avg:60.64ms +step:1506/1555 train_time:91351ms step_avg:60.66ms +step:1507/1555 train_time:91434ms step_avg:60.67ms +step:1508/1555 train_time:91523ms step_avg:60.69ms +step:1509/1555 train_time:91605ms step_avg:60.71ms +step:1510/1555 train_time:91696ms step_avg:60.73ms +step:1511/1555 train_time:91778ms step_avg:60.74ms +step:1512/1555 train_time:91867ms step_avg:60.76ms +step:1513/1555 train_time:91953ms step_avg:60.78ms +step:1514/1555 train_time:92045ms step_avg:60.80ms +step:1515/1555 train_time:92135ms step_avg:60.82ms +step:1516/1555 train_time:92228ms step_avg:60.84ms +step:1517/1555 train_time:92315ms step_avg:60.85ms +step:1518/1555 train_time:92403ms step_avg:60.87ms +step:1519/1555 train_time:92487ms step_avg:60.89ms +step:1520/1555 train_time:92577ms step_avg:60.91ms +step:1521/1555 train_time:92660ms step_avg:60.92ms +step:1522/1555 train_time:92750ms step_avg:60.94ms +step:1523/1555 train_time:92834ms step_avg:60.95ms +step:1524/1555 train_time:92924ms step_avg:60.97ms +step:1525/1555 train_time:93009ms step_avg:60.99ms +step:1526/1555 train_time:93102ms step_avg:61.01ms +step:1527/1555 train_time:93187ms step_avg:61.03ms +step:1528/1555 train_time:93281ms step_avg:61.05ms +step:1529/1555 train_time:93365ms step_avg:61.06ms +step:1530/1555 train_time:93455ms step_avg:61.08ms +step:1531/1555 train_time:93539ms step_avg:61.10ms +step:1532/1555 train_time:93629ms step_avg:61.12ms +step:1533/1555 train_time:93712ms step_avg:61.13ms +step:1534/1555 train_time:93802ms step_avg:61.15ms +step:1535/1555 train_time:93887ms step_avg:61.16ms +step:1536/1555 train_time:93979ms step_avg:61.18ms +step:1537/1555 train_time:94064ms step_avg:61.20ms +step:1538/1555 train_time:94154ms step_avg:61.22ms +step:1539/1555 train_time:94240ms step_avg:61.23ms +step:1540/1555 train_time:94330ms step_avg:61.25ms +step:1541/1555 train_time:94414ms step_avg:61.27ms +step:1542/1555 train_time:94505ms step_avg:61.29ms +step:1543/1555 train_time:94589ms step_avg:61.30ms +step:1544/1555 train_time:94680ms step_avg:61.32ms +step:1545/1555 train_time:94763ms step_avg:61.34ms +step:1546/1555 train_time:94854ms step_avg:61.35ms +step:1547/1555 train_time:94939ms step_avg:61.37ms +step:1548/1555 train_time:95030ms step_avg:61.39ms +step:1549/1555 train_time:95114ms step_avg:61.40ms +step:1550/1555 train_time:95204ms step_avg:61.42ms +step:1551/1555 train_time:95289ms step_avg:61.44ms +step:1552/1555 train_time:95380ms step_avg:61.46ms +step:1553/1555 train_time:95464ms step_avg:61.47ms +step:1554/1555 train_time:95555ms step_avg:61.49ms +step:1555/1555 train_time:95639ms step_avg:61.50ms +step:1555/1555 val_loss:3.2807 train_time:95753ms step_avg:61.58ms +peak memory allocated: 31630 MiB reserved: 46658 MiB diff --git a/records/track_1_short/2026-01-31-BigramHashH2D/e8b7eb4d-7eb4-46cf-a4bf-18e3e44b1b63.txt b/records/track_1_short/2026-01-31-BigramHashH2D/e8b7eb4d-7eb4-46cf-a4bf-18e3e44b1b63.txt new file mode 100644 index 000000000..e6cf442f9 --- /dev/null +++ b/records/track_1_short/2026-01-31-BigramHashH2D/e8b7eb4d-7eb4-46cf-a4bf-18e3e44b1b63.txt @@ -0,0 +1,3976 @@ +import os +import sys + +# Read the current file and the kernels file code ASAP, for logging +with open(sys.argv[0], 'r') as f: + code = f.read() +with open(os.path.join(os.path.dirname(sys.argv[0]), 'triton_kernels.py'), 'r') as f: + code += f"\n\n{'-'*40}\n# triton_kernels.py\n{'-'*40}\n\n" + code += f.read() + +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate, pairwise +from pathlib import Path +import gc + +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +import torch +import triton + +torch.empty( + 1, device=f"cuda:{os.environ['LOCAL_RANK']}", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +from kernels import get_kernel +from torch import Tensor, nn + +from triton_kernels import XXT, ba_plus_cAA, FusedLinearReLUSquareFunction, FusedSoftcappedCrossEntropy + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Distributed training setup +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +grad_scale = 2 / grad_accum_steps # consistent grad magnitudes between different num_devices +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng +# Transposed layout by @ChrisJMcCormick allows for faster gradient accumulation. + +@torch.library.custom_op("nanogpt::mm_t", mutates_args=()) +def mm_t_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + """Computes y = x @ w with F8 weights stored as (in_features, out_features).""" + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + assert x.shape[1] == w.shape[0] # x: (batch, in), w: (in, out) + + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + + # _scaled_mm requires column-major B. w_f8 is row-major (in, out). + # .T.contiguous().T creates a column-major view without changing logical shape. + w_f8_col_major = w_f8.T.contiguous().T + + out = torch._scaled_mm( + x_f8, + w_f8_col_major, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_t_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[0] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_t_backward", mutates_args=()) +def mm_t_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + + x_scale = grad.new_tensor(x_s, dtype=torch.float32) + w_scale = grad.new_tensor(w_s, dtype=torch.float32) + grad_scale = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + + # grad_x = grad @ w.T + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=grad_scale, + scale_b=w_scale, + use_fast_accum=False, + ) + + # grad_w = x.T @ grad + # Result is (in, out), naturally matching weight storage. No final .T needed. + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_scale, + scale_b=grad_scale, + use_fast_accum=False, + ) + + return grad_x, grad_w + + grad_x, grad_w = impl(g, x_f8, w_f8) + + return grad_x, grad_w + +@mm_t_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) + +def backward_t(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_t_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context_t(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_t_op.register_autograd(backward_t, setup_context=setup_context_t) + +# ----------------------------------------------------------------------------- +# Polar Express + +# Computed for num_iters=5, safety_factor=2e-2, cushion=2 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) +] + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def polar_express(G: torch.Tensor, split_baddbmm: bool = False): + """ + Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + """ + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + # Select batched vs unbatched + if split_baddbmm: + BX_matmul = torch.bmm if X.ndim > 2 else torch.mm + else: + aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the iterations + for a, b, c in polar_express_coeffs: + XXT(X, out=A) # A = X @ X.mT + ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + + # Referencing X twice causes pytorch to make a defensive copy, + # resulting in a cudaMemcpyAsync in baddbmm. + # For large matrices (i.e., the mlp weights), it's faster to split + # the operation into two kernels to avoid this. + if split_baddbmm: + BX_matmul(B, X, out=C) # C = B @ X + C.add_(X, alpha=a) # C = C + a*X (in-place, X only read) + else: + aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X + + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Combined NorMuon + Adam Optimizer + +@dataclass +class ParamConfig: + """Per-parameter configuration for NorMuonAndAdam optimizer.""" + label: str + optim: str # "adam" or "normuon" + comms: str # "none", "replicated", or "sharded" + adam_betas: tuple[float, float] | None + lr_mul: float + wd_mul: float + lr: float + initial_lr: float + weight_decay: float + # Adam-specific + eps: float | None = None + # NorMuon-specific + reshape: tuple | None = None + chunk_size: int | None = None + momentum: float | None = None + beta2: float | None = None + per_matrix_lr_mul: list[float] | None = None + + +class NorMuonAndAdam: + """ + Combined optimizer that handles both NorMuon (for projection matrices) and + Adam (for embeddings/scalars/gate weights). + + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, Muon uses a Newton-Schulz iteration (replaced + here with Polar Express), which has the advantage that it can be stably run in bfloat16 on the GPU. + + Muon is applied only to the projection matrices in the attention and MLP layers, and is not recommended + for embeddings, scalars, or individual weight vectors (e.g., bias terms or gate weights). + + Differences from standard Muon: + - Newton-Shulz is replaced with Polar Express for the orthogonalization step + - NorMuon adds a low-rank variance estimator similar to Adafactor. https://arxiv.org/pdf/2510.05491 + - Cautious weight decay, a gated version of decoupled weight decay + - Mantissa tracking for precision + + Adam (for embeddings/scalars/gates): + - Standard Adam with bias correction + - Cautious weight decay + + Configuration: + Unlike torch.optim.Optimizer, this class uses per-parameter configs from a `param_table` dict + and does not include parameter "groups". All parameters require a .label attribute, and a + corresponding entry in the param_table to specify their hyperparameters (lr_mul, wd_mul, adam_betas, etc.). + + Communication and ordering: + Gradient communication is explicitly scheduled rather than hook-driven. + Reductions are launched in `scatter_order`, while update math and final + gathers are executed in `work_order`. These orders are independent and + must each contain every parameter label exactly once. + + Two communication modes are supported per parameter: + - 'replicated': Gradients are all-reduced and each rank computes the full update. + - 'sharded': Gradients are reduce-scattered, each rank updates its shard, + and results are all-gathered. + + Adam parameters may be freely sharded. NorMuon operates on full matrices; sharding is + supported by grouping matrices into parameter banks. NorMuon parameters must have a + `.reshape` attribute that reshapes the bank so that the leading dimension is divisible + by world_size. + + # Contributors include @YouJiacheng, @KonstantinWilleke, @alexrgilbert, @adricarda, + # @tuttyfrutyee, @vdlad, @ryanyang0, @vagrawal, @varunneal, @chrisjmccormick + """ + def __init__(self, named_params, param_table: dict, scatter_order: list, work_order: list, + adam_defaults: dict, normuon_defaults: dict): + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # Store defaults for each optimizer type + self.adam_defaults = adam_defaults + self.normuon_defaults = normuon_defaults + self.param_table = param_table + self.scatter_order = scatter_order + self.work_order = work_order + + # Collect params by label and build config + self.param_cfgs: dict[nn.Parameter, ParamConfig] = {} + self.param_states: dict[nn.Parameter, dict] = {} + self._param_by_label: dict[str, nn.Parameter] = {} + for name, param in named_params: + label = getattr(param, "label", None) + assert label is not None and label in param_table # all params must have valid label + assert label not in self._param_by_label # exactly one param per label + self._param_by_label[label] = param + self._build_param_cfg(param, label) + + # Assert scatter_order and work_order match present labels exactly + present = set(self._param_by_label.keys()) + assert set(scatter_order) == present and set(work_order) == present + + # Handle world_size=1: overwrite comms to "none" + if self.world_size == 1: + for p_cfg in self.param_cfgs.values(): + p_cfg.comms = "none" + + # Initialize state for all params + self._init_state() + + # 0-D CPU tensors to avoid recompilation + self._step_size_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eff_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + + # Track async operations + self._reduce_futures: dict[nn.Parameter, tuple] = {} + + # Embed/lm_head tying state + self.split_embed = False + self._lm_head_param = self._param_by_label.get("lm_head") + self._embed_param = self._param_by_label.get("embed") + + def _build_param_cfg(self, param: nn.Parameter, label: str): + """Build config for a single parameter from param_table.""" + table_entry = self.param_table[label] + optim = table_entry["optim"] + comms = table_entry["comms"] + adam_betas = table_entry.get("adam_betas") + lr_mul = table_entry.get("lr_mul", 1.0) + wd_mul = table_entry.get("wd_mul", 1.0) + + if optim == "adam": + chunk_size = param.shape[0] // self.world_size if comms == "sharded" else None + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.adam_defaults["lr"], + initial_lr=self.adam_defaults["lr"], + weight_decay=self.adam_defaults["weight_decay"], + eps=self.adam_defaults["eps"], + chunk_size=chunk_size, + ) + elif optim == "normuon": + reshape = getattr(param, "reshape", None) + if reshape is None: + raise ValueError(f"NorMuon param {label} must have .reshape attribute") + if reshape[0] % self.world_size != 0: + raise ValueError(f"reshape[0]={reshape[0]} must be divisible by world_size") + + chunk_size = reshape[0] // self.world_size + chunk_shape = (chunk_size, *reshape[1:]) + # Shape-based LR multiplier for NorMuon + shape_mult = max(1.0, chunk_shape[-2] / chunk_shape[-1]) ** 0.5 if len(chunk_shape) >= 2 else 1.0 + lr_mul = shape_mult * lr_mul + + # Per-matrix LR multipliers for MLP c_proj (2x LR on odd indices) + per_matrix_lr_mul = None + if label == "mlp": + rank = dist.get_rank() if dist.is_initialized() else 0 + start_idx = rank * chunk_size + per_matrix_lr_mul = [] + for i in range(chunk_size): + global_idx = start_idx + i + is_c_proj = (global_idx % 2 == 1) + per_matrix_lr_mul.append(2.0 if is_c_proj else 1.0) + + p_cfg = ParamConfig( + label=label, + optim=optim, + comms=comms, + adam_betas=tuple(adam_betas) if adam_betas else None, + lr_mul=lr_mul, + wd_mul=wd_mul, + lr=self.normuon_defaults["lr"], + initial_lr=self.normuon_defaults["lr"], + weight_decay=self.normuon_defaults["weight_decay"], + reshape=reshape, + chunk_size=chunk_size, + momentum=self.normuon_defaults["momentum"], + beta2=self.normuon_defaults["beta2"], + per_matrix_lr_mul=per_matrix_lr_mul, + ) + else: + raise ValueError(f"Unknown optim type: {optim}") + + self.param_cfgs[param] = p_cfg + + def _init_state(self): + """Initialize optimizer state for all parameters.""" + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam": + # Sharded params use chunk state, replicated use full state + if p_cfg.comms == "sharded": + chunk = param[:p_cfg.chunk_size] + else: + chunk = param + exp_avg = torch.zeros_like(chunk, dtype=torch.float32, device=param.device) + self.param_states[param] = dict(step=0, exp_avg=exp_avg, exp_avg_sq=torch.zeros_like(exp_avg)) + + elif p_cfg.optim == "normuon": + chunk_shape = (p_cfg.chunk_size, *p_cfg.reshape[1:]) + + # Momentum buffer (FP32 for precision) + momentum_buffer = torch.zeros( + chunk_shape, dtype=torch.float32, device=param.device + ) + + # Second momentum buffer - reduced along one dimension + if chunk_shape[-2] >= chunk_shape[-1]: + second_mom_shape = (*chunk_shape[:-1], 1) + else: + second_mom_shape = (*chunk_shape[:-2], 1, chunk_shape[-1]) + second_momentum_buffer = torch.zeros( + second_mom_shape, dtype=torch.float32, device=param.device + ) + + # Mantissa buffer for precision tracking + mantissa = torch.zeros( + chunk_shape, dtype=torch.uint16, device=param.device + ) + + self.param_states[param] = dict( + momentum_buffer=momentum_buffer, + second_momentum_buffer=second_momentum_buffer, + mantissa=mantissa, + ) + + # ----------------------------------- + # Reduce/Gather operations + + def _launch_reduce(self, param: nn.Parameter, grad: Tensor): + """Launch async reduce for a parameter based on its comms policy.""" + p_cfg = self.param_cfgs[param] + + if p_cfg.comms == "none": + if p_cfg.optim == "normuon": + # NorMuon needs reshaped gradient even without communication + grad = grad.view(p_cfg.reshape) + self._reduce_futures[param] = (None, grad) + elif p_cfg.comms == "replicated": + future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + self._reduce_futures[param] = (future, grad) + elif p_cfg.comms == "sharded": + if p_cfg.optim == "normuon": + # NorMuon: reshape before reduce_scatter + grad_reshaped = grad.view(p_cfg.reshape) + grad_chunk = torch.empty( + (p_cfg.chunk_size, *grad_reshaped.shape[1:]), + dtype=grad.dtype, + device=grad.device + ) + future = dist.reduce_scatter_tensor( + grad_chunk, grad_reshaped.contiguous(), op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + else: + # Adam: simple reduce_scatter + grad_chunk = torch.empty_like(grad[:p_cfg.chunk_size]) + future = dist.reduce_scatter_tensor( + grad_chunk, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + self._reduce_futures[param] = (future, grad_chunk) + + def _launch_gather(self, param: nn.Parameter, p_slice: Tensor) -> "torch.futures.Future": + """Launch async all_gather for a sharded parameter.""" + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "normuon": + full_param = param.data.view(p_cfg.reshape) + assert full_param.is_contiguous() + return dist.all_gather_into_tensor( + full_param, p_slice.contiguous(), async_op=True + ).get_future() + else: + return dist.all_gather_into_tensor( + param, p_slice.contiguous(), async_op=True + ).get_future() + + # ----------------------------------- + # State management + + def reset(self): + """Reset NorMuon momentum buffers and split_embed state (called on training reset).""" + self.split_embed = False + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "normuon": + p_state = self.param_states[param] + p_state["momentum_buffer"].zero_() + p_state["mantissa"].zero_() + p_state["second_momentum_buffer"].zero_() + + def copy_lm_state_to_embed(self): + """ + Copy the optimizer state from the lm_head to the embed at the untie point. + This requires an all-gather + reshard because of different sharding: + - lm_head (768, 50304) is sharded to (96, 50304) per rank (along model_dim) + - embed (50304, 768) is sharded to (6288, 768) per rank (along vocab_size) + + We all-gather the lm_head momentum, transpose it, then each rank takes their + embed shard to get the correct momentum state. + """ + lm_head = self._lm_head_param + embed = self._embed_param + lm_state = self.param_states[lm_head] + embed_state = self.param_states[embed] + lm_cfg = self.param_cfgs[lm_head] + embed_cfg = self.param_cfgs[embed] + + embed_state['step'] = lm_state['step'] # Preserve step count for bias correction + + # Copy optimizer state with all-gather + transpose + reshard + if self.world_size > 1: + rank = dist.get_rank() + lm_chunk_size = lm_cfg.chunk_size # 96 + embed_chunk_size = embed_cfg.chunk_size # 6288 + + # All-gather lm_head momentum to get full (768, 50304) tensor + for key in ["exp_avg", "exp_avg_sq"]: + lm_chunk = lm_state[key] # (96, 50304) + full_lm = torch.empty(lm_head.shape[0], lm_head.shape[1], dtype=lm_chunk.dtype, device=lm_chunk.device) + dist.all_gather_into_tensor(full_lm, lm_chunk.contiguous()) + embed_state[key].copy_(full_lm.T[rank * embed_chunk_size:(rank + 1) * embed_chunk_size]) + else: + # Single GPU: simple transpose + for key in ["exp_avg", "exp_avg_sq"]: + embed_state[key].copy_(lm_state[key].T) + + # Mark as split + self.split_embed = True + + def state_dict(self): + """Return the optimizer state as a dict.""" + return { + "param_states": {id(p): s for p, s in self.param_states.items()}, + "param_cfgs": {id(p): s for p, s in self.param_cfgs.items()}, + } + + def load_state_dict(self, state_dict): + """Load optimizer state from a dict.""" + # Build id->param mapping + id_to_param = {id(p): p for p in self.param_cfgs.keys()} + + # Load state, preserving dtypes + for param_id, saved_p_state in state_dict["param_states"].items(): + if param_id in id_to_param: + param = id_to_param[param_id] + p_state = self.param_states[param] + for k, v in saved_p_state.items(): + if isinstance(v, torch.Tensor) and k in p_state: + target_dtype = p_state[k].dtype + p_state[k] = v.to(dtype=target_dtype, device=p_state[k].device) + else: + p_state[k] = v + + # ----------------------------------- + # Unified optimizer step with explicit ordering + + @torch.no_grad() + def step(self, do_adam: bool = True): + """ + Combined optimizer step with explicit ordering. + + Args: + do_adam: If True, update Adam params. NorMuon params always updated. + + Flow: + 1. Scatter phase: Launch reduces in scatter_order + 2. Work phase: Process updates in work_order + - Wait for reduce, compute update, launch gather + 3. Finalize phase: Wait for gathers + + While the embeddings are tied: + - Comms and update math are only done on lm_head. + - We add embed.grad.T into lm_head.grad before comms. + - After lm_head gather, we copy lm_head.data.T --> embed.data + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + lm_param, embed_param = self._lm_head_param, self._embed_param + + # ===== Phase 1: Launch reduces in scatter_order ===== + for label in self.scatter_order: + param = self._param_by_label[label] + p_cfg = self.param_cfgs[param] + + if p_cfg.optim == "adam" and not do_adam: + continue + if param.grad is None: + continue + + # lm_head when tied: aggregate embed.grad.T (transposed shapes) + if label == "lm_head" and do_adam and not self.split_embed: + if embed_param is not None and embed_param.grad is not None: + param.grad.add_(embed_param.grad.T) + + # Skip embed when tied (copied from lm_head after gather) + if label == "embed" and not self.split_embed: + continue + + self._launch_reduce(param, param.grad) + + # ===== Phase 2: Process updates in work_order ===== + gather_futures = [] + lm_head_gather_future = None + + for label in self.work_order: + param = self._param_by_label[label] + if param not in self._reduce_futures: + continue + + p_cfg = self.param_cfgs[param] + if p_cfg.optim == "adam" and not do_adam: + continue + # Wait for reduce + future, grad_chunk = self._reduce_futures[param] + if future is not None: + future.wait() + # Apply update based on optim type + if p_cfg.optim == "adam": + p_slice = self._adam_update(param, grad_chunk, p_cfg, rank) + else: + p_slice = self._normuon_update(param, grad_chunk, p_cfg, rank) + # Launch gather for sharded params + if p_cfg.comms == "sharded" and self.world_size > 1: + gather_fut = self._launch_gather(param, p_slice) + if label == "lm_head": + lm_head_gather_future = gather_fut + else: + gather_futures.append(gather_fut) + + # ===== Phase 3: Wait for gathers, sync embed if tied ===== + # Wait for lm_head gather first so we can copy to embed while other gathers complete + if lm_head_gather_future is not None: + lm_head_gather_future.wait() + + # When tied: copy lm_head.T to embed + if do_adam and not self.split_embed and embed_param is not None and lm_param is not None: + embed_param.data.copy_(lm_param.data.T) + + # Wait for remaining gathers + for fut in gather_futures: + fut.wait() + + self._reduce_futures.clear() + + # Clear grads for updated params + for param, p_cfg in self.param_cfgs.items(): + if p_cfg.optim == "adam" and not do_adam: + continue # Don't clear Adam grads on even steps + param.grad = None + + # ----------------------------------- + # Adam update + + def _adam_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply Adam update to a parameter. Returns the updated p_slice.""" + beta1, beta2 = p_cfg.adam_betas + lr = p_cfg.lr * p_cfg.lr_mul + + # Get parameter slice + if p_cfg.comms == "sharded": + p_slice = param[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + else: + p_slice = param + + p_state = self.param_states[param] + p_state["step"] += 1 + t = p_state["step"] + + bias1, bias2 = 1 - beta1 ** t, 1 - beta2 ** t + self._step_size_t.fill_(lr * (bias2 ** 0.5 / bias1)) + self._eff_wd_t.fill_(lr * lr * p_cfg.weight_decay * p_cfg.wd_mul) + + NorMuonAndAdam._adam_update_step( + p_slice, grad_chunk, p_state["exp_avg"], p_state["exp_avg_sq"], + beta1, beta2, p_cfg.eps, self._step_size_t, self._eff_wd_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _adam_update_step(p_slice, g_slice, exp_avg, exp_avg_sq, beta1, beta2, eps, step_size_t, eff_wd_t): + """Compiled Adam update step.""" + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + update = exp_avg.div(exp_avg_sq.sqrt().add_(eps)).mul_(step_size_t) + # Cautious weight decay + mask = (update * p_slice) > 0 + update.addcmul_(p_slice, mask, value=eff_wd_t) + p_slice.add_(other=update, alpha=-1.0) + + # ----------------------------------- + # NorMuon update + + def _normuon_update(self, param: nn.Parameter, grad_chunk: Tensor, p_cfg: ParamConfig, rank: int) -> Tensor: + """Apply NorMuon update to a parameter. Returns the updated p_slice.""" + chunk_shape = grad_chunk.shape + + p_state = self.param_states[param] + grad_chunk = grad_chunk.float() # FP32 for momentum + + # Momentum update + momentum_buffer = p_state["momentum_buffer"] + momentum_buffer.lerp_(grad_chunk, 1 - p_cfg.momentum) + updated_grads = grad_chunk.lerp_(momentum_buffer, p_cfg.momentum) + + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + + # Polar Express orthogonalization + is_large_matrix = chunk_shape[-2] > 1024 + v_chunk = polar_express(updated_grads, split_baddbmm=is_large_matrix) + + # Variance reduction + red_dim = -1 if chunk_shape[-2] >= chunk_shape[-1] else -2 + v_chunk = NorMuonAndAdam._apply_normuon_variance_reduction( + v_chunk, p_state["second_momentum_buffer"], p_cfg.beta2, red_dim + ) + + # Update parameter, in place, with cautious weight decay + param_view = param.data.view(p_cfg.reshape) + p_slice = param_view[rank * p_cfg.chunk_size:(rank + 1) * p_cfg.chunk_size] + + # MLP has per-matrix LR multipliers (c_proj gets 2x LR) + if p_cfg.per_matrix_lr_mul is not None: + for mat_idx in range(p_cfg.chunk_size): + self._eff_lr_t.fill_(p_cfg.lr_mul * p_cfg.per_matrix_lr_mul[mat_idx] * p_cfg.lr) + self._eff_wd_t.fill_(p_cfg.wd_mul * p_cfg.weight_decay * p_cfg.lr) + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice[mat_idx].view(torch.uint16), p_state["mantissa"][mat_idx], v_chunk[mat_idx], + self._eff_wd_t, self._eff_lr_t + ) + else: + NorMuonAndAdam._cautious_wd_and_update_inplace( + p_slice.view(torch.uint16), p_state["mantissa"], v_chunk, + self._eff_wd_t, self._eff_lr_t + ) + + return p_slice + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _cautious_wd_and_update_inplace(p, mantissa, grad, wd_tensor, lr_tensor): + """ + Cautious weight decay + parameter update. wd_tensor and lr_tensor are 0-D CPU tensors. + Mantissa is tracked to enable higher precision updates on bfloat16 parameters. + bfloat16 format: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits total + float32 format: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits total + """ + assert p.dtype == mantissa.dtype == torch.uint16 + grad = grad.float() + wd_factor = wd_tensor.to(torch.float32) + lr_factor = lr_tensor.to(torch.float32) + p_precise_raw = (p.to(torch.uint32) << 16) | mantissa.to(torch.uint32) + p_precise = p_precise_raw.view(torch.float32) + mask = (grad * p_precise) >= 0 + p_precise.copy_(p_precise - (p_precise * mask * wd_factor * lr_factor) - (grad * lr_factor)) + p.copy_((p_precise_raw >> 16).to(torch.uint16)) + mantissa.copy_(p_precise_raw.to(torch.uint16)) + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _apply_normuon_variance_reduction(v_chunk, second_momentum_buffer, beta2, red_dim): + """NorMuon variance reduction. Algebraically fuses the normalization steps to minimize memory ops.""" + v_mean = v_chunk.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = v_chunk.size(red_dim) + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True).mul_(red_dim_size) + v_norm = v_norm_sq.sqrt_() + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt_() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt_() + final_scale = step_size * (v_norm / v_norm_new.clamp_min_(1e-10)) + return v_chunk.mul_(final_scale.type_as(v_chunk)) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinearT(nn.Module): + """ + Linear layer with transposed weight storage (in_features, out_features) which + addresses the slow kernel that was used for gradient accumulation. @chrisjmccormick + """ + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + self.weight = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.bfloat16)) + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + nn.init.zeros_(self.weight) # @Grad62304977 and others + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out = torch.ops.nanogpt.mm_t(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return x @ self.weight.type_as(x) + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len, paired=False): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.paired = paired + self.reset() + + def rotary(self, x_BTHD): + assert self.factor1.size(0) >= x_BTHD.size(-3) + factor1, factor2 = ( + self.factor1[None, : x_BTHD.size(-3), None, :], + self.factor2[None, : x_BTHD.size(-3), None, :], + ) + x_flip = x_BTHD.view(*x_BTHD.shape[:-1], x_BTHD.shape[-1] // 2, 2).flip(-1).view(x_BTHD.shape) + return factor1 * x_BTHD + factor2 * x_flip + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + angular_freq = angular_freq.repeat_interleave(2) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//2)]) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=device) + if not self.paired: + theta = torch.outer(t, angular_freq) + self.factor1 = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.factor2 = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, angular_freq) + theta2 = torch.outer(t_odd, angular_freq) + self.factor1 = nn.Buffer( + torch.cat((theta1.cos(), theta2.cos()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2 = nn.Buffer( + torch.cat((theta1.sin(), theta2.sin()), dim=-1).to(torch.bfloat16), + persistent=False + ) + self.factor2[..., 1::2] *= -1 + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(2*self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + if not self.paired: + theta = torch.outer(t, self.angular_freq) + self.factor1.copy_(theta.cos()) + self.factor2.copy_(theta.sin()) + else: + t_even = 2 * t + t_odd = 2 * t + 1 + theta1 = torch.outer(t_even, self.angular_freq) + theta2 = torch.outer(t_odd, self.angular_freq) + self.factor1.copy_(torch.cat((theta1.cos(), theta2.cos()), dim=-1)) + self.factor2.copy_(torch.cat((theta1.sin(), theta2.sin()), dim=-1)) + self.factor2[..., 1::2] *= -1 + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + yarn: Yarn + key_offset: bool + attn_gate_w: torch.Tensor + ve_gate_w: torch.Tensor + +flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, paired: bool = False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.dim = dim + self.hdim = num_heads * head_dim + self.paired = paired + assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + yarn = attn_args.yarn + ve, sa_lambdas, key_offset = attn_args.ve, attn_args.sa_lambdas, attn_args.key_offset + seqlens, bm_size = attn_args.seqlens, attn_args.bm_size + # sparse gated attention to enable context based no-op by @classiclarryd + # only include gates on layers with value embeds used on forward pass + attn_gate_w, ve_gate_w = attn_args.attn_gate_w, attn_args.ve_gate_w + + q, k, v = F.linear(x, sa_lambdas[0] * qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + q, k = norm(q), norm(k) # QK norm @Grad62304977 + + if not self.paired: + q, k = yarn.rotary(q), yarn.rotary(k) + + if key_offset: + # shift keys forward for the stationary head dims. Enables 1-layer induction. + k[:, 1:, :, self.head_dim // 2:] = k[:, :-1, :, self.head_dim // 2:] + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T, self.num_heads, 1) + v = v + ve_gate_out * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + + else: + # Paired heads: adjacent heads' queries attend to each other's keys. + # Two copies of the input stream are interleaved to achieve this, which: + # - doubles the length of each sequence + # - halves the effective window size + q = q.view(B, T, self.num_heads // 2, self.head_dim * 2) + k = k.view(B, T, self.num_heads // 2, self.head_dim * 2) + v = v.reshape(B, T * 2, self.num_heads // 2, self.head_dim) + + q, k = yarn.rotary(q), yarn.rotary(k) + + q = q.view(B, T * 2, self.num_heads // 2, self.head_dim) + k = k.view(B, T * 2, self.num_heads // 2, self.head_dim) + + if ve is not None: + ve_gate_out = 2 * torch.sigmoid(F.linear(x[..., :12], ve_gate_w)).view(B, T * 2, self.num_heads // 2, 1) + v = v + ve_gate_out * ve.view_as(v) + + seqlens = 2 * seqlens + max_len = 2 * max_len + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, + max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=yarn.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(F.linear(x[..., :12], attn_gate_w)).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, sa_lambdas[1] * qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg + return y + +class MLP(nn.Module): + def __init__(self): + super().__init__() + # Weights are stored in parameter banks and passed via forward() + + def forward(self, x: Tensor, c_fc: Tensor, c_proj: Tensor): + # relu(x)^2: + # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + # Fused triton kernel for relu(x @ W1.T)^2 @ W2.T + return FusedLinearReLUSquareFunction.apply(x, c_fc, c_proj) + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, has_attn: bool, has_mlp: bool, use_paired_head: bool): + super().__init__() + # skip attention of blocks.6 (the 7th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads, paired=use_paired_head) if has_attn else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP() if has_mlp else None + + def forward(self, x: Tensor, attn_args: AttnArgs, qkvo_w: Tensor = None, c_fc: Tensor = None, c_proj: Tensor = None): + if self.attn is not None: + x = x + self.attn(norm(x), attn_args, qkvo_w) + if self.mlp is not None: + x = x + self.mlp(norm(x), c_fc, c_proj) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +@dataclass +class ForwardScheduleConfig: + mtp_weights: torch.Tensor + ws_short: int + ws_long: int + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + self.num_layers = num_layers + self.vocab_size = next_multiple_of_n(vocab_size, n=128) + + self.smear_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.smear_gate.weight) + self.smear_gate.weight.label = 'smear_gate' + + self.skip_gate = nn.Linear(12, 1, bias=False) + nn.init.zeros_(self.skip_gate.weight) + self.skip_gate.weight.label = 'skip_gate' + + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.Parameter(torch.zeros(5 * self.vocab_size, model_dim, dtype=torch.bfloat16)) + self.value_embeds.label = 'value_embed' + + # parameter banks for attention and value embedding gate weights + self.attn_gate_bank = nn.Parameter(torch.zeros(10, num_heads, 12)) # 10 layers + self.attn_gate_bank.label = 'attn_gate_bank' + self.ve_gate_bank = nn.Parameter(torch.zeros(5, num_heads, 12)) # 5 unique gates + self.ve_gate_bank.label = 've_gate_bank' + + # ----------------------------------- + # Parameter banks for sharded optimization, by @chrisjmccormick + + # Identify which layers have attention/MLP + # Attention is skipped in layer 6 by @YouJiacheng + self.attn_layer_indices = [i for i in range(num_layers) if i != 6] + # All layers have MLP (At 11 layers--dropped first layer @EmelyanenkoK) + self.mlp_layer_indices = list(range(num_layers)) + + hdim = num_heads * head_dim + mlp_hdim = 4 * model_dim + + # Create index mappings: layer_idx -> bank_idx + self.layer_to_attn_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.attn_layer_indices)} + self.layer_to_mlp_idx = {layer_idx: bank_idx for bank_idx, layer_idx in enumerate(self.mlp_layer_indices)} + + # Attention bank: stores QKVO weights for all attention layers + # merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + # Simplified layout by @chrisjmccormick + # Shape: (num_attn_layers, 4*model_dim, hdim) = (10, 3072, 768) + # Reshape for sharding: (40, 768, 768) for even distribution across 8 GPUs + self.attn_bank = nn.Parameter(torch.empty(len(self.attn_layer_indices), 4 * model_dim, hdim)) + self.attn_bank.label = 'attn' + self.attn_bank.reshape = (len(self.attn_layer_indices) * 4, hdim, hdim) # (40, 768, 768) + + # MLP bank: stores c_fc and c_proj for all MLP layers + # Shape: (num_mlp_layers + padding, 2, mlp_hdim, model_dim) = (12, 2, 3072, 768) + # We add 1 padding layer (index 11) to get 12*2=24 matrices for even distribution across 8 GPUs + # Reshape for sharding: (24, 3072, 768) + num_mlp_with_padding = len(self.mlp_layer_indices) + 1 # 11 + 1 = 12 + self.mlp_bank = nn.Parameter(torch.empty(num_mlp_with_padding, 2, mlp_hdim, model_dim)) + self.mlp_bank.label = 'mlp' + self.mlp_bank.reshape = (num_mlp_with_padding * 2, mlp_hdim, model_dim) # (24, 3072, 768) + + # improved init scale by @YouJiacheng and @srashedll + std = 0.5 * model_dim ** -0.5 + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.attn_bank.uniform_(-bound, bound) + self.mlp_bank[:, 0, :, :].uniform_(-bound, bound) # c_fc + self.mlp_bank[:, 1, :, :].zero_() # c_proj - zero init suggested by @Grad62304977 + + # Create blocks with has_attn/has_mlp flags + self.paired_head_layers = [0, 2, 5, 9] + self.blocks = nn.ModuleList([ + Block(model_dim, head_dim, num_heads, + has_attn=(i in self.layer_to_attn_idx), + has_mlp=(i in self.layer_to_mlp_idx), + use_paired_head=(i in self.paired_head_layers)) + for i in range(num_layers) + ]) + self.yarn = Yarn(head_dim, max_seq_len) + self.yarn_paired_head = Yarn(head_dim, max_seq_len, paired=True) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + # Transposed weight storage for faster gradient accumulation + self.lm_head = CastedLinearT(model_dim, self.vocab_size, use_fp8=use_fp8, x_s=100/448, w_s=1.6/448, grad_s=grad_scale * 0.75/448) + + nn.init.normal_(self.lm_head.weight, mean=0, std=0.005) + self.lm_head.weight.label = 'lm_head' + + self.embed = nn.Embedding(self.vocab_size, model_dim) + self.embed.weight.label = 'embed' + with torch.no_grad(): + self.embed.weight.copy_(self.lm_head.weight.T) + + self.bigram_embed = nn.Embedding(args.bigram_vocab_size, model_dim) + self.bigram_embed.weight.label = 'bigram_embed' + nn.init.zeros_(self.bigram_embed.weight) + + # x0_lambdas separated out for different optimizer treatment (no beta smoothing) + self.x0_lambdas = nn.Parameter(torch.zeros(num_layers)) + self.x0_lambdas.label = 'x0_lambdas' + + pad = (-num_layers * 3 - 3) % dist.get_world_size() # updated: 3*num_layers instead of 4* + self.scalars = nn.Parameter( + torch.cat( + [ + 1.1 * torch.ones(num_layers), # resid lambdas. 1.1 init such that layer i weight is i^(num_layers-i). + *[torch.tensor([0.5, 1.0]) for _ in range(num_layers)], # SA lambdas + 0.1 * torch.ones(num_layers), # bigram lambdas + torch.zeros(1), # smear_lambda + 0.5*torch.ones(1), # backout_lambda + -1.5 * torch.ones(1), # skip_lambda -> σ(-1.5) ≈ 0.18 + torch.ones(pad), + ] + ) + ) + self.scalars.label = 'scalars' + + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _compute_bigram_hash(x: Tensor, mod: int) -> Tensor: + """ + Computes bigram hash on GPU for each position using [prev_token, curr_token]. + Mathematically identical to the CPU version but computed on device. + """ + rand_int_1 = 36313 + rand_int_2 = 27191 + result = torch.empty_like(x) + result[0] = mod + result[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod + return result + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, schedule_cfg: ForwardScheduleConfig): + assert input_seq.ndim == 1 + + # unpack schedule_cfg + mtp_weights, ws_short, ws_long = schedule_cfg.mtp_weights, schedule_cfg.ws_short, schedule_cfg.ws_long + + # set configs + skip_connections = [] + skip_in = [3] # long attention window on layer 3 + skip_out = [6] # no attn op on layer 6 + x_backout = None + backout_layer = 7 + + # set lambdas + resid_lambdas = self.scalars[: 1 * self.num_layers] + x0_lambdas = self.x0_lambdas + sa_lambdas = self.scalars[1 * self.num_layers: 3 * self.num_layers].view(-1, 2) + bigram_lambdas = self.scalars[3 * self.num_layers: 4 * self.num_layers] + smear_lambda = self.scalars[4 * self.num_layers] + backout_lambda = self.scalars[4 * self.num_layers+1] + skip_lambda = self.scalars[4 * self.num_layers+2] + + # set block masks and key shift + bm_sizes = [ws_short, ws_short, ws_short, ws_long, ws_short, ws_short, None, ws_short, ws_short, ws_short, ws_long] + assert len(bm_sizes) == self.num_layers + key_offset = [b==ws_long for b in bm_sizes] # apply partial key offset to long windows + + # Embedding lookup - embed is synced from lm_head during tied phase by optimizer + x = self.embed(input_seq) + # Compute bigram hash on GPU (moved from CPU data loader) + bigram_seq = self._compute_bigram_hash(input_seq, args.bigram_vocab_size - 1) + x0_bigram = self.bigram_embed(bigram_seq)[None] + + # Value embeddings - always computed (not precomputed) + ve = self.value_embeds.view(5, self.vocab_size, -1)[:, input_seq] + # 01 ... 234 structure on token value embeddings by @photomz + ve = [ve[0], ve[1]] + [None] * (self.num_layers - 5) + [ve[2], ve[3], ve[4]] + assert len(ve) == self.num_layers + + # smear token embed forward 1 position @classiclarryd + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # unbind gate banks to avoid select_backwards kernel + ag = [w.bfloat16() for w in self.attn_gate_bank.unbind(0)] + veg = [w.bfloat16() for w in self.ve_gate_bank.unbind(0)] + attn_gates = ag[:6] + [None] + ag[6:] + ve_gates = [veg[0], veg[1]] + [None] * (self.num_layers - 5) + [veg[2], veg[3], veg[4]] + assert len(attn_gates) == self.num_layers + assert len(ve_gates) == self.num_layers + + # unbind weight banks to avoid select_backwards kernel + attn_weights = self.attn_bank.unbind(0) # tuple of [4*dim, hdim] tensors + mlp_fcs = self.mlp_bank[:, 0, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + mlp_projs = self.mlp_bank[:, 1, :, :].unbind(0) # tuple of [mlp_hdim, dim] tensors + + for i in range(self.num_layers): + yarn = self.yarn_paired_head if i in self.paired_head_layers else self.yarn + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + yarn=yarn, + key_offset=key_offset[i], + attn_gate_w=attn_gates[i], + ve_gate_w=ve_gates[i] + ) + if i in skip_out: + skip_gate_out = torch.sigmoid(skip_lambda) * 2 * torch.sigmoid(self.skip_gate(x0[..., :self.skip_gate.weight.size(-1)])) + x = x + skip_gate_out * skip_connections.pop() + if i == 0: + x = (resid_lambdas[0] + x0_lambdas[0]) * x + bigram_lambdas[0] * x0_bigram + else: + x = resid_lambdas[i] * x + x0_lambdas[i] * x0 + bigram_lambdas[i] * x0_bigram + + # Get weights for this layer from banks + qkvo_w = attn_weights[self.layer_to_attn_idx[i]] if i in self.layer_to_attn_idx else None + c_fc = mlp_fcs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + c_proj = mlp_projs[self.layer_to_mlp_idx[i]] if i in self.layer_to_mlp_idx else None + + x = self.blocks[i](x, attn_args, qkvo_w, c_fc, c_proj) + if i in skip_in: + skip_connections.append(x) + if i == backout_layer: + x_backout = x + + # back out contributions from first 7 layers that are only required for downstream context and not direct prediction + x -= backout_lambda * x_backout + x = norm(x) + logits = self.lm_head(x) + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15 + # @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1). @classiclarryd updated to 23*sigmoid((logits+5)/7.5) + if self.training: + losses = FusedSoftcappedCrossEntropy.apply(logits.view(-1, logits.size(-1)), target_seq, mtp_weights, 23.0, 5.0, 7.5) + loss = losses.sum() + else: + logits = 23 * torch.sigmoid((logits + 5) / 7.5) + logits_for_loss = logits.float() + loss = F.cross_entropy(logits_for_loss.view(-1, logits_for_loss.size(-1)), target_seq, reduction="mean") + return loss +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class Shard: + def __init__(self, tokens: Tensor, world_size: int = 1): + self.tokens = tokens + self.size = tokens.numel() + self.world_size = world_size + self.i = 0 + + # Partial index now, full index async + self.bos_idx = (tokens[:6_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._full_idx = None + self._loader_thread = None + self._ready = threading.Event() + self._loader_thread = threading.Thread(target=self._scan) + self._loader_thread.start() + + def _scan(self): + self._full_idx = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self._ready.set() + + def _maybe_switch(self): + # Switch to full index as soon as async scan completes + if self.bos_idx is not self._full_idx and self._ready.is_set(): + self._loader_thread.join() + self.bos_idx = self._full_idx + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + self._maybe_switch() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + return starts, ends + + @staticmethod + def load_async(file: Path, world_size: int = 1): + """Returns getter function for async shard loading""" + result = {} + ready = threading.Event() + def load(): + tokens = _load_data_shard(file) + result['shard'] = Shard(tokens, world_size) + ready.set() + thread = threading.Thread(target=load) + thread.start() + def get(): + ready.wait() + thread.join() + return result['shard'] + return get + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + shard = Shard(tokens, world_size) + next_shard_getter = Shard.load_async(next(file_iter), world_size) + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = shard.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + shard = next_shard_getter() + tokens = shard.tokens + try: + next_shard_getter = Shard.load_async(next(file_iter), world_size) + except StopIteration: + next_shard_getter = None # no more shards to preload + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + # Cast to int32 on CPU before transfer to avoid dtype conversion during .to() + _inputs = _inputs.to(dtype=torch.int32) + _targets = _targets.to(dtype=torch.int64) + _cum_lengths = _cum_lengths.to(dtype=torch.int32) + # Bigram hash computation moved to GPU in forward() + + new_params = yield ( + _inputs.to(device="cuda", non_blocking=True), + _targets.to(device="cuda", non_blocking=True), + _cum_lengths.to(device="cuda", non_blocking=True), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * new_grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens // new_grad_accum_steps + max_seq_len = new_max_seq_len + +# ----------------------------------------------------------------------------- +# Training Management + +@dataclass +class Hyperparameters: + # data + data_path = os.environ.get("DATA_PATH", ".") + train_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_train_*.bin") # input .bin to train on + val_files: str = os.path.join(data_path, "data/fineweb10B/fineweb_val_*.bin") # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + # batch sizes + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # schedule + num_scheduled_iterations: int = 1515 # number of steps to complete lr and ws schedule + num_extension_iterations: int = 40 # number of steps to continue training at final lr and ws + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = 250 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # bigram hash embedding + bigram_vocab_size: int = 50304 * 5 + +args = Hyperparameters() + +@dataclass +class TrainingStage: + lr_mul: float + batch_size: int + window_sizes: tuple[int, int] # (short, long) in block units + mtp_weights_start: list[float] + mtp_weights_end: list[float] + duration: float = None + +class TrainingSchedule: + """ + Training schedule initialized via TRAINING_STAGES + 1. Multi Token Prediction schedule of [1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1] @varunneal + 2. Sliding Attention window schedule of [1,3] -> [3,7] -> [5,11] -> [6,13] + 3. YaRN updates to RoPE on window changes + 4. Split embed and lm head at 2/3 of training + 5. Batch size schedule of 8 -> 16 -> 24 + 6. Post training extension of long windows from 13 to 20 + """ + + def __init__(self, stages: list[TrainingStage], scheduled_iterations: int, extension_iterations: int, + cooldown_frac: float = 0.5, split_embed_stage: int = 2, ws_post_yarn_ext: int = 20): + self.stages = stages + self.scheduled_iterations = scheduled_iterations + self.cooldown_frac = cooldown_frac + # increase final validation ws, used for YaRN extension and short window size @classiclarryd + self.ws_post_yarn_ext = ws_post_yarn_ext + + self.total_steps = self.scheduled_iterations + extension_iterations + + # Build stage boundaries (last is extension stage) + ends = [0] + [round(c * scheduled_iterations) for c in accumulate(s.duration for s in stages[:-1])] + [self.total_steps] + assert self.scheduled_iterations == ends[-2] + self.boundaries = list(pairwise(ends)) + + # Split embed at specified stage (ensure odd step for Adam) + self.split_step = self.boundaries[split_embed_stage][0] | 1 + + # Precompute MTP weights for all steps + self.mtp_weights = [] + for step in range(self.total_steps + 1): + stage, t = self.lookup(step) + w = [a + (b - a) * t for a, b in zip(stage.mtp_weights_start, stage.mtp_weights_end)] + self.mtp_weights.append(torch.tensor(w, device=device)) + + def lookup(self, step: int) -> tuple[TrainingStage, float]: + # Returns stage and % of the way through that stage + for i, (start, end) in enumerate(self.boundaries): + if step < end: + t = (step - start) / (end - start) + return self.stages[i], t + return self.stages[-1], 1.0 + + def get_lr(self, step: int) -> float: + # learning rate schedule: tied to batch size schedule, with cooldown at the end + stage, _ = self.lookup(step) + lr = stage.lr_mul + cd_start = int(self.scheduled_iterations * (1 - self.cooldown_frac)) + if step >= cd_start: + t = min(1.0, (step - cd_start) / (self.scheduled_iterations - cd_start)) + lr = lr * (1 - t) + 0.1 * t + return lr + +# window_sizes are in units of `block_size` tokens (defined in TrainingManager) +TRAINING_STAGES = [ + TrainingStage(duration=1/3, batch_size=8 * 2048 * 8, window_sizes=(1, 3), lr_mul=1.0, + mtp_weights_start=[1.0, 0.5, 0.25], mtp_weights_end=[1.0, 0.5, 0.0]), + TrainingStage(duration=1/3, batch_size=16 * 2048 * 8, window_sizes=(3, 7), lr_mul=1.52, # (16/8)**0.6 + mtp_weights_start=[1.0, 0.5], mtp_weights_end=[1.0, 0.0]), + TrainingStage(duration=1/3, batch_size=24 * 2048 * 8, window_sizes=(5, 11), lr_mul=1.73, # (24/8)**0.5 + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), + # extension stage + TrainingStage(batch_size=24 * 2048 * 8, window_sizes=(6, 13), lr_mul=1.0, # lr_mul is not used + mtp_weights_start=[1.0], mtp_weights_end=[1.0]), +] + +training_schedule = TrainingSchedule(TRAINING_STAGES, args.num_scheduled_iterations, args.num_extension_iterations, cooldown_frac=0.55) + +def get_muon_momentum(step: int, muon_warmup_steps=300, muon_cooldown_steps=50, momentum_min=0.85, momentum_max=0.95): + # warmup phase: linearly increase momentum from min to max + # cooldown phase: linearly decrease momentum from max to min + momentum_cd_start = training_schedule.total_steps - muon_cooldown_steps + if step < muon_warmup_steps: + frac = step / muon_warmup_steps + momentum = momentum_min + frac * (momentum_max - momentum_min) + elif step > momentum_cd_start: + frac = (step - momentum_cd_start) / muon_cooldown_steps + momentum = momentum_max - frac * (momentum_max - momentum_min) + else: + momentum = momentum_max + return momentum + +class TrainingManager(): + """ + Manages the NorMuonAndAdam for all parameters with explicit ordering. + 1. Scalars are given higher momentum terms to smooth learning @ChrisJMcCormick + 2. Adam optimizers are only stepped on odd steps @classiclarryd + 3. Explicit scatter_order and work_order for communication scheduling (no backward hooks) + 4. Muon has a linear momentum warmup and cooldown schedule + 5. Learning rates follow a linear decay schedule + 6. Embed is tied to lm_head until split step (2/3 of training), then untied @classiclarryd + """ + def __init__(self, model): + self.model = model + self.block_size = 128 + + # - Ordering dictates when to launch reduce/reduce_scatter operations + # - "sharded" parameters use reduce_scatter/all_gather and "replicated" ones use all_reduce + # - lr_mul and wd_mul are per-parameter learning rate and weight decay multipliers + self.param_table = { + "attn": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "mlp": {"optim": "normuon", "comms": "sharded", "adam_betas": None}, + "scalars": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 5.0, "wd_mul": 0.0}, + "value_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "bigram_embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.75, 0.95], "lr_mul": 75., "wd_mul": 5.0}, + "smear_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.01, "wd_mul": 0.0}, + "skip_gate": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99], "lr_mul": 0.05, "wd_mul": 0.0}, + "attn_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "ve_gate_bank": {"optim": "adam", "comms": "replicated", "adam_betas": [0.9, 0.99]}, + "x0_lambdas": {"optim": "adam", "comms": "replicated", "adam_betas": [0.65, 0.95], "lr_mul": 5.0, "wd_mul": 0.0}, + "lm_head": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + "embed": {"optim": "adam", "comms": "sharded", "adam_betas": [0.5, 0.95], "wd_mul": 150.}, + } + + # - Process smaller/faster params first while large reduces complete + # - lm_head must complete before embed sync (when tied) + self.work_order = [ + "scalars", "smear_gate", "skip_gate", "attn_gate_bank", "ve_gate_bank", "x0_lambdas", # Small, fast + "value_embed", "bigram_embed", # Medium + "lm_head", "embed", # lm_head must complete before embed sync (when tied) + "attn", "mlp", # Large, polar express - process last to maximize overlap + ] + + adam_defaults = dict( + lr=0.008, + eps=1e-10, + weight_decay=0.005, + ) + + normuon_defaults = dict( + lr=0.023, + momentum=0.95, + beta2=0.95, + weight_decay=1.2, + ) + + self.optimizer = NorMuonAndAdam( + model.named_parameters(), + param_table=self.param_table, + scatter_order=list(self.param_table.keys()), # Dict order defines scatter priority + work_order=self.work_order, + adam_defaults=adam_defaults, + normuon_defaults=normuon_defaults, + ) + + # Split embed from lm_head at 2/3 of training (on an odd step so Adam updates) + self.split_step = training_schedule.split_step + + self.reset() + + def apply_final_ws_ext(self): + self.ws_long = training_schedule.ws_post_yarn_ext + + def get_forward_args(self): + return ForwardScheduleConfig( + mtp_weights = self.mtp_weights, + ws_short = self.ws_short * self.block_size, + ws_long = self.ws_long * self.block_size + ) + + def _is_adam_step(self, step: int): + """Adam params are only updated on odd steps.""" + return step % 2 == 1 + + def get_transition_steps(self): + return [start for start, _ in training_schedule.boundaries[1:]] + + def advance_schedule(self, step: int): + stage, _ = training_schedule.lookup(step) + self.ws_short, new_ws_long = stage.window_sizes + if new_ws_long != self.ws_long: + self.model.yarn.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + self.model.yarn_paired_head.apply(self.ws_long * self.block_size, new_ws_long * self.block_size) + + new_batch_size = stage.batch_size + if new_batch_size != self.batch_size: + self.train_loader_send_args = (new_batch_size, args.train_max_seq_len, grad_accum_steps) + self.batch_size = new_batch_size + else: + self.train_loader_send_args = None + + self.ws_long = new_ws_long + self.mtp_weights = training_schedule.mtp_weights[step] + + def step_optimizers(self, step: int): + step_lr = training_schedule.get_lr(step) + muon_momentum = get_muon_momentum(step) + do_adam = self._is_adam_step(step) + + # Update learning rates and momentum for all params + for param, p_cfg in self.optimizer.param_cfgs.items(): + p_cfg.lr = p_cfg.initial_lr * step_lr + if p_cfg.optim == "normuon": + p_cfg.momentum = muon_momentum + + # Step optimizer with do_adam flag + self.optimizer.step(do_adam=do_adam) + + # At split step: copy lm_head optimizer state to embed and mark as split + if step == self.split_step: + self.optimizer.copy_lm_state_to_embed() + + def reset(self, state=None): + if state is not None: + self.optimizer.load_state_dict(state) + + # Reset NorMuon momentum buffers and split_embed state + self.optimizer.reset() + + stage, _ = training_schedule.lookup(0) + self.ws_short, self.ws_long = stage.window_sizes + self.batch_size = stage.batch_size + self.model.yarn.reset() + self.model.yarn_paired_head.reset() + + def get_state(self): + return copy.deepcopy(self.optimizer.state_dict()) + +# ----------------------------------------------------------------------------- +# int main + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=11, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=args.val_batch_size // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.weight.data = m.weight.data.bfloat16() +model.attn_gate_bank.data = model.attn_gate_bank.data.bfloat16() +model.ve_gate_bank.data = model.ve_gate_bank.data.bfloat16() +model.attn_bank.data = model.attn_bank.data.bfloat16() +model.mlp_bank.data = model.mlp_bank.data.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) +training_manager = TrainingManager(model) + +######################################## +# Warmup kernels # +######################################## +print0("Compiling model and warming up kernels (~7 minutes on first execution)", console=True) +# Warmup the training kernels, then re-initialize the state so we aren't cheating +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizer=training_manager.get_state()) # save the initial state +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + +transition_steps = training_manager.get_transition_steps() +# first few steps plus transitions +warmup_steps = sorted({0, 1, 2} | set(s + offset for s in transition_steps for offset in [-1, 0, 1] if s + offset >= 0)) +print0(f"Sampling steps {warmup_steps} for warmup", console=True) +for step in warmup_steps: + training_manager.advance_schedule(step) + model.eval() + with torch.no_grad(): + inputs, targets, cum_seqlens = next(val_loader) + model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + model.train() + for idx in range(grad_accum_steps): + send_args = training_manager.train_loader_send_args + inputs, targets, cum_seqlens = train_loader.send(send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) +print0("Resetting Model", console=True) +model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +training_manager.reset(initial_state["optimizer"]) +del val_loader, train_loader, initial_state +model.train() + +######################################## +# Training and validation # +######################################## +train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) + +gc.collect() + +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = training_schedule.total_steps +for step in range(train_steps + 1): + last_step = (step == train_steps) + training_manager.advance_schedule(step) + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + if last_step: + training_manager.apply_final_ws_ext() + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) + val_loss /= val_steps + del val_loader + dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizer=training_manager.get_state()) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for idx in range(grad_accum_steps): + inputs, targets, cum_seqlens = train_loader.send(training_manager.train_loader_send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() + training_manager.step_optimizers(step) + + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + + +---------------------------------------- +# triton_kernels.py +---------------------------------------- + +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.jit +def XXT_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def XXT(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded configs based on H100 autotuning + if K == 768: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + else: + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 128 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + XXT_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +@triton.jit +def ba_plus_cAA_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A + # Performance is slightly slower than XXT_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + # Hardcoded config based on H100 autotuning (M=768) + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 64 + num_stages, num_warps = 4, 4 + + grid = (batch_size * triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(M, BLOCK_SIZE_N),) + ba_plus_cAA_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + LOWER_UPPER=1, + num_stages=num_stages, + num_warps=num_warps, + ) + return out + +# ----------------------------------------------------------------------------- +# Triton kernel for MLP: relu(x @ W1.T)^2, by @andrewbriand, @jrauvola + +@triton.jit +def linear_relu_square_kernel(a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, + ): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am_c = pid_m * BLOCK_SIZE_M + offs_bn_c = pid_n * BLOCK_SIZE_N + + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + + c0 = acc0.to(dtype) + if not FORWARD: + c0_pre = aux_desc.load([offs_am_c, offs_bn_c]) + c0 = 2 * c0 * tl.where(c0_pre > 0, c0_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c], c0) + + if FORWARD: + c0_post = tl.maximum(c0, 0) + c0_post = c0_post * c0_post + aux_desc.store([offs_am_c, offs_bn_c], c0_post) + + c1 = acc1.to(dtype) + if not FORWARD: + c1_pre = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c1 = 2 * c1 * tl.where(c1_pre > 0, c1_pre, 0) + + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + + if FORWARD: + c1_post = tl.maximum(c1, 0) + c1_post = c1_post * c1_post + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1_post) + + +def linear_relu_square(a, b, aux=None): + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + FORWARD = False + if aux is None: + FORWARD = True + aux = torch.empty((M, N), device=a.device, dtype=dtype) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_stages = 4 if FORWARD else 3 + num_warps = 8 + + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + + def grid(META): + return (min( + NUM_SMS, + triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), + ), ) + + linear_relu_square_kernel[grid]( + a_desc, b_desc, c_desc, aux_desc, + M, N, K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + NUM_SMS=NUM_SMS, + FORWARD=FORWARD, + num_stages=num_stages, + num_warps=num_warps + ) + + if FORWARD: + return c, aux + else: + return c + +class FusedLinearReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, W1, W2): + pre, post = linear_relu_square(x.view((-1, x.shape[-1])), W1) + x3 = post @ W2 + ctx.save_for_backward(x, W1, W2, pre, post) + return x3.view(x.shape) + + @staticmethod + def backward(ctx, grad_output): + x, W1, W2, pre, post = ctx.saved_tensors + dW2 = post.T @ grad_output + dpre = linear_relu_square(grad_output.view((-1, grad_output.shape[-1])), W2, aux=pre) + dW1 = dpre.T @ x + dx = dpre @ W1 + return dx.view(x.shape), dW1, dW2 + +# ----------------------------------------------------------------------------- +# Fused Softcapped Cross Entropy + + +@triton.jit +def fused_softcapped_entropy_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + + max_val = -float('inf') + sum_exp = 0.0 + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=-float('inf')).to(tl.float32) + z = A * tl.sigmoid((val + B) / C) + z = tl.where(mask, z, -float('inf')) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + + total_loss = 0.0 + for k in range(n_predict): + target_idx = row_idx + k + if target_idx < n_rows: + weight = tl.load(mtp_weights_ptr + k) + if weight > 0: + target = tl.load(targets_ptr + target_idx).to(tl.int32) + if target >= 0 and target < n_cols: + val_target = tl.load(logits_row_ptr + target).to(tl.float32) + z_target = A * tl.sigmoid((val_target + B) / C) + total_loss += weight * (lse - z_target) + + tl.store(losses_ptr + row_idx, total_loss) + +@triton.jit +def fused_softcapped_entropy_bwd_kernel( + grad_input_ptr, grad_output_ptr, lse_ptr, logits_ptr, targets_ptr, mtp_weights_ptr, + stride_logits_n, stride_logits_v, stride_grad_n, stride_grad_v, + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0).to(tl.int64) + + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_input_ptr + row_idx * stride_grad_n + + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_output_ptr + row_idx) + + S_w = 0.0 + for k in range(n_predict): + if row_idx + k < n_rows: + S_w += tl.load(mtp_weights_ptr + k) + + for off in range(0, n_cols, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + val = tl.load(logits_row_ptr + cols, mask=mask, other=0.0).to(tl.float32) + u = (val + B) / C + sigmoid_u = tl.sigmoid(u) + z = A * sigmoid_u + p = tl.exp(z - lse) + + term1 = S_w * p + term2 = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for k in range(n_predict): + if row_idx + k < n_rows: + target = tl.load(targets_ptr + row_idx + k).to(tl.int32) + weight = tl.load(mtp_weights_ptr + k) + term2 += tl.where(cols == target, weight, 0.0) + + grad_z = grad_loss * (term1 - term2) + dz_dx = (1.0 / C) * z * (1.0 - sigmoid_u) + grad_x = grad_z * dz_dx + tl.store(grad_row_ptr + cols, grad_x.to(tl.bfloat16), mask=mask) + +class FusedSoftcappedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, targets, mtp_weights, A=23.0, B=5.0, C=7.5): + n_rows, n_cols = logits.shape + if mtp_weights is None: + mtp_weights = torch.tensor([1.0], device=logits.device, dtype=torch.float32) + n_predict = mtp_weights.shape[0] + + losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + lse = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + + logits = logits.contiguous() + targets = targets.contiguous() + mtp_weights = mtp_weights.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_fwd_kernel[grid]( + logits, losses, lse, targets, mtp_weights, + logits.stride(0), logits.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + + ctx.save_for_backward(logits, targets, mtp_weights, lse) + ctx.params = (A, B, C) + return losses + + @staticmethod + def backward(ctx, grad_output): + logits, targets, mtp_weights, lse = ctx.saved_tensors + A, B, C = ctx.params + n_rows, n_cols = logits.shape + n_predict = mtp_weights.shape[0] + + grad_input = torch.empty((n_rows, n_cols), dtype=torch.bfloat16, device=logits.device) + grad_output = grad_output.contiguous() + + grid = (n_rows,) + fused_softcapped_entropy_bwd_kernel[grid]( + grad_input, grad_output, lse, logits, targets, mtp_weights, + logits.stride(0), logits.stride(1), grad_input.stride(0), grad_input.stride(1), + n_rows, n_cols, n_predict, + A, B, C, + BLOCK_SIZE=1024, + num_warps=8, + num_stages=4 + ) + return grad_input, None, None, None, None, None + +==================================================================================================== +Running Python 3.12.7 (main, Jan 31 2026, 04:21:49) [GCC 13.2.0] +Running PyTorch 2.10.0.dev20251210+cu126 compiled for CUDA 12.6 +Running Triton version 3.6.0 +Sun Feb 1 06:16:24 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:71:00.0 Off | 0 | +| N/A 39C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:79:00.0 Off | 0 | +| N/A 34C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:7F:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 39C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 37C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 23290 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 23291 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 23292 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 23293 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 23294 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 23295 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 23296 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 23297 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +Compiling model and warming up kernels (~7 minutes on first execution) +Sampling steps [0, 1, 2, 504, 505, 506, 1009, 1010, 1011, 1514, 1515, 1516] for warmup +Resetting Model +step:0/1555 val_loss:10.8319 train_time:0ms step_avg:0.03ms +step:1/1555 train_time:83ms step_avg:82.60ms +step:2/1555 train_time:111ms step_avg:55.25ms +step:3/1555 train_time:139ms step_avg:46.28ms +step:4/1555 train_time:165ms step_avg:41.29ms +step:5/1555 train_time:190ms step_avg:37.97ms +step:6/1555 train_time:228ms step_avg:37.96ms +step:7/1555 train_time:258ms step_avg:36.91ms +step:8/1555 train_time:296ms step_avg:37.00ms +step:9/1555 train_time:327ms step_avg:36.33ms +step:10/1555 train_time:365ms step_avg:36.49ms +step:11/1555 train_time:396ms step_avg:35.98ms +step:12/1555 train_time:433ms step_avg:36.12ms +step:13/1555 train_time:465ms step_avg:35.74ms +step:14/1555 train_time:502ms step_avg:35.86ms +step:15/1555 train_time:533ms step_avg:35.54ms +step:16/1555 train_time:571ms step_avg:35.69ms +step:17/1555 train_time:602ms step_avg:35.40ms +step:18/1555 train_time:639ms step_avg:35.51ms +step:19/1555 train_time:670ms step_avg:35.27ms +step:20/1555 train_time:708ms step_avg:35.41ms +step:21/1555 train_time:739ms step_avg:35.20ms +step:22/1555 train_time:777ms step_avg:35.30ms +step:23/1555 train_time:808ms step_avg:35.12ms +step:24/1555 train_time:845ms step_avg:35.22ms +step:25/1555 train_time:876ms step_avg:35.05ms +step:26/1555 train_time:914ms step_avg:35.17ms +step:27/1555 train_time:945ms step_avg:35.01ms +step:28/1555 train_time:983ms step_avg:35.12ms +step:29/1555 train_time:1014ms step_avg:34.98ms +step:30/1555 train_time:1052ms step_avg:35.08ms +step:31/1555 train_time:1084ms step_avg:34.95ms +step:32/1555 train_time:1122ms step_avg:35.06ms +step:33/1555 train_time:1153ms step_avg:34.94ms +step:34/1555 train_time:1191ms step_avg:35.02ms +step:35/1555 train_time:1222ms step_avg:34.91ms +step:36/1555 train_time:1260ms step_avg:34.99ms +step:37/1555 train_time:1291ms step_avg:34.89ms +step:38/1555 train_time:1329ms step_avg:34.98ms +step:39/1555 train_time:1360ms step_avg:34.87ms +step:40/1555 train_time:1398ms step_avg:34.94ms +step:41/1555 train_time:1429ms step_avg:34.86ms +step:42/1555 train_time:1467ms step_avg:34.94ms +step:43/1555 train_time:1498ms step_avg:34.85ms +step:44/1555 train_time:1536ms step_avg:34.91ms +step:45/1555 train_time:1567ms step_avg:34.83ms +step:46/1555 train_time:1606ms step_avg:34.90ms +step:47/1555 train_time:1636ms step_avg:34.82ms +step:48/1555 train_time:1674ms step_avg:34.88ms +step:49/1555 train_time:1705ms step_avg:34.79ms +step:50/1555 train_time:1743ms step_avg:34.85ms +step:51/1555 train_time:1774ms step_avg:34.78ms +step:52/1555 train_time:1811ms step_avg:34.83ms +step:53/1555 train_time:1842ms step_avg:34.75ms +step:54/1555 train_time:1879ms step_avg:34.80ms +step:55/1555 train_time:1910ms step_avg:34.74ms +step:56/1555 train_time:1951ms step_avg:34.84ms +step:57/1555 train_time:1980ms step_avg:34.73ms +step:58/1555 train_time:2017ms step_avg:34.78ms +step:59/1555 train_time:2048ms step_avg:34.72ms +step:60/1555 train_time:2087ms step_avg:34.78ms +step:61/1555 train_time:2118ms step_avg:34.72ms +step:62/1555 train_time:2155ms step_avg:34.77ms +step:63/1555 train_time:2187ms step_avg:34.72ms +step:64/1555 train_time:2225ms step_avg:34.77ms +step:65/1555 train_time:2256ms step_avg:34.71ms +step:66/1555 train_time:2294ms step_avg:34.75ms +step:67/1555 train_time:2325ms step_avg:34.70ms +step:68/1555 train_time:2363ms step_avg:34.75ms +step:69/1555 train_time:2394ms step_avg:34.70ms +step:70/1555 train_time:2432ms step_avg:34.74ms +step:71/1555 train_time:2463ms step_avg:34.69ms +step:72/1555 train_time:2501ms step_avg:34.74ms +step:73/1555 train_time:2532ms step_avg:34.68ms +step:74/1555 train_time:2569ms step_avg:34.72ms +step:75/1555 train_time:2600ms step_avg:34.67ms +step:76/1555 train_time:2637ms step_avg:34.70ms +step:77/1555 train_time:2668ms step_avg:34.66ms +step:78/1555 train_time:2706ms step_avg:34.69ms +step:79/1555 train_time:2737ms step_avg:34.65ms +step:80/1555 train_time:2775ms step_avg:34.68ms +step:81/1555 train_time:2806ms step_avg:34.64ms +step:82/1555 train_time:2843ms step_avg:34.68ms +step:83/1555 train_time:2874ms step_avg:34.63ms +step:84/1555 train_time:2912ms step_avg:34.67ms +step:85/1555 train_time:2943ms step_avg:34.62ms +step:86/1555 train_time:2980ms step_avg:34.66ms +step:87/1555 train_time:3012ms step_avg:34.62ms +step:88/1555 train_time:3050ms step_avg:34.66ms +step:89/1555 train_time:3081ms step_avg:34.61ms +step:90/1555 train_time:3118ms step_avg:34.65ms +step:91/1555 train_time:3149ms step_avg:34.61ms +step:92/1555 train_time:3187ms step_avg:34.65ms +step:93/1555 train_time:3219ms step_avg:34.61ms +step:94/1555 train_time:3256ms step_avg:34.64ms +step:95/1555 train_time:3288ms step_avg:34.61ms +step:96/1555 train_time:3326ms step_avg:34.65ms +step:97/1555 train_time:3357ms step_avg:34.60ms +step:98/1555 train_time:3394ms step_avg:34.64ms +step:99/1555 train_time:3426ms step_avg:34.60ms +step:100/1555 train_time:3463ms step_avg:34.63ms +step:101/1555 train_time:3495ms step_avg:34.60ms +step:102/1555 train_time:3533ms step_avg:34.64ms +step:103/1555 train_time:3564ms step_avg:34.60ms +step:104/1555 train_time:3602ms step_avg:34.64ms +step:105/1555 train_time:3633ms step_avg:34.60ms +step:106/1555 train_time:3671ms step_avg:34.63ms +step:107/1555 train_time:3702ms step_avg:34.60ms +step:108/1555 train_time:3740ms step_avg:34.63ms +step:109/1555 train_time:3771ms step_avg:34.60ms +step:110/1555 train_time:3809ms step_avg:34.62ms +step:111/1555 train_time:3840ms step_avg:34.59ms +step:112/1555 train_time:3877ms step_avg:34.62ms +step:113/1555 train_time:3908ms step_avg:34.59ms +step:114/1555 train_time:3946ms step_avg:34.61ms +step:115/1555 train_time:3977ms step_avg:34.58ms +step:116/1555 train_time:4015ms step_avg:34.61ms +step:117/1555 train_time:4046ms step_avg:34.58ms +step:118/1555 train_time:4084ms step_avg:34.61ms +step:119/1555 train_time:4115ms step_avg:34.58ms +step:120/1555 train_time:4153ms step_avg:34.61ms +step:121/1555 train_time:4184ms step_avg:34.58ms +step:122/1555 train_time:4221ms step_avg:34.60ms +step:123/1555 train_time:4252ms step_avg:34.57ms +step:124/1555 train_time:4289ms step_avg:34.59ms +step:125/1555 train_time:4320ms step_avg:34.56ms +step:126/1555 train_time:4358ms step_avg:34.59ms +step:127/1555 train_time:4389ms step_avg:34.56ms +step:128/1555 train_time:4428ms step_avg:34.59ms +step:129/1555 train_time:4459ms step_avg:34.56ms +step:130/1555 train_time:4496ms step_avg:34.59ms +step:131/1555 train_time:4528ms step_avg:34.56ms +step:132/1555 train_time:4566ms step_avg:34.59ms +step:133/1555 train_time:4597ms step_avg:34.56ms +step:134/1555 train_time:4634ms step_avg:34.58ms +step:135/1555 train_time:4665ms step_avg:34.56ms +step:136/1555 train_time:4703ms step_avg:34.58ms +step:137/1555 train_time:4734ms step_avg:34.55ms +step:138/1555 train_time:4772ms step_avg:34.58ms +step:139/1555 train_time:4802ms step_avg:34.55ms +step:140/1555 train_time:4840ms step_avg:34.57ms +step:141/1555 train_time:4871ms step_avg:34.55ms +step:142/1555 train_time:4908ms step_avg:34.57ms +step:143/1555 train_time:4939ms step_avg:34.54ms +step:144/1555 train_time:4976ms step_avg:34.56ms +step:145/1555 train_time:5008ms step_avg:34.54ms +step:146/1555 train_time:5046ms step_avg:34.56ms +step:147/1555 train_time:5077ms step_avg:34.54ms +step:148/1555 train_time:5115ms step_avg:34.56ms +step:149/1555 train_time:5146ms step_avg:34.54ms +step:150/1555 train_time:5184ms step_avg:34.56ms +step:151/1555 train_time:5215ms step_avg:34.54ms +step:152/1555 train_time:5253ms step_avg:34.56ms +step:153/1555 train_time:5284ms step_avg:34.53ms +step:154/1555 train_time:5322ms step_avg:34.56ms +step:155/1555 train_time:5353ms step_avg:34.53ms +step:156/1555 train_time:5390ms step_avg:34.55ms +step:157/1555 train_time:5421ms step_avg:34.53ms +step:158/1555 train_time:5459ms step_avg:34.55ms +step:159/1555 train_time:5490ms step_avg:34.53ms +step:160/1555 train_time:5528ms step_avg:34.55ms +step:161/1555 train_time:5559ms step_avg:34.53ms +step:162/1555 train_time:5596ms step_avg:34.54ms +step:163/1555 train_time:5627ms step_avg:34.52ms +step:164/1555 train_time:5665ms step_avg:34.54ms +step:165/1555 train_time:5696ms step_avg:34.52ms +step:166/1555 train_time:5734ms step_avg:34.54ms +step:167/1555 train_time:5765ms step_avg:34.52ms +step:168/1555 train_time:5802ms step_avg:34.54ms +step:169/1555 train_time:5833ms step_avg:34.52ms +step:170/1555 train_time:5871ms step_avg:34.53ms +step:171/1555 train_time:5902ms step_avg:34.51ms +step:172/1555 train_time:5939ms step_avg:34.53ms +step:173/1555 train_time:5970ms step_avg:34.51ms +step:174/1555 train_time:6008ms step_avg:34.53ms +step:175/1555 train_time:6039ms step_avg:34.51ms +step:176/1555 train_time:6076ms step_avg:34.52ms +step:177/1555 train_time:6107ms step_avg:34.50ms +step:178/1555 train_time:6145ms step_avg:34.52ms +step:179/1555 train_time:6175ms step_avg:34.50ms +step:180/1555 train_time:6213ms step_avg:34.52ms +step:181/1555 train_time:6244ms step_avg:34.50ms +step:182/1555 train_time:6281ms step_avg:34.51ms +step:183/1555 train_time:6313ms step_avg:34.50ms +step:184/1555 train_time:6351ms step_avg:34.51ms +step:185/1555 train_time:6381ms step_avg:34.49ms +step:186/1555 train_time:6419ms step_avg:34.51ms +step:187/1555 train_time:6450ms step_avg:34.49ms +step:188/1555 train_time:6488ms step_avg:34.51ms +step:189/1555 train_time:6519ms step_avg:34.49ms +step:190/1555 train_time:6557ms step_avg:34.51ms +step:191/1555 train_time:6588ms step_avg:34.49ms +step:192/1555 train_time:6626ms step_avg:34.51ms +step:193/1555 train_time:6656ms step_avg:34.49ms +step:194/1555 train_time:6694ms step_avg:34.50ms +step:195/1555 train_time:6725ms step_avg:34.49ms +step:196/1555 train_time:6763ms step_avg:34.51ms +step:197/1555 train_time:6794ms step_avg:34.49ms +step:198/1555 train_time:6832ms step_avg:34.50ms +step:199/1555 train_time:6863ms step_avg:34.49ms +step:200/1555 train_time:6901ms step_avg:34.50ms +step:201/1555 train_time:6931ms step_avg:34.48ms +step:202/1555 train_time:6969ms step_avg:34.50ms +step:203/1555 train_time:7000ms step_avg:34.48ms +step:204/1555 train_time:7037ms step_avg:34.50ms +step:205/1555 train_time:7069ms step_avg:34.48ms +step:206/1555 train_time:7107ms step_avg:34.50ms +step:207/1555 train_time:7137ms step_avg:34.48ms +step:208/1555 train_time:7175ms step_avg:34.49ms +step:209/1555 train_time:7205ms step_avg:34.48ms +step:210/1555 train_time:7243ms step_avg:34.49ms +step:211/1555 train_time:7274ms step_avg:34.47ms +step:212/1555 train_time:7312ms step_avg:34.49ms +step:213/1555 train_time:7343ms step_avg:34.47ms +step:214/1555 train_time:7380ms step_avg:34.49ms +step:215/1555 train_time:7412ms step_avg:34.47ms +step:216/1555 train_time:7450ms step_avg:34.49ms +step:217/1555 train_time:7480ms step_avg:34.47ms +step:218/1555 train_time:7518ms step_avg:34.48ms +step:219/1555 train_time:7549ms step_avg:34.47ms +step:220/1555 train_time:7587ms step_avg:34.49ms +step:221/1555 train_time:7618ms step_avg:34.47ms +step:222/1555 train_time:7655ms step_avg:34.48ms +step:223/1555 train_time:7686ms step_avg:34.47ms +step:224/1555 train_time:7724ms step_avg:34.48ms +step:225/1555 train_time:7755ms step_avg:34.47ms +step:226/1555 train_time:7793ms step_avg:34.48ms +step:227/1555 train_time:7824ms step_avg:34.46ms +step:228/1555 train_time:7861ms step_avg:34.48ms +step:229/1555 train_time:7892ms step_avg:34.46ms +step:230/1555 train_time:7929ms step_avg:34.48ms +step:231/1555 train_time:7960ms step_avg:34.46ms +step:232/1555 train_time:7998ms step_avg:34.47ms +step:233/1555 train_time:8029ms step_avg:34.46ms +step:234/1555 train_time:8067ms step_avg:34.47ms +step:235/1555 train_time:8097ms step_avg:34.46ms +step:236/1555 train_time:8135ms step_avg:34.47ms +step:237/1555 train_time:8166ms step_avg:34.46ms +step:238/1555 train_time:8203ms step_avg:34.47ms +step:239/1555 train_time:8234ms step_avg:34.45ms +step:240/1555 train_time:8272ms step_avg:34.47ms +step:241/1555 train_time:8303ms step_avg:34.45ms +step:242/1555 train_time:8340ms step_avg:34.46ms +step:243/1555 train_time:8371ms step_avg:34.45ms +step:244/1555 train_time:8409ms step_avg:34.46ms +step:245/1555 train_time:8440ms step_avg:34.45ms +step:246/1555 train_time:8477ms step_avg:34.46ms +step:247/1555 train_time:8508ms step_avg:34.45ms +step:248/1555 train_time:8546ms step_avg:34.46ms +step:249/1555 train_time:8577ms step_avg:34.45ms +step:250/1555 train_time:8615ms step_avg:34.46ms +step:250/1555 val_loss:4.5442 train_time:8665ms step_avg:34.66ms +step:251/1555 train_time:8685ms step_avg:34.60ms +step:252/1555 train_time:8710ms step_avg:34.56ms +step:253/1555 train_time:8734ms step_avg:34.52ms +step:254/1555 train_time:8759ms step_avg:34.48ms +step:255/1555 train_time:8787ms step_avg:34.46ms +step:256/1555 train_time:8825ms step_avg:34.47ms +step:257/1555 train_time:8857ms step_avg:34.46ms +step:258/1555 train_time:8895ms step_avg:34.48ms +step:259/1555 train_time:8926ms step_avg:34.46ms +step:260/1555 train_time:8963ms step_avg:34.47ms +step:261/1555 train_time:8994ms step_avg:34.46ms +step:262/1555 train_time:9032ms step_avg:34.47ms +step:263/1555 train_time:9063ms step_avg:34.46ms +step:264/1555 train_time:9100ms step_avg:34.47ms +step:265/1555 train_time:9131ms step_avg:34.46ms +step:266/1555 train_time:9169ms step_avg:34.47ms +step:267/1555 train_time:9199ms step_avg:34.45ms +step:268/1555 train_time:9237ms step_avg:34.47ms +step:269/1555 train_time:9267ms step_avg:34.45ms +step:270/1555 train_time:9305ms step_avg:34.46ms +step:271/1555 train_time:9336ms step_avg:34.45ms +step:272/1555 train_time:9373ms step_avg:34.46ms +step:273/1555 train_time:9404ms step_avg:34.45ms +step:274/1555 train_time:9441ms step_avg:34.45ms +step:275/1555 train_time:9471ms step_avg:34.44ms +step:276/1555 train_time:9508ms step_avg:34.45ms +step:277/1555 train_time:9540ms step_avg:34.44ms +step:278/1555 train_time:9577ms step_avg:34.45ms +step:279/1555 train_time:9608ms step_avg:34.44ms +step:280/1555 train_time:9645ms step_avg:34.45ms +step:281/1555 train_time:9676ms step_avg:34.44ms +step:282/1555 train_time:9715ms step_avg:34.45ms +step:283/1555 train_time:9746ms step_avg:34.44ms +step:284/1555 train_time:9783ms step_avg:34.45ms +step:285/1555 train_time:9815ms step_avg:34.44ms +step:286/1555 train_time:9853ms step_avg:34.45ms +step:287/1555 train_time:9884ms step_avg:34.44ms +step:288/1555 train_time:9922ms step_avg:34.45ms +step:289/1555 train_time:9953ms step_avg:34.44ms +step:290/1555 train_time:9990ms step_avg:34.45ms +step:291/1555 train_time:10022ms step_avg:34.44ms +step:292/1555 train_time:10059ms step_avg:34.45ms +step:293/1555 train_time:10090ms step_avg:34.44ms +step:294/1555 train_time:10127ms step_avg:34.45ms +step:295/1555 train_time:10158ms step_avg:34.43ms +step:296/1555 train_time:10196ms step_avg:34.45ms +step:297/1555 train_time:10228ms step_avg:34.44ms +step:298/1555 train_time:10265ms step_avg:34.45ms +step:299/1555 train_time:10296ms step_avg:34.43ms +step:300/1555 train_time:10334ms step_avg:34.45ms +step:301/1555 train_time:10364ms step_avg:34.43ms +step:302/1555 train_time:10402ms step_avg:34.44ms +step:303/1555 train_time:10433ms step_avg:34.43ms +step:304/1555 train_time:10471ms step_avg:34.44ms +step:305/1555 train_time:10501ms step_avg:34.43ms +step:306/1555 train_time:10538ms step_avg:34.44ms +step:307/1555 train_time:10569ms step_avg:34.43ms +step:308/1555 train_time:10606ms step_avg:34.44ms +step:309/1555 train_time:10637ms step_avg:34.42ms +step:310/1555 train_time:10675ms step_avg:34.44ms +step:311/1555 train_time:10705ms step_avg:34.42ms +step:312/1555 train_time:10744ms step_avg:34.44ms +step:313/1555 train_time:10774ms step_avg:34.42ms +step:314/1555 train_time:10812ms step_avg:34.43ms +step:315/1555 train_time:10844ms step_avg:34.42ms +step:316/1555 train_time:10881ms step_avg:34.43ms +step:317/1555 train_time:10912ms step_avg:34.42ms +step:318/1555 train_time:10949ms step_avg:34.43ms +step:319/1555 train_time:10981ms step_avg:34.42ms +step:320/1555 train_time:11018ms step_avg:34.43ms +step:321/1555 train_time:11049ms step_avg:34.42ms +step:322/1555 train_time:11087ms step_avg:34.43ms +step:323/1555 train_time:11117ms step_avg:34.42ms +step:324/1555 train_time:11155ms step_avg:34.43ms +step:325/1555 train_time:11186ms step_avg:34.42ms +step:326/1555 train_time:11223ms step_avg:34.43ms +step:327/1555 train_time:11254ms step_avg:34.42ms +step:328/1555 train_time:11292ms step_avg:34.43ms +step:329/1555 train_time:11323ms step_avg:34.42ms +step:330/1555 train_time:11360ms step_avg:34.43ms +step:331/1555 train_time:11391ms step_avg:34.41ms +step:332/1555 train_time:11429ms step_avg:34.42ms +step:333/1555 train_time:11460ms step_avg:34.41ms +step:334/1555 train_time:11497ms step_avg:34.42ms +step:335/1555 train_time:11528ms step_avg:34.41ms +step:336/1555 train_time:11565ms step_avg:34.42ms +step:337/1555 train_time:11596ms step_avg:34.41ms +step:338/1555 train_time:11634ms step_avg:34.42ms +step:339/1555 train_time:11665ms step_avg:34.41ms +step:340/1555 train_time:11702ms step_avg:34.42ms +step:341/1555 train_time:11733ms step_avg:34.41ms +step:342/1555 train_time:11771ms step_avg:34.42ms +step:343/1555 train_time:11802ms step_avg:34.41ms +step:344/1555 train_time:11839ms step_avg:34.42ms +step:345/1555 train_time:11870ms step_avg:34.41ms +step:346/1555 train_time:11908ms step_avg:34.42ms +step:347/1555 train_time:11940ms step_avg:34.41ms +step:348/1555 train_time:11978ms step_avg:34.42ms +step:349/1555 train_time:12008ms step_avg:34.41ms +step:350/1555 train_time:12046ms step_avg:34.42ms +step:351/1555 train_time:12077ms step_avg:34.41ms +step:352/1555 train_time:12114ms step_avg:34.42ms +step:353/1555 train_time:12145ms step_avg:34.41ms +step:354/1555 train_time:12183ms step_avg:34.42ms +step:355/1555 train_time:12214ms step_avg:34.41ms +step:356/1555 train_time:12252ms step_avg:34.42ms +step:357/1555 train_time:12283ms step_avg:34.41ms +step:358/1555 train_time:12321ms step_avg:34.42ms +step:359/1555 train_time:12352ms step_avg:34.41ms +step:360/1555 train_time:12390ms step_avg:34.42ms +step:361/1555 train_time:12421ms step_avg:34.41ms +step:362/1555 train_time:12458ms step_avg:34.41ms +step:363/1555 train_time:12489ms step_avg:34.40ms +step:364/1555 train_time:12526ms step_avg:34.41ms +step:365/1555 train_time:12557ms step_avg:34.40ms +step:366/1555 train_time:12594ms step_avg:34.41ms +step:367/1555 train_time:12625ms step_avg:34.40ms +step:368/1555 train_time:12663ms step_avg:34.41ms +step:369/1555 train_time:12694ms step_avg:34.40ms +step:370/1555 train_time:12731ms step_avg:34.41ms +step:371/1555 train_time:12762ms step_avg:34.40ms +step:372/1555 train_time:12800ms step_avg:34.41ms +step:373/1555 train_time:12831ms step_avg:34.40ms +step:374/1555 train_time:12868ms step_avg:34.41ms +step:375/1555 train_time:12899ms step_avg:34.40ms +step:376/1555 train_time:12937ms step_avg:34.41ms +step:377/1555 train_time:12968ms step_avg:34.40ms +step:378/1555 train_time:13006ms step_avg:34.41ms +step:379/1555 train_time:13036ms step_avg:34.40ms +step:380/1555 train_time:13074ms step_avg:34.41ms +step:381/1555 train_time:13105ms step_avg:34.40ms +step:382/1555 train_time:13142ms step_avg:34.40ms +step:383/1555 train_time:13173ms step_avg:34.40ms +step:384/1555 train_time:13212ms step_avg:34.41ms +step:385/1555 train_time:13243ms step_avg:34.40ms +step:386/1555 train_time:13281ms step_avg:34.41ms +step:387/1555 train_time:13311ms step_avg:34.40ms +step:388/1555 train_time:13349ms step_avg:34.40ms +step:389/1555 train_time:13380ms step_avg:34.40ms +step:390/1555 train_time:13418ms step_avg:34.41ms +step:391/1555 train_time:13449ms step_avg:34.40ms +step:392/1555 train_time:13487ms step_avg:34.40ms +step:393/1555 train_time:13518ms step_avg:34.40ms +step:394/1555 train_time:13555ms step_avg:34.40ms +step:395/1555 train_time:13586ms step_avg:34.40ms +step:396/1555 train_time:13623ms step_avg:34.40ms +step:397/1555 train_time:13654ms step_avg:34.39ms +step:398/1555 train_time:13692ms step_avg:34.40ms +step:399/1555 train_time:13723ms step_avg:34.39ms +step:400/1555 train_time:13760ms step_avg:34.40ms +step:401/1555 train_time:13791ms step_avg:34.39ms +step:402/1555 train_time:13828ms step_avg:34.40ms +step:403/1555 train_time:13859ms step_avg:34.39ms +step:404/1555 train_time:13897ms step_avg:34.40ms +step:405/1555 train_time:13928ms step_avg:34.39ms +step:406/1555 train_time:13965ms step_avg:34.40ms +step:407/1555 train_time:13996ms step_avg:34.39ms +step:408/1555 train_time:14034ms step_avg:34.40ms +step:409/1555 train_time:14065ms step_avg:34.39ms +step:410/1555 train_time:14102ms step_avg:34.40ms +step:411/1555 train_time:14133ms step_avg:34.39ms +step:412/1555 train_time:14171ms step_avg:34.39ms +step:413/1555 train_time:14201ms step_avg:34.39ms +step:414/1555 train_time:14239ms step_avg:34.39ms +step:415/1555 train_time:14270ms step_avg:34.38ms +step:416/1555 train_time:14307ms step_avg:34.39ms +step:417/1555 train_time:14338ms step_avg:34.38ms +step:418/1555 train_time:14375ms step_avg:34.39ms +step:419/1555 train_time:14406ms step_avg:34.38ms +step:420/1555 train_time:14444ms step_avg:34.39ms +step:421/1555 train_time:14475ms step_avg:34.38ms +step:422/1555 train_time:14513ms step_avg:34.39ms +step:423/1555 train_time:14544ms step_avg:34.38ms +step:424/1555 train_time:14582ms step_avg:34.39ms +step:425/1555 train_time:14613ms step_avg:34.38ms +step:426/1555 train_time:14651ms step_avg:34.39ms +step:427/1555 train_time:14681ms step_avg:34.38ms +step:428/1555 train_time:14719ms step_avg:34.39ms +step:429/1555 train_time:14750ms step_avg:34.38ms +step:430/1555 train_time:14787ms step_avg:34.39ms +step:431/1555 train_time:14818ms step_avg:34.38ms +step:432/1555 train_time:14855ms step_avg:34.39ms +step:433/1555 train_time:14886ms step_avg:34.38ms +step:434/1555 train_time:14924ms step_avg:34.39ms +step:435/1555 train_time:14955ms step_avg:34.38ms +step:436/1555 train_time:14993ms step_avg:34.39ms +step:437/1555 train_time:15024ms step_avg:34.38ms +step:438/1555 train_time:15061ms step_avg:34.39ms +step:439/1555 train_time:15092ms step_avg:34.38ms +step:440/1555 train_time:15130ms step_avg:34.39ms +step:441/1555 train_time:15161ms step_avg:34.38ms +step:442/1555 train_time:15198ms step_avg:34.39ms +step:443/1555 train_time:15229ms step_avg:34.38ms +step:444/1555 train_time:15267ms step_avg:34.38ms +step:445/1555 train_time:15297ms step_avg:34.38ms +step:446/1555 train_time:15335ms step_avg:34.38ms +step:447/1555 train_time:15366ms step_avg:34.38ms +step:448/1555 train_time:15403ms step_avg:34.38ms +step:449/1555 train_time:15434ms step_avg:34.37ms +step:450/1555 train_time:15472ms step_avg:34.38ms +step:451/1555 train_time:15503ms step_avg:34.37ms +step:452/1555 train_time:15540ms step_avg:34.38ms +step:453/1555 train_time:15571ms step_avg:34.37ms +step:454/1555 train_time:15609ms step_avg:34.38ms +step:455/1555 train_time:15640ms step_avg:34.37ms +step:456/1555 train_time:15678ms step_avg:34.38ms +step:457/1555 train_time:15708ms step_avg:34.37ms +step:458/1555 train_time:15746ms step_avg:34.38ms +step:459/1555 train_time:15777ms step_avg:34.37ms +step:460/1555 train_time:15815ms step_avg:34.38ms +step:461/1555 train_time:15845ms step_avg:34.37ms +step:462/1555 train_time:15883ms step_avg:34.38ms +step:463/1555 train_time:15914ms step_avg:34.37ms +step:464/1555 train_time:15952ms step_avg:34.38ms +step:465/1555 train_time:15983ms step_avg:34.37ms +step:466/1555 train_time:16020ms step_avg:34.38ms +step:467/1555 train_time:16051ms step_avg:34.37ms +step:468/1555 train_time:16089ms step_avg:34.38ms +step:469/1555 train_time:16119ms step_avg:34.37ms +step:470/1555 train_time:16157ms step_avg:34.38ms +step:471/1555 train_time:16187ms step_avg:34.37ms +step:472/1555 train_time:16225ms step_avg:34.37ms +step:473/1555 train_time:16255ms step_avg:34.37ms +step:474/1555 train_time:16293ms step_avg:34.37ms +step:475/1555 train_time:16324ms step_avg:34.37ms +step:476/1555 train_time:16362ms step_avg:34.37ms +step:477/1555 train_time:16392ms step_avg:34.37ms +step:478/1555 train_time:16430ms step_avg:34.37ms +step:479/1555 train_time:16461ms step_avg:34.36ms +step:480/1555 train_time:16498ms step_avg:34.37ms +step:481/1555 train_time:16529ms step_avg:34.36ms +step:482/1555 train_time:16566ms step_avg:34.37ms +step:483/1555 train_time:16596ms step_avg:34.36ms +step:484/1555 train_time:16634ms step_avg:34.37ms +step:485/1555 train_time:16665ms step_avg:34.36ms +step:486/1555 train_time:16703ms step_avg:34.37ms +step:487/1555 train_time:16733ms step_avg:34.36ms +step:488/1555 train_time:16772ms step_avg:34.37ms +step:489/1555 train_time:16802ms step_avg:34.36ms +step:490/1555 train_time:16839ms step_avg:34.37ms +step:491/1555 train_time:16870ms step_avg:34.36ms +step:492/1555 train_time:16908ms step_avg:34.37ms +step:493/1555 train_time:16939ms step_avg:34.36ms +step:494/1555 train_time:16976ms step_avg:34.36ms +step:495/1555 train_time:17007ms step_avg:34.36ms +step:496/1555 train_time:17045ms step_avg:34.37ms +step:497/1555 train_time:17076ms step_avg:34.36ms +step:498/1555 train_time:17114ms step_avg:34.36ms +step:499/1555 train_time:17145ms step_avg:34.36ms +step:500/1555 train_time:17183ms step_avg:34.37ms +step:500/1555 val_loss:4.2184 train_time:17232ms step_avg:34.46ms +step:501/1555 train_time:17252ms step_avg:34.43ms +step:502/1555 train_time:17272ms step_avg:34.41ms +step:503/1555 train_time:17290ms step_avg:34.37ms +step:504/1555 train_time:17322ms step_avg:34.37ms +step:505/1555 train_time:17354ms step_avg:34.36ms +step:506/1555 train_time:17396ms step_avg:34.38ms +step:507/1555 train_time:17450ms step_avg:34.42ms +step:508/1555 train_time:17513ms step_avg:34.48ms +step:509/1555 train_time:17570ms step_avg:34.52ms +step:510/1555 train_time:17636ms step_avg:34.58ms +step:511/1555 train_time:17692ms step_avg:34.62ms +step:512/1555 train_time:17757ms step_avg:34.68ms +step:513/1555 train_time:17813ms step_avg:34.72ms +step:514/1555 train_time:17878ms step_avg:34.78ms +step:515/1555 train_time:17935ms step_avg:34.82ms +step:516/1555 train_time:17999ms step_avg:34.88ms +step:517/1555 train_time:18056ms step_avg:34.92ms +step:518/1555 train_time:18120ms step_avg:34.98ms +step:519/1555 train_time:18178ms step_avg:35.03ms +step:520/1555 train_time:18244ms step_avg:35.09ms +step:521/1555 train_time:18303ms step_avg:35.13ms +step:522/1555 train_time:18369ms step_avg:35.19ms +step:523/1555 train_time:18428ms step_avg:35.24ms +step:524/1555 train_time:18491ms step_avg:35.29ms +step:525/1555 train_time:18549ms step_avg:35.33ms +step:526/1555 train_time:18613ms step_avg:35.39ms +step:527/1555 train_time:18671ms step_avg:35.43ms +step:528/1555 train_time:18735ms step_avg:35.48ms +step:529/1555 train_time:18792ms step_avg:35.52ms +step:530/1555 train_time:18856ms step_avg:35.58ms +step:531/1555 train_time:18913ms step_avg:35.62ms +step:532/1555 train_time:18977ms step_avg:35.67ms +step:533/1555 train_time:19034ms step_avg:35.71ms +step:534/1555 train_time:19098ms step_avg:35.76ms +step:535/1555 train_time:19156ms step_avg:35.81ms +step:536/1555 train_time:19221ms step_avg:35.86ms +step:537/1555 train_time:19279ms step_avg:35.90ms +step:538/1555 train_time:19344ms step_avg:35.96ms +step:539/1555 train_time:19405ms step_avg:36.00ms +step:540/1555 train_time:19468ms step_avg:36.05ms +step:541/1555 train_time:19527ms step_avg:36.09ms +step:542/1555 train_time:19591ms step_avg:36.15ms +step:543/1555 train_time:19648ms step_avg:36.18ms +step:544/1555 train_time:19712ms step_avg:36.24ms +step:545/1555 train_time:19770ms step_avg:36.28ms +step:546/1555 train_time:19835ms step_avg:36.33ms +step:547/1555 train_time:19892ms step_avg:36.37ms +step:548/1555 train_time:19956ms step_avg:36.42ms +step:549/1555 train_time:20013ms step_avg:36.45ms +step:550/1555 train_time:20078ms step_avg:36.50ms +step:551/1555 train_time:20135ms step_avg:36.54ms +step:552/1555 train_time:20199ms step_avg:36.59ms +step:553/1555 train_time:20257ms step_avg:36.63ms +step:554/1555 train_time:20323ms step_avg:36.68ms +step:555/1555 train_time:20381ms step_avg:36.72ms +step:556/1555 train_time:20446ms step_avg:36.77ms +step:557/1555 train_time:20504ms step_avg:36.81ms +step:558/1555 train_time:20568ms step_avg:36.86ms +step:559/1555 train_time:20627ms step_avg:36.90ms +step:560/1555 train_time:20691ms step_avg:36.95ms +step:561/1555 train_time:20748ms step_avg:36.98ms +step:562/1555 train_time:20813ms step_avg:37.03ms +step:563/1555 train_time:20869ms step_avg:37.07ms +step:564/1555 train_time:20934ms step_avg:37.12ms +step:565/1555 train_time:20990ms step_avg:37.15ms +step:566/1555 train_time:21055ms step_avg:37.20ms +step:567/1555 train_time:21112ms step_avg:37.23ms +step:568/1555 train_time:21177ms step_avg:37.28ms +step:569/1555 train_time:21235ms step_avg:37.32ms +step:570/1555 train_time:21300ms step_avg:37.37ms +step:571/1555 train_time:21358ms step_avg:37.40ms +step:572/1555 train_time:21424ms step_avg:37.45ms +step:573/1555 train_time:21483ms step_avg:37.49ms +step:574/1555 train_time:21547ms step_avg:37.54ms +step:575/1555 train_time:21604ms step_avg:37.57ms +step:576/1555 train_time:21669ms step_avg:37.62ms +step:577/1555 train_time:21728ms step_avg:37.66ms +step:578/1555 train_time:21791ms step_avg:37.70ms +step:579/1555 train_time:21849ms step_avg:37.74ms +step:580/1555 train_time:21913ms step_avg:37.78ms +step:581/1555 train_time:21970ms step_avg:37.81ms +step:582/1555 train_time:22035ms step_avg:37.86ms +step:583/1555 train_time:22092ms step_avg:37.89ms +step:584/1555 train_time:22157ms step_avg:37.94ms +step:585/1555 train_time:22215ms step_avg:37.97ms +step:586/1555 train_time:22280ms step_avg:38.02ms +step:587/1555 train_time:22338ms step_avg:38.05ms +step:588/1555 train_time:22404ms step_avg:38.10ms +step:589/1555 train_time:22461ms step_avg:38.13ms +step:590/1555 train_time:22526ms step_avg:38.18ms +step:591/1555 train_time:22583ms step_avg:38.21ms +step:592/1555 train_time:22647ms step_avg:38.25ms +step:593/1555 train_time:22705ms step_avg:38.29ms +step:594/1555 train_time:22769ms step_avg:38.33ms +step:595/1555 train_time:22828ms step_avg:38.37ms +step:596/1555 train_time:22891ms step_avg:38.41ms +step:597/1555 train_time:22949ms step_avg:38.44ms +step:598/1555 train_time:23014ms step_avg:38.48ms +step:599/1555 train_time:23071ms step_avg:38.52ms +step:600/1555 train_time:23136ms step_avg:38.56ms +step:601/1555 train_time:23193ms step_avg:38.59ms +step:602/1555 train_time:23258ms step_avg:38.63ms +step:603/1555 train_time:23316ms step_avg:38.67ms +step:604/1555 train_time:23380ms step_avg:38.71ms +step:605/1555 train_time:23439ms step_avg:38.74ms +step:606/1555 train_time:23503ms step_avg:38.78ms +step:607/1555 train_time:23561ms step_avg:38.82ms +step:608/1555 train_time:23625ms step_avg:38.86ms +step:609/1555 train_time:23682ms step_avg:38.89ms +step:610/1555 train_time:23747ms step_avg:38.93ms +step:611/1555 train_time:23805ms step_avg:38.96ms +step:612/1555 train_time:23869ms step_avg:39.00ms +step:613/1555 train_time:23928ms step_avg:39.03ms +step:614/1555 train_time:23992ms step_avg:39.07ms +step:615/1555 train_time:24049ms step_avg:39.10ms +step:616/1555 train_time:24113ms step_avg:39.14ms +step:617/1555 train_time:24171ms step_avg:39.17ms +step:618/1555 train_time:24235ms step_avg:39.21ms +step:619/1555 train_time:24293ms step_avg:39.25ms +step:620/1555 train_time:24358ms step_avg:39.29ms +step:621/1555 train_time:24416ms step_avg:39.32ms +step:622/1555 train_time:24482ms step_avg:39.36ms +step:623/1555 train_time:24540ms step_avg:39.39ms +step:624/1555 train_time:24603ms step_avg:39.43ms +step:625/1555 train_time:24661ms step_avg:39.46ms +step:626/1555 train_time:24726ms step_avg:39.50ms +step:627/1555 train_time:24783ms step_avg:39.53ms +step:628/1555 train_time:24848ms step_avg:39.57ms +step:629/1555 train_time:24906ms step_avg:39.60ms +step:630/1555 train_time:24970ms step_avg:39.63ms +step:631/1555 train_time:25029ms step_avg:39.67ms +step:632/1555 train_time:25092ms step_avg:39.70ms +step:633/1555 train_time:25150ms step_avg:39.73ms +step:634/1555 train_time:25214ms step_avg:39.77ms +step:635/1555 train_time:25272ms step_avg:39.80ms +step:636/1555 train_time:25336ms step_avg:39.84ms +step:637/1555 train_time:25395ms step_avg:39.87ms +step:638/1555 train_time:25459ms step_avg:39.90ms +step:639/1555 train_time:25517ms step_avg:39.93ms +step:640/1555 train_time:25581ms step_avg:39.97ms +step:641/1555 train_time:25640ms step_avg:40.00ms +step:642/1555 train_time:25704ms step_avg:40.04ms +step:643/1555 train_time:25761ms step_avg:40.06ms +step:644/1555 train_time:25826ms step_avg:40.10ms +step:645/1555 train_time:25884ms step_avg:40.13ms +step:646/1555 train_time:25949ms step_avg:40.17ms +step:647/1555 train_time:26008ms step_avg:40.20ms +step:648/1555 train_time:26071ms step_avg:40.23ms +step:649/1555 train_time:26130ms step_avg:40.26ms +step:650/1555 train_time:26193ms step_avg:40.30ms +step:651/1555 train_time:26250ms step_avg:40.32ms +step:652/1555 train_time:26315ms step_avg:40.36ms +step:653/1555 train_time:26373ms step_avg:40.39ms +step:654/1555 train_time:26438ms step_avg:40.42ms +step:655/1555 train_time:26496ms step_avg:40.45ms +step:656/1555 train_time:26561ms step_avg:40.49ms +step:657/1555 train_time:26618ms step_avg:40.51ms +step:658/1555 train_time:26682ms step_avg:40.55ms +step:659/1555 train_time:26739ms step_avg:40.58ms +step:660/1555 train_time:26804ms step_avg:40.61ms +step:661/1555 train_time:26862ms step_avg:40.64ms +step:662/1555 train_time:26928ms step_avg:40.68ms +step:663/1555 train_time:26986ms step_avg:40.70ms +step:664/1555 train_time:27050ms step_avg:40.74ms +step:665/1555 train_time:27108ms step_avg:40.76ms +step:666/1555 train_time:27172ms step_avg:40.80ms +step:667/1555 train_time:27230ms step_avg:40.82ms +step:668/1555 train_time:27293ms step_avg:40.86ms +step:669/1555 train_time:27350ms step_avg:40.88ms +step:670/1555 train_time:27414ms step_avg:40.92ms +step:671/1555 train_time:27471ms step_avg:40.94ms +step:672/1555 train_time:27536ms step_avg:40.98ms +step:673/1555 train_time:27596ms step_avg:41.00ms +step:674/1555 train_time:27660ms step_avg:41.04ms +step:675/1555 train_time:27718ms step_avg:41.06ms +step:676/1555 train_time:27782ms step_avg:41.10ms +step:677/1555 train_time:27840ms step_avg:41.12ms +step:678/1555 train_time:27904ms step_avg:41.16ms +step:679/1555 train_time:27962ms step_avg:41.18ms +step:680/1555 train_time:28027ms step_avg:41.22ms +step:681/1555 train_time:28084ms step_avg:41.24ms +step:682/1555 train_time:28149ms step_avg:41.27ms +step:683/1555 train_time:28208ms step_avg:41.30ms +step:684/1555 train_time:28273ms step_avg:41.33ms +step:685/1555 train_time:28330ms step_avg:41.36ms +step:686/1555 train_time:28394ms step_avg:41.39ms +step:687/1555 train_time:28451ms step_avg:41.41ms +step:688/1555 train_time:28517ms step_avg:41.45ms +step:689/1555 train_time:28575ms step_avg:41.47ms +step:690/1555 train_time:28640ms step_avg:41.51ms +step:691/1555 train_time:28698ms step_avg:41.53ms +step:692/1555 train_time:28763ms step_avg:41.56ms +step:693/1555 train_time:28820ms step_avg:41.59ms +step:694/1555 train_time:28884ms step_avg:41.62ms +step:695/1555 train_time:28942ms step_avg:41.64ms +step:696/1555 train_time:29006ms step_avg:41.68ms +step:697/1555 train_time:29064ms step_avg:41.70ms +step:698/1555 train_time:29129ms step_avg:41.73ms +step:699/1555 train_time:29186ms step_avg:41.75ms +step:700/1555 train_time:29250ms step_avg:41.79ms +step:701/1555 train_time:29308ms step_avg:41.81ms +step:702/1555 train_time:29372ms step_avg:41.84ms +step:703/1555 train_time:29430ms step_avg:41.86ms +step:704/1555 train_time:29494ms step_avg:41.90ms +step:705/1555 train_time:29552ms step_avg:41.92ms +step:706/1555 train_time:29617ms step_avg:41.95ms +step:707/1555 train_time:29675ms step_avg:41.97ms +step:708/1555 train_time:29739ms step_avg:42.00ms +step:709/1555 train_time:29798ms step_avg:42.03ms +step:710/1555 train_time:29863ms step_avg:42.06ms +step:711/1555 train_time:29920ms step_avg:42.08ms +step:712/1555 train_time:29984ms step_avg:42.11ms +step:713/1555 train_time:30042ms step_avg:42.13ms +step:714/1555 train_time:30107ms step_avg:42.17ms +step:715/1555 train_time:30165ms step_avg:42.19ms +step:716/1555 train_time:30230ms step_avg:42.22ms +step:717/1555 train_time:30287ms step_avg:42.24ms +step:718/1555 train_time:30352ms step_avg:42.27ms +step:719/1555 train_time:30410ms step_avg:42.29ms +step:720/1555 train_time:30474ms step_avg:42.32ms +step:721/1555 train_time:30532ms step_avg:42.35ms +step:722/1555 train_time:30595ms step_avg:42.38ms +step:723/1555 train_time:30653ms step_avg:42.40ms +step:724/1555 train_time:30718ms step_avg:42.43ms +step:725/1555 train_time:30775ms step_avg:42.45ms +step:726/1555 train_time:30840ms step_avg:42.48ms +step:727/1555 train_time:30899ms step_avg:42.50ms +step:728/1555 train_time:30963ms step_avg:42.53ms +step:729/1555 train_time:31021ms step_avg:42.55ms +step:730/1555 train_time:31086ms step_avg:42.58ms +step:731/1555 train_time:31143ms step_avg:42.60ms +step:732/1555 train_time:31208ms step_avg:42.63ms +step:733/1555 train_time:31266ms step_avg:42.65ms +step:734/1555 train_time:31330ms step_avg:42.68ms +step:735/1555 train_time:31388ms step_avg:42.70ms +step:736/1555 train_time:31452ms step_avg:42.73ms +step:737/1555 train_time:31510ms step_avg:42.75ms +step:738/1555 train_time:31574ms step_avg:42.78ms +step:739/1555 train_time:31632ms step_avg:42.80ms +step:740/1555 train_time:31696ms step_avg:42.83ms +step:741/1555 train_time:31753ms step_avg:42.85ms +step:742/1555 train_time:31818ms step_avg:42.88ms +step:743/1555 train_time:31876ms step_avg:42.90ms +step:744/1555 train_time:31941ms step_avg:42.93ms +step:745/1555 train_time:32000ms step_avg:42.95ms +step:746/1555 train_time:32064ms step_avg:42.98ms +step:747/1555 train_time:32122ms step_avg:43.00ms +step:748/1555 train_time:32186ms step_avg:43.03ms +step:749/1555 train_time:32244ms step_avg:43.05ms +step:750/1555 train_time:32308ms step_avg:43.08ms +step:750/1555 val_loss:3.8666 train_time:32390ms step_avg:43.19ms +step:751/1555 train_time:32411ms step_avg:43.16ms +step:752/1555 train_time:32432ms step_avg:43.13ms +step:753/1555 train_time:32491ms step_avg:43.15ms +step:754/1555 train_time:32557ms step_avg:43.18ms +step:755/1555 train_time:32615ms step_avg:43.20ms +step:756/1555 train_time:32681ms step_avg:43.23ms +step:757/1555 train_time:32740ms step_avg:43.25ms +step:758/1555 train_time:32803ms step_avg:43.28ms +step:759/1555 train_time:32861ms step_avg:43.30ms +step:760/1555 train_time:32925ms step_avg:43.32ms +step:761/1555 train_time:32983ms step_avg:43.34ms +step:762/1555 train_time:33047ms step_avg:43.37ms +step:763/1555 train_time:33104ms step_avg:43.39ms +step:764/1555 train_time:33167ms step_avg:43.41ms +step:765/1555 train_time:33224ms step_avg:43.43ms +step:766/1555 train_time:33288ms step_avg:43.46ms +step:767/1555 train_time:33346ms step_avg:43.48ms +step:768/1555 train_time:33411ms step_avg:43.50ms +step:769/1555 train_time:33470ms step_avg:43.52ms +step:770/1555 train_time:33536ms step_avg:43.55ms +step:771/1555 train_time:33593ms step_avg:43.57ms +step:772/1555 train_time:33659ms step_avg:43.60ms +step:773/1555 train_time:33716ms step_avg:43.62ms +step:774/1555 train_time:33780ms step_avg:43.64ms +step:775/1555 train_time:33838ms step_avg:43.66ms +step:776/1555 train_time:33902ms step_avg:43.69ms +step:777/1555 train_time:33960ms step_avg:43.71ms +step:778/1555 train_time:34024ms step_avg:43.73ms +step:779/1555 train_time:34081ms step_avg:43.75ms +step:780/1555 train_time:34145ms step_avg:43.78ms +step:781/1555 train_time:34202ms step_avg:43.79ms +step:782/1555 train_time:34267ms step_avg:43.82ms +step:783/1555 train_time:34324ms step_avg:43.84ms +step:784/1555 train_time:34388ms step_avg:43.86ms +step:785/1555 train_time:34447ms step_avg:43.88ms +step:786/1555 train_time:34512ms step_avg:43.91ms +step:787/1555 train_time:34570ms step_avg:43.93ms +step:788/1555 train_time:34635ms step_avg:43.95ms +step:789/1555 train_time:34694ms step_avg:43.97ms +step:790/1555 train_time:34758ms step_avg:44.00ms +step:791/1555 train_time:34815ms step_avg:44.01ms +step:792/1555 train_time:34880ms step_avg:44.04ms +step:793/1555 train_time:34937ms step_avg:44.06ms +step:794/1555 train_time:35002ms step_avg:44.08ms +step:795/1555 train_time:35059ms step_avg:44.10ms +step:796/1555 train_time:35123ms step_avg:44.12ms +step:797/1555 train_time:35181ms step_avg:44.14ms +step:798/1555 train_time:35245ms step_avg:44.17ms +step:799/1555 train_time:35303ms step_avg:44.18ms +step:800/1555 train_time:35367ms step_avg:44.21ms +step:801/1555 train_time:35427ms step_avg:44.23ms +step:802/1555 train_time:35490ms step_avg:44.25ms +step:803/1555 train_time:35549ms step_avg:44.27ms +step:804/1555 train_time:35614ms step_avg:44.30ms +step:805/1555 train_time:35672ms step_avg:44.31ms +step:806/1555 train_time:35735ms step_avg:44.34ms +step:807/1555 train_time:35792ms step_avg:44.35ms +step:808/1555 train_time:35856ms step_avg:44.38ms +step:809/1555 train_time:35913ms step_avg:44.39ms +step:810/1555 train_time:35978ms step_avg:44.42ms +step:811/1555 train_time:36035ms step_avg:44.43ms +step:812/1555 train_time:36101ms step_avg:44.46ms +step:813/1555 train_time:36159ms step_avg:44.48ms +step:814/1555 train_time:36223ms step_avg:44.50ms +step:815/1555 train_time:36282ms step_avg:44.52ms +step:816/1555 train_time:36347ms step_avg:44.54ms +step:817/1555 train_time:36404ms step_avg:44.56ms +step:818/1555 train_time:36468ms step_avg:44.58ms +step:819/1555 train_time:36526ms step_avg:44.60ms +step:820/1555 train_time:36590ms step_avg:44.62ms +step:821/1555 train_time:36649ms step_avg:44.64ms +step:822/1555 train_time:36712ms step_avg:44.66ms +step:823/1555 train_time:36770ms step_avg:44.68ms +step:824/1555 train_time:36835ms step_avg:44.70ms +step:825/1555 train_time:36893ms step_avg:44.72ms +step:826/1555 train_time:36958ms step_avg:44.74ms +step:827/1555 train_time:37014ms step_avg:44.76ms +step:828/1555 train_time:37080ms step_avg:44.78ms +step:829/1555 train_time:37137ms step_avg:44.80ms +step:830/1555 train_time:37200ms step_avg:44.82ms +step:831/1555 train_time:37258ms step_avg:44.84ms +step:832/1555 train_time:37323ms step_avg:44.86ms +step:833/1555 train_time:37380ms step_avg:44.87ms +step:834/1555 train_time:37445ms step_avg:44.90ms +step:835/1555 train_time:37504ms step_avg:44.91ms +step:836/1555 train_time:37568ms step_avg:44.94ms +step:837/1555 train_time:37627ms step_avg:44.95ms +step:838/1555 train_time:37691ms step_avg:44.98ms +step:839/1555 train_time:37749ms step_avg:44.99ms +step:840/1555 train_time:37813ms step_avg:45.02ms +step:841/1555 train_time:37870ms step_avg:45.03ms +step:842/1555 train_time:37934ms step_avg:45.05ms +step:843/1555 train_time:37992ms step_avg:45.07ms +step:844/1555 train_time:38056ms step_avg:45.09ms +step:845/1555 train_time:38113ms step_avg:45.10ms +step:846/1555 train_time:38178ms step_avg:45.13ms +step:847/1555 train_time:38236ms step_avg:45.14ms +step:848/1555 train_time:38301ms step_avg:45.17ms +step:849/1555 train_time:38360ms step_avg:45.18ms +step:850/1555 train_time:38424ms step_avg:45.20ms +step:851/1555 train_time:38482ms step_avg:45.22ms +step:852/1555 train_time:38546ms step_avg:45.24ms +step:853/1555 train_time:38604ms step_avg:45.26ms +step:854/1555 train_time:38668ms step_avg:45.28ms +step:855/1555 train_time:38726ms step_avg:45.29ms +step:856/1555 train_time:38791ms step_avg:45.32ms +step:857/1555 train_time:38850ms step_avg:45.33ms +step:858/1555 train_time:38915ms step_avg:45.36ms +step:859/1555 train_time:38972ms step_avg:45.37ms +step:860/1555 train_time:39036ms step_avg:45.39ms +step:861/1555 train_time:39093ms step_avg:45.40ms +step:862/1555 train_time:39158ms step_avg:45.43ms +step:863/1555 train_time:39216ms step_avg:45.44ms +step:864/1555 train_time:39280ms step_avg:45.46ms +step:865/1555 train_time:39338ms step_avg:45.48ms +step:866/1555 train_time:39402ms step_avg:45.50ms +step:867/1555 train_time:39460ms step_avg:45.51ms +step:868/1555 train_time:39526ms step_avg:45.54ms +step:869/1555 train_time:39584ms step_avg:45.55ms +step:870/1555 train_time:39648ms step_avg:45.57ms +step:871/1555 train_time:39706ms step_avg:45.59ms +step:872/1555 train_time:39770ms step_avg:45.61ms +step:873/1555 train_time:39829ms step_avg:45.62ms +step:874/1555 train_time:39892ms step_avg:45.64ms +step:875/1555 train_time:39950ms step_avg:45.66ms +step:876/1555 train_time:40014ms step_avg:45.68ms +step:877/1555 train_time:40072ms step_avg:45.69ms +step:878/1555 train_time:40136ms step_avg:45.71ms +step:879/1555 train_time:40193ms step_avg:45.73ms +step:880/1555 train_time:40257ms step_avg:45.75ms +step:881/1555 train_time:40314ms step_avg:45.76ms +step:882/1555 train_time:40379ms step_avg:45.78ms +step:883/1555 train_time:40437ms step_avg:45.79ms +step:884/1555 train_time:40502ms step_avg:45.82ms +step:885/1555 train_time:40561ms step_avg:45.83ms +step:886/1555 train_time:40626ms step_avg:45.85ms +step:887/1555 train_time:40683ms step_avg:45.87ms +step:888/1555 train_time:40747ms step_avg:45.89ms +step:889/1555 train_time:40805ms step_avg:45.90ms +step:890/1555 train_time:40869ms step_avg:45.92ms +step:891/1555 train_time:40927ms step_avg:45.93ms +step:892/1555 train_time:40992ms step_avg:45.96ms +step:893/1555 train_time:41051ms step_avg:45.97ms +step:894/1555 train_time:41114ms step_avg:45.99ms +step:895/1555 train_time:41171ms step_avg:46.00ms +step:896/1555 train_time:41236ms step_avg:46.02ms +step:897/1555 train_time:41292ms step_avg:46.03ms +step:898/1555 train_time:41357ms step_avg:46.06ms +step:899/1555 train_time:41415ms step_avg:46.07ms +step:900/1555 train_time:41481ms step_avg:46.09ms +step:901/1555 train_time:41539ms step_avg:46.10ms +step:902/1555 train_time:41603ms step_avg:46.12ms +step:903/1555 train_time:41661ms step_avg:46.14ms +step:904/1555 train_time:41725ms step_avg:46.16ms +step:905/1555 train_time:41784ms step_avg:46.17ms +step:906/1555 train_time:41848ms step_avg:46.19ms +step:907/1555 train_time:41905ms step_avg:46.20ms +step:908/1555 train_time:41969ms step_avg:46.22ms +step:909/1555 train_time:42028ms step_avg:46.24ms +step:910/1555 train_time:42092ms step_avg:46.26ms +step:911/1555 train_time:42151ms step_avg:46.27ms +step:912/1555 train_time:42213ms step_avg:46.29ms +step:913/1555 train_time:42271ms step_avg:46.30ms +step:914/1555 train_time:42335ms step_avg:46.32ms +step:915/1555 train_time:42392ms step_avg:46.33ms +step:916/1555 train_time:42458ms step_avg:46.35ms +step:917/1555 train_time:42515ms step_avg:46.36ms +step:918/1555 train_time:42581ms step_avg:46.38ms +step:919/1555 train_time:42638ms step_avg:46.40ms +step:920/1555 train_time:42703ms step_avg:46.42ms +step:921/1555 train_time:42761ms step_avg:46.43ms +step:922/1555 train_time:42827ms step_avg:46.45ms +step:923/1555 train_time:42884ms step_avg:46.46ms +step:924/1555 train_time:42949ms step_avg:46.48ms +step:925/1555 train_time:43007ms step_avg:46.49ms +step:926/1555 train_time:43071ms step_avg:46.51ms +step:927/1555 train_time:43128ms step_avg:46.52ms +step:928/1555 train_time:43192ms step_avg:46.54ms +step:929/1555 train_time:43251ms step_avg:46.56ms +step:930/1555 train_time:43315ms step_avg:46.58ms +step:931/1555 train_time:43372ms step_avg:46.59ms +step:932/1555 train_time:43436ms step_avg:46.61ms +step:933/1555 train_time:43494ms step_avg:46.62ms +step:934/1555 train_time:43559ms step_avg:46.64ms +step:935/1555 train_time:43617ms step_avg:46.65ms +step:936/1555 train_time:43682ms step_avg:46.67ms +step:937/1555 train_time:43740ms step_avg:46.68ms +step:938/1555 train_time:43804ms step_avg:46.70ms +step:939/1555 train_time:43862ms step_avg:46.71ms +step:940/1555 train_time:43926ms step_avg:46.73ms +step:941/1555 train_time:43985ms step_avg:46.74ms +step:942/1555 train_time:44049ms step_avg:46.76ms +step:943/1555 train_time:44107ms step_avg:46.77ms +step:944/1555 train_time:44171ms step_avg:46.79ms +step:945/1555 train_time:44229ms step_avg:46.80ms +step:946/1555 train_time:44293ms step_avg:46.82ms +step:947/1555 train_time:44351ms step_avg:46.83ms +step:948/1555 train_time:44416ms step_avg:46.85ms +step:949/1555 train_time:44473ms step_avg:46.86ms +step:950/1555 train_time:44537ms step_avg:46.88ms +step:951/1555 train_time:44594ms step_avg:46.89ms +step:952/1555 train_time:44659ms step_avg:46.91ms +step:953/1555 train_time:44717ms step_avg:46.92ms +step:954/1555 train_time:44783ms step_avg:46.94ms +step:955/1555 train_time:44840ms step_avg:46.95ms +step:956/1555 train_time:44906ms step_avg:46.97ms +step:957/1555 train_time:44963ms step_avg:46.98ms +step:958/1555 train_time:45026ms step_avg:47.00ms +step:959/1555 train_time:45085ms step_avg:47.01ms +step:960/1555 train_time:45149ms step_avg:47.03ms +step:961/1555 train_time:45206ms step_avg:47.04ms +step:962/1555 train_time:45270ms step_avg:47.06ms +step:963/1555 train_time:45329ms step_avg:47.07ms +step:964/1555 train_time:45393ms step_avg:47.09ms +step:965/1555 train_time:45451ms step_avg:47.10ms +step:966/1555 train_time:45514ms step_avg:47.12ms +step:967/1555 train_time:45571ms step_avg:47.13ms +step:968/1555 train_time:45636ms step_avg:47.14ms +step:969/1555 train_time:45694ms step_avg:47.16ms +step:970/1555 train_time:45759ms step_avg:47.17ms +step:971/1555 train_time:45816ms step_avg:47.18ms +step:972/1555 train_time:45881ms step_avg:47.20ms +step:973/1555 train_time:45939ms step_avg:47.21ms +step:974/1555 train_time:46004ms step_avg:47.23ms +step:975/1555 train_time:46062ms step_avg:47.24ms +step:976/1555 train_time:46127ms step_avg:47.26ms +step:977/1555 train_time:46184ms step_avg:47.27ms +step:978/1555 train_time:46249ms step_avg:47.29ms +step:979/1555 train_time:46307ms step_avg:47.30ms +step:980/1555 train_time:46371ms step_avg:47.32ms +step:981/1555 train_time:46429ms step_avg:47.33ms +step:982/1555 train_time:46492ms step_avg:47.34ms +step:983/1555 train_time:46551ms step_avg:47.36ms +step:984/1555 train_time:46614ms step_avg:47.37ms +step:985/1555 train_time:46672ms step_avg:47.38ms +step:986/1555 train_time:46736ms step_avg:47.40ms +step:987/1555 train_time:46794ms step_avg:47.41ms +step:988/1555 train_time:46860ms step_avg:47.43ms +step:989/1555 train_time:46918ms step_avg:47.44ms +step:990/1555 train_time:46983ms step_avg:47.46ms +step:991/1555 train_time:47041ms step_avg:47.47ms +step:992/1555 train_time:47105ms step_avg:47.49ms +step:993/1555 train_time:47162ms step_avg:47.49ms +step:994/1555 train_time:47227ms step_avg:47.51ms +step:995/1555 train_time:47285ms step_avg:47.52ms +step:996/1555 train_time:47350ms step_avg:47.54ms +step:997/1555 train_time:47407ms step_avg:47.55ms +step:998/1555 train_time:47471ms step_avg:47.57ms +step:999/1555 train_time:47530ms step_avg:47.58ms +step:1000/1555 train_time:47593ms step_avg:47.59ms +step:1000/1555 val_loss:3.5703 train_time:47675ms step_avg:47.68ms +step:1001/1555 train_time:47698ms step_avg:47.65ms +step:1002/1555 train_time:47720ms step_avg:47.62ms +step:1003/1555 train_time:47773ms step_avg:47.63ms +step:1004/1555 train_time:47841ms step_avg:47.65ms +step:1005/1555 train_time:47902ms step_avg:47.66ms +step:1006/1555 train_time:47966ms step_avg:47.68ms +step:1007/1555 train_time:48023ms step_avg:47.69ms +step:1008/1555 train_time:48086ms step_avg:47.70ms +step:1009/1555 train_time:48143ms step_avg:47.71ms +step:1010/1555 train_time:48206ms step_avg:47.73ms +step:1011/1555 train_time:48267ms step_avg:47.74ms +step:1012/1555 train_time:48352ms step_avg:47.78ms +step:1013/1555 train_time:48434ms step_avg:47.81ms +step:1014/1555 train_time:48524ms step_avg:47.85ms +step:1015/1555 train_time:48609ms step_avg:47.89ms +step:1016/1555 train_time:48700ms step_avg:47.93ms +step:1017/1555 train_time:48786ms step_avg:47.97ms +step:1018/1555 train_time:48877ms step_avg:48.01ms +step:1019/1555 train_time:48964ms step_avg:48.05ms +step:1020/1555 train_time:49054ms step_avg:48.09ms +step:1021/1555 train_time:49138ms step_avg:48.13ms +step:1022/1555 train_time:49229ms step_avg:48.17ms +step:1023/1555 train_time:49313ms step_avg:48.20ms +step:1024/1555 train_time:49401ms step_avg:48.24ms +step:1025/1555 train_time:49484ms step_avg:48.28ms +step:1026/1555 train_time:49573ms step_avg:48.32ms +step:1027/1555 train_time:49657ms step_avg:48.35ms +step:1028/1555 train_time:49749ms step_avg:48.39ms +step:1029/1555 train_time:49834ms step_avg:48.43ms +step:1030/1555 train_time:49925ms step_avg:48.47ms +step:1031/1555 train_time:50010ms step_avg:48.51ms +step:1032/1555 train_time:50100ms step_avg:48.55ms +step:1033/1555 train_time:50184ms step_avg:48.58ms +step:1034/1555 train_time:50273ms step_avg:48.62ms +step:1035/1555 train_time:50356ms step_avg:48.65ms +step:1036/1555 train_time:50445ms step_avg:48.69ms +step:1037/1555 train_time:50529ms step_avg:48.73ms +step:1038/1555 train_time:50618ms step_avg:48.77ms +step:1039/1555 train_time:50704ms step_avg:48.80ms +step:1040/1555 train_time:50794ms step_avg:48.84ms +step:1041/1555 train_time:50878ms step_avg:48.87ms +step:1042/1555 train_time:50970ms step_avg:48.92ms +step:1043/1555 train_time:51052ms step_avg:48.95ms +step:1044/1555 train_time:51142ms step_avg:48.99ms +step:1045/1555 train_time:51228ms step_avg:49.02ms +step:1046/1555 train_time:51317ms step_avg:49.06ms +step:1047/1555 train_time:51402ms step_avg:49.09ms +step:1048/1555 train_time:51493ms step_avg:49.13ms +step:1049/1555 train_time:51576ms step_avg:49.17ms +step:1050/1555 train_time:51666ms step_avg:49.21ms +step:1051/1555 train_time:51751ms step_avg:49.24ms +step:1052/1555 train_time:51841ms step_avg:49.28ms +step:1053/1555 train_time:51927ms step_avg:49.31ms +step:1054/1555 train_time:52017ms step_avg:49.35ms +step:1055/1555 train_time:52102ms step_avg:49.39ms +step:1056/1555 train_time:52193ms step_avg:49.42ms +step:1057/1555 train_time:52276ms step_avg:49.46ms +step:1058/1555 train_time:52366ms step_avg:49.50ms +step:1059/1555 train_time:52451ms step_avg:49.53ms +step:1060/1555 train_time:52540ms step_avg:49.57ms +step:1061/1555 train_time:52625ms step_avg:49.60ms +step:1062/1555 train_time:52714ms step_avg:49.64ms +step:1063/1555 train_time:52798ms step_avg:49.67ms +step:1064/1555 train_time:52889ms step_avg:49.71ms +step:1065/1555 train_time:52972ms step_avg:49.74ms +step:1066/1555 train_time:53063ms step_avg:49.78ms +step:1067/1555 train_time:53147ms step_avg:49.81ms +step:1068/1555 train_time:53236ms step_avg:49.85ms +step:1069/1555 train_time:53321ms step_avg:49.88ms +step:1070/1555 train_time:53412ms step_avg:49.92ms +step:1071/1555 train_time:53495ms step_avg:49.95ms +step:1072/1555 train_time:53585ms step_avg:49.99ms +step:1073/1555 train_time:53668ms step_avg:50.02ms +step:1074/1555 train_time:53757ms step_avg:50.05ms +step:1075/1555 train_time:53842ms step_avg:50.09ms +step:1076/1555 train_time:53933ms step_avg:50.12ms +step:1077/1555 train_time:54016ms step_avg:50.15ms +step:1078/1555 train_time:54107ms step_avg:50.19ms +step:1079/1555 train_time:54191ms step_avg:50.22ms +step:1080/1555 train_time:54281ms step_avg:50.26ms +step:1081/1555 train_time:54365ms step_avg:50.29ms +step:1082/1555 train_time:54455ms step_avg:50.33ms +step:1083/1555 train_time:54539ms step_avg:50.36ms +step:1084/1555 train_time:54631ms step_avg:50.40ms +step:1085/1555 train_time:54715ms step_avg:50.43ms +step:1086/1555 train_time:54804ms step_avg:50.46ms +step:1087/1555 train_time:54889ms step_avg:50.50ms +step:1088/1555 train_time:54978ms step_avg:50.53ms +step:1089/1555 train_time:55062ms step_avg:50.56ms +step:1090/1555 train_time:55153ms step_avg:50.60ms +step:1091/1555 train_time:55237ms step_avg:50.63ms +step:1092/1555 train_time:55329ms step_avg:50.67ms +step:1093/1555 train_time:55412ms step_avg:50.70ms +step:1094/1555 train_time:55501ms step_avg:50.73ms +step:1095/1555 train_time:55587ms step_avg:50.76ms +step:1096/1555 train_time:55676ms step_avg:50.80ms +step:1097/1555 train_time:55760ms step_avg:50.83ms +step:1098/1555 train_time:55851ms step_avg:50.87ms +step:1099/1555 train_time:55935ms step_avg:50.90ms +step:1100/1555 train_time:56025ms step_avg:50.93ms +step:1101/1555 train_time:56109ms step_avg:50.96ms +step:1102/1555 train_time:56199ms step_avg:51.00ms +step:1103/1555 train_time:56283ms step_avg:51.03ms +step:1104/1555 train_time:56372ms step_avg:51.06ms +step:1105/1555 train_time:56455ms step_avg:51.09ms +step:1106/1555 train_time:56546ms step_avg:51.13ms +step:1107/1555 train_time:56631ms step_avg:51.16ms +step:1108/1555 train_time:56721ms step_avg:51.19ms +step:1109/1555 train_time:56805ms step_avg:51.22ms +step:1110/1555 train_time:56895ms step_avg:51.26ms +step:1111/1555 train_time:56979ms step_avg:51.29ms +step:1112/1555 train_time:57070ms step_avg:51.32ms +step:1113/1555 train_time:57154ms step_avg:51.35ms +step:1114/1555 train_time:57244ms step_avg:51.39ms +step:1115/1555 train_time:57328ms step_avg:51.42ms +step:1116/1555 train_time:57417ms step_avg:51.45ms +step:1117/1555 train_time:57502ms step_avg:51.48ms +step:1118/1555 train_time:57593ms step_avg:51.51ms +step:1119/1555 train_time:57676ms step_avg:51.54ms +step:1120/1555 train_time:57767ms step_avg:51.58ms +step:1121/1555 train_time:57851ms step_avg:51.61ms +step:1122/1555 train_time:57940ms step_avg:51.64ms +step:1123/1555 train_time:58025ms step_avg:51.67ms +step:1124/1555 train_time:58116ms step_avg:51.70ms +step:1125/1555 train_time:58199ms step_avg:51.73ms +step:1126/1555 train_time:58290ms step_avg:51.77ms +step:1127/1555 train_time:58372ms step_avg:51.79ms +step:1128/1555 train_time:58463ms step_avg:51.83ms +step:1129/1555 train_time:58548ms step_avg:51.86ms +step:1130/1555 train_time:58637ms step_avg:51.89ms +step:1131/1555 train_time:58722ms step_avg:51.92ms +step:1132/1555 train_time:58812ms step_avg:51.95ms +step:1133/1555 train_time:58896ms step_avg:51.98ms +step:1134/1555 train_time:58987ms step_avg:52.02ms +step:1135/1555 train_time:59071ms step_avg:52.04ms +step:1136/1555 train_time:59162ms step_avg:52.08ms +step:1137/1555 train_time:59246ms step_avg:52.11ms +step:1138/1555 train_time:59334ms step_avg:52.14ms +step:1139/1555 train_time:59418ms step_avg:52.17ms +step:1140/1555 train_time:59509ms step_avg:52.20ms +step:1141/1555 train_time:59593ms step_avg:52.23ms +step:1142/1555 train_time:59683ms step_avg:52.26ms +step:1143/1555 train_time:59767ms step_avg:52.29ms +step:1144/1555 train_time:59857ms step_avg:52.32ms +step:1145/1555 train_time:59941ms step_avg:52.35ms +step:1146/1555 train_time:60032ms step_avg:52.38ms +step:1147/1555 train_time:60115ms step_avg:52.41ms +step:1148/1555 train_time:60206ms step_avg:52.44ms +step:1149/1555 train_time:60291ms step_avg:52.47ms +step:1150/1555 train_time:60381ms step_avg:52.51ms +step:1151/1555 train_time:60464ms step_avg:52.53ms +step:1152/1555 train_time:60554ms step_avg:52.56ms +step:1153/1555 train_time:60637ms step_avg:52.59ms +step:1154/1555 train_time:60728ms step_avg:52.62ms +step:1155/1555 train_time:60812ms step_avg:52.65ms +step:1156/1555 train_time:60902ms step_avg:52.68ms +step:1157/1555 train_time:60987ms step_avg:52.71ms +step:1158/1555 train_time:61076ms step_avg:52.74ms +step:1159/1555 train_time:61161ms step_avg:52.77ms +step:1160/1555 train_time:61251ms step_avg:52.80ms +step:1161/1555 train_time:61335ms step_avg:52.83ms +step:1162/1555 train_time:61424ms step_avg:52.86ms +step:1163/1555 train_time:61509ms step_avg:52.89ms +step:1164/1555 train_time:61598ms step_avg:52.92ms +step:1165/1555 train_time:61684ms step_avg:52.95ms +step:1166/1555 train_time:61773ms step_avg:52.98ms +step:1167/1555 train_time:61857ms step_avg:53.01ms +step:1168/1555 train_time:61948ms step_avg:53.04ms +step:1169/1555 train_time:62031ms step_avg:53.06ms +step:1170/1555 train_time:62121ms step_avg:53.09ms +step:1171/1555 train_time:62205ms step_avg:53.12ms +step:1172/1555 train_time:62295ms step_avg:53.15ms +step:1173/1555 train_time:62379ms step_avg:53.18ms +step:1174/1555 train_time:62469ms step_avg:53.21ms +step:1175/1555 train_time:62553ms step_avg:53.24ms +step:1176/1555 train_time:62643ms step_avg:53.27ms +step:1177/1555 train_time:62727ms step_avg:53.29ms +step:1178/1555 train_time:62817ms step_avg:53.33ms +step:1179/1555 train_time:62902ms step_avg:53.35ms +step:1180/1555 train_time:62993ms step_avg:53.38ms +step:1181/1555 train_time:63076ms step_avg:53.41ms +step:1182/1555 train_time:63167ms step_avg:53.44ms +step:1183/1555 train_time:63251ms step_avg:53.47ms +step:1184/1555 train_time:63341ms step_avg:53.50ms +step:1185/1555 train_time:63425ms step_avg:53.52ms +step:1186/1555 train_time:63514ms step_avg:53.55ms +step:1187/1555 train_time:63598ms step_avg:53.58ms +step:1188/1555 train_time:63689ms step_avg:53.61ms +step:1189/1555 train_time:63772ms step_avg:53.64ms +step:1190/1555 train_time:63863ms step_avg:53.67ms +step:1191/1555 train_time:63949ms step_avg:53.69ms +step:1192/1555 train_time:64038ms step_avg:53.72ms +step:1193/1555 train_time:64123ms step_avg:53.75ms +step:1194/1555 train_time:64213ms step_avg:53.78ms +step:1195/1555 train_time:64297ms step_avg:53.80ms +step:1196/1555 train_time:64389ms step_avg:53.84ms +step:1197/1555 train_time:64472ms step_avg:53.86ms +step:1198/1555 train_time:64563ms step_avg:53.89ms +step:1199/1555 train_time:64646ms step_avg:53.92ms +step:1200/1555 train_time:64736ms step_avg:53.95ms +step:1201/1555 train_time:64820ms step_avg:53.97ms +step:1202/1555 train_time:64911ms step_avg:54.00ms +step:1203/1555 train_time:64994ms step_avg:54.03ms +step:1204/1555 train_time:65085ms step_avg:54.06ms +step:1205/1555 train_time:65169ms step_avg:54.08ms +step:1206/1555 train_time:65259ms step_avg:54.11ms +step:1207/1555 train_time:65342ms step_avg:54.14ms +step:1208/1555 train_time:65433ms step_avg:54.17ms +step:1209/1555 train_time:65516ms step_avg:54.19ms +step:1210/1555 train_time:65608ms step_avg:54.22ms +step:1211/1555 train_time:65691ms step_avg:54.25ms +step:1212/1555 train_time:65783ms step_avg:54.28ms +step:1213/1555 train_time:65866ms step_avg:54.30ms +step:1214/1555 train_time:65957ms step_avg:54.33ms +step:1215/1555 train_time:66039ms step_avg:54.35ms +step:1216/1555 train_time:66129ms step_avg:54.38ms +step:1217/1555 train_time:66213ms step_avg:54.41ms +step:1218/1555 train_time:66304ms step_avg:54.44ms +step:1219/1555 train_time:66389ms step_avg:54.46ms +step:1220/1555 train_time:66478ms step_avg:54.49ms +step:1221/1555 train_time:66562ms step_avg:54.51ms +step:1222/1555 train_time:66652ms step_avg:54.54ms +step:1223/1555 train_time:66736ms step_avg:54.57ms +step:1224/1555 train_time:66826ms step_avg:54.60ms +step:1225/1555 train_time:66910ms step_avg:54.62ms +step:1226/1555 train_time:66999ms step_avg:54.65ms +step:1227/1555 train_time:67085ms step_avg:54.67ms +step:1228/1555 train_time:67173ms step_avg:54.70ms +step:1229/1555 train_time:67258ms step_avg:54.73ms +step:1230/1555 train_time:67348ms step_avg:54.75ms +step:1231/1555 train_time:67432ms step_avg:54.78ms +step:1232/1555 train_time:67522ms step_avg:54.81ms +step:1233/1555 train_time:67606ms step_avg:54.83ms +step:1234/1555 train_time:67695ms step_avg:54.86ms +step:1235/1555 train_time:67781ms step_avg:54.88ms +step:1236/1555 train_time:67872ms step_avg:54.91ms +step:1237/1555 train_time:67956ms step_avg:54.94ms +step:1238/1555 train_time:68047ms step_avg:54.97ms +step:1239/1555 train_time:68131ms step_avg:54.99ms +step:1240/1555 train_time:68221ms step_avg:55.02ms +step:1241/1555 train_time:68306ms step_avg:55.04ms +step:1242/1555 train_time:68395ms step_avg:55.07ms +step:1243/1555 train_time:68478ms step_avg:55.09ms +step:1244/1555 train_time:68569ms step_avg:55.12ms +step:1245/1555 train_time:68654ms step_avg:55.14ms +step:1246/1555 train_time:68743ms step_avg:55.17ms +step:1247/1555 train_time:68829ms step_avg:55.20ms +step:1248/1555 train_time:68917ms step_avg:55.22ms +step:1249/1555 train_time:69002ms step_avg:55.25ms +step:1250/1555 train_time:69093ms step_avg:55.27ms +step:1250/1555 val_loss:3.3964 train_time:69208ms step_avg:55.37ms +step:1251/1555 train_time:69229ms step_avg:55.34ms +step:1252/1555 train_time:69270ms step_avg:55.33ms +step:1253/1555 train_time:69354ms step_avg:55.35ms +step:1254/1555 train_time:69447ms step_avg:55.38ms +step:1255/1555 train_time:69531ms step_avg:55.40ms +step:1256/1555 train_time:69621ms step_avg:55.43ms +step:1257/1555 train_time:69705ms step_avg:55.45ms +step:1258/1555 train_time:69795ms step_avg:55.48ms +step:1259/1555 train_time:69878ms step_avg:55.50ms +step:1260/1555 train_time:69968ms step_avg:55.53ms +step:1261/1555 train_time:70051ms step_avg:55.55ms +step:1262/1555 train_time:70141ms step_avg:55.58ms +step:1263/1555 train_time:70228ms step_avg:55.60ms +step:1264/1555 train_time:70318ms step_avg:55.63ms +step:1265/1555 train_time:70405ms step_avg:55.66ms +step:1266/1555 train_time:70497ms step_avg:55.69ms +step:1267/1555 train_time:70582ms step_avg:55.71ms +step:1268/1555 train_time:70672ms step_avg:55.73ms +step:1269/1555 train_time:70755ms step_avg:55.76ms +step:1270/1555 train_time:70845ms step_avg:55.78ms +step:1271/1555 train_time:70928ms step_avg:55.81ms +step:1272/1555 train_time:71017ms step_avg:55.83ms +step:1273/1555 train_time:71100ms step_avg:55.85ms +step:1274/1555 train_time:71192ms step_avg:55.88ms +step:1275/1555 train_time:71276ms step_avg:55.90ms +step:1276/1555 train_time:71368ms step_avg:55.93ms +step:1277/1555 train_time:71454ms step_avg:55.95ms +step:1278/1555 train_time:71544ms step_avg:55.98ms +step:1279/1555 train_time:71627ms step_avg:56.00ms +step:1280/1555 train_time:71717ms step_avg:56.03ms +step:1281/1555 train_time:71800ms step_avg:56.05ms +step:1282/1555 train_time:71892ms step_avg:56.08ms +step:1283/1555 train_time:71973ms step_avg:56.10ms +step:1284/1555 train_time:72064ms step_avg:56.12ms +step:1285/1555 train_time:72147ms step_avg:56.15ms +step:1286/1555 train_time:72238ms step_avg:56.17ms +step:1287/1555 train_time:72323ms step_avg:56.20ms +step:1288/1555 train_time:72414ms step_avg:56.22ms +step:1289/1555 train_time:72498ms step_avg:56.24ms +step:1290/1555 train_time:72590ms step_avg:56.27ms +step:1291/1555 train_time:72672ms step_avg:56.29ms +step:1292/1555 train_time:72762ms step_avg:56.32ms +step:1293/1555 train_time:72846ms step_avg:56.34ms +step:1294/1555 train_time:72936ms step_avg:56.36ms +step:1295/1555 train_time:73021ms step_avg:56.39ms +step:1296/1555 train_time:73111ms step_avg:56.41ms +step:1297/1555 train_time:73195ms step_avg:56.43ms +step:1298/1555 train_time:73285ms step_avg:56.46ms +step:1299/1555 train_time:73369ms step_avg:56.48ms +step:1300/1555 train_time:73459ms step_avg:56.51ms +step:1301/1555 train_time:73544ms step_avg:56.53ms +step:1302/1555 train_time:73634ms step_avg:56.55ms +step:1303/1555 train_time:73717ms step_avg:56.58ms +step:1304/1555 train_time:73807ms step_avg:56.60ms +step:1305/1555 train_time:73891ms step_avg:56.62ms +step:1306/1555 train_time:73980ms step_avg:56.65ms +step:1307/1555 train_time:74065ms step_avg:56.67ms +step:1308/1555 train_time:74155ms step_avg:56.69ms +step:1309/1555 train_time:74238ms step_avg:56.71ms +step:1310/1555 train_time:74331ms step_avg:56.74ms +step:1311/1555 train_time:74413ms step_avg:56.76ms +step:1312/1555 train_time:74504ms step_avg:56.79ms +step:1313/1555 train_time:74589ms step_avg:56.81ms +step:1314/1555 train_time:74678ms step_avg:56.83ms +step:1315/1555 train_time:74763ms step_avg:56.85ms +step:1316/1555 train_time:74853ms step_avg:56.88ms +step:1317/1555 train_time:74936ms step_avg:56.90ms +step:1318/1555 train_time:75026ms step_avg:56.92ms +step:1319/1555 train_time:75110ms step_avg:56.94ms +step:1320/1555 train_time:75201ms step_avg:56.97ms +step:1321/1555 train_time:75285ms step_avg:56.99ms +step:1322/1555 train_time:75376ms step_avg:57.02ms +step:1323/1555 train_time:75461ms step_avg:57.04ms +step:1324/1555 train_time:75551ms step_avg:57.06ms +step:1325/1555 train_time:75634ms step_avg:57.08ms +step:1326/1555 train_time:75724ms step_avg:57.11ms +step:1327/1555 train_time:75808ms step_avg:57.13ms +step:1328/1555 train_time:75898ms step_avg:57.15ms +step:1329/1555 train_time:75981ms step_avg:57.17ms +step:1330/1555 train_time:76071ms step_avg:57.20ms +step:1331/1555 train_time:76155ms step_avg:57.22ms +step:1332/1555 train_time:76245ms step_avg:57.24ms +step:1333/1555 train_time:76329ms step_avg:57.26ms +step:1334/1555 train_time:76418ms step_avg:57.29ms +step:1335/1555 train_time:76504ms step_avg:57.31ms +step:1336/1555 train_time:76594ms step_avg:57.33ms +step:1337/1555 train_time:76676ms step_avg:57.35ms +step:1338/1555 train_time:76768ms step_avg:57.37ms +step:1339/1555 train_time:76852ms step_avg:57.40ms +step:1340/1555 train_time:76941ms step_avg:57.42ms +step:1341/1555 train_time:77025ms step_avg:57.44ms +step:1342/1555 train_time:77116ms step_avg:57.46ms +step:1343/1555 train_time:77199ms step_avg:57.48ms +step:1344/1555 train_time:77290ms step_avg:57.51ms +step:1345/1555 train_time:77373ms step_avg:57.53ms +step:1346/1555 train_time:77463ms step_avg:57.55ms +step:1347/1555 train_time:77548ms step_avg:57.57ms +step:1348/1555 train_time:77637ms step_avg:57.59ms +step:1349/1555 train_time:77722ms step_avg:57.61ms +step:1350/1555 train_time:77814ms step_avg:57.64ms +step:1351/1555 train_time:77897ms step_avg:57.66ms +step:1352/1555 train_time:77987ms step_avg:57.68ms +step:1353/1555 train_time:78071ms step_avg:57.70ms +step:1354/1555 train_time:78160ms step_avg:57.73ms +step:1355/1555 train_time:78244ms step_avg:57.74ms +step:1356/1555 train_time:78334ms step_avg:57.77ms +step:1357/1555 train_time:78419ms step_avg:57.79ms +step:1358/1555 train_time:78510ms step_avg:57.81ms +step:1359/1555 train_time:78593ms step_avg:57.83ms +step:1360/1555 train_time:78684ms step_avg:57.86ms +step:1361/1555 train_time:78768ms step_avg:57.87ms +step:1362/1555 train_time:78858ms step_avg:57.90ms +step:1363/1555 train_time:78941ms step_avg:57.92ms +step:1364/1555 train_time:79032ms step_avg:57.94ms +step:1365/1555 train_time:79116ms step_avg:57.96ms +step:1366/1555 train_time:79207ms step_avg:57.98ms +step:1367/1555 train_time:79291ms step_avg:58.00ms +step:1368/1555 train_time:79379ms step_avg:58.03ms +step:1369/1555 train_time:79464ms step_avg:58.05ms +step:1370/1555 train_time:79555ms step_avg:58.07ms +step:1371/1555 train_time:79638ms step_avg:58.09ms +step:1372/1555 train_time:79729ms step_avg:58.11ms +step:1373/1555 train_time:79813ms step_avg:58.13ms +step:1374/1555 train_time:79903ms step_avg:58.15ms +step:1375/1555 train_time:79988ms step_avg:58.17ms +step:1376/1555 train_time:80077ms step_avg:58.20ms +step:1377/1555 train_time:80161ms step_avg:58.21ms +step:1378/1555 train_time:80252ms step_avg:58.24ms +step:1379/1555 train_time:80336ms step_avg:58.26ms +step:1380/1555 train_time:80427ms step_avg:58.28ms +step:1381/1555 train_time:80511ms step_avg:58.30ms +step:1382/1555 train_time:80601ms step_avg:58.32ms +step:1383/1555 train_time:80684ms step_avg:58.34ms +step:1384/1555 train_time:80775ms step_avg:58.36ms +step:1385/1555 train_time:80859ms step_avg:58.38ms +step:1386/1555 train_time:80949ms step_avg:58.40ms +step:1387/1555 train_time:81033ms step_avg:58.42ms +step:1388/1555 train_time:81124ms step_avg:58.45ms +step:1389/1555 train_time:81208ms step_avg:58.47ms +step:1390/1555 train_time:81300ms step_avg:58.49ms +step:1391/1555 train_time:81382ms step_avg:58.51ms +step:1392/1555 train_time:81473ms step_avg:58.53ms +step:1393/1555 train_time:81556ms step_avg:58.55ms +step:1394/1555 train_time:81646ms step_avg:58.57ms +step:1395/1555 train_time:81732ms step_avg:58.59ms +step:1396/1555 train_time:81821ms step_avg:58.61ms +step:1397/1555 train_time:81906ms step_avg:58.63ms +step:1398/1555 train_time:81998ms step_avg:58.65ms +step:1399/1555 train_time:82081ms step_avg:58.67ms +step:1400/1555 train_time:82171ms step_avg:58.69ms +step:1401/1555 train_time:82255ms step_avg:58.71ms +step:1402/1555 train_time:82346ms step_avg:58.73ms +step:1403/1555 train_time:82431ms step_avg:58.75ms +step:1404/1555 train_time:82519ms step_avg:58.77ms +step:1405/1555 train_time:82603ms step_avg:58.79ms +step:1406/1555 train_time:82693ms step_avg:58.81ms +step:1407/1555 train_time:82776ms step_avg:58.83ms +step:1408/1555 train_time:82867ms step_avg:58.85ms +step:1409/1555 train_time:82951ms step_avg:58.87ms +step:1410/1555 train_time:83041ms step_avg:58.89ms +step:1411/1555 train_time:83126ms step_avg:58.91ms +step:1412/1555 train_time:83216ms step_avg:58.93ms +step:1413/1555 train_time:83300ms step_avg:58.95ms +step:1414/1555 train_time:83392ms step_avg:58.98ms +step:1415/1555 train_time:83474ms step_avg:58.99ms +step:1416/1555 train_time:83565ms step_avg:59.01ms +step:1417/1555 train_time:83649ms step_avg:59.03ms +step:1418/1555 train_time:83738ms step_avg:59.05ms +step:1419/1555 train_time:83823ms step_avg:59.07ms +step:1420/1555 train_time:83914ms step_avg:59.09ms +step:1421/1555 train_time:83997ms step_avg:59.11ms +step:1422/1555 train_time:84089ms step_avg:59.13ms +step:1423/1555 train_time:84172ms step_avg:59.15ms +step:1424/1555 train_time:84262ms step_avg:59.17ms +step:1425/1555 train_time:84347ms step_avg:59.19ms +step:1426/1555 train_time:84436ms step_avg:59.21ms +step:1427/1555 train_time:84520ms step_avg:59.23ms +step:1428/1555 train_time:84611ms step_avg:59.25ms +step:1429/1555 train_time:84695ms step_avg:59.27ms +step:1430/1555 train_time:84785ms step_avg:59.29ms +step:1431/1555 train_time:84869ms step_avg:59.31ms +step:1432/1555 train_time:84959ms step_avg:59.33ms +step:1433/1555 train_time:85043ms step_avg:59.35ms +step:1434/1555 train_time:85134ms step_avg:59.37ms +step:1435/1555 train_time:85218ms step_avg:59.39ms +step:1436/1555 train_time:85308ms step_avg:59.41ms +step:1437/1555 train_time:85392ms step_avg:59.42ms +step:1438/1555 train_time:85482ms step_avg:59.44ms +step:1439/1555 train_time:85567ms step_avg:59.46ms +step:1440/1555 train_time:85657ms step_avg:59.48ms +step:1441/1555 train_time:85740ms step_avg:59.50ms +step:1442/1555 train_time:85832ms step_avg:59.52ms +step:1443/1555 train_time:85915ms step_avg:59.54ms +step:1444/1555 train_time:86005ms step_avg:59.56ms +step:1445/1555 train_time:86090ms step_avg:59.58ms +step:1446/1555 train_time:86179ms step_avg:59.60ms +step:1447/1555 train_time:86265ms step_avg:59.62ms +step:1448/1555 train_time:86354ms step_avg:59.64ms +step:1449/1555 train_time:86438ms step_avg:59.65ms +step:1450/1555 train_time:86528ms step_avg:59.67ms +step:1451/1555 train_time:86613ms step_avg:59.69ms +step:1452/1555 train_time:86703ms step_avg:59.71ms +step:1453/1555 train_time:86786ms step_avg:59.73ms +step:1454/1555 train_time:86876ms step_avg:59.75ms +step:1455/1555 train_time:86960ms step_avg:59.77ms +step:1456/1555 train_time:87050ms step_avg:59.79ms +step:1457/1555 train_time:87134ms step_avg:59.80ms +step:1458/1555 train_time:87224ms step_avg:59.82ms +step:1459/1555 train_time:87309ms step_avg:59.84ms +step:1460/1555 train_time:87398ms step_avg:59.86ms +step:1461/1555 train_time:87482ms step_avg:59.88ms +step:1462/1555 train_time:87572ms step_avg:59.90ms +step:1463/1555 train_time:87656ms step_avg:59.92ms +step:1464/1555 train_time:87747ms step_avg:59.94ms +step:1465/1555 train_time:87832ms step_avg:59.95ms +step:1466/1555 train_time:87922ms step_avg:59.97ms +step:1467/1555 train_time:88006ms step_avg:59.99ms +step:1468/1555 train_time:88097ms step_avg:60.01ms +step:1469/1555 train_time:88181ms step_avg:60.03ms +step:1470/1555 train_time:88271ms step_avg:60.05ms +step:1471/1555 train_time:88356ms step_avg:60.07ms +step:1472/1555 train_time:88445ms step_avg:60.08ms +step:1473/1555 train_time:88529ms step_avg:60.10ms +step:1474/1555 train_time:88618ms step_avg:60.12ms +step:1475/1555 train_time:88704ms step_avg:60.14ms +step:1476/1555 train_time:88795ms step_avg:60.16ms +step:1477/1555 train_time:88878ms step_avg:60.17ms +step:1478/1555 train_time:88970ms step_avg:60.20ms +step:1479/1555 train_time:89053ms step_avg:60.21ms +step:1480/1555 train_time:89143ms step_avg:60.23ms +step:1481/1555 train_time:89229ms step_avg:60.25ms +step:1482/1555 train_time:89318ms step_avg:60.27ms +step:1483/1555 train_time:89403ms step_avg:60.28ms +step:1484/1555 train_time:89494ms step_avg:60.31ms +step:1485/1555 train_time:89576ms step_avg:60.32ms +step:1486/1555 train_time:89666ms step_avg:60.34ms +step:1487/1555 train_time:89751ms step_avg:60.36ms +step:1488/1555 train_time:89840ms step_avg:60.38ms +step:1489/1555 train_time:89925ms step_avg:60.39ms +step:1490/1555 train_time:90016ms step_avg:60.41ms +step:1491/1555 train_time:90099ms step_avg:60.43ms +step:1492/1555 train_time:90190ms step_avg:60.45ms +step:1493/1555 train_time:90274ms step_avg:60.46ms +step:1494/1555 train_time:90364ms step_avg:60.48ms +step:1495/1555 train_time:90448ms step_avg:60.50ms +step:1496/1555 train_time:90539ms step_avg:60.52ms +step:1497/1555 train_time:90623ms step_avg:60.54ms +step:1498/1555 train_time:90714ms step_avg:60.56ms +step:1499/1555 train_time:90798ms step_avg:60.57ms +step:1500/1555 train_time:90889ms step_avg:60.59ms +step:1500/1555 val_loss:3.2935 train_time:91004ms step_avg:60.67ms +step:1501/1555 train_time:91025ms step_avg:60.64ms +step:1502/1555 train_time:91064ms step_avg:60.63ms +step:1503/1555 train_time:91154ms step_avg:60.65ms +step:1504/1555 train_time:91245ms step_avg:60.67ms +step:1505/1555 train_time:91330ms step_avg:60.68ms +step:1506/1555 train_time:91420ms step_avg:60.70ms +step:1507/1555 train_time:91503ms step_avg:60.72ms +step:1508/1555 train_time:91591ms step_avg:60.74ms +step:1509/1555 train_time:91674ms step_avg:60.75ms +step:1510/1555 train_time:91763ms step_avg:60.77ms +step:1511/1555 train_time:91846ms step_avg:60.78ms +step:1512/1555 train_time:91936ms step_avg:60.80ms +step:1513/1555 train_time:92022ms step_avg:60.82ms +step:1514/1555 train_time:92115ms step_avg:60.84ms +step:1515/1555 train_time:92202ms step_avg:60.86ms +step:1516/1555 train_time:92297ms step_avg:60.88ms +step:1517/1555 train_time:92382ms step_avg:60.90ms +step:1518/1555 train_time:92471ms step_avg:60.92ms +step:1519/1555 train_time:92555ms step_avg:60.93ms +step:1520/1555 train_time:92644ms step_avg:60.95ms +step:1521/1555 train_time:92727ms step_avg:60.96ms +step:1522/1555 train_time:92817ms step_avg:60.98ms +step:1523/1555 train_time:92901ms step_avg:61.00ms +step:1524/1555 train_time:92991ms step_avg:61.02ms +step:1525/1555 train_time:93078ms step_avg:61.03ms +step:1526/1555 train_time:93169ms step_avg:61.05ms +step:1527/1555 train_time:93255ms step_avg:61.07ms +step:1528/1555 train_time:93346ms step_avg:61.09ms +step:1529/1555 train_time:93432ms step_avg:61.11ms +step:1530/1555 train_time:93522ms step_avg:61.13ms +step:1531/1555 train_time:93605ms step_avg:61.14ms +step:1532/1555 train_time:93695ms step_avg:61.16ms +step:1533/1555 train_time:93779ms step_avg:61.17ms +step:1534/1555 train_time:93868ms step_avg:61.19ms +step:1535/1555 train_time:93952ms step_avg:61.21ms +step:1536/1555 train_time:94044ms step_avg:61.23ms +step:1537/1555 train_time:94130ms step_avg:61.24ms +step:1538/1555 train_time:94222ms step_avg:61.26ms +step:1539/1555 train_time:94307ms step_avg:61.28ms +step:1540/1555 train_time:94399ms step_avg:61.30ms +step:1541/1555 train_time:94481ms step_avg:61.31ms +step:1542/1555 train_time:94571ms step_avg:61.33ms +step:1543/1555 train_time:94655ms step_avg:61.35ms +step:1544/1555 train_time:94746ms step_avg:61.36ms +step:1545/1555 train_time:94830ms step_avg:61.38ms +step:1546/1555 train_time:94920ms step_avg:61.40ms +step:1547/1555 train_time:95004ms step_avg:61.41ms +step:1548/1555 train_time:95095ms step_avg:61.43ms +step:1549/1555 train_time:95180ms step_avg:61.45ms +step:1550/1555 train_time:95271ms step_avg:61.47ms +step:1551/1555 train_time:95357ms step_avg:61.48ms +step:1552/1555 train_time:95446ms step_avg:61.50ms +step:1553/1555 train_time:95530ms step_avg:61.51ms +step:1554/1555 train_time:95620ms step_avg:61.53ms +step:1555/1555 train_time:95704ms step_avg:61.55ms +step:1555/1555 val_loss:3.2769 train_time:95818ms step_avg:61.62ms +peak memory allocated: 30920 MiB reserved: 46638 MiB diff --git a/train_gpt.py b/train_gpt.py index 02edd0b68..4e955b2a2 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1134,7 +1134,21 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: i ) self.scalars.label = 'scalars' - def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, bigram_input_seq: Tensor, schedule_cfg: ForwardScheduleConfig): + @staticmethod + @torch.compile(dynamic=False, fullgraph=True) + def _compute_bigram_hash(x: Tensor, mod: int) -> Tensor: + """ + Computes bigram hash on GPU for each position using [prev_token, curr_token]. + Mathematically identical to the CPU version but computed on device. + """ + rand_int_1 = 36313 + rand_int_2 = 27191 + result = torch.empty_like(x) + result[0] = mod + result[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod + return result + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, schedule_cfg: ForwardScheduleConfig): assert input_seq.ndim == 1 # unpack schedule_cfg @@ -1163,7 +1177,9 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, bigram # Embedding lookup - embed is synced from lm_head during tied phase by optimizer x = self.embed(input_seq) - x0_bigram = self.bigram_embed(bigram_input_seq)[None] + # Compute bigram hash on GPU (moved from CPU data loader) + bigram_seq = self._compute_bigram_hash(input_seq, args.bigram_vocab_size - 1) + x0_bigram = self.bigram_embed(bigram_seq)[None] # Value embeddings - always computed (not precomputed) ve = self.value_embeds.view(5, self.vocab_size, -1)[:, input_seq] @@ -1318,21 +1334,6 @@ def get(): return result['shard'] return get -def get_bigram_hash(x): - """ - Computes bigram hash for each position using [prev_token, curr_token]. - Multiply by arbitary large ints to get even spread over int32 range. - Position 0 is mapped to the reserved index (vocab_size - 1). - BOS_tokens within the batch will hash based on last token of prior doc. Masking this ran slower and showed no improvement. - """ - rand_int_1 = 36313 - rand_int_2 = 27191 - mod = args.bigram_vocab_size-1 - x = x.to(torch.int32).clone() - x[0] = mod - x[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod - return x - def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len rank = dist.get_rank() if dist.is_initialized() else 0 @@ -1397,13 +1398,12 @@ def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_l _inputs = _inputs.to(dtype=torch.int32) _targets = _targets.to(dtype=torch.int64) _cum_lengths = _cum_lengths.to(dtype=torch.int32) - _bigram_inputs = get_bigram_hash(_inputs) + # Bigram hash computation moved to GPU in forward() new_params = yield ( _inputs.to(device="cuda", non_blocking=True), _targets.to(device="cuda", non_blocking=True), _cum_lengths.to(device="cuda", non_blocking=True), - _bigram_inputs.to(device="cuda", non_blocking=True) ) if new_params is not None: @@ -1736,13 +1736,13 @@ def nvidia_smi(): training_manager.advance_schedule(step) model.eval() with torch.no_grad(): - inputs, targets, cum_seqlens, bigram_inputs = next(val_loader) - model(inputs, targets, cum_seqlens, bigram_inputs, training_manager.get_forward_args()) + inputs, targets, cum_seqlens = next(val_loader) + model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) model.train() for idx in range(grad_accum_steps): send_args = training_manager.train_loader_send_args - inputs, targets, cum_seqlens, bigram_inputs = train_loader.send(send_args) - (model(inputs, targets, cum_seqlens, bigram_inputs, training_manager.get_forward_args()) * grad_scale).backward() + inputs, targets, cum_seqlens = train_loader.send(send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() training_manager.step_optimizers(step) print0("Resetting Model", console=True) model.zero_grad(set_to_none=True) @@ -1781,8 +1781,8 @@ def nvidia_smi(): val_loss = 0 with torch.no_grad(): for _ in range(val_steps): - inputs, targets, cum_seqlens, bigram_inputs = next(val_loader) - val_loss += model(inputs, targets, cum_seqlens, bigram_inputs, training_manager.get_forward_args()) + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) val_loss /= val_steps del val_loader dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG) @@ -1802,8 +1802,8 @@ def nvidia_smi(): # --------------- TRAINING SECTION ----------------- for idx in range(grad_accum_steps): - inputs, targets, cum_seqlens, bigram_inputs = train_loader.send(training_manager.train_loader_send_args) - (model(inputs, targets, cum_seqlens, bigram_inputs, training_manager.get_forward_args()) * grad_scale).backward() + inputs, targets, cum_seqlens = train_loader.send(training_manager.train_loader_send_args) + (model(inputs, targets, cum_seqlens, training_manager.get_forward_args()) * grad_scale).backward() training_manager.step_optimizers(step) # logging